{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}

module Network.TLS.Handshake.Server.ClientHello13 (
    processClientHello13,
) where

import qualified Data.ByteString as B

import Network.TLS.Cipher
import Network.TLS.Context.Internal
import Network.TLS.Crypto
import Network.TLS.Extension
import Network.TLS.Handshake.Common13
import Network.TLS.Handshake.Signature
import Network.TLS.Handshake.State
import Network.TLS.IO.Encode
import Network.TLS.Imports
import Network.TLS.Packet
import Network.TLS.Parameters
import Network.TLS.Session
import Network.TLS.State
import Network.TLS.Struct
import Network.TLS.Types

-- TLS 1.3 or later
processClientHello13
    :: ServerParams
    -> Context
    -> ClientHello
    -> IO
        ( Maybe KeyShareEntry
        , (Cipher, Hash, Bool)
        , (SecretPair EarlySecret, [ExtensionRaw], Bool, Bool)
        )
processClientHello13 :: ServerParams
-> Context
-> ClientHello
-> IO
     (Maybe KeyShareEntry, (Cipher, Hash, Bool),
      (SecretPair EarlySecret, [ExtensionRaw], Bool, Bool))
processClientHello13 ServerParams
sparams Context
ctx ch :: ClientHello
ch@CH{[CompressionID]
[CipherId]
[ExtensionRaw]
Version
Session
ClientRandom
chVersion :: Version
chRandom :: ClientRandom
chSession :: Session
chCiphers :: [CipherId]
chComps :: [CompressionID]
chExtensions :: [ExtensionRaw]
chVersion :: ClientHello -> Version
chRandom :: ClientHello -> ClientRandom
chSession :: ClientHello -> Session
chCiphers :: ClientHello -> [CipherId]
chComps :: ClientHello -> [CompressionID]
chExtensions :: ClientHello -> [ExtensionRaw]
..} = do
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when
        ((ExtensionRaw -> Bool) -> [ExtensionRaw] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (\(ExtensionRaw ExtensionID
eid ByteString
_) -> ExtensionID
eid ExtensionID -> ExtensionID -> Bool
forall a. Eq a => a -> a -> Bool
== ExtensionID
EID_PreSharedKey) ([ExtensionRaw] -> Bool) -> [ExtensionRaw] -> Bool
forall a b. (a -> b) -> a -> b
$ [ExtensionRaw] -> [ExtensionRaw]
forall a. HasCallStack => [a] -> [a]
init [ExtensionRaw]
chExtensions)
        (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore
        (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$ String -> AlertDescription -> TLSError
Error_Protocol String
"extension pre_shared_key must be last" AlertDescription
IllegalParameter
    -- Deciding cipher.
    -- The shared cipherlist can become empty after filtering for compatible
    -- creds, check now before calling onCipherChoosing, which does not handle
    -- empty lists.
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([Cipher] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Cipher]
ciphersFilteredVersion) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
        TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
            String -> AlertDescription -> TLSError
Error_Protocol String
"no cipher in common with the TLS 1.3 client" AlertDescription
HandshakeFailure
    let usedCipher :: Cipher
usedCipher = ServerHooks -> Version -> [Cipher] -> Cipher
onCipherChoosing (ServerParams -> ServerHooks
serverHooks ServerParams
sparams) Version
TLS13 [Cipher]
ciphersFilteredVersion
        usedHash :: Hash
usedHash = Cipher -> Hash
cipherHash Cipher
usedCipher
        rtt0 :: Bool
rtt0 =
            ExtensionID
-> MessageType
-> [ExtensionRaw]
-> Bool
-> (EarlyDataIndication -> Bool)
-> Bool
forall e a.
Extension e =>
ExtensionID -> MessageType -> [ExtensionRaw] -> a -> (e -> a) -> a
lookupAndDecode
                ExtensionID
EID_EarlyData
                MessageType
MsgTClientHello
                [ExtensionRaw]
chExtensions
                Bool
False
                (\(EarlyDataIndication Maybe Word32
_) -> Bool
True)
    if Bool
rtt0
        then
            -- mark a 0-RTT attempt before a possible HRR, and before updating the
            -- status again if 0-RTT successful
            Context -> Established -> IO ()
setEstablished Context
ctx (Int -> Established
EarlyDataNotAllowed Int
3) -- hardcoding
        else
            -- In the case of HRR, EarlyDataNotAllowed is already set.
            -- It should be cleared here.
            Context -> Established -> IO ()
setEstablished Context
ctx Established
NotEstablished
    -- Deciding key exchange from key shares
    let require :: IO a
require =
            TLSError -> IO a
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO a) -> TLSError -> IO a
forall a b. (a -> b) -> a -> b
$
                String -> AlertDescription -> TLSError
