{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE Trustworthy #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_HADDOCK show-extensions #-}
module Clash.Promoted.Nat
(
SNat (..)
, snatProxy
, withSNat
, snatToInteger, snatToNatural, snatToNum
, natToInteger, natToNatural, natToNum
, addSNat, mulSNat, powSNat, minSNat, maxSNat, succSNat
, subSNat, divSNat, modSNat, flogBaseSNat, clogBaseSNat, logBaseSNat, predSNat
, pow2SNat
, SNatLE (..), compareSNat
, UNat (..)
, toUNat
, fromUNat
, addUNat, mulUNat, powUNat
, predUNat, subUNat
, BNat (..)
, toBNat
, fromBNat
, showBNat
, succBNat, addBNat, mulBNat, powBNat
, predBNat, div2BNat, div2Sub1BNat, log2BNat
, stripZeros
, leToPlus
, leToPlusKN
)
where
import Data.Constraint (Dict(..), (:-)(..))
import Data.Constraint.Nat (euclideanNat)
import Data.Kind (Type)
import Data.Type.Equality ((:~:)(..))
import Data.Type.Ord (OrderingI(..))
import GHC.Show (appPrec)
import GHC.TypeLits (KnownNat, Nat, type (+), type (-), type (*),
type (^), type (<=),
cmpNat, sameNat,
natVal)
import GHC.TypeLits.Extra (CLog, FLog, Div, Log, Mod, Min, Max)
import GHC.Natural (naturalFromInteger)
import Language.Haskell.TH (appT, conT, litT, numTyLit, sigE)
import Language.Haskell.TH.Syntax (Lift (..))
import Language.Haskell.TH.Compat
import Numeric.Natural (Natural)
import Clash.Annotations.Primitive (hasBlackBox)
import Clash.XException (ShowX (..), showsPrecXWith)
data SNat (n :: Nat) where
SNat :: KnownNat n => SNat n
instance Lift (SNat n) where
lift s = sigE [| SNat |]
(appT (conT ''SNat) (litT $ numTyLit (snatToInteger s)))
liftTyped = liftTypedFromUntyped
snatProxy :: KnownNat n => proxy n -> SNat n
snatProxy _ = SNat
instance Show (SNat n) where
showsPrec d p@SNat | n <= 1024 = showChar 'd' . shows n
| otherwise = showParen (d > appPrec) $
showString "SNat @" . shows n
where
n = snatToInteger p
instance ShowX (SNat n) where
showsPrecX = showsPrecXWith showsPrec
{-# INLINE withSNat #-}
withSNat :: KnownNat n => (SNat n -> a) -> a
withSNat f = f SNat
natToInteger :: forall n . KnownNat n => Integer
natToInteger = snatToInteger (SNat @n)
{-# INLINE natToInteger #-}
snatToInteger :: SNat n -> Integer
snatToInteger p@SNat = natVal p
{-# INLINE snatToInteger #-}
natToNatural :: forall n . KnownNat n => Natural
natToNatural = snatToNatural (SNat @n)
{-# INLINE natToNatural #-}
snatToNatural :: SNat n -> Natural
snatToNatural = naturalFromInteger . snatToInteger
{-# INLINE snatToNatural #-}
natToNum :: forall n a . (Num a, KnownNat n) => a
natToNum = snatToNum (SNat @n)
{-# INLINE natToNum #-}
snatToNum :: forall a n . Num a => SNat n -> a
snatToNum p@SNat = fromInteger (snatToInteger p)
{-# INLINE snatToNum #-}
data UNat :: Nat -> Type where
UZero :: UNat 0
USucc :: UNat n -> UNat (n + 1)
instance KnownNat n => Show (UNat n) where
show x = 'u':show (natVal x)
instance KnownNat n => ShowX (UNat n) where
showsPrecX = showsPrecXWith showsPrec
toUNat :: forall n . SNat n -> UNat n
toUNat p@SNat = case cmpNat (SNat @1) p of
LTI -> USucc (toUNat @(n - 1) (predSNat p))
EQI -> USucc UZero
GTI -> case sameNat p (SNat @0) of
Just Refl -> UZero
_ -> error "toUNat: impossible: 1 > n and n /= 0 for (n :: Nat)"
fromUNat :: UNat n -> SNat n
fromUNat UZero = SNat :: SNat 0
fromUNat (USucc x) = addSNat (fromUNat x) (SNat :: SNat 1)
addUNat :: UNat n -> UNat m -> UNat (n + m)
addUNat UZero y = y
addUNat x UZero = x
addUNat (USucc x) y = USucc (addUNat x y)
mulUNat :: UNat n -> UNat m -> UNat (n * m)
mulUNat UZero _ = UZero
mulUNat _ UZero = UZero
mulUNat (USucc x) y = addUNat y (mulUNat x y)
powUNat :: UNat n -> UNat m -> UNat (n ^ m)
powUNat _ UZero = USucc UZero
powUNat x (USucc y) = mulUNat x (powUNat x y)
predUNat :: UNat (n+1) -> UNat n
predUNat (USucc x) = x
predUNat UZero =
error "predUNat: impossible: 0 minus 1, -1 is not a natural number"
subUNat :: UNat (m+n) -> UNat n -> UNat m
subUNat x UZero = x
subUNat (USucc x) (USucc y) = subUNat x y
subUNat UZero _ = error "subUNat: impossible: 0 + (n + 1) ~ 0"
predSNat :: SNat (a+1) -> SNat (a)
predSNat SNat = SNat
{-# INLINE predSNat #-}
succSNat :: SNat a -> SNat (a+1)
succSNat SNat = SNat
{-# INLINE succSNat #-}
addSNat :: SNat a -> SNat b -> SNat (a+b)
addSNat SNat SNat = SNat
{-# INLINE addSNat #-}
infixl 6 `addSNat`
subSNat :: SNat (a+b) -> SNat b -> SNat a
subSNat SNat SNat = SNat
{-# INLINE subSNat #-}
infixl 6 `subSNat`
mulSNat :: SNat a -> SNat b -> SNat (a*b)
mulSNat SNat SNat = SNat
{-# INLINE mulSNat #-}
infixl 7 `mulSNat`
powSNat :: SNat a -> SNat b -> SNat (a^b)
powSNat SNat SNat = SNat
{-# OPAQUE powSNat #-}
{-# ANN powSNat hasBlackBox #-}
infixr 8 `powSNat`
divSNat :: (1 <= b) => SNat a -> SNat b -> SNat (Div a b)
divSNat SNat SNat = SNat
{-# INLINE divSNat #-}
infixl 7 `divSNat`
modSNat :: (1 <= b) => SNat a -> SNat b -> SNat (Mod a b)
modSNat SNat SNat = SNat
{-# INLINE modSNat #-}
infixl 7 `modSNat`
minSNat :: SNat a -> SNat b -> SNat (Min a b)
minSNat SNat SNat = SNat
maxSNat :: SNat a -> SNat b -> SNat (Max a b)
maxSNat SNat SNat = SNat
flogBaseSNat :: (2 <= base, 1 <= x)
=> SNat base
-> SNat x
-> SNat (FLog base x)
flogBaseSNat SNat SNat = SNat
{-# OPAQUE flogBaseSNat #-}
{-# ANN flogBaseSNat hasBlackBox #-}
clogBaseSNat :: (2 <= base, 1 <= x)
=> SNat base
-> SNat x
-> SNat (CLog base x)
clogBaseSNat SNat SNat = SNat
{-# OPAQUE clogBaseSNat #-}
{-# ANN clogBaseSNat hasBlackBox #-}
logBaseSNat :: (FLog base x ~ CLog base x)
=> SNat base
-> SNat x
-> SNat (Log base x)
logBaseSNat SNat SNat = SNat
{-# OPAQUE logBaseSNat #-}
{-# ANN logBaseSNat hasBlackBox #-}
pow2SNat :: SNat a -> SNat (2^a)
pow2SNat SNat = SNat
{-# INLINE pow2SNat #-}
data SNatLE a b where
SNatLE :: forall a b . a <= b => SNatLE a b
SNatGT :: forall a b . (b+1) <= a => SNatLE a b
deriving instance Show (SNatLE a b)
compareSNat :: forall a b . SNat a -> SNat b -> SNatLE a b
compareSNat a@SNat b@SNat = case cmpNat a b of
LTI -> SNatLE
EQI -> SNatLE
GTI -> case cmpNat (succSNat b) a of
LTI -> SNatGT
EQI -> SNatGT
GTI -> error "compareSNat: impossible: a > b and b + 1 > a"
data BNat :: Nat -> Type where
BT :: BNat 0
B0 :: BNat n -> BNat (2*n)
B1 :: BNat n -> BNat ((2*n) + 1)
instance KnownNat n => Show (BNat n) where
show x = 'b':show (natVal x)
instance KnownNat n => ShowX (BNat n) where
showsPrecX = showsPrecXWith showsPrec
modSNat :: forall (b :: Natural) (a :: Natural).
(1 <= b) =>
SNat a -> SNat b -> SNat (Mod a b)
showBNat :: BNat n -> String
showBNat = go []
where
go :: String -> BNat m -> String
go xs BT = "0b" ++ xs
go xs (B0 x) = go ('0':xs) x
go xs (B1 x) = go ('1':xs) x
toBNat :: forall n. SNat n -> BNat n
toBNat s@SNat = case cmpNat (SNat @1) s of
LTI -> case euclideanNat @2 @n of
Sub Dict -> case sameNat (SNat @(n `Mod` 2)) (SNat @0) of
Just Refl -> B0 (toBNat (SNat @(n `Div` 2)))
Nothing -> case sameNat (SNat @(n `Mod` 2)) (SNat @1) of
Just Refl -> B1 (toBNat (SNat @(n `Div` 2)))
Nothing -> error "toBNat: impossible: n mod 2 is either 0 or 1"
EQI -> B1 BT
GTI -> case sameNat s (SNat @0) of
Just Refl -> BT
_ -> error "toBNat: impossible: 1 > n and n /= 0 for (n :: Nat)"
fromBNat :: BNat n -> SNat n
fromBNat BT = SNat :: SNat 0
fromBNat (B0 x) = mulSNat (SNat :: SNat 2) (fromBNat x)
fromBNat (B1 x) = addSNat (mulSNat (SNat :: SNat 2) (fromBNat x))
(SNat :: SNat 1)
addBNat :: BNat n -> BNat m -> BNat (n+m)
addBNat (B0 a) (B0 b) = B0 (addBNat a b)
addBNat (B0 a) (B1 b) = B1 (addBNat a b)
addBNat (B1 a) (B0 b) = B1 (addBNat a b)
addBNat (B1 a) (B1 b) = B0 (succBNat (addBNat a b))
addBNat BT b = b
addBNat a BT = a
mulBNat :: BNat n -> BNat m -> BNat (n*m)
mulBNat BT _ b :: SNat b
= BT
mulBNat _ BT = BT
mulBNat (B0 a) b = B0 (mulBNat a b)
mulBNat (B1 a) b = addBNat (B0 (mulBNat a b)) b
powBNat :: BNat n -> BNat m -> BNat (n^m)
powBNat _ BT = B1 BT
powBNat a (B0 b) = let z = powBNat a b
in mulBNat z z
powBNat a (B1 b) = let z = powBNat a b
in mulBNat a (mulBNat z z)
succBNat :: BNat n -> BNat (n+1)
succBNat BT = B1 BT
succBNat (B0 a) = B1 a
succBNat (B1 a) = B0 (succBNat a)
predBNat :: (1 <= n) => BNat n -> BNat (n-1)
predBNat (B1 a) = case stripZeros a of
BT -> BT
a' -> B0 a'
predBNat (B0 x) = B1 (predBNat x)
div2BNat :: BNat (2*n) -> BNat n
div2BNat BT = BT
div2BNat (B0 x) = x
div2BNat (B1 _) = error "div2BNat: impossible: 2*n ~ 2*n+1"
div2Sub1BNat :: BNat (2*n+1) -> BNat n
div2Sub1BNat (B1 x) = x
div2Sub1BNat _ = error "div2Sub1BNat: impossible: 2*n+1 ~ 2*n"
log2BNat :: BNat (2^n) -> BNat n
log2BNat BT = error "log2BNat: log2(0) not defined"
log2BNat (B1 x) = case stripZeros x of
BT -> BT
_ -> error "log2BNat: impossible: 2^n ~ 2x+1"
log2BNat (B0 x) = succBNat (log2BNat x)
stripZeros :: BNat n -> BNat n
stripZeros BT = BT
stripZeros (B1 x) = B1 (stripZeros x)
stripZeros (B0 BT) = BT
stripZeros (B0 x) = case stripZeros x of
BT -> BT
k -> B0 k
leToPlus
:: forall (k :: Nat) (n :: Nat) r
. ( k <= n
)
=> (forall m . (n ~ (k + m)) => r)
-> r
leToPlus r = r @BNat m
(n - k)
{-# INLINE leToPlus #-}
leToPlusKN
:: forall (k :: Nat) (n :: Nat) r
. ( k <= n
, KnownNat k
, KnownNat n
)
=> (forall m . (n ~ (k + m), KnownNat m) => r)
-> r
leToPlusKN r = r @(n - k)
{-# INLINE leToPlusKN #-}