{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE FlexibleInstances #-}
{-# OPTIONS_GHC -Wno-deprecations #-}
-- | JWT signing and verification using jose library.
--
-- This module provides functions for signing and verifying JWTs using the
-- jose library. It supports both Text-based JWK strings and jose JWK objects.
module SDJWT.Internal.JWT
  ( signJWT
  , signJWTWithOptionalTyp
  , signJWTWithHeaders
  , signJWTWithTyp
  , verifyJWT
  , parseJWKFromText
  , JWKLike(..)
  ) where

import SDJWT.Internal.Types (SDJWTError(..))
import SDJWT.Internal.Utils (base64urlEncode, base64urlDecode)
import qualified Crypto.JOSE as Jose
import qualified Crypto.JOSE.JWS as JWS
import qualified Crypto.JOSE.JWK as JWK
import qualified Crypto.JOSE.Header as Header
import qualified Crypto.JOSE.JWA.JWS as JWA
import qualified Crypto.JOSE.Compact as Compact
import qualified Crypto.JOSE.Error as JoseError
import qualified Data.Aeson as Aeson
import qualified Data.Aeson.KeyMap as KeyMap
import qualified Data.Aeson.Key as Key
import qualified Data.Text as T
import qualified Data.Text.Encoding as TE
import qualified Data.ByteString.Lazy as LBS
import qualified Data.ByteString as BS
import Control.Lens ((&), (?~), (^.), (^..))
import Data.Functor.Identity (Identity(..))
import Data.Time.Clock.POSIX (getPOSIXTime)
import Data.Scientific (toBoundedInteger)
import Data.Int (Int64)
import Data.Maybe (isJust)

-- | Type class for types that can be converted to a jose JWK.
--
-- This allows functions to accept both Text (JWK JSON strings) and jose JWK objects.
-- Users can pass JWK strings directly without importing jose, or pass jose JWK objects
-- if they're already working with the jose library.
class JWKLike a where
  -- | Convert to a jose JWK object.
  toJWK :: a -> Either SDJWTError JWK.JWK

-- | Text instance: parse JWK from JSON string.
instance JWKLike T.Text where
  toJWK :: Text -> Either SDJWTError JWK
toJWK = Text -> Either SDJWTError JWK
parseJWKFromText

-- | JWK instance: identity conversion (already a JWK).
instance JWKLike JWK.JWK where
  toJWK :: JWK -> Either SDJWTError JWK
toJWK = JWK -> Either SDJWTError JWK
forall a b. b -> Either a b
Right

-- | Detect the key type from a jose JWK object and return the appropriate algorithm.
-- Returns "PS256" for RSA keys (defaults to PS256 for security, RS256 also supported via "alg" field),
-- "EdDSA" for Ed25519 keys, "ES256" for EC P-256 keys, or an error.
detectKeyAlgorithmFromJWK :: JWK.JWK -> Either SDJWTError T.Text
detectKeyAlgorithmFromJWK :: JWK -> Either SDJWTError Text
detectKeyAlgorithmFromJWK JWK
jwk = do
  -- Convert JWK to JSON Value to extract fields
  let jwkValue :: Value
jwkValue = JWK -> Value
forall a. ToJSON a => a -> Value
Aeson.toJSON JWK
jwk
  case Value
jwkValue of
    Aeson.Object Object
obj -> do
      Text
kty <- case Key -> Object -> Maybe Value
forall v. Key -> KeyMap v -> Maybe v
KeyMap.lookup (Text -> Key
Key.fromText Text
"kty") Object
obj of
        Just (Aeson.String Text
ktyText) -> Text -> Either SDJWTError Text
forall a b. b -> Either a b
Right Text
ktyText
        Maybe Value
_ -> SDJWTError -> Either SDJWTError Text
forall a b. a -> Either a b
Left (SDJWTError -> Either SDJWTError Text)
-> SDJWTError -> Either SDJWTError Text
forall a b. (a -> b) -> a -> b
$ Text -> SDJWTError
InvalidSignature Text
"Missing 'kty' field in JWK"
      
      if Text
kty Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
"RSA"
        then do
          -- Check if JWK specifies algorithm (RFC 7517 allows optional "alg" field)
          -- RS256 is deprecated per draft-ietf-jose-deprecate-none-rsa15 (padding oracle attacks)
          -- Default to PS256 (RSA-PSS) for security; RS256 can be explicitly requested but is deprecated
          case Key -> Object -> Maybe Value
forall v. Key -> KeyMap v -> Maybe v
KeyMap.lookup (Text -> Key
Key.fromText Text
"alg") Object
obj of
            Just (Aeson.String Text
"RS256") -> Text -> Either SDJWTError Text
forall a b. b -> Either a b
Right Text
"RS256"  -- Deprecated but still supported for compatibility
            Maybe Value
_ -> Text -> Either SDJWTError Text
forall a b. b -> Either a b
Right Text
"PS256"  -- Default to PS256 (RSA-PSS) for security
        else if Text
kty Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
"EC"
          then do
            -- Check curve for EC keys (only P-256 is supported)
            ()
_crv <- case Key -> Object -> Maybe Value
forall v. Key -> KeyMap v -> Maybe v
KeyMap.lookup (Text -> Key
Key.fromText Text
"crv") Object
obj of
              Just (Aeson.String Text
"P-256") -> () -> Either SDJWTError ()
forall a b. b -> Either a b
Right ()
              Just (Aeson.String Text
crvText) -> SDJWTError -> Either SDJWTError ()
forall a b. a -> Either a b
Left (SDJWTError -> Either SDJWTError ())
-> SDJWTError -> Either SDJWTError ()
forall a b. (a -> b) -> a -> b
$ Text -> SDJWTError
InvalidSignature (Text -> SDJWTError) -> Text -> SDJWTError
forall a b. (a -> b) -> a -> b
$ Text
"Unsupported EC curve: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
crvText Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" (only P-256 is supported)"
              Maybe Value
_ -> SDJWTError -> Either SDJWTError ()
forall a b. a -> Either a b
Left (SDJWTError -> Either SDJWTError ())
-> SDJWTError -> Either SDJWTError ()
forall a b. (a -> b) -> a -> b
$ Text -> SDJWTError
InvalidSignature Text
"Missing 'crv' field in EC JWK"
            Text -> Either SDJWTError Text
forall a b. b -> Either a b
Right Text
"ES256"
        else if Text
kty Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
"OKP"
          then do
            -- Check curve for OKP keys (Ed25519, Ed448)
            Text
crv <- case Key -> Object -> Maybe Value
forall v. Key -> KeyMap v -> Maybe v
KeyMap.lookup (Text -> Key
Key.fromText Text
"crv") Object
obj of
              Just (Aeson.String Text
crvText) -> Text -> Either SDJWTError Text
forall a b. b -> Either a b
Right Text
crvText
              Maybe Value
_ -> SDJWTError -> Either SDJWTError Text
forall a b. a -> Either a b
Left (SDJWTError -> Either SDJWTError Text)
-> SDJWTError -> Either SDJWTError Text
forall a b. (a -> b) -> a -> b
$ Text -> SDJWTError
InvalidSignature Text
"Missing 'crv' field in OKP JWK"
            
            if Text
crv Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
"Ed25519"
              then Text -> Either SDJWTError Text
forall a b. b -> Either a b
Right Text
"EdDSA"
              else SDJWTError -> Either SDJWTError Text
forall a b. a -> Either a b
Left (SDJWTError -> Either SDJWTError Text)
-> SDJWTError -> Either SDJWTError Text
forall a b. (a -> b) -> a -> b
$ Text -> SDJWTError
InvalidSignature (Text -> SDJWTError) -> Text -> SDJWTError
forall a b. (a -> b) -> a -> b
$ Text
"Unsupported OKP curve: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
crv Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" (only Ed25519 is supported)"
          else SDJWTError -> Either SDJWTError Text
forall a b. a -> Either a b
Left (SDJWTError -> Either SDJWTError Text)
-> SDJWTError -> Either SDJWTError Text
forall a b. (a -> b) -> a -> b
$ Text -> SDJWTError
InvalidSignature (Text -> SDJWTError) -> Text -> SDJWTError
forall a b. (a -> b) -> a -> b
$ Text
"Unsupported key type: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
kty Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" (supported: RSA, EC P-256, Ed25519)"
    Value
_ -> SDJWTError -> Either SDJWTError Text
forall a b. a -> Either a b
Left (SDJWTError -> Either SDJWTError Text)
-> SDJWTError -> Either SDJWTError Text
forall a b. (a -> b) -> a -> b
$ Text -> SDJWTError
InvalidSignature Text
"Invalid JWK format: expected object"

-- | Convert algorithm string to JWA.Alg
-- Supports RSA-PSS (PS256, default) and RSA-PKCS#1 v1.5 (RS256, deprecated per draft-ietf-jose-deprecate-none-rsa15).
-- RS256 is deprecated due to padding oracle attack vulnerabilities. PS256 (RSA-PSS) is recommended.
toJwsAlg :: T.Text -> Either SDJWTError JWA.Alg
toJwsAlg :: Text -> Either SDJWTError Alg
toJwsAlg Text
"RS256" = Alg -> Either SDJWTError Alg
forall a b. b -> Either a b
Right Alg
JWA.RS256  -- Deprecated: Use PS256 instead (draft-ietf-jose-deprecate-none-rsa15)
toJwsAlg Text
"PS256" = Alg -> Either SDJWTError Alg
forall a b. b -> Either a b
Right Alg
JWA.PS256
toJwsAlg Text
"EdDSA" = Alg -> Either SDJWTError Alg
forall a b. b -> Either a b
Right Alg
JWA.EdDSA
toJwsAlg Text
"ES256" = Alg -> Either SDJWTError Alg
forall a b. b -> Either a b
Right Alg
JWA.ES256
toJwsAlg Text
alg = SDJWTError -> Either SDJWTError Alg
forall a b. a -> Either a b
Left (SDJWTError -> Either SDJWTError Alg)
-> SDJWTError -> Either SDJWTError Alg
forall a b. (a -> b) -> a -> b
$ Text -> SDJWTError
InvalidSignature (Text -> SDJWTError) -> Text -> SDJWTError
forall a b. (a -> b) -> a -> b
$ Text
"Unsupported algorithm: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
alg Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" (supported: PS256 default, RS256 deprecated, EdDSA, ES256)"

-- | Sign a JWT payload using a private key.
--
-- Returns the signed JWT as a compact string, or an error.
-- Automatically detects key type and uses:
--
-- - PS256 for RSA keys (default, RS256 also supported via JWK "alg" field)
-- - EdDSA for Ed25519 keys
-- - ES256 for EC P-256 keys
signJWT
  :: JWKLike jwk => jwk  -- ^ Private key JWK (Text or jose JWK object)
  -> Aeson.Value  -- ^ JWT payload
  -> IO (Either SDJWTError T.Text)
signJWT :: forall jwk.
JWKLike jwk =>
jwk -> Value -> IO (Either SDJWTError Text)
signJWT jwk
privateKeyJWK Value
payload = Maybe Text -> jwk -> Value -> IO (Either SDJWTError Text)
forall jwk.
JWKLike jwk =>
Maybe Text -> jwk -> Value -> IO (Either SDJWTError Text)
signJWTWithOptionalTyp Maybe Text
forall a. Maybe a
Nothing jwk
privateKeyJWK Value
payload

-- | Sign a JWT payload with optional typ header parameter.
--
-- This function allows setting a typ header for issuer-signed JWTs (RFC 9901 Section 9.11 recommends
-- explicit typing, e.g., "sd-jwt" or "example+sd-jwt"). Use 'signJWT' for default behavior (no typ header).
--
-- Returns the signed JWT as a compact string, or an error.
signJWTWithOptionalTyp
  :: JWKLike jwk => Maybe T.Text  -- ^ Optional typ header value (RFC 9901 Section 9.11 recommends explicit typing)
  -> jwk  -- ^ Private key JWK (Text or jose JWK object)
  -> Aeson.Value  -- ^ JWT payload
  -> IO (Either SDJWTError T.Text)
signJWTWithOptionalTyp :: forall jwk.
JWKLike jwk =>
Maybe Text -> jwk -> Value -> IO (Either SDJWTError Text)
signJWTWithOptionalTyp Maybe Text
mbTyp jwk
privateKeyJWK Value
payload = 
  Maybe Text
-> Maybe Text -> jwk -> Value -> IO (Either SDJWTError Text)
forall jwk.
JWKLike jwk =>
Maybe Text
-> Maybe Text -> jwk -> Value -> IO (Either SDJWTError Text)
signJWTWithHeaders Maybe Text
mbTyp Maybe Text
forall a. Maybe a
Nothing jwk
privateKeyJWK Value
payload

-- | Sign a JWT payload with optional typ and kid header parameters.
--
-- This function allows setting @typ@ and @kid@ headers for issuer-signed JWTs.
-- Both headers are supported natively through jose's API.
--
-- Returns the signed JWT as a compact string, or an error.
signJWTWithHeaders
  :: JWKLike jwk => Maybe T.Text  -- ^ Optional typ header value (RFC 9901 Section 9.11 recommends explicit typing, e.g., "sd-jwt")
  -> Maybe T.Text  -- ^ Optional kid header value (Key ID for key management)
  -> jwk  -- ^ Private key JWK (Text or jose JWK object)
  -> Aeson.Value  -- ^ JWT payload
  -> IO (Either SDJWTError T.Text)
signJWTWithHeaders :: forall jwk.
JWKLike jwk =>
Maybe Text
-> Maybe Text -> jwk -> Value -> IO (Either SDJWTError Text)
signJWTWithHeaders Maybe Text
mbTyp Maybe Text
mbKid jwk
privateKeyJWK Value
payload = do
  -- Convert to jose JWK
  case jwk -> Either SDJWTError JWK
forall a. JWKLike a => a -> Either SDJWTError JWK
toJWK jwk
privateKeyJWK of
    Left SDJWTError
err -> Either SDJWTError Text -> IO (Either SDJWTError Text)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError Text -> IO (Either SDJWTError Text))
-> Either SDJWTError Text -> IO (Either SDJWTError Text)
forall a b. (a -> b) -> a -> b
$ SDJWTError -> Either SDJWTError Text
forall a b. a -> Either a b
Left SDJWTError
err
    Right JWK
jwk -> do
      -- Detect algorithm from key type
      Either SDJWTError Text
algResult <- case JWK -> Either SDJWTError Text
detectKeyAlgorithmFromJWK JWK
jwk of
        Left SDJWTError
err -> Either SDJWTError Text -> IO (Either SDJWTError Text)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError Text -> IO (Either SDJWTError Text))
-> Either SDJWTError Text -> IO (Either SDJWTError Text)
forall a b. (a -> b) -> a -> b
$ SDJWTError -> Either SDJWTError Text
forall a b. a -> Either a b
Left SDJWTError
err
        Right Text