Error_Protocol
                    String
"key exchange not implemented, expected key_share extension"
                    AlertDescription
MissingExtension
        extract :: KeyShare -> IO [KeyShareEntry]
extract (KeyShareClientHello [KeyShareEntry]
kses) = [KeyShareEntry] -> IO [KeyShareEntry]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return [KeyShareEntry]
kses
        extract KeyShare
_ = IO [KeyShareEntry]
forall {a}. IO a
require
    [KeyShareEntry]
keyShares <-
        ExtensionID
-> MessageType
-> [ExtensionRaw]
-> IO [KeyShareEntry]
-> (KeyShare -> IO [KeyShareEntry])
-> IO [KeyShareEntry]
forall a b.
Extension a =>
ExtensionID
-> MessageType -> [ExtensionRaw] -> IO b -> (a -> IO b) -> IO b
lookupAndDecodeAndDo ExtensionID
EID_KeyShare MessageType
MsgTClientHello [ExtensionRaw]
chExtensions IO [KeyShareEntry]
forall {a}. IO a
require KeyShare -> IO [KeyShareEntry]
extract
    Maybe KeyShareEntry
mshare <- [KeyShareEntry] -> [Group] -> IO (Maybe KeyShareEntry)
findKeyShare [KeyShareEntry]
keyShares [Group]
serverGroups
    let triple :: (Cipher, Hash, Bool)
triple = (Cipher
usedCipher, Hash
usedHash, Bool
rtt0)
    (SecretPair EarlySecret, [ExtensionRaw], Bool, Bool)
pskEarlySecret <- ServerParams
-> Context
-> (Cipher, Hash, Bool)
-> ClientHello
-> IO (SecretPair EarlySecret, [ExtensionRaw], Bool, Bool)
pskAndEarlySecret ServerParams
sparams Context
ctx (Cipher, Hash, Bool)
triple ClientHello
ch
    ClientHello
clientHello <- Maybe ClientHello -> ClientHello
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe ClientHello -> ClientHello)
-> IO (Maybe ClientHello) -> IO ClientHello
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Context -> HandshakeM (Maybe ClientHello) -> IO (Maybe ClientHello)
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx HandshakeM (Maybe ClientHello)
getClientHello
    IO ByteString -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO ByteString -> IO ()) -> IO ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ Context -> Handshake -> IO ByteString
updateTranscriptHash12 Context
ctx (Handshake -> IO ByteString) -> Handshake -> IO ByteString
forall a b. (a -> b) -> a -> b
$ ClientHello -> Handshake
ClientHello ClientHello
clientHello
    (Maybe KeyShareEntry, (Cipher, Hash, Bool),
 (SecretPair EarlySecret, [ExtensionRaw], Bool, Bool))
-> IO
     (Maybe KeyShareEntry, (Cipher, Hash, Bool),
      (SecretPair EarlySecret, [ExtensionRaw], Bool, Bool))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe KeyShareEntry
mshare, (Cipher, Hash, Bool)
triple, (SecretPair EarlySecret, [ExtensionRaw], Bool, Bool)
pskEarlySecret)
  where
    ciphersFilteredVersion :: [Cipher]
ciphersFilteredVersion = [CipherId] -> [Cipher] -> [Cipher]
intersectCiphers [CipherId]
chCiphers [Cipher]
serverCiphers
    serverCiphers :: [Cipher]
serverCiphers =
        (Cipher -> Bool) -> [Cipher] -> [Cipher]
