{-# LANGUAGE ExistentialQuantification #-}
module Control.Concurrent.Utils
  ( Lock()
  , mkExclusiveLock
  , mkQLock
  , withLock
  ) where

import Control.Monad.Catch (MonadMask)
import qualified Control.Monad.Catch as Catch
import Control.Concurrent.MVar
  ( newMVar
  , takeMVar
  , putMVar
  )
import Control.Concurrent.QSem
import Control.Monad.IO.Class (MonadIO, liftIO)

-- | Opaque lock.
data Lock = forall l . Lock l (l -> IO ()) (l -> IO ())

-- | Take a lock.
acquire :: MonadIO m => Lock -> m ()
acquire :: forall (m :: * -> *). MonadIO m => Lock -> m ()
acquire (Lock l
l l -> IO ()
acq l -> IO ()
_) = IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ l -> IO ()
acq l
l

-- | Release lock.
release :: MonadIO m => Lock -> m ()
release :: forall (m :: * -> *). MonadIO m => Lock -> m ()
release (Lock l
l l -> IO ()
_ l -> IO ()
rel) = IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ l -> IO ()
rel l
l

-- | Create exclusive lock. Only one process could take such lock.
mkExclusiveLock :: IO Lock
mkExclusiveLock :: IO Lock
mkExclusiveLock = MVar () -> (MVar () -> IO ()) -> (MVar () -> IO ()) -> Lock
forall l. l -> (l -> IO ()) -> (l -> IO ()) -> Lock
Lock (MVar () -> (MVar () -> IO ()) -> (MVar () -> IO ()) -> Lock)
-> IO (MVar ())
-> IO ((MVar () -> IO ()) -> (MVar () -> IO ()) -> Lock)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> () -> IO (MVar ())
forall a. a -> IO (MVar a)
newMVar () IO ((MVar () -> IO ()) -> (MVar () -> IO ()) -> Lock)
-> IO (MVar () -> IO ()) -> IO ((MVar () -> IO ()) -> Lock)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (MVar () -> IO ()) -> IO (MVar () -> IO ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure MVar () -> IO ()
forall a. MVar a -> IO a
takeMVar IO ((MVar () -> IO ()) -> Lock) -> IO (MVar () -> IO ()) -> IO Lock
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (MVar () -> IO ()) -> IO (MVar () -> IO ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((MVar () -> () -> IO ()) -> () -> MVar () -> IO ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip MVar () -> () -> IO ()
forall a. MVar a -> a -> IO ()
putMVar ())

-- | Create quantity lock. A fixed number of processes can take this lock simultaniously.
mkQLock :: Int -> IO Lock
mkQLock :: Int -> IO Lock
mkQLock Int
n = QSem -> (QSem -> IO ()) -> (QSem -> IO ()) -> Lock
forall l. l -> (l -> IO ()) -> (l -> IO ()) -> Lock
Lock (QSem -> (QSem -> IO ()) -> (QSem -> IO ()) -> Lock)
-> IO QSem -> IO ((QSem -> IO ()) -> (QSem -> IO ()) -> Lock)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> IO QSem
newQSem Int
n IO ((QSem -> IO ()) -> (QSem -> IO ()) -> Lock)
-> IO (QSem -> IO ()) -> IO ((QSem -> IO ()) -> Lock)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (QSem -> IO ()) -> IO (QSem -> IO ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure QSem -> IO ()
waitQSem IO ((QSem -> IO ()) -> Lock) -> IO (QSem -> IO ()) -> IO Lock
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (QSem -> IO ()) -> IO (QSem -> IO ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure QSem -> IO ()
signalQSem

-- | Run action under a held lock.
withLock :: (MonadMask m, MonadIO m) => Lock -> m a -> m a
withLock :: forall (m :: * -> *) a.
(MonadMask m, MonadIO m) =>
Lock -> m a -> m a
withLock Lock
excl =
  m () -> m () -> m a -> m a
forall (m :: * -> *) a c b.
(HasCallStack, MonadMask m) =>
m a -> m c -> m b -> m b
Catch.bracket_  (Lock -> m ()
forall (m :: * -> *). MonadIO m => Lock -> m ()
acquire Lock
excl)
                  (Lock -> m ()
forall (m :: * -> *). MonadIO m => Lock -> m ()
release Lock
excl)