{-
  Copyright (c) Meta Platforms, Inc. and affiliates.
  All rights reserved.

  This source code is licensed under the BSD-style license found in the
  LICENSE file in the root directory of this source tree.
-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- | A more flexible way to create 'Async's and have them automatically
-- cancelled when the 'Warden' is shut down.
module Control.Concurrent.Async.Warden
  ( Warden
  , withWarden
  , create
  , shutdown
  , spawn
  , spawn_
  , spawnMask
  , WardenException(..)
  ) where

import Control.Concurrent (forkIO)
import Control.Concurrent.Async (Async)
import qualified Control.Concurrent.Async as Async
import Control.Concurrent.MVar
import Control.Exception
import Data.HashSet (HashSet)
import qualified Data.HashSet as HashSet
import System.IO (fixIO)

#if defined(__MHS__)
import Prelude hiding(mapM_)
import Control.Monad hiding(mapM_)
import Data.Foldable(mapM_)
#else
import Control.Monad
#endif

-- | A 'Warden' is an owner of 'Async's which cancels them on 'shutdown'.
--
-- 'Nothing' in the MVar means the 'Warden' has been shut down.
newtype Warden = Warden (MVar (Maybe (HashSet (Async ()))))

-- | Run the action with a new 'Warden', and call 'shutdown' when the action
-- exits.
withWarden :: (Warden -> IO a) -> IO a
withWarden :: forall a. (Warden -> IO a) -> IO a
withWarden = IO Warden -> (Warden -> IO ()) -> (Warden -> IO a) -> IO a
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket IO Warden
create Warden -> IO ()
shutdown

-- | Create a new 'Warden'.
create :: IO Warden
create :: IO Warden
create = MVar (Maybe (HashSet (Async ()))) -> Warden
Warden (MVar (Maybe (HashSet (Async ()))) -> Warden)
-> IO (MVar (Maybe (HashSet (Async ())))) -> IO Warden
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe (HashSet (Async ()))
-> IO (MVar (Maybe (HashSet (Async ()))))
forall a. a -> IO (MVar a)
newMVar (HashSet (Async ()) -> Maybe (HashSet (Async ()))
forall a. a -> Maybe a
Just HashSet (Async ())
forall a. Monoid a => a
mempty)

-- | Shutdown a 'Warden', calling 'cancel' on all owned threads. Subsequent
-- calls to 'spawn' and 'shutdown' will be no-ops. 
-- 
-- Note that any exceptions thrown by the threads will be ignored. If you want
-- exceptions to be propagated, either call `wait` explicitly on the 'Async', 
-- or use 'link'.
shutdown :: Warden -> IO ()
shutdown :: Warden -> IO ()
shutdown (Warden MVar (Maybe (HashSet (Async ())))
v) = do
  Maybe (HashSet (Async ()))
r <- MVar (Maybe (HashSet (Async ())))
-> Maybe (HashSet (Async ())) -> IO (Maybe (HashSet (Async ())))
forall a. MVar a -> a -> IO a
swapMVar MVar (Maybe (HashSet (Async ())))
v Maybe (HashSet (Async ()))
forall a. Maybe a
Nothing
  (HashSet (Async ()) -> IO ())
-> Maybe (HashSet (Async ())) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((Async () -> IO ()) -> HashSet (Async ()) -> IO ()
forall (f :: * -> *) a b. Foldable f => (a -> IO b) -> f a -> IO ()
Async.mapConcurrently_ Async () -> IO ()
forall a. Async a -> IO ()
Async.cancel) Maybe (HashSet (Async ()))
r

forget :: Warden -> Async a -> IO ()
forget :: forall a. Warden -> Async a -> IO ()
forget (Warden MVar (Maybe (HashSet (Async ())))
v) Async a
async = MVar (Maybe (HashSet (Async ())))
-> (Maybe (HashSet (Async ())) -> IO (Maybe (HashSet (Async ()))))
-> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ MVar (Maybe (HashSet (Async ())))
v ((Maybe (HashSet (Async ())) -> IO (Maybe (HashSet (Async ()))))
 -> IO ())
-> (Maybe (HashSet (Async ())) -> IO (Maybe (HashSet (Async ()))))
-> IO ()
forall a b. (a -> b) -> a -> b
$ \Maybe (HashSet (Async ()))
x -> case Maybe (HashSet (Async ()))
x of
  Just HashSet (Async ())
xs -> Maybe (HashSet (Async ())) -> IO (Maybe (HashSet (Async ())))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (HashSet (Async ())) -> IO (Maybe (HashSet (Async ()))))
-> Maybe (HashSet (Async ())) -> IO (Maybe (HashSet (Async ())))
forall a b. (a -> b) -> a -> b
$! HashSet (Async ()) -> Maybe (HashSet (Async ()))
forall a. a -> Maybe a
Just (HashSet (Async ()) -> Maybe (HashSet (Async ())))
-> HashSet (Async ()) -> Maybe (HashSet (Async ()))
forall a b. (a -> b) -> a -> b
$! Async () -> HashSet (Async ()) -> HashSet (Async ())
forall a. Hashable a => a -> HashSet a -> HashSet a
HashSet.delete (Async a -> Async ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void Async a
async) HashSet (Async ())
xs
  Maybe (HashSet (Async ()))
Nothing -> Maybe (HashSet (Async ())) -> IO (Maybe (HashSet (Async ())))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (HashSet (Async ()))
forall a. Maybe a
Nothing