forall a. (a -> Bool) -> [a] -> [a]
filter
            (Version -> Cipher -> Bool
cipherAllowedForVersion Version
TLS13)
            (Supported -> [Cipher]
supportedCiphers (Supported -> [Cipher]) -> Supported -> [Cipher]
forall a b. (a -> b) -> a -> b
$ ServerParams -> Supported
serverSupported ServerParams
sparams)
    serverGroups :: [Group]
serverGroups = Supported -> [Group]
supportedGroups (Context -> Supported
ctxSupported Context
ctx)

findKeyShare :: [KeyShareEntry] -> [Group] -> IO (Maybe KeyShareEntry)
findKeyShare :: [KeyShareEntry] -> [Group] -> IO (Maybe KeyShareEntry)
findKeyShare [KeyShareEntry]
ks [Group]
ggs = [Group] -> IO (Maybe KeyShareEntry)
forall {m :: * -> *}.
MonadIO m =>
[Group] -> m (Maybe KeyShareEntry)
go [Group]
ggs
  where
    go :: [Group] -> m (Maybe KeyShareEntry)
go [] = Maybe KeyShareEntry -> m (Maybe KeyShareEntry)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe KeyShareEntry
forall a. Maybe a
Nothing
    go (Group
g : [Group]
gs) = case (KeyShareEntry -> Bool) -> [KeyShareEntry] -> [KeyShareEntry]
forall a. (a -> Bool) -> [a] -> [a]
filter (Group -> KeyShareEntry -> Bool
grpEq Group
g) [KeyShareEntry]
ks of
        [] -> [Group] -> m (Maybe KeyShareEntry)
go [Group]
gs
        [KeyShareEntry
k] -> do
            Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (KeyShareEntry -> Bool
checkKeyShareKeyLength KeyShareEntry
k) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
                TLSError -> m ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> m ()) -> TLSError -> m ()
forall a b. (a -> b) -> a -> b
$
                    String -> AlertDescription -> TLSError
Error_Protocol String
"broken key_share" AlertDescription
IllegalParameter
            Maybe KeyShareEntry -> m (Maybe KeyShareEntry)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe KeyShareEntry -> m (Maybe KeyShareEntry))
-> Maybe KeyShareEntry -> m (Maybe KeyShareEntry)
forall a b. (a -> b) -> a -> b
$ KeyShareEntry -> Maybe KeyShareEntry
forall a. a -> Maybe a
Just KeyShareEntry
k
        [KeyShareEntry]
_ -> TLSError -> m (Maybe KeyShareEntry)
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> m (Maybe KeyShareEntry))
-> TLSError -> m (Maybe KeyShareEntry)
forall a b. (a -> b) -> a -> b
$ String -> AlertDescription -> TLSError
Error_Protocol String
"duplicated key_share" AlertDescription
IllegalParameter
    grpEq :: Group -> KeyShareEntry -> Bool
grpEq Group
g KeyShareEntry
ent = Group
g Group -> Group -> Bool
forall a. Eq a => a -> a -> Bool
== KeyShareEntry -> Group
keyShareEntryGroup KeyShareEntry
ent

pskAndEarlySecret
    :: ServerParams
    -> Context
    -> (Cipher, Hash, Bool)
    -> ClientHello
    -> IO (SecretPair EarlySecret, [ExtensionRaw], Bool, Bool)
pskAndEarlySecret :: ServerParams
-> Context
-> (Cipher, Hash, Bool)
-> ClientHello
-> IO (SecretPair EarlySecret, [ExtensionRaw], Bool, Bool)
pskAndEarlySecret ServerParams
sparams Context
ctx (Cipher
usedCipher, Hash
usedHash, Bool
rtt0) CH{[CompressionID]
[CipherId]
[ExtensionRaw]
Version
Session
ClientRandom
chVersion :: ClientHello -> Version
chRandom :: ClientHello -> ClientRandom
chSession :: ClientHello -> Session
chCiphers :: ClientHello -> [CipherId]
chComps :: ClientHello -> [CompressionID]
chExtensions :: ClientHello -> [ExtensionRaw]
chVersion :: Version
chRandom :: ClientRandom
chSession :: Session
chCiphers :: [CipherId]
chComps :: [CompressionID]
chExtensions :: [ExtensionRaw]
..} = do
    (ByteString
psk, Maybe (ByteString, Int, Int)
binderInfo, Bool
is0RTTvalid) <- IO (ByteString, Maybe (ByteString, Int, Int), Bool)
choosePSK
    SecretPair EarlySecret
