{-# LANGUAGE OverloadedStrings #-}
module Kubernetes.Client.Internal.TLSUtils where

import Control.Exception.Safe     (Exception, MonadThrow, throwM)
import Control.Monad.IO.Class     (MonadIO, liftIO)
import Data.ByteString            (ByteString)
import Data.Default.Class         (def)
import Data.Either                (rights)
import Data.Either.Combinators    (mapLeft)
import Data.PEM                   (pemContent, pemParseBS)
import Data.X509                  (SignedCertificate, decodeSignedCertificate)
import Data.X509.CertificateStore (CertificateStore, makeCertificateStore)
import Lens.Micro                 ((&), (.~), Lens', lens, set)
import Network.TLS                (Credential, credentialLoadX509FromMemory, defaultParamsClient)
import System.X509                (getSystemCertificateStore)

import qualified Data.ByteString        as B
import qualified Data.ByteString.Base64 as B64
import qualified Data.X509              as X509
import qualified Data.X509.Validation   as X509
import qualified Network.TLS            as TLS
import qualified Network.TLS.Extra      as TLS

-- |Default TLS settings using the system CA store.
defaultTLSClientParams :: IO TLS.ClientParams
defaultTLSClientParams :: IO ClientParams
defaultTLSClientParams = do
    let defParams :: ClientParams
defParams = HostName -> ByteString -> ClientParams
defaultParamsClient HostName
"" ByteString
""
    CertificateStore
systemCAStore <- IO CertificateStore
getSystemCertificateStore
    ClientParams -> IO ClientParams
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ClientParams
defParams
        { TLS.clientSupported = def
            { TLS.supportedCiphers = TLS.ciphersuite_strong
            }
        , TLS.clientShared    = (TLS.clientShared defParams)
            { TLS.sharedCAStore = systemCAStore
            }
        }

-- |Parses a PEM-encoded @ByteString@ into a list of certificates.
parsePEMCerts :: B.ByteString -> Either ParseCertException [SignedCertificate]
parsePEMCerts :: ByteString -> Either ParseCertException [SignedCertificate]
parsePEMCerts ByteString
pemBS = do
    [PEM]
pems <- ByteString -> Either HostName [PEM]
pemParseBS ByteString
pemBS
            Either HostName [PEM]
-> (Either HostName [PEM] -> Either ParseCertException [PEM])
-> Either ParseCertException [PEM]
forall a b. a -> (a -> b) -> b
& (HostName -> ParseCertException)
-> Either HostName [PEM] -> Either ParseCertException [PEM]
forall a c b. (a -> c) -> Either a b -> Either c b
mapLeft HostName -> ParseCertException
PEMParsingFailed
    [SignedCertificate]
-> Either ParseCertException [SignedCertificate]
forall a. a -> Either ParseCertException a
forall (m :: * -> *) a. Monad m => a -> m a
return ([SignedCertificate]
 -> Either ParseCertException [SignedCertificate])
-> [SignedCertificate]
-> Either ParseCertException [SignedCertificate]
forall a b. (a -> b) -> a -> b
$ [Either HostName SignedCertificate] -> [SignedCertificate]
forall a b. [Either a b] -> [b]
rights ([Either HostName SignedCertificate] -> [SignedCertificate])
-> [Either HostName SignedCertificate] -> [SignedCertificate]
forall a b. (a -> b) -> a -> b
$ (PEM -> Either HostName SignedCertificate)
-> [PEM] -> [Either HostName SignedCertificate]
forall a b. (a -> b) -> [a] -> [b]
map (ByteString -> Either HostName SignedCertificate
decodeSignedCertificate (ByteString -> Either HostName SignedCertificate)
-> (PEM -> ByteString) -> PEM -> Either HostName SignedCertificate
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PEM -> ByteString
pemContent) [PEM]
pems

-- | Updates client params, sets CA store to passed bytestring of CA certificates
updateClientParams :: TLS.ClientParams -> ByteString -> Either ParseCertException TLS.ClientParams
updateClientParams :: ClientParams
-> ByteString -> Either ParseCertException ClientParams
updateClientParams ClientParams
cp ByteString
certText = ByteString -> Either ParseCertException [SignedCertificate]
parsePEMCerts ByteString
certText
                                 Either ParseCertException [SignedCertificate]
-> (Either ParseCertException [SignedCertificate]
    -> Either ParseCertException ClientParams)
-> Either ParseCertException ClientParams
forall a b. a -> (a -> b) -> b
& (([SignedCertificate] -> ClientParams)
-> Either ParseCertException [SignedCertificate]
-> Either ParseCertException ClientParams
forall a b.
(a -> b)
-> Either ParseCertException a -> Either ParseCertException b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([SignedCertificate] -> ClientParams -> ClientParams)
-> ClientParams -> [SignedCertificate] -> ClientParams
forall a b c. (a -> b -> c) -> b -> a -> c
flip [SignedCertificate] -> ClientParams -> ClientParams
setCAStore ClientParams
cp))

-- |Use a custom CA store.
setCAStore :: [SignedCertificate] -> TLS.ClientParams -> TLS.ClientParams
setCAStore :: [SignedCertificate] -> ClientParams -> ClientParams
setCAStore [SignedCertificate]
certs ClientParams
tlsParams =
  ClientParams
tlsParams ClientParams -> (ClientParams -> ClientParams) -> ClientParams
forall a b. a -> (a -> b) -> b
& (Shared -> Identity Shared)
-> ClientParams -> Identity ClientParams
Lens' ClientParams Shared
clientSharedL ((Shared -> Identity Shared)
 -> ClientParams -> Identity ClientParams)
-> ((CertificateStore -> Identity CertificateStore)
    -> Shared -> Identity Shared)
-> (CertificateStore -> Identity CertificateStore)
-> ClientParams
-> Identity ClientParams
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (CertificateStore -> Identity CertificateStore)
-> Shared -> Identity Shared
Lens' Shared CertificateStore
sharedCAStoreL ((CertificateStore -> Identity CertificateStore)
 -> ClientParams -> Identity ClientParams)
-> CertificateStore -> ClientParams -> ClientParams
forall s t a b. ASetter s t a b -> b -> s -> t
.~ [SignedCertificate] -> CertificateStore
makeCertificateStore [SignedCertificate]
certs

-- |Use a client cert for authentication.
setClientCert :: Credential -> TLS.ClientParams -> TLS.ClientParams
setClientCert :: Credential -> ClientParams -> ClientParams
setClientCert Credential
cred = ASetter
  ClientParams
  ClientParams
  (([CertificateType], Maybe [HashAndSignatureAlgorithm],
    [DistinguishedName])
   -> IO (Maybe Credential))
  (([CertificateType], Maybe [HashAndSignatureAlgorithm],
    [DistinguishedName])
   -> IO (Maybe Credential))
-> (([CertificateType], Maybe [HashAndSignatureAlgorithm],
     [DistinguishedName])
    -> IO (Maybe Credential))
-> ClientParams
-> ClientParams
forall s t a b. ASetter s t a b -> b -> s -> t
set ASetter
  ClientParams
  ClientParams
  (([CertificateType], Maybe [HashAndSignatureAlgorithm],
    [DistinguishedName])
   -> IO (Maybe Credential))
  (([CertificateType], Maybe [HashAndSignatureAlgorithm],
    [DistinguishedName])
   -> IO (Maybe Credential))
Lens'
  ClientParams
  (([CertificateType], Maybe [HashAndSignatureAlgorithm],
    [DistinguishedName])
   -> IO (Maybe Credential))
onCertificateRequestL (\([CertificateType], Maybe [HashAndSignatureAlgorithm],
 [DistinguishedName])
_ -> Maybe Credential -> IO (Maybe Credential)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Credential -> IO (Maybe Credential))
-> Maybe Credential -> IO (Maybe Credential)
forall a b. (a -> b) -> a -> b
$ Credential -> Maybe Credential
forall a. a -> Maybe a
Just Credential
cred)

