module Database.PostgreSQL.Consumers.Utils
  ( finalize
  , ThrownFrom (..)
  , stopExecution
  , forkP
  , gforkP
  , preparedSqlName
  ) where

import Control.Concurrent.Lifted
import Control.Concurrent.Thread.Group.Lifted qualified as TG
import Control.Concurrent.Thread.Lifted qualified as T
import Control.Exception.Lifted qualified as E
import Control.Monad.Base
import Control.Monad.Catch
import Control.Monad.Trans.Control
import Data.Maybe
import Data.Text qualified as T
import Database.PostgreSQL.PQTypes.Class
import Database.PostgreSQL.PQTypes.SQL.Raw

-- | Run an action 'm' that returns a finalizer and perform the returned
-- finalizer after the action 'action' completes.
finalize :: (MonadMask m, MonadBase IO m) => m (m ()) -> m a -> m a
finalize :: forall (m :: * -> *) a.
(MonadMask m, MonadBase IO m) =>
m (m ()) -> m a -> m a
finalize m (m ())
m m a
action = do
  finalizer <- m (MVar (m ()))
forall (m :: * -> *) a. MonadBase IO m => m (MVar a)
newEmptyMVar
  flip finally (tryTakeMVar finalizer >>= fromMaybe (pure ())) $ do
    putMVar finalizer =<< m
    action

----------------------------------------

-- | Exception thrown to a thread to stop its execution.
--
-- All exceptions other than 'StopExecution' thrown to threads spawned by
-- 'forkP' and 'gforkP' are propagated back to the parent thread.
data StopExecution = StopExecution
  deriving (Int -> StopExecution -> ShowS
[StopExecution] -> ShowS
StopExecution -> String
(Int -> StopExecution -> ShowS)
-> (StopExecution -> String)
-> ([StopExecution] -> ShowS)
-> Show StopExecution
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> StopExecution -> ShowS
showsPrec :: Int -> StopExecution -> ShowS
$cshow :: StopExecution -> String
show :: StopExecution -> String
$cshowList :: [StopExecution] -> ShowS
showList :: [StopExecution] -> ShowS
Show)

instance Exception StopExecution where
  toException :: StopExecution -> SomeException
toException = StopExecution -> SomeException
forall e. Exception e => e -> SomeException
E.asyncExceptionToException
  fromException :: SomeException -> Maybe StopExecution
fromException = SomeException -> Maybe StopExecution
forall e. Exception e => SomeException -> Maybe e
E.asyncExceptionFromException

-- | Exception thrown from a child thread.
data ThrownFrom = ThrownFrom String SomeException
  deriving (Int -> ThrownFrom -> ShowS
[ThrownFrom] -> ShowS
ThrownFrom -> String
(Int -> ThrownFrom -> ShowS)
-> (ThrownFrom -> String)
-> ([ThrownFrom] -> ShowS)
-> Show ThrownFrom
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ThrownFrom -> ShowS
showsPrec :: Int -> ThrownFrom -> ShowS
$cshow :: ThrownFrom -> String
show :: ThrownFrom -> String
$cshowList :: [ThrownFrom] -> ShowS
showList :: [ThrownFrom] -> ShowS
Show)

instance Exception ThrownFrom

-- | Stop execution of a thread.
stopExecution :: MonadBase IO m => ThreadId -> m ()
stopExecution :: forall (m :: * -> *). MonadBase IO m => ThreadId -> m ()
stopExecution = (ThreadId -> StopExecution -> m ())
-> StopExecution -> ThreadId -> m ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip ThreadId -> StopExecution -> m ()
forall (m :: * -> *) e.
(MonadBase IO m, Exception e) =>
ThreadId -> e -> m ()
throwTo StopExecution
StopExecution

----------------------------------------

-- | Modified version of 'fork' that propagates thrown exceptions to the parent
-- thread.
forkP :: MonadBaseControl IO m => String -> m () -> m ThreadId
forkP :: forall (m :: * -> *).
MonadBaseControl IO m =>
String -> m () -> m ThreadId
forkP = (m () -> m ThreadId) -> String -> m () -> m ThreadId
forall (m :: * -> *) a.
MonadBaseControl IO m =>
(m () -> m a) -> String -> m () -> m a
forkImpl m () -> m ThreadId
forall (m :: * -> *). MonadBaseControl IO m => m () -> m ThreadId
fork

-- | Modified version of 'TG.fork' that propagates thrown exceptions to the
-- parent thread.
gforkP
  :: MonadBaseControl IO m
  => TG.ThreadGroup
  -> String
  -> m ()
  -> m (ThreadId, m (T.Result ()))
gforkP :: forall (m :: * -> *).
MonadBaseControl IO m =>
ThreadGroup -> String -> m () -> m (ThreadId, m (Result ()))
gforkP = (m () -> m (ThreadId, m (Result ())))
-> String -> m () -> m (ThreadId, m (Result ()))
forall (m :: * -> *) a.
MonadBaseControl IO m =>
(m () -> m a) -> String -> m () -> m a
forkImpl ((m () -> m (ThreadId, m (Result ())))
 -> String -> m () -> m (ThreadId, m (Result ())))
-> (ThreadGroup -> m () -> m (ThreadId, m (Result ())))
-> ThreadGroup
-> String
-> m ()
-> m (ThreadId, m (Result ()))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ThreadGroup -> m () -> m (ThreadId, m (Result ()))
forall (m :: * -> *) a.
MonadBaseControl IO m =>
ThreadGroup -> m a -> m (ThreadId, m (Result a))
TG.fork

----------------------------------------

forkImpl
  :: MonadBaseControl IO m
  => (m () -> m a)
  -> String
  -> m ()
  -> m a
forkImpl :: forall (m :: * -> *) a.
MonadBaseControl IO m =>
(m () -> m a) -> String -> m () -> m a
forkImpl m () -> m a
ffork String
tname m ()
m = ((forall a. m a -> m a) -> m a) -> m a
forall (m :: * -> *) b.
MonadBaseControl IO m =>
((forall a. m a -> m a) -> m b) -> m b
E.mask (((forall a. m a -> m a) -> m a) -> m a)
-> ((forall a. m a -> m a) -> m a) -> m a
forall a b. (a -> b) -> a -> b
$ \forall a. m a -> m a
release -> do
  parent <- m ThreadId
forall (m :: * -> *). MonadBase IO m => m ThreadId
myThreadId
  ffork $
    release m
      `E.catches` [ E.Handler $ \StopExecution
StopExecution -> () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
                  , E.Handler $ throwTo parent . ThrownFrom tname
                  ]

preparedSqlName :: T.Text -> RawSQL () -> QueryName
preparedSqlName :: Text -> RawSQL () -> QueryName
preparedSqlName Text
baseName RawSQL ()
tableName = Text -> QueryName
QueryName (Text -> QueryName) -> (Text -> Text) -> Text -> QueryName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Text -> Text
T.take Int
63 (Text -> QueryName) -> Text -> QueryName
forall a b. (a -> b) -> a -> b
$ Text
baseName Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"$" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> RawSQL () -> Text
unRawSQL RawSQL ()
tableName