module Database.PostgreSQL.Consumers.Utils
  ( finalize
  , ThrownFrom (..)
  , stopExecution
  , forkP
  , gforkP
  , preparedSqlName
  ) where

import Control.Concurrent.Lifted
import Control.Concurrent.Thread.Group.Lifted qualified as TG
import Control.Concurrent.Thread.Lifted qualified as T
import Control.Exception.Lifted qualified as E
import Control.Monad.Base
import Control.Monad.Catch
import Control.Monad.Trans.Control
import Data.Maybe
import Data.Text qualified as T
import Database.PostgreSQL.PQTypes.Class
import Database.PostgreSQL.PQTypes.SQL.Raw

-- | Run an action 'm' that returns a finalizer and perform the returned
-- finalizer after the action 'action' completes.
finalize :: (MonadMask m, MonadBase IO m) => m (m ()) -> m a -> m a
finalize :: forall (m :: * -> *) a.
(MonadMask m, MonadBase IO m) =>
m (m ()) -> m a -> m a
finalize m (m ())
m m a
action = do
  MVar (m ())
finalizer <- m (MVar (m ()))
forall (m :: * -> *) a. MonadBase IO m => m (MVar a)
newEmptyMVar
  (m a -> m () -> m a) -> m () -> m a -> m a
forall a b c. (a -> b -> c) -> b -> a -> c
flip m a -> m () -> m a
forall (m :: * -> *) a b.
(HasCallStack, MonadMask m) =>
m a -> m b -> m a
finally (MVar (m ()) -> m (Maybe (m ()))
forall (m :: * -> *) a. MonadBase IO m => MVar a -> m (Maybe a)
tryTakeMVar MVar (m ())
finalizer m (Maybe (m ())) -> (Maybe (m ()) -> m ()) -> m ()
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= m () -> Maybe (m ()) -> m ()
forall a. a -> Maybe a -> a
fromMaybe (() -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ())) (m a -> m a) -> m a -> m a
forall a b. (a -> b) -> a -> b
$ do
    MVar (m ()) -> m () -> m ()
forall (m :: * -> *) a. MonadBase IO m => MVar a -> a -> m ()
putMVar MVar (m ())
finalizer (m () -> m ()) -> m (m ()) -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (m ())
m
    m a
action

----------------------------------------

-- | Exception thrown to a thread to stop its execution.
--
-- All exceptions other than 'StopExecution' thrown to threads spawned by
-- 'forkP' and 'gforkP' are propagated back to the parent thread.
data StopExecution = StopExecution
  deriving (Int -> StopExecution -> ShowS
[StopExecution] -> ShowS
StopExecution -> String
(Int -> StopExecution -> ShowS)
-> (StopExecution -> String)
-> ([StopExecution] -> ShowS)
-> Show StopExecution
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> StopExecution -> ShowS
showsPrec :: Int -> StopExecution -> ShowS
$cshow :: StopExecution -> String
show :: StopExecution -> String
$cshowList :: [StopExecution] -> ShowS
showList :: [StopExecution] -> ShowS
Show)

instance Exception StopExecution where
  toException :: StopExecution -> SomeException
toException = StopExecution -> SomeException
forall e. Exception e => e -> SomeException
E.asyncExceptionToException
  fromException :: SomeException -> Maybe StopExecution
fromException = SomeException -> Maybe StopExecution
forall e. Exception e => SomeException -> Maybe e
E.asyncExceptionFromException

-- | Exception thrown from a child thread.
data ThrownFrom = ThrownFrom String SomeException
  deriving (Int -> ThrownFrom -> ShowS
[ThrownFrom] -> ShowS
ThrownFrom -> String
(Int -> ThrownFrom -> ShowS)
-> (ThrownFrom -> String)
-> ([ThrownFrom] -> ShowS)
-> Show ThrownFrom
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ThrownFrom -> ShowS
showsPrec :: Int -> ThrownFrom -> ShowS
$cshow :: ThrownFrom -> String
show :: ThrownFrom -> String
$cshowList :: [ThrownFrom] -> ShowS
showList :: [ThrownFrom] -> ShowS
Show)

instance Exception ThrownFrom

-- | Stop execution of a thread.
stopExecution :: MonadBase IO m => ThreadId -> m ()
stopExecution :: forall (m :: * -> *). MonadBase IO m => ThreadId -> m ()
stopExecution = (ThreadId -> StopExecution -> m ())
-> StopExecution -> ThreadId -> m ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip ThreadId -> StopExecution -> m ()
forall (m :: * -> *) e.
(MonadBase IO m, Exception e) =>
ThreadId -> e -> m ()
throwTo StopExecution
StopExecution

----------------------------------------