algText -> Either SDJWTError Text -> IO (Either SDJWTError Text)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError Text -> IO (Either SDJWTError Text))
-> Either SDJWTError Text -> IO (Either SDJWTError Text)
forall a b. (a -> b) -> a -> b
$ Text -> Either SDJWTError Text
forall a b. b -> Either a b
Right Text
algText
      
      case Either SDJWTError Text
algResult of
        Left SDJWTError
err -> Either SDJWTError Text -> IO (Either SDJWTError Text)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError Text -> IO (Either SDJWTError Text))
-> Either SDJWTError Text -> IO (Either SDJWTError Text)
forall a b. (a -> b) -> a -> b
$ SDJWTError -> Either SDJWTError Text
forall a b. a -> Either a b
Left SDJWTError
err
        Right Text
algText -> do
          -- Convert to JWA.Alg
          Either SDJWTError Alg
jwsAlgResult <- case Text -> Either SDJWTError Alg
toJwsAlg Text
algText of
            Left SDJWTError
err -> Either SDJWTError Alg -> IO (Either SDJWTError Alg)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError Alg -> IO (Either SDJWTError Alg))
-> Either SDJWTError Alg -> IO (Either SDJWTError Alg)
forall a b. (a -> b) -> a -> b
$ SDJWTError -> Either SDJWTError Alg
forall a b. a -> Either a b
Left SDJWTError
err
            Right Alg
