module Hpgsql.Locking (getMyWeakThreadId, withMutex, Mutex, WeakThreadId) where

import Control.Concurrent (mkWeakThreadId, myThreadId)
import qualified Control.Concurrent.STM as STM
import Control.Exception.Safe (bracket)
import Control.Monad (void)
import Hpgsql.InternalTypes (Mutex (..), WeakThreadId (..), throwIrrecoverableError)
#if MIN_VERSION_base(4,19,0)
import GHC.Conc.Sync (fromThreadId)
#else
import GHC.Conc.Sync (showThreadId)
#endif

getMyWeakThreadId :: IO WeakThreadId
getMyWeakThreadId :: IO WeakThreadId
getMyWeakThreadId = do
  -- We don't keep a reference to `ThreadId` as it can stop threads from getting
  -- runtime exceptions and can prevent dead threads from being garbage-collected.
  -- It's explained somewhere in hackage.
  tid <- IO ThreadId
myThreadId
  wtid <- mkWeakThreadId tid
#if MIN_VERSION_base(4,19,0)
  pure $ WeakThreadId wtid (fromThreadId tid)
#else
  let tidStr = showThreadId tid
  pure $ WeakThreadId wtid tidStr
#endif

withMutex ::
  Mutex ->
  IO a ->
  IO a
withMutex :: forall a. Mutex -> IO a -> IO a
withMutex (Mutex TVar (Maybe (WeakThreadId, Int))
blockedByT) IO a
f = do
  thisThreadId <- IO WeakThreadId
getMyWeakThreadId
  bracket
    ( STM.atomically $ do
        blockedBy <- STM.readTVar blockedByT
        newSt :: (WeakThreadId, Int) <- case blockedBy of
          Maybe (WeakThreadId, Int)
Nothing -> (WeakThreadId, Int) -> STM (WeakThreadId, Int)
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure {- traceShowWith ("Grabbing ",) $ -} (WeakThreadId
thisThreadId, Int
1)
          Just (WeakThreadId
tid, Int
nGrabs) ->
            if WeakThreadId
tid WeakThreadId -> WeakThreadId -> Bool
forall a. Eq a => a -> a -> Bool
== WeakThreadId
thisThreadId
              then (WeakThreadId, Int) -> STM (WeakThreadId, Int)
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure {- traceShowWith ("Grabbing ",) $ -} (WeakThreadId
thisThreadId, Int
nGrabs Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
              else STM (WeakThreadId, Int)
forall a. STM a
STM.retry
        STM.writeTVar blockedByT (Just newSt)
    )
    -- Release lock on success or error
    ( const $ void $ STM.atomically $ do
        blockedBy <- STM.readTVar blockedByT
        newSt <- case blockedBy of
          Maybe (WeakThreadId, Int)
Nothing -> Text -> STM (Maybe (WeakThreadId, Int))
forall (m :: * -> *) a. MonadThrow m => Text -> m a
throwIrrecoverableError Text
"Impossible: should have been blocked but was not!"
          Just (WeakThreadId
tid, Int
nGrabs) ->
            let newLockState :: Maybe (WeakThreadId, Int)
newLockState = {- traceShowWith ("Releasing ",) $ -} if Int
nGrabs Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1 then Maybe (WeakThreadId, Int)
forall a. Maybe a
Nothing else (WeakThreadId, Int) -> Maybe (WeakThreadId, Int)
forall a. a -> Maybe a
Just (WeakThreadId
thisThreadId, Int
nGrabs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
             in if WeakThreadId
tid WeakThreadId -> WeakThreadId -> Bool
forall a. Eq a => a -> a -> Bool
== WeakThreadId
thisThreadId then Maybe (WeakThreadId, Int) -> STM (Maybe (WeakThreadId, Int))
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (WeakThreadId, Int)
newLockState else Text -> STM (Maybe (WeakThreadId, Int))
forall (m :: * -> *) a. MonadThrow m => Text -> m a
throwIrrecoverableError Text
"Impossible: Lock of a different thread!"
        STM.writeTVar blockedByT newSt
        -- debugPrint "Internal state: [Released control-msg-lock]."
    )
    $ \() -> IO a
f