clientHooksL :: Lens' TLS.ClientParams TLS.ClientHooks
clientHooksL :: Lens' ClientParams ClientHooks
clientHooksL = (ClientParams -> ClientHooks)
-> (ClientParams -> ClientHooks -> ClientParams)
-> Lens' ClientParams ClientHooks
forall s a b t. (s -> a) -> (s -> b -> t) -> Lens s t a b
lens ClientParams -> ClientHooks
TLS.clientHooks (\ClientParams
cp ClientHooks
ch -> ClientParams
cp { TLS.clientHooks = ch })

onServerCertificateL :: Lens' TLS.ClientParams (CertificateStore -> TLS.ValidationCache -> X509.ServiceID -> X509.CertificateChain -> IO [X509.FailedReason])
onServerCertificateL :: Lens'
  ClientParams
  (CertificateStore
   -> ValidationCache
   -> ServiceID
   -> CertificateChain
   -> IO [FailedReason])
onServerCertificateL =
  (ClientHooks -> f ClientHooks) -> ClientParams -> f ClientParams
Lens' ClientParams ClientHooks
clientHooksL ((ClientHooks -> f ClientHooks) -> ClientParams -> f ClientParams)
-> (((CertificateStore
      -> ValidationCache
      -> ServiceID
      -> CertificateChain
      -> IO [FailedReason])
     -> f (CertificateStore
           -> ValidationCache
           -> ServiceID
           -> CertificateChain
           -> IO [FailedReason]))
    -> ClientHooks -> f ClientHooks)