alg -> Either SDJWTError Alg -> IO (Either SDJWTError Alg)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError Alg -> IO (Either SDJWTError Alg))
-> Either SDJWTError Alg -> IO (Either SDJWTError Alg)
forall a b. (a -> b) -> a -> b
$ Alg -> Either SDJWTError Alg
forall a b. b -> Either a b
Right Alg
alg
          
          case Either SDJWTError Alg
jwsAlgResult of
            Left SDJWTError
err -> Either SDJWTError Text -> IO (Either SDJWTError Text)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError Text -> IO (Either SDJWTError Text))
-> Either SDJWTError Text -> IO (Either SDJWTError Text)
forall a b. (a -> b) -> a -> b
$ SDJWTError -> Either SDJWTError Text
forall a b. a -> Either a b
Left SDJWTError
err
            Right Alg
jwsAlg -> do
              -- Create header with algorithm (Protected header)
              let baseHeader :: JWSHeader Protection
baseHeader = (Protection, Alg) -> JWSHeader Protection
forall p. (p, Alg) -> JWSHeader p
JWS.newJWSHeader (Protection
Header.Protected, Alg
jwsAlg)
              -- Add typ header if specified (native support in jose!)
              let headerWithTyp :: JWSHeader Protection
headerWithTyp = case Maybe Text
mbTyp of
                    Just Text
typValue -> JWSHeader Protection
baseHeader JWSHeader Protection
-> (JWSHeader Protection -> JWSHeader Protection)
-> JWSHeader Protection
forall a b. a -> (a -> b) -> b
& (Maybe (HeaderParam Protection Text)
 -> Identity (Maybe (HeaderParam Protection Text)))
-> JWSHeader Protection -> Identity (JWSHeader Protection)
forall p. Lens' (JWSHeader p) (Maybe (HeaderParam p Text))
forall (a :: * -> *) p.
HasTyp a =>
Lens' (a p) (Maybe (HeaderParam p Text))
Header.typ ((Maybe (HeaderParam Protection Text)
  -> Identity (Maybe (HeaderParam Protection Text)))
 -> JWSHeader Protection -> Identity (JWSHeader Protection))
-> HeaderParam Protection Text
-> JWSHeader Protection
-> JWSHeader Protection
forall s t a b. ASetter s t a (Maybe b) -> b -> s -> t
?~ Protection -> Text -> HeaderParam Protection Text
forall p a. p -> a -> HeaderParam p a
Header.HeaderParam Protection
Header.Protected Text
typValue
                    Maybe Text
Nothing -> JWSHeader Protection
baseHeader
              -- Add kid header if specified (native support in jose!)
              let header :: JWSHeader Protection
header = case Maybe Text
mbKid of
                    Just Text
kidValue -> JWSHeader Protection
headerWithTyp JWSHeader Protection
-> (JWSHeader Protection -> JWSHeader Protection)
-> JWSHeader Protection
forall a b. a -> (a -> b) -> b
& (Maybe (HeaderParam Protection Text)
 -> Identity (Maybe (HeaderParam Protection Text)))
-> JWSHeader Protection -> Identity (JWSHeader Protection)
forall p. Lens' (JWSHeader p) (Maybe (HeaderParam p Text))
forall (a :: * -> *) p.
HasKid a =>
Lens' (a p) (Maybe (HeaderParam p Text))
Header.kid ((Maybe (HeaderParam Protection Text)
  -> Identity (Maybe (HeaderParam Protection Text)))
 -> JWSHeader Protection -> Identity (JWSHeader Protection))
-> HeaderParam Protection Text
-> JWSHeader Protection
-> JWSHeader Protection
forall s t a b. ASetter s t a (Maybe b) -> b -> s -> t
?~ Protection -> Text -> HeaderParam Protection Text
forall p a. p -> a -> HeaderParam p a
Header.HeaderParam Protection
Header.Protected Text
kidValue
                    Maybe Text
Nothing -> JWSHeader Protection
headerWithTyp
              
              -- Encode payload to ByteString
              let payloadBS :: ByteString
payloadBS = ByteString -> ByteString
LBS.toStrict (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ Value -> ByteString
forall a. ToJSON a => a -> ByteString
Aeson.encode Value
payload
              
              -- Sign the JWT using Identity container to get FlattenedJWS (single signature)
              -- Note: Header.Protection is deprecated in newer jose versions but required for jose-0.10 compatibility
              Either Error (JWS Identity Protection JWSHeader)
result <- JOSE Error IO (JWS Identity Protection JWSHeader)
-> IO (Either Error (JWS Identity Protection JWSHeader))
forall e (m :: * -> *) a. JOSE e m a -> m (Either e a)
Jose.runJOSE (JOSE Error IO (JWS Identity Protection JWSHeader)
 -> IO (Either Error (JWS Identity Protection JWSHeader)))
-> JOSE Error IO (JWS Identity Protection JWSHeader)
-> IO (Either Error (JWS Identity Protection JWSHeader))
forall a b. (a -> b) -> a -> b
$ ByteString
-> Identity (JWSHeader Protection, JWK)
-> JOSE Error IO (JWS Identity Protection JWSHeader)
forall s (a :: * -> *) (m :: * -> *) e (t :: * -> *) p.
(Cons s s Word8 Word8, HasJWSHeader a, HasParams a, MonadRandom m,
 AsError e, MonadError e m, Traversable t, ProtectionSupport p) =>
s -> t (a p, JWK) -> m (JWS t p a)
JWS.signJWS ByteString
payloadBS ((JWSHeader Protection, JWK) -> Identity (JWSHeader Protection, JWK)
forall a. a -> Identity a
Identity (JWSHeader Protection
header, JWK
jwk)) :: IO (Either JoseError.Error (JWS.JWS Identity Header.Protection JWS.JWSHeader))
              
              case Either Error (JWS Identity Protection JWSHeader)
result of
                Left Error
err -> Either SDJWTError Text -> IO (Either SDJWTError Text)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError Text -> IO (Either SDJWTError Text))
-> Either SDJWTError Text -> IO (Either SDJWTError Text)
forall a b. (a -> b) -> a -> b
$ SDJWTError -> Either SDJWTError Text
forall a b. a -> Either a b
Left (SDJWTError -> Either SDJWTError Text)
-> SDJWTError -> Either SDJWTError Text
forall a b. (a -> b) -> a -> b
$ Text -> SDJWTError
InvalidSignature (Text -> SDJWTError) -> Text -> SDJWTError
forall a b. (a -> b) -> a -> b
$ Text
"JWT signing failed: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (Error -> String
forall a. Show a => a -> String
show Error
err)
                Right JWS Identity Protection JWSHeader
