{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE TypeFamilies #-}

module Data.Type.Nat.Singleton.Safe (
  -- * Natural Number Singletons
  SNat (..),
  toSafe,
  fromSafe,
  fromSNat,
  fromSNatRaw,
  plus,
  decSNat,

  -- * Existential Wrapper
  SomeSNat (..),
  withSomeSNat,
  toSomeSNat,
  toSomeSNatRaw,
  fromSomeSNat,
  fromSomeSNatRaw,

  -- * Laws
  plusUnitL,
  plusUnitR,
  plusCommS,
  plusComm,
  plusAssoc,

  -- * Linking Type-Level and Value-Level
  KnownNat (..),
  withKnownNat,

  -- * Specialised target for conversion
  SNatRep,
) where

import Control.DeepSeq (NFData (..))
import Data.Kind (Constraint, Type)
import Data.Maybe (isJust)
import Data.Proxy (Proxy (..))
import Data.Type.Equality (type (:~:) (Refl))
import Data.Type.Equality qualified as Eq
import Data.Type.Nat (Nat (..), type (+))
import Data.Type.Nat.Singleton.Fast (SNatRep)
import Data.Type.Nat.Singleton.Fast qualified as Fast

{- $setup
>>> import Data.Type.Nat.Singleton.Safe.Arbitrary
-}

--------------------------------------------------------------------------------
-- Natural Number Singletons
--------------------------------------------------------------------------------

-- | @'SNat' n@ is the singleton type for natural numbers.
type SNat :: Nat -> Type
data SNat n where
  Z :: SNat Z
  S :: SNat n -> SNat (S n)

instance Eq (SNat n) where
  (==) :: SNat n -> SNat n -> Bool
  SNat n
m == :: SNat n -> SNat n -> Bool
== SNat n
n = Maybe (n :~: n) -> Bool
forall a. Maybe a -> Bool
isJust (SNat n -> SNat n -> Maybe (n :~: n)
forall (m :: Nat) (n :: Nat). SNat m -> SNat n -> Maybe (m :~: n)
decSNat SNat n
m SNat n
n)

instance Show (SNat n) where
  showsPrec :: Int -> SNat n -> ShowS
  showsPrec :: Int -> SNat n -> ShowS
showsPrec Int
p = \case
    SNat n
Z -> String -> ShowS
showString String
"Z"
    S SNat n
n -> Bool -> ShowS -> ShowS
showParen (Int
p Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
10) (ShowS -> ShowS) -> ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$ String -> ShowS
showString String
"S " ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> SNat n -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
11 SNat n
n

instance NFData (SNat n) where
  rnf :: SNat n -> ()
  rnf :: SNat n -> ()
rnf SNat n
Z = ()
  rnf (S SNat n
n) = SNat n -> ()
forall a. NFData a => a -> ()
rnf SNat n
n

-- | Convert from the efficient representation 'Fast.SNat' to the safe representation 'SNat'.
toSafe :: Fast.SNat n -> SNat n
toSafe :: forall (n :: Nat). SNat n -> SNat n
toSafe SNat n
Fast.Z = SNat n
SNat Z
Z
toSafe (Fast.S SNat (Pred n)
n) = SNat (Pred n) -> SNat (S (Pred n))
forall (n :: Nat). SNat n -> SNat (S n)
S (SNat (Pred n) -> SNat (Pred n)
forall (n :: Nat). SNat n -> SNat n
toSafe SNat (Pred n)
n)

-- | Convert from the safe representation 'SNat' to the efficient representation 'Fast.SNat'.
fromSafe :: SNat n -> Fast.SNat n
fromSafe :: forall (n :: Nat). SNat n -> SNat n
fromSafe SNat n
Z = SNat n
forall (n :: Nat). (n ~ Z) => SNat n
Fast.Z
fromSafe (S SNat n
n) = SNat (Pred n) -> SNat n
forall (n :: Nat). Pos n => SNat (Pred n) -> SNat n
Fast.S (SNat n -> SNat n
forall (n :: Nat). SNat n -> SNat n
fromSafe SNat n
n)

-- | @'fromSNat' n@ returns the numeric representation of 'SNat n'.
fromSNat :: (Integral i) => SNat n -> i
fromSNat :: forall i (n :: Nat). Integral i => SNat n -> i
fromSNat SNat n
Z = i
0
fromSNat (S SNat n
n') = i
1 i -> i -> i
forall a. Num a => a -> a -> a
+ SNat n -> i
forall i (n :: Nat). Integral i => SNat n -> i
fromSNat SNat n
n'
{-# SPECIALIZE fromSNat :: SNat n -> SNatRep #-}

fromSNatRaw :: SNat n -> SNatRep
fromSNatRaw :: forall (n :: Nat). SNat n -> Int
fromSNatRaw = SNat n -> Int
forall i (n :: Nat). Integral i => SNat n -> i
fromSNat
{-# INLINE fromSNatRaw #-}

-- | Addition for natural number singletons.
plus :: SNat n -> SNat m -> SNat (n + m)
SNat n
Z plus :: forall (n :: Nat) (m :: Nat). SNat n -> SNat m -> SNat (n + m)
`plus` SNat m
m = SNat m
SNat (n + m)
m
S SNat n
n `plus` SNat m
m = SNat (n + m) -> SNat (S (n + m))
forall (n :: Nat). SNat n -> SNat (S n)
S (SNat n
n SNat n -> SNat m -> SNat (n + m)
forall (n :: Nat) (m :: Nat). SNat n -> SNat m -> SNat (n + m)
`plus` SNat m
m)

-- | Decidable equality for natural number singletons.
decSNat :: SNat m -> SNat n -> Maybe (m :~: n)
decSNat :: forall (m :: Nat) (n :: Nat). SNat m -> SNat n -> Maybe (m :~: n)
decSNat SNat m
Z SNat n
Z = (m :~: n) -> Maybe (m :~: n)
forall a. a -> Maybe a
Just m :~: m
m :~: n
forall {k} (a :: k). a :~: a
Refl
decSNat (S SNat n
m') (S SNat n
n') = (\n :~: n
Refl -> m :~: m
m :~: n
forall {k} (a :: k). a :~: a
Refl) ((n :~: n) -> m :~: n) -> Maybe (n :~: n) -> Maybe (m :~: n)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SNat n -> SNat n -> Maybe (n :~: n)
forall (m :: Nat) (n :: Nat). SNat m -> SNat n -> Maybe (m :~: n)
decSNat SNat n
m' SNat n
n'
decSNat SNat m
_m SNat n
_n = Maybe (m :~: n)
forall a. Maybe a
Nothing

--------------------------------------------------------------------------------
-- Existential Wrapper
--------------------------------------------------------------------------------

-- | An existential wrapper around natural number singletons.
type SomeSNat :: Type
data SomeSNat = forall (n :: Nat). SomeSNat !(SNat n)

instance NFData SomeSNat where
  rnf :: SomeSNat -> ()
  rnf :: SomeSNat -> ()
rnf (SomeSNat SNat n
n) = SNat n -> ()
forall a. NFData a => a -> ()
rnf SNat n
n

deriving instance Show SomeSNat

instance Eq SomeSNat where
  (==) :: SomeSNat -> SomeSNat -> Bool
  SomeSNat SNat n
m == :: SomeSNat -> SomeSNat -> Bool
== SomeSNat SNat n
n = Maybe (n :~: n) -> Bool
forall a. Maybe a -> Bool
isJust (SNat n -> SNat n -> Maybe (n :~: n)
forall (m :: Nat) (n :: Nat). SNat m -> SNat n -> Maybe (m :~: n)
decSNat SNat n
m SNat n
n)

-- | Evaluate a term with access to the underlying @'SNat'@.
withSomeSNat :: (forall n. SNat n -> a) -> SomeSNat -> a
withSomeSNat :: forall a. (forall (n :: Nat). SNat n -> a) -> SomeSNat -> a
withSomeSNat forall (n :: Nat). SNat n -> a
action (SomeSNat SNat n
n) = SNat n -> a
forall (n :: Nat). SNat n -> a
action SNat n
n

{-| @'toSomeSNat' n@ constructs the singleton @'SNat' n@.

prop> toSomeSNat (fromSomeSNat n) == n
-}
toSomeSNat :: (Integral i) => i -> SomeSNat
toSomeSNat :: forall i. Integral i => i -> SomeSNat
toSomeSNat i
n = i -> (SomeSNat -> SomeSNat) -> SomeSNat -> SomeSNat
forall i a. Integral i => i -> (a -> a) -> a -> a
iterate' i
n ((forall (n :: Nat). SNat n -> SomeSNat) -> SomeSNat -> SomeSNat
forall a. (forall (n :: Nat). SNat n -> a) -> SomeSNat -> a
withSomeSNat ((forall (n :: Nat). SNat n -> SomeSNat) -> SomeSNat -> SomeSNat)
-> (forall (n :: Nat). SNat n -> SomeSNat) -> SomeSNat -> SomeSNat
forall a b. (a -> b) -> a -> b
$ SNat (S n) -> SomeSNat
forall (n :: Nat). SNat n -> SomeSNat
SomeSNat (SNat (S n) -> SomeSNat)
-> (SNat n -> SNat (S n)) -> SNat n -> SomeSNat
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SNat n -> SNat (S n)
forall (n :: Nat). SNat n -> SNat (S n)
S) (SNat Z -> SomeSNat
forall (n :: Nat). SNat n -> SomeSNat
SomeSNat SNat Z
Z)

{-| @'toSomeSNat' n@ constructs the singleton @'SNat' n@.

prop> toSomeSNatRaw (fromSomeSNatRaw n) == n
-}
toSomeSNatRaw :: SNatRep -> SomeSNat
toSomeSNatRaw :: Int -> SomeSNat
toSomeSNatRaw = Int -> SomeSNat
forall i. Integral i => i -> SomeSNat
toSomeSNat

-- | @'fromSomeSNat' n@ returns the numeric representation of the wrapped singleton.
fromSomeSNat :: (Integral i) => SomeSNat -> i
fromSomeSNat :: forall i. Integral i => SomeSNat -> i
fromSomeSNat = (forall (n :: Nat). SNat n -> i) -> SomeSNat -> i
forall a. (forall (n :: Nat). SNat n -> a) -> SomeSNat -> a
withSomeSNat SNat n -> i
forall i (n :: Nat). Integral i => SNat n -> i
forall (n :: Nat). SNat n -> i
fromSNat

-- | @'fromSomeSNat' n@ returns the 'SNatRep' representation of the wrapped singleton.
fromSomeSNatRaw :: SomeSNat -> SNatRep
fromSomeSNatRaw :: SomeSNat -> Int
fromSomeSNatRaw = SomeSNat -> Int
forall i. Integral i => SomeSNat -> i
fromSomeSNat

--------------------------------------------------------------------------------
-- Laws
--------------------------------------------------------------------------------

plusUnitL :: Proxy n -> Z + n :~: n
plusUnitL :: forall (n :: Nat). Proxy n -> (Z + n) :~: n
plusUnitL Proxy n
_ = n :~: n
(Z + n) :~: n
forall {k} (a :: k). a :~: a
Refl

plusUnitR :: SNat n -> n + Z :~: n
plusUnitR :: forall (n :: Nat). SNat n -> (n + Z) :~: n
plusUnitR SNat n
Z = n :~: n
(n + Z) :~: n
forall {k} (a :: k). a :~: a
Refl
plusUnitR (S SNat n
n') =
  case SNat n -> (n + Z) :~: n
forall (n :: Nat). SNat n -> (n + Z) :~: n
plusUnitR SNat n
n' of
    (n + Z) :~: n
Refl -> n :~: n
(n + Z) :~: n
forall {k} (a :: k). a :~: a
Refl

plusCommS :: SNat n -> Proxy m -> S (n + m) :~: n + S m
plusCommS :: forall (n :: Nat) (m :: Nat).
SNat n -> Proxy m -> S (n + m) :~: (n + S m)
plusCommS SNat n
Z Proxy m
_ = S m :~: S m
S (n + m) :~: (n + S m)
forall {k} (a :: k). a :~: a
Refl
plusCommS (S SNat n
n') Proxy m
m = (S :~: S)
-> (S (n + m) :~: (n + S m)) -> S (S (n + m)) :~: S (n + S m)
forall {k1} {k2} (f :: k1 -> k2) (g :: k1 -> k2) (a :: k1)
       (b :: k1).
(f :~: g) -> (a :~: b) -> f a :~: g b
Eq.apply S :~: S
forall {k} (a :: k). a :~: a
Refl (SNat n -> Proxy m -> S (n + m) :~: (n + S m)
forall (n :: Nat) (m :: Nat).
SNat n -> Proxy m -> S (n + m) :~: (n + S m)
plusCommS SNat n
n' Proxy m
m)

plusComm :: SNat n -> SNat m -> n + m :~: m + n
plusComm :: forall (n :: Nat) (m :: Nat).
SNat n -> SNat m -> (n + m) :~: (m + n)
plusComm SNat n
Z SNat m
m = ((m + Z) :~: m) -> m :~: (m + Z)
forall {k} (a :: k) (b :: k). (a :~: b) -> b :~: a
Eq.sym (SNat m -> (m + Z) :~: m
forall (n :: Nat). SNat n -> (n + Z) :~: n
plusUnitR SNat m
m)
plusComm (S SNat n
n') SNat m
m = (S :~: S) -> ((n + m) :~: (m + n)) -> S (n + m) :~: S (m + n)
forall {k1} {k2} (f :: k1 -> k2) (g :: k1 -> k2) (a :: k1)
       (b :: k1).
(f :~: g) -> (a :~: b) -> f a :~: g b
Eq.apply S :~: S
forall {k} (a :: k). a :~: a
Refl (SNat n -> SNat m -> (n + m) :~: (m + n)
forall (n :: Nat) (m :: Nat).
SNat n -> SNat m -> (n + m) :~: (m + n)
plusComm SNat n
n' SNat m
m) (S (n + m) :~: S (m + n))
-> (S (m + n) :~: (m + S n)) -> S (n + m) :~: (m + S n)
forall {k} (a :: k) (b :: k) (c :: k).
(a :~: b) -> (b :~: c) -> a :~: c
`Eq.trans` SNat m -> Proxy n -> S (m + n) :~: (m + S n)
forall (n :: Nat) (m :: Nat).
SNat n -> Proxy m -> S (n + m) :~: (n + S m)
plusCommS SNat m
m (SNat n -> Proxy n
forall {k} (f :: k -> *) (a :: k). f a -> Proxy a
erase SNat n
n')

plusAssoc :: SNat n -> Proxy m -> Proxy l -> (n + m) + l :~: n + (m + l)
plusAssoc :: forall (n :: Nat) (m :: Nat) (l :: Nat).
SNat n -> Proxy m -> Proxy l -> ((n + m) + l) :~: (n + (m + l))
plusAssoc SNat n
Z Proxy m
_m Proxy l
_l = (m + l) :~: (m + l)
((n + m) + l) :~: (n + (m + l))
forall {k} (a :: k). a :~: a
Refl
plusAssoc (S SNat n
n') Proxy m
m Proxy l
l = (S :~: S)
-> (((n + m) + l) :~: (n + (m + l)))
-> S ((n + m) + l) :~: S (n + (m + l))
forall {k1} {k2} (f :: k1 -> k2) (g :: k1 -> k2) (a :: k1)
       (b :: k1).
(f :~: g) -> (a :~: b) -> f a :~: g b
Eq.apply S :~: S
forall {k} (a :: k). a :~: a
Refl (SNat n -> Proxy m -> Proxy l -> ((n + m) + l) :~: (n + (m + l))
forall (n :: Nat) (m :: Nat) (l :: Nat).
SNat n -> Proxy m -> Proxy l -> ((n + m) + l) :~: (n + (m + l))
plusAssoc SNat n
n' Proxy m
m Proxy l
l)

--------------------------------------------------------------------------------
-- Linking Type-Level and Value-Level
--------------------------------------------------------------------------------

type KnownNat :: Nat -> Constraint
class KnownNat n where
  natSing :: SNat n

instance KnownNat Z where
  natSing :: SNat Z
  natSing :: SNat Z
natSing = SNat Z
Z

instance (KnownNat n) => KnownNat (S n) where
  natSing :: SNat (S n)
  natSing :: SNat (S n)
natSing = SNat n -> SNat (S n)
forall (n :: Nat). SNat n -> SNat (S n)
S SNat n
forall (n :: Nat). KnownNat n => SNat n
natSing

withKnownNat :: SNat n -> ((KnownNat n) => r) -> r
withKnownNat :: forall (n :: Nat) r. SNat n -> (KnownNat n => r) -> r
withKnownNat SNat n
Z KnownNat n => r
action = r
KnownNat n => r
action
withKnownNat (S SNat n
n) KnownNat n => r
action = SNat n -> (KnownNat n => r) -> r
forall (n :: Nat) r. SNat n -> (KnownNat n => r) -> r
withKnownNat SNat n
n r
KnownNat n => r
KnownNat n => r
action

--------------------------------------------------------------------------------
-- Helper Functions
--------------------------------------------------------------------------------

-- | @`erase` x@ erases the content of @x@ to a @`Proxy`@.
erase :: f a -> Proxy a
erase :: forall {k} (f :: k -> *) (a :: k). f a -> Proxy a
erase f a
_ = Proxy a
forall {k} (t :: k). Proxy t
Proxy
{-# INLINE erase #-}

-- | @`iterate'` i f@ applies @f@ @i@ times.
iterate' :: (Integral i) => i -> (a -> a) -> a -> a
iterate' :: forall i a. Integral i => i -> (a -> a) -> a -> a
iterate' i
i a -> a
f a
x
  | i
i i -> i -> Bool
forall a. Ord a => a -> a -> Bool
<= i
0 = a
x
  | Bool
otherwise = i -> (a -> a) -> a -> a
forall i a. Integral i => i -> (a -> a) -> a -> a
iterate' (i
i i -> i -> i
forall a. Num a => a -> a -> a
- i
1) a -> a
f (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$! a -> a
f a
x