-- | Spawn a thread with masked exceptions and pass an unmask function to the
-- action.
spawnMask :: Warden -> ((forall b. IO b -> IO b) -> IO a) -> IO (Async a)
spawnMask :: forall a.
Warden -> ((forall b. IO b -> IO b) -> IO a) -> IO (Async a)
spawnMask (Warden MVar (Maybe (HashSet (Async ())))
v) (forall b. IO b -> IO b) -> IO a
action = MVar (Maybe (HashSet (Async ())))
-> (Maybe (HashSet (Async ()))
    -> IO (Maybe (HashSet (Async ())), Async a))
-> IO (Async a)
forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVar MVar (Maybe (HashSet (Async ())))
v ((Maybe (HashSet (Async ()))
  -> IO (Maybe (HashSet (Async ())), Async a))
 -> IO (Async a))
-> (Maybe (HashSet (Async ()))
    -> IO (Maybe (HashSet (Async ())), Async a))
-> IO (Async a)
forall a b. (a -> b) -> a -> b
$ \Maybe (HashSet (Async ()))
r -> case Maybe (HashSet (Async ()))
r of
  Just HashSet (Async ())
asyncs -> do
    -- Create a new thread which removes itself from the 'HashSet' when it
    -- exits.
    Async a
this <- (Async a -> IO (Async a)) -> IO (Async a)
forall a. (a -> IO a) -> IO a
fixIO ((Async a -> IO (Async a)) -> IO (Async a))
-> (Async a -> IO (Async a)) -> IO (Async a)
forall a b. (a -> b) -> a -> b
$ \Async a
this -> IO (Async a) -> IO (Async a)
forall b. IO b -> IO b
mask_ (IO (Async a) -> IO (Async a)) -> IO (Async a) -> IO (Async a)
forall a b. (a -> b) -> a -> b
$ ((forall b. IO b -> IO b) -> IO a) -> IO (Async a)
forall a. ((forall b. IO b -> IO b) -> IO a) -> IO (Async a)
Async.asyncWithUnmask (((forall b. IO b -> IO b) -> IO a) -> IO (Async a))
-> ((forall b. IO b -> IO b) -> IO a) -> IO (Async a)
forall a b. (a -> b) -> a -> b
$ \forall b. IO b -> IO b
unmask ->
      (forall b. IO b -> IO b) -> IO a
action IO b -> IO b
forall b. IO b -> IO b
unmask IO a -> IO () -> IO a
forall a b. IO a -> IO b -> IO a
`finally` Warden -> Async a -> IO ()
forall a. Warden -> Async a -> IO ()
forget (MVar (Maybe (HashSet (Async ()))) -> Warden
Warden MVar (Maybe (HashSet (Async ())))
v) Async a
this
    (Maybe (HashSet (Async ())), Async a)
-> IO (Maybe (HashSet (Async ())), Async a)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (HashSet (Async ()) -> Maybe (HashSet (Async ()))
forall a. a -> Maybe a
Just (HashSet (Async ()) -> Maybe (HashSet (Async ())))
-> HashSet (Async ()) -> Maybe (HashSet (Async ()))
forall a b. (a -> b) -> a -> b
$ Async () -> HashSet (Async ()) -> HashSet (Async ())
forall a. Hashable a => a -> HashSet a -> HashSet a
HashSet.insert (Async a -> Async ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void Async a
this) HashSet (Async ())
asyncs, Async a
this)
  Maybe (HashSet (Async ()))
Nothing -> WardenException -> IO (Maybe (HashSet (Async ())), Async a)
forall e a. Exception e => e -> IO a
throwIO (WardenException -> IO (Maybe (HashSet (Async ())), Async a))
-> WardenException -> IO (Maybe (HashSet (Async ())), Async a)
forall a b. (a -> b) -> a -> b
$ String -> WardenException
WardenException String
"Warden has been shut down"

newtype WardenException = WardenException String
  deriving (Int -> WardenException -> ShowS
[WardenException] -> ShowS
WardenException -> String
(Int -> WardenException -> ShowS)
-> (WardenException -> String)
-> ([WardenException] -> ShowS)
-> Show WardenException
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> WardenException -> ShowS
showsPrec :: Int -> WardenException -> ShowS
$cshow :: WardenException -> String
show :: WardenException -> String
$cshowList :: [WardenException] -> ShowS
showList :: [WardenException] -> ShowS
Show)

instance Exception WardenException

-- | Spawn a new thread owned by the 'Warden'.
spawn :: Warden -> IO a -> IO (Async a)
spawn :: forall a. Warden -> IO a -> IO (Async a)
spawn Warden
warden IO a
action = Warden -> ((forall b. IO b -> IO b) -> IO a) -> IO (Async a)
forall a.
Warden -> ((forall b. IO b -> IO b) -> IO a) -> IO (Async a)
spawnMask Warden
warden (((forall b. IO b -> IO b) -> IO a) -> IO (Async a))
-> ((forall b. IO b -> IO b) -> IO a) -> IO (Async a)
forall a b. (a -> b) -> a -> b
$ \forall b. IO b -> IO b
unmask -> IO a -> IO a
forall b. IO b -> IO b
unmask IO a
action

-- | Spawn a new thread owned by the 'Warden'.
spawn_ :: Warden -> IO () -> IO ()
spawn_ :: Warden -> IO () -> IO ()
spawn_ Warden
w = IO (Async ()) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO (Async ()) -> IO ())
-> (IO () -> IO (Async ())) -> IO () -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Warden -> IO () -> IO (Async ())
forall a. Warden -> IO a -> IO (Async a)
spawn Warden
w