module Servant.Auth.Server.Internal.JWT where
import Control.Lens
import Control.Monad (MonadPlus (..), guard)
import Control.Monad.Reader
import qualified Crypto.JOSE as Jose
import qualified Crypto.JWT as Jose
import Data.ByteArray (constEq)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BSL
import Data.Maybe (fromMaybe)
import Data.Time (UTCTime)
import Network.Wai (requestHeaders)
import Servant.Auth.JWT (FromJWT (..), ToJWT (..))
import Servant.Auth.Server.Internal.ConfigTypes
import Servant.Auth.Server.Internal.Types
jwtAuthCheck :: FromJWT usr => JWTSettings -> AuthCheck usr
jwtAuthCheck :: forall usr. FromJWT usr => JWTSettings -> AuthCheck usr
jwtAuthCheck JWTSettings
jwtSettings = do
Request
req <- AuthCheck Request
forall r (m :: * -> *). MonadReader r m => m r
ask
ByteString
token <- AuthCheck ByteString
-> (ByteString -> AuthCheck ByteString)
-> Maybe ByteString
-> AuthCheck ByteString
forall b a. b -> (a -> b) -> Maybe a -> b
maybe AuthCheck ByteString
forall a. Monoid a => a
mempty ByteString -> AuthCheck ByteString
forall a. a -> AuthCheck a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ByteString -> AuthCheck ByteString)
-> Maybe ByteString -> AuthCheck ByteString
forall a b. (a -> b) -> a -> b
$ do
ByteString
authHdr <- HeaderName -> [(HeaderName, ByteString)] -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
"Authorization" ([(HeaderName, ByteString)] -> Maybe ByteString)
-> [(HeaderName, ByteString)] -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$ Request -> [(HeaderName, ByteString)]
requestHeaders Request
req
let bearer :: ByteString
bearer = ByteString
"Bearer "
(ByteString
mbearer, ByteString
rest) = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt (ByteString -> Int
BS.length ByteString
bearer) ByteString
authHdr
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (ByteString
mbearer ByteString -> ByteString -> Bool
forall bs1 bs2.
(ByteArrayAccess bs1, ByteArrayAccess bs2) =>
bs1 -> bs2 -> Bool
`constEq` ByteString
bearer)
ByteString -> Maybe ByteString
forall a. a -> Maybe a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
rest
Maybe usr
verifiedJWT <- IO (Maybe usr) -> AuthCheck (Maybe usr)
forall a. IO a -> AuthCheck a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Maybe usr) -> AuthCheck (Maybe usr))
-> IO (Maybe usr) -> AuthCheck (Maybe usr)
forall a b. (a -> b) -> a -> b
$ JWTSettings -> ByteString -> IO (Maybe usr)
forall a. FromJWT a => JWTSettings -> ByteString -> IO (Maybe a)
verifyJWT JWTSettings
jwtSettings ByteString
token
case Maybe usr
verifiedJWT of
Maybe usr
Nothing -> AuthCheck usr
forall a. AuthCheck a
forall (m :: * -> *) a. MonadPlus m => m a
mzero
Just usr
v -> usr -> AuthCheck usr
forall a. a -> AuthCheck a
forall (m :: * -> *) a. Monad m => a -> m a
return usr
v
makeJWT
:: ToJWT a
=> a
-> JWTSettings
-> Maybe UTCTime
-> IO (Either Jose.Error BSL.ByteString)
makeJWT :: forall a.
ToJWT a =>
a -> JWTSettings -> Maybe UTCTime -> IO (Either Error ByteString)
makeJWT a
v JWTSettings
cfg Maybe UTCTime
expiry = JOSE Error IO ByteString -> IO (Either Error ByteString)
forall e (m :: * -> *) a. JOSE e m a -> m (Either e a)
Jose.runJOSE (JOSE Error IO ByteString -> IO (Either Error ByteString))
-> JOSE Error IO ByteString -> IO (Either Error ByteString)
forall a b. (a -> b) -> a -> b
$ do
Alg
bestAlg <- JWK -> JOSE Error IO Alg
forall e (m :: * -> *). (MonadError e m, AsError e) => JWK -> m Alg
Jose.bestJWSAlg (JWK -> JOSE Error IO Alg) -> JWK -> JOSE Error IO Alg
forall a b. (a -> b) -> a -> b
$ JWTSettings -> JWK
signingKey JWTSettings
cfg
let alg :: Alg
alg = Alg -> Maybe Alg -> Alg
forall a. a -> Maybe a -> a
fromMaybe Alg
bestAlg (Maybe Alg -> Alg) -> Maybe Alg -> Alg
forall a b. (a -> b) -> a -> b
$ JWTSettings -> Maybe Alg
jwtAlg JWTSettings
cfg
SignedJWT
ejwt <-
JWK -> JWSHeader () -> ClaimsSet -> JOSE Error IO SignedJWT
forall (m :: * -> *) e.
(MonadRandom m, MonadError e m, AsError e) =>
JWK -> JWSHeader () -> ClaimsSet -> m SignedJWT
Jose.signClaims
(JWTSettings -> JWK
signingKey JWTSettings
cfg)
(((), Alg) -> JWSHeader ()
forall p. (p, Alg) -> JWSHeader p
Jose.newJWSHeader ((), Alg
alg))
(ClaimsSet -> ClaimsSet
addExp (ClaimsSet -> ClaimsSet) -> ClaimsSet -> ClaimsSet
forall a b. (a -> b) -> a -> b
$ a -> ClaimsSet
forall a. ToJWT a => a -> ClaimsSet
encodeJWT a
v)
ByteString -> JOSE Error IO ByteString
forall a. a -> JOSE Error IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> JOSE Error IO ByteString)
-> ByteString -> JOSE Error IO ByteString
forall a b. (a -> b) -> a -> b
$ SignedJWT -> ByteString
forall a. ToCompact a => a -> ByteString
Jose.encodeCompact SignedJWT
ejwt
where
addExp :: ClaimsSet -> ClaimsSet
addExp ClaimsSet
claims = case Maybe UTCTime
expiry of
Maybe UTCTime
Nothing -> ClaimsSet
claims
Just UTCTime
e -> ClaimsSet
claims ClaimsSet -> (ClaimsSet -> ClaimsSet) -> ClaimsSet
forall a b. a -> (a -> b) -> b
& (Maybe NumericDate -> Identity (Maybe NumericDate))
-> ClaimsSet -> Identity ClaimsSet
forall a. HasClaimsSet a => Lens' a (Maybe NumericDate)
Lens' ClaimsSet (Maybe NumericDate)
Jose.claimExp ((Maybe NumericDate -> Identity (Maybe NumericDate))
-> ClaimsSet -> Identity ClaimsSet)
-> NumericDate -> ClaimsSet -> ClaimsSet
forall s t a b. ASetter s t a (Maybe b) -> b -> s -> t
?~ UTCTime -> NumericDate
Jose.NumericDate UTCTime
e
verifyJWT :: FromJWT a => JWTSettings -> BS.ByteString -> IO (Maybe a)
verifyJWT :: forall a. FromJWT a => JWTSettings -> ByteString -> IO (Maybe a)
verifyJWT JWTSettings
jwtCfg ByteString
input = do
JWKSet
keys <- JWTSettings -> IO JWKSet
validationKeys JWTSettings
jwtCfg
Either JWTError ClaimsSet
verifiedJWT <- JOSE JWTError IO ClaimsSet -> IO (Either JWTError ClaimsSet)
forall e (m :: * -> *) a. JOSE e m a -> m (Either e a)
Jose.runJOSE (JOSE JWTError IO ClaimsSet -> IO (Either JWTError ClaimsSet))
-> JOSE JWTError IO ClaimsSet -> IO (Either JWTError ClaimsSet)
forall a b. (a -> b) -> a -> b
$ do
SignedJWT
unverifiedJWT <- ByteString -> JOSE JWTError IO SignedJWT
forall a e (m :: * -> *).
(FromCompact a, AsError e, MonadError e m) =>
ByteString -> m a
Jose.decodeCompact (ByteString -> ByteString
BSL.fromStrict ByteString
input)
JWTValidationSettings
-> JWKSet -> SignedJWT -> JOSE JWTError IO ClaimsSet
forall (m :: * -> *) a e k.
(MonadTime m, HasAllowedSkew a, HasAudiencePredicate a,
HasIssuerPredicate a, HasCheckIssuedAt a, HasValidationSettings a,
AsError e, AsJWTError e, MonadError e m,
VerificationKeyStore m (JWSHeader ()) ClaimsSet k) =>
a -> k -> SignedJWT -> m ClaimsSet
Jose.verifyClaims
(JWTSettings -> JWTValidationSettings
jwtSettingsToJwtValidationSettings JWTSettings
jwtCfg)
JWKSet
keys
SignedJWT
unverifiedJWT
Maybe a -> IO (Maybe a)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe a -> IO (Maybe a)) -> Maybe a -> IO (Maybe a)
forall a b. (a -> b) -> a -> b
$ case Either JWTError ClaimsSet
verifiedJWT of
Left (JWTError
_ :: Jose.JWTError) -> Maybe a
forall a. Maybe a
Nothing
Right ClaimsSet
v -> case ClaimsSet -> Either Text a
forall a. FromJWT a => ClaimsSet -> Either Text a
decodeJWT ClaimsSet
v of
Left Text
_ -> Maybe a
forall a. Maybe a
Nothing
Right a
v' -> a -> Maybe a
forall a. a -> Maybe a
Just a
v'