module Network.TLS.IO.Encode (
    encodePacket12,
    encodePacket13,
    updateTranscriptHash12,
    updateTranscriptHash13,
) where

import Control.Concurrent.MVar
import Control.Monad.State.Strict
import qualified Data.ByteString as B
import Data.IORef

import Network.TLS.Cipher
import Network.TLS.Context.Internal
import Network.TLS.Handshake.State
import Network.TLS.Handshake.TranscriptHash
import Network.TLS.Imports
import Network.TLS.Packet
import Network.TLS.Packet13
import Network.TLS.Parameters
import Network.TLS.Record
import Network.TLS.State
import Network.TLS.Struct
import Network.TLS.Struct13
import Network.TLS.Types (Role (..))
import Network.TLS.Util

-- | encodePacket transform a packet into marshalled data related to current state
-- and updating state on the go
encodePacket12
    :: Monoid bytes
    => Context
    -> RecordLayer bytes
    -> Packet
    -> IO (Either TLSError bytes)
encodePacket12 :: forall bytes.
Monoid bytes =>
Context
-> RecordLayer bytes -> Packet -> IO (Either TLSError bytes)
encodePacket12 Context
ctx RecordLayer bytes
recordLayer Packet
pkt = do
    (Version
ver, Bool
_) <- Context -> IO (Version, Bool)
decideRecordVersion Context
ctx
    let pt :: ProtocolType
pt = Packet -> ProtocolType
packetType Packet
pkt
        mkRecord :: ByteString -> Record Plaintext
mkRecord ByteString
bs = ProtocolType -> Version -> Fragment Plaintext -> Record Plaintext
forall a. ProtocolType -> Version -> Fragment a -> Record a
Record ProtocolType
pt Version
ver (ByteString -> Fragment Plaintext
fragmentPlaintext ByteString
bs)
    Maybe Int
mlen <- Context -> IO (Maybe Int)
getPeerRecordLimit Context
ctx
    [Record Plaintext]
records <- (ByteString -> Record Plaintext)
-> [ByteString] -> [Record Plaintext]
forall a b. (a -> b) -> [a] -> [b]
map ByteString -> Record Plaintext
mkRecord ([ByteString] -> [Record Plaintext])
-> IO [ByteString] -> IO [Record Plaintext]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Context -> Maybe Int -> Packet -> IO [ByteString]
packetToFragments12 Context
ctx Maybe Int
mlen Packet
pkt
    Either TLSError bytes
bs <- ([bytes] -> bytes)
-> Either TLSError [bytes] -> Either TLSError bytes
forall a b. (a -> b) -> Either TLSError a -> Either TLSError b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [bytes] -> bytes
forall a. Monoid a => [a] -> a
mconcat (Either TLSError [bytes] -> Either TLSError bytes)
-> IO (Either TLSError [bytes]) -> IO (Either TLSError bytes)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Record Plaintext]
-> (Record Plaintext -> IO (Either TLSError bytes))
-> IO (Either TLSError [bytes])
forall (m :: * -> *) a l b.
Monad m =>
[a] -> (a -> m (Either l b)) -> m (Either l [b])
forEitherM [Record Plaintext]
records (RecordLayer bytes
-> Context -> Record Plaintext -> IO (Either TLSError bytes)
forall a.
RecordLayer a
-> Context -> Record Plaintext -> IO (Either TLSError a)
recordEncode12 RecordLayer bytes
recordLayer Context
ctx)
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Packet
pkt Packet -> Packet -> Bool
forall a. Eq a => a -> a -> Bool
== Packet
ChangeCipherSpec) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Context -> IO ()
switchTxEncryption Context
ctx
    Either TLSError bytes -> IO (Either TLSError bytes)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Either TLSError bytes
bs

-- Decompose handshake packets into fragments of the specified length.  AppData
-- packets are not fragmented here but by callers of sendPacket, so that the
-- empty-packet countermeasure may be applied to each fragment independently.
packetToFragments12 :: Context -> Maybe Int -> Packet -> IO [ByteString]
packetToFragments12 :: Context -> Maybe Int -> Packet -> IO [ByteString]
packetToFragments12 Context
ctx Maybe Int
mlen (Handshake [Handshake]
hss) =
    Maybe Int -> ByteString -> [ByteString]
