-- | Monitored threads
--
-- Intended for unqualified import.
module Network.GRPC.Util.Thread (
    ThreadState(..)
    -- * Creating threads
  , ThreadBody
  , newThreadState
  , forkThread
  , threadBody
    -- * Thread debug ID
  , DebugThreadId -- opaque
  , threadDebugId
    -- * Access thread state
  , CancelResult(..)
  , cancelThread
  , ThreadState_(..)
  , getThreadState_
  , unlessAbnormallyTerminated
  , withThreadInterface
  , waitForNormalThreadTermination
  , waitForNormalOrAbnormalThreadTermination
  ) where

import Control.Concurrent
import Control.Concurrent.STM
import Control.Exception
import Control.Monad
import Data.Void (Void, absurd)
import Foreign (newStablePtr, freeStablePtr)
import GHC.Stack
import System.IO.Unsafe (unsafePerformIO)

import Network.GRPC.Util.GHC

{-------------------------------------------------------------------------------
  Debug thread IDs
-------------------------------------------------------------------------------}

-- | Debug thread IDs
--
-- Unlike 'ThreadId', these do not correspond to a /running/ thread necessarily,
-- but just enable us to distinguish one thread from another.
data DebugThreadId = DebugThreadId {
      DebugThreadId -> Word
debugThreadId        :: Word
    , DebugThreadId -> CallStack
debugThreadCreatedAt :: CallStack
    }
  deriving stock (Int -> DebugThreadId -> ShowS
[DebugThreadId] -> ShowS
DebugThreadId -> String
(Int -> DebugThreadId -> ShowS)
-> (DebugThreadId -> String)
-> ([DebugThreadId] -> ShowS)
-> Show DebugThreadId
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> DebugThreadId -> ShowS
showsPrec :: Int -> DebugThreadId -> ShowS
$cshow :: DebugThreadId -> String
show :: DebugThreadId -> String
$cshowList :: [DebugThreadId] -> ShowS
showList :: [DebugThreadId] -> ShowS
Show)

