{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RoleAnnotations #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -Wno-duplicate-exports #-}
module Data.Type.Nat.Singleton.Fast (
SNat (Z, S),
fromSNat,
fromSNatRaw,
plus,
decSNat,
SomeSNat (..),
withSomeSNat,
toSomeSNat,
toSomeSNatRaw,
fromSomeSNat,
fromSomeSNatRaw,
plusUnitL,
plusUnitR,
plusCommS,
plusComm,
plusAssoc,
KnownNat (..),
withKnownNat,
SNatRep,
intToSNatRep,
snatRepToInt,
SNat (UnsafeSNat, snatRep),
) where
import Control.DeepSeq (NFData (..))
import Control.Exception (assert)
import Data.Kind (Constraint, Type)
import Data.Maybe (isJust)
import Data.Proxy (Proxy (..))
import Data.Type.Equality (type (:~:) (Refl))
import Data.Type.Nat (Nat (..), Pos, Pred, type (+))
import GHC.TypeLits qualified as GHC
import Text.Printf (printf)
import Unsafe.Coerce (unsafeCoerce)
#ifdef SNAT_AS_WORD8
import Control.Exception (throw, ArithException (Overflow, Underflow))
import Data.Word (Word8)
#endif
#if defined(SNAT_AS_WORD8)
type SNatRep = Word8
#elif defined(SNAT_AS_INT)
type SNatRep = Int
#elif !defined(__HLINT__)
#error "cpp: define one of [SNAT_AS_WORD8, SNAT_AS_INT]"
#endif
isValidSNatRep :: SNatRep -> Bool
isValidSNatRep :: Int -> Bool
isValidSNatRep = (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0)
mkZRep :: SNatRep
mkZRep :: Int
mkZRep = Int
0
{-# INLINE mkZRep #-}
mkSRep :: SNatRep -> SNatRep
mkSRep :: Int -> Int
mkSRep = (Int
1 +)
{-# INLINE mkSRep #-}
unSRep :: SNatRep -> SNatRep
unSRep :: Int -> Int
unSRep = Int -> Int -> Int
forall a. Num a => a -> a -> a
subtract Int
1
{-# INLINE unSRep #-}
elSNatRep :: a -> (SNatRep -> a) -> SNatRep -> a
elSNatRep :: forall a. a -> (Int -> a) -> Int -> a
elSNatRep a
ifZ Int -> a
ifS Int
n =
Bool -> a -> a
forall a. HasCallStack => Bool -> a -> a
assert (Int -> Bool
isValidSNatRep Int
n) (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$
if Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
mkZRep
then a
ifZ
else Int -> a
ifS (Int -> Int
unSRep Int
n)
{-# INLINE elSNatRep #-}
type SNat :: Nat -> Type
newtype SNat n = UnsafeSNat {forall (n :: Nat). SNat n -> Int
snatRep :: SNatRep}
type role SNat nominal
mkZ :: SNat Z
mkZ :: SNat Z
mkZ = Int -> SNat Z
forall (n :: Nat). Int -> SNat n
UnsafeSNat Int
mkZRep
{-# INLINE mkZ #-}
mkS :: SNat n -> SNat (S n)
mkS :: forall (n :: Nat). SNat n -> SNat (S n)
mkS = Int -> SNat (S n)
forall (n :: Nat). Int -> SNat n
UnsafeSNat (Int -> SNat (S n)) -> (SNat n -> Int) -> SNat n -> SNat (S n)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Int
mkSRep (Int -> Int) -> (SNat n -> Int) -> SNat n -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (.snatRep)
{-# INLINE mkS #-}
data SNatF (snat :: Nat -> Type) (n :: Nat) where
ZF :: SNatF snat Z
SF :: !(snat n) -> SNatF snat (S n)
projectSNat :: SNat n -> SNatF SNat n
projectSNat :: forall (n :: Nat). SNat n -> SNatF SNat n
projectSNat =
SNatF SNat n -> (Int -> SNatF SNat n) -> Int -> SNatF SNat n
forall a. a -> (Int -> a) -> Int -> a
elSNatRep (SNatF Any Z -> SNatF SNat n
forall a b. a -> b
unsafeCoerce SNatF Any Z
forall (snat :: Nat -> *). SNatF snat Z
ZF) (SNatF SNat (S Any) -> SNatF SNat n
forall a b. a -> b
unsafeCoerce (SNatF SNat (S Any) -> SNatF SNat n)
-> (Int -> SNatF SNat (S Any)) -> Int -> SNatF SNat n
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SNat Any -> SNatF SNat (S Any)
forall (snat :: Nat -> *) (n :: Nat). snat n -> SNatF snat (S n)
SF (SNat Any -> SNatF SNat (S Any))
-> (Int -> SNat Any) -> Int -> SNatF SNat (S Any)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> SNat Any
forall (n :: Nat). Int -> SNat n
UnsafeSNat) (Int -> SNatF SNat n) -> (SNat n -> Int) -> SNat n -> SNatF SNat n
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (.snatRep)
{-# INLINE projectSNat #-}
embedSNat :: SNatF SNat n -> SNat n
embedSNat :: forall (n :: Nat). SNatF SNat n -> SNat n
embedSNat = \case
SNatF SNat n
ZF -> SNat n
SNat Z
mkZ
SF SNat n
n -> SNat n -> SNat (S n)
forall (n :: Nat). SNat n -> SNat (S n)
mkS SNat n
n
{-# INLINE embedSNat #-}
pattern Z :: () => (n ~ Z) => SNat n
pattern $mZ :: forall {r} {n :: Nat}.
SNat n -> ((n ~ Z) => r) -> ((# #) -> r) -> r
$bZ :: forall (n :: Nat). (n ~ Z) => SNat n
Z <- (projectSNat -> ZF) where Z = SNatF SNat n -> SNat n
forall (n :: Nat). SNatF SNat n -> SNat n
embedSNat SNatF SNat n
SNatF SNat Z
forall (snat :: Nat -> *). SNatF snat Z
ZF
{-# INLINE Z #-}
pattern S :: () => (Pos n) => SNat (Pred n) -> SNat n
pattern $mS :: forall {r} {n :: Nat}.
SNat n -> (Pos n => SNat (Pred n) -> r) -> ((# #) -> r) -> r
$bS :: forall (n :: Nat). Pos n => SNat (Pred n) -> SNat n
S n <- (projectSNat -> SF n) where S SNat (Pred n)
n = SNatF SNat n -> SNat n
forall (n :: Nat). SNatF SNat n -> SNat n
embedSNat (SNat (Pred n) -> SNatF SNat (S (Pred n))
forall (snat :: Nat -> *) (n :: Nat). snat n -> SNatF snat (S n)
SF SNat (Pred n)
n)
{-# INLINE S #-}
{-# COMPLETE Z, S #-}
instance Eq (SNat n) where
(==) :: SNat n -> SNat n -> Bool
SNat n
m == :: SNat n -> SNat n -> Bool
== SNat n
n = Maybe (n :~: n) -> Bool
forall a. Maybe a -> Bool
isJust (SNat n -> SNat n -> Maybe (n :~: n)
forall (n :: Nat) (m :: Nat). SNat n -> SNat m -> Maybe (n :~: m)
decSNat SNat n
m SNat n
n)
instance Show (SNat n) where
showsPrec :: Int -> SNat n -> ShowS
showsPrec :: Int -> SNat n -> ShowS
showsPrec Int
p = \case
SNat n
Z -> [Char] -> ShowS
showString [Char]
"Z"
S SNat (Pred n)
n -> Bool -> ShowS -> ShowS
showParen (Int
p Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
10) (ShowS -> ShowS) -> ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$ [Char] -> ShowS
showString [Char]
"S " ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> SNat (Pred n) -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
11 SNat (Pred n)
n
deriving newtype instance NFData (SNat n)
fromSNat :: (Integral i) => SNat n -> i
fromSNat :: forall i (n :: Nat). Integral i => SNat n -> i
fromSNat = Integer -> i
forall a. Num a => Integer -> a
fromInteger (Integer -> i) -> (SNat n -> Integer) -> SNat n -> i
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Integer
forall a. Integral a => a -> Integer
toInteger (Int -> Integer) -> (SNat n -> Int) -> SNat n -> Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (.snatRep)
fromSNatRaw :: SNat n -> SNatRep
fromSNatRaw :: forall (n :: Nat). SNat n -> Int
fromSNatRaw = (.snatRep)
plus :: SNat n -> SNat m -> SNat (n + m)
SNat n
n plus :: forall (n :: Nat) (m :: Nat). SNat n -> SNat m -> SNat (n + m)
`plus` SNat m
m = Int -> SNat (n + m)
forall (n :: Nat). Int -> SNat n
UnsafeSNat (SNat n
n.snatRep Int -> Int -> Int
forall a. Num a => a -> a -> a
+ SNat m
m.snatRep)
decSNat :: SNat n -> SNat m -> Maybe (n :~: m)
decSNat :: forall (n :: Nat) (m :: Nat). SNat n -> SNat m -> Maybe (n :~: m)
decSNat SNat n
n SNat m
m =
if SNat n
n.snatRep Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== SNat m
m.snatRep
then (n :~: m) -> Maybe (n :~: m)
forall a. a -> Maybe a
Just ((Any :~: Any) -> n :~: m
forall a b. a -> b
unsafeCoerce Any :~: Any
forall {k} (a :: k). a :~: a
Refl)
else Maybe (n :~: m)
forall a. Maybe a
Nothing
intToSNatRep :: Int -> SNatRep
#ifdef SNAT_AS_WORD8
intToSNatRep int
| int < 0 = throw Underflow
| int > fromIntegral (maxBound @Word8) = throw Overflow
| otherwise = fromIntegral @Int @Word8 int
{-# INLINE intToSNatRep #-}
#else
intToSNatRep :: Int -> Int
intToSNatRep = Int -> Int
forall a. a -> a
id
{-# INLINE intToSNatRep #-}
#endif
snatRepToInt :: SNatRep -> Int
#ifdef SNAT_AS_WORD8
snatRepToInt = fromIntegral @Word8 @Int
{-# INLINE snatRepToInt #-}
#else
snatRepToInt :: Int -> Int
snatRepToInt = Int -> Int
forall a. a -> a
id
{-# INLINE snatRepToInt #-}
#endif
type SomeSNat :: Type
data SomeSNat = forall (n :: Nat). SomeSNat !(SNat n)
instance Eq SomeSNat where
(==) :: SomeSNat -> SomeSNat -> Bool
SomeSNat SNat n
m == :: SomeSNat -> SomeSNat -> Bool
== SomeSNat SNat n
n = Maybe (n :~: n) -> Bool
forall a. Maybe a -> Bool
isJust (SNat n -> SNat n -> Maybe (n :~: n)
forall (n :: Nat) (m :: Nat). SNat n -> SNat m -> Maybe (n :~: m)
decSNat SNat n
m SNat n
n)
deriving instance Show SomeSNat
instance NFData SomeSNat where
rnf :: SomeSNat -> ()
rnf :: SomeSNat -> ()
rnf (SomeSNat SNat n
n) = SNat n -> ()
forall a. NFData a => a -> ()
rnf SNat n
n
withSomeSNat :: (forall n. SNat n -> a) -> SomeSNat -> a
withSomeSNat :: forall a. (forall (n :: Nat). SNat n -> a) -> SomeSNat -> a
withSomeSNat forall (n :: Nat). SNat n -> a
action (SomeSNat SNat n
n) = SNat n -> a
forall (n :: Nat). SNat n -> a
action SNat n
n
toSomeSNat :: (Integral i) => i -> SomeSNat
toSomeSNat :: forall i. Integral i => i -> SomeSNat
toSomeSNat i
r
| i
r i -> i -> Bool
forall a. Ord a => a -> a -> Bool
< i
0 = [Char] -> SomeSNat
forall a. HasCallStack => [Char] -> a
error ([Char] -> SomeSNat) -> [Char] -> SomeSNat
forall a b. (a -> b) -> a -> b
$ [Char] -> Integer -> [Char]
forall r. PrintfType r => [Char] -> r
printf [Char]
"cannot convert %d to natural number singleton" (i -> Integer
forall a. Integral a => a -> Integer
toInteger i
r)
| Bool
otherwise = SNat Any -> SomeSNat
forall (n :: Nat). SNat n -> SomeSNat
SomeSNat (Int -> SNat Any
forall (n :: Nat). Int -> SNat n
UnsafeSNat (i -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral i
r))
toSomeSNatRaw :: SNatRep -> SomeSNat
toSomeSNatRaw :: Int -> SomeSNat
toSomeSNatRaw Int
r
| Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 = [Char] -> SomeSNat
forall a. HasCallStack => [Char] -> a
error ([Char] -> SomeSNat) -> [Char] -> SomeSNat
forall a b. (a -> b) -> a -> b
$ ShowS
forall r. PrintfType r => [Char] -> r
printf [Char]
"cannot convert %d to natural number singleton"
| Bool
otherwise = SNat Any -> SomeSNat
forall (n :: Nat). SNat n -> SomeSNat
SomeSNat (Int -> SNat Any
forall (n :: Nat). Int -> SNat n
UnsafeSNat Int
r)
fromSomeSNat :: (Integral i) => SomeSNat -> i
fromSomeSNat :: forall i. Integral i => SomeSNat -> i
fromSomeSNat = (forall (n :: Nat). SNat n -> i) -> SomeSNat -> i
forall a. (forall (n :: Nat). SNat n -> a) -> SomeSNat -> a
withSomeSNat SNat n -> i
forall i (n :: Nat). Integral i => SNat n -> i
forall (n :: Nat). SNat n -> i
fromSNat
fromSomeSNatRaw :: SomeSNat -> SNatRep
fromSomeSNatRaw :: SomeSNat -> Int
fromSomeSNatRaw (SomeSNat (UnsafeSNat Int
r)) = Int
r
plusUnitL :: Proxy n -> Z + n :~: n
plusUnitL :: forall (n :: Nat). Proxy n -> (Z + n) :~: n
plusUnitL Proxy n
_ = n :~: n
(Z + n) :~: n
forall {k} (a :: k). a :~: a
Refl
plusUnitR :: SNat n -> n + Z :~: n
plusUnitR :: forall (n :: Nat). SNat n -> (n + Z) :~: n
plusUnitR SNat n
_ = (Any :~: Any) -> (n + Z) :~: n
forall a b. a -> b
unsafeCoerce Any :~: Any
forall {k} (a :: k). a :~: a
Refl
plusCommS :: SNat n -> Proxy m -> S (n + m) :~: n + S m
plusCommS :: forall (n :: Nat) (m :: Nat).
SNat n -> Proxy m -> S (n + m) :~: (n + S m)
plusCommS SNat n
_ Proxy m
_ = (Any :~: Any) -> S (n + m) :~: (n + S m)
forall a b. a -> b
unsafeCoerce Any :~: Any
forall {k} (a :: k). a :~: a
Refl
plusComm :: SNat n -> SNat m -> n + m :~: m + n
plusComm :: forall (n :: Nat) (m :: Nat).
SNat n -> SNat m -> (n + m) :~: (m + n)
plusComm SNat n
_ SNat m
_ = (Any :~: Any) -> (n + m) :~: (m + n)
forall a b. a -> b
unsafeCoerce Any :~: Any
forall {k} (a :: k). a :~: a
Refl
plusAssoc :: SNat n -> Proxy m -> Proxy l -> (n + m) + l :~: n + (m + l)
plusAssoc :: forall (n :: Nat) (m :: Nat) (l :: Nat).
SNat n -> Proxy m -> Proxy l -> ((n + m) + l) :~: (n + (m + l))
plusAssoc SNat n
_ Proxy m
_ Proxy l
_ = (Any :~: Any) -> ((n + m) + l) :~: (n + (m + l))
forall a b. a -> b
unsafeCoerce Any :~: Any
forall {k} (a :: k). a :~: a
Refl
type FromNat :: Nat -> GHC.Nat
type family FromNat n where
FromNat Z = 0
FromNat (S n) = FromNat n GHC.+ 1
type KnownNat :: Nat -> Constraint
class KnownNat n where
natSing :: SNat n
instance KnownNat Z where
natSing :: SNat Z
natSing :: SNat Z
natSing = SNat Z
forall (n :: Nat). (n ~ Z) => SNat n
Z
instance (KnownNat n) => KnownNat (S n) where
natSing :: SNat (S n)
natSing :: SNat (S n)
natSing = SNat (Pred (S n)) -> SNat (S n)
forall (n :: Nat). Pos n => SNat (Pred n) -> SNat n
S SNat n
SNat (Pred (S n))
forall (n :: Nat). KnownNat n => SNat n
natSing
data Dict (c :: Constraint) :: Type where
Dict :: (c) => Dict c
data FakeKnownNat n = FakeKnownNat (SNat n)
{-# ANN FakeKnownNat ("HLint: ignore Use newtype instead of data" :: String) #-}
withKnownNat :: SNat n -> ((KnownNat n) => r) -> r
withKnownNat :: forall (n :: Nat) r. SNat n -> (KnownNat n => r) -> r
withKnownNat SNat n
n KnownNat n => r
action = case SNat n -> Dict (KnownNat n)
forall (n :: Nat). SNat n -> Dict (KnownNat n)
knownNat SNat n
n of Dict (KnownNat n)
Dict -> r
KnownNat n => r
action
where
knownNat :: SNat n -> Dict (KnownNat n)
knownNat :: forall (n :: Nat). SNat n -> Dict (KnownNat n)
knownNat = FakeKnownNat n -> Dict (KnownNat n)
forall a b. a -> b
unsafeCoerce (FakeKnownNat n -> Dict (KnownNat n))
-> (SNat n -> FakeKnownNat n) -> SNat n -> Dict (KnownNat n)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SNat n -> FakeKnownNat n
forall (n :: Nat). SNat n -> FakeKnownNat n
FakeKnownNat