-- | Serial Haskell module compilation
module Language.Haskell.TH.Lock
  ( ensureSerialCompilation
  , ensureSerialCompilationQuietly
  , ensureSerialCompilationVerbose
  ) where

import Control.Concurrent.MVar (newEmptyMVar, putMVar, takeMVar, tryPutMVar, tryReadMVar, MVar)
import Control.Monad (void)
import Language.Haskell.TH.Syntax
import Prelude hiding (log)
import System.IO.Unsafe (unsafePerformIO)

-- | 'MVar' holds a Haskell module name expecting sequentian
-- compilation and which is currently goes through the compilation
-- process.
moduleCompilationLock :: MVar String
moduleCompilationLock :: MVar String
moduleCompilationLock = IO (MVar String) -> MVar String
forall a. IO a -> a
unsafePerformIO (IO (MVar String)
forall a. IO (MVar a)
newEmptyMVar)

-- | Call this function right after import section to prevent
-- concurrent TH code exection for modules with mutable compile time
-- shared state.  The function acquires a lock in the scope of GHC
-- process and the lock is released once type checker completes module
-- verification.
ensureSerialCompilation :: (String -> IO ()) -> Q [Dec]
ensureSerialCompilation :: (String -> IO ()) -> Q [Dec]
ensureSerialCompilation String -> IO ()
log = do
  m <- Loc -> String
loc_module (Loc -> String) -> Q Loc -> Q String
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Q Loc
location
  go moduleCompilationLock m
  addModFinalizer (goAway moduleCompilationLock m)
  pure []
  where
    goAway :: MVar String -> String -> Q ()
goAway MVar String
l String
m = do
      IO (Maybe String) -> Q (Maybe String)
forall a. IO a -> Q a
runIO (MVar String -> IO (Maybe String)
forall a. MVar a -> IO (Maybe a)
tryReadMVar MVar String
l) Q (Maybe String) -> (Maybe String -> Q ()) -> Q ()
forall a b. Q a -> (a -> Q b) -> Q b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Just String
holdingModule ->
          if String
holdingModule String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
m
          then IO () -> Q ()
forall a. IO a -> Q a
runIO (IO () -> Q ()) -> IO () -> Q ()
forall a b. (a -> b) -> a -> b
$ do
            String -> IO ()
log (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"Module [" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
m String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"] released lock"
            IO String -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO String -> IO ()) -> IO String -> IO ()
forall a b. (a -> b) -> a -> b
$ MVar String -> IO String
forall a. MVar a -> IO a
takeMVar MVar String
l
          else
            String -> Q ()
reportError (String -> Q ()) -> String -> Q ()
forall a b. (a -> b) -> a -> b
$
              String
"Module [" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
m String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"] attempted to release lock of TH MVar holding by [" String -> String -> String
forall a. Semigroup a => a -> a -> a
<>
                String
holdingModule String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"]"
        Maybe String
Nothing ->
          String -> Q ()
reportError (String -> Q ()) -> String -> Q ()
forall a b. (a -> b) -> a -> b
$ String
"Module [" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
m String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"] attempted to release unlocked TH MVar"

    go :: MVar String -> String -> Q ()
go MVar String
l String
m = IO () -> Q ()
forall a. IO a -> Q a
runIO (MVar String -> String -> IO ()
goIo MVar String
l String
m)
    goIo :: MVar String -> String -> IO ()
goIo MVar String
l String
m =
      MVar String -> String -> IO Bool
forall a. MVar a -> a -> IO Bool
tryPutMVar MVar String
l String
m IO Bool -> (Bool -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Bool
True -> do
          String -> IO ()
log (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"Module [" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
m String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"] acquired TH lock"
        Bool
False -> MVar String -> IO (Maybe String)
forall a. MVar a -> IO (Maybe a)
tryReadMVar MVar String
l IO (Maybe String) -> (Maybe String -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
          Just String
holdingModule -> do
            String -> IO ()
log (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"Module [" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
m String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"] is waiting for [" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
holdingModule String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"]"
            MVar String -> String -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar String
l String
m
            String -> IO ()
log (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"Module [" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
m String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"] acquired TH lock"
          Maybe String
Nothing -> do
            String -> IO ()
log (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"Module [" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
m String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"] retries to acquire TH lock"
            MVar String -> String -> IO ()
goIo MVar String
l String
m

-- | 'ensureSerialCompilation' without lock logging
ensureSerialCompilationQuietly :: Q [Dec]
ensureSerialCompilationQuietly :: Q [Dec]
ensureSerialCompilationQuietly = (String -> IO ()) -> Q [Dec]
ensureSerialCompilation (IO String -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO String -> IO ()) -> (String -> IO String) -> String -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> IO String
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure)

-- | 'ensureSerialCompilation' with lock logging
ensureSerialCompilationVerbose :: Q [Dec]
ensureSerialCompilationVerbose :: Q [Dec]
ensureSerialCompilationVerbose = (String -> IO ()) -> Q [Dec]
ensureSerialCompilation String -> IO ()
putStrLn