{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}

module Network.TLS.Handshake.Server.ClientHello13 (
    processClientHello13,
    sendHRR,
) where

import Network.TLS.Cipher
import Network.TLS.Context.Internal
import Network.TLS.Crypto
import Network.TLS.Extension
import Network.TLS.Handshake.Common13
import Network.TLS.Handshake.State
import Network.TLS.Handshake.State13
import Network.TLS.IO
import Network.TLS.Imports
import Network.TLS.Parameters
import Network.TLS.State
import Network.TLS.Struct
import Network.TLS.Struct13
import Network.TLS.Types

-- TLS 1.3 or later
processClientHello13
    :: ServerParams
    -> Context
    -> CH
    -> IO (Maybe KeyShareEntry, (Cipher, Hash, Bool))
processClientHello13 :: ServerParams
-> Context -> CH -> IO (Maybe KeyShareEntry, (Cipher, Hash, Bool))
processClientHello13 ServerParams
sparams Context
ctx CH{[CipherId]
[ExtensionRaw]
Session
chSession :: Session
chCiphers :: [CipherId]
chExtensions :: [ExtensionRaw]
chSession :: CH -> Session
chCiphers :: CH -> [CipherId]
chExtensions :: CH -> [ExtensionRaw]
..} = do
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when
        ((ExtensionRaw -> Bool) -> [ExtensionRaw] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (\(ExtensionRaw ExtensionID
eid ByteString
_) -> ExtensionID
eid ExtensionID -> ExtensionID -> Bool
forall a. Eq a => a -> a -> Bool
== ExtensionID
EID_PreSharedKey) ([ExtensionRaw] -> Bool) -> [ExtensionRaw] -> Bool
forall a b. (a -> b) -> a -> b
$ [ExtensionRaw] -> [ExtensionRaw]
forall a. HasCallStack => [a] -> [a]
init [ExtensionRaw]
chExtensions)
        (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore
        (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$ String -> AlertDescription -> TLSError
Error_Protocol String
"extension pre_shared_key must be last" AlertDescription
IllegalParameter
    -- Deciding cipher.
    -- The shared cipherlist can become empty after filtering for compatible
    -- creds, check now before calling onCipherChoosing, which does not handle
    -- empty lists.
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([Cipher] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Cipher]
ciphersFilteredVersion) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
        TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
            String -> AlertDescription -> TLSError
Error_Protocol String
"no cipher in common with the TLS 1.3 client" AlertDescription
HandshakeFailure
    let usedCipher :: Cipher
usedCipher = ServerHooks -> Version -> [Cipher] -> Cipher
onCipherChoosing (ServerParams -> ServerHooks
serverHooks ServerParams
sparams) Version
TLS13 [Cipher]
ciphersFilteredVersion
        usedHash :: Hash
usedHash = Cipher -> Hash
cipherHash Cipher
usedCipher
        rtt0 :: Bool
rtt0 =
            ExtensionID
-> MessageType
-> [ExtensionRaw]
-> Bool
-> (EarlyDataIndication -> Bool)
-> Bool
forall e a.
Extension e =>
ExtensionID -> MessageType -> [ExtensionRaw] -> a -> (e -> a) -> a
lookupAndDecode
                ExtensionID
EID_EarlyData
                MessageType
MsgTClientHello
                [ExtensionRaw]
chExtensions
                Bool
False
                (\(EarlyDataIndication Maybe Word32
_) -> Bool
True)
    if Bool
rtt0
        then
            -- mark a 0-RTT attempt before a possible HRR, and before updating the
            -- status again if 0-RTT successful
            Context -> Established -> IO ()
setEstablished Context
ctx (Int -> Established
EarlyDataNotAllowed Int
3) -- hardcoding
        else
            -- In the case of HRR, EarlyDataNotAllowed is already set.
            -- It should be cleared here.
            Context -> Established -> IO ()
setEstablished Context
ctx Established
NotEstablished
    -- Deciding key exchange from key shares
    let require :: IO a
require =
            TLSError -> IO a
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO a) -> TLSError -> IO a
forall a b. (a -> b) -> a -> b
$
                String -> AlertDescription -> TLSError
Error_Protocol
                    String
"key exchange not implemented, expected key_share extension"
                    AlertDescription
MissingExtension
        extract :: KeyShare -> IO [KeyShareEntry]
extract (KeyShareClientHello [KeyShareEntry]
kses) = [KeyShareEntry] -> IO [KeyShareEntry]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return [KeyShareEntry]
kses
        extract KeyShare
_ = IO [KeyShareEntry]
forall {a}. IO a
require
    [KeyShareEntry]
keyShares <-
        ExtensionID
