{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings     #-}
{-# LANGUAGE RankNTypes            #-}
{-# LANGUAGE StrictData            #-}

module Tox.Transport.SecureSession.Manager where

import           Control.Monad.State              (MonadState, gets, modify,
                                                   runStateT)
import           Data.ByteString                  (ByteString)
import           Data.Map                         (Map)
import qualified Data.Map                         as Map

import           Tox.Core.Timed                   (Timed)
import           Tox.Crypto.Core.Key              (CombinedKey, PublicKey)
import           Tox.Crypto.Core.Keyed            (Keyed)
import           Tox.Crypto.Core.KeyPair          (KeyPair)
import           Tox.Crypto.Core.MonadRandomBytes (MonadRandomBytes)
import           Tox.Network.Core.Networked       (Networked)
import qualified Tox.Network.Core.NodeInfo        as NodeInfo
import           Tox.Network.Core.NodeInfo        (NodeInfo)
import           Tox.Network.Core.Packet          (Packet (..))
import qualified Tox.Network.Core.PacketKind      as PacketKind
import           Tox.Transport.SecureSession      (SecureSessionState,
                                                   handleCookieRequest,
                                                   handlePacket)

data SessionManager = SessionManager
  { SessionManager -> Map PublicKey SecureSessionState
sessionsByPk  :: Map PublicKey SecureSessionState
  , SessionManager -> CombinedKey
cookieKey     :: CombinedKey
  , SessionManager -> KeyPair
ourDhtKeyPair :: KeyPair
  }

-- | Handle an incoming packet, dispatching to the correct session.
dispatchPacket :: (Timed m, MonadRandomBytes m, Keyed m, Networked m, MonadState SessionManager m)
               => NodeInfo -> Packet ByteString -> m ()
dispatchPacket :: NodeInfo -> Packet ByteString -> m ()
dispatchPacket NodeInfo
from pkt :: Packet ByteString
pkt@(Packet PacketKind
kind ByteString
payload) = case PacketKind
kind of
  PacketKind
PacketKind.CookieRequest -> do
    CombinedKey
ck <- (SessionManager -> CombinedKey) -> m CombinedKey
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets SessionManager -> CombinedKey
cookieKey
    KeyPair
dk <- (SessionManager -> KeyPair) -> m KeyPair
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets SessionManager -> KeyPair
ourDhtKeyPair
    CombinedKey -> KeyPair -> NodeInfo -> ByteString -> m ()
forall (m :: * -> *).
(Timed m, MonadRandomBytes m, Keyed m, Networked m) =>
CombinedKey -> KeyPair -> NodeInfo -> ByteString -> m ()
handleCookieRequest CombinedKey
ck KeyPair
dk NodeInfo
from ByteString
payload
  PacketKind
_ -> do
    let pk :: PublicKey
pk = NodeInfo -> PublicKey
NodeInfo.publicKey NodeInfo
from
    Maybe SecureSessionState
mSession <- (SessionManager -> Maybe SecureSessionState)
-> m (Maybe SecureSessionState)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (PublicKey
-> Map PublicKey SecureSessionState -> Maybe SecureSessionState
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup PublicKey
pk (Map PublicKey SecureSessionState -> Maybe SecureSessionState)
-> (SessionManager -> Map PublicKey SecureSessionState)
-> SessionManager
-> Maybe SecureSessionState
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SessionManager -> Map PublicKey SecureSessionState
sessionsByPk)
    case Maybe SecureSessionState
mSession of
      Maybe SecureSessionState
Nothing -> () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      Just SecureSessionState
session -> do
        CombinedKey
ck <- (SessionManager -> CombinedKey) -> m CombinedKey
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets SessionManager -> CombinedKey
cookieKey
        ( (), SecureSessionState
session' ) <- StateT SecureSessionState m ()
-> SecureSessionState -> m ((), SecureSessionState)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (CombinedKey
-> NodeInfo -> Packet ByteString -> StateT SecureSessionState m ()
forall (m :: * -> *).
(Timed m, MonadRandomBytes m, Keyed m, Networked m,
 MonadState SecureSessionState m) =>
CombinedKey -> NodeInfo -> Packet ByteString -> m ()
handlePacket CombinedKey
ck NodeInfo
from Packet ByteString
pkt) SecureSessionState
session
        (SessionManager -> SessionManager) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((SessionManager -> SessionManager) -> m ())
-> (SessionManager -> SessionManager) -> m ()
forall a b. (a -> b) -> a -> b
$ \SessionManager
s -> SessionManager
s { sessionsByPk :: Map PublicKey SecureSessionState
sessionsByPk = PublicKey
-> SecureSessionState
-> Map PublicKey SecureSessionState
-> Map PublicKey SecureSessionState
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert PublicKey
pk SecureSessionState
session' (SessionManager -> Map PublicKey SecureSessionState
sessionsByPk SessionManager
s) }