{-# LANGUAGE OverloadedStrings #-}

module Network.TLS.Handshake.Server.Common (
    applicationProtocol,
    checkValidClientCertChain,
    clientCertificate,
    credentialDigitalSignatureKey,
    filterCredentials,
    filterCredentialsWithHashSignatures,
    isCredentialAllowed,
    storePrivInfoServer,
    hashAndSignaturesInCommon,
) where

import Control.Monad.State.Strict
import Data.X509 (ExtKeyUsageFlag (..))

import Network.TLS.Context.Internal
import Network.TLS.Credentials
import Network.TLS.Crypto
import Network.TLS.Extension
import Network.TLS.Handshake.Certificate
import Network.TLS.Handshake.Common
import Network.TLS.Handshake.Key
import Network.TLS.Handshake.State
import Network.TLS.Imports
import Network.TLS.Parameters
import Network.TLS.State
import Network.TLS.Struct
import Network.TLS.Util (catchException)
import Network.TLS.X509

checkValidClientCertChain
    :: MonadIO m => Context -> String -> m CertificateChain
checkValidClientCertChain :: forall (m :: * -> *).
MonadIO m =>
Context -> String -> m CertificateChain
checkValidClientCertChain Context
ctx String
errmsg = do
    Maybe CertificateChain
chain <- Context
-> HandshakeM (Maybe CertificateChain)
-> m (Maybe CertificateChain)
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx HandshakeM (Maybe CertificateChain)
getClientCertChain
    let throwerror :: TLSError
throwerror = String -> AlertDescription -> TLSError
Error_Protocol String
errmsg AlertDescription
UnexpectedMessage
    case Maybe CertificateChain
chain of
        Maybe CertificateChain
Nothing -> TLSError -> m CertificateChain
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore TLSError
throwerror
        Just CertificateChain
cc
            | CertificateChain -> Bool
isNullCertificateChain CertificateChain
cc -> TLSError -> m CertificateChain
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore TLSError
throwerror
            | Bool
otherwise -> CertificateChain -> m CertificateChain
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return CertificateChain
cc

credentialDigitalSignatureKey :: Credential -> Maybe PubKey
credentialDigitalSignatureKey :: Credential -> Maybe PubKey
credentialDigitalSignatureKey Credential
cred
    | (PubKey, PrivKey) -> Bool
isDigitalSignaturePair (PubKey, PrivKey)
keys = PubKey -> Maybe PubKey
forall a. a -> Maybe a
Just PubKey
pubkey
    | Bool
otherwise = Maybe PubKey
forall a. Maybe a
Nothing
  where
    keys :: (PubKey, PrivKey)
keys@(PubKey
pubkey, PrivKey
_) = Credential -> (PubKey, PrivKey)
credentialPublicPrivateKeys Credential
cred

filterCredentials :: (Credential -> Bool) -> Credentials -> Credentials
filterCredentials :: (Credential -> Bool) -> Credentials -> Credentials
filterCredentials Credential -> Bool
p (Credentials [Credential]
l) = [Credential] -> Credentials
Credentials ((Credential -> Bool) -> [Credential] -> [Credential]
forall a. (a -> Bool) -> [a] -> [a]
filter Credential -> Bool
p [Credential]
l)

isCredentialAllowed :: Version -> [ExtensionRaw] -> Credential -> Bool
isCredentialAllowed :: Version -> [ExtensionRaw] -> Credential -> Bool
isCredentialAllowed Version
ver [ExtensionRaw]
exts Credential
cred =
    PubKey
pubkey PubKey -> Version -> Bool
`versionCompatible` Version
ver Bool -> Bool -> Bool
&& (Group -> Bool) -> PubKey -> Bool
satisfiesEcPredicate Group -> Bool
p PubKey
pubkey
  where
    (PubKey
pubkey, PrivKey
_) = Credential -> (PubKey, PrivKey)
credentialPublicPrivateKeys Credential
cred
    -- ECDSA keys are tested against supported elliptic curves until TLS12 but
    -- not after.  With TLS13, the curve is linked to the signature algorithm
    -- and client support is tested with signatureCompatible13.
    p :: Group -> Bool
p
        | Version
ver Version -> Version -> Bool
forall a. Ord a => a -> a -> Bool
< Version
TLS13 =
            ExtensionID
-> MessageType
-> [ExtensionRaw]
-> (Group -> Bool)
-> (SupportedGroups -> Group -> Bool)
-> Group
-> Bool
forall e a.
Extension e =>
ExtensionID -> MessageType -> [ExtensionRaw] -> a -> (e -> a) -> a
lookupAndDecode
                ExtensionID
EID_SupportedGroups
                MessageType
MsgTClientHello
                [ExtensionRaw]
exts
                (Bool -> Group -> Bool
forall a b. a -> b -> a
const Bool
True)
                (\(SupportedGroups [Group]
sg) -> (Group -> [Group] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Group]
sg))
        | Bool