jws -> do
                  -- Extract the three parts needed for compact JWT format
                  let sig :: [Signature Protection JWSHeader]
sig = JWS Identity Protection JWSHeader
jws JWS Identity Protection JWSHeader
-> Getting
     (Endo [Signature Protection JWSHeader])
     (JWS Identity Protection JWSHeader)
     (Signature Protection JWSHeader)
-> [Signature Protection JWSHeader]
forall s a. s -> Getting (Endo [a]) s a -> [a]
^.. Getting
  (Endo [Signature Protection JWSHeader])
  (JWS Identity Protection JWSHeader)
  (Signature Protection JWSHeader)
forall {k} (t :: * -> *) (p :: k) (a :: k -> *).
Foldable t =>
Fold (JWS t p a) (Signature p a)
Fold
  (JWS Identity Protection JWSHeader)
  (Signature Protection JWSHeader)
JWS.signatures
                  case [Signature Protection JWSHeader]
sig of
                    [] -> Either SDJWTError Text -> IO (Either SDJWTError Text)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError Text -> IO (Either SDJWTError Text))
-> Either SDJWTError Text -> IO (Either SDJWTError Text)
forall a b. (a -> b) -> a -> b
$ SDJWTError -> Either SDJWTError Text
forall a b. a -> Either a b
Left (SDJWTError -> Either SDJWTError Text)
-> SDJWTError -> Either SDJWTError Text
forall a b. (a -> b) -> a -> b
$ Text -> SDJWTError
InvalidSignature Text
"No signatures in JWS"
                    (Signature Protection JWSHeader
sigHead:[Signature Protection JWSHeader]
_) -> do
                      -- Get payload using verifyJWSWithPayload (returns raw bytes, need to base64url encode)
                      Either Error ByteString
payloadResult <- JOSE Error IO ByteString -> IO (Either Error ByteString)
forall e (m :: * -> *) a. JOSE e m a -> m (Either e a)
Jose.runJOSE (JOSE Error IO ByteString -> IO (Either Error ByteString))
-> JOSE Error IO ByteString -> IO (Either Error ByteString)
forall a b. (a -> b) -> a -> b
$ (ByteString -> JOSE Error IO ByteString)
-> ValidationSettings
-> JWK
-> JWS Identity Protection JWSHeader
-> JOSE Error IO ByteString
forall a e (m :: * -> *) (h :: * -> *) p payload k s (t :: * -> *).
(HasAlgorithms a, HasValidationPolicy a, AsError e, MonadError e m,
 HasJWSHeader h, HasParams h,
 VerificationKeyStore m (h p) payload k, Cons s s Word8 Word8,
 AsEmpty s, Foldable t, ProtectionSupport p) =>
(s -> m payload) -> a -> k -> JWS t p h -> m payload
JWS.verifyJWSWithPayload ByteString -> JOSE Error IO ByteString
forall a. a -> JOSE Error IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ValidationSettings
JWS.defaultValidationSettings JWK
jwk JWS Identity Protection JWSHeader
jws :: IO (Either JoseError.Error BS.ByteString)
                      case Either Error ByteString
payloadResult of
                        Left Error
err -> Either SDJWTError Text -> IO (Either SDJWTError Text)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError Text -> IO (Either SDJWTError Text))
-> Either SDJWTError Text -> IO (Either SDJWTError Text)
forall a b. (a -> b) -> a -> b
$ SDJWTError -> Either SDJWTError Text
forall a b. a -> Either a b
Left (SDJWTError -> Either SDJWTError Text)
-> SDJWTError -> Either SDJWTError Text
forall a b. (a -> b) -> a -> b
$ Text -> SDJWTError
InvalidSignature (Text -> SDJWTError) -> Text -> SDJWTError
forall a b. (a -> b) -> a -> b
$ Text
"Failed to extract payload: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (Error -> String
forall a. Show a => a -> String
show Error
err)
                        Right ByteString
extractedPayloadBS -> do
                          let headerBS :: ByteString
headerBS = Signature Protection JWSHeader -> ByteString
forall (a :: * -> *) p.
(HasParams a, ProtectionSupport p) =>
Signature p a -> ByteString
JWS.rawProtectedHeader Signature Protection JWSHeader
sigHead
                          let sigBS :: ByteString
sigBS = Signature Protection JWSHeader
sigHead Signature Protection JWSHeader
-> Getting ByteString (Signature Protection JWSHeader) ByteString
-> ByteString
forall s a. s -> Getting a s a -> a
^. Getting ByteString (Signature Protection JWSHeader) ByteString
forall {k} s (p :: k) (a :: k -> *).
(Cons s s Word8 Word8, AsEmpty s) =>
Getter (Signature p a) s
Getter (Signature Protection JWSHeader) ByteString
JWS.signature
                          -- Construct compact JWT: base64url(header).base64url(payload).base64url(signature)
                          let headerB64 :: Text
headerB64 = ByteString -> Text
TE.decodeUtf8 ByteString
headerBS  -- Already base64url encoded
                          let payloadB64 :: Text
payloadB64 = ByteString -> Text
base64urlEncode ByteString
extractedPayloadBS
                          let sigB64 :: Text
sigB64 = ByteString -> Text
base64urlEncode ByteString
sigBS  -- Raw binary, needs encoding
                          let compactJWT :: Text
compactJWT = Text
headerB64 Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"." Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
payloadB64 Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"." Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
sigB64
                          Either SDJWTError Text -> IO (Either SDJWTError Text)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError Text -> IO (Either SDJWTError Text))
-> Either SDJWTError Text -> IO (Either SDJWTError Text)
forall a b. (a -> b) -> a -> b
$ Text -> Either SDJWTError Text
forall a b. b -> Either a b
Right Text
compactJWT

-- | Sign a JWT payload with a custom typ header parameter.
--
-- This function constructs the JWT header with the specified typ value,
-- then signs the JWT. This is needed for KB-JWT which requires typ: "kb+jwt"
-- (RFC 9901 Section 4.3).
--
-- Supports all algorithms: EC P-256 (ES256), RSA (PS256 default, RS256 also supported), and Ed25519 (EdDSA).
--
-- Returns the signed JWT as a compact string, or an error.
signJWTWithTyp
  :: JWKLike jwk => T.Text  -- ^ typ header value (e.g., "kb+jwt" for KB-JWT)
  -> jwk  -- ^ Private key JWK (Text or jose JWK object)
  -> Aeson.Value  -- ^ JWT payload
  -> IO (Either SDJWTError T.Text)
