{-# LANGUAGE OverloadedStrings #-}
module SDJWT.Internal.KeyBinding
( createKeyBindingJWT
, computeSDHash
, verifyKeyBindingJWT
, addKeyBindingToPresentation
) where
import SDJWT.Internal.Types (HashAlgorithm(..), Digest(..), SDJWTPresentation(..), SDJWTError(..))
import SDJWT.Internal.Utils (hashToBytes, textToByteString, base64urlEncode, constantTimeEq, base64urlDecode)
import SDJWT.Internal.Serialization (serializePresentation)
import SDJWT.Internal.JWT (signJWTWithTyp, verifyJWT, JWKLike)
import qualified Data.Aeson as Aeson
import qualified Data.Aeson.Key as Key
import qualified Data.Aeson.KeyMap as KeyMap
import qualified Data.Text as T
import Data.Int (Int64)
createKeyBindingJWT
:: JWKLike jwk => HashAlgorithm
-> jwk
-> T.Text
-> T.Text
-> Int64
-> SDJWTPresentation
-> Aeson.Object
-> IO (Either SDJWTError T.Text)
createKeyBindingJWT :: forall jwk.
JWKLike jwk =>
HashAlgorithm
-> jwk
-> Text
-> Text
-> Int64
-> SDJWTPresentation
-> Object
-> IO (Either SDJWTError Text)
createKeyBindingJWT HashAlgorithm
hashAlg jwk
holderPrivateKey Text
audience Text
nonce Int64
issuedAt SDJWTPresentation
presentation Object
optionalClaims =
let sdHash :: Digest
sdHash = HashAlgorithm -> SDJWTPresentation -> Digest
computeSDHash HashAlgorithm
hashAlg SDJWTPresentation
presentation
basePayloadObj :: Object
basePayloadObj = [(Key, Value)] -> Object
forall v. [(Key, v)] -> KeyMap v
KeyMap.fromList
[ (Text -> Key
Key.fromText Text
"aud", Text -> Value
Aeson.String Text
audience)
, (Text -> Key
Key.fromText Text
"nonce", Text -> Value
Aeson.String Text
nonce)
, (Text -> Key
Key.fromText Text
"iat", Scientific -> Value
Aeson.Number (Int64 -> Scientific
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
issuedAt))
, (Text -> Key
Key.fromText Text
"sd_hash", Text -> Value
Aeson.String (Digest -> Text
unDigest Digest
sdHash))
]
kbPayloadObj :: Object
kbPayloadObj = Object -> Object -> Object
forall v. KeyMap v -> KeyMap v -> KeyMap v
KeyMap.union Object
optionalClaims Object
basePayloadObj
kbPayload :: Value
kbPayload = Object -> Value
Aeson.Object Object
kbPayloadObj
in
Text -> jwk -> Value -> IO (Either SDJWTError Text)
forall jwk.
JWKLike jwk =>
Text -> jwk -> Value -> IO (Either SDJWTError Text)
signJWTWithTyp Text
"kb+jwt" jwk
holderPrivateKey Value
kbPayload
computeSDHash
:: HashAlgorithm
-> SDJWTPresentation
-> Digest
computeSDHash :: HashAlgorithm -> SDJWTPresentation -> Digest
computeSDHash HashAlgorithm
hashAlg SDJWTPresentation
presentation =
let presentationWithoutKB :: SDJWTPresentation
presentationWithoutKB = SDJWTPresentation
presentation { keyBindingJWT = Nothing }
presentationText :: Text
presentationText = SDJWTPresentation -> Text
serializePresentation SDJWTPresentation
presentationWithoutKB
presentationBytes :: ByteString
presentationBytes = Text -> ByteString
textToByteString Text
presentationText
hashBytes :: ByteString
hashBytes = HashAlgorithm -> ByteString -> ByteString
hashToBytes HashAlgorithm
hashAlg ByteString
presentationBytes
hashText :: Text
hashText = ByteString -> Text
base64urlEncode ByteString
hashBytes
in
Text -> Digest
Digest Text
hashText
verifyKeyBindingJWT
:: JWKLike jwk => HashAlgorithm
-> jwk
-> T.Text
-> SDJWTPresentation
-> IO (Either SDJWTError ())
verifyKeyBindingJWT :: forall jwk.
JWKLike jwk =>
HashAlgorithm
-> jwk -> Text -> SDJWTPresentation -> IO (Either SDJWTError ())
verifyKeyBindingJWT HashAlgorithm
hashAlg jwk
holderPublicKey Text
kbJWT SDJWTPresentation
presentation = do
let kbParts :: [Text]
kbParts = HasCallStack => Text -> Text -> [Text]
Text -> Text -> [Text]
T.splitOn Text
"." Text
kbJWT
case [Text]
kbParts of
(Text
headerPart : Text
_payloadPart : [Text]
_signaturePart) -> do
Either SDJWTError ByteString
headerBytes <- case Text -> Either Text ByteString
base64urlDecode Text
headerPart of
Left Text
err -> Either SDJWTError ByteString -> IO (Either SDJWTError ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError ByteString -> IO (Either SDJWTError ByteString))
-> Either SDJWTError ByteString
-> IO (Either SDJWTError ByteString)
forall a b. (a -> b) -> a -> b
$ SDJWTError -> Either SDJWTError ByteString
forall a b. a -> Either a b
Left (SDJWTError -> Either SDJWTError ByteString)
-> SDJWTError -> Either SDJWTError ByteString
forall a b. (a -> b) -> a -> b
$ Text -> SDJWTError
InvalidKeyBinding (Text -> SDJWTError) -> Text -> SDJWTError
forall a b. (a -> b) -> a -> b
$ Text
"Failed to decode KB-JWT header: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
err
Right ByteString
bs -> Either SDJWTError ByteString -> IO (Either SDJWTError ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError ByteString -> IO (Either SDJWTError ByteString))
-> Either SDJWTError ByteString
-> IO (Either SDJWTError ByteString)
forall a b. (a -> b) -> a -> b
$ ByteString -> Either SDJWTError ByteString
forall a b. b -> Either a b
Right ByteString
bs
case Either SDJWTError ByteString
headerBytes of
Left SDJWTError
err -> Either SDJWTError () -> IO (Either SDJWTError ())
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError () -> IO (Either SDJWTError ()))
-> Either SDJWTError () -> IO (Either SDJWTError ())
forall a b. (a -> b) -> a -> b
$ SDJWTError -> Either SDJWTError ()
forall a b. a -> Either a b
Left SDJWTError
err
Right ByteString
hBytes -> do
Either SDJWTError Value
headerJson <- case ByteString -> Either String Value
forall a. FromJSON a => ByteString -> Either String a
Aeson.eitherDecodeStrict ByteString
hBytes of
Left String
err -> Either SDJWTError Value -> IO (Either SDJWTError Value)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError Value -> IO (Either SDJWTError Value))
-> Either SDJWTError Value -> IO (Either SDJWTError Value)
forall a b. (a -> b) -> a -> b
$ SDJWTError -> Either SDJWTError Value
forall a b. a -> Either a b
Left (SDJWTError -> Either SDJWTError Value)
-> SDJWTError -> Either SDJWTError Value
forall a b. (a -> b) -> a -> b
$ Text -> SDJWTError
InvalidKeyBinding (Text -> SDJWTError) -> Text -> SDJWTError
forall a b. (a -> b) -> a -> b
$ Text
"Failed to parse KB-JWT header: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack String
err
Right Value
val -> Either SDJWTError Value -> IO (Either SDJWTError Value)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError Value -> IO (Either SDJWTError Value))
-> Either SDJWTError Value -> IO (Either SDJWTError Value)
forall a b. (a -> b) -> a -> b
$ Value -> Either SDJWTError Value
forall a b. b -> Either a b
Right Value
val
case Either SDJWTError Value
headerJson of
Left SDJWTError
err -> Either SDJWTError () -> IO (Either SDJWTError ())
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError () -> IO (Either SDJWTError ()))
-> Either SDJWTError () -> IO (Either SDJWTError ())
forall a b. (a -> b) -> a -> b
$ SDJWTError -> Either SDJWTError ()
forall a b. a -> Either a b
Left SDJWTError
err
Right (Aeson.Object Object
hObj) -> do
case Key -> Object -> Maybe Value
forall v. Key -> KeyMap v -> Maybe v
KeyMap.lookup Key
"typ" Object
hObj of
Just (Aeson.String Text
"kb+jwt") -> do
Either SDJWTError Value
verifiedPayloadResult <- jwk -> Text -> Maybe Text -> IO (Either SDJWTError Value)
forall jwk.
JWKLike jwk =>
jwk -> Text -> Maybe Text -> IO (Either SDJWTError Value)
verifyJWT jwk
holderPublicKey Text
kbJWT Maybe Text
forall a. Maybe a
Nothing
case Either SDJWTError Value
verifiedPayloadResult of
Left SDJWTError
err -> Either SDJWTError () -> IO (Either SDJWTError ())
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (SDJWTError -> Either SDJWTError ()
forall a b. a -> Either a b
Left SDJWTError
err)
Right Value
kbPayload -> do
Either SDJWTError Value
sdHashClaim <- Either SDJWTError Value -> IO (Either SDJWTError Value)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError Value -> IO (Either SDJWTError Value))
-> Either SDJWTError Value -> IO (Either SDJWTError Value)
forall a b. (a -> b) -> a -> b
$ Text -> Value -> Either SDJWTError Value
extractClaim Text
"sd_hash" Value
kbPayload
Either SDJWTError Value
nonceClaim <- Either SDJWTError Value -> IO (Either SDJWTError Value)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError Value -> IO (Either SDJWTError Value))
-> Either SDJWTError Value -> IO (Either SDJWTError Value)
forall a b. (a -> b) -> a -> b
$ Text -> Value -> Either SDJWTError Value
extractClaim Text
"nonce" Value
kbPayload
Either SDJWTError Value
audClaim <- Either SDJWTError Value -> IO (Either SDJWTError Value)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError Value -> IO (Either SDJWTError Value))
-> Either SDJWTError Value -> IO (Either SDJWTError Value)
forall a b. (a -> b) -> a -> b
$ Text -> Value -> Either SDJWTError Value
extractClaim Text
"aud" Value
kbPayload
Either SDJWTError Value
iatClaim <- Either SDJWTError Value -> IO (Either SDJWTError Value)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError Value -> IO (Either SDJWTError Value))
-> Either SDJWTError Value -> IO (Either SDJWTError Value)
forall a b. (a -> b) -> a -> b
$ Text -> Value -> Either SDJWTError Value
extractClaim Text
"iat" Value
kbPayload
case Either SDJWTError Value
sdHashClaim of
Left SDJWTError
err -> Either SDJWTError () -> IO (Either SDJWTError ())
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (SDJWTError -> Either SDJWTError ()
forall a b. a -> Either a b
Left SDJWTError
err)
Right (Aeson.String Text
hashText) -> do
let computedHash :: Digest
computedHash = HashAlgorithm -> SDJWTPresentation -> Digest
computeSDHash HashAlgorithm
hashAlg SDJWTPresentation
presentation
expectedBytes :: ByteString
expectedBytes = Text -> ByteString
textToByteString Text
hashText
computedBytes :: ByteString
computedBytes = Text -> ByteString
textToByteString (Digest -> Text
unDigest Digest
computedHash)
if ByteString -> ByteString -> Bool
constantTimeEq ByteString
expectedBytes ByteString
computedBytes
then do
case (Either SDJWTError Value
nonceClaim, Either SDJWTError Value
audClaim, Either SDJWTError Value
iatClaim) of
(Right (Aeson.String Text
_), Right (Aeson.String Text
_), Right (Aeson.Number Scientific
_)) -> Either SDJWTError () -> IO (Either SDJWTError ())
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (() -> Either SDJWTError ()
forall a b. b -> Either a b
Right ())
(Either SDJWTError Value, Either SDJWTError Value,
Either SDJWTError Value)
_ -> Either SDJWTError () -> IO (Either SDJWTError ())
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError () -> IO (Either SDJWTError ()))
-> Either SDJWTError () -> IO (Either SDJWTError ())
forall a b. (a -> b) -> a -> b
$ SDJWTError -> Either SDJWTError ()
forall a b. a -> Either a b
Left (SDJWTError -> Either SDJWTError ())
-> SDJWTError -> Either SDJWTError ()
forall a b. (a -> b) -> a -> b
$ Text -> SDJWTError
InvalidKeyBinding Text
"Missing required claims (nonce, aud, iat)"
else Either SDJWTError () -> IO (Either SDJWTError ())
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError () -> IO (Either SDJWTError ()))
-> Either SDJWTError () -> IO (Either SDJWTError ())
forall a b. (a -> b) -> a -> b
$ SDJWTError -> Either SDJWTError ()
forall a b. a -> Either a b
Left (SDJWTError -> Either SDJWTError ())
-> SDJWTError -> Either SDJWTError ()
forall a b. (a -> b) -> a -> b
$ Text -> SDJWTError
InvalidKeyBinding Text
"sd_hash mismatch"
Right Value
_ -> Either SDJWTError () -> IO (Either SDJWTError ())
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError () -> IO (Either SDJWTError ()))
-> Either SDJWTError () -> IO (Either SDJWTError ())
forall a b. (a -> b) -> a -> b
$ SDJWTError -> Either SDJWTError ()
forall a b. a -> Either a b
Left (SDJWTError -> Either SDJWTError ())
-> SDJWTError -> Either SDJWTError ()
forall a b. (a -> b) -> a -> b
$ Text -> SDJWTError
InvalidKeyBinding Text
"Invalid sd_hash claim format"
Just (Aeson.String Text
typValue) -> Either SDJWTError () -> IO (Either SDJWTError ())
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError () -> IO (Either SDJWTError ()))
-> Either SDJWTError () -> IO (Either SDJWTError ())
forall a b. (a -> b) -> a -> b
$ SDJWTError -> Either SDJWTError ()
forall a b. a -> Either a b
Left (SDJWTError -> Either SDJWTError ())
-> SDJWTError -> Either SDJWTError ()
forall a b. (a -> b) -> a -> b
$ Text -> SDJWTError
InvalidKeyBinding (Text -> SDJWTError) -> Text -> SDJWTError
forall a b. (a -> b) -> a -> b
$ Text
"Invalid KB-JWT typ: expected 'kb+jwt', got '" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
typValue Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"' (RFC 9901 Section 4.3)"
Maybe Value
_ -> Either SDJWTError () -> IO (Either SDJWTError ())
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError () -> IO (Either SDJWTError ()))
-> Either SDJWTError () -> IO (Either SDJWTError ())
forall a b. (a -> b) -> a -> b
$ SDJWTError -> Either SDJWTError ()
forall a b. a -> Either a b
Left (SDJWTError -> Either SDJWTError ())
-> SDJWTError -> Either SDJWTError ()
forall a b. (a -> b) -> a -> b
$ Text -> SDJWTError
InvalidKeyBinding Text
"Missing 'typ' header in KB-JWT (RFC 9901 Section 4.3 requires typ: 'kb+jwt')"
Right Value
_ -> Either SDJWTError () -> IO (Either SDJWTError ())
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError () -> IO (Either SDJWTError ()))
-> Either SDJWTError () -> IO (Either SDJWTError ())
forall a b. (a -> b) -> a -> b
$ SDJWTError -> Either SDJWTError ()
forall a b. a -> Either a b
Left (SDJWTError -> Either SDJWTError ())
-> SDJWTError -> Either SDJWTError ()
forall a b. (a -> b) -> a -> b
$ Text -> SDJWTError
InvalidKeyBinding Text
"Invalid KB-JWT header format: expected object"
[Text]
_ -> Either SDJWTError () -> IO (Either SDJWTError ())
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SDJWTError () -> IO (Either SDJWTError ()))
-> Either SDJWTError () -> IO (Either SDJWTError ())
forall a b. (a -> b) -> a -> b
$ SDJWTError -> Either SDJWTError ()
forall a b. a -> Either a b
Left (SDJWTError -> Either SDJWTError ())
-> SDJWTError -> Either SDJWTError ()
forall a b. (a -> b) -> a -> b
$ Text -> SDJWTError
InvalidKeyBinding Text
"Invalid KB-JWT format: expected header.payload.signature"
addKeyBindingToPresentation
:: JWKLike jwk => HashAlgorithm
-> jwk
-> T.Text
-> T.Text
-> Int64
-> SDJWTPresentation
-> Aeson.Object
-> IO (Either SDJWTError SDJWTPresentation)
addKeyBindingToPresentation :: forall jwk.
JWKLike jwk =>
HashAlgorithm
-> jwk
-> Text
-> Text
-> Int64
-> SDJWTPresentation
-> Object
-> IO (Either SDJWTError SDJWTPresentation)
addKeyBindingToPresentation HashAlgorithm
hashAlg jwk
holderKey Text
audience Text
nonce Int64
issuedAt SDJWTPresentation
presentation Object
optionalClaims =
(Text -> SDJWTPresentation)
-> Either SDJWTError Text -> Either SDJWTError SDJWTPresentation
forall a b. (a -> b) -> Either SDJWTError a -> Either SDJWTError b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Text
kb -> SDJWTPresentation
presentation { keyBindingJWT = Just kb })
(Either SDJWTError Text -> Either SDJWTError SDJWTPresentation)
-> IO (Either SDJWTError Text)
-> IO (Either SDJWTError SDJWTPresentation)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> HashAlgorithm
-> jwk
-> Text
-> Text
-> Int64
-> SDJWTPresentation
-> Object
-> IO (Either SDJWTError Text)
forall jwk.
JWKLike jwk =>
HashAlgorithm
-> jwk
-> Text
-> Text
-> Int64
-> SDJWTPresentation
-> Object
-> IO (Either SDJWTError Text)
createKeyBindingJWT HashAlgorithm
hashAlg jwk
holderKey Text
audience Text
nonce Int64
issuedAt SDJWTPresentation
presentation Object
optionalClaims
extractClaim :: T.Text -> Aeson.Value -> Either SDJWTError Aeson.Value
Text
claimName (Aeson.Object Object
obj) =
case Key -> Object -> Maybe Value
forall v. Key -> KeyMap v -> Maybe v
KeyMap.lookup (Text -> Key
Key.fromText Text
claimName) Object
obj of
Just Value
val -> Value -> Either SDJWTError Value
forall a b. b -> Either a b
Right Value
val
Maybe Value
Nothing -> SDJWTError -> Either SDJWTError Value
forall a b. a -> Either a b
Left (SDJWTError -> Either SDJWTError Value)
-> SDJWTError -> Either SDJWTError Value
forall a b. (a -> b) -> a -> b
$ Text -> SDJWTError
InvalidKeyBinding (Text -> SDJWTError) -> Text -> SDJWTError
forall a b. (a -> b) -> a -> b
$ Text
"Missing claim: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
claimName
extractClaim Text
_ Value
_ = SDJWTError -> Either SDJWTError Value
forall a b. a -> Either a b
Left (SDJWTError -> Either SDJWTError Value)
-> SDJWTError -> Either SDJWTError Value
forall a b. (a -> b) -> a -> b
$ Text -> SDJWTError
InvalidKeyBinding Text
"KB-JWT payload is not an object"