-> ((CertificateStore
     -> ValidationCache
     -> ServiceID
     -> CertificateChain
     -> IO [FailedReason])
    -> f (CertificateStore
          -> ValidationCache
          -> ServiceID
          -> CertificateChain
          -> IO [FailedReason]))
-> ClientParams
-> f ClientParams
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ClientHooks
 -> CertificateStore
 -> ValidationCache
 -> ServiceID
 -> CertificateChain
 -> IO [FailedReason])
-> (ClientHooks
    -> (CertificateStore
        -> ValidationCache
        -> ServiceID
        -> CertificateChain
        -> IO [FailedReason])
    -> ClientHooks)
-> Lens
     ClientHooks
     ClientHooks
     (CertificateStore
      -> ValidationCache
      -> ServiceID
      -> CertificateChain
      -> IO [FailedReason])
     (CertificateStore
      -> ValidationCache
      -> ServiceID
      -> CertificateChain
      -> IO [FailedReason])
forall s a b t. (s -> a) -> (s -> b -> t) -> Lens s t a b
lens ClientHooks
-> CertificateStore
-> ValidationCache
-> ServiceID
-> CertificateChain
-> IO [FailedReason]
TLS.onServerCertificate (\ClientHooks
ch CertificateStore
-> ValidationCache
-> ServiceID
-> CertificateChain
-> IO [FailedReason]
osc -> ClientHooks
ch { TLS.onServerCertificate = osc })

clientSharedL :: Lens' TLS.ClientParams TLS.Shared
clientSharedL :: Lens' ClientParams Shared
clientSharedL = (ClientParams -> Shared)
-> (ClientParams -> Shared -> ClientParams)
-> Lens' ClientParams Shared
forall s a b t. (s -> a) -> (s -> b -> t) -> Lens s t a b
lens ClientParams -> Shared
TLS.clientShared (\ClientParams
tlsParams Shared
cs -> ClientParams
tlsParams {TLS.clientShared = cs} )

sharedCAStoreL :: Lens' TLS.Shared CertificateStore
sharedCAStoreL :: Lens' Shared CertificateStore
sharedCAStoreL = (Shared -> CertificateStore)
-> (Shared -> CertificateStore -> Shared)
-> Lens' Shared CertificateStore
forall s a b t. (s -> a) -> (s -> b -> t) -> Lens s t a b
lens Shared -> CertificateStore
TLS.sharedCAStore (\Shared
shared CertificateStore
store -> Shared
shared{TLS.sharedCAStore = store})

-- |Don't check whether the cert presented by the server matches the name of the server you are connecting to.
-- This is necessary if you specify the server host by its IP address.
disableServerNameValidation :: TLS.ClientParams -> TLS.ClientParams
disableServerNameValidation :: ClientParams -> ClientParams
disableServerNameValidation =
  ASetter
  ClientParams
  ClientParams
  (CertificateStore
   -> ValidationCache
   -> ServiceID
   -> CertificateChain
   -> IO [FailedReason])
  (CertificateStore
   -> ValidationCache
   -> ServiceID
   -> CertificateChain
   -> IO [FailedReason])
-> (CertificateStore
    -> ValidationCache
    -> ServiceID
    -> CertificateChain
    -> IO [FailedReason])
-> ClientParams
-> ClientParams
forall s t a b. ASetter s t a b -> b -> s -> t
set ASetter
  ClientParams
  ClientParams
  (CertificateStore
   -> ValidationCache
   -> ServiceID
   -> CertificateChain
   -> IO [FailedReason])
  (CertificateStore
   -> ValidationCache
   -> ServiceID
   -> CertificateChain
   -> IO [FailedReason])
Lens'
  ClientParams
  (CertificateStore
   -> ValidationCache
   -> ServiceID
   -> CertificateChain
   -> IO [FailedReason])