signJWTWithTyp :: forall jwk.
JWKLike jwk =>
Text -> jwk -> Value -> IO (Either SDJWTError Text)
signJWTWithTyp Text
typValue jwk
privateKeyJWK Value
payload = Maybe Text -> jwk -> Value -> IO (Either SDJWTError Text)
forall jwk.
JWKLike jwk =>
Maybe Text -> jwk -> Value -> IO (Either SDJWTError Text)
signJWTWithOptionalTyp (Text -> Maybe Text
forall a. a -> Maybe a
Just Text
typValue) jwk
privateKeyJWK Value
payload

-- | Verify a JWT signature using a public key.
--
-- Returns the decoded payload if verification succeeds, or an error.
verifyJWT
  :: JWKLike jwk => jwk  -- ^ Public key JWK (Text or jose JWK object)
  -> T.Text  -- ^ JWT to verify as a compact string
  -> Maybe T.Text  -- ^ Required typ header value (Nothing = allow any/none, Just "sd-jwt" = require exactly "sd-jwt")
  -> IO (Either SDJWTError Aeson.Value)
verifyJWT :: forall jwk.
JWKLike jwk =>
jwk -> Text -> Maybe Text -> IO (Either SDJWTError Value)
verifyJWT jwk
publicKeyJWK Text
jwtText Maybe Text
requiredTyp = do
  -- Convert to jose JWK
  case jwk -> Either SDJWTError JWK
forall a. JWKLike a => a -> Either SDJWTError JWK
toJWK jwk
publicKeyJWK of
    Left SDJWTError
err -> Either SDJWTError Value -> IO (Either SDJWTError Value)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError Value -> IO (Either SDJWTError Value))
-> Either SDJWTError Value -> IO (Either SDJWTError Value)
forall a b. (a -> b) -> a -> b
$ SDJWTError -> Either SDJWTError Value
forall a b. a -> Either a b
Left SDJWTError
err
    Right JWK
jwk -> do
      -- Decode compact JWT
      case ByteString -> Either Error (CompactJWS JWSHeader)
forall a e (m :: * -> *).
(FromCompact a, AsError e, MonadError e m) =>
ByteString -> m a
Compact.decodeCompact (ByteString -> ByteString
LBS.fromStrict (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ Text -> ByteString
TE.encodeUtf8 Text
jwtText) :: Either JoseError.Error (JWS.CompactJWS JWS.JWSHeader) of
        Left Error
err -> Either SDJWTError Value -> IO (Either SDJWTError Value)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError Value -> IO (Either SDJWTError Value))
-> Either SDJWTError Value -> IO (Either SDJWTError Value)
forall a b. (a -> b) -> a -> b
$ SDJWTError -> Either SDJWTError Value
forall a b. a -> Either a b
Left (SDJWTError -> Either SDJWTError Value)
-> SDJWTError -> Either SDJWTError Value
forall a b. (a -> b) -> a -> b
$ Text -> SDJWTError
InvalidSignature (Text -> SDJWTError) -> Text -> SDJWTError
forall a b. (a -> b) -> a -> b
$ Text
"Failed to decode JWT: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (Error -> String
forall a. Show a => a -> String
show Error
err)
        Right CompactJWS JWSHeader
jws -> do
          -- Extract header from signature
          let sigs :: [Signature RequiredProtection JWSHeader]
sigs = CompactJWS JWSHeader
jws CompactJWS JWSHeader
-> Getting
     (Endo [Signature RequiredProtection JWSHeader])
     (CompactJWS JWSHeader)
     (Signature RequiredProtection JWSHeader)
-> [Signature RequiredProtection JWSHeader]
forall s a. s -> Getting (Endo [a]) s a -> [a]
^.. Getting
  (Endo [Signature RequiredProtection JWSHeader])
  (CompactJWS JWSHeader)
  (Signature RequiredProtection JWSHeader)
forall {k} (t :: * -> *) (p :: k) (a :: k -> *).
Foldable t =>
Fold (JWS t p a) (Signature p a)
Fold
  (CompactJWS JWSHeader) (Signature RequiredProtection JWSHeader)
JWS.signatures
          case [Signature RequiredProtection JWSHeader]
sigs of
            [] -> Either SDJWTError Value -> IO (Either SDJWTError Value)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError Value -> IO (Either SDJWTError Value))
-> Either SDJWTError Value -> IO (Either SDJWTError Value)
forall a b. (a -> b) -> a -> b
$ SDJWTError -> Either SDJWTError Value
forall a b. a -> Either a b
Left (SDJWTError -> Either SDJWTError Value)
-> SDJWTError -> Either SDJWTError Value
forall a b. (a -> b) -> a -> b
$ Text -> SDJWTError
InvalidSignature Text
"No signatures found in JWT"
            (Signature RequiredProtection JWSHeader
sig:[Signature RequiredProtection JWSHeader]
_) -> do
              let hdr :: JWSHeader RequiredProtection
hdr = Signature RequiredProtection JWSHeader
sig Signature RequiredProtection JWSHeader
-> Getting
     (JWSHeader RequiredProtection)
     (Signature RequiredProtection JWSHeader)
     (JWSHeader RequiredProtection)
-> JWSHeader RequiredProtection
forall s a. s -> Getting a s a -> a
^. Getting
  (JWSHeader RequiredProtection)
  (Signature RequiredProtection JWSHeader)
  (JWSHeader RequiredProtection)
forall {k} (p :: k) (a :: k -> *) (f :: * -> *).
(Contravariant f, Functor f) =>
(a p -> f (a p)) -> Signature p a -> f (Signature p a)
JWS.header
              
              -- SECURITY: RFC 8725bis - Extract and validate algorithm BEFORE verification
              -- We MUST NOT trust the alg value in the header - we must validate it matches the key
              let algParam :: Alg
algParam = JWSHeader RequiredProtection
hdr JWSHeader RequiredProtection
-> Getting Alg (JWSHeader RequiredProtection) Alg -> Alg
forall s a. s -> Getting a s a -> a
^. (HeaderParam RequiredProtection Alg
 -> Const Alg (HeaderParam RequiredProtection Alg))
-> JWSHeader RequiredProtection
-> Const Alg (JWSHeader RequiredProtection)
forall p. Lens' (JWSHeader p) (HeaderParam p Alg)
forall (a :: * -> *) p. HasAlg a => Lens' (a p) (HeaderParam p Alg)
Header.alg ((HeaderParam RequiredProtection Alg
  -> Const Alg (HeaderParam RequiredProtection Alg))
 -> JWSHeader RequiredProtection
 -> Const Alg (JWSHeader RequiredProtection))
-> ((Alg -> Const Alg Alg)
    -> HeaderParam RequiredProtection Alg
    -> Const Alg (HeaderParam RequiredProtection Alg))
-> Getting Alg (JWSHeader RequiredProtection) Alg
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Alg -> Const Alg Alg)
-> HeaderParam RequiredProtection Alg
-> Const Alg (HeaderParam RequiredProtection Alg)
forall p a (f :: * -> *).
Functor f =>
(a -> f a) -> HeaderParam p a -> f (HeaderParam p a)
Header.param
              let headerAlg :: Text
headerAlg = case Alg
algParam of
                    Alg
JWA.RS256 -> Text
"RS256"
                    Alg
JWA.PS256 -> Text
"PS256"
                    Alg
JWA.EdDSA -> Text
"EdDSA"
                    Alg
JWA.ES256 -> Text
"ES256"
                    Alg