nextDebugThreadId :: MVar Word
{-# NOINLINE nextDebugThreadId #-}
nextDebugThreadId :: MVar Word
nextDebugThreadId = IO (MVar Word) -> MVar Word
forall a. IO a -> a
unsafePerformIO (IO (MVar Word) -> MVar Word) -> IO (MVar Word) -> MVar Word
forall a b. (a -> b) -> a -> b
$ Word -> IO (MVar Word)
forall a. a -> IO (MVar a)
newMVar Word
0

newDebugThreadId :: HasCallStack => IO DebugThreadId
newDebugThreadId :: HasCallStack => IO DebugThreadId
newDebugThreadId =
    MVar Word -> (Word -> IO (Word, DebugThreadId)) -> IO DebugThreadId
forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVar MVar Word
nextDebugThreadId ((Word -> IO (Word, DebugThreadId)) -> IO DebugThreadId)
-> (Word -> IO (Word, DebugThreadId)) -> IO DebugThreadId
forall a b. (a -> b) -> a -> b
$ \Word
x -> do
      let !nextId :: Word
nextId = Word -> Word
forall a. Enum a => a -> a
succ Word
x
      (Word, DebugThreadId) -> IO (Word, DebugThreadId)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (
          Word
nextId
        , Word -> CallStack -> DebugThreadId
DebugThreadId Word
x (CallStack -> CallStack
popIrrelevant CallStack
HasCallStack => CallStack
callStack)
        )
  where
    -- Pop off the call to 'newDebugThreadId'
    --
    -- We leave the call to 'newThreadState' on the stack because it is useful
    -- to know where that was called /from/.
    popIrrelevant :: CallStack -> CallStack
    popIrrelevant :: CallStack -> CallStack
popIrrelevant = CallStack -> CallStack
popCallStack

{-------------------------------------------------------------------------------
  State
-------------------------------------------------------------------------------}

-- | State of a thread with public interface of type @a@
data ThreadState a =
    -- | The thread has not yet started
    --
    -- If the thread is cancelled before it is started, then the exception will
    -- be delivered once started. This is important, because it gives the thread
    -- control over /when/ the exception is delivered (that is, when it chooses
    -- to unmask async exceptions).
    --
    -- The alternative would be not to start the thread at all in this case, but
    -- this takes away the control mentioned above; if the thread /needs/ to do
    -- something before it can be killed, it must be given that chance. It may
    -- /seem/ that this alternative would give the caller (which /created/ the
    -- thread) more control, but actually that control is illusory, since the
    -- timing of async exceptions is anyway unpredictable.
    ThreadNotStarted DebugThreadId

    -- | The externally visible thread interface is still being initialized
  | ThreadInitializing DebugThreadId ThreadId

    -- | Thread is ready
  | ThreadRunning DebugThreadId ThreadId a

    -- | Thread terminated normally
    --
    -- This still carries the thread interface: we may need it to query the
    -- thread's final status, for example.
  | ThreadDone DebugThreadId a

    -- | Thread terminated with an exception
  | ThreadException DebugThreadId SomeException
  deriving stock (Int -> ThreadState a -> ShowS
[ThreadState a] -> ShowS
ThreadState a -> String
(Int -> ThreadState a -> ShowS)
-> (ThreadState a -> String)
-> ([ThreadState a] -> ShowS)
-> Show (ThreadState a)
forall a. Show a => Int -> ThreadState a -> ShowS
forall a. Show a => [ThreadState a] -> ShowS
forall a. Show a => ThreadState a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall a. Show a => Int -> ThreadState a -> ShowS
showsPrec :: Int -> ThreadState a -> ShowS
$cshow :: forall a. Show a => ThreadState a -> String
show :: ThreadState a -> String
$cshowList :: forall a. Show a => [ThreadState a] -> ShowS
showList :: [ThreadState a] -> ShowS
Show, (forall a b. (a -> b) -> ThreadState a -> ThreadState b)
-> (forall a b. a -> ThreadState b -> ThreadState a)
-> Functor ThreadState
forall a b. a -> ThreadState b -> ThreadState a
forall a b. (a -> b) -> ThreadState a -> ThreadState b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> ThreadState a -> ThreadState b
fmap :: forall a b. (a -> b) -> ThreadState a -> ThreadState b
$c<$ :: forall a b. a -> ThreadState b -> ThreadState a
<$ :: forall a b. a -> ThreadState b -> ThreadState a
Functor)

threadDebugId :: ThreadState a -> DebugThreadId
threadDebugId :: forall a. ThreadState a -> DebugThreadId
threadDebugId (ThreadNotStarted   DebugThreadId
debugId    ) = DebugThreadId
debugId
threadDebugId (ThreadInitializing DebugThreadId
debugId ThreadId
_  ) = DebugThreadId
debugId
threadDebugId (ThreadRunning      DebugThreadId
debugId ThreadId
_ a
_) = DebugThreadId
debugId
threadDebugId (ThreadDone         DebugThreadId
debugId   a
_) = DebugThreadId
debugId
threadDebugId (ThreadException    DebugThreadId
debugId   SomeException
_) = DebugThreadId
debugId

{-------------------------------------------------------------------------------
  Creating threads
-------------------------------------------------------------------------------}

type ThreadBody a =
          (forall x. IO x -> IO x) -- ^ Unmask exceptions
       -> (a -> IO ())             -- ^ Mark thread ready
       -> DebugThreadId            -- ^ Unique identifier for this thread
       -> IO ()

newThreadState :: HasCallStack => IO (TVar (ThreadState a))
newThreadState :: forall a. HasCallStack => IO (TVar (ThreadState a))
newThreadState = do
    debugId <- IO DebugThreadId
HasCallStack => IO DebugThreadId
newDebugThreadId
    newTVarIO $ ThreadNotStarted debugId

forkThread ::
     HasCallStack
  => ThreadLabel -> TVar (ThreadState a) -> ThreadBody a -> IO ()
forkThread :: forall a.
HasCallStack =>
String -> TVar (ThreadState a) -> ThreadBody a -> IO ()
forkThread String
label TVar (ThreadState a)
state ThreadBody a
body =
    IO ThreadId -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO ThreadId -> IO ()) -> IO ThreadId -> IO ()
