{-# LANGUAGE ExistentialQuantification #-}
{-# OPTIONS_HADDOCK hide #-}

module Network.TLS.Cipher (
    CipherKeyExchangeType (..),
    Bulk (..),
    BulkFunctions (..),
    BulkDirection (..),
    BulkState (..),
    BulkStream (..),
    BulkBlock,
    BulkAEAD,
    bulkInit,
    Hash (..),
    Cipher (..),
    CipherID,
    cipherKeyBlockSize,
    BulkKey,
    BulkIV,
    BulkNonce,
    BulkAdditionalData,
    cipherAllowedForVersion,
    hasMAC,
    hasRecordIV,
    elemCipher,
    intersectCiphers,
    findCipher,
) where

import Network.TLS.Crypto (Hash (..), hashDigestSize)
import Network.TLS.Imports
import Network.TLS.Types

data BulkState
    = BulkStateStream BulkStream
    | BulkStateBlock BulkBlock
    | BulkStateAEAD BulkAEAD
    | BulkStateUninitialized

instance Show BulkState where
    show :: BulkState -> String
show (BulkStateStream BulkStream
_) = String
"BulkStateStream"
    show (BulkStateBlock BulkBlock
_) = String
"BulkStateBlock"
    show (BulkStateAEAD BulkAEAD
_) = String
"BulkStateAEAD"
    show BulkState
BulkStateUninitialized = String
"BulkStateUninitialized"

bulkInit :: Bulk -> BulkDirection -> BulkKey -> BulkState
bulkInit :: Bulk -> BulkDirection -> BulkKey -> BulkState
bulkInit Bulk
bulk BulkDirection
direction BulkKey
key =
    case Bulk -> BulkFunctions
bulkF Bulk
bulk of
        BulkBlockF BulkDirection -> BulkKey -> BulkBlock
ini -> BulkBlock -> BulkState
BulkStateBlock (BulkDirection -> BulkKey -> BulkBlock
ini BulkDirection
direction BulkKey
key)
        BulkStreamF BulkDirection -> BulkKey -> BulkStream
ini -> BulkStream -> BulkState
BulkStateStream (BulkDirection -> BulkKey -> BulkStream
ini BulkDirection
direction BulkKey
key)
        BulkAeadF BulkDirection -> BulkKey -> BulkAEAD
ini -> BulkAEAD -> BulkState
BulkStateAEAD (BulkDirection -> BulkKey -> BulkAEAD
ini BulkDirection
direction BulkKey
key)

hasMAC, hasRecordIV :: BulkFunctions -> Bool
hasMAC :: BulkFunctions -> Bool
hasMAC (BulkBlockF BulkDirection -> BulkKey -> BulkBlock
_) = Bool
True
hasMAC (BulkStreamF BulkDirection -> BulkKey -> BulkStream
_) = Bool
True
hasMAC (BulkAeadF BulkDirection -> BulkKey -> BulkAEAD
_) = Bool
False
hasRecordIV :: BulkFunctions -> Bool
hasRecordIV = BulkFunctions -> Bool
hasMAC

cipherKeyBlockSize :: Cipher -> Int
cipherKeyBlockSize :: Cipher -> Int
cipherKeyBlockSize Cipher
cipher = Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Hash -> Int
hashDigestSize (Cipher -> Hash
cipherHash Cipher
cipher) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Bulk -> Int
bulkIVSize Bulk
bulk Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Bulk -> Int
bulkKeySize Bulk
bulk)
  where
    bulk :: Bulk
bulk = Cipher -> Bulk
cipherBulk Cipher
cipher

-- | Check if a specific 'Cipher' is allowed to be used
-- with the version specified
cipherAllowedForVersion :: Version -> Cipher -> Bool
cipherAllowedForVersion :: Version -> Cipher -> Bool
cipherAllowedForVersion Version
ver Cipher
cipher =
    case Cipher -> Maybe Version
cipherMinVer Cipher
cipher of
        Maybe Version
Nothing -> Version
ver Version -> Version -> Bool
forall a. Ord a => a -> a -> Bool
< Version
TLS13
        Just Version
cVer -> Version
cVer Version -> Version -> Bool
forall a. Ord a => a -> a -> Bool
<= Version
ver Bool -> Bool -> Bool
&& (Version
ver Version -> Version -> Bool
forall a. Ord a => a -> a -> Bool
< Version
TLS13 Bool -> Bool -> Bool
|| Version
cVer Version -> Version -> Bool
forall a. Ord a => a -> a -> Bool
>= Version
TLS13)

eqCipher :: CipherID -> Cipher -> Bool
eqCipher :: CipherID -> Cipher -> Bool
eqCipher CipherID
cid Cipher
c = Cipher -> CipherID
cipherID Cipher
c CipherID -> CipherID -> Bool
forall a. Eq a => a -> a -> Bool
== CipherID
cid

elemCipher :: [CipherId] -> Cipher -> Bool
elemCipher :: [CipherId] -> Cipher -> Bool
elemCipher [CipherId]
cids Cipher
c = CipherId
cid CipherId -> [CipherId] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [CipherId]
cids
  where
    cid :: CipherId
cid = CipherID -> CipherId
CipherId (CipherID -> CipherId) -> CipherID -> CipherId
forall a b. (a -> b) -> a -> b
$ Cipher -> CipherID
cipherID Cipher
c

intersectCiphers :: [CipherId] -> [Cipher] -> [Cipher]
intersectCiphers :: [CipherId] -> [Cipher] -> [Cipher]
intersectCiphers [CipherId]
peerCiphers [Cipher]
myCiphers = (Cipher -> Bool) -> [Cipher] -> [Cipher]
forall a. (a -> Bool) -> [a] -> [a]
filter ([CipherId] -> Cipher -> Bool
elemCipher [CipherId]
peerCiphers) [Cipher]
myCiphers

findCipher :: CipherID -> [Cipher] -> Maybe Cipher
findCipher :: CipherID -> [Cipher] -> Maybe Cipher
findCipher CipherID
cid = (Cipher -> Bool) -> [Cipher] -> Maybe Cipher
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((Cipher -> Bool) -> [Cipher] -> Maybe Cipher)
-> (Cipher -> Bool) -> [Cipher] -> Maybe Cipher
forall a b. (a -> b) -> a -> b
$ CipherID -> Cipher -> Bool
eqCipher CipherID
cid