_ -> Text
"UNSUPPORTED"
              
              -- Validate algorithm matches key type (RFC 8725bis requirement)
              Either SDJWTError Text
expectedAlgResult <- case JWK -> Either SDJWTError Text
detectKeyAlgorithmFromJWK JWK
jwk of
                Left SDJWTError
err -> Either SDJWTError Text -> IO (Either SDJWTError Text)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError Text -> IO (Either SDJWTError Text))
-> Either SDJWTError Text -> IO (Either SDJWTError Text)
forall a b. (a -> b) -> a -> b
$ SDJWTError -> Either SDJWTError Text
forall a b. a -> Either a b
Left SDJWTError
err
                Right Text
expectedAlg -> Either SDJWTError Text -> IO (Either SDJWTError Text)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError Text -> IO (Either SDJWTError Text))
-> Either SDJWTError Text -> IO (Either SDJWTError Text)
forall a b. (a -> b) -> a -> b
$ Text -> Either SDJWTError Text
forall a b. b -> Either a b
Right Text
expectedAlg
              
              case Either SDJWTError Text
expectedAlgResult of
                Left SDJWTError
err -> Either SDJWTError Value -> IO (Either SDJWTError Value)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError Value -> IO (Either SDJWTError Value))
-> Either SDJWTError Value -> IO (Either SDJWTError Value)
forall a b. (a -> b) -> a -> b
$ SDJWTError -> Either SDJWTError Value
forall a b. a -> Either a b
Left SDJWTError
err
                Right Text
expectedAlg -> do
                  -- Note: "none" algorithm is prevented by jose's type system (JWA.Alg doesn't include "none")
                  -- so headerAlg can never be "none" - jose will reject it during decodeCompact
                  -- Validate algorithm matches expected algorithm (RFC 8725bis - don't trust header)
                  if Text
headerAlg Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
/= Text
expectedAlg
                        then Either SDJWTError Value -> IO (Either SDJWTError Value)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError Value -> IO (Either SDJWTError Value))
-> Either SDJWTError Value -> IO (Either SDJWTError Value)
forall a b. (a -> b) -> a -> b
$ SDJWTError -> Either SDJWTError Value
forall a b. a -> Either a b
Left (SDJWTError -> Either SDJWTError Value)
-> SDJWTError -> Either SDJWTError Value
forall a b. (a -> b) -> a -> b
$ Text -> SDJWTError
InvalidSignature (Text -> SDJWTError) -> Text -> SDJWTError
forall a b. (a -> b) -> a -> b
$ Text
"Algorithm mismatch: header claims '" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
headerAlg Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"', but key type requires '" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
expectedAlg Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"' (RFC 8725bis)"
                        else do
                          -- Validate algorithm is in whitelist
                          case Text -> Either SDJWTError Alg
toJwsAlg Text
expectedAlg of
                            Left SDJWTError
err -> Either SDJWTError Value -> IO (Either SDJWTError Value)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError Value -> IO (Either SDJWTError Value))
-> Either SDJWTError Value -> IO (Either SDJWTError Value)
forall a b. (a -> b) -> a -> b
$ SDJWTError -> Either SDJWTError Value
forall a b. a -> Either a b
Left SDJWTError
err
                            Right Alg
_ -> do
                                  -- Extract typ from header
                                  let mbTypValue :: Maybe Text
mbTypValue = case JWSHeader RequiredProtection
hdr JWSHeader RequiredProtection
-> Getting
     (Maybe (HeaderParam RequiredProtection Text))
     (JWSHeader RequiredProtection)
     (Maybe (HeaderParam RequiredProtection Text))
-> Maybe (HeaderParam RequiredProtection Text)
forall s a. s -> Getting a s a -> a
^. Getting
  (Maybe (HeaderParam RequiredProtection Text))
  (JWSHeader RequiredProtection)
  (Maybe (HeaderParam RequiredProtection Text))
forall p. Lens' (JWSHeader p) (Maybe (HeaderParam p Text))
forall (a :: * -> *) p.
HasTyp a =>
Lens' (a p) (Maybe (HeaderParam p Text))
Header.typ of
                                        Maybe (HeaderParam RequiredProtection Text)
Nothing -> Maybe Text
forall a. Maybe a
Nothing
                                        Just HeaderParam RequiredProtection Text
typParam -> Text -> Maybe Text
forall a. a -> Maybe a
Just (HeaderParam RequiredProtection Text
typParam HeaderParam RequiredProtection Text
-> Getting Text (HeaderParam RequiredProtection Text) Text -> Text
forall s a. s -> Getting a s a -> a
^. Getting Text (HeaderParam RequiredProtection Text) Text
forall p a (f :: * -> *).
Functor f =>
(a -> f a) -> HeaderParam p a -> f (HeaderParam p a)
Header.param)
                                  
                                  -- Validate typ header if required
                                  Either SDJWTError ()
typValidation <- case Maybe Text
requiredTyp of
                                    Maybe Text
Nothing -> Either SDJWTError () -> IO (Either SDJWTError ())
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError () -> IO (Either SDJWTError ()))
-> Either SDJWTError () -> IO (Either SDJWTError ())
forall a b. (a -> b) -> a -> b
$ () -> Either SDJWTError ()
forall a b. b -> Either a b
Right ()  -- Liberal mode: allow any typ or none
                                    Just Text
requiredTypValue -> do
                                      case Maybe Text
mbTypValue of
                                        Maybe Text
Nothing -> Either SDJWTError () -> IO (Either SDJWTError ())
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError () -> IO (Either SDJWTError ()))
-> Either SDJWTError () -> IO (Either SDJWTError ())
forall a b. (a -> b) -> a -> b
$ SDJWTError -> Either SDJWTError ()
forall a b. a -> Either a b
Left (SDJWTError -> Either SDJWTError ())
-> SDJWTError -> Either SDJWTError ()
forall a b. (a -> b) -> a -> b
$ Text -> SDJWTError
InvalidSignature (Text -> SDJWTError) -> Text -> SDJWTError
forall a b. (a -> b) -> a -> b
$ Text
"Missing typ header: required '" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
requiredTypValue Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"'"
                                        Just Text
typVal -> do
                                          if Text
typVal Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
requiredTypValue
                                            then Either SDJWTError () -> IO (Either SDJWTError ())
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError () -> IO (Either SDJWTError ()))
-> Either SDJWTError () -> IO (Either SDJWTError ())
forall a b. (a -> b) -> a -> b
$ () -> Either SDJWTError ()
forall a b. b -> Either a b
Right ()
                                            else Either SDJWTError () -> IO (Either SDJWTError ())
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError () -> IO (Either SDJWTError ()))
-> Either SDJWTError () -> IO (Either SDJWTError ())
forall a b. (a -> b) -> a -> b
$ SDJWTError -> Either SDJWTError ()
forall a b. a -> Either a b
Left (SDJWTError -> Either SDJWTError ())
-> SDJWTError -> Either SDJWTError ()
forall a b. (a -> b) -> a -> b
$ Text -> SDJWTError
InvalidSignature (Text -> SDJWTError) -> Text -> SDJWTError
forall a b. (a -> b) -> a -> b
$ Text
"Invalid typ header: expected '" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
requiredTypValue Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"', got '" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
typVal Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"'"
                                  
                                  case Either SDJWTError ()
typValidation of
                                    Left SDJWTError
