-- | This module defines the multi-thread backend for the `Network` monad.
module Choreography.Network.Local where

import Choreography.Locations
import Choreography.Network
import Control.Concurrent
import Control.Monad
import Control.Monad.Freer
import Control.Monad.IO.Class
import Data.HashMap.Strict (HashMap, (!), (!?))
import Data.HashMap.Strict qualified as HashMap

-- | Each location is associated with a message buffer which stores messages sent
-- from other locations.
type MsgBuf = HashMap LocTm (Chan String)

-- | A backend for running choreographies using [Haskell threads](https://hackage.haskell.org/package/base/docs/Control-Concurrent.html)
--   as the locations and buffered `Control.Concurrent.Chan.Chan` channels for communication.
newtype LocalConfig = LocalConfig
  { LocalConfig -> HashMap LocTm MsgBuf
locToBuf :: HashMap LocTm MsgBuf
  }

-- | Make a channel for each of the listed locations, on which messages from that location can be recieved.
newEmptyMsgBuf :: [LocTm] -> IO MsgBuf
newEmptyMsgBuf :: [LocTm] -> IO MsgBuf
newEmptyMsgBuf = (MsgBuf -> LocTm -> IO MsgBuf) -> MsgBuf -> [LocTm] -> IO MsgBuf
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM MsgBuf -> LocTm -> IO MsgBuf
forall {p} {a}.
Hashable p =>
HashMap p (Chan a) -> p -> IO (HashMap p (Chan a))
f MsgBuf
forall k v. HashMap k v
HashMap.empty
  where
    f :: HashMap p (Chan a) -> p -> IO (HashMap p (Chan a))
f HashMap p (Chan a)
hash p
loc = do
      Chan a
chan <- IO (Chan a)
forall a. IO (Chan a)
newChan
      HashMap p (Chan a) -> IO (HashMap p (Chan a))
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (p -> Chan a -> HashMap p (Chan a) -> HashMap p (Chan a)
forall k v.
(Eq k, Hashable k) =>
k -> v -> HashMap k v -> HashMap k v
HashMap.insert p
loc Chan a
chan HashMap p (Chan a)
hash)

-- | Make a local backend for the listed parties.
--   Make just the one backend and then have all your threads use the same one.
mkLocalConfig :: [LocTm] -> IO LocalConfig
mkLocalConfig :: [LocTm] -> IO LocalConfig
mkLocalConfig [LocTm]
ls = HashMap LocTm MsgBuf -> LocalConfig
LocalConfig (HashMap LocTm MsgBuf -> LocalConfig)
-> IO (HashMap LocTm MsgBuf) -> IO LocalConfig
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HashMap LocTm MsgBuf -> LocTm -> IO (HashMap LocTm MsgBuf))
-> HashMap LocTm MsgBuf -> [LocTm] -> IO (HashMap LocTm MsgBuf)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM HashMap LocTm MsgBuf -> LocTm -> IO (HashMap LocTm MsgBuf)
f HashMap LocTm MsgBuf
forall k v. HashMap k v
HashMap.empty [LocTm]
ls
  where
    f :: HashMap LocTm MsgBuf -> LocTm -> IO (HashMap LocTm MsgBuf)
f HashMap LocTm MsgBuf
hash LocTm
loc = do
      MsgBuf
buf <- [LocTm] -> IO MsgBuf
newEmptyMsgBuf [LocTm]
ls
      HashMap LocTm MsgBuf -> IO (HashMap LocTm MsgBuf)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LocTm -> MsgBuf -> HashMap LocTm MsgBuf -> HashMap LocTm MsgBuf
forall k v.
(Eq k, Hashable k) =>
k -> v -> HashMap k v -> HashMap k v
HashMap.insert LocTm
loc MsgBuf
buf HashMap LocTm MsgBuf
hash)

-- | List the parties known to the backend.
locs :: LocalConfig -> [LocTm]
locs :: LocalConfig -> [LocTm]
locs = HashMap LocTm MsgBuf -> [LocTm]
forall k v. HashMap k v -> [k]
HashMap.keys (HashMap LocTm MsgBuf -> [LocTm])
-> (LocalConfig -> HashMap LocTm MsgBuf) -> LocalConfig -> [LocTm]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LocalConfig -> HashMap LocTm MsgBuf
locToBuf

