{-# LANGUAGE AllowAmbiguousTypes #-}
module Servant.Auth.Hmac.Crypto (
    
    SecretKey (..),
    Signature (..),
    sign,
    signSHA256,
    
    RequestPayload (..),
    requestSignature,
    verifySignatureHmac,
    whitelistHeaders,
    keepWhitelistedHeaders,
    
    authHeaderName,
) where
import Crypto.Hash (hash)
import Crypto.Hash.Algorithms (MD5, SHA256)
import Crypto.Hash.IO (HashAlgorithm)
import Crypto.MAC.HMAC (HMAC (hmacGetDigest), hmac)
import Data.ByteString (ByteString)
import Data.CaseInsensitive (foldedCase)
import Data.List (sort, uncons)
import Network.HTTP.Types (Header, HeaderName, Method, RequestHeaders)
import qualified Data.ByteArray as BA (convert)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Base64 as Base64
import qualified Data.ByteString.Lazy as LBS
newtype SecretKey = SecretKey
    { SecretKey -> ByteString
unSecretKey :: ByteString
    }
newtype Signature = Signature
    { Signature -> ByteString
unSignature :: ByteString
    }
    deriving (Signature -> Signature -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Signature -> Signature -> Bool
$c/= :: Signature -> Signature -> Bool
== :: Signature -> Signature -> Bool
$c== :: Signature -> Signature -> Bool
Eq)
sign ::
    forall algo.
    (HashAlgorithm algo) =>
    
    SecretKey ->
    
    ByteString ->
    
    Signature
sign :: forall algo.
HashAlgorithm algo =>
SecretKey -> ByteString -> Signature
sign (SecretKey ByteString
sk) ByteString
msg =
    ByteString -> Signature
Signature forall a b. (a -> b) -> a -> b
$
        ByteString -> ByteString
Base64.encode forall a b. (a -> b) -> a -> b
$
            forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert forall a b. (a -> b) -> a -> b
$
                forall a. HMAC a -> Digest a
hmacGetDigest forall a b. (a -> b) -> a -> b
$
                    forall key message a.
(ByteArrayAccess key, ByteArrayAccess message, HashAlgorithm a) =>
key -> message -> HMAC a
hmac @_ @_ @algo ByteString
sk ByteString
msg
{-# INLINE sign #-}
signSHA256 :: SecretKey -> ByteString -> Signature
signSHA256 :: SecretKey -> ByteString -> Signature
signSHA256 = forall algo.
HashAlgorithm algo =>
SecretKey -> ByteString -> Signature
sign @SHA256
{-# INLINE signSHA256 #-}
data RequestPayload = RequestPayload
    { RequestPayload -> ByteString
rpMethod :: !Method
    
    , RequestPayload -> ByteString
rpContent :: !ByteString
    
    ,  :: !RequestHeaders
    
    , RequestPayload -> ByteString
rpRawUrl :: !ByteString
    
    }
    deriving (Int -> RequestPayload -> ShowS
[RequestPayload] -> ShowS
RequestPayload -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [RequestPayload] -> ShowS
$cshowList :: [RequestPayload] -> ShowS
show :: RequestPayload -> String
$cshow :: RequestPayload -> String
showsPrec :: Int -> RequestPayload -> ShowS
$cshowsPrec :: Int -> RequestPayload -> ShowS
Show)
requestSignature ::
    
    (SecretKey -> ByteString -> Signature) ->
    
    SecretKey ->
    
    RequestPayload ->
    Signature
requestSignature :: (SecretKey -> ByteString -> Signature)
-> SecretKey -> RequestPayload -> Signature
requestSignature SecretKey -> ByteString -> Signature
signer SecretKey
sk = SecretKey -> ByteString -> Signature
signer SecretKey
sk forall b c a. (b -> c) -> (a -> b) -> a -> c
. RequestPayload -> ByteString
createStringToSign
  where
    createStringToSign :: RequestPayload -> ByteString
    createStringToSign :: RequestPayload -> ByteString
createStringToSign RequestPayload{RequestHeaders
ByteString
rpRawUrl :: ByteString
rpHeaders :: RequestHeaders
rpContent :: ByteString
rpMethod :: ByteString
rpRawUrl :: RequestPayload -> ByteString
rpHeaders :: RequestPayload -> RequestHeaders
rpContent :: RequestPayload -> ByteString
rpMethod :: RequestPayload -> ByteString
..} =
        ByteString -> [ByteString] -> ByteString
BS.intercalate
            ByteString
"\n"
            [ ByteString
rpMethod
            , ByteString -> ByteString
hashMD5 ByteString
rpContent
            , RequestHeaders -> ByteString
normalizeHeaders RequestHeaders
rpHeaders
            , ByteString
rpRawUrl
            ]
    normalizeHeaders :: [Header] -> ByteString
    normalizeHeaders :: RequestHeaders -> ByteString
normalizeHeaders = ByteString -> [ByteString] -> ByteString
BS.intercalate ByteString
"\n" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Ord a => [a] -> [a]
sort forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map (HeaderName, ByteString) -> ByteString
normalize
      where
        normalize :: Header -> ByteString
        normalize :: (HeaderName, ByteString) -> ByteString
normalize (HeaderName
name, ByteString
value) = forall s. CI s -> s
foldedCase HeaderName
name forall a. Semigroup a => a -> a -> a
<> ByteString
value
whitelistHeaders :: [HeaderName]
 =
    [ HeaderName
authHeaderName
    , HeaderName
"Host"
    , HeaderName
"Accept-Encoding"
    ]
keepWhitelistedHeaders :: [Header] -> [Header]
 = forall a. (a -> Bool) -> [a] -> [a]
filter (\(HeaderName
name, ByteString
_) -> HeaderName
name forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [HeaderName]
whitelistHeaders)
verifySignatureHmac ::
    
    (SecretKey -> ByteString -> Signature) ->
    
    SecretKey ->
    RequestPayload ->
    Maybe LBS.ByteString
verifySignatureHmac :: (SecretKey -> ByteString -> Signature)
-> SecretKey -> RequestPayload -> Maybe ByteString
verifySignatureHmac SecretKey -> ByteString -> Signature
signer SecretKey
sk RequestPayload
signedPayload = case Either ByteString (RequestPayload, Signature)
unsignedPayload of
    Left ByteString
err -> forall a. a -> Maybe a
Just ByteString
err
    Right (RequestPayload
pay, Signature
sig) ->
        if Signature
sig forall a. Eq a => a -> a -> Bool
== (SecretKey -> ByteString -> Signature)
-> SecretKey -> RequestPayload -> Signature
requestSignature SecretKey -> ByteString -> Signature
signer SecretKey
sk RequestPayload
pay
            then forall a. Maybe a
Nothing
            else forall a. a -> Maybe a
Just ByteString
"Signatures don't match"
  where
    
    unsignedPayload :: Either LBS.ByteString (RequestPayload, Signature)
    unsignedPayload :: Either ByteString (RequestPayload, Signature)
unsignedPayload = case forall a. (a -> Bool) -> [a] -> (Maybe a, [a])
extractOn (HeaderName, ByteString) -> Bool
isAuthHeader forall a b. (a -> b) -> a -> b
$ RequestPayload -> RequestHeaders
rpHeaders RequestPayload
signedPayload of
        (Maybe (HeaderName, ByteString)
Nothing, RequestHeaders
_) -> forall a b. a -> Either a b
Left ByteString
"No 'Authentication' header"
        (Just (HeaderName
_, ByteString
val), RequestHeaders
headers) -> case ByteString -> ByteString -> Maybe ByteString
BS.stripPrefix ByteString
"HMAC " ByteString
val of
            Just ByteString
sig ->
                forall a b. b -> Either a b
Right
                    ( RequestPayload
signedPayload{rpHeaders :: RequestHeaders
rpHeaders = RequestHeaders
headers}
                    , ByteString -> Signature
Signature ByteString
sig
                    )
            Maybe ByteString
Nothing -> forall a b. a -> Either a b
Left ByteString
"Can not strip 'HMAC' prefix in header"
authHeaderName :: HeaderName
 = HeaderName
"Authentication"
isAuthHeader :: Header -> Bool
 = (forall a. Eq a => a -> a -> Bool
== HeaderName
authHeaderName) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst
hashMD5 :: ByteString -> ByteString
hashMD5 :: ByteString -> ByteString
hashMD5 = forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall ba a.
(ByteArrayAccess ba, HashAlgorithm a) =>
ba -> Digest a
hash @_ @MD5
extractOn :: (a -> Bool) -> [a] -> (Maybe a, [a])
 a -> Bool
p [a]
l =
    let ([a]
before, [a]
after) = forall a. (a -> Bool) -> [a] -> ([a], [a])
break a -> Bool
p [a]
l
     in case forall a. [a] -> Maybe (a, [a])
uncons [a]
after of
            Maybe (a, [a])
Nothing -> (forall a. Maybe a
Nothing, [a]
l)
            Just (a
x, [a]
xs) -> (forall a. a -> Maybe a
Just a
x, [a]
before forall a. [a] -> [a] -> [a]
++ [a]
xs)