onServerCertificateL (HashALG
-> ValidationHooks
-> ValidationChecks
-> CertificateStore
-> ValidationCache
-> ServiceID
-> CertificateChain
-> IO [FailedReason]
X509.validate HashALG
X509.HashSHA256 ValidationHooks
forall a. Default a => a
def (ValidationChecks
forall a. Default a => a
def { X509.checkFQHN = False }))

-- |Insecure mode. The client will not validate the server cert at all.
disableServerCertValidation :: TLS.ClientParams -> TLS.ClientParams
disableServerCertValidation :: ClientParams -> ClientParams
disableServerCertValidation = ASetter
  ClientParams
  ClientParams
  (CertificateStore
   -> ValidationCache
   -> ServiceID
   -> CertificateChain
   -> IO [FailedReason])
  (CertificateStore
   -> ValidationCache
   -> ServiceID
   -> CertificateChain
   -> IO [FailedReason])
-> (CertificateStore
    -> ValidationCache
    -> ServiceID
    -> CertificateChain
    -> IO [FailedReason])
-> ClientParams
-> ClientParams
forall s t a b. ASetter s t a b -> b -> s -> t
set ASetter
  ClientParams
  ClientParams
  (CertificateStore
   -> ValidationCache
   -> ServiceID
   -> CertificateChain
   -> IO [FailedReason])
  (CertificateStore
   -> ValidationCache
   -> ServiceID
   -> CertificateChain
   -> IO [FailedReason])
Lens'
  ClientParams
  (CertificateStore
   -> ValidationCache
   -> ServiceID
   -> CertificateChain
   -> IO [FailedReason])
onServerCertificateL (\CertificateStore
_ ValidationCache
_ ServiceID
_ CertificateChain
_ -> [FailedReason] -> IO [FailedReason]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return [])

onCertificateRequestL :: Lens' TLS.ClientParams (([TLS.CertificateType], Maybe [TLS.HashAndSignatureAlgorithm], [X509.DistinguishedName]) -> IO (Maybe (X509.CertificateChain, TLS.PrivKey)))
onCertificateRequestL :: Lens'
  ClientParams
  (([CertificateType], Maybe [HashAndSignatureAlgorithm],
    [DistinguishedName])
   -> IO (Maybe Credential))
onCertificateRequestL =
  (ClientHooks -> f ClientHooks) -> ClientParams -> f ClientParams
Lens' ClientParams ClientHooks
clientHooksL ((ClientHooks -> f ClientHooks) -> ClientParams -> f ClientParams)
-> (((([CertificateType], Maybe [HashAndSignatureAlgorithm],
       [DistinguishedName])
      -> IO (Maybe Credential))
     -> f (([CertificateType], Maybe [HashAndSignatureAlgorithm],
            [DistinguishedName])
           -> IO (Maybe Credential)))
    -> ClientHooks -> f ClientHooks)
-> ((([CertificateType], Maybe [HashAndSignatureAlgorithm],
      [DistinguishedName])
     -> IO (Maybe Credential))
    -> f (([CertificateType], Maybe [HashAndSignatureAlgorithm],
           [DistinguishedName])
          -> IO (Maybe Credential)))
-> ClientParams
-> f ClientParams
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ClientHooks
 -> ([CertificateType], Maybe [HashAndSignatureAlgorithm],
     [DistinguishedName])
 -> IO (Maybe Credential))
-> (ClientHooks
    -> (([CertificateType], Maybe [HashAndSignatureAlgorithm],
         [DistinguishedName])
        -> IO (Maybe Credential))
    -> ClientHooks)
-> Lens
     ClientHooks
     ClientHooks
     (([CertificateType], Maybe [HashAndSignatureAlgorithm],
       [DistinguishedName])
      -> IO (Maybe Credential))
     (([CertificateType], Maybe [HashAndSignatureAlgorithm],
       [DistinguishedName])
      -> IO (Maybe Credential))
forall s a b t. (s -> a) -> (s -> b -> t) -> Lens s t a b
lens ClientHooks
-> ([CertificateType], Maybe [HashAndSignatureAlgorithm],
    [DistinguishedName])
-> IO (Maybe Credential)
TLS.onCertificateRequest (\ClientHooks
ch ([CertificateType], Maybe [HashAndSignatureAlgorithm],
 [DistinguishedName])
-> IO (Maybe Credential)
ocr -> ClientHooks
ch { TLS.onCertificateRequest = ocr })

