{-# LANGUAGE CApiFFI #-}
{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}

-- |
-- Module      : Crypto.Cipher.ChaCha
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : stable
-- Portability : good
module Crypto.Cipher.ChaCha (
    initialize,
    initializeX,
    combine,
    generate,
    State,

    -- * Simple interface for DRG purpose
    initializeSimple,
    generateSimple,
    StateSimple,

    -- * Seeking and cursor for DRG purposes
    generateSimpleBlock,
    ChaChaState (..),
) where

import Crypto.Internal.ByteArray (
    ByteArray,
    ByteArrayAccess,
    ScrubbedBytes,
 )
import qualified Crypto.Internal.ByteArray as B
import Crypto.Internal.Compat
import Crypto.Internal.Imports
import Foreign.C.Types
import Foreign.Ptr

-- | ChaCha context
newtype State = State ScrubbedBytes
    deriving (State -> ()
(State -> ()) -> NFData State
forall a. (a -> ()) -> NFData a
$crnf :: State -> ()
rnf :: State -> ()
NFData)

-- | ChaCha context for DRG purpose (see Crypto.Random.ChaChaDRG)
newtype StateSimple = StateSimple ScrubbedBytes -- just ChaCha's state
    deriving (StateSimple -> ()
(StateSimple -> ()) -> NFData StateSimple
forall a. (a -> ()) -> NFData a
$crnf :: StateSimple -> ()
rnf :: StateSimple -> ()
NFData)

class ChaChaState a where
    getCounter64 :: a -> Word64
    setCounter64 :: Word64 -> a -> a
    getCounter32 :: a -> Word32
    setCounter32 :: Word32 -> a -> a

instance ChaChaState State where
    getCounter64 :: State -> Word64
getCounter64 (State ScrubbedBytes
st) = ScrubbedBytes -> (Ptr State -> Ptr StateSimple) -> Word64
forall a. ScrubbedBytes -> (Ptr a -> Ptr StateSimple) -> Word64
getCounter64' ScrubbedBytes
st Ptr State -> Ptr StateSimple
ccrypton_chacha_get_state
    setCounter64 :: Word64 -> State -> State
setCounter64 Word64
n (State ScrubbedBytes
st) = ScrubbedBytes -> State
State (ScrubbedBytes -> State) -> ScrubbedBytes -> State
forall a b. (a -> b) -> a -> b
$ Word64
-> ScrubbedBytes -> (Ptr State -> Ptr StateSimple) -> ScrubbedBytes
forall a.
Word64
-> ScrubbedBytes -> (Ptr a -> Ptr StateSimple) -> ScrubbedBytes
setCounter64' Word64
n ScrubbedBytes
st Ptr State -> Ptr StateSimple
ccrypton_chacha_get_state
    getCounter32 :: State -> Word32
getCounter32 (State ScrubbedBytes
st) = ScrubbedBytes -> (Ptr State -> Ptr StateSimple) -> Word32
forall a. ScrubbedBytes -> (Ptr a -> Ptr StateSimple) -> Word32
getCounter32' ScrubbedBytes
st Ptr State -> Ptr StateSimple
ccrypton_chacha_get_state
    setCounter32 :: Word32 -> State -> State
setCounter32 Word32
n (State ScrubbedBytes
st) = ScrubbedBytes -> State
State (ScrubbedBytes -> State) -> ScrubbedBytes -> State
forall a b. (a -> b) -> a -> b
$ Word32
-> ScrubbedBytes -> (Ptr State -> Ptr StateSimple) -> ScrubbedBytes
forall a.
Word32
-> ScrubbedBytes -> (Ptr a -> Ptr StateSimple) -> ScrubbedBytes
setCounter32' Word32
n ScrubbedBytes
st Ptr State -> Ptr StateSimple
ccrypton_chacha_get_state

instance ChaChaState StateSimple where
    getCounter64 :: StateSimple -> Word64
getCounter64 (StateSimple ScrubbedBytes
st) = ScrubbedBytes -> (Ptr StateSimple -> Ptr StateSimple) -> Word64
forall a. ScrubbedBytes -> (Ptr a -> Ptr StateSimple) -> Word64
getCounter64' ScrubbedBytes
st Ptr StateSimple -> Ptr StateSimple
forall a. a -> a
id
    setCounter64 :: Word64 -> StateSimple -> StateSimple
setCounter64 Word64
n (StateSimple ScrubbedBytes
st) = ScrubbedBytes -> StateSimple
StateSimple (ScrubbedBytes -> StateSimple) -> ScrubbedBytes -> StateSimple
forall a b. (a -> b) -> a -> b
$ Word64
-> ScrubbedBytes
-> (Ptr StateSimple -> Ptr StateSimple)
-> ScrubbedBytes
forall a.
Word64
-> ScrubbedBytes -> (Ptr a -> Ptr StateSimple) -> ScrubbedBytes
setCounter64' Word64
n ScrubbedBytes
st Ptr StateSimple -> Ptr StateSimple
forall a. a -> a
id
    getCounter32 :: StateSimple -> Word32
getCounter32 (StateSimple ScrubbedBytes
st) = ScrubbedBytes -> (Ptr StateSimple -> Ptr StateSimple) -> Word32
forall a. ScrubbedBytes -> (Ptr a -> Ptr StateSimple) -> Word32
getCounter32' ScrubbedBytes
st Ptr StateSimple -> Ptr StateSimple
forall a. a -> a
id
    setCounter32 :: Word32 -> StateSimple -> StateSimple
setCounter32 Word32
n (StateSimple ScrubbedBytes
st) = ScrubbedBytes -> StateSimple
StateSimple (ScrubbedBytes -> StateSimple) -> ScrubbedBytes -> StateSimple
forall a b. (a -> b) -> a -> b
$ Word32
-> ScrubbedBytes
-> (Ptr StateSimple -> Ptr StateSimple)
-> ScrubbedBytes
forall a.
Word32
-> ScrubbedBytes -> (Ptr a -> Ptr StateSimple) -> ScrubbedBytes
setCounter32' Word32
n ScrubbedBytes
st Ptr StateSimple -> Ptr StateSimple
forall a. a -> a
id

getCounter64' :: ScrubbedBytes -> (Ptr a -> Ptr StateSimple) -> Word64
getCounter64' :: forall a. ScrubbedBytes -> (Ptr a -> Ptr StateSimple) -> Word64
getCounter64' ScrubbedBytes
currSt Ptr a -> Ptr StateSimple
conv =
    IO Word64 -> Word64
forall a. IO a -> a
unsafeDoIO (IO Word64 -> Word64) -> IO Word64 -> Word64
forall a b. (a -> b) -> a -> b
$ do
        ScrubbedBytes -> (Ptr a -> IO Word64) -> IO Word64
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
forall p a. ScrubbedBytes -> (Ptr p -> IO a) -> IO a
B.withByteArray ScrubbedBytes
currSt ((Ptr a -> IO Word64) -> IO Word64)
-> (Ptr a -> IO Word64) -> IO Word64
forall a b. (a -> b) -> a -> b
$ \Ptr a
stPtr ->
            Ptr StateSimple -> IO Word64
ccrypton_chacha_counter64 (Ptr StateSimple -> IO Word64) -> Ptr StateSimple -> IO Word64
forall a b. (a -> b) -> a -> b
$ Ptr a -> Ptr StateSimple
conv Ptr a
stPtr

getCounter32' :: ScrubbedBytes -> (Ptr a -> Ptr StateSimple) -> Word32
getCounter32' :: forall a. ScrubbedBytes -> (Ptr a -> Ptr StateSimple) -> Word32
getCounter32' ScrubbedBytes
currSt Ptr a -> Ptr StateSimple
conv =
    IO Word32 -> Word32
forall a. IO a -> a
unsafeDoIO (IO Word32 -> Word32) -> IO Word32 -> Word32
forall a b. (a -> b) -> a -> b
$ do
        ScrubbedBytes -> (Ptr a -> IO Word32) -> IO Word32
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
forall p a. ScrubbedBytes -> (Ptr p -> IO a) -> IO a
B.withByteArray ScrubbedBytes
currSt ((Ptr a -> IO Word32) -> IO Word32)
-> (Ptr a -> IO Word32) -> IO Word32
forall a b. (a -> b) -> a -> b
$ \Ptr a
stPtr ->
            Ptr StateSimple -> IO Word32
ccrypton_chacha_counter32 (Ptr StateSimple -> IO Word32) -> Ptr StateSimple -> IO Word32
forall a b. (a -> b) -> a -> b
$ Ptr a -> Ptr StateSimple
conv Ptr a
stPtr

setCounter64'
    :: Word64 -> ScrubbedBytes -> (Ptr a -> Ptr StateSimple) -> ScrubbedBytes
setCounter64' :: forall a.
Word64
-> ScrubbedBytes -> (Ptr a -> Ptr StateSimple) -> ScrubbedBytes
setCounter64' Word64
newCounter ScrubbedBytes
prevSt Ptr a -> Ptr StateSimple
conv =
    IO ScrubbedBytes -> ScrubbedBytes
forall a. IO a -> a
unsafeDoIO (IO ScrubbedBytes -> ScrubbedBytes)
-> IO ScrubbedBytes -> ScrubbedBytes
forall a b. (a -> b) -> a -> b
$ do
        ScrubbedBytes
newSt <- ScrubbedBytes -> (Ptr Any -> IO ()) -> IO ScrubbedBytes
forall bs1 bs2 p.
(ByteArrayAccess bs1, ByteArray bs2) =>
bs1 -> (Ptr p -> IO ()) -> IO bs2
B.copy ScrubbedBytes
prevSt (\Ptr Any
_ -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ())
        ScrubbedBytes -> (Ptr a -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
forall p a. ScrubbedBytes -> (Ptr p -> IO a) -> IO a
B.withByteArray ScrubbedBytes
newSt ((Ptr a -> IO ()) -> IO ()) -> (Ptr a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
stPtr ->
            Ptr StateSimple -> Word64 -> IO ()
ccrypton_chacha_set_counter64 (Ptr a -> Ptr StateSimple
conv Ptr a
stPtr) Word64
newCounter
        ScrubbedBytes -> IO ScrubbedBytes
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ScrubbedBytes
newSt

setCounter32'
    :: Word32 -> ScrubbedBytes -> (Ptr a -> Ptr StateSimple) -> ScrubbedBytes
setCounter32' :: forall a.
Word32
-> ScrubbedBytes -> (Ptr a -> Ptr StateSimple) -> ScrubbedBytes
setCounter32' Word32
newCounter ScrubbedBytes
prevSt Ptr a -> Ptr StateSimple
conv =
    IO ScrubbedBytes -> ScrubbedBytes
forall a. IO a -> a
unsafeDoIO (IO ScrubbedBytes -> ScrubbedBytes)
-> IO ScrubbedBytes -> ScrubbedBytes
forall a b. (a -> b) -> a -> b
$ do
        ScrubbedBytes
newSt <- ScrubbedBytes -> (Ptr Any -> IO ()) -> IO ScrubbedBytes
forall bs1 bs2 p.
(ByteArrayAccess bs1, ByteArray bs2) =>
bs1 -> (Ptr p -> IO ()) -> IO bs2
B.copy ScrubbedBytes
prevSt (\Ptr Any
_ -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ())
        ScrubbedBytes -> (Ptr a -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
forall p a. ScrubbedBytes -> (Ptr p -> IO a) -> IO a
B.withByteArray ScrubbedBytes
newSt ((Ptr a -> IO ()) -> IO ()) -> (Ptr a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
stPtr ->
            Ptr StateSimple -> Word32 -> IO ()
ccrypton_chacha_set_counter32 (Ptr a -> Ptr StateSimple
conv Ptr a
stPtr) Word32
newCounter
        ScrubbedBytes -> IO ScrubbedBytes
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ScrubbedBytes
newSt

-- | Initialize a new ChaCha context with the number of rounds,
-- the key and the nonce associated.
initialize
    :: (ByteArrayAccess key, ByteArrayAccess nonce)
    => Int
    -- ^ number of rounds (8,12,20)
    -> key
    -- ^ the key (128 or 256 bits)
    -> nonce
    -- ^ the nonce (64 or 96 bits)
    -> State
    -- ^ the initial ChaCha state
initialize :: forall key nonce.
(ByteArrayAccess key, ByteArrayAccess nonce) =>
Int -> key -> nonce -> State
initialize Int
nbRounds key
key nonce
nonce
    | Int
kLen Int -> [Int] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [Int
16, Int
32] =
        [Char] -> State
forall a. HasCallStack => [Char] -> a
error [Char]
"ChaCha: key length should be 128 or 256 bits"
    | Int
nonceLen Int -> [Int] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [Int
8, Int
12] =
        [Char] -> State
forall a. HasCallStack => [Char] -> a
error [Char]
"ChaCha: nonce length should be 64 or 96 bits"
    | Int
nbRounds Int -> [Int] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [Int
8, Int
12, Int
20] = [Char] -> State
forall a. HasCallStack => [Char] -> a
error [Char]
"ChaCha: rounds should be 8, 12 or 20"
    | Bool
otherwise = IO State -> State
forall a. IO a -> a
unsafeDoIO (IO State -> State) -> IO State -> State
forall a b. (a -> b) -> a -> b
$ do
        ScrubbedBytes
stPtr <- Int -> (Ptr State -> IO ()) -> IO ScrubbedBytes
forall ba p. ByteArray ba => Int -> (Ptr p -> IO ()) -> IO ba
B.alloc Int
132 ((Ptr State -> IO ()) -> IO ScrubbedBytes)
-> (Ptr State -> IO ()) -> IO ScrubbedBytes
forall a b. (a -> b) -> a -> b
$ \Ptr State
stPtr ->
            nonce -> (Ptr Word8 -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
forall p a. nonce -> (Ptr p -> IO a) -> IO a
B.withByteArray nonce
nonce ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
noncePtr ->
                key -> (Ptr Word8 -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
forall p a. key -> (Ptr p -> IO a) -> IO a
B.withByteArray key
key ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
keyPtr ->
                    Ptr State -> Int -> Int -> Ptr Word8 -> Int -> Ptr Word8 -> IO ()
ccrypton_chacha_init Ptr State
stPtr Int
nbRounds Int
kLen Ptr Word8
keyPtr Int
nonceLen Ptr Word8
noncePtr
        State -> IO State
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (State -> IO State) -> State -> IO State
forall a b. (a -> b) -> a -> b
$ ScrubbedBytes -> State
State ScrubbedBytes
stPtr
  where
    kLen :: Int
kLen = key -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length key
key
    nonceLen :: Int
nonceLen = nonce -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length nonce
nonce

-- | Initialize a new XChaCha context with the number of rounds,
-- the key and the nonce associated.
--
-- An XChaCha state can be used like a regular ChaCha state after initialisation.
initializeX
    :: (ByteArrayAccess key, ByteArrayAccess nonce)
    => Int
    -- ^ number of rounds (8,12,20)
    -> key
    -- ^ the key (256 bits)
    -> nonce
    -- ^ the nonce (192 bits)
    -> State
    -- ^ the initial ChaCha state
initializeX :: forall key nonce.
(ByteArrayAccess key, ByteArrayAccess nonce) =>
Int -> key -> nonce -> State
initializeX Int
nbRounds key
key nonce
nonce
    | Int
kLen Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
32 =
        [Char] -> State
forall a. HasCallStack => [Char] -> a
error [Char]
"XChaCha: key length should be 256 bits"
    | Int
nonceLen Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
24 =
        [Char] -> State
forall a. HasCallStack => [Char] -> a
error [Char]
"XChaCha: nonce length should be 192 bits"
    | Int
nbRounds Int -> [Int] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [Int
8, Int
12, Int
20] =
        [Char] -> State
forall a. HasCallStack => [Char] -> a
error [Char]
"XChaCha: rounds should be 8, 12 or 20"
    | Bool
otherwise = IO State -> State
forall a. IO a -> a
unsafeDoIO (IO State -> State) -> IO State -> State
forall a b. (a -> b) -> a -> b
$ do
        ScrubbedBytes
stPtr <- Int -> (Ptr State -> IO ()) -> IO ScrubbedBytes
forall ba p. ByteArray ba => Int -> (Ptr p -> IO ()) -> IO ba
B.alloc Int
132 ((Ptr State -> IO ()) -> IO ScrubbedBytes)
-> (Ptr State -> IO ()) -> IO ScrubbedBytes
forall a b. (a -> b) -> a -> b
$ \Ptr State
stPtr ->
            nonce -> (Ptr Word8 -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
forall p a. nonce -> (Ptr p -> IO a) -> IO a
B.withByteArray nonce
nonce ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
noncePtr ->
                key -> (Ptr Word8 -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
forall p a. key -> (Ptr p -> IO a) -> IO a
B.withByteArray key
key ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
keyPtr ->
                    Ptr State -> Int -> Ptr Word8 -> Ptr Word8 -> IO ()
ccrypton_xchacha_init Ptr State
stPtr Int
nbRounds Ptr Word8
keyPtr Ptr Word8
noncePtr
        State -> IO State
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (State -> IO State) -> State -> IO State
forall a b. (a -> b) -> a -> b
$ ScrubbedBytes -> State
State ScrubbedBytes
stPtr
  where
    kLen :: Int
kLen = key -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length key
key
    nonceLen :: Int
nonceLen = nonce -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length nonce
nonce

-- | Initialize simple ChaCha State
--
-- The seed need to be at least 40 bytes long
initializeSimple
    :: ByteArrayAccess seed
    => seed
    -- ^ a 40 bytes long seed
    -> StateSimple
initializeSimple :: forall seed. ByteArrayAccess seed => seed -> StateSimple
initializeSimple seed
seed
    | Int
sLen Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
40 = [Char] -> StateSimple
forall a. HasCallStack => [Char] -> a
error [Char]
"ChaCha Random: seed length should be 40 bytes"
    | Bool
otherwise = IO StateSimple -> StateSimple
forall a. IO a -> a
unsafeDoIO (IO StateSimple -> StateSimple) -> IO StateSimple -> StateSimple
forall a b. (a -> b) -> a -> b
$ do
        ScrubbedBytes
stPtr <- Int -> (Ptr StateSimple -> IO ()) -> IO ScrubbedBytes
forall ba p. ByteArray ba => Int -> (Ptr p -> IO ()) -> IO ba
B.alloc Int
64 ((Ptr StateSimple -> IO ()) -> IO ScrubbedBytes)
-> (Ptr StateSimple -> IO ()) -> IO ScrubbedBytes
forall a b. (a -> b) -> a -> b
$ \Ptr StateSimple
stPtr ->
            seed -> (Ptr Word8 -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
forall p a. seed -> (Ptr p -> IO a) -> IO a
B.withByteArray seed
seed ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
seedPtr ->
                Ptr StateSimple -> Int -> Ptr Word8 -> Int -> Ptr Word8 -> IO ()
ccrypton_chacha_init_core Ptr StateSimple
stPtr Int
32 Ptr Word8
seedPtr Int
8 (Ptr Word8
seedPtr Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
32)
        StateSimple -> IO StateSimple
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (StateSimple -> IO StateSimple) -> StateSimple -> IO StateSimple
forall a b. (a -> b) -> a -> b
$ ScrubbedBytes -> StateSimple
StateSimple ScrubbedBytes
stPtr
  where
    sLen :: Int
sLen = seed -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length seed
seed

-- | Combine the chacha output and an arbitrary message with a xor,
-- and return the combined output and the new state.
combine
    :: ByteArray ba
    => State
    -- ^ the current ChaCha state
    -> ba
    -- ^ the source to xor with the generator
    -> (ba, State)
combine :: forall ba. ByteArray ba => State -> ba -> (ba, State)
combine prevSt :: State
prevSt@(State ScrubbedBytes
prevStMem) ba
src
    | ba -> Bool
forall a. ByteArrayAccess a => a -> Bool
B.null ba
src = (ba
forall a. ByteArray a => a
B.empty, State
prevSt)
    | Bool
otherwise = IO (ba, State) -> (ba, State)
forall a. IO a -> a
unsafeDoIO (IO (ba, State) -> (ba, State)) -> IO (ba, State) -> (ba, State)
forall a b. (a -> b) -> a -> b
$ do
        (ba
out, ScrubbedBytes
st) <- ScrubbedBytes -> (Ptr State -> IO ba) -> IO (ba, ScrubbedBytes)
forall bs1 bs2 p a.
(ByteArrayAccess bs1, ByteArray bs2) =>
bs1 -> (Ptr p -> IO a) -> IO (a, bs2)
B.copyRet ScrubbedBytes
prevStMem ((Ptr State -> IO ba) -> IO (ba, ScrubbedBytes))
-> (Ptr State -> IO ba) -> IO (ba, ScrubbedBytes)
forall a b. (a -> b) -> a -> b
$ \Ptr State
ctx ->
            Int -> (Ptr Word8 -> IO ()) -> IO ba
forall ba p. ByteArray ba => Int -> (Ptr p -> IO ()) -> IO ba
B.alloc (ba -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
src) ((Ptr Word8 -> IO ()) -> IO ba) -> (Ptr Word8 -> IO ()) -> IO ba
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
dstPtr ->
                ba -> (Ptr Word8 -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
forall p a. ba -> (Ptr p -> IO a) -> IO a
B.withByteArray ba
src ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
srcPtr ->
                    Ptr Word8 -> Ptr State -> Ptr Word8 -> CUInt -> IO ()
ccrypton_chacha_combine Ptr Word8
dstPtr Ptr State
ctx Ptr Word8
srcPtr (Int -> CUInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CUInt) -> Int -> CUInt
forall a b. (a -> b) -> a -> b
$ ba -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
src)
        (ba, State) -> IO (ba, State)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ba
out, ScrubbedBytes -> State
State ScrubbedBytes
st)

-- | Generate a number of bytes from the ChaCha output directly
generate
    :: ByteArray ba
    => State
    -- ^ the current ChaCha state
    -> Int
    -- ^ the length of data to generate
    -> (ba, State)
generate :: forall ba. ByteArray ba => State -> Int -> (ba, State)
generate prevSt :: State
prevSt@(State ScrubbedBytes
prevStMem) Int
len
    | Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0 = (ba
forall a. ByteArray a => a
B.empty, State
prevSt)
    | Bool
otherwise = IO (ba, State) -> (ba, State)
forall a. IO a -> a
unsafeDoIO (IO (ba, State) -> (ba, State)) -> IO (ba, State) -> (ba, State)
forall a b. (a -> b) -> a -> b
$ do
        (ba
out, ScrubbedBytes
st) <- ScrubbedBytes -> (Ptr State -> IO ba) -> IO (ba, ScrubbedBytes)
forall bs1 bs2 p a.
(ByteArrayAccess bs1, ByteArray bs2) =>
bs1 -> (Ptr p -> IO a) -> IO (a, bs2)
B.copyRet ScrubbedBytes
prevStMem ((Ptr State -> IO ba) -> IO (ba, ScrubbedBytes))
-> (Ptr State -> IO ba) -> IO (ba, ScrubbedBytes)
forall a b. (a -> b) -> a -> b
$ \Ptr State
ctx ->
            Int -> (Ptr Word8 -> IO ()) -> IO ba
forall ba p. ByteArray ba => Int -> (Ptr p -> IO ()) -> IO ba
B.alloc Int
len ((Ptr Word8 -> IO ()) -> IO ba) -> (Ptr Word8 -> IO ()) -> IO ba
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
dstPtr ->
                Ptr Word8 -> Ptr State -> CUInt -> IO ()
ccrypton_chacha_generate Ptr Word8
dstPtr Ptr State
ctx (Int -> CUInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)
        (ba, State) -> IO (ba, State)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ba
out, ScrubbedBytes -> State
State ScrubbedBytes
st)

-- | similar to 'generate' but assume certains values
generateSimple
    :: ByteArray ba
    => StateSimple
    -> Int
    -> (ba, StateSimple)
generateSimple :: forall ba. ByteArray ba => StateSimple -> Int -> (ba, StateSimple)
generateSimple (StateSimple ScrubbedBytes
prevSt) Int
nbBytes = IO (ba, StateSimple) -> (ba, StateSimple)
forall a. IO a -> a
unsafeDoIO (IO (ba, StateSimple) -> (ba, StateSimple))
-> IO (ba, StateSimple) -> (ba, StateSimple)
forall a b. (a -> b) -> a -> b
$ do
    ScrubbedBytes
newSt <- ScrubbedBytes -> (Ptr Any -> IO ()) -> IO ScrubbedBytes
forall bs1 bs2 p.
(ByteArrayAccess bs1, ByteArray bs2) =>
bs1 -> (Ptr p -> IO ()) -> IO bs2
B.copy ScrubbedBytes
prevSt (\Ptr Any
_ -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ())
    ba
output <- Int -> (Ptr Word8 -> IO ()) -> IO ba
forall ba p. ByteArray ba => Int -> (Ptr p -> IO ()) -> IO ba
B.alloc Int
nbBytes ((Ptr Word8 -> IO ()) -> IO ba) -> (Ptr Word8 -> IO ()) -> IO ba
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
dstPtr ->
        ScrubbedBytes -> (Ptr StateSimple -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
forall p a. ScrubbedBytes -> (Ptr p -> IO a) -> IO a
B.withByteArray ScrubbedBytes
newSt ((Ptr StateSimple -> IO ()) -> IO ())
-> (Ptr StateSimple -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr StateSimple
stPtr ->
            Int -> Ptr Word8 -> Ptr StateSimple -> CUInt -> IO ()
ccrypton_chacha_random Int
8 Ptr Word8
dstPtr Ptr StateSimple
stPtr (Int -> CUInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nbBytes)
    (ba, StateSimple) -> IO (ba, StateSimple)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ba
output, ScrubbedBytes -> StateSimple
StateSimple ScrubbedBytes
newSt)

-- | similar to 'generate' but accepts a number of rounds, and always generates
--   64 bytes (a single block)
generateSimpleBlock
    :: ByteArray ba
    => Word8
    -> StateSimple
    -> (ba, StateSimple)
generateSimpleBlock :: forall ba.
ByteArray ba =>
Word8 -> StateSimple -> (ba, StateSimple)
generateSimpleBlock Word8
nbRounds (StateSimple ScrubbedBytes
prevSt)
    | Word8
nbRounds Word8 -> [Word8] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [Word8
8, Word8
12, Word8
20] = [Char] -> (ba, StateSimple)
forall a. HasCallStack => [Char] -> a
error [Char]
"ChaCha: rounds should be 8, 12 or 20"
    | Bool
otherwise = IO (ba, StateSimple) -> (ba, StateSimple)
forall a. IO a -> a
unsafeDoIO (IO (ba, StateSimple) -> (ba, StateSimple))
-> IO (ba, StateSimple) -> (ba, StateSimple)
forall a b. (a -> b) -> a -> b
$ do
        ScrubbedBytes
newSt <- ScrubbedBytes -> (Ptr Any -> IO ()) -> IO ScrubbedBytes
forall bs1 bs2 p.
(ByteArrayAccess bs1, ByteArray bs2) =>
bs1 -> (Ptr p -> IO ()) -> IO bs2
B.copy ScrubbedBytes
prevSt (\Ptr Any
_ -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ())
        ba
output <- Int -> (Ptr Word8 -> IO ()) -> IO ba
forall ba p. ByteArray ba => Int -> (Ptr p -> IO ()) -> IO ba
B.alloc Int
64 ((Ptr Word8 -> IO ()) -> IO ba) -> (Ptr Word8 -> IO ()) -> IO ba
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
dstPtr ->
            ScrubbedBytes -> (Ptr StateSimple -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
forall p a. ScrubbedBytes -> (Ptr p -> IO a) -> IO a
B.withByteArray ScrubbedBytes
newSt ((Ptr StateSimple -> IO ()) -> IO ())
-> (Ptr StateSimple -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr StateSimple
stPtr ->
                Ptr Word8 -> Ptr StateSimple -> Word8 -> IO ()
ccrypton_chacha_generate_simple_block Ptr Word8
dstPtr Ptr StateSimple
stPtr Word8
nbRounds
        (ba, StateSimple) -> IO (ba, StateSimple)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ba
output, ScrubbedBytes -> StateSimple
StateSimple ScrubbedBytes
newSt)

foreign import ccall unsafe "crypton_chacha_init_core"
    ccrypton_chacha_init_core
        :: Ptr StateSimple -> Int -> Ptr Word8 -> Int -> Ptr Word8 -> IO ()

foreign import ccall unsafe "crypton_chacha_init"
    ccrypton_chacha_init
        :: Ptr State -> Int -> Int -> Ptr Word8 -> Int -> Ptr Word8 -> IO ()

foreign import ccall unsafe "crypton_xchacha_init"
    ccrypton_xchacha_init :: Ptr State -> Int -> Ptr Word8 -> Ptr Word8 -> IO ()

foreign import ccall "crypton_chacha_combine"
    ccrypton_chacha_combine :: Ptr Word8 -> Ptr State -> Ptr Word8 -> CUInt -> IO ()

foreign import ccall "crypton_chacha_generate"
    ccrypton_chacha_generate :: Ptr Word8 -> Ptr State -> CUInt -> IO ()

foreign import ccall "crypton_chacha_random"
    ccrypton_chacha_random :: Int -> Ptr Word8 -> Ptr StateSimple -> CUInt -> IO ()

foreign import ccall unsafe "crypton_chacha_counter64"
    ccrypton_chacha_counter64 :: Ptr StateSimple -> IO Word64

foreign import ccall unsafe "crypton_chacha_set_counter64"
    ccrypton_chacha_set_counter64 :: Ptr StateSimple -> Word64 -> IO ()

foreign import ccall unsafe "crypton_chacha_counter32"
    ccrypton_chacha_counter32 :: Ptr StateSimple -> IO Word32

foreign import ccall unsafe "crypton_chacha_set_counter32"
    ccrypton_chacha_set_counter32 :: Ptr StateSimple -> Word32 -> IO ()

foreign import ccall unsafe "crypton_chacha_generate_simple_block"
    ccrypton_chacha_generate_simple_block
        :: Ptr Word8 -> Ptr StateSimple -> Word8 -> IO ()

foreign import capi unsafe "crypton_chacha.h crypton_chacha_get_state"
    ccrypton_chacha_get_state :: Ptr State -> Ptr StateSimple