-- | TLS record layer in Rx direction
module Network.TLS.Record.Recv (
    recvRecord12,
    recvRecord13,
) where

import qualified Data.ByteString as B

import Network.TLS.Context.Internal
import Network.TLS.Hooks
import Network.TLS.Imports
import Network.TLS.Packet
import Network.TLS.Record
import Network.TLS.Struct
import Network.TLS.Types

----------------------------------------------------------------

getMyPlainLimit :: Context -> IO Int
getMyPlainLimit :: Context -> IO Int
getMyPlainLimit Context
ctx = do
    Maybe Int
msiz <- Context -> IO (Maybe Int)
getMyRecordLimit Context
ctx
    Int -> IO Int
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> IO Int) -> Int -> IO Int
forall a b. (a -> b) -> a -> b
$ case Maybe Int
msiz of
        Maybe Int
Nothing -> Int
defaultRecordSizeLimit
        Just Int
siz -> Int
siz

getRecord
    :: Context
    -> Header
    -> ByteString
    -> IO (Either TLSError (Record Plaintext))
getRecord :: Context
-> Header -> ByteString -> IO (Either TLSError (Record Plaintext))
getRecord Context
ctx Header
header ByteString
content = do
    Context -> (Logging -> IO ()) -> IO ()
withLog Context
ctx ((Logging -> IO ()) -> IO ()) -> (Logging -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Logging
logging -> Logging -> Header -> ByteString -> IO ()
loggingIORecv Logging
logging Header
header ByteString
content
    Int
lim <- Context -> IO Int
getMyPlainLimit Context
ctx
    Context
-> RecordM (Record Plaintext)
-> IO (Either TLSError (Record Plaintext))
forall a. Context -> RecordM a -> IO (Either TLSError a)
runRxRecordState Context
ctx (RecordM (Record Plaintext)
 -> IO (Either TLSError (Record Plaintext)))
-> RecordM (Record Plaintext)
-> IO (Either TLSError (Record Plaintext))
forall a b. (a -> b) -> a -> b
$ do
        let erecord :: Record Ciphertext
erecord = Header -> Fragment Ciphertext -> Record Ciphertext
forall a. Header -> Fragment a -> Record a
rawToRecord Header
header (Fragment Ciphertext -> Record Ciphertext)
-> Fragment Ciphertext -> Record Ciphertext
forall a b. (a -> b) -> a -> b
$ ByteString -> Fragment Ciphertext
fragmentCiphertext ByteString
content
        Record Ciphertext -> Int -> RecordM (Record Plaintext)
decryptRecord Record Ciphertext
erecord Int
lim

----------------------------------------------------------------

exceedsTLSCiphertext :: Int -> Word16 -> Bool
exceedsTLSCiphertext :: Int -> Word16 -> Bool
exceedsTLSCiphertext Int
overhead Word16
actual =
    -- In TLS 1.3, overhead is included one more byte for content type.
    Word16 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
actual Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
defaultRecordSizeLimit Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
overhead

-- | recvRecord receive a full TLS record (header + data), from the other side.
--
-- The record is disengaged from the record layer
recvRecord12
    :: Context
    -- ^ TLS context
    -> IO (Either TLSError (Record Plaintext))
recvRecord12 :: Context -> IO (Either TLSError (Record Plaintext))
recvRecord12 Context
ctx =
    Context -> Int -> IO (Either TLSError ByteString)
readExactBytes Context
ctx Int
5 IO (Either TLSError ByteString)
-> (Either TLSError ByteString
    -> IO (Either TLSError (Record Plaintext)))
-> IO (Either TLSError (Record Plaintext))
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (TLSError -> IO (Either TLSError (Record Plaintext)))
-> (ByteString -> IO (Either TLSError (Record Plaintext)))
-> Either TLSError ByteString
-> IO (Either TLSError (Record Plaintext))
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Either TLSError (Record Plaintext)
-> IO (Either TLSError (Record Plaintext))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError (Record Plaintext)
 -> IO (Either TLSError (Record Plaintext)))
-> (TLSError -> Either TLSError (Record Plaintext))
-> TLSError
-> IO (Either TLSError (Record Plaintext))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TLSError -> Either TLSError (Record Plaintext)
forall a b. a -> Either a b
Left) (Either TLSError Header -> IO (Either TLSError (Record Plaintext))
recvLengthE (Either TLSError Header -> IO (Either TLSError (Record Plaintext)))
-> (ByteString -> Either TLSError Header)
-> ByteString
-> IO (Either TLSError (Record Plaintext))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Either TLSError Header
decodeHeader)
  where
    recvLengthE :: Either TLSError Header -> IO (Either TLSError (Record Plaintext))
