{-# LANGUAGE BangPatterns #-}
module Network.TLS.Record.Encrypt (
    encryptRecord,
) where
import Control.Monad.State.Strict
import Crypto.Cipher.Types (AuthTag (..))
import qualified Data.ByteArray as B (convert, xor)
import qualified Data.ByteString as B
import Network.TLS.Cipher
import Network.TLS.Imports
import Network.TLS.Packet
import Network.TLS.Record.State
import Network.TLS.Record.Types
import Network.TLS.Wire
encryptRecord :: Record Plaintext -> RecordM (Record Ciphertext)
encryptRecord :: Record Plaintext -> RecordM (Record Ciphertext)
encryptRecord record :: Record Plaintext
record@(Record ProtocolType
ct Version
ver Fragment Plaintext
fragment) = do
    RecordState
st <- RecordM RecordState
forall s (m :: * -> *). MonadState s m => m s
get
    case RecordState -> Maybe Cipher
stCipher RecordState
st of
        Maybe Cipher
Nothing -> RecordM (Record Ciphertext)
noEncryption
        Maybe Cipher
_ -> do
            RecordOptions
recOpts <- RecordM RecordOptions
getRecordOptions
            if RecordOptions -> Bool
recordTLS13 RecordOptions
recOpts
                then RecordM (Record Ciphertext)
encryptContent13
                else Record Plaintext
-> (Fragment Plaintext -> RecordM (Fragment Ciphertext))
-> RecordM (Record Ciphertext)
forall a b.
Record a
-> (Fragment a -> RecordM (Fragment b)) -> RecordM (Record b)
onRecordFragment Record Plaintext
record ((Fragment Plaintext -> RecordM (Fragment Ciphertext))
 -> RecordM (Record Ciphertext))
-> (Fragment Plaintext -> RecordM (Fragment Ciphertext))
-> RecordM (Record Ciphertext)
forall a b. (a -> b) -> a -> b
$ (ByteString -> RecordM ByteString)
-> Fragment Plaintext -> RecordM (Fragment Ciphertext)
fragmentCipher (Bool -> Record Plaintext -> ByteString -> RecordM ByteString
encryptContent Bool
False Record Plaintext
record)
  where
    noEncryption :: RecordM (Record Ciphertext)
noEncryption = Record Plaintext
-> (Fragment Plaintext -> RecordM (Fragment Ciphertext))
-> RecordM (Record Ciphertext)
forall a b.
Record a
-> (Fragment a -> RecordM (Fragment b)) -> RecordM (Record b)
onRecordFragment Record Plaintext
record ((Fragment Plaintext -> RecordM (Fragment Ciphertext))
 -> RecordM (Record Ciphertext))
-> (Fragment Plaintext -> RecordM (Fragment Ciphertext))
-> RecordM (Record Ciphertext)
forall a b. (a -> b) -> a -> b
$ (ByteString -> RecordM ByteString)
-> Fragment Plaintext -> RecordM (Fragment Ciphertext)
fragmentCipher ByteString -> RecordM ByteString
forall a. a -> RecordM a
forall (m :: * -> *) a. Monad m => a -> m a
return
    encryptContent13 :: RecordM (Record Ciphertext)
encryptContent13
        | ProtocolType
ct ProtocolType -> ProtocolType -> Bool
forall a. Eq a => a -> a -> Bool
== ProtocolType
ProtocolType_ChangeCipherSpec = RecordM (Record Ciphertext)
noEncryption
        | Bool
otherwise = do
            let bytes :: ByteString
bytes = Fragment Plaintext -> ByteString
forall a. Fragment a -> ByteString
fragmentGetBytes Fragment Plaintext
fragment
                fragment' :: Fragment Plaintext
fragment' = ByteString -> Fragment Plaintext
fragmentPlaintext (ByteString -> Fragment Plaintext)
-> ByteString -> Fragment Plaintext
forall a b. (a -> b) -> a -> b
$ ProtocolType -> ByteString -> ByteString
innerPlaintext ProtocolType
ct ByteString
bytes
                record' :: Record Plaintext
record' = ProtocolType -> Version -> Fragment Plaintext -> Record Plaintext
forall a. ProtocolType -> Version -> Fragment a -> Record a
Record ProtocolType
ProtocolType_AppData Version
ver Fragment Plaintext
fragment'
            Record Plaintext
-> (Fragment Plaintext -> RecordM (Fragment Ciphertext))
-> RecordM (Record Ciphertext)
forall a b.
Record a
-> (Fragment a -> RecordM (Fragment b)) -> RecordM (Record b)
onRecordFragment Record Plaintext
record' ((Fragment Plaintext -> RecordM (Fragment Ciphertext))
 -> RecordM (Record Ciphertext))
