{-# OPTIONS_GHC -Wall -Werror #-}

{-# LANGUAGE NoGeneralizedNewtypeDeriving #-}
{-# LANGUAGE Safe                         #-}

{-# LANGUAGE LambdaCase                   #-}

--------------------------------------------------------------------------------

-- |
-- Copyright  : (c) 2026 SPISE MISU ApS
-- License    : SSPL-1.0 OR AGPL-3.0-only
-- Maintainer : SPISE MISU <mail+hackage@spisemisu.com>
-- Stability  : experimental

--------------------------------------------------------------------------------

module Agent.Control.Concurrent
  ( Async
  , Task
  , fork
  , wait
  , join
  )
where

--------------------------------------------------------------------------------

import           Control.Concurrent      ( ThreadId, forkFinally )
import           Control.Concurrent.MVar
  ( MVar
  , newEmptyMVar
  , putMVar
  , readMVar
  )
import           GHC.Exception           ( SomeException, throw )

--------------------------------------------------------------------------------

data Task a = Task !ThreadId (IO (Either SomeException a))

--------------------------------------------------------------------------------

-- |
-- To prevent the users from adding instances of `Async`, we
-- provide a middle-layer (`Proxy`) between the `Monad` instance and the `Proxy`
-- (`Async`) instance.

class Monad m => Proxy m where
  new :: m (MVar a)
  put ::    MVar a -> a -> m ( )
  get ::    MVar a ->      m  a
class Proxy m => Async m where
  fork ::   m    a   -> m (Task a)
  wait ::   Task a   -> m       a
  join :: [ Task a ] -> m      ( )

--------------------------------------------------------------------------------

instance Proxy IO where
  -- | newEmptyMVar: Create an MVar which is initially empty.
  new :: forall a. IO (MVar a)
new = IO (MVar a)
forall a. IO (MVar a)
newEmptyMVar
  -- | putMVar: Put a value into an MVar.
  put :: forall a. MVar a -> a -> IO ()
put = MVar a -> a -> IO ()
forall a. MVar a -> a -> IO ()
putMVar
  -- |
  -- readMVar: Atomically read the contents of an MVar. If the MVar is currently
  -- empty, readMVar will wait until it is full. readMVar is guaranteed to
  -- receive the next putMVar.
  get :: forall a. MVar a -> IO a
get = MVar a -> IO a
forall a. MVar a -> IO a
readMVar

instance Async IO where
  fork :: forall a. IO a -> IO (Task a)
fork IO a
compute =
    IO (MVar (Either SomeException a))
forall a. IO (MVar a)
forall (m :: * -> *) a. Proxy m => m (MVar a)
new                               IO (MVar (Either SomeException a))
-> (MVar (Either SomeException a) -> IO (Task a)) -> IO (Task a)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \ MVar (Either SomeException a)
var ->
    IO a -> (Either SomeException a -> IO ()) -> IO ThreadId
forall a. IO a -> (Either SomeException a -> IO ()) -> IO ThreadId
forkFinally IO a
compute (MVar (Either SomeException a) -> Either SomeException a -> IO ()
forall {m :: * -> *} {a}. Proxy m => MVar a -> a -> m ()
finally MVar (Either SomeException a)
var) IO ThreadId -> (ThreadId -> IO (Task a)) -> IO (Task a)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \ ThreadId
tid ->
    Task a -> IO (Task a)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Task a -> IO (Task a)) -> Task a -> IO (Task a)
forall a b. (a -> b) -> a -> b
$ ThreadId -> IO (Either SomeException a) -> Task a
forall a. ThreadId -> IO (Either SomeException a) -> Task a
Task ThreadId
tid (IO (Either SomeException a) -> Task a)
-> IO (Either SomeException a) -> Task a
forall a b. (a -> b) -> a -> b
$ MVar (Either SomeException a) -> IO (Either SomeException a)
forall a. MVar a -> IO a
forall (m :: * -> *) a. Proxy m => MVar a -> m a
get MVar (Either SomeException a)
var
    where
      finally :: MVar a -> a -> m ()
finally MVar a
v a
r = MVar a -> a -> m ()
forall a. MVar a -> a -> m ()
forall (m :: * -> *) a. Proxy m => MVar a -> a -> m ()
put MVar a
v a
r
  wait :: forall a. Task a -> IO a
wait (Task ThreadId
_ IO (Either SomeException a)
mvar) =
    ( \case
        Right a
v ->       a
v
        Left  SomeException
e -> SomeException -> a
forall a e. Exception e => e -> a
throw SomeException
e
    )
    (Either SomeException a -> a)
-> IO (Either SomeException a) -> IO a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO (Either SomeException a)
mvar
  join :: forall a. [Task a] -> IO ()
join =
    (Task a -> IO a) -> [Task a] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Task a -> IO a
forall a. Task a -> IO a
forall (m :: * -> *) a. Async m => Task a -> m a
wait