forall a b. (a -> b) -> a -> b
$ IO ThreadId -> IO ThreadId
forall a. IO a -> IO a
mask_ (IO ThreadId -> IO ThreadId) -> IO ThreadId -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ ((forall a. IO a -> IO a) -> IO ()) -> IO ThreadId
forkIOWithUnmask (((forall a. IO a -> IO a) -> IO ()) -> IO ThreadId)
-> ((forall a. IO a -> IO a) -> IO ()) -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ \forall a. IO a -> IO a
unmask ->
      String
-> TVar (ThreadState a)
-> ((a -> IO ()) -> DebugThreadId -> IO ())
-> IO ()
forall a.
HasCallStack =>
String
-> TVar (ThreadState a)
-> ((a -> IO ()) -> DebugThreadId -> IO ())
-> IO ()
threadBody String
label TVar (ThreadState a)
state (((a -> IO ()) -> DebugThreadId -> IO ()) -> IO ())
-> ((a -> IO ()) -> DebugThreadId -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ ThreadBody a
body IO x -> IO x
forall a. IO a -> IO a
unmask

-- | Wrap the thread body
--
-- This should be wrapped around the body of the thread, and should be called
-- with exceptions masked.
--
-- This is intended for integration with existing libraries (such as @http2@),
-- which might do the forking under the hood.
--
-- If the 'ThreadState' is anything other than 'ThreadNotStarted' on entry,
-- this function terminates immediately.
threadBody :: forall a.
     HasCallStack
  => ThreadLabel
  -> TVar (ThreadState a)
  -> ((a -> IO ()) -> DebugThreadId -> IO ())
  -> IO ()
threadBody :: forall a.
HasCallStack =>
String
-> TVar (ThreadState a)
-> ((a -> IO ()) -> DebugThreadId -> IO ())
-> IO ()
threadBody String
label TVar (ThreadState a)
state (a -> IO ()) -> DebugThreadId -> IO ()
body = do
    String -> IO ()
forall (m :: * -> *). MonadIO m => String -> m ()
labelThisThread String
label
    threadId  <- IO ThreadId
myThreadId
    initState <- atomically $ readTVar state

    -- See discussion of 'ThreadNotStarted'
    -- It's critical that async exceptions are masked at this point.
    case initState of
      ThreadNotStarted DebugThreadId
debugId -> do
        STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar (ThreadState a) -> ThreadState a -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar (ThreadState a)
state (ThreadState a -> STM ()) -> ThreadState a -> STM ()
forall a b. (a -> b) -> a -> b
$ DebugThreadId -> ThreadId -> ThreadState a
forall a. DebugThreadId -> ThreadId -> ThreadState a
ThreadInitializing DebugThreadId
debugId ThreadId
threadId
      ThreadException DebugThreadId
_ SomeException
exception ->
        -- We don't change the thread status here: 'cancelThread' offers the
        -- guarantee that the thread status /will/ be in aborted or done state
        -- on return. This means that /externally/ the thread will be
        -- considered done, even if perhaps the thread must still execute some
        -- actions before it can actually terminate.
        IO ThreadId -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO ThreadId -> IO ()) -> (IO () -> IO ThreadId) -> IO () -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO () -> IO ThreadId