-- |Loads certificates from a PEM-encoded file.
loadPEMCerts :: (MonadIO m, MonadThrow m) => FilePath -> m [SignedCertificate]
loadPEMCerts :: forall (m :: * -> *).
(MonadIO m, MonadThrow m) =>
HostName -> m [SignedCertificate]
loadPEMCerts HostName
pemFile = do
    IO ByteString -> m ByteString
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (HostName -> IO ByteString
B.readFile HostName
pemFile)
        m ByteString
-> (ByteString -> m [SignedCertificate]) -> m [SignedCertificate]
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ((ParseCertException -> m [SignedCertificate])
-> ([SignedCertificate] -> m [SignedCertificate])
-> Either ParseCertException [SignedCertificate]
-> m [SignedCertificate]
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either ParseCertException -> m [SignedCertificate]
forall (m :: * -> *) e a.
(HasCallStack, MonadThrow m, Exception e) =>
e -> m a
throwM [SignedCertificate] -> m [SignedCertificate]
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return)
        (Either ParseCertException [SignedCertificate]
 -> m [SignedCertificate])
-> (ByteString -> Either ParseCertException [SignedCertificate])
-> ByteString
-> m [SignedCertificate]
forall b c a. (b -> c) -> (a -> b) -> a -> c
.   ByteString -> Either ParseCertException [SignedCertificate]
parsePEMCerts

-- |Loads Base64 encoded certificate and private key
loadB64EncodedCert :: (MonadThrow m) => B.ByteString -> B.ByteString -> m Credential
loadB64EncodedCert :: forall (m :: * -> *).
MonadThrow m =>
ByteString -> ByteString -> m Credential
loadB64EncodedCert ByteString
certB64 ByteString
keyB64 = (ParseCertException -> m Credential)
-> (Credential -> m Credential)
-> Either ParseCertException Credential
-> m Credential
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either ParseCertException -> m Credential
forall (m :: * -> *) e a.
(HasCallStack, MonadThrow m, Exception e) =>
e -> m a
throwM Credential -> m Credential
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either ParseCertException Credential -> m Credential)
-> Either ParseCertException Credential -> m Credential
forall a b. (a -> b) -> a -> b
$ do
  ByteString
certText <- ByteString -> Either HostName ByteString
B64.decode ByteString
certB64
              Either HostName ByteString
-> (Either HostName ByteString
    -> Either ParseCertException ByteString)
-> Either ParseCertException ByteString
forall a b. a -> (a -> b) -> b
& (HostName -> ParseCertException)
-> Either HostName ByteString
-> Either ParseCertException ByteString
forall a c b. (a -> c) -> Either a b -> Either c b
mapLeft HostName -> ParseCertException
Base64ParsingFailed
  ByteString
keyText <- ByteString -> Either HostName ByteString
B64.decode ByteString
keyB64
              Either HostName ByteString
-> (Either HostName ByteString
    -> Either ParseCertException ByteString)
-> Either ParseCertException ByteString
forall a b. a -> (a -> b) -> b
& (HostName -> ParseCertException)
-> Either HostName ByteString
-> Either ParseCertException ByteString
forall a c b. (a -> c) -> Either a b -> Either c b
mapLeft HostName -> ParseCertException
Base64ParsingFailed
  ByteString -> ByteString -> Either HostName Credential
credentialLoadX509FromMemory ByteString
certText ByteString
keyText
    Either HostName Credential
-> (Either HostName Credential
    -> Either ParseCertException Credential)
-> Either ParseCertException Credential
forall a b. a -> (a -> b) -> b
& (HostName -> ParseCertException)
-> Either HostName Credential
-> Either ParseCertException Credential
forall a c b. (a -> c) -> Either a b -> Either c b
mapLeft HostName -> ParseCertException
FailedToLoadCredential

data ParseCertException = PEMParsingFailed String
                        | Base64ParsingFailed String
                        | FailedToLoadCredential String
  deriving Int -> ParseCertException -> ShowS
[ParseCertException] -> ShowS
ParseCertException -> HostName
(Int -> ParseCertException -> ShowS)
-> (ParseCertException -> HostName)
-> ([ParseCertException] -> ShowS)
-> Show ParseCertException
forall a.
(Int -> a -> ShowS) -> (a -> HostName) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ParseCertException -> ShowS
showsPrec :: Int -> ParseCertException -> ShowS
$cshow :: ParseCertException -> HostName
show :: ParseCertException -> HostName
$cshowList :: [ParseCertException] -> ShowS
showList :: [ParseCertException] -> ShowS
Show

instance Exception ParseCertException