-- | Modified version of 'fork' that propagates thrown exceptions to the parent
-- thread.
forkP :: MonadBaseControl IO m => String -> m () -> m ThreadId
forkP :: forall (m :: * -> *).
MonadBaseControl IO m =>
String -> m () -> m ThreadId
forkP = (m () -> m ThreadId) -> String -> m () -> m ThreadId
forall (m :: * -> *) a.
MonadBaseControl IO m =>
(m () -> m a) -> String -> m () -> m a
forkImpl m () -> m ThreadId
forall (m :: * -> *). MonadBaseControl IO m => m () -> m ThreadId
fork

-- | Modified version of 'TG.fork' that propagates thrown exceptions to the
-- parent thread.
gforkP
  :: MonadBaseControl IO m
  => TG.ThreadGroup
  -> String
  -> m ()
  -> m (ThreadId, m (T.Result ()))
gforkP :: forall (m :: * -> *).
MonadBaseControl IO m =>
ThreadGroup -> String -> m () -> m (ThreadId, m (Result ()))
gforkP = (m () -> m (ThreadId, m (Result ())))
-> String -> m () -> m (ThreadId, m (Result ()))
forall (m :: * -> *) a.
MonadBaseControl IO m =>
(m () -> m a) -> String -> m () -> m a
forkImpl ((m () -> m (ThreadId, m (Result ())))
 -> String -> m () -> m (ThreadId, m (Result ())))
-> (ThreadGroup -> m () -> m (ThreadId, m (Result ())))
-> ThreadGroup
-> String
-> m ()
-> m (ThreadId, m (Result ()))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ThreadGroup -> m () -> m (ThreadId, m (Result ()))
forall (m :: * -> *) a.
MonadBaseControl IO m =>
ThreadGroup -> m a -> m (ThreadId, m (Result a))
TG.fork

----------------------------------------

forkImpl
  :: MonadBaseControl IO m
  => (m () -> m a)
  -> String
  -> m ()
  -> m a
forkImpl :: forall (m :: * -> *) a.
MonadBaseControl IO m =>
(m () -> m a) -> String -> m () -> m a
forkImpl m () -> m a
ffork String
tname m ()
m = ((forall a. m a -> m a) -> m a) -> m a
forall (m :: * -> *) b.
MonadBaseControl IO m =>
((forall a. m a -> m a) -> m b) -> m b
E.mask (((forall a. m a -> m a) -> m a) -> m a)
-> ((forall a. m a -> m a) -> m a) -> m a
forall a b. (a -> b) -> a -> b
$ \forall a. m a -> m a
release -> do
  ThreadId
parent <- m ThreadId
forall (m :: * -> *). MonadBase IO m => m ThreadId
myThreadId
  m () -> m a
ffork (m () -> m a) -> m () -> m a
forall a b. (a -> b) -> a -> b
$
    m () -> m ()
forall a. m a -> m a
release m ()
m
      m () -> [Handler m ()] -> m ()
forall (m :: * -> *) a.
MonadBaseControl IO m =>
m a -> [Handler m a] -> m a
`E.catches` [ (StopExecution -> m ()) -> Handler m ()
forall (m :: * -> *) a e. Exception e => (e -> m a) -> Handler m a
E.Handler ((StopExecution -> m ()) -> Handler m ())
-> (StopExecution -> m ()) -> Handler m ()
forall a b. (a -> b) -> a -> b
$ \StopExecution
StopExecution -> () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
                  , (SomeException -> m ()) -> Handler m ()
forall (m :: * -> *) a e. Exception e => (e -> m a) -> Handler m a
E.Handler ((SomeException -> m ()) -> Handler m ())
-> (SomeException -> m ()) -> Handler m ()
forall a b. (a -> b) -> a -> b
$ ThreadId -> ThrownFrom -> m ()
forall (m :: * -> *) e.
(MonadBase IO m, Exception e) =>
ThreadId -> e -> m ()
throwTo ThreadId
parent (ThrownFrom -> m ())
-> (SomeException -> ThrownFrom) -> SomeException -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> SomeException -> ThrownFrom
ThrownFrom String
tname
                  ]

preparedSqlName :: T.Text -> RawSQL () -> QueryName
preparedSqlName :: Text -> RawSQL () -> QueryName
preparedSqlName Text
baseName RawSQL ()
tableName = Text -> QueryName
QueryName (Text -> QueryName) -> (Text -> Text) -> Text -> QueryName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Text -> Text
T.take Int
63 (Text -> QueryName) -> Text -> QueryName
forall a b. (a -> b) -> a -> b
$ Text
baseName Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"$" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> RawSQL () -> Text
unRawSQL RawSQL ()
tableName