module Hpgsql.Locking (getMyWeakThreadId, withMutex, Mutex, WeakThreadId) where
import Control.Concurrent (mkWeakThreadId, myThreadId)
import qualified Control.Concurrent.STM as STM
import Control.Exception.Safe (bracket)
import Control.Monad (void)
import Hpgsql.InternalTypes (Mutex (..), WeakThreadId (..), throwIrrecoverableError)
#if MIN_VERSION_base(4,19,0)
import GHC.Conc.Sync (fromThreadId)
#else
import GHC.Conc.Sync (showThreadId)
#endif
getMyWeakThreadId :: IO WeakThreadId
getMyWeakThreadId :: IO WeakThreadId
getMyWeakThreadId = do
tid <- IO ThreadId
myThreadId
wtid <- mkWeakThreadId tid
#if MIN_VERSION_base(4,19,0)
pure $ WeakThreadId wtid (fromThreadId tid)
#else
let tidStr = showThreadId tid
pure $ WeakThreadId wtid tidStr
#endif
withMutex ::
Mutex ->
IO a ->
IO a
withMutex :: forall a. Mutex -> IO a -> IO a
withMutex (Mutex TVar (Maybe (WeakThreadId, Int))
blockedByT) IO a
f = do
thisThreadId <- IO WeakThreadId
getMyWeakThreadId
bracket
( STM.atomically $ do
blockedBy <- STM.readTVar blockedByT
newSt :: (WeakThreadId, Int) <- case blockedBy of
Maybe (WeakThreadId, Int)
Nothing -> (WeakThreadId, Int) -> STM (WeakThreadId, Int)
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (WeakThreadId
thisThreadId, Int
1)
Just (WeakThreadId
tid, Int
nGrabs) ->
if WeakThreadId
tid WeakThreadId -> WeakThreadId -> Bool
forall a. Eq a => a -> a -> Bool
== WeakThreadId
thisThreadId
then (WeakThreadId, Int) -> STM (WeakThreadId, Int)
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (WeakThreadId
thisThreadId, Int
nGrabs Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
else STM (WeakThreadId, Int)
forall a. STM a
STM.retry
STM.writeTVar blockedByT (Just newSt)
)
( const $ void $ STM.atomically $ do
blockedBy <- STM.readTVar blockedByT
newSt <- case blockedBy of
Maybe (WeakThreadId, Int)
Nothing -> Text -> STM (Maybe (WeakThreadId, Int))
forall (m :: * -> *) a. MonadThrow m => Text -> m a
throwIrrecoverableError Text
"Impossible: should have been blocked but was not!"
Just (WeakThreadId
tid, Int
nGrabs) ->
let newLockState :: Maybe (WeakThreadId, Int)
newLockState = if Int
nGrabs Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1 then Maybe (WeakThreadId, Int)
forall a. Maybe a
Nothing else (WeakThreadId, Int) -> Maybe (WeakThreadId, Int)
forall a. a -> Maybe a
Just (WeakThreadId
thisThreadId, Int
nGrabs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
in if WeakThreadId
tid WeakThreadId -> WeakThreadId -> Bool
forall a. Eq a => a -> a -> Bool
== WeakThreadId
thisThreadId then Maybe (WeakThreadId, Int) -> STM (Maybe (WeakThreadId, Int))
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (WeakThreadId, Int)
newLockState else Text -> STM (Maybe (WeakThreadId, Int))
forall (m :: * -> *) a. MonadThrow m => Text -> m a
throwIrrecoverableError Text
"Impossible: Lock of a different thread!"
STM.writeTVar blockedByT newSt
)
$ \() -> IO a
f