{-# LANGUAGE OverloadedStrings #-}
-- | Utility functions for SD-JWT operations (low-level).
--
-- This module provides base64url encoding/decoding, salt generation,
-- and text/ByteString conversions used throughout the SD-JWT library.
--
-- == Usage
--
-- This module contains low-level utilities that are typically used internally
-- by other SD-JWT modules. Most users should use the higher-level APIs in:
--
-- * 'SDJWT.Issuer' - For issuers
-- * 'SDJWT.Holder' - For holders  
-- * 'SDJWT.Verifier' - For verifiers
--
-- These utilities may be useful for:
-- * Advanced use cases requiring custom implementations
-- * Library developers building on top of SD-JWT
-- * Testing and debugging
--
module SDJWT.Internal.Utils
  ( base64urlEncode
  , base64urlDecode
  , textToByteString
  , byteStringToText
  , hashToBytes
  , splitJSONPointer
  , unescapeJSONPointer
  , constantTimeEq
  , generateSalt  -- Internal use only, not part of public API
  , groupPathsByFirstSegment
  ) where

import qualified Data.ByteString.Base64.URL as Base64
import qualified Data.ByteString as BS
import qualified Data.Text as T
import qualified Data.Text.Encoding as TE
import qualified Crypto.Random as RNG
import qualified Crypto.Hash as Hash
import qualified Data.ByteArray as BA
import qualified Data.Map.Strict as Map
import Control.Monad.IO.Class (MonadIO, liftIO)
import SDJWT.Internal.Types (HashAlgorithm(..))

-- | Base64url encode a ByteString (without padding).
--
-- This function encodes a ByteString using base64url encoding as specified
-- in RFC 4648 Section 5. The result is URL-safe and does not include padding.
--
-- >>> base64urlEncode "Hello, World!"
-- "SGVsbG8sIFdvcmxkIQ"
base64urlEncode :: BS.ByteString -> T.Text
base64urlEncode :: ByteString -> Text
base64urlEncode = ByteString -> Text
TE.decodeUtf8 (ByteString -> Text)
-> (ByteString -> ByteString) -> ByteString -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
Base64.encodeUnpadded

-- | Base64url decode a Text (handles padding).
--
-- This function decodes a base64url-encoded Text back to a ByteString.
-- It handles both padded and unpadded input.
--
-- Returns 'Left' with an error message if decoding fails.
base64urlDecode :: T.Text -> Either T.Text BS.ByteString
base64urlDecode :: Text -> Either Text ByteString
base64urlDecode Text
t =
  case ByteString -> Either String ByteString
Base64.decodeUnpadded (Text -> ByteString
TE.encodeUtf8 Text
t) of
    Left String
err -> Text -> Either Text ByteString
forall a b. a -> Either a b
Left (Text -> Either Text ByteString) -> Text -> Either Text ByteString
forall a b. (a -> b) -> a -> b
$ String -> Text
T.pack (String -> Text) -> String -> Text
forall a b. (a -> b) -> a -> b
$ String -> String
forall a. Show a => a -> String
show String
err
    Right ByteString
bs -> ByteString -> Either Text ByteString
forall a b. b -> Either a b
Right ByteString
bs

-- | Generate a cryptographically secure random salt.
--
-- Generates 128 bits (16 bytes) of random data as recommended by RFC 9901.
-- This salt is used when creating disclosures to ensure that digests cannot
-- be guessed or brute-forced.
--
-- The salt is generated using cryptonite's secure random number generator.
generateSalt :: MonadIO m => m BS.ByteString
generateSalt :: forall (m :: * -> *). MonadIO m => m ByteString
generateSalt = IO ByteString -> m ByteString
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ByteString -> m ByteString) -> IO ByteString -> m ByteString
forall a b. (a -> b) -> a -> b
$ Int -> IO ByteString
forall byteArray. ByteArray byteArray => Int -> IO byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
RNG.getRandomBytes Int
16

-- | Convert Text to ByteString (UTF-8 encoding).
--
-- This is a convenience function that encodes Text as UTF-8 ByteString.
textToByteString :: T.Text -> BS.ByteString
textToByteString :: Text -> ByteString
textToByteString = Text -> ByteString
TE.encodeUtf8