getChunks Maybe Int
mlen (ByteString -> [ByteString])
-> ([ByteString] -> ByteString) -> [ByteString] -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ByteString] -> ByteString
B.concat ([ByteString] -> [ByteString])
-> IO [ByteString] -> IO [ByteString]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Handshake -> IO ByteString) -> [Handshake] -> IO [ByteString]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Context -> Handshake -> IO ByteString
updateTranscriptHash12 Context
ctx) [Handshake]
hss
packetToFragments12 Context
_ Maybe Int
_ (Alert [(AlertLevel, AlertDescription)]
a) = [ByteString] -> IO [ByteString]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return [[(AlertLevel, AlertDescription)] -> ByteString
encodeAlerts [(AlertLevel, AlertDescription)]
a]
packetToFragments12 Context
_ Maybe Int
_ Packet
ChangeCipherSpec = [ByteString] -> IO [ByteString]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return [ByteString
encodeChangeCipherSpec]
packetToFragments12 Context
_ Maybe Int
_ (AppData ByteString
x) = [ByteString] -> IO [ByteString]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return [ByteString
x]

switchTxEncryption :: Context -> IO ()
switchTxEncryption :: Context -> IO ()
switchTxEncryption Context
ctx = do
    RecordState
tx <- Context -> HandshakeM RecordState -> IO RecordState
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (Maybe RecordState -> RecordState
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe RecordState -> RecordState)
-> HandshakeM (Maybe RecordState) -> HandshakeM RecordState
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HandshakeState -> Maybe RecordState)
-> HandshakeM (Maybe RecordState)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> Maybe RecordState
hstPendingTxState)
    (Version
ver, Role
role) <- Context -> TLSSt (Version, Role) -> IO (Version, Role)
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx (TLSSt (Version, Role) -> IO (Version, Role))
-> TLSSt (Version, Role) -> IO (Version, Role)
forall a b. (a -> b) -> a -> b
$ do
        Version
v <- TLSSt Version
getVersion
        Role
r <- TLSSt Role
getRole
        (Version, Role) -> TLSSt (Version, Role)
forall a. a -> TLSSt a
forall (m :: * -> *) a. Monad m => a -> m a
return (Version
v, Role
r)
    IO () -> IO ()
forall a. IO a -> IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ MVar RecordState -> (RecordState -> IO RecordState) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ (Context -> MVar RecordState
ctxTxRecordState Context
ctx) (\RecordState
_ -> RecordState -> IO RecordState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return RecordState
tx)
    -- set empty packet counter measure if condition are met
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when
        ( Version
ver Version -> Version -> Bool
forall a. Ord a => a -> a -> Bool
<= Version
TLS10
            Bool -> Bool -> Bool
&& Role
role Role -> Role -> Bool
forall a. Eq a => a -> a -> Bool
== Role
ClientRole
            Bool -> Bool -> Bool
&& RecordState -> Bool
isCBC RecordState
tx
            Bool -> Bool -> Bool
&& Supported -> Bool
supportedEmptyPacket (Context -> Supported
ctxSupported Context
ctx)
        )
        (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ IO () -> IO ()
forall a. IO a -> IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO
        (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ IORef Bool -> Bool -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef (Context -> IORef Bool
ctxNeedEmptyPacket Context
ctx) Bool
True
  where
    isCBC :: RecordState -> Bool
isCBC RecordState
tx = Bool -> (Cipher -> Bool) -> Maybe Cipher -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (\Cipher
c -> Bulk -> Int
bulkBlockSize (Cipher -> Bulk
cipherBulk Cipher
c) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0) (RecordState -> Maybe Cipher
stCipher RecordState
tx)

updateTranscriptHash12 :: Context -> Handshake -> IO ByteString
updateTranscriptHash12 :: Context -> Handshake -> IO ByteString
updateTranscriptHash12 Context
ctx Handshake
hs = do
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Handshake -> Bool
certVerifyHandshakeMaterial Handshake
hs) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
        Context -> HandshakeM () -> IO ()
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (HandshakeM () -> IO ()) -> HandshakeM () -> IO ()
forall a b. (a -> b) -> a -> b
$
            ByteString -> HandshakeM ()
addHandshakeMessage ByteString
encoded
    let label :: String
label = HandshakeType -> String
forall a. Show a => a -> String
show (HandshakeType -> String) -> HandshakeType -> String
forall a b. (a -> b) -> a -> b
$ Handshake -> HandshakeType
typeOfHandshake Handshake
hs
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Handshake -> Bool
finishedHandshakeMaterial Handshake
hs) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Context -> String -> ByteString -> IO ()
updateTranscriptHash Context
ctx String
label ByteString
encoded
    ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
encoded
  where
    encoded :: ByteString
encoded = Handshake -> ByteString
encodeHandshake Handshake
hs

----------------------------------------------------------------

encodePacket13
    :: Monoid bytes
    => Context
    -> RecordLayer bytes
    -> Packet13
    -> IO (Either TLSError bytes)
encodePacket13 :: forall bytes.
Monoid bytes =>
Context
-> RecordLayer bytes -> Packet13 -> IO (Either TLSError bytes)
encodePacket13 Context
ctx RecordLayer bytes
recordLayer Packet13
pkt = do
    let pt :: ProtocolType
pt = Packet13 -> ProtocolType
contentType Packet13
pkt
        mkRecord :: ByteString -> Record Plaintext