forkIO (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ ThreadId -> SomeException -> IO ()
forall e. Exception e => ThreadId -> e -> IO ()
throwTo ThreadId
threadId SomeException
exception
      ThreadState a
_otherwise -> do
        String -> ThreadState a -> IO ()
forall x. String -> ThreadState a -> x
unexpected String
"initState" ThreadState a
initState

    let markReady :: a -> STM ()
        markReady a
a = do
            TVar (ThreadState a) -> (ThreadState a -> ThreadState a) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar TVar (ThreadState a)
state ((ThreadState a -> ThreadState a) -> STM ())
-> (ThreadState a -> ThreadState a) -> STM ()
forall a b. (a -> b) -> a -> b
$ \ThreadState a
oldState ->
              case ThreadState a
oldState of
                ThreadInitializing DebugThreadId
debugId ThreadId
_ ->
                  DebugThreadId -> ThreadId -> a -> ThreadState a
forall a. DebugThreadId -> ThreadId -> a -> ThreadState a
ThreadRunning DebugThreadId
debugId ThreadId
threadId a
a
                ThreadException DebugThreadId
_ SomeException
_ ->
                  ThreadState a
oldState -- leave alone (see discussion above)
                ThreadState a
_otherwise ->
                  String -> ThreadState a -> ThreadState a
forall x. String -> ThreadState a -> x
unexpected String
"markReady" ThreadState a
oldState

        markDone :: Either SomeException () -> STM ()
        markDone Either SomeException ()
mDone = do
            TVar (ThreadState a) -> (ThreadState a -> ThreadState a) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar TVar (ThreadState a)
state ((ThreadState a -> ThreadState a) -> STM ())
-> (ThreadState a -> ThreadState a) -> STM ()
forall a b. (a -> b) -> a -> b
$ \ThreadState a
oldState ->
              case (ThreadState a
oldState, Either SomeException ()
mDone) of
                (ThreadRunning DebugThreadId
debugId ThreadId
_ a
iface, Right ()) ->
                  DebugThreadId -> a -> ThreadState a
forall a. DebugThreadId -> a -> ThreadState a
ThreadDone DebugThreadId
debugId a
iface
                (ThreadException{}, Either SomeException ()
_) ->
                  ThreadState a
oldState -- record /first/ exception
                (ThreadState a
_, Left SomeException
e) ->
                  DebugThreadId -> SomeException -> ThreadState a
forall a. DebugThreadId -> SomeException -> ThreadState a
ThreadException (ThreadState a -> DebugThreadId
forall a. ThreadState a -> DebugThreadId
threadDebugId ThreadState a
oldState) SomeException
e
                (ThreadState a, Either SomeException ())
_otherwise ->
                  String -> ThreadState a -> ThreadState a
forall x. String -> ThreadState a -> x
unexpected String
"markDone" ThreadState a
oldState

    res <- try $ body (atomically . markReady) (threadDebugId initState)
    atomically $ markDone res
  where
    unexpected :: String -> ThreadState a -> x
    unexpected :: forall x. String -> ThreadState a -> x
unexpected String
msg ThreadState a
st = String -> x
forall a. HasCallStack => String -> a
error (String -> x) -> String -> x
forall a b. (a -> b) -> a -> b
$ [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [
          String
msg
        , String
": unexpected "
        , ThreadState () -> String
forall a. Show a => a -> String
show (() -> a -> ()
forall a b. a -> b -> a
const () (a -> ()) -> ThreadState a -> ThreadState ()
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ThreadState a
st)
        ]

{-------------------------------------------------------------------------------
  Stopping
-------------------------------------------------------------------------------}

-- | Result of cancelling a thread
data CancelResult a =
    -- | The thread terminated normally before we could cancel it
    AlreadyTerminated a

    -- | The thread terminated with an exception before we could cancel it
  | AlreadyAborted SomeException

    -- | We killed the thread with the specified exception
  | Cancelled

-- | Kill thread if it is running
--
-- * If the thread is in `ThreadNotStarted` state, we merely change the state to
--   'ThreadException'. The thread may still be started (see discussion of
--   'ThreadNotStarted'), but /externally/ the thread will be considered to have
--   terminated.
--
-- * If the thread is initializing or running, we update the state to
--   'ThreadException' and then throw the specified exception to the thread.
--
-- * If the thread is /already/ in 'ThreadException' state, or if the thread is
--   in 'ThreadDone' state, we do nothing.
--
-- In all cases, the caller is guaranteed that the thread state has been updated
-- even if perhaps the thread is still shutting down.
cancelThread :: forall a.
     TVar (ThreadState a)
  -> SomeException
  -> IO (CancelResult a)
cancelThread :: forall a.
TVar (ThreadState a) -> SomeException -> IO (CancelResult a)
cancelThread TVar (ThreadState a)
state SomeException
e = do
    (result, mTid) <- STM (CancelResult a, Maybe ThreadId)
-> IO (CancelResult a, Maybe ThreadId)
forall a. STM a -> IO a
atomically STM (CancelResult a, Maybe ThreadId)
aux
    forM_ mTid $ flip throwTo e
    return result
  where
    aux :: STM (CancelResult a, Maybe ThreadId)
    aux :: STM (CancelResult a, Maybe ThreadId)
aux = do
        st <- TVar (ThreadState a) -> STM (ThreadState a)
forall a. TVar a -> STM a
readTVar TVar (ThreadState a)
state
        case st of
          ThreadNotStarted DebugThreadId
debugId -> do
            TVar (ThreadState a) -> ThreadState a -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar (ThreadState a)
state (ThreadState a -> STM ()) -> ThreadState a -> STM ()
forall a b. (a -> b) -> a -> b
$ DebugThreadId -> SomeException -> ThreadState a
forall a. DebugThreadId -> SomeException -> ThreadState a
ThreadException DebugThreadId
debugId SomeException
e
            (CancelResult a, Maybe ThreadId)
-> STM (CancelResult a, Maybe ThreadId)
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return (CancelResult a
forall a. CancelResult a
Cancelled, Maybe ThreadId
forall a. Maybe a
Nothing)
          ThreadInitializing DebugThreadId
debugId ThreadId
threadId -> do
            TVar (ThreadState a) -> ThreadState a -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar (ThreadState a)
state (ThreadState a -> STM ()) -> ThreadState a -> STM ()
forall a b. (a -> b) -> a -> b
$ DebugThreadId -> SomeException -> ThreadState a
forall a. DebugThreadId -> SomeException -> ThreadState a
ThreadException DebugThreadId
debugId SomeException
e
            (CancelResult a, Maybe ThreadId)
-> STM (CancelResult a, Maybe ThreadId)
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return (CancelResult a
forall a. CancelResult a
Cancelled, ThreadId -> Maybe ThreadId
forall a. a -> Maybe a
Just ThreadId
threadId)
          ThreadRunning DebugThreadId
debugId ThreadId
threadId a
_ -> do
            TVar (ThreadState a) -> ThreadState a -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar (ThreadState a)
state (ThreadState a -> STM ()) -> ThreadState a -> STM ()
forall a b. (a -> b) -> a -> b
$ DebugThreadId -> SomeException -> ThreadState a
forall a. DebugThreadId -> SomeException -> ThreadState a
ThreadException DebugThreadId
debugId SomeException
e
            (CancelResult a, Maybe ThreadId)
-> STM (CancelResult a, Maybe ThreadId)
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return (CancelResult a
forall a. CancelResult a
Cancelled, ThreadId -> Maybe ThreadId
forall a. a -> Maybe a
Just ThreadId
threadId)
          ThreadException DebugThreadId
_debugId SomeException
e' ->
            (CancelResult a, Maybe ThreadId)
-> STM (CancelResult a, Maybe ThreadId)
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return (SomeException -> CancelResult a
forall a. SomeException -> CancelResult a
AlreadyAborted SomeException
e', Maybe ThreadId
forall a. Maybe a
Nothing)
          ThreadDone DebugThreadId
_debugId a
a ->
            (CancelResult a, Maybe ThreadId)
-> STM (CancelResult a, Maybe ThreadId)
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> CancelResult a
forall a. a -> CancelResult a
AlreadyTerminated a
a, Maybe ThreadId
forall a. Maybe a
Nothing)

{-------------------------------------------------------------------------------
  Interacting with the thread
-------------------------------------------------------------------------------}

-- | Get the thread's interface
--
-- The behaviour of this 'getThreadInterface' depends on the thread state; it
--
-- * blocks if the thread in case of 'ThreadNotStarted' or 'ThreadInitializing'
-- * throws 'ThreadInterfaceUnavailable' in case of 'ThreadException'.
-- * returns the thread interface otherwise
--
-- We do /not/ distinguish between 'ThreadDone' and 'ThreadRunning' here, as
-- doing so is inherently racy (we might return that the client is still
-- running, and then it terminates before the calling code can do anything with
-- that information).
--
-- NOTE: This turns off deadlock detection for the duration of the transaction.
-- It should therefore only be used for transactions that can never be blocked
-- indefinitely.
--
-- Usage note: in practice we use this to interact with threads that in turn
-- interact with @http2@ (running 'sendMessageLoop' or 'recvMessageLoop').
-- Although calls into @http2@ may in fact block indefinitely, we will /catch/
-- those exception and treat them as network failures. If a @grapesy@ function
-- ever throws a "blocked indefinitely" exception, this should be reported as a
-- bug in @grapesy@.
withThreadInterface :: forall a b.
     TVar (ThreadState a)
  -> (a -> STM b)
  -> IO b
withThreadInterface :: forall a b. TVar (ThreadState a) -> (a -> STM b) -> IO b
withThreadInterface TVar (ThreadState a)
state a -> STM b
k =
    IO b -> IO b
forall a. IO a -> IO a
withoutDeadlockDetection (IO b -> IO b) -> (STM b -> IO b) -> STM b -> IO b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. STM b -> IO b
forall a. STM a -> IO a
atomically (STM b -> IO b) -> STM b -> IO b
forall a b. (a -> b) -> a -> b
$
      a -> STM b
k (a -> STM b) -> STM a -> STM b
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< STM a
getThreadInterface
  where
    getThreadInterface :: STM a
    getThreadInterface :: STM a
getThreadInterface = do
        st <- TVar (ThreadState a) -> STM (ThreadState a)
forall a. TVar a -> STM a
readTVar TVar (ThreadState a)
state
        case st of
          ThreadNotStarted   DebugThreadId
_     -> STM a
forall a. STM a
retry
          ThreadInitializing DebugThreadId
_ ThreadId
_   -> STM a
forall a. STM a
retry
          ThreadRunning      DebugThreadId
_ ThreadId
_ a
a -> a -> STM a
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a
          ThreadDone         DebugThreadId
_   a
a -> a -> STM a
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a
          ThreadException    DebugThreadId
_   SomeException
e -> SomeException -> STM a
forall e a. Exception e => e -> STM a
throwSTM SomeException
e

-- | Wait for the thread to terminate normally
--
-- If the thread terminated with an exception, this rethrows that exception.
waitForNormalThreadTermination :: TVar (ThreadState a) -> STM ()
waitForNormalThreadTermination :: forall a. TVar (ThreadState a) -> STM ()
waitForNormalThreadTermination TVar (ThreadState a)
state =
    TVar (ThreadState a) -> STM (ThreadState_ Void)
forall a. TVar (ThreadState a) -> STM (ThreadState_ Void)
waitUntilInitialized TVar (ThreadState a)
state STM (ThreadState_ Void) -> (ThreadState_ Void -> STM ()) -> STM ()
forall a b. STM a -> (a -> STM b) -> STM b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      ThreadNotYetRunning_ Void
v -> Void -> STM ()
forall a. Void -> a
absurd Void
v
      ThreadState_ Void
ThreadRunning_         -> STM ()
forall a. STM a
retry
      ThreadState_ Void
ThreadDone_            -> () -> STM ()
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      ThreadException_ SomeException
e     -> SomeException -> STM ()
forall e a. Exception e => e -> STM a
throwSTM SomeException
e

-- | Wait for the thread to terminate normally or abnormally
waitForNormalOrAbnormalThreadTermination ::
     TVar (ThreadState a)
  -> STM (Maybe SomeException)
waitForNormalOrAbnormalThreadTermination :: forall a. TVar (ThreadState a) -> STM (Maybe SomeException)
waitForNormalOrAbnormalThreadTermination TVar (ThreadState a)
state =
    TVar (ThreadState a) -> STM (ThreadState_ Void)
forall a. TVar (ThreadState a) -> STM (ThreadState_ Void)
waitUntilInitialized TVar (ThreadState a)
state STM (ThreadState_ Void)
-> (ThreadState_ Void -> STM (Maybe SomeException))
-> STM (Maybe SomeException)
forall a b. STM a -> (a -> STM b) -> STM b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      ThreadNotYetRunning_ Void
v -> Void -> STM (Maybe SomeException)
forall a. Void -> a
absurd Void
v
      ThreadState_ Void
ThreadRunning_         -> STM (Maybe SomeException)
forall a. STM a
retry
      ThreadState_ Void
ThreadDone_            -> Maybe SomeException -> STM (Maybe SomeException)
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe SomeException -> STM (Maybe SomeException))
-> Maybe SomeException -> STM (Maybe SomeException)
forall a b. (a -> b) -> a -> b
$ Maybe SomeException
forall a. Maybe a
Nothing
      ThreadException_ SomeException