otherwise = Bool -> Group -> Bool
forall a b. a -> b -> a
const Bool
True

-- Filters a list of candidate credentials with credentialMatchesHashSignatures.
--
-- Algorithms to filter with are taken from "signature_algorithms_cert"
-- extension when it exists, else from "signature_algorithms" when clients do
-- not implement the new extension (see RFC 8446 section 4.2.3).
--
-- Resulting credential list can be used as input to the hybrid cipher-and-
-- certificate selection for TLS12, or to the direct certificate selection
-- simplified with TLS13.  As filtering credential signatures with client-
-- advertised algorithms is not supposed to cause negotiation failure, in case
-- of dead end with the subsequent selection process, this process should always
-- be restarted with the unfiltered credential list as input (see fallback
-- certificate chains, described in same RFC section).
--
-- Calling code should not forget to apply constraints of extension
-- "signature_algorithms" to any signature-based key exchange derived from the
-- output credentials.  Respecting client constraints on KX signatures is
-- mandatory but not implemented by this function.
filterCredentialsWithHashSignatures
    :: [ExtensionRaw] -> Credentials -> Credentials
filterCredentialsWithHashSignatures :: [ExtensionRaw] -> Credentials -> Credentials
filterCredentialsWithHashSignatures [ExtensionRaw]
exts =
    ExtensionID
-> MessageType
-> [ExtensionRaw]
-> (Credentials -> Credentials)
-> (SignatureAlgorithmsCert -> Credentials -> Credentials)
-> Credentials
-> Credentials
forall e a.
Extension e =>
ExtensionID -> MessageType -> [ExtensionRaw] -> a -> (e -> a) -> a
lookupAndDecode
        ExtensionID
EID_SignatureAlgorithmsCert
        MessageType
MsgTClientHello
        [ExtensionRaw]
exts
        Credentials -> Credentials
lookupSignatureAlgorithms
        (\(SignatureAlgorithmsCert [HashAndSignatureAlgorithm]
sas) -> [HashAndSignatureAlgorithm] -> Credentials -> Credentials
withAlgs [HashAndSignatureAlgorithm]
sas)
  where
    lookupSignatureAlgorithms :: Credentials -> Credentials
lookupSignatureAlgorithms =
        ExtensionID
-> MessageType
-> [ExtensionRaw]
-> (Credentials -> Credentials)
-> (SignatureAlgorithms -> Credentials -> Credentials)
-> Credentials
-> Credentials
forall e a.
Extension e =>
ExtensionID -> MessageType -> [ExtensionRaw] -> a -> (e -> a) -> a
lookupAndDecode
            ExtensionID
EID_SignatureAlgorithms
            MessageType
MsgTClientHello
            [ExtensionRaw]
exts
            Credentials -> Credentials
forall a. a -> a
id
            (\(SignatureAlgorithms [HashAndSignatureAlgorithm]
sas) -> [HashAndSignatureAlgorithm] -> Credentials -> Credentials
withAlgs [HashAndSignatureAlgorithm]
sas)
    withAlgs :: [HashAndSignatureAlgorithm] -> Credentials -> Credentials
withAlgs [HashAndSignatureAlgorithm]
sas = (Credential -> Bool) -> Credentials -> Credentials
filterCredentials ([HashAndSignatureAlgorithm] -> Credential -> Bool
credentialMatchesHashSignatures [HashAndSignatureAlgorithm]
sas)

storePrivInfoServer :: MonadIO m => Context -> Credential -> m ()
storePrivInfoServer :: forall (m :: * -> *). MonadIO m => Context -> Credential -> m ()
storePrivInfoServer Context
ctx (CertificateChain
cc, PrivKey
privkey) = m PubKey -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (Context -> CertificateChain -> PrivKey -> m PubKey
forall (m :: * -> *).
MonadIO m =>
Context -> CertificateChain -> PrivKey -> m PubKey
storePrivInfo Context
ctx CertificateChain
cc PrivKey
privkey)