-- | Run a `Network` behavior using the channels in a t`LocalConfig` for communication.
--   Call this inside a concurrent thread.
runNetworkLocal :: (MonadIO m) => LocalConfig -> LocTm -> Network m a -> m a
runNetworkLocal :: forall (m :: * -> *) a.
MonadIO m =>
LocalConfig -> LocTm -> Network m a -> m a
runNetworkLocal LocalConfig
cfg LocTm
self = (forall b. NetworkSig m b -> m b) -> Freer (NetworkSig m) a -> m a
forall (m :: * -> *) (f :: * -> *) a.
Monad m =>
(forall b. f b -> m b) -> Freer f a -> m a
interpFreer NetworkSig m b -> m b
forall b. NetworkSig m b -> m b
forall (m :: * -> *) a. MonadIO m => NetworkSig m a -> m a
handler
  where
    handler :: (MonadIO m) => NetworkSig m a -> m a
    handler :: forall (m :: * -> *) a. MonadIO m => NetworkSig m a -> m a
handler (Run m a
m) = m a
m
    handler (Send a1
a [LocTm]
ls) = IO a -> m a
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO a -> m a) -> IO a -> m a
forall a b. (a -> b) -> a -> b
$ (LocTm -> IO ()) -> [LocTm] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (\LocTm
l -> Chan LocTm -> LocTm -> IO ()
forall a. Chan a -> a -> IO ()
writeChan ((LocalConfig -> HashMap LocTm MsgBuf
locToBuf LocalConfig
cfg HashMap LocTm MsgBuf -> LocTm -> MsgBuf
forall k v.
(Eq k, Hashable k, HasCallStack) =>
HashMap k v -> k -> v
! LocTm
l) MsgBuf -> LocTm -> Chan LocTm
forall k v.
(Eq k, Hashable k, HasCallStack) =>
HashMap k v -> k -> v
! LocTm
self) (a1 -> LocTm
forall a. Show a => a -> LocTm
show a1
a)) [LocTm]
ls
    handler (Recv LocTm
l) = do
      let b :: MsgBuf
b = LocalConfig -> HashMap LocTm MsgBuf
locToBuf LocalConfig
cfg HashMap LocTm MsgBuf -> LocTm -> MsgBuf
forall k v.
(Eq k, Hashable k, HasCallStack) =>
HashMap k v -> k -> v
! LocTm
self
      let q :: Maybe (Chan LocTm)
q = MsgBuf
b MsgBuf -> LocTm -> Maybe (Chan LocTm)
forall k v. (Eq k, Hashable k) => HashMap k v -> k -> Maybe v
!? LocTm
l
      case Maybe (Chan LocTm)
q of
        Just Chan LocTm
q' -> IO a -> m a
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO a -> m a) -> IO a -> m a
forall a b. (a -> b) -> a -> b
$ LocTm -> a
forall a. Read a => LocTm -> a
read (LocTm -> a) -> IO LocTm -> IO a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Chan LocTm -> IO LocTm
forall a. Chan a -> IO a
readChan Chan LocTm
q'
        Maybe (Chan LocTm)
Nothing -> IO a -> m a
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO do
          HashMap LocTm () -> IO ()
forall a. Show a => a -> IO ()
print (HashMap LocTm () -> IO ()) -> HashMap LocTm () -> IO ()
forall a b. (a -> b) -> a -> b
$ MsgBuf -> HashMap LocTm ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void MsgBuf
b
          LocTm -> IO ()
forall a. Show a => a -> IO ()
print LocTm
l
          LocTm -> IO a
forall a. HasCallStack => LocTm -> a
error (LocTm -> IO a) -> LocTm -> IO a
forall a b. (a -> b) -> a -> b
$ LocTm
"We don't know how to contact the party named \"" LocTm -> LocTm -> LocTm
forall a. Semigroup a => a -> a -> a
<> LocTm
l LocTm -> LocTm -> LocTm
forall a. Semigroup a => a -> a -> a
<> LocTm
"\"."

instance Backend LocalConfig where
  runNetwork :: forall (m :: * -> *) a.
MonadIO m =>
LocalConfig -> LocTm -> Network m a -> m a
runNetwork = LocalConfig -> LocTm -> Network m a -> m a
forall (m :: * -> *) a.
MonadIO m =>
LocalConfig -> LocTm -> Network m a -> m a
runNetworkLocal