earlyKey <- Context
-> CipherChoice
-> Either ByteString (BaseSecret EarlySecret)
-> IO (SecretPair EarlySecret)
calculateEarlySecret Context
ctx CipherChoice
choice (ByteString -> Either ByteString (BaseSecret EarlySecret)
forall a b. a -> Either a b
Left ByteString
psk)
    let earlySecret :: BaseSecret EarlySecret
earlySecret = SecretPair EarlySecret -> BaseSecret EarlySecret
forall a. SecretPair a -> BaseSecret a
pairBase SecretPair EarlySecret
earlyKey
        authenticated :: Bool
authenticated = Maybe (ByteString, Int, Int) -> Bool
forall a. Maybe a -> Bool
isJust Maybe (ByteString, Int, Int)
binderInfo
    [ExtensionRaw]
preSharedKeyExt <- BaseSecret EarlySecret
-> Maybe (ByteString, Int, Int) -> IO [ExtensionRaw]
forall {m :: * -> *} {a}.
(MonadIO m, Integral a) =>
BaseSecret EarlySecret
-> Maybe (ByteString, a, Int) -> m [ExtensionRaw]
checkBinder BaseSecret EarlySecret
earlySecret Maybe (ByteString, Int, Int)
binderInfo
    (SecretPair EarlySecret, [ExtensionRaw], Bool, Bool)
-> IO (SecretPair EarlySecret, [ExtensionRaw], Bool, Bool)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (SecretPair EarlySecret
earlyKey, [ExtensionRaw]
preSharedKeyExt, Bool
authenticated, Bool
is0RTTvalid)
  where
    choice :: CipherChoice
choice = Version -> Cipher -> CipherChoice
makeCipherChoice Version
TLS13 Cipher
usedCipher

    choosePSK :: IO (ByteString, Maybe (ByteString, Int, Int), Bool)
choosePSK =
        ExtensionID
-> MessageType
-> [ExtensionRaw]
-> IO (ByteString, Maybe (ByteString, Int, Int), Bool)
-> (PreSharedKey
    -> IO (ByteString, Maybe (ByteString, Int, Int), Bool))
-> IO (ByteString, Maybe (ByteString, Int, Int), Bool)
forall a b.
Extension a =>
ExtensionID
-> MessageType -> [ExtensionRaw] -> IO b -> (a -> IO b) -> IO b
lookupAndDecodeAndDo
            ExtensionID
EID_PreSharedKey
            MessageType
MsgTClientHello
            [ExtensionRaw]
chExtensions
            ((ByteString, Maybe (ByteString, Int, Int), Bool)
-> IO (ByteString, Maybe (ByteString, Int, Int), Bool)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
zero, Maybe (ByteString, Int, Int)
forall a. Maybe a
Nothing, Bool
False))
            PreSharedKey -> IO (ByteString, Maybe (ByteString, Int, Int), Bool)
selectPSK

    selectPSK :: PreSharedKey -> IO (ByteString, Maybe (ByteString, Int, Int), Bool)
selectPSK (PreSharedKeyClientHello (PskIdentity ByteString
identity Word32
obfAge : [PskIdentity]
_) bnds :: [ByteString]
bnds@(ByteString
bnd : [ByteString]
_)) = do
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([PskKexMode] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [PskKexMode]
dhModes) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
            TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
                String -> AlertDescription -> TLSError
Error_Protocol String
"no psk_key_exchange_modes extension" AlertDescription
MissingExtension
        if PskKexMode
PSK_DHE_KE PskKexMode -> [PskKexMode] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [PskKexMode]
dhModes
            then do
                let len :: Int