recvLengthE = (TLSError -> IO (Either TLSError (Record Plaintext)))
-> (Header -> IO (Either TLSError (Record Plaintext)))
-> Either TLSError Header
-> IO (Either TLSError (Record Plaintext))
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Either TLSError (Record Plaintext)
-> IO (Either TLSError (Record Plaintext))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError (Record Plaintext)
 -> IO (Either TLSError (Record Plaintext)))
-> (TLSError -> Either TLSError (Record Plaintext))
-> TLSError
-> IO (Either TLSError (Record Plaintext))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TLSError -> Either TLSError (Record Plaintext)
forall a b. a -> Either a b
Left) Header -> IO (Either TLSError (Record Plaintext))
recvLength

    recvLength :: Header -> IO (Either TLSError (Record Plaintext))
recvLength header :: Header
header@(Header ProtocolType
_ Version
_ Word16
readlen) = do
        -- RFC 5246 Section 7.2.2
        -- A TLSCiphertext record was received that had a length more
        -- than 2^14+2048 bytes, or a record decrypted to a
        -- TLSCompressed record with more than 2^14+1024 bytes.  This
        -- message is always fatal and should never be observed in
        -- communication between proper implementations (except when
        -- messages were corrupted in the network).
        if Int -> Word16 -> Bool
exceedsTLSCiphertext Int
2048 Word16
readlen
            then Either TLSError (Record Plaintext)
-> IO (Either TLSError (Record Plaintext))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError (Record Plaintext)
 -> IO (Either TLSError (Record Plaintext)))
-> Either TLSError (Record Plaintext)
-> IO (Either TLSError (Record Plaintext))
forall a b. (a -> b) -> a -> b
$ TLSError -> Either TLSError (Record Plaintext)
forall a b. a -> Either a b
Left TLSError
maximumSizeExceeded
            else
                Context -> Int -> IO (Either TLSError ByteString)
readExactBytes Context
ctx (Word16 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
readlen)
                    IO (Either TLSError ByteString)
-> (Either TLSError ByteString
    -> IO (Either TLSError (Record Plaintext)))
-> IO (Either TLSError (Record Plaintext))
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (TLSError -> IO (Either TLSError (Record Plaintext)))
-> (ByteString -> IO (Either TLSError (Record Plaintext)))
-> Either TLSError ByteString
-> IO (Either TLSError (Record Plaintext))
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Either TLSError (Record Plaintext)
-> IO (Either TLSError (Record Plaintext))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError (Record Plaintext)
 -> IO (Either TLSError (Record Plaintext)))
-> (TLSError -> Either TLSError (Record Plaintext))
-> TLSError
-> IO (Either TLSError (Record Plaintext))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TLSError -> Either TLSError (Record Plaintext)
forall a b. a -> Either a b
Left) (Context
-> Header -> ByteString -> IO (Either TLSError (Record Plaintext))
getRecord Context
ctx Header
header)

recvRecord13 :: Context -> IO (Either TLSError (Record Plaintext))
recvRecord13 :: Context -> IO (Either TLSError (Record Plaintext))
recvRecord13 Context
ctx = Context -> Int -> IO (Either TLSError ByteString)
readExactBytes Context
ctx Int
5 IO (Either TLSError ByteString)
-> (Either TLSError ByteString
    -> IO (Either TLSError (Record Plaintext)))
-> IO (Either TLSError (Record Plaintext))
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (TLSError -> IO (Either TLSError (Record Plaintext)))
-> (ByteString -> IO (Either TLSError (Record Plaintext)))
-> Either TLSError ByteString
-> IO (Either TLSError (Record Plaintext))
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Either TLSError (Record Plaintext)
-> IO (Either TLSError (Record Plaintext))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError (Record Plaintext)
 -> IO (Either TLSError (Record Plaintext)))
-> (TLSError -> Either TLSError (Record Plaintext))
-> TLSError
-> IO (Either TLSError (Record Plaintext))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TLSError -> Either TLSError (Record Plaintext)
forall a b. a -> Either a b
Left) (Either TLSError Header -> IO (Either TLSError (Record Plaintext))
recvLengthE (Either TLSError Header -> IO (Either TLSError (Record Plaintext)))
-> (ByteString -> Either TLSError Header)
-> ByteString
-> IO (Either TLSError (Record Plaintext))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Either TLSError Header
decodeHeader)
  where
    recvLengthE :: Either TLSError Header -> IO (Either TLSError (Record Plaintext))
recvLengthE = (TLSError -> IO (Either TLSError (Record Plaintext)))
-> (Header -> IO (Either TLSError (Record Plaintext)))
-> Either TLSError Header
-> IO (Either TLSError (Record Plaintext))
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Either TLSError (Record Plaintext)
-> IO (Either TLSError (Record Plaintext))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError (Record Plaintext)
 -> IO (Either TLSError (Record Plaintext)))
