-- |
-- Module      : Data.SNat
-- Description : Singleton naturals
--
-- Runtime data that connects to type-level nats.
module Data.SNat(
  Nat(..), toNatural, fromNatural,
  SNat(..),  snatToNat,
  SNatI(..), snat, withSNat, reify, reflect,
  type (+),
  N0, N1, N2, N3,
  s0, s1, s2, s3,
  sPlus,
  axiomPlusZ,
  axiomAssoc,
  SNat_(..), snat_,
  prev,
  next,
  ToInt(..),
 ) where

-- Singleton nats are purely runtime

import Data.Type.Equality
import Data.Type.Nat
import Test.QuickCheck
import Unsafe.Coerce (unsafeCoerce)
import Prelude hiding (pred, succ)

-----------------------------------------------------
-- axioms (use unsafeCoerce)
-----------------------------------------------------

-- | '0' is identity element for @+@
axiomPlusZ :: forall m. m + Z :~: m
axiomPlusZ :: forall (m :: Nat). (m + N0) :~: m
axiomPlusZ = (Any :~: Any) -> Plus m N0 :~: m
forall a b. a -> b
unsafeCoerce Any :~: Any
forall {k} (a :: k). a :~: a
Refl

-- | @+@ is associative.
axiomAssoc :: forall p m n. p + (m + n) :~: (p + m) + n
axiomAssoc :: forall (p :: Nat) (m :: Nat) (n :: Nat).
(p + (m + n)) :~: ((p + m) + n)
axiomAssoc = (Any :~: Any) -> Plus p (Plus m n) :~: Plus (Plus p m) n
forall a b. a -> b
unsafeCoerce Any :~: Any
forall {k} (a :: k). a :~: a
Refl

-----------------------------------------------------
-- Nats (singleton nats and implicit singletons)
-----------------------------------------------------

-- | 0.
type N0 = Z

-- | 1.
type N1 = S N0

-- | 2.
type N2 = S N1

-- | 3.
type N3 = S N2

---------------------------------------------------------
-- Singletons and instances
---------------------------------------------------------

-- | 0.
s0 :: SNat N0
s0 :: SNat N0
s0 = SNat N0
forall (n :: Nat). SNatI n => SNat n
snat

-- | 1.
s1 :: SNat N1
s1 :: SNat N1
s1 = SNat N1
forall (n :: Nat). SNatI n => SNat n
snat

-- | 2.
s2 :: SNat N2
s2 :: SNat N2
s2 = SNat N2
forall (n :: Nat). SNatI n => SNat n
snat

-- | 3.
s3 :: SNat N3
s3 :: SNat N3
s3 = SNat N3
forall (n :: Nat). SNatI n => SNat n
snat

instance (SNatI n) => Arbitrary (SNat n) where
  arbitrary :: (SNatI n) => Gen (SNat n)
  arbitrary :: SNatI n => Gen (SNat n)
arbitrary = SNat n -> Gen (SNat n)
forall a. a -> Gen a
forall (f :: * -> *) a. Applicative f => a -> f a
pure SNat n
forall (n :: Nat). SNatI n => SNat n
snat

-- | Conversion to 'Int'.
class ToInt a where
  toInt :: a -> Int

instance ToInt (SNat n) where
  toInt :: SNat n -> Int
  toInt :: SNat n -> Int
toInt = Integer -> Int
forall a. Num a => Integer -> a
fromInteger (Integer -> Int) -> (SNat n -> Integer) -> SNat n -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Nat -> Integer
forall a. Integral a => a -> Integer
toInteger (Nat -> Integer) -> (SNat n -> Nat) -> SNat n -> Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SNat n -> Nat
forall (n :: Nat). SNat n -> Nat
snatToNat

---------------------------------------------------------
-- Addition
---------------------------------------------------------

-- | Notation for the addition of naturals.
type family (n :: Nat) + (m :: Nat) :: Nat where
  m + n = Plus m n

-- | Addition of singleton naturals.
sPlus :: forall n1 n2. SNat n1 -> SNat n2 -> SNat (n1 + n2)
sPlus :: forall (n1 :: Nat) (n2 :: Nat).
SNat n1 -> SNat n2 -> SNat (n1 + n2)
sPlus SNat n1
SZ SNat n2
n = SNat n2
SNat (n1 + n2)
n
sPlus x :: SNat n1
x@SNat n1
SS SNat n2
y = SNat (Plus n1 n2)
-> (SNatI (Plus n1 n2) => SNat ('S (Plus n1 n2)))
-> SNat ('S (Plus n1 n2))
forall (n :: Nat) r. SNat n -> (SNatI n => r) -> r
withSNat (SNat n1 -> SNat n2 -> SNat (n1 + n2)
forall (n1 :: Nat) (n2 :: Nat).
SNat n1 -> SNat n2 -> SNat (n1 + n2)
sPlus (SNat ('S n1) -> SNat n1
forall (n :: Nat). SNat ('S n) -> SNat n
prev SNat n1
SNat ('S n1)
x) SNat n2
y) SNat ('S (Plus n1 n2))
SNatI (Plus n1 n2) => SNat ('S (Plus n1 n2))
forall (n1 :: Nat). SNatI n1 => SNat ('S n1)
SS

-- >>> reflect $ sPlus s3 s1
-- 4

---------------------------------------------------------
-- View pattern access to the predecessor
---------------------------------------------------------

-- | View pattern allowing pattern matching on naturals.
-- See 'snat_'.
data SNat_ n where
  SZ_ :: SNat_ Z
  SS_ :: SNat n -> SNat_ (S n)

-- | View pattern allowing pattern matching on naturals.
--
-- @
-- f :: forall p. SNat p -> ...
-- f SZ = ...
-- f (snat_ -> SS_ m) = ...
-- @
snat_ :: SNat n -> SNat_ n
snat_ :: forall (n :: Nat). SNat n -> SNat_ n
snat_ SNat n
SZ = SNat_ n
SNat_ N0
SZ_
snat_ SNat n
SS = SNat n1 -> SNat_ ('S n1)
forall (n :: Nat). SNat n -> SNat_ ('S n)
SS_ SNat n1
forall (n :: Nat). SNatI n => SNat n
snat

-- | Predecessor of a natural.
prev :: SNat (S n) -> SNat n
prev :: forall (n :: Nat). SNat ('S n) -> SNat n
prev SNat ('S n)
SS = SNat n
forall (n :: Nat). SNatI n => SNat n
snat

-- | Successor of a natural.
next :: SNat n -> SNat (S n)
next :: forall (n :: Nat). SNat n -> SNat ('S n)
next SNat n
x = SNat n -> (SNatI n => SNat ('S n)) -> SNat ('S n)
forall (n :: Nat) r. SNat n -> (SNatI n => r) -> r
withSNat SNat n
x SNat ('S n)
SNatI n => SNat ('S n)
forall (n1 :: Nat). SNatI n1 => SNat ('S n1)
SS