{-# OPTIONS_HADDOCK not-home #-}
{-# LANGUAGE TupleSections #-}
module MagicWormhole.Internal.ClientProtocol
  ( Connection(..)
  , SessionKey(..)
  , PeerError(..)
  , sendEncrypted
  , receiveEncrypted
  , PlainText(..)
  
  , CipherText(..)
  , decrypt
  , encrypt
  , deriveKey
  , Purpose
  , phasePurpose
  ) where
import Protolude hiding (phase)
import Crypto.Hash (SHA256(..), hashWith)
import qualified Crypto.KDF.HKDF as HKDF
import qualified Crypto.Saltine.Internal.ByteSizes as ByteSizes
import qualified Crypto.Saltine.Class as Saltine
import qualified Crypto.Saltine.Core.SecretBox as SecretBox
import qualified Data.ByteArray as ByteArray
import qualified Data.ByteString as ByteString
import qualified MagicWormhole.Internal.Messages as Messages
data Connection
  = Connection
  { 
    appID :: Messages.AppID
    
  , ourSide :: Messages.Side
    
  , send :: Messages.Phase -> Messages.Body -> IO ()
    
  , receive :: STM Messages.MailboxMessage
  }
newtype SessionKey = SessionKey ByteString
sendEncrypted
  :: Connection 
  -> SessionKey 
  -> Messages.Phase 
  -> PlainText 
  -> IO ()
sendEncrypted conn key phase plaintext = do
  encryptedBody <- encryptMessage conn key phase plaintext
  send conn phase encryptedBody
receiveEncrypted
  :: Connection 
  -> SessionKey 
  -> STM (Messages.Phase, PlainText)  
receiveEncrypted conn key = do
  message <- receive conn
  either throwSTM pure $ decryptMessage key message
encryptMessage :: Connection -> SessionKey -> Messages.Phase -> PlainText -> IO Messages.Body
encryptMessage conn key phase plaintext = Messages.Body . cipherTextToByteString <$> encrypt derivedKey plaintext
  where
    derivedKey = deriveKey key (phasePurpose (ourSide conn) phase)
encrypt :: SecretBox.Key -> PlainText -> IO CipherText
encrypt key (PlainText message) = do
  nonce <- SecretBox.newNonce
  let ciphertext = SecretBox.secretbox key nonce message
  pure . CipherText $ Saltine.encode nonce <> ciphertext
decryptMessage :: SessionKey -> Messages.MailboxMessage -> Either PeerError (Messages.Phase, PlainText)
decryptMessage key message =
  let Messages.Body ciphertext = Messages.body message
  in (Messages.phase message,) <$> decrypt (derivedKey message) (CipherText ciphertext)
  where
    derivedKey msg = deriveKey key (phasePurpose (Messages.side msg) (Messages.phase msg))
decrypt :: SecretBox.Key -> CipherText -> Either PeerError PlainText
decrypt key (CipherText ciphertext) = do
  let (nonce', ciphertext') = ByteString.splitAt ByteSizes.secretBoxNonce ciphertext
  nonce <- note (InvalidNonce nonce') $ Saltine.decode nonce'
  note (CouldNotDecrypt ciphertext') $ PlainText <$> SecretBox.secretboxOpen key nonce ciphertext'
newtype PlainText = PlainText { plainTextToByteString :: ByteString } deriving (Eq, Ord, Show)
newtype CipherText = CipherText { cipherTextToByteString :: ByteString } deriving (Eq, Ord, Show)
type Purpose = ByteString
deriveKey
  :: SessionKey 
  -> Purpose 
  -> SecretBox.Key  
deriveKey (SessionKey key) purpose =
  fromMaybe (panic "Could not encode to SecretBox key") $ 
    Saltine.decode (HKDF.expand (HKDF.extract salt key :: HKDF.PRK SHA256) purpose keySize)
  where
    salt = "" :: ByteString
    keySize = ByteSizes.secretBoxKey
phasePurpose :: Messages.Side -> Messages.Phase -> Purpose
phasePurpose (Messages.Side side) phase = "wormhole:phase:" <> sideHashDigest <> phaseHashDigest
  where
    sideHashDigest = hashDigest (toS @Text @ByteString side)
    phaseHashDigest = hashDigest (toS (Messages.phaseName phase) :: ByteString)
    hashDigest thing = ByteArray.convert (hashWith SHA256 thing)
data PeerError
  
  = CouldNotDecrypt ByteString
  
  | InvalidNonce ByteString
  
  | MessageOutOfOrder Messages.Phase PlainText
  deriving (Eq, Show, Typeable)
instance Exception PeerError