e     -> Maybe SomeException -> STM (Maybe SomeException)
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe SomeException -> STM (Maybe SomeException))
-> Maybe SomeException -> STM (Maybe SomeException)
forall a b. (a -> b) -> a -> b
$ SomeException -> Maybe SomeException
forall a. a -> Maybe a
Just SomeException
e

-- | Run the specified transaction, unless the thread terminated with an
-- exception
unlessAbnormallyTerminated ::
     TVar (ThreadState a)
  -> STM b
  -> STM (Either SomeException b)
unlessAbnormallyTerminated :: forall a b.
TVar (ThreadState a) -> STM b -> STM (Either SomeException b)
unlessAbnormallyTerminated TVar (ThreadState a)
state STM b
f =
    TVar (ThreadState a) -> STM (ThreadState_ Void)
forall a. TVar (ThreadState a) -> STM (ThreadState_ Void)
waitUntilInitialized TVar (ThreadState a)
state STM (ThreadState_ Void)
-> (ThreadState_ Void -> STM (Either SomeException b))
-> STM (Either SomeException b)
forall a b. STM a -> (a -> STM b) -> STM b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      ThreadNotYetRunning_ Void
v -> Void -> STM (Either SomeException b)
forall a. Void -> a
absurd Void
v
      ThreadState_ Void