len = [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((ByteString -> Int) -> [ByteString] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (\ByteString
x -> ByteString -> Int
B.length ByteString
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [ByteString]
bnds) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2
                    mgr :: SessionManager
mgr = Shared -> SessionManager
sharedSessionManager (Shared -> SessionManager) -> Shared -> SessionManager
forall a b. (a -> b) -> a -> b
$ ServerParams -> Shared
serverShared ServerParams
sparams
                -- sessionInvalidate is not used for TLS 1.3
                -- because PSK is always changed.
                -- So, identity is not stored in Context.
                Maybe SessionData
msdata <-
                    if Bool
rtt0
                        then SessionManager -> ByteString -> IO (Maybe SessionData)
sessionResumeOnlyOnce SessionManager
mgr ByteString
identity
                        else SessionManager -> ByteString -> IO (Maybe SessionData)
sessionResume SessionManager
mgr ByteString
identity
                case Maybe SessionData
msdata of
                    Just SessionData
sdata -> do
                        let tinfo :: TLS13TicketInfo
tinfo = Maybe TLS13TicketInfo -> TLS13TicketInfo
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe TLS13TicketInfo -> TLS13TicketInfo)
-> Maybe TLS13TicketInfo -> TLS13TicketInfo
forall a b. (a -> b) -> a -> b
$ SessionData -> Maybe TLS13TicketInfo
sessionTicketInfo SessionData
sdata
                            psk :: ByteString
psk = SessionData -> ByteString
sessionSecret SessionData
sdata
                        Bool
isFresh <- TLS13TicketInfo -> Word32 -> IO Bool
checkFreshness TLS13TicketInfo
tinfo Word32
obfAge
                        (Bool
isPSKvalid, Bool
is0RTTvalid) <- SessionData -> IO (Bool, Bool)
checkSessionEquality SessionData
sdata
                        if Bool
isPSKvalid Bool -> Bool -> Bool
&& Bool
isFresh
                            then (ByteString, Maybe (ByteString, Int, Int), Bool)
-> IO (ByteString, Maybe (ByteString, Int, Int), Bool)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
psk, (ByteString, Int, Int) -> Maybe (ByteString, Int, Int)
forall a. a -> Maybe a
Just (ByteString
bnd, Int
0 :: Int, Int
len), Bool
is0RTTvalid)
                            else -- fall back to full handshake
                                (ByteString, Maybe (ByteString, Int, Int), Bool)
-> IO (ByteString, Maybe (ByteString, Int, Int), Bool)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
zero, Maybe (ByteString, Int, Int)
forall a. Maybe a
Nothing, Bool
False)
                    Maybe SessionData
_ -> (ByteString, Maybe (ByteString, Int, Int), Bool)
-> IO (ByteString, Maybe (ByteString, Int, Int), Bool)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
zero, Maybe (ByteString, Int, Int)
forall a. Maybe a
Nothing, Bool
False)
            else (ByteString, Maybe (ByteString, Int, Int), Bool)
-> IO (ByteString, Maybe (ByteString, Int, Int), Bool)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
zero, Maybe (ByteString, Int, Int)
forall a. Maybe a
Nothing, Bool
False)
    selectPSK PreSharedKey
_ = (ByteString, Maybe (ByteString, Int, Int), Bool)
-> IO (ByteString, Maybe (ByteString, Int, Int), Bool)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
zero, Maybe (ByteString, Int, Int)
forall a. Maybe a
Nothing, Bool
False)

    checkBinder :: BaseSecret EarlySecret
-> Maybe (ByteString, a, Int) -> m [ExtensionRaw]
checkBinder BaseSecret EarlySecret
_ Maybe (ByteString, a, Int)
Nothing = [ExtensionRaw] -> m [ExtensionRaw]
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return []
    checkBinder BaseSecret EarlySecret
earlySecret (Just (ByteString
binder, a
n, Int
tlen)) = do
        ClientHello
ch <- Maybe ClientHello -> ClientHello
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe ClientHello -> ClientHello)
-> m (Maybe ClientHello) -> m ClientHello
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Context -> HandshakeM (Maybe ClientHello) -> m (Maybe ClientHello)
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx HandshakeM (Maybe ClientHello)
getClientHello
        let ech :: ByteString
ech = Handshake -> ByteString
encodeHandshake (Handshake -> ByteString) -> Handshake -> ByteString
forall a b. (a -> b) -> a -> b
$ ClientHello -> Handshake
ClientHello ClientHello
ch
            binder' :: ByteString