-- ALPN (Application Layer Protocol Negotiation)
applicationProtocol
    :: Context -> [ExtensionRaw] -> ServerParams -> IO (Maybe ExtensionRaw)
applicationProtocol :: Context
-> [ExtensionRaw] -> ServerParams -> IO (Maybe ExtensionRaw)
applicationProtocol Context
ctx [ExtensionRaw]
exts ServerParams
sparams = case Maybe ([ByteString] -> IO ByteString)
onALPN of
    Maybe ([ByteString] -> IO ByteString)
Nothing -> Maybe ExtensionRaw -> IO (Maybe ExtensionRaw)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ExtensionRaw
forall a. Maybe a
Nothing
    Just [ByteString] -> IO ByteString
io ->
        ExtensionID
-> MessageType
-> [ExtensionRaw]
-> IO (Maybe ExtensionRaw)
-> (ApplicationLayerProtocolNegotiation -> IO (Maybe ExtensionRaw))
-> IO (Maybe ExtensionRaw)
forall a b.
Extension a =>
ExtensionID
-> MessageType -> [ExtensionRaw] -> IO b -> (a -> IO b) -> IO b
lookupAndDecodeAndDo
            ExtensionID
EID_ApplicationLayerProtocolNegotiation
            MessageType
MsgTClientHello
            [ExtensionRaw]
exts
            (Maybe ExtensionRaw -> IO (Maybe ExtensionRaw)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ExtensionRaw
forall a. Maybe a
Nothing)
            ((ApplicationLayerProtocolNegotiation -> IO (Maybe ExtensionRaw))
 -> IO (Maybe ExtensionRaw))
-> (ApplicationLayerProtocolNegotiation -> IO (Maybe ExtensionRaw))
-> IO (Maybe ExtensionRaw)
forall a b. (a -> b) -> a -> b
$ ([ByteString] -> IO ByteString)
-> ApplicationLayerProtocolNegotiation -> IO (Maybe ExtensionRaw)
select [ByteString] -> IO ByteString
io
  where
    onALPN :: Maybe ([ByteString] -> IO ByteString)
onALPN = ServerHooks -> Maybe ([ByteString] -> IO ByteString)
onALPNClientSuggest (ServerHooks -> Maybe ([ByteString] -> IO ByteString))
-> ServerHooks -> Maybe ([ByteString] -> IO ByteString)
forall a b. (a -> b) -> a -> b
$ ServerParams -> ServerHooks
serverHooks ServerParams
sparams
    select :: ([ByteString] -> IO ByteString)
-> ApplicationLayerProtocolNegotiation -> IO (Maybe ExtensionRaw)
select [ByteString] -> IO ByteString
io (ApplicationLayerProtocolNegotiation [ByteString]
protos) = do
        ByteString
proto <- [ByteString] -> IO ByteString
io [ByteString]
protos
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ByteString
proto ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"") (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
            TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
                String -> AlertDescription -> TLSError
Error_Protocol String
"no supported application protocols" AlertDescription
NoApplicationProtocol
        Context -> TLSSt () -> IO ()
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx (TLSSt () -> IO ()) -> TLSSt () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
            Bool -> TLSSt ()
setExtensionALPN Bool
True
            ByteString -> TLSSt ()
setNegotiatedProtocol ByteString
proto
        Maybe ExtensionRaw -> IO (Maybe ExtensionRaw)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ExtensionRaw -> IO (Maybe ExtensionRaw))
-> Maybe ExtensionRaw -> IO (Maybe ExtensionRaw)
forall a b. (a -> b) -> a -> b
$
            ExtensionRaw -> Maybe ExtensionRaw
forall a. a -> Maybe a
Just (ExtensionRaw -> Maybe ExtensionRaw)
-> ExtensionRaw -> Maybe ExtensionRaw
forall a b. (a -> b) -> a -> b
$
                ExtensionID -> ByteString -> ExtensionRaw
ExtensionRaw
                    ExtensionID
EID_ApplicationLayerProtocolNegotiation
                    (ApplicationLayerProtocolNegotiation -> ByteString
forall a. Extension a => a -> ByteString
extensionEncode (ApplicationLayerProtocolNegotiation -> ByteString)
-> ApplicationLayerProtocolNegotiation -> ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> ApplicationLayerProtocolNegotiation
ApplicationLayerProtocolNegotiation [ByteString
proto])