ThreadRunning_         -> b -> Either SomeException b
forall a b. b -> Either a b
Right (b -> Either SomeException b)
-> STM b -> STM (Either SomeException b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STM b
f
      ThreadState_ Void
ThreadDone_            -> b -> Either SomeException b
forall a b. b -> Either a b
Right (b -> Either SomeException b)
-> STM b -> STM (Either SomeException b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STM b
f
      ThreadException_ SomeException
e     -> Either SomeException b -> STM (Either SomeException b)
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SomeException b -> STM (Either SomeException b))
-> Either SomeException b -> STM (Either SomeException b)
forall a b. (a -> b) -> a -> b
$ SomeException -> Either SomeException b
forall a b. a -> Either a b
Left SomeException
e

waitUntilInitialized ::
     TVar (ThreadState a)
  -> STM (ThreadState_ Void)
waitUntilInitialized :: forall a. TVar (ThreadState a) -> STM (ThreadState_ Void)
waitUntilInitialized TVar (ThreadState a)
state = TVar (ThreadState a) -> STM Void -> STM (ThreadState_ Void)
forall a notRunning.
TVar (ThreadState a)
-> STM notRunning -> STM (ThreadState_ notRunning)
getThreadState_ TVar (ThreadState a)
state STM Void
forall a. STM a
retry

-- | An abstraction of 'ThreadState' without the public interface type.
data ThreadState_ notRunning =
      ThreadNotYetRunning_ notRunning
    | ThreadRunning_
    | ThreadDone_
    | ThreadException_ SomeException

getThreadState_ ::
     TVar (ThreadState a)
  -> STM notRunning
  -> STM (ThreadState_ notRunning)
getThreadState_ :: forall a notRunning.
TVar (ThreadState a)
-> STM notRunning -> STM (ThreadState_ notRunning)
getThreadState_ TVar (ThreadState a)
state STM notRunning
onNotRunning = do
    st <- TVar (ThreadState a) -> STM (ThreadState a)
forall a. TVar a -> STM a
readTVar TVar (ThreadState a)
state
    case st of
      ThreadNotStarted   DebugThreadId
_     -> notRunning -> ThreadState_ notRunning
forall notRunning. notRunning -> ThreadState_ notRunning
ThreadNotYetRunning_ (notRunning -> ThreadState_ notRunning)
-> STM notRunning -> STM (ThreadState_ notRunning)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STM notRunning
onNotRunning
      ThreadInitializing DebugThreadId
_ ThreadId
_   -> notRunning -> ThreadState_ notRunning
forall notRunning. notRunning -> ThreadState_ notRunning
ThreadNotYetRunning_ (notRunning -> ThreadState_ notRunning)
-> STM notRunning -> STM (ThreadState_ notRunning)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STM notRunning
onNotRunning
      ThreadRunning      DebugThreadId
_ ThreadId
_ a
_ -> ThreadState_ notRunning -> STM (ThreadState_ notRunning)
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return (ThreadState_ notRunning -> STM (ThreadState_ notRunning))
-> ThreadState_ notRunning -> STM (ThreadState_ notRunning)
forall a b. (a -> b) -> a -> b
$ ThreadState_ notRunning
forall notRunning. ThreadState_ notRunning
ThreadRunning_
      ThreadDone         DebugThreadId
_   a
_ -> ThreadState_ notRunning -> STM (ThreadState_ notRunning)
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return (ThreadState_ notRunning -> STM (ThreadState_ notRunning))
-> ThreadState_ notRunning -> STM (ThreadState_ notRunning)
forall a b. (a -> b) -> a -> b
$ ThreadState_ notRunning
forall notRunning. ThreadState_ notRunning
ThreadDone_
      ThreadException    DebugThreadId
_   SomeException
e -> ThreadState_ notRunning -> STM (ThreadState_ notRunning)
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return (ThreadState_ notRunning -> STM (ThreadState_ notRunning))
-> ThreadState_ notRunning -> STM (ThreadState_ notRunning)
forall a b. (a -> b) -> a -> b
$ SomeException -> ThreadState_ notRunning
forall notRunning. SomeException -> ThreadState_ notRunning
ThreadException_ SomeException
e

{-------------------------------------------------------------------------------
  Internal auxiliary
-------------------------------------------------------------------------------}

-- | Locally turn off deadlock detection
--
-- See also <https://well-typed.com/blog/2024/01/when-blocked-indefinitely-is-not-indefinite/>.
withoutDeadlockDetection :: IO a -> IO a
withoutDeadlockDetection :: forall a. IO a -> IO a
withoutDeadlockDetection IO a
k = do
    threadId <- IO ThreadId
myThreadId
    bracket (newStablePtr threadId) freeStablePtr $ \StablePtr ThreadId
_ -> IO a
k