{-# OPTIONS_GHC -Wall -Werror #-}
{-# LANGUAGE NoGeneralizedNewtypeDeriving #-}
{-# LANGUAGE Safe #-}
{-# LANGUAGE LambdaCase #-}
module Agent.Control.Concurrent
( Async
, Task
, fork
, wait
, join
)
where
import Control.Concurrent ( ThreadId, forkFinally )
import Control.Concurrent.MVar
( MVar
, newEmptyMVar
, putMVar
, readMVar
)
import GHC.Exception ( SomeException, throw )
data Task a = Task !ThreadId (IO (Either SomeException a))
class Monad m => Proxy m where
new :: m (MVar a)
put :: MVar a -> a -> m ( )
get :: MVar a -> m a
class Proxy m => Async m where
fork :: m a -> m (Task a)
wait :: Task a -> m a
join :: [ Task a ] -> m ( )
instance Proxy IO where
new :: forall a. IO (MVar a)
new = IO (MVar a)
forall a. IO (MVar a)
newEmptyMVar
put :: forall a. MVar a -> a -> IO ()
put = MVar a -> a -> IO ()
forall a. MVar a -> a -> IO ()
putMVar
get :: forall a. MVar a -> IO a
get = MVar a -> IO a
forall a. MVar a -> IO a
readMVar
instance Async IO where
fork :: forall a. IO a -> IO (Task a)
fork IO a
compute =
IO (MVar (Either SomeException a))
forall a. IO (MVar a)
forall (m :: * -> *) a. Proxy m => m (MVar a)
new IO (MVar (Either SomeException a))
-> (MVar (Either SomeException a) -> IO (Task a)) -> IO (Task a)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \ MVar (Either SomeException a)
var ->
IO a -> (Either SomeException a -> IO ()) -> IO ThreadId
forall a. IO a -> (Either SomeException a -> IO ()) -> IO ThreadId
forkFinally IO a
compute (MVar (Either SomeException a) -> Either SomeException a -> IO ()
forall {m :: * -> *} {a}. Proxy m => MVar a -> a -> m ()
finally MVar (Either SomeException a)
var) IO ThreadId -> (ThreadId -> IO (Task a)) -> IO (Task a)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \ ThreadId
tid ->
Task a -> IO (Task a)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Task a -> IO (Task a)) -> Task a -> IO (Task a)
forall a b. (a -> b) -> a -> b
$ ThreadId -> IO (Either SomeException a) -> Task a
forall a. ThreadId -> IO (Either SomeException a) -> Task a
Task ThreadId
tid (IO (Either SomeException a) -> Task a)
-> IO (Either SomeException a) -> Task a
forall a b. (a -> b) -> a -> b
$ MVar (Either SomeException a) -> IO (Either SomeException a)
forall a. MVar a -> IO a
forall (m :: * -> *) a. Proxy m => MVar a -> m a
get MVar (Either SomeException a)
var
where
finally :: MVar a -> a -> m ()
finally MVar a
v a
r = MVar a -> a -> m ()
forall a. MVar a -> a -> m ()
forall (m :: * -> *) a. Proxy m => MVar a -> a -> m ()
put MVar a
v a
r
wait :: forall a. Task a -> IO a
wait (Task ThreadId
_ IO (Either SomeException a)
mvar) =
( \case
Right a
v -> a
v
Left SomeException
e -> SomeException -> a
forall a e. Exception e => e -> a
throw SomeException
e
)
(Either SomeException a -> a)
-> IO (Either SomeException a) -> IO a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO (Either SomeException a)
mvar
join :: forall a. [Task a] -> IO ()
join =
(Task a -> IO a) -> [Task a] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Task a -> IO a
forall a. Task a -> IO a
forall (m :: * -> *) a. Async m => Task a -> m a
wait