{-# LANGUAGE OverloadedStrings #-}
-- | Hash computation and verification for SD-JWT disclosures (low-level).
--
-- This module provides functions for computing digests of disclosures
-- and verifying that digests match disclosures. All three hash algorithms
-- required by RFC 9901 are supported: SHA-256, SHA-384, and SHA-512.
--
-- == Usage
--
-- This module contains low-level hash and digest utilities that are typically
-- used internally by other SD-JWT modules. Most users should use the higher-level
-- APIs in:
--
-- * 'SDJWT.Issuer' - For issuers (handles digest computation internally)
-- * 'SDJWT.Holder' - For holders (handles digest computation internally)
-- * 'SDJWT.Verifier' - For verifiers (handles digest verification internally)
--
-- These utilities may be useful for:
--
-- * Advanced use cases requiring custom digest computation
-- * Library developers building on top of SD-JWT
-- * Testing and debugging
--
module SDJWT.Internal.Digest
  ( computeDigest
  , computeDigestText
  , verifyDigest
  , parseHashAlgorithm
  , defaultHashAlgorithm
  , hashAlgorithmToText
  , extractDigestsFromValue
  , extractDigestStringsFromSDArray
  ) where

import SDJWT.Internal.Types (HashAlgorithm(..), Digest(..), EncodedDisclosure(..), SDJWTError(..))
import SDJWT.Internal.Utils (hashToBytes, base64urlEncode, constantTimeEq, textToByteString)
import qualified Data.Aeson as Aeson
import qualified Data.Aeson.Key as Key
import qualified Data.Aeson.KeyMap as KeyMap
import qualified Data.Vector as V
import qualified Data.Text as T
import qualified Data.Text.Encoding as TE
import Data.Maybe (mapMaybe)
import Control.Monad (mapM)

-- | Default hash algorithm (SHA-256 per RFC 9901).
--
-- When the _sd_alg claim is not present in an SD-JWT, SHA-256 is used
-- as the default hash algorithm.
defaultHashAlgorithm :: HashAlgorithm
defaultHashAlgorithm :: HashAlgorithm
defaultHashAlgorithm = HashAlgorithm
SHA256

-- | Convert hash algorithm to text identifier.
--
-- Returns the hash algorithm name as specified in RFC 9901:
-- "sha-256", "sha-384", or "sha-512".
hashAlgorithmToText :: HashAlgorithm -> T.Text
hashAlgorithmToText :: HashAlgorithm -> Text
hashAlgorithmToText HashAlgorithm
SHA256 = Text
"sha-256"
hashAlgorithmToText HashAlgorithm
SHA384 = Text
"sha-384"
hashAlgorithmToText HashAlgorithm
SHA512 = Text
"sha-512"

-- | Parse hash algorithm from text identifier.
--
-- Parses hash algorithm names from the _sd_alg claim.
-- Returns 'Nothing' if the algorithm is not recognized.
parseHashAlgorithm :: T.Text -> Maybe HashAlgorithm
parseHashAlgorithm :: Text -> Maybe HashAlgorithm
parseHashAlgorithm Text
"sha-256" = HashAlgorithm -> Maybe HashAlgorithm
forall a. a -> Maybe a
Just HashAlgorithm
SHA256
parseHashAlgorithm Text
"sha-384" = HashAlgorithm -> Maybe HashAlgorithm
forall a. a -> Maybe a
Just HashAlgorithm
SHA384
parseHashAlgorithm Text
"sha-512" = HashAlgorithm -> Maybe HashAlgorithm
forall a. a -> Maybe a
Just HashAlgorithm
SHA512
parseHashAlgorithm Text
_ = Maybe HashAlgorithm
forall a. Maybe a
Nothing

-- | Compute digest of a disclosure.
--
-- The digest is computed over the US-ASCII bytes of the base64url-encoded
-- disclosure string (per RFC 9901). The bytes of the hash output are then
-- base64url encoded to produce the final digest.
--
-- This follows the convention in JWS (RFC 7515) and JWE (RFC 7516).
--
-- Note: RFC 9901 requires US-ASCII encoding. Since base64url strings contain
-- only ASCII characters (A-Z, a-z, 0-9, -, _), UTF-8 encoding produces
-- identical bytes to US-ASCII for these strings.

computeDigest :: HashAlgorithm -> EncodedDisclosure -> Digest
computeDigest :: HashAlgorithm -> EncodedDisclosure -> Digest
computeDigest HashAlgorithm
alg (EncodedDisclosure Text
encoded) =
  let
    -- Convert the base64url-encoded disclosure to bytes
    -- UTF-8 encoding is equivalent to US-ASCII for base64url strings (ASCII-only)
    disclosureBytes :: ByteString
disclosureBytes = Text -> ByteString
TE.encodeUtf8 Text
encoded
    -- Compute hash
    hashBytes :: ByteString
hashBytes = HashAlgorithm -> ByteString -> ByteString
hashToBytes HashAlgorithm
alg ByteString
disclosureBytes
    -- Base64url encode the hash bytes
    digestText :: Text
digestText = ByteString -> Text
base64urlEncode ByteString
hashBytes
  in
    Text -> Digest
Digest Text
digestText

-- | Compute digest text (string) from a disclosure.
--
-- Convenience function that computes the digest and extracts the text.
-- Equivalent to @unDigest . computeDigest@.
computeDigestText :: HashAlgorithm -> EncodedDisclosure -> T.Text
computeDigestText :: HashAlgorithm -> EncodedDisclosure -> Text
computeDigestText HashAlgorithm
alg = Digest -> Text
unDigest (Digest -> Text)
-> (EncodedDisclosure -> Digest) -> EncodedDisclosure -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HashAlgorithm -> EncodedDisclosure -> Digest
computeDigest HashAlgorithm
alg

-- | Verify that a digest matches a disclosure.
--
-- Computes the digest of the disclosure using the specified hash algorithm
-- and compares it to the expected digest using constant-time comparison.
-- Returns 'True' if they match.
--
-- SECURITY: Uses constant-time comparison to prevent timing attacks.
-- This is critical for cryptographic verification operations.
verifyDigest :: HashAlgorithm -> Digest -> EncodedDisclosure -> Bool
verifyDigest :: HashAlgorithm -> Digest -> EncodedDisclosure -> Bool
verifyDigest HashAlgorithm
alg Digest
expectedDigest EncodedDisclosure
disclosure =
  let
    computedDigest :: Digest
computedDigest = HashAlgorithm -> EncodedDisclosure -> Digest
computeDigest HashAlgorithm
alg EncodedDisclosure
disclosure
    -- Convert digests to ByteString for constant-time comparison
    expectedBytes :: ByteString
expectedBytes = Text -> ByteString
textToByteString (Digest -> Text
unDigest Digest
expectedDigest)
    computedBytes :: ByteString
computedBytes = Text -> ByteString
textToByteString (Digest -> Text
unDigest Digest
computedDigest)
  in
    ByteString -> ByteString -> Bool
constantTimeEq ByteString
expectedBytes ByteString
computedBytes

-- | Recursively extract digests from JSON value (_sd arrays and array ellipsis objects).
--
-- This function extracts all digests from a JSON value by:
--
-- 1. Looking for _sd arrays in objects and extracting string digests
-- 2. Looking for {"...": "<digest>"} objects in arrays
-- 3. Recursively processing nested structures
--
-- Used for extracting digests from SD-JWT payloads and disclosure values.
--
-- Per RFC 9901 Section 4.2.4.1, _sd arrays MUST contain only strings (digests).
-- Returns an error if non-string values are found in _sd arrays.
extractDigestsFromValue :: Aeson.Value -> Either SDJWTError [Digest]
extractDigestsFromValue :: Value -> Either SDJWTError [Digest]
extractDigestsFromValue (Aeson.Object Object
obj) = do
  [Digest]
topLevelDigests <- case Key -> Object -> Maybe Value
forall v. Key -> KeyMap v -> Maybe v
KeyMap.lookup Key
"_sd" Object
obj of
    Just (Aeson.Array Array
arr) ->
      (Value -> Either SDJWTError Digest)
-> [Value] -> Either SDJWTError [Digest]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\Value
v -> case Value
v of
        Aeson.String Text
s -> Digest -> Either SDJWTError Digest
forall a b. b -> Either a b
Right (Text -> Digest
Digest Text
s)
        Value
_ -> SDJWTError -> Either SDJWTError Digest
forall a b. a -> Either a b
Left (SDJWTError -> Either SDJWTError Digest)
-> SDJWTError -> Either SDJWTError Digest
forall a b. (a -> b) -> a -> b
$ Text -> SDJWTError
InvalidDigest Text
"_sd array must contain only string digests (RFC 9901 Section 4.2.4.1)"
      ) (Array -> [Value]
forall a. Vector a -> [a]
V.toList Array
arr)
    Maybe Value