mkRecord ByteString
bs = ProtocolType -> Version -> Fragment Plaintext -> Record Plaintext
forall a. ProtocolType -> Version -> Fragment a -> Record a
Record ProtocolType
pt Version
TLS12 (ByteString -> Fragment Plaintext
fragmentPlaintext ByteString
bs)
    Maybe Int
mlen <- Context -> IO (Maybe Int)
getPeerRecordLimit Context
ctx
    [Record Plaintext]
records <- (ByteString -> Record Plaintext)
-> [ByteString] -> [Record Plaintext]
forall a b. (a -> b) -> [a] -> [b]
map ByteString -> Record Plaintext
mkRecord ([ByteString] -> [Record Plaintext])
-> IO [ByteString] -> IO [Record Plaintext]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Context -> Maybe Int -> Packet13 -> IO [ByteString]
packetToFragments13 Context
ctx Maybe Int
mlen Packet13
pkt
    ([bytes] -> bytes)
-> Either TLSError [bytes] -> Either TLSError bytes
forall a b. (a -> b) -> Either TLSError a -> Either TLSError b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [bytes] -> bytes
forall a. Monoid a => [a] -> a
mconcat (Either TLSError [bytes] -> Either TLSError bytes)
-> IO (Either TLSError [bytes]) -> IO (Either TLSError bytes)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Record Plaintext]
-> (Record Plaintext -> IO (Either TLSError bytes))
-> IO (Either TLSError [bytes])
forall (m :: * -> *) a l b.
Monad m =>
[a] -> (a -> m (Either l b)) -> m (Either l [b])
forEitherM [Record Plaintext]
records (RecordLayer bytes
-> Context -> Record Plaintext -> IO (Either TLSError bytes)
forall a.
RecordLayer a
-> Context -> Record Plaintext -> IO (Either TLSError a)
recordEncode13 RecordLayer bytes
recordLayer Context
ctx)

packetToFragments13 :: Context -> Maybe Int -> Packet13 -> IO [ByteString]
packetToFragments13 :: Context -> Maybe Int -> Packet13 -> IO [ByteString]
packetToFragments13 Context
ctx Maybe Int
mlen (Handshake13 [Handshake13]
hss) =
    Maybe Int -> ByteString -> [ByteString]
getChunks Maybe Int
mlen (ByteString -> [ByteString])
-> ([ByteString] -> ByteString) -> [ByteString] -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ByteString] -> ByteString
B.concat ([ByteString] -> [ByteString])
-> IO [ByteString] -> IO [ByteString]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Handshake13 -> IO ByteString) -> [Handshake13] -> IO [ByteString]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Context -> Handshake13 -> IO ByteString
updateTranscriptHash13 Context
ctx) [Handshake13]
hss
packetToFragments13 Context
_ Maybe Int
_ (Alert13 [(AlertLevel, AlertDescription)]
a) = [ByteString] -> IO [ByteString]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return [[(AlertLevel, AlertDescription)] -> ByteString
encodeAlerts [(AlertLevel, AlertDescription)]
a]
packetToFragments13 Context
_ Maybe Int
_ (AppData13 ByteString
x) = [ByteString] -> IO [ByteString]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return [ByteString
x]
packetToFragments13 Context
_ Maybe Int
_ Packet13
ChangeCipherSpec13 = [ByteString] -> IO [ByteString]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return [ByteString
encodeChangeCipherSpec]

updateTranscriptHash13 :: Context -> Handshake13 -> IO ByteString
updateTranscriptHash13 :: Context -> Handshake13 -> IO ByteString
updateTranscriptHash13 Context
ctx Handshake13
hs
    | Handshake13 -> Bool
isIgnored Handshake13
hs = ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
encoded
    | Bool
otherwise = do
        let label :: String
label = HandshakeType -> String
forall a. Show a => a -> String
show (HandshakeType -> String) -> HandshakeType -> String
forall a b. (a -> b) -> a -> b
$ Handshake13 -> HandshakeType
typeOfHandshake13 Handshake13
hs
        Context -> String -> ByteString -> IO ()
updateTranscriptHash Context
ctx String
label ByteString
encoded
        Context -> HandshakeM () -> IO ()
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (HandshakeM () -> IO ()) -> HandshakeM () -> IO ()
forall a b. (a -> b) -> a -> b
$ ByteString -> HandshakeM ()
addHandshakeMessage ByteString
encoded
        ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
encoded
  where
    encoded :: ByteString
encoded = Handshake13 -> ByteString
encodeHandshake13 Handshake13
hs

    isIgnored :: Handshake13 -> Bool
isIgnored NewSessionTicket13{} = Bool
True
    isIgnored KeyUpdate13{} = Bool
True
    isIgnored Handshake13
_ = Bool
False