clientCertificate :: ServerParams -> Context -> CertificateChain -> IO ()
clientCertificate :: ServerParams -> Context -> CertificateChain -> IO ()
clientCertificate ServerParams
sparams Context
ctx CertificateChain
certs = do
    -- run certificate recv hook
    Context -> (Hooks -> IO ()) -> IO ()
forall a. Context -> (Hooks -> IO a) -> IO a
ctxWithHooks Context
ctx (Hooks -> CertificateChain -> IO ()
`hookRecvCertificates` CertificateChain
certs)
    -- Call application callback to see whether the
    -- certificate chain is acceptable.
    --
    CertificateUsage
usage <-
        IO CertificateUsage -> IO CertificateUsage
forall a. IO a -> IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO CertificateUsage -> IO CertificateUsage)
-> IO CertificateUsage -> IO CertificateUsage
forall a b. (a -> b) -> a -> b
$
            IO CertificateUsage
-> (SomeException -> IO CertificateUsage) -> IO CertificateUsage
forall a. IO a -> (SomeException -> IO a) -> IO a
catchException
                (ServerHooks -> CertificateChain -> IO CertificateUsage
onClientCertificate (ServerParams -> ServerHooks
serverHooks ServerParams
sparams) CertificateChain
certs)
                SomeException -> IO CertificateUsage
rejectOnException
    case CertificateUsage
usage of
        CertificateUsage
CertificateUsageAccept -> [ExtKeyUsageFlag] -> CertificateChain -> IO ()
forall (m :: * -> *).
MonadIO m =>
[ExtKeyUsageFlag] -> CertificateChain -> m ()
verifyLeafKeyUsage [ExtKeyUsageFlag
KeyUsage_digitalSignature] CertificateChain
certs
        CertificateUsageReject CertificateRejectReason
reason -> CertificateRejectReason -> IO ()
forall (m :: * -> *) a. MonadIO m => CertificateRejectReason -> m a
certificateRejected CertificateRejectReason
reason

    -- Remember cert chain for later use.
    --
    Context -> HandshakeM () -> IO ()
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (HandshakeM () -> IO ()) -> HandshakeM () -> IO ()
forall a b. (a -> b) -> a -> b
$ CertificateChain -> HandshakeM ()
setClientCertChain CertificateChain
certs

----------------------------------------------------------------

-- The values in the "signature_algorithms" extension
-- are in descending order of preference.
-- However here the algorithms are selected according
-- to server preference in 'supportedHashSignatures'.
hashAndSignaturesInCommon
    :: [HashAndSignatureAlgorithm] -> [ExtensionRaw] -> [HashAndSignatureAlgorithm]
hashAndSignaturesInCommon :: [HashAndSignatureAlgorithm]
-> [ExtensionRaw] -> [HashAndSignatureAlgorithm]
hashAndSignaturesInCommon [HashAndSignatureAlgorithm]
sHashSigs [ExtensionRaw]
exts = [HashAndSignatureAlgorithm]
sHashSigs [HashAndSignatureAlgorithm]
-> [HashAndSignatureAlgorithm] -> [HashAndSignatureAlgorithm]
forall a. Eq a => [a] -> [a] -> [a]
`intersect` [HashAndSignatureAlgorithm]
cHashSigs
  where
    -- See Section 7.4.1.4.1 of RFC 5246.
    defVal :: [HashAndSignatureAlgorithm]
defVal =
        [ (HashAlgorithm
HashSHA1, SignatureAlgorithm
SignatureECDSA)
        , (HashAlgorithm
HashSHA1, SignatureAlgorithm
SignatureRSA)
        , (HashAlgorithm
HashSHA1, SignatureAlgorithm
SignatureDSA)
        ]
    cHashSigs :: [HashAndSignatureAlgorithm]
cHashSigs =
        ExtensionID
-> MessageType
-> [ExtensionRaw]
-> [HashAndSignatureAlgorithm]
-> (SignatureAlgorithms -> [HashAndSignatureAlgorithm])
-> [HashAndSignatureAlgorithm]
forall e a.
Extension e =>
ExtensionID -> MessageType -> [ExtensionRaw] -> a -> (e -> a) -> a
lookupAndDecode
            ExtensionID
EID_SignatureAlgorithms
            MessageType
MsgTClientHello
            [ExtensionRaw]
exts
            [HashAndSignatureAlgorithm]
defVal
            (\(SignatureAlgorithms [HashAndSignatureAlgorithm]
sas) -> [HashAndSignatureAlgorithm]
sas)