{-# LANGUAGE OverloadedStrings #-}
-- | Serialization and deserialization of SD-JWT structures.
--
-- This module provides functions to serialize and deserialize SD-JWTs
-- to/from the tilde-separated format specified in RFC 9901.
module SDJWT.Internal.Serialization
  ( serializeSDJWT
  , deserializeSDJWT
  , serializePresentation
  , deserializePresentation
  , parseTildeSeparated
  ) where

import SDJWT.Internal.Types (SDJWT(..), SDJWTPresentation(..), SDJWTError(..), EncodedDisclosure(..))
import Data.Maybe (fromMaybe)
import qualified Data.Text as T

-- | Serialize SD-JWT to tilde-separated format.
--
-- Format: @<Issuer-signed JWT>~<Disclosure 1>~<Disclosure 2>~...~<Disclosure N>~@
--
-- The last tilde is always present, even if there are no disclosures.
serializeSDJWT :: SDJWT -> T.Text
serializeSDJWT :: SDJWT -> Text
serializeSDJWT (SDJWT Text
jwt [EncodedDisclosure]
sdDisclosures) =
  let
    disclosureParts :: [Text]
disclosureParts = (EncodedDisclosure -> Text) -> [EncodedDisclosure] -> [Text]
forall a b. (a -> b) -> [a] -> [b]
map EncodedDisclosure -> Text
unEncodedDisclosure [EncodedDisclosure]
sdDisclosures
    allParts :: [Text]
allParts = Text
jwt Text -> [Text] -> [Text]
forall a. a -> [a] -> [a]
: [Text]
disclosureParts [Text] -> [Text] -> [Text]
forall a. [a] -> [a] -> [a]
++ [Text
""]
  in
    Text -> [Text] -> Text
T.intercalate Text
"~" [Text]
allParts

-- | Deserialize SD-JWT from tilde-separated format.
--
-- Parses a tilde-separated string into an 'SDJWT' structure.
-- Returns an error if the format is invalid or if a Key Binding JWT
-- is present (use 'deserializePresentation' for SD-JWT+KB).
deserializeSDJWT :: T.Text -> Either SDJWTError SDJWT
deserializeSDJWT :: Text -> Either SDJWTError SDJWT
deserializeSDJWT Text
input =
  case Text -> Either SDJWTError (Text, [EncodedDisclosure], Maybe Text)
parseTildeSeparated Text
input of
    Left SDJWTError
err -> SDJWTError -> Either SDJWTError SDJWT
forall a b. a -> Either a b
Left SDJWTError
err
    Right (Text
jwt, [EncodedDisclosure]
sdDisclosures, Maybe Text
Nothing) ->
      -- Verify last part is empty (SD-JWT format)
      SDJWT -> Either SDJWTError SDJWT
forall a b. b -> Either a b
Right (SDJWT -> Either SDJWTError SDJWT)
-> SDJWT -> Either SDJWTError SDJWT
forall a b. (a -> b) -> a -> b
$ Text -> [EncodedDisclosure] -> SDJWT
SDJWT Text
jwt [EncodedDisclosure]
sdDisclosures
    Right (Text
_, [EncodedDisclosure]
_, Just Text
_) ->
      SDJWTError -> Either SDJWTError SDJWT
forall a b. a -> Either a b
Left (SDJWTError -> Either SDJWTError SDJWT)
-> SDJWTError -> Either SDJWTError SDJWT
forall a b. (a -> b) -> a -> b
$ Text -> SDJWTError
SerializationError Text
"SD-JWT should not have Key Binding JWT (use SD-JWT+KB format)"

-- | Serialize SD-JWT presentation.
--
-- Format: @<Issuer-signed JWT>~<Disclosure 1>~...~<Disclosure N>~[<KB-JWT>]@
--
-- If a Key Binding JWT is present, it is included as the last component.
-- Otherwise, the last component is empty (just a trailing tilde).
serializePresentation :: SDJWTPresentation -> T.Text
serializePresentation :: SDJWTPresentation -> Text
serializePresentation (SDJWTPresentation Text
jwt [EncodedDisclosure]
sdDisclosures Maybe Text
mbKbJwt) =
  let
    disclosureParts :: [Text]
disclosureParts = (EncodedDisclosure -> Text) -> [EncodedDisclosure] -> [Text]
forall a b. (a -> b) -> [a] -> [b]
map EncodedDisclosure -> Text
unEncodedDisclosure [EncodedDisclosure]
sdDisclosures
    kbPart :: Text
kbPart = Text -> Maybe Text -> Text
forall a. a -> Maybe a -> a
fromMaybe Text
"" Maybe Text
mbKbJwt
    allParts :: [Text]
allParts = Text
jwt Text -> [Text] -> [Text]
forall a. a -> [a] -> [a]
: [Text]
disclosureParts [Text] -> [Text] -> [Text]
forall a. [a] -> [a] -> [a]
++ [Text
kbPart]
  in
    Text -> [Text] -> Text
T.intercalate Text
"~" [Text]
allParts

-- | Deserialize SD-JWT presentation.
--
-- Parses a tilde-separated string into an 'SDJWTPresentation' structure.
-- This handles both SD-JWT (without KB-JWT) and SD-JWT+KB (with KB-JWT) formats.
deserializePresentation :: T.Text -> Either SDJWTError SDJWTPresentation
deserializePresentation :: Text -> Either SDJWTError SDJWTPresentation
deserializePresentation Text
input =
  case Text -> Either SDJWTError (Text, [EncodedDisclosure], Maybe Text)
parseTildeSeparated Text
input of
    Left SDJWTError
err -> SDJWTError -> Either SDJWTError SDJWTPresentation
forall a b. a -> Either a b
Left SDJWTError
err
    Right (Text
jwt, [EncodedDisclosure]
sdDisclosures, Maybe Text
mbKbJwt) ->
      SDJWTPresentation -> Either SDJWTError SDJWTPresentation
forall a b. b -> Either a b
Right (SDJWTPresentation -> Either SDJWTError SDJWTPresentation)
-> SDJWTPresentation -> Either SDJWTError SDJWTPresentation
forall a b. (a -> b) -> a -> b
$ Text -> [EncodedDisclosure] -> Maybe Text -> SDJWTPresentation
SDJWTPresentation Text
jwt [EncodedDisclosure]
sdDisclosures Maybe Text
mbKbJwt

-- | Parse tilde-separated format.
--
-- Low-level function that parses the tilde-separated format and returns
-- the components: (JWT, [Disclosures], Maybe KB-JWT).
--
-- The last component is 'Nothing' for SD-JWT format (empty string after
-- last tilde) or 'Just' KB-JWT for SD-JWT+KB format.
parseTildeSeparated :: T.Text -> Either SDJWTError (T.Text, [EncodedDisclosure], Maybe T.Text)
parseTildeSeparated :: Text -> Either SDJWTError (Text, [EncodedDisclosure], Maybe Text)
parseTildeSeparated Text
input =
  let
    parts :: [Text]
parts = HasCallStack => Text -> Text -> [Text]
Text -> Text -> [Text]
T.splitOn Text
"~" Text
input
  in
    case [Text]
parts of
      [] -> SDJWTError
-> Either SDJWTError (Text, [EncodedDisclosure], Maybe Text)
forall a b. a -> Either a b
Left (SDJWTError
 -> Either SDJWTError (Text, [EncodedDisclosure], Maybe Text))
-> SDJWTError
-> Either SDJWTError (Text, [EncodedDisclosure], Maybe Text)
forall a b. (a -> b) -> a -> b
$ Text -> SDJWTError
SerializationError Text
"Empty SD-JWT"
      [Text
jwt] ->
        -- Just JWT, no disclosures or KB-JWT
        (Text, [EncodedDisclosure], Maybe Text)
-> Either SDJWTError (Text, [EncodedDisclosure], Maybe Text)
forall a b. b -> Either a b
Right (Text
jwt, [], Maybe Text
forall a. Maybe a
Nothing)
      Text
jwt : [Text]
rest ->
        let
          -- Last part could be empty (SD-JWT) or KB-JWT (SD-JWT+KB)
          -- Note: rest is guaranteed to be non-empty since [jwt] case is handled above
          ([Text]
disclosureParts, Maybe Text
lastPart) = case [Text] -> [Text]
forall a. [a] -> [a]
reverse [Text]
rest of
            [] -> [Char] -> ([Text], Maybe Text)
forall a. HasCallStack => [Char] -> a
error [Char]
"parseTildeSeparated: impossible case - rest should be non-empty"
            Text
lastItem : [Text]
revDisclosures ->
              if Text -> Bool
T.null Text
lastItem
                then ([Text] -> [Text]
forall a. [a] -> [a]
reverse [Text]
revDisclosures, Maybe Text
forall a. Maybe a
Nothing)
                else ([Text] -> [Text]
forall a. [a] -> [a]
reverse [Text]
revDisclosures, Text -> Maybe Text
forall a. a -> Maybe a
Just Text
lastItem)
          sdDisclosures :: [EncodedDisclosure]
sdDisclosures = (Text -> EncodedDisclosure) -> [Text] -> [EncodedDisclosure]
forall a b. (a -> b) -> [a] -> [b]
map Text -> EncodedDisclosure
EncodedDisclosure [Text]
disclosureParts
        in
          (Text, [EncodedDisclosure], Maybe Text)
-> Either SDJWTError (Text, [EncodedDisclosure], Maybe Text)
forall a b. b -> Either a b
Right (Text
jwt, [EncodedDisclosure]
sdDisclosures, Maybe Text
lastPart)