{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RoleAnnotations #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -Wno-duplicate-exports #-}

module Data.Type.Nat.Singleton.Fast (
  -- * Natural Number Singletons
  SNat (Z, S),
  fromSNat,
  fromSNatRaw,
  plus,
  decSNat,

  -- * Existential Wrapper
  SomeSNat (..),
  withSomeSNat,
  toSomeSNat,
  toSomeSNatRaw,
  fromSomeSNat,
  fromSomeSNatRaw,

  -- * Laws
  plusUnitL,
  plusUnitR,
  plusCommS,
  plusComm,
  plusAssoc,

  -- * Linking Type-Level and Value-Level
  KnownNat (..),
  withKnownNat,

  -- * Fast
  SNatRep,
  intToSNatRep,
  snatRepToInt,
  SNat (UnsafeSNat, snatRep),
) where

import Control.DeepSeq (NFData (..))
import Control.Exception (assert)
import Data.Kind (Constraint, Type)
import Data.Maybe (isJust)
import Data.Proxy (Proxy (..))
import Data.Type.Equality (type (:~:) (Refl))
import Data.Type.Nat (Nat (..), Pos, Pred, type (+))
import GHC.TypeLits qualified as GHC
import Text.Printf (printf)
import Unsafe.Coerce (unsafeCoerce)

#ifdef SNAT_AS_WORD8
import Control.Exception (throw, ArithException (Overflow, Underflow))
import Data.Word (Word8)
#endif

{- $setup
>>> import Data.Type.Nat.Singleton.Fast.Arbitrary
-}

--------------------------------------------------------------------------------
-- Natural Number Singleton Representation
--------------------------------------------------------------------------------

#if defined(SNAT_AS_WORD8)
type SNatRep = Word8
#elif defined(SNAT_AS_INT)
type SNatRep = Int
#elif !defined(__HLINT__)
#error "cpp: define one of [SNAT_AS_WORD8, SNAT_AS_INT]"
#endif

isValidSNatRep :: SNatRep -> Bool
isValidSNatRep :: Int -> Bool
isValidSNatRep = (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0)

mkZRep :: SNatRep
mkZRep :: Int
mkZRep = Int
0
{-# INLINE mkZRep #-}

mkSRep :: SNatRep -> SNatRep
mkSRep :: Int -> Int
mkSRep = (Int
1 +)
{-# INLINE mkSRep #-}

unSRep :: SNatRep -> SNatRep
unSRep :: Int -> Int
unSRep = Int -> Int -> Int
forall a. Num a => a -> a -> a
subtract Int
1
{-# INLINE unSRep #-}

elSNatRep :: a -> (SNatRep -> a) -> SNatRep -> a
elSNatRep :: forall a. a -> (Int -> a) -> Int -> a
elSNatRep a
ifZ Int -> a
ifS Int
n =
  Bool -> a -> a
forall a. HasCallStack => Bool -> a -> a
assert (Int -> Bool
isValidSNatRep Int
n) (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$
    if Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
mkZRep
      then a
ifZ
      else Int -> a
ifS (Int -> Int
unSRep Int
n)
{-# INLINE elSNatRep #-}

--------------------------------------------------------------------------------
-- Natural Number Singletons
--------------------------------------------------------------------------------

-- | @'SNat' n@ is the singleton type for natural numbers.
type SNat :: Nat -> Type
newtype SNat n = UnsafeSNat {forall (n :: Nat). SNat n -> Int
snatRep :: SNatRep}

type role SNat nominal

mkZ :: SNat Z
mkZ :: SNat Z
mkZ = Int -> SNat Z
forall (n :: Nat). Int -> SNat n
UnsafeSNat Int
mkZRep
{-# INLINE mkZ #-}

mkS :: SNat n -> SNat (S n)
mkS :: forall (n :: Nat). SNat n -> SNat (S n)
mkS = Int -> SNat (S n)
forall (n :: Nat). Int -> SNat n
UnsafeSNat (Int -> SNat (S n)) -> (SNat n -> Int) -> SNat n -> SNat (S n)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Int
mkSRep (Int -> Int) -> (SNat n -> Int) -> SNat n -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (.snatRep)
{-# INLINE mkS #-}

-- | @'SNatF'@ is the base functor of @'SNat'@.
data SNatF (snat :: Nat -> Type) (n :: Nat) where
  ZF :: SNatF snat Z
  SF :: !(snat n) -> SNatF snat (S n)

projectSNat :: SNat n -> SNatF SNat n
projectSNat :: forall (n :: Nat). SNat n -> SNatF SNat n
projectSNat =
  SNatF SNat n -> (Int -> SNatF SNat n) -> Int -> SNatF SNat n
forall a. a -> (Int -> a) -> Int -> a
elSNatRep (SNatF Any Z -> SNatF SNat n
forall a b. a -> b
unsafeCoerce SNatF Any Z
forall (snat :: Nat -> *). SNatF snat Z
ZF) (SNatF SNat (S Any) -> SNatF SNat n
forall a b. a -> b
unsafeCoerce (SNatF SNat (S Any) -> SNatF SNat n)
-> (Int -> SNatF SNat (S Any)) -> Int -> SNatF SNat n
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SNat Any -> SNatF SNat (S Any)
forall (snat :: Nat -> *) (n :: Nat). snat n -> SNatF snat (S n)
SF (SNat Any -> SNatF SNat (S Any))
-> (Int -> SNat Any) -> Int -> SNatF SNat (S Any)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> SNat Any
forall (n :: Nat). Int -> SNat n
UnsafeSNat) (Int -> SNatF SNat n) -> (SNat n -> Int) -> SNat n -> SNatF SNat n
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (.snatRep)
{-# INLINE projectSNat #-}

embedSNat :: SNatF SNat n -> SNat n
embedSNat :: forall (n :: Nat). SNatF SNat n -> SNat n
embedSNat = \case
  SNatF SNat n
ZF -> SNat n
SNat Z
mkZ
  SF SNat n
n -> SNat n -> SNat (S n)
forall (n :: Nat). SNat n -> SNat (S n)
mkS SNat n
n
{-# INLINE embedSNat #-}

pattern Z :: () => (n ~ Z) => SNat n
pattern $mZ :: forall {r} {n :: Nat}.
SNat n -> ((n ~ Z) => r) -> ((# #) -> r) -> r
$bZ :: forall (n :: Nat). (n ~ Z) => SNat n
Z <- (projectSNat -> ZF) where Z = SNatF SNat n -> SNat n
forall (n :: Nat). SNatF SNat n -> SNat n
embedSNat SNatF SNat n
SNatF SNat Z
forall (snat :: Nat -> *). SNatF snat Z
ZF
{-# INLINE Z #-}

pattern S :: () => (Pos n) => SNat (Pred n) -> SNat n
pattern $mS :: forall {r} {n :: Nat}.
SNat n -> (Pos n => SNat (Pred n) -> r) -> ((# #) -> r) -> r
$bS :: forall (n :: Nat). Pos n => SNat (Pred n) -> SNat n
S n <- (projectSNat -> SF n) where S SNat (Pred n)
n = SNatF SNat n -> SNat n
forall (n :: Nat). SNatF SNat n -> SNat n
embedSNat (SNat (Pred n) -> SNatF SNat (S (Pred n))
forall (snat :: Nat -> *) (n :: Nat). snat n -> SNatF snat (S n)
SF SNat (Pred n)
n)
{-# INLINE S #-}

{-# COMPLETE Z, S #-}

instance Eq (SNat n) where
  (==) :: SNat n -> SNat n -> Bool
  SNat n
m == :: SNat n -> SNat n -> Bool
== SNat n
n = Maybe (n :~: n) -> Bool
forall a. Maybe a -> Bool
isJust (SNat n -> SNat n -> Maybe (n :~: n)
forall (n :: Nat) (m :: Nat). SNat n -> SNat m -> Maybe (n :~: m)
decSNat SNat n
m SNat n
n)

instance Show (SNat n) where
  showsPrec :: Int -> SNat n -> ShowS
  showsPrec :: Int -> SNat n -> ShowS
showsPrec Int
p = \case
    SNat n
Z -> [Char] -> ShowS
showString [Char]
"Z"
    S SNat (Pred n)
n -> Bool -> ShowS -> ShowS
showParen (Int
p Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
10) (ShowS -> ShowS) -> ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$ [Char] -> ShowS
showString [Char]
"S " ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> SNat (Pred n) -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
11 SNat (Pred n)
n

deriving newtype instance NFData (SNat n)

-- | @'fromSNat' n@ returns the numeric representation of 'SNat n'.
fromSNat :: (Integral i) => SNat n -> i
fromSNat :: forall i (n :: Nat). Integral i => SNat n -> i
fromSNat = Integer -> i
forall a. Num a => Integer -> a
fromInteger (Integer -> i) -> (SNat n -> Integer) -> SNat n -> i
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Integer
forall a. Integral a => a -> Integer
toInteger (Int -> Integer) -> (SNat n -> Int) -> SNat n -> Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (.snatRep)

-- | @'fromSNatRaw' n@ returns the raw underlying representation of 'SNat n'.
fromSNatRaw :: SNat n -> SNatRep
fromSNatRaw :: forall (n :: Nat). SNat n -> Int
fromSNatRaw = (.snatRep)

-- | Addition for natural number singletons.
plus :: SNat n -> SNat m -> SNat (n + m)
SNat n
n plus :: forall (n :: Nat) (m :: Nat). SNat n -> SNat m -> SNat (n + m)
`plus` SNat m
m = Int -> SNat (n + m)
forall (n :: Nat). Int -> SNat n
UnsafeSNat (SNat n
n.snatRep Int -> Int -> Int
forall a. Num a => a -> a -> a
+ SNat m
m.snatRep)

-- | Decidable equality for natural number singletons.
decSNat :: SNat n -> SNat m -> Maybe (n :~: m)
decSNat :: forall (n :: Nat) (m :: Nat). SNat n -> SNat m -> Maybe (n :~: m)
decSNat SNat n
n SNat m
m =
  if SNat n
n.snatRep Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== SNat m
m.snatRep
    then (n :~: m) -> Maybe (n :~: m)
forall a. a -> Maybe a
Just ((Any :~: Any) -> n :~: m
forall a b. a -> b
unsafeCoerce Any :~: Any
forall {k} (a :: k). a :~: a
Refl)
    else Maybe (n :~: m)
forall a. Maybe a
Nothing

-- | Convert an 'Int' to an 'SNatRep'.
intToSNatRep :: Int -> SNatRep
#ifdef SNAT_AS_WORD8
-- TODO: Make this safe.
intToSNatRep int
  | int < 0 = throw Underflow
  | int > fromIntegral (maxBound @Word8) = throw Overflow
  | otherwise = fromIntegral @Int @Word8 int
{-# INLINE intToSNatRep #-}
#else
intToSNatRep :: Int -> Int
intToSNatRep = Int -> Int
forall a. a -> a
id
{-# INLINE intToSNatRep #-}
#endif

-- | Convert an 'SNatRep' to an 'Int'.
snatRepToInt :: SNatRep -> Int
#ifdef SNAT_AS_WORD8
snatRepToInt = fromIntegral @Word8 @Int
{-# INLINE snatRepToInt #-}
#else
snatRepToInt :: Int -> Int
snatRepToInt = Int -> Int
forall a. a -> a
id
{-# INLINE snatRepToInt #-}
#endif

--------------------------------------------------------------------------------
-- Existential Wrapper
--------------------------------------------------------------------------------

-- | An existential wrapper around natural number singletons.
type SomeSNat :: Type
data SomeSNat = forall (n :: Nat). SomeSNat !(SNat n)

instance Eq SomeSNat where
  (==) :: SomeSNat -> SomeSNat -> Bool
  SomeSNat SNat n
m == :: SomeSNat -> SomeSNat -> Bool
== SomeSNat SNat n
n = Maybe (n :~: n) -> Bool
forall a. Maybe a -> Bool
isJust (SNat n -> SNat n -> Maybe (n :~: n)
forall (n :: Nat) (m :: Nat). SNat n -> SNat m -> Maybe (n :~: m)
decSNat SNat n
m SNat n
n)

deriving instance Show SomeSNat

instance NFData SomeSNat where
  rnf :: SomeSNat -> ()
  rnf :: SomeSNat -> ()
rnf (SomeSNat SNat n
n) = SNat n -> ()
forall a. NFData a => a -> ()
rnf SNat n
n

-- | Evaluate a term with access to the underlying @'SNat'@.
withSomeSNat :: (forall n. SNat n -> a) -> SomeSNat -> a
withSomeSNat :: forall a. (forall (n :: Nat). SNat n -> a) -> SomeSNat -> a
withSomeSNat forall (n :: Nat). SNat n -> a
action (SomeSNat SNat n
n) = SNat n -> a
forall (n :: Nat). SNat n -> a
action SNat n
n

{-| @'toSomeSNat' n@ constructs the singleton @'SNat' n@.

prop> toSomeSNat (fromSomeSNat n) == n
-}
toSomeSNat :: (Integral i) => i -> SomeSNat
toSomeSNat :: forall i. Integral i => i -> SomeSNat
toSomeSNat i
r
  | i
r i -> i -> Bool
forall a. Ord a => a -> a -> Bool
< i
0 = [Char] -> SomeSNat
forall a. HasCallStack => [Char] -> a
error ([Char] -> SomeSNat) -> [Char] -> SomeSNat
forall a b. (a -> b) -> a -> b
$ [Char] -> Integer -> [Char]
forall r. PrintfType r => [Char] -> r
printf [Char]
"cannot convert %d to natural number singleton" (i -> Integer
forall a. Integral a => a -> Integer
toInteger i
r)
  | Bool
otherwise = SNat Any -> SomeSNat
forall (n :: Nat). SNat n -> SomeSNat
SomeSNat (Int -> SNat Any
forall (n :: Nat). Int -> SNat n
UnsafeSNat (i -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral i
r))

{-| @'toSomeSNat' n@ constructs the singleton @'SNat' n@.

prop> toSomeSNatRaw (fromSomeSNatRaw n) == n
-}
toSomeSNatRaw :: SNatRep -> SomeSNat
toSomeSNatRaw :: Int -> SomeSNat
toSomeSNatRaw Int
r
  | Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 = [Char] -> SomeSNat
forall a. HasCallStack => [Char] -> a
error ([Char] -> SomeSNat) -> [Char] -> SomeSNat
forall a b. (a -> b) -> a -> b
$ ShowS
forall r. PrintfType r => [Char] -> r
printf [Char]
"cannot convert %d to natural number singleton"
  | Bool
otherwise = SNat Any -> SomeSNat
forall (n :: Nat). SNat n -> SomeSNat
SomeSNat (Int -> SNat Any
forall (n :: Nat). Int -> SNat n
UnsafeSNat Int
r)

-- | @'fromSomeSNat' n@ returns the numeric representation of the wrapped singleton.
fromSomeSNat :: (Integral i) => SomeSNat -> i
fromSomeSNat :: forall i. Integral i => SomeSNat -> i
fromSomeSNat = (forall (n :: Nat). SNat n -> i) -> SomeSNat -> i
forall a. (forall (n :: Nat). SNat n -> a) -> SomeSNat -> a
withSomeSNat SNat n -> i
forall i (n :: Nat). Integral i => SNat n -> i
forall (n :: Nat). SNat n -> i
fromSNat

-- | @'fromSomeSNat' n@ returns the numeric representation of the wrapped singleton.
fromSomeSNatRaw :: SomeSNat -> SNatRep
fromSomeSNatRaw :: SomeSNat -> Int
fromSomeSNatRaw (SomeSNat (UnsafeSNat Int
r)) = Int
r

--------------------------------------------------------------------------------
-- Laws
--------------------------------------------------------------------------------

plusUnitL :: Proxy n -> Z + n :~: n
plusUnitL :: forall (n :: Nat). Proxy n -> (Z + n) :~: n
plusUnitL Proxy n
_ = n :~: n
(Z + n) :~: n
forall {k} (a :: k). a :~: a
Refl

plusUnitR :: SNat n -> n + Z :~: n
plusUnitR :: forall (n :: Nat). SNat n -> (n + Z) :~: n
plusUnitR SNat n
_ = (Any :~: Any) -> (n + Z) :~: n
forall a b. a -> b
unsafeCoerce Any :~: Any
forall {k} (a :: k). a :~: a
Refl

plusCommS :: SNat n -> Proxy m -> S (n + m) :~: n + S m
plusCommS :: forall (n :: Nat) (m :: Nat).
SNat n -> Proxy m -> S (n + m) :~: (n + S m)
plusCommS SNat n
_ Proxy m
_ = (Any :~: Any) -> S (n + m) :~: (n + S m)
forall a b. a -> b
unsafeCoerce Any :~: Any
forall {k} (a :: k). a :~: a
Refl

plusComm :: SNat n -> SNat m -> n + m :~: m + n
plusComm :: forall (n :: Nat) (m :: Nat).
SNat n -> SNat m -> (n + m) :~: (m + n)
plusComm SNat n
_ SNat m
_ = (Any :~: Any) -> (n + m) :~: (m + n)
forall a b. a -> b
unsafeCoerce Any :~: Any
forall {k} (a :: k). a :~: a
Refl

plusAssoc :: SNat n -> Proxy m -> Proxy l -> (n + m) + l :~: n + (m + l)
plusAssoc :: forall (n :: Nat) (m :: Nat) (l :: Nat).
SNat n -> Proxy m -> Proxy l -> ((n + m) + l) :~: (n + (m + l))
plusAssoc SNat n
_ Proxy m
_ Proxy l
_ = (Any :~: Any) -> ((n + m) + l) :~: (n + (m + l))
forall a b. a -> b
unsafeCoerce Any :~: Any
forall {k} (a :: k). a :~: a
Refl

--------------------------------------------------------------------------------
-- Linking Type-Level and Value-Level
--------------------------------------------------------------------------------

type FromNat :: Nat -> GHC.Nat
type family FromNat n where
  FromNat Z = 0
  FromNat (S n) = FromNat n GHC.+ 1

type KnownNat :: Nat -> Constraint
class KnownNat n where
  natSing :: SNat n

instance KnownNat Z where
  natSing :: SNat Z
  natSing :: SNat Z
natSing = SNat Z
forall (n :: Nat). (n ~ Z) => SNat n
Z

instance (KnownNat n) => KnownNat (S n) where
  natSing :: SNat (S n)
  natSing :: SNat (S n)
natSing = SNat (Pred (S n)) -> SNat (S n)
forall (n :: Nat). Pos n => SNat (Pred n) -> SNat n
S SNat n
SNat (Pred (S n))
forall (n :: Nat). KnownNat n => SNat n
natSing

data Dict (c :: Constraint) :: Type where
  Dict :: (c) => Dict c

data FakeKnownNat n = FakeKnownNat (SNat n)
{-# ANN FakeKnownNat ("HLint: ignore Use newtype instead of data" :: String) #-}

withKnownNat :: SNat n -> ((KnownNat n) => r) -> r
withKnownNat :: forall (n :: Nat) r. SNat n -> (KnownNat n => r) -> r
withKnownNat SNat n
n KnownNat n => r
action = case SNat n -> Dict (KnownNat n)
forall (n :: Nat). SNat n -> Dict (KnownNat n)
knownNat SNat n
n of Dict (KnownNat n)
Dict -> r
KnownNat n => r
action
 where
  knownNat :: SNat n -> Dict (KnownNat n)
  knownNat :: forall (n :: Nat). SNat n -> Dict (KnownNat n)
knownNat = FakeKnownNat n -> Dict (KnownNat n)
forall a b. a -> b
unsafeCoerce (FakeKnownNat n -> Dict (KnownNat n))
-> (SNat n -> FakeKnownNat n) -> SNat n -> Dict (KnownNat n)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SNat n -> FakeKnownNat n
forall (n :: Nat). SNat n -> FakeKnownNat n
FakeKnownNat