-> (TLSError -> Either TLSError (Record Plaintext))
-> TLSError
-> IO (Either TLSError (Record Plaintext))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TLSError -> Either TLSError (Record Plaintext)
forall a b. a -> Either a b
Left) Header -> IO (Either TLSError (Record Plaintext))
recvLength
    recvLength :: Header -> IO (Either TLSError (Record Plaintext))
recvLength header :: Header
header@(Header ProtocolType
_ Version
_ Word16
readlen) = do
        -- RFC 8446 Section 5.2:
        -- An AEAD algorithm used in TLS 1.3 MUST NOT produce an
        -- expansion greater than 255 octets.  An endpoint that
        -- receives a record from its peer with TLSCiphertext.length
        -- larger than 2^14 + 256 octets MUST terminate the connection
        -- with a "record_overflow" alert.  This limit is derived from
        -- the maximum TLSInnerPlaintext length of 2^14 octets + 1
        -- octet for ContentType + the maximum AEAD expansion of 255
        -- octets.
        if Int -> Word16 -> Bool
exceedsTLSCiphertext Int
256 Word16
readlen
            then Either TLSError (Record Plaintext)
-> IO (Either TLSError (Record Plaintext))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError (Record Plaintext)
 -> IO (Either TLSError (Record Plaintext)))
-> Either TLSError (Record Plaintext)
-> IO (Either TLSError (Record Plaintext))
forall a b. (a -> b) -> a -> b
$ TLSError -> Either TLSError (Record Plaintext)
forall a b. a -> Either a b
Left TLSError
maximumSizeExceeded
            else
                Context -> Int -> IO (Either TLSError ByteString)
readExactBytes Context
ctx (Word16 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
readlen)
                    IO (Either TLSError ByteString)
-> (Either TLSError ByteString
    -> IO (Either TLSError (Record Plaintext)))
-> IO (Either TLSError (Record Plaintext))
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (TLSError -> IO (Either TLSError (Record Plaintext)))
-> (ByteString -> IO (Either TLSError (Record Plaintext)))
-> Either TLSError ByteString
-> IO (Either TLSError (Record Plaintext))
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Either TLSError (Record Plaintext)
-> IO (Either TLSError (Record Plaintext))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError (Record Plaintext)
 -> IO (Either TLSError (Record Plaintext)))
-> (TLSError -> Either TLSError (Record Plaintext))
-> TLSError
-> IO (Either TLSError (Record Plaintext))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TLSError -> Either TLSError (Record Plaintext)
forall a b. a -> Either a b
Left) (Context
-> Header -> ByteString -> IO (Either TLSError (Record Plaintext))
getRecord Context
ctx Header
header)

maximumSizeExceeded :: TLSError
maximumSizeExceeded :: TLSError
maximumSizeExceeded = String -> AlertDescription -> TLSError
Error_Protocol String
"record exceeding maximum size" AlertDescription
RecordOverflow

----------------------------------------------------------------

readExactBytes :: Context -> Int -> IO (Either TLSError ByteString)
readExactBytes :: Context -> Int -> IO (Either TLSError ByteString)
readExactBytes Context
ctx Int
sz = do
    ByteString
hdrbs <- Context -> Int -> IO ByteString
contextRecv Context
ctx Int
sz
    if ByteString -> Int
B.length ByteString
hdrbs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
sz
        then Either TLSError ByteString -> IO (Either TLSError ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError ByteString -> IO (Either TLSError ByteString))
-> Either TLSError ByteString -> IO (Either TLSError ByteString)
forall a b. (a -> b) -> a -> b
$ ByteString -> Either TLSError ByteString
forall a b. b -> Either a b
Right ByteString
hdrbs
        else do
            Context -> IO ()
setEOF Context
ctx
            Either TLSError ByteString -> IO (Either TLSError ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError ByteString -> IO (Either TLSError ByteString))
-> (TLSError -> Either TLSError ByteString)
-> TLSError
-> IO (Either TLSError ByteString)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TLSError -> Either TLSError ByteString
forall a b. a -> Either a b
Left (TLSError -> IO (Either TLSError ByteString))
-> TLSError -> IO (Either TLSError ByteString)
forall a b. (a -> b) -> a -> b
$
                if ByteString -> Bool
B.null ByteString
hdrbs
                    then TLSError
Error_EOF
                    else
                        String -> TLSError
Error_Packet
                            ( String
"partial packet: expecting "
                                String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
sz
                                String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" bytes, got: "
                                String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (ByteString -> Int
B.length ByteString
hdrbs)
                            )