{-# LANGUAGE OverloadedStrings #-}
module Network.TLS.Handshake.Server.Common (
applicationProtocol,
checkValidClientCertChain,
clientCertificate,
credentialDigitalSignatureKey,
filterCredentials,
filterCredentialsWithHashSignatures,
isCredentialAllowed,
storePrivInfoServer,
hashAndSignaturesInCommon,
) where
import Control.Monad.State.Strict
import Data.X509 (ExtKeyUsageFlag (..))
import Network.TLS.Context.Internal
import Network.TLS.Credentials
import Network.TLS.Crypto
import Network.TLS.Extension
import Network.TLS.Handshake.Certificate
import Network.TLS.Handshake.Common
import Network.TLS.Handshake.Key
import Network.TLS.Handshake.State
import Network.TLS.Imports
import Network.TLS.Parameters
import Network.TLS.State
import Network.TLS.Struct
import Network.TLS.Util (catchException)
import Network.TLS.X509
checkValidClientCertChain
:: MonadIO m => Context -> String -> m CertificateChain
checkValidClientCertChain :: forall (m :: * -> *).
MonadIO m =>
Context -> String -> m CertificateChain
checkValidClientCertChain Context
ctx String
errmsg = do
Maybe CertificateChain
chain <- Context
-> HandshakeM (Maybe CertificateChain)
-> m (Maybe CertificateChain)
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx HandshakeM (Maybe CertificateChain)
getClientCertChain
let throwerror :: TLSError
throwerror = String -> AlertDescription -> TLSError
Error_Protocol String
errmsg AlertDescription
UnexpectedMessage
case Maybe CertificateChain
chain of
Maybe CertificateChain
Nothing -> TLSError -> m CertificateChain
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore TLSError
throwerror
Just CertificateChain
cc
| CertificateChain -> Bool
isNullCertificateChain CertificateChain
cc -> TLSError -> m CertificateChain
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore TLSError
throwerror
| Bool
otherwise -> CertificateChain -> m CertificateChain
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return CertificateChain
cc
credentialDigitalSignatureKey :: Credential -> Maybe PubKey
credentialDigitalSignatureKey :: Credential -> Maybe PubKey
credentialDigitalSignatureKey Credential
cred
| (PubKey, PrivKey) -> Bool
isDigitalSignaturePair (PubKey, PrivKey)
keys = PubKey -> Maybe PubKey
forall a. a -> Maybe a
Just PubKey
pubkey
| Bool
otherwise = Maybe PubKey
forall a. Maybe a
Nothing
where
keys :: (PubKey, PrivKey)
keys@(PubKey
pubkey, PrivKey
_) = Credential -> (PubKey, PrivKey)
credentialPublicPrivateKeys Credential
cred
filterCredentials :: (Credential -> Bool) -> Credentials -> Credentials
filterCredentials :: (Credential -> Bool) -> Credentials -> Credentials
filterCredentials Credential -> Bool
p (Credentials [Credential]
l) = [Credential] -> Credentials
Credentials ((Credential -> Bool) -> [Credential] -> [Credential]
forall a. (a -> Bool) -> [a] -> [a]
filter Credential -> Bool
p [Credential]
l)
isCredentialAllowed :: Version -> [ExtensionRaw] -> Credential -> Bool
isCredentialAllowed :: Version -> [ExtensionRaw] -> Credential -> Bool
isCredentialAllowed Version
ver [ExtensionRaw]
exts Credential
cred =
PubKey
pubkey PubKey -> Version -> Bool
`versionCompatible` Version
ver Bool -> Bool -> Bool
&& (Group -> Bool) -> PubKey -> Bool
satisfiesEcPredicate Group -> Bool
p PubKey
pubkey
where
(PubKey
pubkey, PrivKey
_) = Credential -> (PubKey, PrivKey)
credentialPublicPrivateKeys Credential
cred
p :: Group -> Bool
p
| Version
ver Version -> Version -> Bool
forall a. Ord a => a -> a -> Bool
< Version
TLS13 =
ExtensionID
-> MessageType
-> [ExtensionRaw]
-> (Group -> Bool)
-> (SupportedGroups -> Group -> Bool)
-> Group
-> Bool
forall e a.
Extension e =>
ExtensionID -> MessageType -> [ExtensionRaw] -> a -> (e -> a) -> a
lookupAndDecode
ExtensionID
EID_SupportedGroups
MessageType
MsgTClientHello
[ExtensionRaw]
exts
(Bool -> Group -> Bool
forall a b. a -> b -> a
const Bool
True)
(\(SupportedGroups [Group]
sg) -> (Group -> [Group] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Group]
sg))
| Bool
otherwise = Bool -> Group -> Bool
forall a b. a -> b -> a
const Bool
True
filterCredentialsWithHashSignatures
:: [ExtensionRaw] -> Credentials -> Credentials
filterCredentialsWithHashSignatures :: [ExtensionRaw] -> Credentials -> Credentials
filterCredentialsWithHashSignatures [ExtensionRaw]
exts =
ExtensionID
-> MessageType
-> [ExtensionRaw]
-> (Credentials -> Credentials)
-> (SignatureAlgorithmsCert -> Credentials -> Credentials)
-> Credentials
-> Credentials
forall e a.
Extension e =>
ExtensionID -> MessageType -> [ExtensionRaw] -> a -> (e -> a) -> a
lookupAndDecode
ExtensionID
EID_SignatureAlgorithmsCert
MessageType
MsgTClientHello
[ExtensionRaw]
exts
Credentials -> Credentials
lookupSignatureAlgorithms
(\(SignatureAlgorithmsCert [HashAndSignatureAlgorithm]
sas) -> [HashAndSignatureAlgorithm] -> Credentials -> Credentials
withAlgs [HashAndSignatureAlgorithm]
sas)
where
lookupSignatureAlgorithms :: Credentials -> Credentials
lookupSignatureAlgorithms =
ExtensionID
-> MessageType
-> [ExtensionRaw]
-> (Credentials -> Credentials)
-> (SignatureAlgorithms -> Credentials -> Credentials)
-> Credentials
-> Credentials
forall e a.
Extension e =>
ExtensionID -> MessageType -> [ExtensionRaw] -> a -> (e -> a) -> a
lookupAndDecode
ExtensionID
EID_SignatureAlgorithms
MessageType
MsgTClientHello
[ExtensionRaw]
exts
Credentials -> Credentials
forall a. a -> a
id
(\(SignatureAlgorithms [HashAndSignatureAlgorithm]
sas) -> [HashAndSignatureAlgorithm] -> Credentials -> Credentials
withAlgs [HashAndSignatureAlgorithm]
sas)
withAlgs :: [HashAndSignatureAlgorithm] -> Credentials -> Credentials
withAlgs [HashAndSignatureAlgorithm]
sas = (Credential -> Bool) -> Credentials -> Credentials
filterCredentials ([HashAndSignatureAlgorithm] -> Credential -> Bool
credentialMatchesHashSignatures [HashAndSignatureAlgorithm]
sas)
storePrivInfoServer :: MonadIO m => Context -> Credential -> m ()
storePrivInfoServer :: forall (m :: * -> *). MonadIO m => Context -> Credential -> m ()
storePrivInfoServer Context
ctx (CertificateChain
cc, PrivKey
privkey) = m PubKey -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (Context -> CertificateChain -> PrivKey -> m PubKey
forall (m :: * -> *).
MonadIO m =>
Context -> CertificateChain -> PrivKey -> m PubKey
storePrivInfo Context
ctx CertificateChain
cc PrivKey
privkey)
applicationProtocol
:: Context -> [ExtensionRaw] -> ServerParams -> IO (Maybe ExtensionRaw)
applicationProtocol :: Context
-> [ExtensionRaw] -> ServerParams -> IO (Maybe ExtensionRaw)
applicationProtocol Context
ctx [ExtensionRaw]
exts ServerParams
sparams = case Maybe ([ByteString] -> IO ByteString)
onALPN of
Maybe ([ByteString] -> IO ByteString)
Nothing -> Maybe ExtensionRaw -> IO (Maybe ExtensionRaw)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ExtensionRaw
forall a. Maybe a
Nothing
Just [ByteString] -> IO ByteString
io ->
ExtensionID
-> MessageType
-> [ExtensionRaw]
-> IO (Maybe ExtensionRaw)
-> (ApplicationLayerProtocolNegotiation -> IO (Maybe ExtensionRaw))
-> IO (Maybe ExtensionRaw)
forall a b.
Extension a =>
ExtensionID
-> MessageType -> [ExtensionRaw] -> IO b -> (a -> IO b) -> IO b
lookupAndDecodeAndDo
ExtensionID
EID_ApplicationLayerProtocolNegotiation
MessageType
MsgTClientHello
[ExtensionRaw]
exts
(Maybe ExtensionRaw -> IO (Maybe ExtensionRaw)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ExtensionRaw
forall a. Maybe a
Nothing)
((ApplicationLayerProtocolNegotiation -> IO (Maybe ExtensionRaw))
-> IO (Maybe ExtensionRaw))
-> (ApplicationLayerProtocolNegotiation -> IO (Maybe ExtensionRaw))
-> IO (Maybe ExtensionRaw)
forall a b. (a -> b) -> a -> b
$ ([ByteString] -> IO ByteString)
-> ApplicationLayerProtocolNegotiation -> IO (Maybe ExtensionRaw)
select [ByteString] -> IO ByteString
io
where
onALPN :: Maybe ([ByteString] -> IO ByteString)
onALPN = ServerHooks -> Maybe ([ByteString] -> IO ByteString)
onALPNClientSuggest (ServerHooks -> Maybe ([ByteString] -> IO ByteString))
-> ServerHooks -> Maybe ([ByteString] -> IO ByteString)
forall a b. (a -> b) -> a -> b
$ ServerParams -> ServerHooks
serverHooks ServerParams
sparams
select :: ([ByteString] -> IO ByteString)
-> ApplicationLayerProtocolNegotiation -> IO (Maybe ExtensionRaw)
select [ByteString] -> IO ByteString
io (ApplicationLayerProtocolNegotiation [ByteString]
protos) = do
ByteString
proto <- [ByteString] -> IO ByteString
io [ByteString]
protos
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ByteString
proto ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"") (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
String -> AlertDescription -> TLSError
Error_Protocol String
"no supported application protocols" AlertDescription
NoApplicationProtocol
Context -> TLSSt () -> IO ()
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx (TLSSt () -> IO ()) -> TLSSt () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
Bool -> TLSSt ()
setExtensionALPN Bool
True
ByteString -> TLSSt ()
setNegotiatedProtocol ByteString
proto
Maybe ExtensionRaw -> IO (Maybe ExtensionRaw)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ExtensionRaw -> IO (Maybe ExtensionRaw))
-> Maybe ExtensionRaw -> IO (Maybe ExtensionRaw)
forall a b. (a -> b) -> a -> b
$
ExtensionRaw -> Maybe ExtensionRaw
forall a. a -> Maybe a
Just (ExtensionRaw -> Maybe ExtensionRaw)
-> ExtensionRaw -> Maybe ExtensionRaw
forall a b. (a -> b) -> a -> b
$
ExtensionID -> ByteString -> ExtensionRaw
ExtensionRaw
ExtensionID
EID_ApplicationLayerProtocolNegotiation
(ApplicationLayerProtocolNegotiation -> ByteString
forall a. Extension a => a -> ByteString
extensionEncode (ApplicationLayerProtocolNegotiation -> ByteString)
-> ApplicationLayerProtocolNegotiation -> ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> ApplicationLayerProtocolNegotiation
ApplicationLayerProtocolNegotiation [ByteString
proto])
clientCertificate :: ServerParams -> Context -> CertificateChain -> IO ()
clientCertificate :: ServerParams -> Context -> CertificateChain -> IO ()
clientCertificate ServerParams
sparams Context
ctx CertificateChain
certs = do
Context -> (Hooks -> IO ()) -> IO ()
forall a. Context -> (Hooks -> IO a) -> IO a
ctxWithHooks Context
ctx (Hooks -> CertificateChain -> IO ()
`hookRecvCertificates` CertificateChain
certs)
CertificateUsage
usage <-
IO CertificateUsage -> IO CertificateUsage
forall a. IO a -> IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO CertificateUsage -> IO CertificateUsage)
-> IO CertificateUsage -> IO CertificateUsage
forall a b. (a -> b) -> a -> b
$
IO CertificateUsage
-> (SomeException -> IO CertificateUsage) -> IO CertificateUsage
forall a. IO a -> (SomeException -> IO a) -> IO a
catchException
(ServerHooks -> CertificateChain -> IO CertificateUsage
onClientCertificate (ServerParams -> ServerHooks
serverHooks ServerParams
sparams) CertificateChain
certs)
SomeException -> IO CertificateUsage
rejectOnException
case CertificateUsage
usage of
CertificateUsage
CertificateUsageAccept -> [ExtKeyUsageFlag] -> CertificateChain -> IO ()
forall (m :: * -> *).
MonadIO m =>
[ExtKeyUsageFlag] -> CertificateChain -> m ()
verifyLeafKeyUsage [ExtKeyUsageFlag
KeyUsage_digitalSignature] CertificateChain
certs
CertificateUsageReject CertificateRejectReason
reason -> CertificateRejectReason -> IO ()
forall (m :: * -> *) a. MonadIO m => CertificateRejectReason -> m a
certificateRejected CertificateRejectReason
reason
Context -> HandshakeM () -> IO ()
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (HandshakeM () -> IO ()) -> HandshakeM () -> IO ()
forall a b. (a -> b) -> a -> b
$ CertificateChain -> HandshakeM ()
setClientCertChain CertificateChain
certs
hashAndSignaturesInCommon
:: [HashAndSignatureAlgorithm] -> [ExtensionRaw] -> [HashAndSignatureAlgorithm]
hashAndSignaturesInCommon :: [HashAndSignatureAlgorithm]
-> [ExtensionRaw] -> [HashAndSignatureAlgorithm]
hashAndSignaturesInCommon [HashAndSignatureAlgorithm]
sHashSigs [ExtensionRaw]
exts = [HashAndSignatureAlgorithm]
sHashSigs [HashAndSignatureAlgorithm]
-> [HashAndSignatureAlgorithm] -> [HashAndSignatureAlgorithm]
forall a. Eq a => [a] -> [a] -> [a]
`intersect` [HashAndSignatureAlgorithm]
cHashSigs
where
defVal :: [HashAndSignatureAlgorithm]
defVal =
[ (HashAlgorithm
HashSHA1, SignatureAlgorithm
SignatureECDSA)
, (HashAlgorithm
HashSHA1, SignatureAlgorithm
SignatureRSA)
, (HashAlgorithm
HashSHA1, SignatureAlgorithm
SignatureDSA)
]
cHashSigs :: [HashAndSignatureAlgorithm]
cHashSigs =
ExtensionID
-> MessageType
-> [ExtensionRaw]
-> [HashAndSignatureAlgorithm]
-> (SignatureAlgorithms -> [HashAndSignatureAlgorithm])
-> [HashAndSignatureAlgorithm]
forall e a.
Extension e =>
ExtensionID -> MessageType -> [ExtensionRaw] -> a -> (e -> a) -> a
lookupAndDecode
ExtensionID
EID_SignatureAlgorithms
MessageType
MsgTClientHello
[ExtensionRaw]
exts
[HashAndSignatureAlgorithm]
defVal
(\(SignatureAlgorithms [HashAndSignatureAlgorithm]
sas) -> [HashAndSignatureAlgorithm]
sas)