{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilyDependencies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Data.Array.Nested.Types (
  -- * Reasoning helpers
  subst1, subst2,

  -- * Reified evidence of a type class
  Dict(..),

  -- * Type-level naturals
  pattern SZ, pattern SS,
  fromSNat', sameNat',
  snatPlus, snatMinus, snatMul,
  snatSucc,

  -- * Type-level lists
  type (++),
  Replicate,
  lemReplicateSucc,
  MapJust,
  lemMapJustEmpty, lemMapJustCons,
  Head,
  Tail,
  Init,
  Last,

  -- * Unsafe
  unsafeCoerceRefl,
) where

import Data.Proxy
import Data.Type.Equality
import GHC.TypeLits
import GHC.TypeNats qualified as TN
import Unsafe.Coerce qualified


-- Reasoning helpers

subst1 :: forall f a b. a :~: b -> f a :~: f b
subst1 :: forall {k} {k} (f :: k -> k) (a :: k) (b :: k).
(a :~: b) -> f a :~: f b
subst1 a :~: b
Refl = f a :~: f a
f a :~: f b
forall {k} (a :: k). a :~: a
Refl

subst2 :: forall f c a b. a :~: b -> f a c :~: f b c
subst2 :: forall {k} {k} {k} (f :: k -> k -> k) (c :: k) (a :: k) (b :: k).
(a :~: b) -> f a c :~: f b c
subst2 a :~: b
Refl = f a c :~: f a c
f a c :~: f b c
forall {k} (a :: k). a :~: a
Refl

-- | Evidence for the constraint @c a@.
data Dict c a where
  Dict :: c a => Dict c a

fromSNat' :: SNat n -> Int
fromSNat' :: forall (n :: Natural). SNat n -> Int
fromSNat' = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int) -> (SNat n -> Integer) -> SNat n -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SNat n -> Integer
forall (n :: Natural). SNat n -> Integer
fromSNat

sameNat' :: SNat n -> SNat m -> Maybe (n :~: m)
sameNat' :: forall (n :: Natural) (m :: Natural).
SNat n -> SNat m -> Maybe (n :~: m)
sameNat' n :: SNat n
n@SNat n
SNat m :: SNat m
m@SNat m
SNat = SNat n -> SNat m -> Maybe (n :~: m)
forall (a :: Natural) (b :: Natural) (proxy1 :: Natural -> Type)
       (proxy2 :: Natural -> Type).
(KnownNat a, KnownNat b) =>
proxy1 a -> proxy2 b -> Maybe (a :~: b)
sameNat SNat n
n SNat m
m

pattern SZ :: () => (n ~ 0) => SNat n
pattern $mSZ :: forall {r} {n :: Natural}.
SNat n -> ((n ~ 0) => r) -> ((# #) -> r) -> r
$bSZ :: forall (n :: Natural). (n ~ 0) => SNat n
SZ <- ((\SNat n
sn -> SNat n -> SNat 0 -> Maybe (n :~: 0)
forall (n :: Natural) (m :: Natural).
SNat n -> SNat m -> Maybe (n :~: m)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality f =>
f a -> f b -> Maybe (a :~: b)
testEquality SNat n
sn (forall (n :: Natural). KnownNat n => SNat n
SNat @0)) -> Just Refl)
  where SZ = SNat n
forall (n :: Natural). KnownNat n => SNat n
SNat

pattern SS :: forall np1. () => forall n. (n + 1 ~ np1) => SNat n -> SNat np1
pattern $mSS :: forall {r} {np1 :: Natural}.
SNat np1
-> (forall {n :: Natural}. ((n + 1) ~ np1) => SNat n -> r)
-> ((# #) -> r)
-> r
$bSS :: forall (np1 :: Natural) (n :: Natural).
((n + 1) ~ np1) =>
SNat n -> SNat np1
SS sn <- (snatPred -> Just (SNatPredResult sn Refl))
  where SS = SNat n -> SNat np1
SNat n -> SNat (n + 1)
forall (n :: Natural). SNat n -> SNat (n + 1)
snatSucc

{-# COMPLETE SZ, SS #-}

snatSucc :: SNat n -> SNat (n + 1)
snatSucc :: forall (n :: Natural). SNat n -> SNat (n + 1)
snatSucc SNat n
SNat = SNat (n + 1)
forall (n :: Natural). KnownNat n => SNat n
SNat

data SNatPredResult np1 = forall n. SNatPredResult (SNat n) (n + 1 :~: np1)
snatPred :: forall np1. SNat np1 -> Maybe (SNatPredResult np1)
snatPred :: forall (np1 :: Natural). SNat np1 -> Maybe (SNatPredResult np1)
snatPred SNat np1
snp1 =
  SNat np1
-> (KnownNat np1 => Maybe (SNatPredResult np1))
-> Maybe (SNatPredResult np1)
forall (n :: Natural) r. SNat n -> (KnownNat n => r) -> r
withKnownNat SNat np1
snp1 ((KnownNat np1 => Maybe (SNatPredResult np1))
 -> Maybe (SNatPredResult np1))
-> (KnownNat np1 => Maybe (SNatPredResult np1))
-> Maybe (SNatPredResult np1)
forall a b. (a -> b) -> a -> b
$
    case Proxy 1 -> Proxy np1 -> OrderingI 1 np1
forall (a :: Natural) (b :: Natural) (proxy1 :: Natural -> Type)
       (proxy2 :: Natural -> Type).
(KnownNat a, KnownNat b) =>
proxy1 a -> proxy2 b -> OrderingI a b
cmpNat (forall (t :: Natural). Proxy t
forall {k} (t :: k). Proxy t
Proxy @1) (forall (t :: Natural). Proxy t
forall {k} (t :: k). Proxy t
Proxy @np1) of
      OrderingI 1 np1
LTI -> SNatPredResult np1 -> Maybe (SNatPredResult np1)
forall a. a -> Maybe a
Just (SNat (np1 - 1) -> (((np1 - 1) + 1) :~: np1) -> SNatPredResult np1
forall (np1 :: Natural) (n :: Natural).
SNat n -> ((n + 1) :~: np1) -> SNatPredResult np1
SNatPredResult (forall (n :: Natural). KnownNat n => SNat n
SNat @(np1 - 1)) np1 :~: np1
((np1 - 1) + 1) :~: np1
forall {k} (a :: k). a :~: a
Refl)
      OrderingI 1 np1
EQI -> SNatPredResult np1 -> Maybe (SNatPredResult np1)
forall a. a -> Maybe a
Just (SNat 0 -> ((0 + 1) :~: np1) -> SNatPredResult np1
forall (np1 :: Natural) (n :: Natural).
SNat n -> ((n + 1) :~: np1) -> SNatPredResult np1
SNatPredResult (forall (n :: Natural). KnownNat n => SNat n
SNat @(np1 - 1)) np1 :~: np1
(0 + 1) :~: np1
forall {k} (a :: k). a :~: a
Refl)
      OrderingI 1 np1
GTI -> Maybe (SNatPredResult np1)
forall a. Maybe a
Nothing

-- This should be a function in base
snatPlus :: SNat n -> SNat m -> SNat (n + m)
snatPlus :: forall (n :: Natural) (m :: Natural).
SNat n -> SNat m -> SNat (n + m)
snatPlus SNat n
n SNat m
m = Natural
-> (forall (n :: Natural). SNat n -> SNat (n + m)) -> SNat (n + m)
forall r. Natural -> (forall (n :: Natural). SNat n -> r) -> r
TN.withSomeSNat (SNat n -> Natural
forall (n :: Natural). SNat n -> Natural
TN.fromSNat SNat n
n Natural -> Natural -> Natural
forall a. Num a => a -> a -> a
+ SNat m -> Natural
forall (n :: Natural). SNat n -> Natural
TN.fromSNat SNat m
m) SNat n -> SNat (n + m)
forall (n :: Natural). SNat n -> SNat (n + m)
forall a b. a -> b
Unsafe.Coerce.unsafeCoerce

-- This should be a function in base
snatMinus :: SNat n -> SNat m -> SNat (n - m)
snatMinus :: forall (n :: Natural) (m :: Natural).
SNat n -> SNat m -> SNat (n - m)
snatMinus SNat n
n SNat m
m = let res :: Natural
res = SNat n -> Natural
forall (n :: Natural). SNat n -> Natural
TN.fromSNat SNat n
n Natural -> Natural -> Natural
forall a. Num a => a -> a -> a
- SNat m -> Natural
forall (n :: Natural). SNat n -> Natural
TN.fromSNat SNat m
m in Natural
res Natural -> SNat (n - m) -> SNat (n - m)
forall a b. a -> b -> b
`seq` Natural
-> (forall (n :: Natural). SNat n -> SNat (n - m)) -> SNat (n - m)
forall r. Natural -> (forall (n :: Natural). SNat n -> r) -> r
TN.withSomeSNat Natural
res SNat n -> SNat (n - m)
forall (n :: Natural). SNat n -> SNat (n - m)
forall a b. a -> b
Unsafe.Coerce.unsafeCoerce

-- This should be a function in base
snatMul :: SNat n -> SNat m -> SNat (n * m)
snatMul :: forall (n :: Natural) (m :: Natural).
SNat n -> SNat m -> SNat (n * m)
snatMul SNat n
n SNat m
m = Natural
-> (forall (n :: Natural). SNat n -> SNat (n * m)) -> SNat (n * m)
forall r. Natural -> (forall (n :: Natural). SNat n -> r) -> r
TN.withSomeSNat (SNat n -> Natural
forall (n :: Natural). SNat n -> Natural
TN.fromSNat SNat n
n Natural -> Natural -> Natural
forall a. Num a => a -> a -> a
* SNat m -> Natural
forall (n :: Natural). SNat n -> Natural
TN.fromSNat SNat m
m) SNat n -> SNat (n * m)
forall (n :: Natural). SNat n -> SNat (n * m)
forall a b. a -> b
Unsafe.Coerce.unsafeCoerce


-- | Type-level list append.
type family l1 ++ l2 where
  '[] ++ l2 = l2
  (x : xs) ++ l2 = x : xs ++ l2

type family Replicate n a where
  Replicate 0 a = '[]
  Replicate n a = a : Replicate (n - 1) a

lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a
lemReplicateSucc :: forall {a} (a :: a) (n :: Natural).
(a : Replicate n a) :~: Replicate (n + 1) a
lemReplicateSucc = (a : Replicate n a) :~: Replicate (n + 1) a
forall {k} (a :: k) (b :: k). a :~: b
unsafeCoerceRefl

type family MapJust l = r | r -> l where
  MapJust '[] = '[]
  MapJust (x : xs) = Just x : MapJust xs

lemMapJustEmpty :: MapJust sh :~: '[] -> sh :~: '[]
lemMapJustEmpty :: forall {a} (sh :: [a]). (MapJust sh :~: '[]) -> sh :~: '[]
lemMapJustEmpty MapJust sh :~: '[]
Refl = sh :~: '[]
forall {k} (a :: k) (b :: k). a :~: b
unsafeCoerceRefl

lemMapJustCons :: MapJust sh :~: Just n : sh' -> sh :~: n : Tail sh
lemMapJustCons :: forall {a} (sh :: [a]) (n :: a) (sh' :: [Maybe a]).
(MapJust sh :~: ('Just n : sh')) -> sh :~: (n : Tail sh)
lemMapJustCons MapJust sh :~: ('Just n : sh')
Refl = sh :~: (n : Tail sh)
forall {k} (a :: k) (b :: k). a :~: b
unsafeCoerceRefl

type family Head l where
  Head (x : _) = x

type family Tail l where
  Tail (_ : xs) = xs

type family Init l where
  Init (x : y : xs) = x : Init (y : xs)
  Init '[x] = '[]

type family Last l where
  Last (x : y : xs) = Last (y : xs)
  Last '[x] = x


-- | This is just @'Unsafe.Coerce.unsafeCoerce' 'Refl'@, but specialised to
-- only typecheck for actual type equalities. One cannot, e.g. accidentally
-- write this:
--
-- @
-- foo :: Proxy a -> Proxy b -> a :~: b
-- foo = unsafeCoerceRefl
-- @
--
-- which would have been permitted with normal 'Unsafe.Coerce.unsafeCoerce',
-- but would have resulted in interesting memory errors at runtime.
unsafeCoerceRefl :: a :~: b
unsafeCoerceRefl :: forall {k} (a :: k) (b :: k). a :~: b
unsafeCoerceRefl = (Any :~: Any) -> a :~: b
forall a b. a -> b
Unsafe.Coerce.unsafeCoerce Any :~: Any
forall {k} (a :: k). a :~: a
Refl