_ -> [Digest] -> Either SDJWTError [Digest]
forall a b. b -> Either a b
Right []
  -- Recursively extract from nested objects
  [[Digest]]
nestedDigests <- ((Key, Value) -> Either SDJWTError [Digest])
-> [(Key, Value)] -> Either SDJWTError [[Digest]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Value -> Either SDJWTError [Digest]
extractDigestsFromValue (Value -> Either SDJWTError [Digest])
-> ((Key, Value) -> Value)
-> (Key, Value)
-> Either SDJWTError [Digest]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Key, Value) -> Value
forall a b. (a, b) -> b
snd) (Object -> [(Key, Value)]
forall v. KeyMap v -> [(Key, v)]
KeyMap.toList Object
obj)
  [Digest] -> Either SDJWTError [Digest]
forall a. a -> Either SDJWTError a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Digest] -> Either SDJWTError [Digest])
-> [Digest] -> Either SDJWTError [Digest]
forall a b. (a -> b) -> a -> b
$ [Digest]
topLevelDigests [Digest] -> [Digest] -> [Digest]
forall a. [a] -> [a] -> [a]
++ [[Digest]] -> [Digest]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Digest]]
nestedDigests
extractDigestsFromValue (Aeson.Array Array
arr) = do
  -- Check for array ellipsis objects {"...": "<digest>"}
  -- Per RFC 9901 Section 4.2.4.2: "There MUST NOT be any other keys in the object."
  let elements :: [Value]