-- | Convert ByteString to Text (UTF-8 decoding).
--
-- This is a convenience function that decodes a UTF-8 ByteString to Text.
-- Note: This will throw an exception if the ByteString is not valid UTF-8.
-- For safe decoding, use 'Data.Text.Encoding.decodeUtf8'' instead.
byteStringToText :: BS.ByteString -> T.Text
byteStringToText :: ByteString -> Text
byteStringToText = ByteString -> Text
TE.decodeUtf8

-- | Hash bytes using the specified hash algorithm.
--
-- This function computes a cryptographic hash of the input ByteString
-- using the specified hash algorithm (SHA-256, SHA-384, or SHA-512).
-- Returns the hash digest as a ByteString.
hashToBytes :: HashAlgorithm -> BS.ByteString -> BS.ByteString
hashToBytes :: HashAlgorithm -> ByteString -> ByteString
hashToBytes HashAlgorithm
SHA256 ByteString
bs = Digest SHA256 -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert (ByteString -> Digest SHA256
forall ba a.
(ByteArrayAccess ba, HashAlgorithm a) =>
ba -> Digest a
Hash.hash ByteString
bs :: Hash.Digest Hash.SHA256)
hashToBytes HashAlgorithm
SHA384 ByteString
bs = Digest SHA384 -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert (ByteString -> Digest SHA384
forall ba a.
(ByteArrayAccess ba, HashAlgorithm a) =>
ba -> Digest a
Hash.hash ByteString
bs :: Hash.Digest Hash.SHA384)
hashToBytes HashAlgorithm
SHA512 ByteString
bs = Digest SHA512 -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert (ByteString -> Digest SHA512
forall ba a.
(ByteArrayAccess ba, HashAlgorithm a) =>
ba -> Digest a
Hash.hash ByteString
bs :: Hash.Digest Hash.SHA512)

-- | Split JSON Pointer path by "/", respecting escapes (RFC 6901).
--
-- This function properly handles JSON Pointer escaping:
--
-- - "~1" represents a literal forward slash "/"
-- - "~0" represents a literal tilde "~"
--
-- Examples:
--
-- - "a\/b" → ["a", "b"]
-- - "a~1b" → ["a\/b"] (escaped slash)
-- - "a~0b" → ["a~b"] (escaped tilde)
-- - "a~1\/b" → ["a\/", "b"] (escaped slash becomes "\/", then "\/" is separator)
-- 
-- Note: This function is designed for relative JSON Pointer paths (without leading "/").
-- Leading slashes are stripped, trailing slashes don't create empty segments,
-- and consecutive slashes are collapsed.
splitJSONPointer :: T.Text -> [T.Text]
splitJSONPointer :: Text -> [Text]
splitJSONPointer Text
path = Text -> [Text] -> Text -> [Text]
go Text
path [] Text
""
  where
    go :: Text -> [Text] -> Text -> [Text]
go Text
remaining [Text]
acc Text
current
      | Text -> Bool
T.null Text
remaining = [Text] -> [Text]
forall a. [a] -> [a]
reverse (if Text -> Bool
T.null Text
current then [Text]
acc else Text
current Text -> [Text] -> [Text]
forall a. a -> [a] -> [a]
: [Text]
acc)
      | Int -> Text -> Text
T.take Int
2 Text
remaining Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
"~1" =
          -- Escaped slash (must check before checking for unescaped "/")
          Text -> [Text] -> Text -> [Text]
go (Int -> Text -> Text
T.drop Int
2 Text
remaining) [Text]
acc (Text
current Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"/")
      | Int -> Text -> Text
T.take Int
2 Text
remaining Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
"~0" =
          -- Escaped tilde
          Text -> [Text] -> Text -> [Text]
go (Int -> Text -> Text
T.drop Int
2 Text
remaining) [Text]
acc (Text
current Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"~")
      | HasCallStack => Text -> Char
