-- copied & adapted from cryptic
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE DataKinds           #-}
{-# LANGUAGE KindSignatures      #-}
{-# LANGUAGE ConstraintKinds     #-}
{-# LANGUAGE TypeOperators       #-}
{-# LANGUAGE Rank2Types          #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE ViewPatterns        #-}
module Crypto.Math.Bytes
    ( Bytes
    , Endian(..)
    , pack
    , packSome
    , unpack
    , fromBits
    , toBits
    , append
    , take
    , drop
    , splitHalf
    , trace
    ) where

import           Data.Proxy
import           Data.Word
import           Data.List (foldl')
import           GHC.Natural
import           GHC.TypeLits
import           Crypto.Math.NatMath
import           Data.Bits (shiftL)
import           Crypto.Math.Bits (FBits(..))
import           Prelude hiding (take, drop)
import qualified Prelude
import qualified Debug.Trace as Trace

newtype Bytes (n :: Nat) = Bytes { forall (n :: Nat). Bytes n -> [Word8]
unpack :: [Word8] }
    deriving (Int -> Bytes n -> ShowS
[Bytes n] -> ShowS
Bytes n -> String
(Int -> Bytes n -> ShowS)
-> (Bytes n -> String) -> ([Bytes n] -> ShowS) -> Show (Bytes n)
forall (n :: Nat). Int -> Bytes n -> ShowS
forall (n :: Nat). [Bytes n] -> ShowS
forall (n :: Nat). Bytes n -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall (n :: Nat). Int -> Bytes n -> ShowS
showsPrec :: Int -> Bytes n -> ShowS
$cshow :: forall (n :: Nat). Bytes n -> String
show :: Bytes n -> String
$cshowList :: forall (n :: Nat). [Bytes n] -> ShowS
showList :: [Bytes n] -> ShowS
Show,Bytes n -> Bytes n -> Bool
(Bytes n -> Bytes n -> Bool)
-> (Bytes n -> Bytes n -> Bool) -> Eq (Bytes n)
forall (n :: Nat). Bytes n -> Bytes n -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall (n :: Nat). Bytes n -> Bytes n -> Bool
== :: Bytes n -> Bytes n -> Bool
$c/= :: forall (n :: Nat). Bytes n -> Bytes n -> Bool
/= :: Bytes n -> Bytes n -> Bool
Eq)

data Endian = LittleEndian | BigEndian
    deriving (Int -> Endian -> ShowS
[Endian] -> ShowS
Endian -> String
(Int -> Endian -> ShowS)
-> (Endian -> String) -> ([Endian] -> ShowS) -> Show Endian
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Endian -> ShowS
showsPrec :: Int -> Endian -> ShowS
$cshow :: Endian -> String
show :: Endian -> String
$cshowList :: [Endian] -> ShowS
showList :: [Endian] -> ShowS
Show,Endian -> Endian -> Bool
(Endian -> Endian -> Bool)
-> (Endian -> Endian -> Bool) -> Eq Endian
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Endian -> Endian -> Bool
== :: Endian -> Endian -> Bool
$c/= :: Endian -> Endian -> Bool
/= :: Endian -> Endian -> Bool
Eq)

pack :: forall n . KnownNat n => [Word8] -> Bytes n
pack :: forall (n :: Nat). KnownNat n => [Word8] -> Bytes n
pack [Word8]
l
    | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
len = [Word8] -> Bytes n
forall (n :: Nat). [Word8] -> Bytes n
Bytes [Word8]
l
    | Bool
otherwise = String -> Bytes n
forall a. HasCallStack => String -> a
error String
"packing failed: length not as expected"
  where
    len :: Int
len = [Word8] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [Word8]
l
    n :: Int
n = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal (Proxy n
forall {k} (t :: k). Proxy t
Proxy :: Proxy n)

packSome :: (forall n . KnownNat n => Bytes n -> a) -> [Word8] -> a
packSome :: forall a.
(forall (n :: Nat). KnownNat n => Bytes n -> a) -> [Word8] -> a
packSome forall (n :: Nat). KnownNat n => Bytes n -> a
f [Word8]
l = case Integer -> Maybe SomeNat
someNatVal (Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len) of
                    Maybe SomeNat
Nothing          -> String -> a
forall a. HasCallStack => String -> a
error String
"impossible"
                    Just (SomeNat (Proxy n
_ :: Proxy n)) -> Bytes n -> a
forall (n :: Nat). KnownNat n => Bytes n -> a
f ([Word8] -> Bytes n
forall (n :: Nat). KnownNat n => [Word8] -> Bytes n
pack [Word8]
l :: Bytes n)
  where len :: Int
len = [Word8] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [Word8]
l

fixupBytes :: Endian -> [Word8] -> [Word8]
fixupBytes :: Endian -> [Word8] -> [Word8]
fixupBytes Endian
LittleEndian = [Word8] -> [Word8]
forall a. [a] -> [a]
reverse
fixupBytes Endian
BigEndian    = [Word8] -> [Word8]
forall a. a -> a
id

trace :: String -> Bytes n -> Bytes n
trace :: forall (n :: Nat). String -> Bytes n -> Bytes n
trace String
cmd b :: Bytes n
b@(Bytes [Word8]
l) = String -> Bytes n -> Bytes n
forall a. String -> a -> a
Trace.trace (String
cmd String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
": " String -> ShowS
forall a. [a] -> [a] -> [a]
++ (Word8 -> String) -> [Word8] -> String
forall (t :: Type -> Type) a b.
Foldable t =>
(a -> [b]) -> t a -> [b]
concatMap Word8 -> String
forall {a} {a}. (Integral a, Enum a) => a -> [a]
toHex [Word8]
l) Bytes n
b
  where
    toHex :: a -> [a]
toHex a
w = let (a
x,a
y) = a
w a -> a -> (a, a)
forall a. Integral a => a -> a -> (a, a)
`divMod` a
16 in [a -> a
forall {a} {a}. (Integral a, Enum a) => a -> a
hex a
x, a -> a
forall {a} {a}. (Integral a, Enum a) => a -> a
hex a
y]
    hex :: a -> a
hex a
i | a
i a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
10    = Int -> a
forall a. Enum a => Int -> a
toEnum (Char -> Int
forall a. Enum a => a -> Int
fromEnum Char
'0' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
i)
          | Bool
otherwise = Int -> a
forall a. Enum a => Int -> a
toEnum (Char -> Int
forall a. Enum a => a -> Int
fromEnum Char
'a' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a
ia -> a -> a
forall a. Num a => a -> a -> a
-a
10))

-- | transform bytes into bits with a specific endianness
toBits :: Endian -> Bytes n -> FBits (n * 8)
toBits :: forall (n :: Nat). Endian -> Bytes n -> FBits (n * 8)
toBits Endian
endian (Bytes [Word8]
l) = Nat -> FBits (n * 8)
forall (n :: Nat). Nat -> FBits n
FBits (Nat -> FBits (n * 8)) -> Nat -> FBits (n * 8)
forall a b. (a -> b) -> a -> b
$
    (Nat -> Word8 -> Nat) -> Nat -> [Word8] -> Nat
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\Nat
acc Word8
i -> (Nat
acc Nat -> Int -> Nat
forall a. Bits a => a -> Int -> a
`shiftL` Int
8) Nat -> Nat -> Nat
forall a. Num a => a -> a -> a
+ Word8 -> Nat
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
i) Nat
0 (Endian -> [Word8] -> [Word8]
fixupBytes Endian
endian [Word8]
l)

-- | transform bits into bytes with a specific endianness
fromBits :: forall n . KnownNat n => Endian -> FBits n -> Bytes (Div8 n)
fromBits :: forall (n :: Nat).
KnownNat n =>
Endian -> FBits n -> Bytes (Div8 n)
fromBits Endian
endian (FBits n -> Nat
forall (n :: Nat). FBits n -> Nat
unFBits -> Nat
allBits) = [Word8] -> Bytes (Div8 n)
forall (n :: Nat). [Word8] -> Bytes n
Bytes ([Word8] -> Bytes (Div8 n)) -> [Word8] -> Bytes (Div8 n)
forall a b. (a -> b) -> a -> b
$ [Word8] -> Word -> Nat -> [Word8]
loop [] (Word
0 :: Word) Nat
allBits
  where
    n :: Integer
n = Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal (Proxy n
forall {k} (t :: k). Proxy t
Proxy :: Proxy n)
    loop :: [Word8] -> Word -> Nat -> [Word8]
loop [Word8]
acc Word
i Nat
nat
        | Word -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word
i Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
> Integer
n  = String -> [Word8]
forall a. HasCallStack => String -> a
error String
"binFromFBits over"
        | Word -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word
i Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
n = Endian -> [Word8] -> [Word8]
fixupBytes Endian
endian [Word8]
acc
        | Bool
otherwise           =
            let (Nat
nat', Word8
b) = Nat -> (Nat, Word8)
divMod8 Nat
nat
             in [Word8] -> Word -> Nat -> [Word8]
loop (Word8
bWord8 -> [Word8] -> [Word8]
forall a. a -> [a] -> [a]
:[Word8]
acc) (Word
iWord -> Word -> Word
forall a. Num a => a -> a -> a
+Word
8) Nat
nat'

    divMod8 :: Natural -> (Natural, Word8)
    divMod8 :: Nat -> (Nat, Word8)
divMod8 Nat
i = let (Nat
q,Nat
m) = Nat
i Nat -> Nat -> (Nat, Nat)
forall a. Integral a => a -> a -> (a, a)
`divMod` Nat
256 in (Nat
q,Nat -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral Nat
m)


splitHalf :: forall m n . (KnownNat n, (n * 2) ~ m) => Bytes m -> (Bytes n, Bytes n)
splitHalf :: forall (m :: Nat) (n :: Nat).
(KnownNat n, (n * 2) ~ m) =>
Bytes m -> (Bytes n, Bytes n)
splitHalf (Bytes [Word8]
l) = ([Word8] -> Bytes n
forall (n :: Nat). [Word8] -> Bytes n
Bytes [Word8]
l1, [Word8] -> Bytes n
forall (n :: Nat). [Word8] -> Bytes n
Bytes [Word8]
l2)
  where
    ([Word8]
l1, [Word8]
l2) = Int -> [Word8] -> ([Word8], [Word8])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
n [Word8]
l
    n :: Int
n = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal (Proxy n
forall {k} (t :: k). Proxy t
Proxy :: Proxy n)

append :: forall m n r . ((m + n) ~ r)
       => Bytes n -> Bytes m -> Bytes r
append :: forall (m :: Nat) (n :: Nat) (r :: Nat).
((m + n) ~ r) =>
Bytes n -> Bytes m -> Bytes r
append (Bytes [Word8]
a) (Bytes [Word8]
b) = [Word8] -> Bytes r
forall (n :: Nat). [Word8] -> Bytes n
Bytes ([Word8]
a [Word8] -> [Word8] -> [Word8]
forall a. [a] -> [a] -> [a]
++ [Word8]
b)

take :: forall n m .(KnownNat n, n <= m) => Bytes m -> Bytes n
take :: forall (n :: Nat) (m :: Nat).
(KnownNat n, n <= m) =>
Bytes m -> Bytes n
take (Bytes [Word8]
l) = [Word8] -> Bytes n
forall (n :: Nat). [Word8] -> Bytes n
Bytes ([Word8] -> Bytes n) -> [Word8] -> Bytes n
forall a b. (a -> b) -> a -> b
$ Int -> [Word8] -> [Word8]
forall a. Int -> [a] -> [a]
Prelude.take (Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal (Proxy n
forall {k} (t :: k). Proxy t
Proxy :: Proxy n)) [Word8]
l

drop :: forall n m . (KnownNat m, KnownNat n, n <= m) => Bytes m -> Bytes n
drop :: forall (n :: Nat) (m :: Nat).
(KnownNat m, KnownNat n, n <= m) =>
Bytes m -> Bytes n
drop (Bytes [Word8]
l) = [Word8] -> Bytes n
forall (n :: Nat). [Word8] -> Bytes n
Bytes ([Word8] -> Bytes n) -> [Word8] -> Bytes n
forall a b. (a -> b) -> a -> b
$ Int -> [Word8] -> [Word8]
forall a. Int -> [a] -> [a]
Prelude.drop (Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
diff) [Word8]
l
  where diff :: Integer
diff = Proxy m -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal (Proxy m
forall {k} (t :: k). Proxy t
Proxy :: Proxy m) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal (Proxy n
forall {k} (t :: k). Proxy t
Proxy :: Proxy n)