elements = Array -> [Value]
forall a. Vector a -> [a]
V.toList Array
arr
  [[Digest]]
results <- (Value -> Either SDJWTError [Digest])
-> [Value] -> Either SDJWTError [[Digest]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\Value
el -> case Value
el of
    Aeson.Object Object
obj ->
      case Key -> Object -> Maybe Value
forall v. Key -> KeyMap v -> Maybe v
KeyMap.lookup (Text -> Key
Key.fromText Text
"...") Object
obj of
        Just (Aeson.String Text
digest) -> do
          -- Validate that ellipsis object only contains the "..." key
          if Object -> Int
forall v. KeyMap v -> Int
KeyMap.size Object
obj Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1
            then [Digest] -> Either SDJWTError [Digest]
forall a b. b -> Either a b
Right [Text -> Digest
Digest Text
digest]
            else SDJWTError -> Either SDJWTError [Digest]
forall a b. a -> Either a b
Left (SDJWTError -> Either SDJWTError [Digest])
-> SDJWTError -> Either SDJWTError [Digest]
forall a b. (a -> b) -> a -> b
$ Text -> SDJWTError
InvalidDigest Text
"Ellipsis object must contain only the \"...\" key (RFC 9901 Section 4.2.4.2)"
        Maybe Value
_ -> Value -> Either SDJWTError [Digest]
extractDigestsFromValue Value
el  -- Recursively check nested structures
    Value
_ -> Value -> Either SDJWTError [Digest]
extractDigestsFromValue Value
el  -- Recursively check nested structures
    ) [Value]
elements
  [Digest] -> Either SDJWTError [Digest]
forall a. a -> Either SDJWTError a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Digest] -> Either SDJWTError [Digest])
-> [Digest] -> Either SDJWTError [Digest]
forall a b. (a -> b) -> a -> b
$ [[Digest]] -> [Digest]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Digest]]
results
extractDigestsFromValue Value
_ = [Digest] -> Either SDJWTError [Digest]
forall a b. b -> Either a b
Right []

-- | Extract digest strings from an _sd array in a JSON object.
--
-- This helper function extracts string digests from the _sd array field
-- of a JSON object. Returns an empty list if _sd is not present or not an array.
-- This is a convenience function for cases where you only need the digest strings,
-- not the full Digest type.
extractDigestStringsFromSDArray :: Aeson.Object -> [T.Text]
extractDigestStringsFromSDArray :: Object -> [Text]
extractDigestStringsFromSDArray Object
obj =
  case Key -> Object -> Maybe Value
forall v. Key -> KeyMap v -> Maybe v
KeyMap.lookup Key
"_sd" Object
obj of
    Just (Aeson.Array Array
arr) ->
      (Value -> Maybe Text) -> [Value] -> [Text]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (\Value
v -> case Value
v of
        Aeson.String Text
s -> Text -> Maybe Text
forall a. a -> Maybe a
Just Text
s
        Value
_ -> Maybe Text
forall a. Maybe a
Nothing
        ) (Array -> [Value]
forall a. Vector a -> [a]
V.toList Array
arr)
    Maybe Value
_ -> []