{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternGuards #-}

module Network.TLS.Handshake.Random (
    serverRandom,
    serverRandomECH,
    replaceServerRandomECH,
    clientRandom,
    isDowngraded,
) where

import qualified Data.ByteString as B

import Network.TLS.Context.Internal
import Network.TLS.Imports
import Network.TLS.Struct

-- | Generate a server random suitable for the version selected by the server
-- and its supported versions.  We use an 8-byte downgrade suffix when the
-- selected version is lowered because of incomplete client support, but also
-- when a version downgrade has been forced with 'debugVersionForced'.  This
-- second part allows to test that the client implementation correctly detects
-- downgrades.  The suffix is not used when forcing TLS13 to a server not
-- officially supporting TLS13 (this is not a downgrade scenario but only the
-- consequence of our debug API allowing this).
serverRandom :: Context -> Version -> [Version] -> IO ServerRandom
serverRandom :: Context -> Version -> [Version] -> IO ServerRandom
serverRandom Context
ctx Version
chosenVer [Version]
suppVers
    | Version
TLS13 Version -> [Version] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Version]
suppVers = case Version
chosenVer of
        Version
TLS13 -> ByteString -> ServerRandom
ServerRandom (ByteString -> ServerRandom) -> IO ByteString -> IO ServerRandom
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Context -> Int -> IO ByteString
getStateRNG Context
ctx Int
32
        Version
TLS12 -> ByteString -> ServerRandom
ServerRandom (ByteString -> ServerRandom) -> IO ByteString -> IO ServerRandom
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> IO ByteString
genServRand ByteString
suffix12
        Version
_ -> ByteString -> ServerRandom
ServerRandom (ByteString -> ServerRandom) -> IO ByteString -> IO ServerRandom
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> IO ByteString
genServRand ByteString
suffix11
    | Version
TLS12 Version -> [Version] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Version]
suppVers = case Version
chosenVer of
        Version
TLS13 -> ByteString -> ServerRandom
ServerRandom (ByteString -> ServerRandom) -> IO ByteString -> IO ServerRandom
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Context -> Int -> IO ByteString
getStateRNG Context
ctx Int
32
        Version
TLS12 -> ByteString -> ServerRandom
ServerRandom (ByteString -> ServerRandom) -> IO ByteString -> IO ServerRandom
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Context -> Int -> IO ByteString
getStateRNG Context
ctx Int
32
        Version
_ -> ByteString -> ServerRandom
ServerRandom (ByteString -> ServerRandom) -> IO ByteString -> IO ServerRandom
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> IO ByteString
genServRand ByteString
suffix11
    | Bool
otherwise = ByteString -> ServerRandom
ServerRandom (ByteString -> ServerRandom) -> IO ByteString -> IO ServerRandom
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Context -> Int -> IO ByteString
getStateRNG Context
ctx Int
32
  where
    genServRand :: ByteString -> IO ByteString
genServRand ByteString
suff = do
        ByteString
pref <- Context -> Int -> IO ByteString
getStateRNG Context
ctx Int
24
        ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
pref ByteString -> ByteString -> ByteString
`B.append` ByteString
suff)

serverRandomECH :: Context -> IO ServerRandom
serverRandomECH :: Context -> IO ServerRandom
serverRandomECH Context
ctx = do
    ByteString
rnd <- Context -> Int -> IO ByteString
getStateRNG Context
ctx Int
24
    let zeros :: ByteString
zeros = ByteString
"\x00\x00\x00\x00\x00\x00\x00\x00"
    ServerRandom -> IO ServerRandom
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ServerRandom -> IO ServerRandom)
-> ServerRandom -> IO ServerRandom
forall a b. (a -> b) -> a -> b
$ ByteString -> ServerRandom
ServerRandom (ByteString
rnd ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
zeros)

replaceServerRandomECH :: ServerRandom -> ByteString -> ServerRandom
replaceServerRandomECH :: ServerRandom -> ByteString -> ServerRandom
replaceServerRandomECH (ServerRandom ByteString
rnd) ByteString
bs = ByteString -> ServerRandom
ServerRandom (ByteString
rnd' ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
bs)
  where
    rnd' :: ByteString
rnd' = Int -> ByteString -> ByteString
B.take Int
24 ByteString
rnd

-- | Test if the negotiated version was artificially downgraded (that is, for
-- other reason than the versions supported by the client).
isDowngraded :: Version -> [Version] -> ServerRandom -> Bool
isDowngraded :: Version -> [Version] -> ServerRandom -> Bool
isDowngraded Version
ver [Version]
suppVers (ServerRandom ByteString
sr)
    | Version
ver Version -> Version -> Bool
forall a. Ord a => a -> a -> Bool
<= Version
TLS12
    , Version
TLS13 Version -> [Version] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Version]
suppVers =
        ByteString
suffix12 ByteString -> ByteString -> Bool
`B.isSuffixOf` ByteString
sr
            Bool -> Bool -> Bool
|| ByteString
suffix11 ByteString -> ByteString -> Bool
`B.isSuffixOf` ByteString
sr
    | Version
ver Version -> Version -> Bool
forall a. Ord a => a -> a -> Bool
<= Version
TLS11
    , Version
TLS12 Version -> [Version] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Version]
suppVers =
        ByteString
suffix11 ByteString -> ByteString -> Bool
`B.isSuffixOf` ByteString
sr
    | Bool
otherwise = Bool
False

suffix12 :: ByteString
suffix12 :: ByteString
suffix12 = ByteString
"\x44\x4F\x57\x4E\x47\x52\x44\x01"

suffix11 :: ByteString
suffix11 :: ByteString
suffix11 = ByteString
"\x44\x4F\x57\x4E\x47\x52\x44\x00"

clientRandom :: Context -> IO ClientRandom
clientRandom :: Context -> IO ClientRandom
clientRandom Context
ctx = ByteString -> ClientRandom
ClientRandom (ByteString -> ClientRandom) -> IO ByteString -> IO ClientRandom
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Context -> Int -> IO ByteString
getStateRNG Context
ctx Int
32