Text -> Char
T.head Text
remaining Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
'/' =
          -- Found unescaped slash (after checking escape sequences)
          Text -> [Text] -> Text -> [Text]
go (HasCallStack => Text -> Text
Text -> Text
T.tail Text
remaining) (if Text -> Bool
T.null Text
current then [Text]
acc else Text
current Text -> [Text] -> [Text]
forall a. a -> [a] -> [a]
: [Text]
acc) Text
""
      | Bool
otherwise =
          -- Regular character
          Text -> [Text] -> Text -> [Text]
go (HasCallStack => Text -> Text
Text -> Text
T.tail Text
remaining) [Text]
acc (Text -> Char -> Text
T.snoc Text
current (HasCallStack => Text -> Char
Text -> Char
T.head Text
remaining))

-- | Unescape JSON Pointer segment (RFC 6901).
--
-- Converts escape sequences back to literal characters:
--
-- - "~1" → "/"
-- - "~0" → "~"
--
-- Note: Order matters - must replace ~1 before ~0 to avoid double-replacement.
unescapeJSONPointer :: T.Text -> T.Text
unescapeJSONPointer :: Text -> Text
unescapeJSONPointer = HasCallStack => Text -> Text -> Text -> Text
Text -> Text -> Text -> Text
T.replace Text
"~1" Text
"/" (Text -> Text) -> (Text -> Text) -> Text -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HasCallStack => Text -> Text -> Text -> Text
Text -> Text -> Text -> Text
T.replace Text
"~0" Text
"~"

-- | Constant-time equality comparison for ByteStrings.
--
-- This function performs a constant-time comparison to prevent timing attacks.
-- It compares two ByteStrings byte-by-byte and always takes the same amount
-- of time regardless of where the first difference occurs.
--
-- SECURITY: Use this function when comparing cryptographic values like digests,
-- hashes, or other sensitive data that could be exploited via timing attacks.
--
-- Implementation uses cryptonite's 'BA.constEq' which provides constant-time
-- comparison for ByteArray instances. ByteString is a ByteArray instance.
--
constantTimeEq :: BS.ByteString -> BS.ByteString -> Bool
constantTimeEq :: ByteString -> ByteString -> Bool
constantTimeEq ByteString
a ByteString
b
  | ByteString -> Int
BS.length ByteString
a Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= ByteString -> Int
BS.length ByteString
b = Bool
False
  | Bool
otherwise = ByteString -> ByteString -> Bool
forall bs1 bs2.
(ByteArrayAccess bs1, ByteArrayAccess bs2) =>
bs1 -> bs2 -> Bool
BA.constEq ByteString
a ByteString
b

-- | Group paths by their first segment.
--
-- This is a common pattern for processing nested JSON Pointer paths.
-- Empty paths are grouped under an empty string key.
--
-- Example:
--   groupPathsByFirstSegment [["a", "b"], ["a", "c"], ["x"]] 
--   = Map.fromList [("a", [["b"], ["c"]]), ("x", [[]])]
groupPathsByFirstSegment :: [[T.Text]] -> Map.Map T.Text [[T.Text]]
groupPathsByFirstSegment :: [[Text]] -> Map Text [[Text]]
groupPathsByFirstSegment [[Text]]
nestedPaths =
  ([[Text]] -> [[Text]] -> [[Text]])
-> [(Text, [[Text]])] -> Map Text [[Text]]
forall k a. Ord k => (a -> a -> a) -> [(k, a)] -> Map k a
Map.fromListWith [[Text]] -> [[Text]] -> [[Text]]
forall a. [a] -> [a] -> [a]
(++) ([(Text, [[Text]])] -> Map Text [[Text]])
-> [(Text, [[Text]])] -> Map Text [[Text]]
forall a b. (a -> b) -> a -> b
$ ([Text] -> (Text, [[Text]])) -> [[Text]] -> [(Text, [[Text]])]
forall a b. (a -> b) -> [a] -> [b]
map (\[Text]
path -> case [Text]
path of
    [] -> (Text
"", [])
    (Text
first:[Text]
rest) -> (Text
first, [[Text]
rest])) [[Text]]
nestedPaths