-> (Fragment Plaintext -> RecordM (Fragment Ciphertext))
-> RecordM (Record Ciphertext)
forall a b. (a -> b) -> a -> b
$ (ByteString -> RecordM ByteString)
-> Fragment Plaintext -> RecordM (Fragment Ciphertext)
fragmentCipher (Bool -> Record Plaintext -> ByteString -> RecordM ByteString
encryptContent Bool
True Record Plaintext
record')
innerPlaintext :: ProtocolType -> ByteString -> ByteString
innerPlaintext :: ProtocolType -> ByteString -> ByteString
innerPlaintext (ProtocolType Word8
c) ByteString
bytes = Put -> ByteString
runPut (Put -> ByteString) -> Put -> ByteString
forall a b. (a -> b) -> a -> b
$ do
    ByteString -> Put
putBytes ByteString
bytes
    Putter Word8
putWord8 Word8
c 
    
encryptContent :: Bool -> Record Plaintext -> ByteString -> RecordM ByteString
encryptContent :: Bool -> Record Plaintext -> ByteString -> RecordM ByteString
encryptContent Bool
tls13 Record Plaintext
record ByteString
content = do
    CryptState
cst <- RecordM CryptState
getCryptState
    Bulk
bulk <- RecordM Bulk
getBulk
    case CryptState -> BulkState
cstKey CryptState
cst of
        BulkStateBlock BulkBlock
encryptF -> do
            ByteString
digest <- Header -> ByteString -> RecordM ByteString
makeDigest (Record Plaintext -> Header
forall a. Record a -> Header
recordToHeader Record Plaintext
record) ByteString
content
            let content' :: ByteString
content' = [ByteString] -> ByteString
B.concat [ByteString
content, ByteString
digest]
            BulkBlock -> ByteString -> Bulk -> RecordM ByteString
encryptBlock BulkBlock
encryptF ByteString
content' Bulk
bulk
        BulkStateStream BulkStream
encryptF -> do
            ByteString
digest <- Header -> ByteString -> RecordM ByteString
makeDigest (Record Plaintext -> Header
forall a. Record a -> Header
recordToHeader Record Plaintext
record) ByteString
content
            let content' :: ByteString
content' = [ByteString] -> ByteString
B.concat [ByteString
content, ByteString
digest]
            BulkStream -> ByteString -> RecordM ByteString
encryptStream BulkStream
encryptF ByteString
content'
        BulkStateAEAD BulkAEAD
encryptF ->
            Bool
-> Bulk
-> BulkAEAD
-> ByteString
-> Record Plaintext
-> RecordM ByteString
encryptAead Bool
tls13 Bulk
bulk BulkAEAD
encryptF ByteString
content Record Plaintext
record
        BulkState
BulkStateUninitialized ->
            ByteString -> RecordM ByteString
forall a. a -> RecordM a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
content
encryptBlock :: BulkBlock -> ByteString -> Bulk -> RecordM ByteString
encryptBlock :: BulkBlock -> ByteString -> Bulk -> RecordM ByteString
encryptBlock BulkBlock
encryptF ByteString
content Bulk
bulk = do
    CryptState
cst <- RecordM CryptState
getCryptState
    let blockSize :: Int
blockSize = Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ Bulk -> Int
bulkBlockSize Bulk
bulk
    let msg_len :: Int
msg_len = ByteString -> Int
B.length ByteString
content
    let padding :: ByteString
padding =
            if Int
blockSize Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0
                then
                    let padbyte :: Int
padbyte = Int
blockSize Int -> Int -> Int
forall a. Num a => a -> a -> a
- (Int
msg_len Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
blockSize)
                     in let padbyte' :: Int
padbyte' = if Int
padbyte Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 then Int
blockSize else Int
padbyte
                         in Int -> Word8 -> ByteString
B.replicate Int
padbyte' (Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
padbyte' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1))
                else ByteString
B.empty
    let (ByteString
e, ByteString
_iv') = BulkBlock
encryptF (CryptState -> ByteString
cstIV CryptState
cst) (ByteString -> (ByteString, ByteString))
-> ByteString -> (ByteString, ByteString)
forall a b. (a -> b) -> a -> b
$ [ByteString] -> ByteString
B.concat [ByteString
content, ByteString
padding]
    ByteString -> RecordM ByteString
forall a. a -> RecordM a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> RecordM ByteString)
-> ByteString -> RecordM ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> ByteString
B.concat [CryptState -> ByteString
cstIV CryptState
cst, ByteString
e]
encryptStream :: BulkStream -> ByteString -> RecordM ByteString
encryptStream :: BulkStream -> ByteString -> RecordM ByteString
encryptStream (BulkStream ByteString -> (ByteString, BulkStream)
encryptF) ByteString
content = do
    CryptState
cst <- RecordM CryptState
getCryptState
    let (!ByteString
e, !BulkStream
newBulkStream) = ByteString -> (ByteString, BulkStream)
encryptF ByteString
content
    (RecordState -> RecordState) -> RecordM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((RecordState -> RecordState) -> RecordM ())
-> (RecordState -> RecordState) -> RecordM ()
forall a b. (a -> b) -> a -> b
$ \RecordState
tstate -> RecordState
tstate{stCryptState = cst{cstKey = BulkStateStream newBulkStream}}
    ByteString -> RecordM ByteString
forall a. a -> RecordM a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
e
encryptAead
    :: Bool
    -> Bulk
    -> BulkAEAD
    -> ByteString
    -> Record Plaintext
    -> RecordM ByteString
encryptAead :: Bool
-> Bulk
-> BulkAEAD
-> ByteString
-> Record Plaintext
-> RecordM ByteString
encryptAead Bool
tls13 Bulk
bulk BulkAEAD
encryptF ByteString
content Record Plaintext
record = do
    let authTagLen :: Int
authTagLen = Bulk -> Int
bulkAuthTagLen Bulk
bulk
        nonceExpLen :: Int
nonceExpLen = Bulk -> Int
bulkExplicitIV Bulk
bulk
    CryptState
cst <- RecordM CryptState
getCryptState
    ByteString
encodedSeq <- Word64 -> ByteString
encodeWord64 (Word64 -> ByteString) -> RecordM Word64 -> RecordM ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> RecordM Word64
getMacSequence
    let iv :: ByteString
iv = CryptState -> ByteString
cstIV CryptState
cst
        ivlen :: Int
ivlen = ByteString -> Int
B.length ByteString
iv
        Header ProtocolType
typ Version
v Word16
plainLen = Record Plaintext -> Header
forall a. Record a -> Header
recordToHeader Record Plaintext
record
        hdrLen :: Word16
hdrLen = if Bool
tls13 then Word16
plainLen Word16 -> Word16 -> Word16
forall a. Num a => a -> a -> a
+ Int -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
authTagLen else Word16
plainLen
        hdr :: Header
hdr = ProtocolType -> Version -> Word16 -> Header
Header ProtocolType
typ Version
v Word16
hdrLen
        ad :: ByteString
ad
            | Bool
tls13 = Header -> ByteString
encodeHeader Header
hdr
            | Bool
otherwise = [ByteString] -> ByteString
B.concat [ByteString
encodedSeq, Header -> ByteString
encodeHeader Header
hdr]
        sqnc :: ByteString
sqnc = Int -> Word8 -> ByteString
B.replicate (Int
ivlen Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
8) Word8
0 ByteString -> ByteString -> ByteString
`B.append` ByteString
encodedSeq
        nonce :: ByteString
nonce
            | Int
nonceExpLen Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = ByteString -> ByteString -> ByteString
forall a b c.
(ByteArrayAccess a, ByteArrayAccess b, ByteArray c) =>
a -> b -> c
B.xor ByteString
iv ByteString
sqnc
            | Bool
otherwise = [ByteString] -> ByteString
B.concat [ByteString
iv, ByteString
encodedSeq]
        (ByteString
e, AuthTag Bytes
authtag) = BulkAEAD
encryptF ByteString
nonce ByteString
content ByteString
ad
        econtent :: ByteString
econtent
            | Int
nonceExpLen Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = ByteString
e ByteString -> ByteString -> ByteString
`B.append` Bytes -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
B.convert Bytes
authtag
            | Bool
otherwise = [ByteString] -> ByteString
B.concat [ByteString
encodedSeq, ByteString
e, Bytes -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
B.convert Bytes
authtag]
    (RecordState -> RecordState) -> RecordM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' RecordState -> RecordState
incrRecordState
    ByteString -> RecordM ByteString
forall a. a -> RecordM a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
econtent
getCryptState :: RecordM CryptState
getCryptState :: RecordM CryptState
getCryptState = RecordState -> CryptState
stCryptState (RecordState -> CryptState)
-> RecordM RecordState -> RecordM CryptState
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> RecordM RecordState
forall s (m :: * -> *). MonadState s m => m s
get