-> MessageType
-> [ExtensionRaw]
-> IO [KeyShareEntry]
-> (KeyShare -> IO [KeyShareEntry])
-> IO [KeyShareEntry]
forall a b.
Extension a =>
ExtensionID
-> MessageType -> [ExtensionRaw] -> IO b -> (a -> IO b) -> IO b
lookupAndDecodeAndDo ExtensionID
EID_KeyShare MessageType
MsgTClientHello [ExtensionRaw]
chExtensions IO [KeyShareEntry]
forall {a}. IO a
require KeyShare -> IO [KeyShareEntry]
extract
    Maybe KeyShareEntry
mshare <- [KeyShareEntry] -> [Group] -> IO (Maybe KeyShareEntry)
findKeyShare [KeyShareEntry]
keyShares [Group]
serverGroups
    (Maybe KeyShareEntry, (Cipher, Hash, Bool))
-> IO (Maybe KeyShareEntry, (Cipher, Hash, Bool))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe KeyShareEntry
mshare, (Cipher
usedCipher, Hash
usedHash, Bool
rtt0))
  where
    ciphersFilteredVersion :: [Cipher]
ciphersFilteredVersion = [CipherId] -> [Cipher] -> [Cipher]
intersectCiphers [CipherId]
chCiphers [Cipher]
serverCiphers
    serverCiphers :: [Cipher]
serverCiphers =
        (Cipher -> Bool) -> [Cipher] -> [Cipher]
forall a. (a -> Bool) -> [a] -> [a]
filter
            (Version -> Cipher -> Bool
cipherAllowedForVersion Version
TLS13)
            (Supported -> [Cipher]
supportedCiphers (Supported -> [Cipher]) -> Supported -> [Cipher]
forall a b. (a -> b) -> a -> b
$ ServerParams -> Supported
serverSupported ServerParams
sparams)
    serverGroups :: [Group]
serverGroups = Supported -> [Group]
supportedGroups (Context -> Supported
ctxSupported Context
ctx)

findKeyShare :: [KeyShareEntry] -> [Group] -> IO (Maybe KeyShareEntry)
findKeyShare :: [KeyShareEntry] -> [Group] -> IO (Maybe KeyShareEntry)
findKeyShare [KeyShareEntry]
ks [Group]
ggs = [Group] -> IO (Maybe KeyShareEntry)
forall {m :: * -> *}.
MonadIO m =>
[Group] -> m (Maybe KeyShareEntry)
go [Group]
ggs
  where
    go :: [Group] -> m (Maybe KeyShareEntry)
go [] = Maybe KeyShareEntry -> m (Maybe KeyShareEntry)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe KeyShareEntry
forall a. Maybe a
Nothing
    go (Group
g : [Group]
gs) = case (KeyShareEntry -> Bool) -> [KeyShareEntry] -> [KeyShareEntry]
forall a. (a -> Bool) -> [a] -> [a]
filter (Group -> KeyShareEntry -> Bool
grpEq Group
g) [KeyShareEntry]
ks of
        [] -> [Group] -> m (Maybe KeyShareEntry)
go [Group]
gs
        [KeyShareEntry
k] -> do
            Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (KeyShareEntry -> Bool
checkKeyShareKeyLength KeyShareEntry
k) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
                TLSError -> m ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> m ()) -> TLSError -> m ()
forall a b. (a -> b) -> a -> b
$
                    String -> AlertDescription -> TLSError
Error_Protocol String
"broken key_share" AlertDescription
IllegalParameter
            Maybe KeyShareEntry -> m (Maybe KeyShareEntry)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe KeyShareEntry -> m (Maybe KeyShareEntry))
-> Maybe KeyShareEntry -> m (Maybe KeyShareEntry)
forall a b. (a -> b) -> a -> b
$ KeyShareEntry -> Maybe KeyShareEntry
forall a. a -> Maybe a
Just KeyShareEntry
k
        [KeyShareEntry]
_ -> TLSError -> m (Maybe KeyShareEntry)
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> m (Maybe KeyShareEntry))
-> TLSError -> m (Maybe KeyShareEntry)
forall a b. (a -> b) -> a -> b
$ String -> AlertDescription -> TLSError
Error_Protocol String
"duplicated key_share" AlertDescription
IllegalParameter
    grpEq :: Group -> KeyShareEntry -> Bool
grpEq Group
g KeyShareEntry
ent = Group
g Group -> Group -> Bool
forall a. Eq a => a -> a -> Bool
== KeyShareEntry -> Group
keyShareEntryGroup KeyShareEntry
ent

sendHRR :: Context -> (Cipher, a, b) -> CH -> IO ()
sendHRR :: forall a b. Context -> (Cipher, a, b) -> CH -> IO ()
sendHRR Context
ctx (Cipher
usedCipher, a
_, b
_) CH{[CipherId]
[ExtensionRaw]
Session
chSession :: CH -> Session
chCiphers :: CH -> [CipherId]
chExtensions :: CH -> [ExtensionRaw]
chSession :: Session
chCiphers :: [CipherId]
chExtensions :: [ExtensionRaw]
..} = do
    Bool