err -> Either SDJWTError Value -> IO (Either SDJWTError Value)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError Value -> IO (Either SDJWTError Value))
-> Either SDJWTError Value -> IO (Either SDJWTError Value)
forall a b. (a -> b) -> a -> b
$ SDJWTError -> Either SDJWTError Value
forall a b. a -> Either a b
Left SDJWTError
err
                                    Right () -> do
                                      -- Verify JWT signature
                                      -- Note: jose's defaultValidationSettings does NOT validate exp/nbf claims,
                                      -- so we must validate them ourselves (see validateStandardClaims below)
                                      Either Error ByteString
result <- JOSE Error IO ByteString -> IO (Either Error ByteString)
forall e (m :: * -> *) a. JOSE e m a -> m (Either e a)
Jose.runJOSE (JOSE Error IO ByteString -> IO (Either Error ByteString))
-> JOSE Error IO ByteString -> IO (Either Error ByteString)
forall a b. (a -> b) -> a -> b
$ (ByteString -> JOSE Error IO ByteString)
-> ValidationSettings
-> JWK
-> CompactJWS JWSHeader
-> JOSE Error IO ByteString
forall a e (m :: * -> *) (h :: * -> *) p payload k s (t :: * -> *).
(HasAlgorithms a, HasValidationPolicy a, AsError e, MonadError e m,
 HasJWSHeader h, HasParams h,
 VerificationKeyStore m (h p) payload k, Cons s s Word8 Word8,
 AsEmpty s, Foldable t, ProtectionSupport p) =>
(s -> m payload) -> a -> k -> JWS t p h -> m payload
JWS.verifyJWSWithPayload ByteString -> JOSE Error IO ByteString
forall a. a -> JOSE Error IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ValidationSettings
JWS.defaultValidationSettings JWK
jwk CompactJWS JWSHeader
jws :: IO (Either JoseError.Error BS.ByteString)
                                      
                                      case Either Error ByteString
result of
                                        Left Error
err -> Either SDJWTError Value -> IO (Either SDJWTError Value)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError Value -> IO (Either SDJWTError Value))
-> Either SDJWTError Value -> IO (Either SDJWTError Value)
forall a b. (a -> b) -> a -> b
$ SDJWTError -> Either SDJWTError Value
forall a b. a -> Either a b
Left (SDJWTError -> Either SDJWTError Value)
-> SDJWTError -> Either SDJWTError Value
forall a b. (a -> b) -> a -> b
$ Text -> SDJWTError
InvalidSignature (Text -> SDJWTError) -> Text -> SDJWTError
forall a b. (a -> b) -> a -> b
$ Text
"JWT verification failed: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (Error -> String
forall a. Show a => a -> String
show Error
err)
                                        Right ByteString
payloadBS -> do
                                          -- Parse payload as JSON
                                          case ByteString -> Either String Value
forall a. FromJSON a => ByteString -> Either String a
Aeson.eitherDecodeStrict ByteString
payloadBS of
                                            Left String
jsonErr -> Either SDJWTError Value -> IO (Either SDJWTError Value)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError Value -> IO (Either SDJWTError Value))
-> Either SDJWTError Value -> IO (Either SDJWTError Value)
forall a b. (a -> b) -> a -> b
$ SDJWTError -> Either SDJWTError Value
forall a b. a -> Either a b
Left (SDJWTError -> Either SDJWTError Value)
-> SDJWTError -> Either SDJWTError Value
forall a b. (a -> b) -> a -> b
$ Text -> SDJWTError
JSONParseError (Text -> SDJWTError) -> Text -> SDJWTError
forall a b. (a -> b) -> a -> b
$ Text
"Failed to parse JWT payload: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack String
jsonErr
                                            Right Value
payload -> do
                                              -- Validate standard JWT claims (exp, nbf) if present
                                              -- jose library does not validate these, so we must do it ourselves
                                              Either SDJWTError ()
validationResult <- Value -> IO (Either SDJWTError ())
validateStandardClaims Value
payload
                                              case Either SDJWTError ()
validationResult of
                                                Left SDJWTError
err -> Either SDJWTError Value -> IO (Either SDJWTError Value)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError Value -> IO (Either SDJWTError Value))
-> Either SDJWTError Value -> IO (Either SDJWTError Value)
forall a b. (a -> b) -> a -> b
$ SDJWTError -> Either SDJWTError Value
forall a b. a -> Either a b
Left SDJWTError
err
                                                Right () -> Either SDJWTError Value -> IO (Either SDJWTError Value)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError Value -> IO (Either SDJWTError Value))
-> Either SDJWTError Value -> IO (Either SDJWTError Value)
forall a b. (a -> b) -> a -> b
$ Value -> Either SDJWTError Value
forall a b. b -> Either a b
Right Value
payload

-- | Validate standard JWT claims (exp, nbf) if present in the payload.
--
-- Per RFC 7519:
-- - exp (expiration time): Token is rejected if current time >= exp
-- - nbf (not before): Token is rejected if current time < nbf
--
-- Returns Right () if validation passes or if claims are not present.
validateStandardClaims :: Aeson.Value -> IO (Either SDJWTError ())
validateStandardClaims :: Value -> IO (Either SDJWTError ())
validateStandardClaims (Aeson.Object Object
obj) = do
  Int64
currentTime <- POSIXTime -> Int64
forall b. Integral b => POSIXTime -> b
forall a b. (RealFrac a, Integral b) => a -> b
round (POSIXTime -> Int64) -> IO POSIXTime -> IO Int64
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO POSIXTime
getPOSIXTime
  
  -- Validate exp claim if present
  Either SDJWTError ()
expValidation <- case Key -> Object -> Maybe Value
forall v. Key -> KeyMap v -> Maybe v
KeyMap.lookup Key
"exp" Object
obj of
    Just (Aeson.Number Scientific
expNum) -> do
      case Scientific -> Maybe Int64
forall i. (Integral i, Bounded i) => Scientific -> Maybe i
toBoundedInteger Scientific
expNum :: Maybe Int64 of
        Just Int64
expTime -> do
          if Int64
currentTime Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
>= Int64
expTime
            then Either SDJWTError () -> IO (Either SDJWTError ())
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError () -> IO (Either SDJWTError ()))
-> Either SDJWTError () -> IO (Either SDJWTError ())
forall a b. (a -> b) -> a -> b
$ SDJWTError -> Either SDJWTError ()
forall a b. a -> Either a b
Left (SDJWTError -> Either SDJWTError ())
-> SDJWTError -> Either SDJWTError ()
forall a b. (a -> b) -> a -> b
$ Text -> SDJWTError
InvalidSignature Text
"JWT has expired (exp claim)"
            else Either SDJWTError () -> IO (Either SDJWTError ())
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError () -> IO (Either SDJWTError ()))
-> Either SDJWTError () -> IO (Either SDJWTError ())
forall a b. (a -> b) -> a -> b
$ () -> Either SDJWTError ()
forall a b. b -> Either a b
Right ()
        Maybe Int64
Nothing -> Either SDJWTError () -> IO (Either SDJWTError ())
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError () -> IO (Either SDJWTError ()))
-> Either SDJWTError () -> IO (Either SDJWTError ())
forall a b. (a -> b) -> a -> b
$ SDJWTError -> Either SDJWTError ()
forall a b. a -> Either a b
Left (SDJWTError -> Either SDJWTError ())
-> SDJWTError -> Either SDJWTError ()
forall a b. (a -> b) -> a -> b
$ Text -> SDJWTError
InvalidSignature Text
"Invalid exp claim: value out of range for Int64"
    Just Value