binder' = BaseSecret EarlySecret -> Hash -> Int -> ByteString -> ByteString
makePSKBinder BaseSecret EarlySecret
earlySecret Hash
usedHash Int
tlen ByteString
ech
        Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ByteString
binder ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
binder') (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
            String -> m ()
forall (m :: * -> *) a. MonadIO m => String -> m a
decryptError String
"PSK binder validation failed"
        [ExtensionRaw] -> m [ExtensionRaw]
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return [PreSharedKey -> ExtensionRaw
forall e. Extension e => e -> ExtensionRaw
toExtensionRaw (PreSharedKey -> ExtensionRaw) -> PreSharedKey -> ExtensionRaw
forall a b. (a -> b) -> a -> b
$ Int -> PreSharedKey
PreSharedKeyServerHello (Int -> PreSharedKey) -> Int -> PreSharedKey
forall a b. (a -> b) -> a -> b
$ a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
n]

    checkSessionEquality :: SessionData -> IO (Bool, Bool)
checkSessionEquality SessionData
sdata = do
        Maybe String
msni <- Context -> TLSSt (Maybe String) -> IO (Maybe String)
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt (Maybe String)
getClientSNI
        -- ALPN should be checked.
        -- But it's an extension in EE, sigh.
        --        malpn <- usingState_ ctx getNegotiatedProtocol
        let isSameSNI :: Bool
isSameSNI = SessionData -> Maybe String
sessionClientSNI SessionData
sdata Maybe String -> Maybe String -> Bool
forall a. Eq a => a -> a -> Bool
== Maybe String
msni
            isSameCipher :: Bool
isSameCipher = SessionData -> CipherID
sessionCipher SessionData
sdata CipherID -> CipherID -> Bool
forall a. Eq a => a -> a -> Bool
== Cipher -> CipherID
cipherID Cipher
usedCipher
            ciphers :: [Cipher]
ciphers = Supported -> [Cipher]
supportedCiphers (Supported -> [Cipher]) -> Supported -> [Cipher]
forall a b. (a -> b) -> a -> b
$ ServerParams -> Supported
serverSupported ServerParams
sparams
            scid :: CipherID
scid = SessionData -> CipherID
sessionCipher SessionData
sdata
            isSameKDF :: Bool
isSameKDF = case CipherID -> [Cipher] -> Maybe Cipher
findCipher CipherID
scid [Cipher]
ciphers of
                Maybe Cipher
Nothing -> Bool
False
                Just Cipher
c -> Cipher -> Hash
cipherHash Cipher
c Hash -> Hash -> Bool
forall a. Eq a => a -> a -> Bool
== Cipher -> Hash
cipherHash Cipher
usedCipher
            isSameVersion :: Bool
isSameVersion = Version
TLS13 Version -> Version -> Bool
forall a. Eq a => a -> a -> Bool
== SessionData -> Version
sessionVersion SessionData
sdata
            --            isSameALPN = sessionALPN sdata == malpn
            isPSKvalid :: Bool
isPSKvalid = Bool
isSameKDF Bool -> Bool -> Bool
&& Bool
isSameSNI -- fixme: SNI is not required
            is0RTTvalid :: Bool
is0RTTvalid = Bool
isSameVersion Bool -> Bool -> Bool
&& Bool
isSameCipher -- && isSameALPN
        (Bool, Bool) -> IO (Bool, Bool)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
isPSKvalid, Bool
is0RTTvalid)

    dhModes :: [PskKexMode]
dhModes =
        ExtensionID
-> MessageType
-> [ExtensionRaw]
-> [PskKexMode]
-> (PskKeyExchangeModes -> [PskKexMode])
-> [PskKexMode]
forall e a.
Extension e =>
ExtensionID -> MessageType -> [ExtensionRaw] -> a -> (e -> a) -> a
lookupAndDecode
            ExtensionID
EID_PskKeyExchangeModes
            MessageType
MsgTClientHello
            [ExtensionRaw]
chExtensions
            []
            (\(PskKeyExchangeModes [PskKexMode]
ms) -> [PskKexMode]
ms)

    hashSize :: Int
hashSize = Hash -> Int
hashDigestSize Hash
usedHash
    zero :: ByteString
zero = Int -> CompressionID -> ByteString
B.replicate Int
hashSize CompressionID
0