{-# LANGUAGE OverloadedStrings #-}

module Network.TLS.Handshake.TranscriptHash (
    transcriptHash,
    transcriptHashWith,
    transitTranscriptHashI,
    updateTranscriptHash,
    updateTranscriptHashI,
    transitTranscriptHash,
    copyTranscriptHash,
    TranscriptHash (..),
) where

import Control.Monad.State
import qualified Data.ByteString as B

import Network.TLS.Cipher
import Network.TLS.Context.Internal
import Network.TLS.Crypto
import Network.TLS.Handshake.State
import Network.TLS.Imports
import Network.TLS.Parameters
import Network.TLS.State
import Network.TLS.Types

----------------------------------------------------------------

transitTranscriptHash :: Context -> String -> Hash -> Bool -> IO ()
transitTranscriptHash :: Context -> [Char] -> Hash -> Bool -> IO ()
transitTranscriptHash Context
ctx [Char]
label Hash
hashAlg Bool
isHRR = do
    Context -> HandshakeM () -> IO ()
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (HandshakeM () -> IO ()) -> HandshakeM () -> IO ()
forall a b. (a -> b) -> a -> b
$ (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((HandshakeState -> HandshakeState) -> HandshakeM ())
-> (HandshakeState -> HandshakeState) -> HandshakeM ()
forall a b. (a -> b) -> a -> b
$ \HandshakeState
hst ->
        HandshakeState
hst{hstTransHashState = transit label hashAlg isHRR $ hstTransHashState hst}
    Context -> [Char] -> (HandshakeState -> TransHashState) -> IO ()
traceTranscriptHash Context
ctx [Char]
label HandshakeState -> TransHashState
hstTransHashState

transitTranscriptHashI :: Context -> String -> Hash -> Bool -> IO ()
transitTranscriptHashI :: Context -> [Char] -> Hash -> Bool -> IO ()
transitTranscriptHashI Context
ctx [Char]
label Hash
hashAlg Bool
isHRR = do
    Context -> HandshakeM () -> IO ()
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (HandshakeM () -> IO ()) -> HandshakeM () -> IO ()
forall a b. (a -> b) -> a -> b
$ (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((HandshakeState -> HandshakeState) -> HandshakeM ())
-> (HandshakeState -> HandshakeState) -> HandshakeM ()
forall a b. (a -> b) -> a -> b
$ \HandshakeState
hst ->
        HandshakeState
hst{hstTransHashStateI = transit label hashAlg isHRR $ hstTransHashStateI hst}
    Context -> [Char] -> (HandshakeState -> TransHashState) -> IO ()
traceTranscriptHash Context
ctx [Char]
label HandshakeState -> TransHashState
hstTransHashStateI

transit :: String -> Hash -> Bool -> TransHashState -> TransHashState
transit :: [Char] -> Hash -> Bool -> TransHashState -> TransHashState
transit [Char]
label Hash
_ Bool
_ st0 :: TransHashState
st0@TransHashState
TransHashState0 = [Char] -> TransHashState
forall a. HasCallStack => [Char] -> a
error ([Char] -> TransHashState) -> [Char] -> TransHashState
forall a b. (a -> b) -> a -> b
$ [Char]
"transitTranscriptHash " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
label [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
" " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ TransHashState -> [Char]
forall a. Show a => a -> [Char]
show TransHashState
st0
transit [Char]
_ Hash
_ Bool
_ st2 :: TransHashState
st2@(TransHashState2 HashCtx
_) = TransHashState
st2
transit [Char]
_ Hash
hashAlg Bool
isHRR (TransHashState1 ByteString
ch)
    | Bool
isHRR = HashCtx -> TransHashState
TransHashState2 (HashCtx -> TransHashState) -> HashCtx -> TransHashState
forall a b. (a -> b) -> a -> b
$ ByteString -> HashCtx
newWith ByteString
hsMsg
    | Bool
otherwise = HashCtx -> TransHashState
TransHashState2 (HashCtx -> TransHashState) -> HashCtx -> TransHashState
forall a b. (a -> b) -> a -> b
$ ByteString -> HashCtx
newWith ByteString
ch
  where
    newWith :: ByteString -> HashCtx
newWith = HashCtx -> ByteString -> HashCtx
hashUpdate (HashCtx -> ByteString -> HashCtx)
-> HashCtx -> ByteString -> HashCtx
forall a b. (a -> b) -> a -> b
$ Hash -> HashCtx
hashInit Hash
hashAlg
    hsMsg :: ByteString
hsMsg =
        -- Handshake message:
        -- typ <-len-> body
        -- 254 0 0 len hash(CH1)
        [ByteString] -> ByteString
B.concat
            [ ByteString
"\254\0\0"
            , Word8 -> ByteString
B.singleton Word8
len
            , ByteString
hashedCH
            ]
      where
        hashedCH :: ByteString
hashedCH = Hash -> ByteString -> ByteString
hash Hash
hashAlg ByteString
ch
        len :: Word8
len = Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word8) -> Int -> Word8
forall a b. (a -> b) -> a -> b
$ ByteString -> Int
B.length ByteString
hashedCH

----------------------------------------------------------------

updateTranscriptHash :: Context -> String -> ByteString -> IO ()
updateTranscriptHash :: Context -> [Char] -> ByteString -> IO ()
updateTranscriptHash Context
ctx [Char]
label ByteString
eh = do
    Context -> HandshakeM () -> IO ()
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (HandshakeM () -> IO ()) -> HandshakeM () -> IO ()
forall a b. (a -> b) -> a -> b
$ (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((HandshakeState -> HandshakeState) -> HandshakeM ())
-> (HandshakeState -> HandshakeState) -> HandshakeM ()
forall a b. (a -> b) -> a -> b
$ \HandshakeState
hst ->
        HandshakeState
hst{hstTransHashState = update eh label $ hstTransHashState hst}
    Context -> [Char] -> (HandshakeState -> TransHashState) -> IO ()
traceTranscriptHash Context
ctx [Char]
label HandshakeState -> TransHashState
hstTransHashState

updateTranscriptHashI :: Context -> String -> ByteString -> IO ()
updateTranscriptHashI :: Context -> [Char] -> ByteString -> IO ()
updateTranscriptHashI Context
ctx [Char]
label ByteString
eh = do
    Context -> HandshakeM () -> IO ()
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (HandshakeM () -> IO ()) -> HandshakeM () -> IO ()
forall a b. (a -> b) -> a -> b
$ (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((HandshakeState -> HandshakeState) -> HandshakeM ())
-> (HandshakeState -> HandshakeState) -> HandshakeM ()
forall a b. (a -> b) -> a -> b
$ \HandshakeState
hst ->
        HandshakeState
hst{hstTransHashStateI = update eh label $ hstTransHashStateI hst}
    Context -> [Char] -> (HandshakeState -> TransHashState) -> IO ()
traceTranscriptHash Context
ctx [Char]
label HandshakeState -> TransHashState
hstTransHashStateI

update :: ByteString -> String -> TransHashState -> TransHashState
update :: ByteString -> [Char] -> TransHashState -> TransHashState
update ByteString
eh [Char]
_ TransHashState
TransHashState0 = ByteString -> TransHashState
TransHashState1 ByteString
eh
update ByteString
eh [Char]
_ (TransHashState2 HashCtx
hctx) = HashCtx -> TransHashState
TransHashState2 (HashCtx -> TransHashState) -> HashCtx -> TransHashState
forall a b. (a -> b) -> a -> b
$ HashCtx -> ByteString -> HashCtx
hashUpdate HashCtx
hctx ByteString
eh
update ByteString
_ [Char]
label TransHashState
st = [Char] -> TransHashState
forall a. HasCallStack => [Char] -> a
error ([Char] -> TransHashState) -> [Char] -> TransHashState
forall a b. (a -> b) -> a -> b
$ [Char]
"updateTranscriptHash " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
label [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
" " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ TransHashState -> [Char]
forall a. Show a => a -> [Char]
show TransHashState
st

----------------------------------------------------------------

transcriptHash :: MonadIO m => Context -> String -> m TranscriptHash
transcriptHash :: forall (m :: * -> *).
MonadIO m =>
Context -> [Char] -> m TranscriptHash
transcriptHash Context
ctx [Char]
label = do
    HandshakeState
hst <- Maybe HandshakeState -> HandshakeState
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe HandshakeState -> HandshakeState)
-> m (Maybe HandshakeState) -> m HandshakeState
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Context -> m (Maybe HandshakeState)
forall (m :: * -> *).
MonadIO m =>
Context -> m (Maybe HandshakeState)
getHState Context
ctx
    let th :: ByteString
th = [Char] -> TransHashState -> ByteString
calc [Char]
label (TransHashState -> ByteString) -> TransHashState -> ByteString
forall a b. (a -> b) -> a -> b
$ HandshakeState -> TransHashState
hstTransHashState HandshakeState
hst
    IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ DebugParams -> [Char] -> IO ()
debugTraceKey (Context -> DebugParams
ctxDebug Context
ctx) ([Char] -> IO ()) -> [Char] -> IO ()
forall a b. (a -> b) -> a -> b
$ [Char] -> [Char]
adjustLabel [Char]
label [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ ByteString -> [Char]
showBytesHex ByteString
th
    TranscriptHash -> m TranscriptHash
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (TranscriptHash -> m TranscriptHash)
-> TranscriptHash -> m TranscriptHash
forall a b. (a -> b) -> a -> b
$ ByteString -> TranscriptHash
TranscriptHash ByteString
th

calc :: String -> TransHashState -> ByteString
calc :: [Char] -> TransHashState -> ByteString
calc [Char]
_ (TransHashState2 HashCtx
hashCtx) = HashCtx -> ByteString
hashFinal HashCtx
hashCtx
calc [Char]
label TransHashState
st = [Char] -> ByteString
forall a. HasCallStack => [Char] -> a
error ([Char] -> ByteString) -> [Char] -> ByteString
forall a b. (a -> b) -> a -> b
$ [Char]
"transcriptHash " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
label [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
" " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ TransHashState -> [Char]
forall a. Show a => a -> [Char]
show TransHashState
st

----------------------------------------------------------------

transcriptHashWith
    :: MonadIO m => Context -> String -> ByteString -> m TranscriptHash
transcriptHashWith :: forall (m :: * -> *).
MonadIO m =>
Context -> [Char] -> ByteString -> m TranscriptHash
transcriptHashWith Context
ctx [Char]
label ByteString
bs = do
    Role
role <- IO Role -> m Role
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Role -> m Role) -> IO Role -> m Role
forall a b. (a -> b) -> a -> b
$ Context -> TLSSt Role -> IO Role
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt Role
getRole
    let isClient :: Bool
isClient = Role
role Role -> Role -> Bool
forall a. Eq a => a -> a -> Bool
== Role
ClientRole
    HandshakeState
hst <- Maybe HandshakeState -> HandshakeState
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe HandshakeState -> HandshakeState)
-> m (Maybe HandshakeState) -> m HandshakeState
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Context -> m (Maybe HandshakeState)
forall (m :: * -> *).
MonadIO m =>
Context -> m (Maybe HandshakeState)
getHState Context
ctx
    let st :: TransHashState
st
            | Bool
isClient = HandshakeState -> TransHashState
hstTransHashStateI HandshakeState
hst
            | Bool
otherwise = HandshakeState -> TransHashState
hstTransHashState HandshakeState
hst
    let th :: ByteString
th = ByteString -> [Char] -> TransHashState -> ByteString
calcWith ByteString
bs [Char]
label TransHashState
st
    IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ DebugParams -> [Char] -> IO ()
debugTraceKey (Context -> DebugParams
ctxDebug Context
ctx) ([Char] -> IO ()) -> [Char] -> IO ()
forall a b. (a -> b) -> a -> b
$ [Char] -> [Char]
adjustLabel [Char]
label [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ ByteString -> [Char]
showBytesHex ByteString
th
    TranscriptHash -> m TranscriptHash
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (TranscriptHash -> m TranscriptHash)
-> TranscriptHash -> m TranscriptHash
forall a b. (a -> b) -> a -> b
$ ByteString -> TranscriptHash
TranscriptHash ByteString
th

calcWith :: ByteString -> String -> TransHashState -> ByteString
calcWith :: ByteString -> [Char] -> TransHashState -> ByteString
calcWith ByteString
bs [Char]
_ (TransHashState2 HashCtx
hashCtx) = HashCtx -> ByteString
hashFinal (HashCtx -> ByteString) -> HashCtx -> ByteString
forall a b. (a -> b) -> a -> b
$ HashCtx -> ByteString -> HashCtx
hashUpdate HashCtx
hashCtx ByteString
bs
calcWith ByteString
_ [Char]
label TransHashState
st = [Char] -> ByteString
forall a. HasCallStack => [Char] -> a
error ([Char] -> ByteString) -> [Char] -> ByteString
forall a b. (a -> b) -> a -> b
$ [Char]
"transcriptHashWith " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
label [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
" " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ TransHashState -> [Char]
forall a. Show a => a -> [Char]
show TransHashState
st

----------------------------------------------------------------

copyTranscriptHash :: Context -> String -> IO ()
copyTranscriptHash :: Context -> [Char] -> IO ()
copyTranscriptHash Context
ctx [Char]
label = do
    Context -> HandshakeM () -> IO ()
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (HandshakeM () -> IO ()) -> HandshakeM () -> IO ()
forall a b. (a -> b) -> a -> b
$ (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((HandshakeState -> HandshakeState) -> HandshakeM ())
-> (HandshakeState -> HandshakeState) -> HandshakeM ()
forall a b. (a -> b) -> a -> b
$ \HandshakeState
hst ->
        HandshakeState
hst
            { hstTransHashState = hstTransHashStateI hst
            }
    Context -> [Char] -> (HandshakeState -> TransHashState) -> IO ()
traceTranscriptHash Context
ctx [Char]
label HandshakeState -> TransHashState
hstTransHashState

----------------------------------------------------------------

traceTranscriptHash
    :: Context -> String -> (HandshakeState -> TransHashState) -> IO ()
traceTranscriptHash :: Context -> [Char] -> (HandshakeState -> TransHashState) -> IO ()
traceTranscriptHash Context
ctx [Char]
label HandshakeState -> TransHashState
getField = do
    HandshakeState
hst <- Maybe HandshakeState -> HandshakeState
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe HandshakeState -> HandshakeState)
-> IO (Maybe HandshakeState) -> IO HandshakeState
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Context -> IO (Maybe HandshakeState)
forall (m :: * -> *).
MonadIO m =>
Context -> m (Maybe HandshakeState)
getHState Context
ctx
    DebugParams -> [Char] -> IO ()
debugTraceKey (Context -> DebugParams
ctxDebug Context
ctx) ([Char] -> IO ()) -> [Char] -> IO ()
forall a b. (a -> b) -> a -> b
$ [Char] -> [Char]
adjustLabel [Char]
label [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ TransHashState -> [Char]
forall a. Show a => a -> [Char]
show (HandshakeState -> TransHashState
getField HandshakeState
hst)

adjustLabel :: String -> String
adjustLabel :: [Char] -> [Char]
adjustLabel [Char]
label = Int -> [Char] -> [Char]
forall a. Int -> [a] -> [a]
take Int
24 ([Char]
label [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"                      ")