twice <- Context -> TLSSt Bool -> IO Bool
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt Bool
getTLS13HRR
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
twice (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
        TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
            String -> AlertDescription -> TLSError
Error_Protocol String
"Hello retry not allowed again" AlertDescription
HandshakeFailure
    Context -> TLSSt () -> IO ()
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx (TLSSt () -> IO ()) -> TLSSt () -> IO ()
forall a b. (a -> b) -> a -> b
$ Bool -> TLSSt ()
setTLS13HRR Bool
True
    IO (Either TLSError ()) -> IO ()
forall (m :: * -> *) a. MonadIO m => m (Either TLSError a) -> m a
failOnEitherError (IO (Either TLSError ()) -> IO ())
-> IO (Either TLSError ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ Context
-> HandshakeM (Either TLSError ()) -> IO (Either TLSError ())
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (HandshakeM (Either TLSError ()) -> IO (Either TLSError ()))
-> HandshakeM (Either TLSError ()) -> IO (Either TLSError ())
forall a b. (a -> b) -> a -> b
$ Cipher -> HandshakeM (Either TLSError ())
setHelloParameters13 Cipher
usedCipher
    let clientGroups :: [Group]
clientGroups =
            ExtensionID
-> MessageType
-> [ExtensionRaw]
-> [Group]
-> (SupportedGroups -> [Group])
-> [Group]
forall e a.
Extension e =>
ExtensionID -> MessageType -> [ExtensionRaw] -> a -> (e -> a) -> a
lookupAndDecode
                ExtensionID
EID_SupportedGroups
                MessageType
MsgTClientHello
                [ExtensionRaw]
chExtensions
                []
                (\(SupportedGroups [Group]
gs) -> [Group]
gs)
        possibleGroups :: [Group]
possibleGroups = [Group]
serverGroups [Group] -> [Group] -> [Group]
forall a. Eq a => [a] -> [a] -> [a]
`intersect` [Group]
clientGroups
    case [Group]
possibleGroups of
        [] ->
            TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
                String -> AlertDescription -> TLSError
Error_Protocol String
"no group in common with the client for HRR" AlertDescription
HandshakeFailure
        Group
g : [Group]
_ -> do
            let keyShareExt :: ExtensionRaw
keyShareExt = KeyShare -> ExtensionRaw
forall e. Extension e => e -> ExtensionRaw
toExtensionRaw (KeyShare -> ExtensionRaw) -> KeyShare -> ExtensionRaw
forall a b. (a -> b) -> a -> b
$ Group -> KeyShare
KeyShareHRR Group
g
                versionExt :: ExtensionRaw
versionExt = SupportedVersions -> ExtensionRaw
forall e. Extension e => e -> ExtensionRaw
toExtensionRaw (SupportedVersions -> ExtensionRaw)
-> SupportedVersions -> ExtensionRaw
forall a b. (a -> b) -> a -> b
$ Version -> SupportedVersions
SupportedVersionsServerHello Version
TLS13
                extensions :: [ExtensionRaw]
extensions = [ExtensionRaw
keyShareExt, ExtensionRaw
versionExt]
                hrr :: Handshake13
hrr = ServerRandom
-> Session -> CipherId -> [ExtensionRaw] -> Handshake13
ServerHello13 ServerRandom
hrrRandom Session
chSession (Word16 -> CipherId
CipherId (Word16 -> CipherId) -> Word16 -> CipherId
forall a b. (a -> b) -> a -> b
$ Cipher -> Word16
cipherID Cipher
usedCipher) [ExtensionRaw]
extensions
            Context -> HandshakeM () -> IO ()
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (HandshakeM () -> IO ()) -> HandshakeM () -> IO ()
forall a b. (a -> b) -> a -> b
$ HandshakeMode13 -> HandshakeM ()
setTLS13HandshakeMode HandshakeMode13
HelloRetryRequest
            Context -> (forall {b}. Monoid b => PacketFlightM b ()) -> IO ()
forall a.
Context -> (forall b. Monoid b => PacketFlightM b a) -> IO a
runPacketFlight Context
ctx ((forall {b}. Monoid b => PacketFlightM b ()) -> IO ())
-> (forall {b}. Monoid b => PacketFlightM b ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ do
                Context -> Packet13 -> PacketFlightM b ()
forall b. Monoid b => Context -> Packet13 -> PacketFlightM b ()
loadPacket13 Context
ctx (Packet13 -> PacketFlightM b ()) -> Packet13 -> PacketFlightM b ()
forall a b. (a -> b) -> a -> b
$ [Handshake13] -> Packet13
Handshake13 [Handshake13
hrr]
                Context -> PacketFlightM b ()
forall b. Monoid b => Context -> PacketFlightM b ()
sendChangeCipherSpec13 Context
ctx
  where
    serverGroups :: [Group]
serverGroups = Supported -> [Group]
supportedGroups (Context -> Supported
ctxSupported Context
ctx)