{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE LambdaCase   #-}
module Supervisors
    ( Supervisor
    , withSupervisor
    , supervise
    , superviseSTM
    ) where
import Control.Concurrent.STM
import Control.Concurrent       (ThreadId, forkIO, myThreadId, throwTo)
import Control.Concurrent.Async (withAsync)
import Control.Exception.Safe
    (Exception, SomeException, bracket, bracket_, toException, withException)
import Control.Monad            (forever, void)
import Data.Foldable            (traverse_)
import qualified Data.Set as S
data Supervisor = Supervisor
    { stateVar :: TVar (Either SomeException (S.Set ThreadId))
    , runQ     :: TQueue (IO ())
    }
newSupervisor :: IO Supervisor
newSupervisor = do
    stateVar <- newTVarIO $ Right S.empty
    runQ <- newTQueueIO
    let sup = Supervisor
            { stateVar = stateVar
            , runQ = runQ
            }
    pure sup
runSupervisor :: Supervisor -> IO ()
runSupervisor sup@Supervisor{runQ=q} =
    forever (atomically (readTQueue q) >>= supervise sup)
    `withException`
    \e -> throwKids sup (e :: SomeException)
withSupervisor :: (Supervisor -> IO a) -> IO a
withSupervisor f = do
    sup <- newSupervisor
    withAsync (runSupervisor sup) $ const (f sup)
throwKids :: Exception e => Supervisor -> e -> IO ()
throwKids Supervisor{stateVar=stateVar} exn =
    bracket
        (atomically $ readTVar stateVar >>= \case
            Left _ ->
                pure S.empty
            Right kids -> do
                writeTVar stateVar $ Left (toException exn)
                pure kids)
        (traverse_ (`throwTo` exn))
        (\_ -> pure ())
supervise :: Supervisor -> IO () -> IO ()
supervise Supervisor{stateVar=stateVar} task =
    void $ forkIO $ bracket_ addMe removeMe task
  where
    
    addMe = do
        me <- myThreadId
        atomically $ do
            supState <- readTVar stateVar
            case supState of
                Left e ->
                    throwSTM e
                Right kids -> do
                    let !newKids = S.insert me kids
                    writeTVar stateVar $ Right newKids
    
    removeMe = do
        me <- myThreadId
        atomically $ modifyTVar' stateVar $ \case
            state@(Left _) ->
                
                
                state
            Right kids ->
                
                
                
                
                
                
                
                
                Right $! S.delete me kids
superviseSTM :: Supervisor -> IO () -> STM ()
superviseSTM Supervisor{runQ=q} = writeTQueue q