module Wai.CSRF
( Config (..)
, defaultConfig
, tokenFromRequestHeader
, tokenFromRequestCookie
, setCookie
, expireCookie
, middleware
, Token (..)
, randomToken
, tokenToBase64UU
, tokenFromBase64UU
, MaskedToken (..)
, maskedTokenToBase64UU
, maskedTokenFromBase64UU
, randomMaskToken
, unmaskToken
) where
import Crypto.Random qualified as C
import Data.ByteArray qualified as BA
import Data.ByteArray.Encoding qualified as BA
import Data.ByteArray.Sized qualified as BAS
import Data.ByteString qualified as B
import Data.CaseInsensitive qualified as CI
import Data.Time.Clock.POSIX qualified as Time
import Network.HTTP.Types qualified as H
import Network.Wai qualified as Wai
import Web.Cookie qualified as C
newtype Token = Token (BAS.SizedByteArray 32 B.ByteString)
instance Show Token where
showsPrec :: Int -> Token -> ShowS
showsPrec Int
n (Token SizedByteArray 32 ByteString
s) = Int -> ByteString -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
n (ByteString -> ShowS) -> ByteString -> ShowS
forall a b. (a -> b) -> a -> b
$ SizedByteArray 32 ByteString -> ByteString
forall (n :: Nat) ba. SizedByteArray n ba -> ba
BAS.unSizedByteArray SizedByteArray 32 ByteString
s
instance Eq Token where
Token SizedByteArray 32 ByteString
a == :: Token -> Token -> Bool
== Token SizedByteArray 32 ByteString
b = SizedByteArray 32 ByteString
-> SizedByteArray 32 ByteString -> Bool
forall bs1 bs2.
(ByteArrayAccess bs1, ByteArrayAccess bs2) =>
bs1 -> bs2 -> Bool
BA.constEq SizedByteArray 32 ByteString
a SizedByteArray 32 ByteString
b
randomToken :: (C.MonadRandom m) => m Token
randomToken :: forall (m :: * -> *). MonadRandom m => m Token
randomToken = (ByteString -> Token) -> m ByteString -> m Token
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (SizedByteArray 32 ByteString -> Token
Token (SizedByteArray 32 ByteString -> Token)
-> (ByteString -> SizedByteArray 32 ByteString)
-> ByteString
-> Token
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> SizedByteArray 32 ByteString
forall (n :: Nat) ba.
(ByteArrayAccess ba, KnownNat n) =>
ba -> SizedByteArray n ba
BAS.unsafeSizedByteArray) (Int -> m ByteString
forall byteArray. ByteArray byteArray => Int -> m byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
C.getRandomBytes Int
32)
tokenToBase64UU :: Token -> B.ByteString
tokenToBase64UU :: Token -> ByteString
tokenToBase64UU (Token SizedByteArray 32 ByteString
t) =
Base -> SizedByteArray 32 ByteString -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
Base -> bin -> bout
BA.convertToBase Base
BA.Base64URLUnpadded SizedByteArray 32 ByteString
t
tokenFromBase64UU :: B.ByteString -> Maybe Token
tokenFromBase64UU :: ByteString -> Maybe Token
tokenFromBase64UU ByteString
b =
case Base -> ByteString -> Either String Bytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
Base -> bin -> Either String bout
BA.convertFromBase Base
BA.Base64URLUnpadded ByteString
b of
Right (Bytes
x :: BA.Bytes) -> SizedByteArray 32 ByteString -> Token
Token (SizedByteArray 32 ByteString -> Token)
-> Maybe (SizedByteArray 32 ByteString) -> Maybe Token
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Bytes -> Maybe (SizedByteArray 32 ByteString)
forall (n :: Nat) bin bout.
(ByteArrayAccess bin, ByteArrayN n bout, KnownNat n) =>
bin -> Maybe bout
BAS.fromByteArrayAccess Bytes
x
Either String Bytes
_ -> Maybe Token
forall a. Maybe a
Nothing
newtype MaskedToken = MaskedToken (BAS.SizedByteArray 64 BA.Bytes)
instance Show MaskedToken where
showsPrec :: Int -> MaskedToken -> ShowS
showsPrec Int
n (MaskedToken SizedByteArray 64 Bytes
s) = Int -> Bytes -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
n (Bytes -> ShowS) -> Bytes -> ShowS
forall a b. (a -> b) -> a -> b
$ SizedByteArray 64 Bytes -> Bytes
forall (n :: Nat) ba. SizedByteArray n ba -> ba
BAS.unSizedByteArray SizedByteArray 64 Bytes
s
instance Eq MaskedToken where
MaskedToken SizedByteArray 64 Bytes
a == :: MaskedToken -> MaskedToken -> Bool
== MaskedToken SizedByteArray 64 Bytes
b = SizedByteArray 64 Bytes -> SizedByteArray 64 Bytes -> Bool
forall bs1 bs2.
(ByteArrayAccess bs1, ByteArrayAccess bs2) =>
bs1 -> bs2 -> Bool
BA.constEq SizedByteArray 64 Bytes
a SizedByteArray 64 Bytes
b
toMaskedToken :: Mask -> Token -> MaskedToken
toMaskedToken :: Mask -> Token -> MaskedToken
toMaskedToken (Mask SizedByteArray 32 Bytes
m) (Token SizedByteArray 32 ByteString
s) =
let x :: SizedByteArray 32 Bytes
x = SizedByteArray 32 Bytes
-> SizedByteArray 32 ByteString -> SizedByteArray 32 Bytes
forall (n :: Nat) a b c.
(ByteArrayN n a, ByteArrayN n b, ByteArrayN n c, ByteArrayAccess a,
ByteArrayAccess b, KnownNat n) =>
a -> b -> c
BAS.xor SizedByteArray 32 Bytes
m SizedByteArray 32 ByteString
s
in SizedByteArray 64 Bytes -> MaskedToken
MaskedToken (SizedByteArray 64 Bytes -> MaskedToken)
-> SizedByteArray 64 Bytes -> MaskedToken
forall a b. (a -> b) -> a -> b
$! SizedByteArray 32 Bytes
-> SizedByteArray 32 Bytes -> SizedByteArray 64 Bytes
forall (nblhs :: Nat) (nbrhs :: Nat) (nbout :: Nat) blhs brhs bout.
(ByteArrayN nblhs blhs, ByteArrayN nbrhs brhs,
ByteArrayN nbout bout, ByteArrayAccess blhs, ByteArrayAccess brhs,
KnownNat nblhs, KnownNat nbrhs, KnownNat nbout,
(nbrhs + nblhs) ~ nbout) =>
blhs -> brhs -> bout
BAS.append SizedByteArray 32 Bytes
m (SizedByteArray 32 Bytes
x SizedByteArray 32 Bytes
-> SizedByteArray 32 Bytes -> SizedByteArray 32 Bytes
forall a. a -> a -> a
`asTypeOf` SizedByteArray 32 Bytes
m)
fromMaskedToken :: MaskedToken -> (Mask, Token)
fromMaskedToken :: MaskedToken -> (Mask, Token)
fromMaskedToken (MaskedToken SizedByteArray 64 Bytes
t) =
let (SizedByteArray 32 Bytes
m, SizedByteArray 32 Bytes
x) = SizedByteArray 64 Bytes
-> (SizedByteArray 32 Bytes, SizedByteArray 32 Bytes)
forall (nblhs :: Nat) (nbi :: Nat) (nbrhs :: Nat) bi blhs brhs.
(ByteArrayN nbi bi, ByteArrayN nblhs blhs, ByteArrayN nbrhs brhs,
ByteArrayAccess bi, KnownNat nbi, KnownNat nblhs, KnownNat nbrhs,
nblhs <= nbi, (nbrhs + nblhs) ~ nbi) =>
bi -> (blhs, brhs)
BAS.splitAt SizedByteArray 64 Bytes
t
in (SizedByteArray 32 Bytes -> Mask
Mask SizedByteArray 32 Bytes
m, SizedByteArray 32 ByteString -> Token
Token (SizedByteArray 32 ByteString -> Token)
-> SizedByteArray 32 ByteString -> Token
forall a b. (a -> b) -> a -> b
$! SizedByteArray 32 Bytes
-> SizedByteArray 32 Bytes -> SizedByteArray 32 ByteString
forall (n :: Nat) a b c.
(ByteArrayN n a, ByteArrayN n b, ByteArrayN n c, ByteArrayAccess a,
ByteArrayAccess b, KnownNat n) =>
a -> b -> c
BAS.xor SizedByteArray 32 Bytes
m (SizedByteArray 32 Bytes
x SizedByteArray 32 Bytes
-> SizedByteArray 32 Bytes -> SizedByteArray 32 Bytes
forall a. a -> a -> a
`asTypeOf` SizedByteArray 32 Bytes
m))
maskedTokenToBase64UU :: MaskedToken -> B.ByteString
maskedTokenToBase64UU :: MaskedToken -> ByteString
maskedTokenToBase64UU (MaskedToken SizedByteArray 64 Bytes
t) = Base -> SizedByteArray 64 Bytes -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
Base -> bin -> bout
BA.convertToBase Base
BA.Base64URLUnpadded SizedByteArray 64 Bytes
t
maskedTokenFromBase64UU :: B.ByteString -> Maybe MaskedToken
maskedTokenFromBase64UU :: ByteString -> Maybe MaskedToken
maskedTokenFromBase64UU ByteString
b = case Base -> ByteString -> Either String Bytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
Base -> bin -> Either String bout
BA.convertFromBase Base
BA.Base64URLUnpadded ByteString
b of
Right (Bytes
x :: BA.Bytes) -> SizedByteArray 64 Bytes -> MaskedToken
MaskedToken (SizedByteArray 64 Bytes -> MaskedToken)
-> Maybe (SizedByteArray 64 Bytes) -> Maybe MaskedToken
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Bytes -> Maybe (SizedByteArray 64 Bytes)
forall (n :: Nat) bin bout.
(ByteArrayAccess bin, ByteArrayN n bout, KnownNat n) =>
bin -> Maybe bout
BAS.fromByteArrayAccess Bytes
x
Either String Bytes
_ -> Maybe MaskedToken
forall a. Maybe a
Nothing
newtype Mask = Mask (BAS.SizedByteArray 32 BA.Bytes)
instance Show Mask where
showsPrec :: Int -> Mask -> ShowS
showsPrec Int
n (Mask SizedByteArray 32 Bytes
s) = Int -> Bytes -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
n (Bytes -> ShowS) -> Bytes -> ShowS
forall a b. (a -> b) -> a -> b
$ SizedByteArray 32 Bytes -> Bytes
forall (n :: Nat) ba. SizedByteArray n ba -> ba
BAS.unSizedByteArray SizedByteArray 32 Bytes
s
instance Eq Mask where
Mask SizedByteArray 32 Bytes
a == :: Mask -> Mask -> Bool
== Mask SizedByteArray 32 Bytes
b = SizedByteArray 32 Bytes -> SizedByteArray 32 Bytes -> Bool
forall bs1 bs2.
(ByteArrayAccess bs1, ByteArrayAccess bs2) =>
bs1 -> bs2 -> Bool
BA.constEq SizedByteArray 32 Bytes
a SizedByteArray 32 Bytes
b
randomMask :: (C.MonadRandom m) => m Mask
randomMask :: forall (m :: * -> *). MonadRandom m => m Mask
randomMask = (Bytes -> Mask) -> m Bytes -> m Mask
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (SizedByteArray 32 Bytes -> Mask
Mask (SizedByteArray 32 Bytes -> Mask)
-> (Bytes -> SizedByteArray 32 Bytes) -> Bytes -> Mask
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bytes -> SizedByteArray 32 Bytes
forall (n :: Nat) ba.
(ByteArrayAccess ba, KnownNat n) =>
ba -> SizedByteArray n ba
BAS.unsafeSizedByteArray) (Int -> m Bytes
forall byteArray. ByteArray byteArray => Int -> m byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
C.getRandomBytes Int
32)
randomMaskToken :: (C.MonadRandom m) => Token -> m MaskedToken
randomMaskToken :: forall (m :: * -> *). MonadRandom m => Token -> m MaskedToken
randomMaskToken Token
t = (Mask -> Token -> MaskedToken) -> Token -> Mask -> MaskedToken
forall a b c. (a -> b -> c) -> b -> a -> c
flip Mask -> Token -> MaskedToken
toMaskedToken Token
t (Mask -> MaskedToken) -> m Mask -> m MaskedToken
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m Mask
forall (m :: * -> *). MonadRandom m => m Mask
randomMask
unmaskToken :: MaskedToken -> Token
unmaskToken :: MaskedToken -> Token
unmaskToken = (Mask, Token) -> Token
forall a b. (a, b) -> b
snd ((Mask, Token) -> Token)
-> (MaskedToken -> (Mask, Token)) -> MaskedToken -> Token
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MaskedToken -> (Mask, Token)
fromMaskedToken
data Config = Config
{ Config -> ByteString
cookieName :: B.ByteString
, :: B.ByteString
, Config -> Request -> Maybe (Token, Bool) -> Maybe Response
reject :: Wai.Request -> Maybe (Token, Bool) -> Maybe Wai.Response
}
defaultConfig :: Config
defaultConfig :: Config
defaultConfig =
Config
{ cookieName :: ByteString
cookieName = ByteString
"CSRF-TOKEN"
, headerName :: ByteString
headerName = ByteString
"X-CSRF-TOKEN"
, reject :: Request -> Maybe (Token, Bool) -> Maybe Response
reject = \Request
req Maybe (Token, Bool)
yteq ->
if
| Request -> ByteString
Wai.requestMethod Request
req ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
H.methodGet -> Maybe Response
forall a. Maybe a
Nothing
| Request -> ByteString
Wai.requestMethod Request
req ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
H.methodHead -> Maybe Response
forall a. Maybe a
Nothing
| Request -> ByteString
Wai.requestMethod Request
req ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
H.methodOptions -> Maybe Response
forall a. Maybe a
Nothing
| Request -> ByteString
Wai.requestMethod Request
req ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
H.methodTrace -> Maybe Response
forall a. Maybe a
Nothing
| Just (Token
_, Bool
eq) <- Maybe (Token, Bool)
yteq, Bool
eq -> Maybe Response
forall a. Maybe a
Nothing
| Bool
otherwise -> Response -> Maybe Response
forall a. a -> Maybe a
Just (Response -> Maybe Response) -> Response -> Maybe Response
forall a b. (a -> b) -> a -> b
$ Status -> ResponseHeaders -> ByteString -> Response
Wai.responseLBS Status
H.forbidden403 [] ByteString
"CSRF"
}
tokenFromRequestHeader :: Config -> Wai.Request -> Maybe Token
Config
c = \Request
r -> do
[ByteString
t64] <- [ByteString] -> Maybe [ByteString]
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([ByteString] -> Maybe [ByteString])
-> [ByteString] -> Maybe [ByteString]
forall a b. (a -> b) -> a -> b
$ HeaderName -> ResponseHeaders -> [ByteString]
forall k v. Eq k => k -> [(k, v)] -> [v]
lookupMany HeaderName
n (ResponseHeaders -> [ByteString])
-> ResponseHeaders -> [ByteString]
forall a b. (a -> b) -> a -> b
$ Request -> ResponseHeaders
Wai.requestHeaders Request
r
ByteString -> Maybe Token
tokenFromBase64UU ByteString
t64
where
n :: HeaderName
n = ByteString -> HeaderName
forall s. FoldCase s => s -> CI s
CI.mk Config
c.headerName
tokenFromRequestCookie :: Config -> Wai.Request -> Maybe Token
tokenFromRequestCookie :: Config -> Request -> Maybe Token
tokenFromRequestCookie Config
c Request
r = do
[ByteString
t64] <- [ByteString] -> Maybe [ByteString]
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([ByteString] -> Maybe [ByteString])
-> [ByteString] -> Maybe [ByteString]
forall a b. (a -> b) -> a -> b
$ ByteString -> [(ByteString, ByteString)] -> [ByteString]
forall k v. Eq k => k -> [(k, v)] -> [v]
lookupMany Config
c.cookieName ([(ByteString, ByteString)] -> [ByteString])
-> [(ByteString, ByteString)] -> [ByteString]
forall a b. (a -> b) -> a -> b
$ Request -> [(ByteString, ByteString)]
requestCookies Request
r
ByteString -> Maybe Token
tokenFromBase64UU ByteString
t64
setCookie :: Config -> Token -> C.SetCookie
setCookie :: Config -> Token -> SetCookie
setCookie Config
c Token
tok =
(Config -> SetCookie
expireCookie Config
c)
{ C.setCookieValue = tokenToBase64UU tok
, C.setCookieExpires = Nothing
, C.setCookieMaxAge = Nothing
}
expireCookie :: Config -> C.SetCookie
expireCookie :: Config -> SetCookie
expireCookie Config
c =
SetCookie
C.defaultSetCookie
{ C.setCookieName = c.cookieName
, C.setCookieValue = ""
, C.setCookieDomain = Nothing
, C.setCookieExpires = Just (Time.posixSecondsToUTCTime 0)
, C.setCookieHttpOnly = False
, C.setCookieMaxAge = Just (negate 1)
, C.setCookiePath = Just "/"
, C.setCookieSameSite = Just C.sameSiteLax
, C.setCookieSecure = True
}
middleware
:: Config
-> (Maybe Token -> Wai.Application)
-> Wai.Application
middleware :: Config -> (Maybe Token -> Application) -> Application
middleware Config
c = \Maybe Token -> Application
fapp Request
req Response -> IO ResponseReceived
respond -> do
let yct :: Maybe Token
yct = Request -> Maybe Token
fyct Request
req
yte :: Maybe (Token, Bool)
yte = (Token -> Token -> (Token, Bool))
-> Maybe Token -> Maybe Token -> Maybe (Token, Bool)
forall a b c. (a -> b -> c) -> Maybe a -> Maybe b -> Maybe c
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 (\Token
ct Token
ht -> (Token
ct, Token
ct Token -> Token -> Bool
forall a. Eq a => a -> a -> Bool
== Token
ht)) Maybe Token
yct (Request -> Maybe Token
fyht Request
req)
case Config
c.reject Request
req Maybe (Token, Bool)
yte of
Maybe Response
Nothing -> Maybe Token -> Application
fapp Maybe Token
yct Request
req Response -> IO ResponseReceived
respond
Just Response
res -> Response -> IO ResponseReceived
respond Response
res
where
fyct :: Request -> Maybe Token
fyct = Config -> Request -> Maybe Token
tokenFromRequestCookie Config
c
fyht :: Request -> Maybe Token
fyht = Config -> Request -> Maybe Token
tokenFromRequestHeader Config
c
requestCookies :: Wai.Request -> [(B.ByteString, B.ByteString)]
requestCookies :: Request -> [(ByteString, ByteString)]
requestCookies Request
r = ByteString -> [(ByteString, ByteString)]
C.parseCookies (ByteString -> [(ByteString, ByteString)])
-> [ByteString] -> [(ByteString, ByteString)]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< HeaderName -> ResponseHeaders -> [ByteString]
forall k v. Eq k => k -> [(k, v)] -> [v]
lookupMany HeaderName
"Cookie" (Request -> ResponseHeaders
Wai.requestHeaders Request
r)
lookupMany :: (Eq k) => k -> [(k, v)] -> [v]
lookupMany :: forall k v. Eq k => k -> [(k, v)] -> [v]
lookupMany k
k = (k -> Bool) -> [(k, v)] -> [v]
forall k v. Eq k => (k -> Bool) -> [(k, v)] -> [v]
findMany (k -> k -> Bool
forall a. Eq a => a -> a -> Bool
== k
k)
findMany :: (Eq k) => (k -> Bool) -> [(k, v)] -> [v]
findMany :: forall k v. Eq k => (k -> Bool) -> [(k, v)] -> [v]
findMany k -> Bool
f = ((k, v) -> v) -> [(k, v)] -> [v]
forall a b. (a -> b) -> [a] -> [b]
map (k, v) -> v
forall a b. (a, b) -> b
snd ([(k, v)] -> [v]) -> ([(k, v)] -> [(k, v)]) -> [(k, v)] -> [v]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((k, v) -> Bool) -> [(k, v)] -> [(k, v)]
forall a. (a -> Bool) -> [a] -> [a]
filter (\(k
a, v
_) -> k -> Bool
f k
a)