{-# LANGUAGE RecordWildCards #-}

module Crypto.HPKE.Context (
    -- * Sender
    ContextS,
    newContextS,
    seal,
    exportS,

    -- * Receiver
    ContextR,
    newContextR,
    open,
    exportR,
) where

import qualified Control.Exception as E
import Data.ByteArray (xor)
import qualified Data.ByteString as BS
import Data.IORef (IORef, newIORef, readIORef, writeIORef)

import Crypto.HPKE.Types

----------------------------------------------------------------

-- | Context for senders.
data ContextS = ContextS
    { ContextS -> IORef Integer
seqRefS :: IORef Integer
    , ContextS -> Seal
sealS :: Seal
    , ContextS -> Nonce
nonceBaseS :: Nonce
    , ContextS -> Nonce -> Int -> Nonce
expandS :: Info -> Int -> Key
    }

-- | Context for receivers.
data ContextR = ContextR
    { ContextR -> IORef Integer
seqRefR :: IORef Integer
    , ContextR -> Seal
openR :: Open
    , ContextR -> Nonce
nonceBaseR :: Nonce
    , ContextR -> Nonce -> Int -> Nonce
expandR :: Info -> Int -> Key
    }

----------------------------------------------------------------

-- | Encryption.
--   This throws 'HPKEError'.
seal :: ContextS -> AAD -> PlainText -> IO CipherText
seal :: ContextS -> Nonce -> Nonce -> IO Nonce
seal ContextS{IORef Integer
Nonce
Nonce -> Int -> Nonce
Seal
seqRefS :: ContextS -> IORef Integer
sealS :: ContextS -> Seal
nonceBaseS :: ContextS -> Nonce
expandS :: ContextS -> Nonce -> Int -> Nonce
seqRefS :: IORef Integer
sealS :: Seal
nonceBaseS :: Nonce
expandS :: Nonce -> Int -> Nonce
..} Nonce
aad Nonce
pt = do
    Integer
seqI <- IORef Integer -> IO Integer
forall a. IORef a -> IO a
readIORef IORef Integer
seqRefS
    let len :: Int
len = Nonce -> Int
BS.length Nonce
nonceBaseS
        seqBS :: Nonce
seqBS = Int -> Integer -> Nonce
forall ba. ByteArray ba => Int -> Integer -> ba
i2ospOf_ Int
len Integer
seqI :: ByteString
        nonce :: Nonce
nonce = Nonce
seqBS Nonce -> Nonce -> Nonce
forall a b c.
(ByteArrayAccess a, ByteArrayAccess b, ByteArray c) =>
a -> b -> c
`xor` Nonce
nonceBaseS
        ect :: Either HPKEError Nonce
ect = Seal
sealS Nonce
nonce Nonce
aad Nonce
pt
        seqI' :: Integer
seqI' = Integer
seqI Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1
    IORef Integer -> Integer -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef Integer
seqRefS Integer
seqI'
    case Either HPKEError Nonce
ect of
        Right Nonce
ct -> Nonce -> IO Nonce
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Nonce
ct
        Left HPKEError
err -> HPKEError -> IO Nonce
forall e a. Exception e => e -> IO a
E.throwIO HPKEError
err

-- | Decryption.
--   This throws 'HPKEError'.
open
    :: ContextR -> AAD -> CipherText -> IO PlainText
open :: ContextR -> Nonce -> Nonce -> IO Nonce
open ContextR{IORef Integer
Nonce
Nonce -> Int -> Nonce
Seal
seqRefR :: ContextR -> IORef Integer
openR :: ContextR -> Seal
nonceBaseR :: ContextR -> Nonce
expandR :: ContextR -> Nonce -> Int -> Nonce
seqRefR :: IORef Integer
openR :: Seal
nonceBaseR :: Nonce
expandR :: Nonce -> Int -> Nonce
..} Nonce
aad Nonce
ct = do
    Integer
seqI <- IORef Integer -> IO Integer
forall a. IORef a -> IO a
readIORef IORef Integer
seqRefR
    let len :: Int
len = Nonce -> Int
BS.length Nonce
nonceBaseR
        seqBS :: Nonce
seqBS = Int -> Integer -> Nonce
forall ba. ByteArray ba => Int -> Integer -> ba
i2ospOf_ Int
len Integer
seqI :: ByteString
        nonce :: Nonce
nonce = Nonce
seqBS Nonce -> Nonce -> Nonce
forall a b c.
(ByteArrayAccess a, ByteArrayAccess b, ByteArray c) =>
a -> b -> c
`xor` Nonce
nonceBaseR
        ept :: Either HPKEError Nonce
ept = Seal
openR Nonce
nonce Nonce
aad Nonce
ct
    case Either HPKEError Nonce
ept of
        Left HPKEError
err -> HPKEError -> IO Nonce
forall e a. Exception e => e -> IO a
E.throwIO HPKEError
err
        Right Nonce
pt -> do
            let seqI' :: Integer
seqI' = Integer
seqI Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1
            IORef Integer -> Integer -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef Integer
seqRefR Integer
seqI'
            Nonce -> IO Nonce
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Nonce
pt

----------------------------------------------------------------

-- | Exporting secret.
exportS :: ContextS -> Info -> Int -> Key
exportS :: ContextS -> Nonce -> Int -> Nonce
exportS ContextS{IORef Integer
Nonce
Nonce -> Int -> Nonce
Seal
seqRefS :: ContextS -> IORef Integer
sealS :: ContextS -> Seal
nonceBaseS :: ContextS -> Nonce
expandS :: ContextS -> Nonce -> Int -> Nonce
seqRefS :: IORef Integer
sealS :: Seal
nonceBaseS :: Nonce
expandS :: Nonce -> Int -> Nonce
..} Nonce
exporter_context Int
len =
    Nonce -> Int -> Nonce
expandS Nonce
exporter_context Int
len

-- | Exporting secret.
exportR :: ContextR -> Info -> Int -> Key
exportR :: ContextR -> Nonce -> Int -> Nonce
exportR ContextR{IORef Integer
Nonce
Nonce -> Int -> Nonce
Seal
seqRefR :: ContextR -> IORef Integer
openR :: ContextR -> Seal
nonceBaseR :: ContextR -> Nonce
expandR :: ContextR -> Nonce -> Int -> Nonce
seqRefR :: IORef Integer
openR :: Seal
nonceBaseR :: Nonce
expandR :: Nonce -> Int -> Nonce
..} Nonce
exporter_context Int
len =
    Nonce -> Int -> Nonce
expandR Nonce
exporter_context Int
len

----------------------------------------------------------------

newContextS
    :: Key
    -> Nonce
    -> (Key -> Seal)
    -> (Info -> Int -> Key)
    -> IO ContextS
newContextS :: Nonce
-> Nonce
-> (Nonce -> Seal)
-> (Nonce -> Int -> Nonce)
-> IO ContextS
newContextS Nonce
key Nonce
nonce_base Nonce -> Seal
seal' Nonce -> Int -> Nonce
expand = do
    IORef Integer
seqref <- Integer -> IO (IORef Integer)
forall a. a -> IO (IORef a)
newIORef Integer
0
    ContextS -> IO ContextS
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ContextS -> IO ContextS) -> ContextS -> IO ContextS
forall a b. (a -> b) -> a -> b
$
        ContextS
            { seqRefS :: IORef Integer
seqRefS = IORef Integer
seqref
            , sealS :: Seal
sealS = Nonce -> Seal
seal' Nonce
key
            , nonceBaseS :: Nonce
nonceBaseS = Nonce
nonce_base
            , expandS :: Nonce -> Int -> Nonce
expandS = Nonce -> Int -> Nonce
expand
            }

----------------------------------------------------------------

newContextR
    :: Key
    -> Nonce
    -> (Key -> Open)
    -> (Info -> Int -> Key)
    -> IO ContextR
newContextR :: Nonce
-> Nonce
-> (Nonce -> Seal)
-> (Nonce -> Int -> Nonce)
-> IO ContextR
newContextR Nonce
key Nonce
nonce_base Nonce -> Seal
open' Nonce -> Int -> Nonce
expand = do
    IORef Integer
seqref <- Integer -> IO (IORef Integer)
forall a. a -> IO (IORef a)
newIORef Integer
0
    ContextR -> IO ContextR
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ContextR -> IO ContextR) -> ContextR -> IO ContextR
forall a b. (a -> b) -> a -> b
$
        ContextR
            { seqRefR :: IORef Integer
seqRefR = IORef Integer
seqref
            , openR :: Seal
openR = Nonce -> Seal
open' Nonce
key
            , nonceBaseR :: Nonce
nonceBaseR = Nonce
nonce_base
            , expandR :: Nonce -> Int -> Nonce
expandR = Nonce -> Int -> Nonce
expand
            }