{-# LANGUAGE OverloadedStrings #-}
module SDJWT.Internal.Utils
( base64urlEncode
, base64urlDecode
, textToByteString
, byteStringToText
, hashToBytes
, splitJSONPointer
, unescapeJSONPointer
, constantTimeEq
, generateSalt
, groupPathsByFirstSegment
) where
import qualified Data.ByteString.Base64.URL as Base64
import qualified Data.ByteString as BS
import qualified Data.Text as T
import qualified Data.Text.Encoding as TE
import qualified Crypto.Random as RNG
import qualified Crypto.Hash as Hash
import qualified Data.ByteArray as BA
import qualified Data.Map.Strict as Map
import Control.Monad.IO.Class (MonadIO, liftIO)
import SDJWT.Internal.Types (HashAlgorithm(..))
base64urlEncode :: BS.ByteString -> T.Text
base64urlEncode :: ByteString -> Text
base64urlEncode = ByteString -> Text
TE.decodeUtf8 (ByteString -> Text)
-> (ByteString -> ByteString) -> ByteString -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
Base64.encodeUnpadded
base64urlDecode :: T.Text -> Either T.Text BS.ByteString
base64urlDecode :: Text -> Either Text ByteString
base64urlDecode Text
t =
case ByteString -> Either String ByteString
Base64.decodeUnpadded (Text -> ByteString
TE.encodeUtf8 Text
t) of
Left String
err -> Text -> Either Text ByteString
forall a b. a -> Either a b
Left (Text -> Either Text ByteString) -> Text -> Either Text ByteString
forall a b. (a -> b) -> a -> b
$ String -> Text
T.pack (String -> Text) -> String -> Text
forall a b. (a -> b) -> a -> b
$ String -> String
forall a. Show a => a -> String
show String
err
Right ByteString
bs -> ByteString -> Either Text ByteString
forall a b. b -> Either a b
Right ByteString
bs
generateSalt :: MonadIO m => m BS.ByteString
generateSalt :: forall (m :: * -> *). MonadIO m => m ByteString
generateSalt = IO ByteString -> m ByteString
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ByteString -> m ByteString) -> IO ByteString -> m ByteString
forall a b. (a -> b) -> a -> b
$ Int -> IO ByteString
forall byteArray. ByteArray byteArray => Int -> IO byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
RNG.getRandomBytes Int
16
textToByteString :: T.Text -> BS.ByteString
textToByteString :: Text -> ByteString
textToByteString = Text -> ByteString
TE.encodeUtf8
byteStringToText :: BS.ByteString -> T.Text
byteStringToText :: ByteString -> Text
byteStringToText = ByteString -> Text
TE.decodeUtf8
hashToBytes :: HashAlgorithm -> BS.ByteString -> BS.ByteString
hashToBytes :: HashAlgorithm -> ByteString -> ByteString
hashToBytes HashAlgorithm
SHA256 ByteString
bs = Digest SHA256 -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert (ByteString -> Digest SHA256
forall ba a.
(ByteArrayAccess ba, HashAlgorithm a) =>
ba -> Digest a
Hash.hash ByteString
bs :: Hash.Digest Hash.SHA256)
hashToBytes HashAlgorithm
SHA384 ByteString
bs = Digest SHA384 -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert (ByteString -> Digest SHA384
forall ba a.
(ByteArrayAccess ba, HashAlgorithm a) =>
ba -> Digest a
Hash.hash ByteString
bs :: Hash.Digest Hash.SHA384)
hashToBytes HashAlgorithm
SHA512 ByteString
bs = Digest SHA512 -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert (ByteString -> Digest SHA512
forall ba a.
(ByteArrayAccess ba, HashAlgorithm a) =>
ba -> Digest a
Hash.hash ByteString
bs :: Hash.Digest Hash.SHA512)
splitJSONPointer :: T.Text -> [T.Text]
splitJSONPointer :: Text -> [Text]
splitJSONPointer Text
path = Text -> [Text] -> Text -> [Text]
go Text
path [] Text
""
where
go :: Text -> [Text] -> Text -> [Text]
go Text
remaining [Text]
acc Text
current
| Text -> Bool
T.null Text
remaining = [Text] -> [Text]
forall a. [a] -> [a]
reverse (if Text -> Bool
T.null Text
current then [Text]
acc else Text
current Text -> [Text] -> [Text]
forall a. a -> [a] -> [a]
: [Text]
acc)
| Int -> Text -> Text
T.take Int
2 Text
remaining Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
"~1" =
Text -> [Text] -> Text -> [Text]
go (Int -> Text -> Text
T.drop Int
2 Text
remaining) [Text]
acc (Text
current Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"/")
| Int -> Text -> Text
T.take Int
2 Text
remaining Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
"~0" =
Text -> [Text] -> Text -> [Text]
go (Int -> Text -> Text
T.drop Int
2 Text
remaining) [Text]
acc (Text
current Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"~")
| HasCallStack => Text -> Char
Text -> Char
T.head Text
remaining Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
'/' =
Text -> [Text] -> Text -> [Text]
go (HasCallStack => Text -> Text
Text -> Text
T.tail Text
remaining) (if Text -> Bool
T.null Text
current then [Text]
acc else Text
current Text -> [Text] -> [Text]
forall a. a -> [a] -> [a]
: [Text]
acc) Text
""
| Bool
otherwise =
Text -> [Text] -> Text -> [Text]
go (HasCallStack => Text -> Text
Text -> Text
T.tail Text
remaining) [Text]
acc (Text -> Char -> Text
T.snoc Text
current (HasCallStack => Text -> Char
Text -> Char
T.head Text
remaining))
unescapeJSONPointer :: T.Text -> T.Text
unescapeJSONPointer :: Text -> Text
unescapeJSONPointer = HasCallStack => Text -> Text -> Text -> Text
Text -> Text -> Text -> Text
T.replace Text
"~1" Text
"/" (Text -> Text) -> (Text -> Text) -> Text -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HasCallStack => Text -> Text -> Text -> Text
Text -> Text -> Text -> Text
T.replace Text
"~0" Text
"~"
constantTimeEq :: BS.ByteString -> BS.ByteString -> Bool
constantTimeEq :: ByteString -> ByteString -> Bool
constantTimeEq ByteString
a ByteString
b
| ByteString -> Int
BS.length ByteString
a Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= ByteString -> Int
BS.length ByteString
b = Bool
False
| Bool
otherwise = ByteString -> ByteString -> Bool
forall bs1 bs2.
(ByteArrayAccess bs1, ByteArrayAccess bs2) =>
bs1 -> bs2 -> Bool
BA.constEq ByteString
a ByteString
b
groupPathsByFirstSegment :: [[T.Text]] -> Map.Map T.Text [[T.Text]]
groupPathsByFirstSegment :: [[Text]] -> Map Text [[Text]]
groupPathsByFirstSegment [[Text]]
nestedPaths =
([[Text]] -> [[Text]] -> [[Text]])
-> [(Text, [[Text]])] -> Map Text [[Text]]
forall k a. Ord k => (a -> a -> a) -> [(k, a)] -> Map k a
Map.fromListWith [[Text]] -> [[Text]] -> [[Text]]
forall a. [a] -> [a] -> [a]
(++) ([(Text, [[Text]])] -> Map Text [[Text]])
-> [(Text, [[Text]])] -> Map Text [[Text]]
forall a b. (a -> b) -> a -> b
$ ([Text] -> (Text, [[Text]])) -> [[Text]] -> [(Text, [[Text]])]
forall a b. (a -> b) -> [a] -> [b]
map (\[Text]
path -> case [Text]
path of
[] -> (Text
"", [])
(Text
first:[Text]
rest) -> (Text
first, [[Text]
rest])) [[Text]]
nestedPaths