_ -> Either SDJWTError () -> IO (Either SDJWTError ())
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError () -> IO (Either SDJWTError ()))
-> Either SDJWTError () -> IO (Either SDJWTError ())
forall a b. (a -> b) -> a -> b
$ SDJWTError -> Either SDJWTError ()
forall a b. a -> Either a b
Left (SDJWTError -> Either SDJWTError ())
-> SDJWTError -> Either SDJWTError ()
forall a b. (a -> b) -> a -> b
$ Text -> SDJWTError
InvalidSignature Text
"Invalid exp claim format: must be a number"
    Maybe Value
Nothing -> Either SDJWTError () -> IO (Either SDJWTError ())
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError () -> IO (Either SDJWTError ()))
-> Either SDJWTError () -> IO (Either SDJWTError ())
forall a b. (a -> b) -> a -> b
$ () -> Either SDJWTError ()
forall a b. b -> Either a b
Right ()  -- exp not present, skip validation
  
  -- If exp validation failed, return early
  case Either SDJWTError ()
expValidation of
    Left SDJWTError
err -> Either SDJWTError () -> IO (Either SDJWTError ())
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError () -> IO (Either SDJWTError ()))
-> Either SDJWTError () -> IO (Either SDJWTError ())
forall a b. (a -> b) -> a -> b
$ SDJWTError -> Either SDJWTError ()
forall a b. a -> Either a b
Left SDJWTError
err
    Right () -> do
      -- Validate nbf claim if present
      case Key -> Object -> Maybe Value
forall v. Key -> KeyMap v -> Maybe v
KeyMap.lookup Key
"nbf" Object
obj of
        Just (Aeson.Number Scientific
nbfNum) -> do
          case Scientific -> Maybe Int64
forall i. (Integral i, Bounded i) => Scientific -> Maybe i
toBoundedInteger Scientific
nbfNum :: Maybe Int64 of
            Just Int64
nbfTime -> do
              if Int64
currentTime Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
< Int64
nbfTime
                then Either SDJWTError () -> IO (Either SDJWTError ())
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError () -> IO (Either SDJWTError ()))
-> Either SDJWTError () -> IO (Either SDJWTError ())
forall a b. (a -> b) -> a -> b
$ SDJWTError -> Either SDJWTError ()
forall a b. a -> Either a b
Left (SDJWTError -> Either SDJWTError ())
-> SDJWTError -> Either SDJWTError ()
forall a b. (a -> b) -> a -> b
$ Text -> SDJWTError
InvalidSignature Text
"JWT not yet valid (nbf claim)"
                else Either SDJWTError () -> IO (Either SDJWTError ())
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError () -> IO (Either SDJWTError ()))
-> Either SDJWTError () -> IO (Either SDJWTError ())
forall a b. (a -> b) -> a -> b
$ () -> Either SDJWTError ()
forall a b. b -> Either a b
Right ()
            Maybe Int64
Nothing -> Either SDJWTError () -> IO (Either SDJWTError ())
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError () -> IO (Either SDJWTError ()))
-> Either SDJWTError () -> IO (Either SDJWTError ())
forall a b. (a -> b) -> a -> b
$ SDJWTError -> Either SDJWTError ()
forall a b. a -> Either a b
Left (SDJWTError -> Either SDJWTError ())
-> SDJWTError -> Either SDJWTError ()
forall a b. (a -> b) -> a -> b
$ Text -> SDJWTError
InvalidSignature Text
"Invalid nbf claim: value out of range for Int64"
        Just Value
_ -> Either SDJWTError () -> IO (Either SDJWTError ())
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError () -> IO (Either SDJWTError ()))
-> Either SDJWTError () -> IO (Either SDJWTError ())
forall a b. (a -> b) -> a -> b
$ SDJWTError -> Either SDJWTError ()
forall a b. a -> Either a b
Left (SDJWTError -> Either SDJWTError ())
-> SDJWTError -> Either SDJWTError ()
forall a b. (a -> b) -> a -> b
$ Text -> SDJWTError
InvalidSignature Text
"Invalid nbf claim format: must be a number"
        Maybe Value
Nothing -> Either SDJWTError () -> IO (Either SDJWTError ())
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError () -> IO (Either SDJWTError ()))
-> Either SDJWTError () -> IO (Either SDJWTError ())
forall a b. (a -> b) -> a -> b
$ () -> Either SDJWTError ()
forall a b. b -> Either a b
Right ()  -- nbf not present, skip validation
validateStandardClaims Value
_ = Either SDJWTError () -> IO (Either SDJWTError ())
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError () -> IO (Either SDJWTError ()))
-> Either SDJWTError () -> IO (Either SDJWTError ())
forall a b. (a -> b) -> a -> b
$ () -> Either SDJWTError ()
forall a b. b -> Either a b
Right ()  -- Not an object, skip validation

-- | Parse a JWK from JSON Text.
--
-- Parses a JSON Web Key (JWK) from its JSON representation.
-- Supports RSA, Ed25519, and EC P-256 keys.
--
-- The JWK JSON format follows RFC 7517. Examples:
--
-- - RSA public key: {"kty":"RSA","n":"...","e":"..."}
-- - Ed25519 public key: {"kty":"OKP","crv":"Ed25519","x":"..."}
-- - EC P-256 public key: {"kty":"EC","crv":"P-256","x":"...","y":"..."}
-- - RSA private key: {"kty":"RSA","n":"...","e":"...","d":"...","p":"...","q":"..."}
-- - Ed25519 private key: {"kty":"OKP","crv":"Ed25519","d":"...","x":"..."}
-- - EC P-256 private key: {"kty":"EC","crv":"P-256","d":"...","x":"...","y":"..."}
parseJWKFromText :: T.Text -> Either SDJWTError JWK.JWK
parseJWKFromText :: Text -> Either SDJWTError JWK
parseJWKFromText Text
jwkText =
  case ByteString -> Either String Value
forall a. FromJSON a => ByteString -> Either String a
Aeson.eitherDecodeStrict (Text -> ByteString
TE.encodeUtf8 Text
jwkText) of
    Left String
err -> SDJWTError -> Either SDJWTError JWK
forall a b. a -> Either a b
Left (SDJWTError -> Either SDJWTError JWK)
-> SDJWTError -> Either SDJWTError JWK
forall a b. (a -> b) -> a -> b
$ Text -> SDJWTError
InvalidSignature (Text -> SDJWTError) -> Text -> SDJWTError
forall a b. (a -> b) -> a -> b
$ Text
"Failed to parse JWK JSON: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack String
err
    Right Value
jwkValue -> case Value -> Result JWK
forall a. FromJSON a => Value -> Result a
Aeson.fromJSON Value
jwkValue of
      Aeson.Error String
err -> SDJWTError -> Either SDJWTError JWK
forall a b. a -> Either a b
Left (SDJWTError -> Either SDJWTError JWK)
-> SDJWTError -> Either SDJWTError JWK
forall a b. (a -> b) -> a -> b
$ Text -> SDJWTError
InvalidSignature (Text -> SDJWTError) -> Text -> SDJWTError
forall a b. (a -> b) -> a -> b
$ Text
"Failed to create JWK: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack String
err
      Aeson.Success JWK
jwk -> JWK -> Either SDJWTError JWK
forall a b. b -> Either a b
Right JWK
jwk