-- copied & adapted from cryptic
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE DataKinds           #-}
{-# LANGUAGE KindSignatures      #-}
{-# LANGUAGE ConstraintKinds     #-}
{-# LANGUAGE TypeOperators       #-}
{-# LANGUAGE TypeFamilies        #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE Rank2Types          #-}
{-# LANGUAGE GADTs               #-}
module Crypto.Math.Bits
    ( FBits(..)
    , FBitsK(..)
    , SizeValid
    , splitHalf
    , append
    , dropBitsOnRight
    , dropBitsOnLeft
    ) where

import Data.Bits
import Data.Proxy
import GHC.Natural
import GHC.TypeLits

-- | Finite Bits
--
-- Sadly Bits is taken by Bits operation
data FBits (n :: Nat) = FBits { forall (n :: Nat). FBits n -> Nat
unFBits :: Natural }
    deriving (Int -> FBits n -> ShowS
[FBits n] -> ShowS
FBits n -> String
(Int -> FBits n -> ShowS)
-> (FBits n -> String) -> ([FBits n] -> ShowS) -> Show (FBits n)
forall (n :: Nat). Int -> FBits n -> ShowS
forall (n :: Nat). [FBits n] -> ShowS
forall (n :: Nat). FBits n -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall (n :: Nat). Int -> FBits n -> ShowS
showsPrec :: Int -> FBits n -> ShowS
$cshow :: forall (n :: Nat). FBits n -> String
show :: FBits n -> String
$cshowList :: forall (n :: Nat). [FBits n] -> ShowS
showList :: [FBits n] -> ShowS
Show,FBits n -> FBits n -> Bool
(FBits n -> FBits n -> Bool)
-> (FBits n -> FBits n -> Bool) -> Eq (FBits n)
forall (n :: Nat). FBits n -> FBits n -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall (n :: Nat). FBits n -> FBits n -> Bool
== :: FBits n -> FBits n -> Bool
$c/= :: forall (n :: Nat). FBits n -> FBits n -> Bool
/= :: FBits n -> FBits n -> Bool
Eq,Eq (FBits n)
Eq (FBits n) =>
(FBits n -> FBits n -> Ordering)
-> (FBits n -> FBits n -> Bool)
-> (FBits n -> FBits n -> Bool)
-> (FBits n -> FBits n -> Bool)
-> (FBits n -> FBits n -> Bool)
-> (FBits n -> FBits n -> FBits n)
-> (FBits n -> FBits n -> FBits n)
-> Ord (FBits n)
FBits n -> FBits n -> Bool
FBits n -> FBits n -> Ordering
FBits n -> FBits n -> FBits n
forall (n :: Nat). Eq (FBits n)
forall (n :: Nat). FBits n -> FBits n -> Bool
forall (n :: Nat). FBits n -> FBits n -> Ordering
forall (n :: Nat). FBits n -> FBits n -> FBits n
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: forall (n :: Nat). FBits n -> FBits n -> Ordering
compare :: FBits n -> FBits n -> Ordering
$c< :: forall (n :: Nat). FBits n -> FBits n -> Bool
< :: FBits n -> FBits n -> Bool
$c<= :: forall (n :: Nat). FBits n -> FBits n -> Bool
<= :: FBits n -> FBits n -> Bool
$c> :: forall (n :: Nat). FBits n -> FBits n -> Bool
> :: FBits n -> FBits n -> Bool
$c>= :: forall (n :: Nat). FBits n -> FBits n -> Bool
>= :: FBits n -> FBits n -> Bool
$cmax :: forall (n :: Nat). FBits n -> FBits n -> FBits n
max :: FBits n -> FBits n -> FBits n
$cmin :: forall (n :: Nat). FBits n -> FBits n -> FBits n
min :: FBits n -> FBits n -> FBits n
Ord)

data FBitsK = FBitsK (forall n . (KnownNat n, SizeValid n) => FBits n)

type SizeValid n = (KnownNat n, 1 <= n)

toFBits :: SizeValid n => Natural -> FBits n
toFBits :: forall (n :: Nat). SizeValid n => Nat -> FBits n
toFBits Nat
nat = Nat -> FBits n
forall (n :: Nat). Nat -> FBits n
FBits Nat
nat FBits n -> FBits n -> FBits n
forall a. Bits a => a -> a -> a
.&. FBits n
forall (n :: Nat). SizeValid n => FBits n
allOne

instance SizeValid n => Enum (FBits n) where
    toEnum :: Int -> FBits n
toEnum Int
i | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
&& Int -> Nat
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i Nat -> Nat -> Bool
forall a. Ord a => a -> a -> Bool
> FBits n -> Nat
forall (n :: Nat). FBits n -> Nat
unFBits FBits n
maxi = String -> FBits n
forall a. HasCallStack => String -> a
error String
"FBits n not within bound"
             | Bool
otherwise                              = Nat -> FBits n
forall (n :: Nat). Nat -> FBits n
FBits (Int -> Nat
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i)
      where maxi :: FBits n
maxi = FBits n
forall (n :: Nat). SizeValid n => FBits n
allOne :: FBits n
    fromEnum :: FBits n -> Int
fromEnum (FBits Nat
n) = Nat -> Int
forall a. Enum a => a -> Int
fromEnum Nat
n

instance SizeValid n => Bounded (FBits n) where
    minBound :: FBits n
minBound = Nat -> FBits n
forall (n :: Nat). Nat -> FBits n
FBits Nat
0
    maxBound :: FBits n
maxBound = FBits n
forall (n :: Nat). SizeValid n => FBits n
allOne

instance SizeValid n => Num (FBits n) where
    fromInteger :: Integer -> FBits n
fromInteger = Nat -> FBits n
forall (n :: Nat). SizeValid n => Nat -> FBits n
toFBits (Nat -> FBits n) -> (Integer -> Nat) -> Integer -> FBits n
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> Nat
forall a. Num a => Integer -> a
fromInteger
    + :: FBits n -> FBits n -> FBits n
(+) (FBits Nat
a) (FBits Nat
b) = Nat -> FBits n
forall (n :: Nat). SizeValid n => Nat -> FBits n
toFBits (Nat
a Nat -> Nat -> Nat
forall a. Num a => a -> a -> a
+ Nat
b)
    (-) (FBits Nat
a) (FBits Nat
b) = Nat -> FBits n
forall (n :: Nat). SizeValid n => Nat -> FBits n
toFBits (Nat
a Nat -> Nat -> Nat
forall a. Num a => a -> a -> a
- Nat
b)
    * :: FBits n -> FBits n -> FBits n
(*) (FBits Nat
a) (FBits Nat
b) = Nat -> FBits n
forall (n :: Nat). SizeValid n => Nat -> FBits n
toFBits (Nat
a Nat -> Nat -> Nat
forall a. Num a => a -> a -> a
* Nat
b)
    abs :: FBits n -> FBits n
abs = FBits n -> FBits n
forall a. a -> a
id
    signum :: FBits n -> FBits n
signum (FBits Nat
a)
        | Nat
a Nat -> Nat -> Bool
forall a. Eq a => a -> a -> Bool
== Nat
0    = Nat -> FBits n
forall (n :: Nat). Nat -> FBits n
FBits Nat
0
        | Bool
otherwise = Nat -> FBits n
forall (n :: Nat). Nat -> FBits n
FBits Nat
1

instance SizeValid n => Bits (FBits n) where
    .&. :: FBits n -> FBits n -> FBits n
(.&.) (FBits Nat
a) (FBits Nat
b) = Nat -> FBits n
forall (n :: Nat). Nat -> FBits n
FBits (Nat
a Nat -> Nat -> Nat
forall a. Bits a => a -> a -> a
.&. Nat
b)
    .|. :: FBits n -> FBits n -> FBits n
(.|.) (FBits Nat
a) (FBits Nat
b) = Nat -> FBits n
forall (n :: Nat). Nat -> FBits n
FBits (Nat
a Nat -> Nat -> Nat
forall a. Bits a => a -> a -> a
.|. Nat
b)
    xor :: FBits n -> FBits n -> FBits n
xor (FBits Nat
a) (FBits Nat
b) = Nat -> FBits n
forall (n :: Nat). Nat -> FBits n
FBits (Nat
a Nat -> Nat -> Nat
forall a. Bits a => a -> a -> a
`xor` Nat
b)
    shiftR :: FBits n -> Int -> FBits n
shiftR (FBits Nat
a) Int
n = Nat -> FBits n
forall (n :: Nat). Nat -> FBits n
FBits (Nat
a Nat -> Int -> Nat
forall a. Bits a => a -> Int -> a
`shiftR` Int
n)
    shiftL :: FBits n -> Int -> FBits n
shiftL (FBits Nat
a) Int
n = Nat -> FBits n
forall (n :: Nat). SizeValid n => Nat -> FBits n
toFBits (Nat
a Nat -> Int -> Nat
forall a. Bits a => a -> Int -> a
`shiftL` Int
n) -- shiftL can overflow here, so explicit safe reconstruction from natural
    rotateL :: FBits n -> Int -> FBits n
rotateL FBits n
a Int
i = ((FBits n
a FBits n -> Int -> FBits n
forall a. Bits a => a -> Int -> a
`shiftL` Int
i) FBits n -> FBits n -> FBits n
forall a. Bits a => a -> a -> a
.|. (FBits n
a FBits n -> Int -> FBits n
forall a. Bits a => a -> Int -> a
`shiftR` (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
i)))
      where n :: Int
n = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal (Proxy n
forall {k} (t :: k). Proxy t
Proxy :: Proxy n)
    rotateR :: FBits n -> Int -> FBits n
rotateR FBits n
a Int
i = ((FBits n
a FBits n -> Int -> FBits n
forall a. Bits a => a -> Int -> a
`shiftR` Int
i) FBits n -> FBits n -> FBits n
forall a. Bits a => a -> a -> a
.|. (FBits n
a FBits n -> Int -> FBits n
forall a. Bits a => a -> Int -> a
`shiftL` (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
i)))
      where n :: Int
n = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal (Proxy n
forall {k} (t :: k). Proxy t
Proxy :: Proxy n)
    zeroBits :: FBits n
zeroBits = Nat -> FBits n
forall (n :: Nat). Nat -> FBits n
FBits Nat
0
    bit :: Int -> FBits n
bit Int
i
        | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal (Proxy n
forall {k} (t :: k). Proxy t
Proxy :: Proxy n) = Nat -> FBits n
forall (n :: Nat). Nat -> FBits n
FBits Nat
0
        | Bool
otherwise                                            = Nat -> FBits n
forall (n :: Nat). Nat -> FBits n
FBits (Nat
2Nat -> Int -> Nat
forall a b. (Num a, Integral b) => a -> b -> a
^Int
i)
    testBit :: FBits n -> Int -> Bool
testBit (FBits Nat
a) Int
i = Nat -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit Nat
a Int
i
    bitSize :: FBits n -> Int
bitSize FBits n
_ = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal (Proxy n
forall {k} (t :: k). Proxy t
Proxy :: Proxy n)
    bitSizeMaybe :: FBits n -> Maybe Int
bitSizeMaybe FBits n
_ = Int -> Maybe Int
forall a. a -> Maybe a
Just (Int -> Maybe Int) -> Int -> Maybe Int
forall a b. (a -> b) -> a -> b
$ Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal (Proxy n
forall {k} (t :: k). Proxy t
Proxy :: Proxy n)
    isSigned :: FBits n -> Bool
isSigned FBits n
_ = Bool
False
    complement :: FBits n -> FBits n
complement FBits n
a = FBits n
forall (n :: Nat). SizeValid n => FBits n
allOne FBits n -> FBits n -> FBits n
forall a. Bits a => a -> a -> a
`xor` FBits n
a
    popCount :: FBits n -> Int
popCount (FBits Nat
a) = Nat -> Int
forall a. Bits a => a -> Int
popCount Nat
a

allOne :: forall n . SizeValid n => FBits n
allOne :: forall (n :: Nat). SizeValid n => FBits n
allOne = Nat -> FBits n
forall (n :: Nat). Nat -> FBits n
FBits (Nat
2 Nat -> Integer -> Nat
forall a b. (Num a, Integral b) => a -> b -> a
^ Integer
n Nat -> Nat -> Nat
forall a. Num a => a -> a -> a
- Nat
1)
  where n :: Integer
n = Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal (Proxy n
forall {k} (t :: k). Proxy t
Proxy :: Proxy n)

splitHalf :: forall m n . (SizeValid n, (n * 2) ~ m) => FBits m -> (FBits n, FBits n)
splitHalf :: forall (m :: Nat) (n :: Nat).
(SizeValid n, (n * 2) ~ m) =>
FBits m -> (FBits n, FBits n)
splitHalf (FBits Nat
a) = (Nat -> FBits n
forall (n :: Nat). Nat -> FBits n
FBits (Nat
a Nat -> Int -> Nat
forall a. Bits a => a -> Int -> a
`shiftR` Int
n), Nat -> FBits n
forall (n :: Nat). SizeValid n => Nat -> FBits n
toFBits Nat
a)
  where n :: Int
n = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal (Proxy n
forall {k} (t :: k). Proxy t
Proxy :: Proxy n)

-- | Append 2 FBits together where the left member is shifted to make room for the right
-- element.
--
-- e.g. append (0x1 :: FBits 1) (0x70 :: FBits 7) = 0xf0 :: FBits 8
append :: forall m n r . (SizeValid m, SizeValid n, SizeValid r, (m + n) ~ r)
       => FBits n -> FBits m -> FBits r
append :: forall (m :: Nat) (n :: Nat) (r :: Nat).
(SizeValid m, SizeValid n, SizeValid r, (m + n) ~ r) =>
FBits n -> FBits m -> FBits r
append (FBits Nat
a) (FBits Nat
b) =
    Nat -> FBits r
forall (n :: Nat). Nat -> FBits n
FBits ((Nat
a Nat -> Int -> Nat
forall a. Bits a => a -> Int -> a
`shiftL` Int
m) Nat -> Nat -> Nat
forall a. Bits a => a -> a -> a
.|.  Nat
b)
  where m :: Int
m = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ Proxy m -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal (Proxy m
forall {k} (t :: k). Proxy t
Proxy :: Proxy m)

--appendK :: FBitsK -> FBitsK -> FBitsK
--appendK (FBitsK a) (FBitsK b) = FBitsK (a `append` b)
    -- FBits ((a `shiftL` m) .|.  b)

dropBitsOnRight :: forall a b diff . (KnownNat diff, b <= a, SizeValid a, SizeValid b, (a - b) ~ diff)
              => FBits a
              -> FBits b
dropBitsOnRight :: forall (a :: Nat) (b :: Nat) (diff :: Nat).
(KnownNat diff, b <= a, SizeValid a, SizeValid b,
 (a - b) ~ diff) =>
FBits a -> FBits b
dropBitsOnRight (FBits Nat
a) = Nat -> FBits b
forall (n :: Nat). Nat -> FBits n
FBits (Nat
a Nat -> Int -> Nat
forall a. Bits a => a -> Int -> a
`shiftR` Integer -> Int
forall a. Num a => Integer -> a
fromInteger (Proxy diff -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal (Proxy diff
forall {k} (t :: k). Proxy t
Proxy :: Proxy diff)))

dropBitsOnLeft :: forall a b . (KnownNat b, b <= a, SizeValid a, SizeValid b)
             => FBits a
             -> FBits b
dropBitsOnLeft :: forall (a :: Nat) (b :: Nat).
(KnownNat b, b <= a, SizeValid a, SizeValid b) =>
FBits a -> FBits b
dropBitsOnLeft (FBits Nat
a) = Nat -> FBits b
forall (n :: Nat). SizeValid n => Nat -> FBits n
toFBits Nat
a