-- | Definition of internal DBT state.
module Database.PostgreSQL.PQTypes.Internal.State
  ( -- * ConnectionData
    ConnectionData
  , getConnectionSource
  , getConnectionAcquisitionModeIO
  , withConnectionData
  , changeAcquisitionModeTo
  , withConnection

    -- * DBState
  , DBState (..)
  , mkDBState
  , updateStateWith
  ) where

import Control.Concurrent.MVar.Lifted
import Control.Monad
import Control.Monad.Base
import Control.Monad.Catch
import Data.Function
import Data.Typeable
import Foreign.ForeignPtr
import GHC.Stack

import Data.Monoid.Utils
import Database.PostgreSQL.PQTypes.FromRow
import Database.PostgreSQL.PQTypes.Internal.BackendPid
import Database.PostgreSQL.PQTypes.Internal.C.Types
import Database.PostgreSQL.PQTypes.Internal.Connection
import Database.PostgreSQL.PQTypes.Internal.Exception
import Database.PostgreSQL.PQTypes.Internal.QueryResult
import Database.PostgreSQL.PQTypes.SQL
import Database.PostgreSQL.PQTypes.SQL.Class
import Database.PostgreSQL.PQTypes.Transaction.Settings

data ConnectionState cdata
  = OnDemand
  | Acquired !IsolationLevel !Permissions !Connection !cdata
  | Finalized

-- Note: initConnection{State,Data} and finalizeConnection{State,Data} need to
-- be invoked inside bracket and run with asynchronous exceptions softly
-- masked. In addition, they may run queries that start/finish a
-- transaction. Running queries is a blocking (and thus interruptible)
-- operation, but if these queries are interrupted with an asynchronous
-- exception, then a connection is leaked, so they need to be run with
-- asynchronous exceptions hard masked with uninterruptibleMask.
--
-- What is more, these queries themselves can throw an exception if e.g. network
-- goes bye-bye and PostgreSQL can't be reached, therefore they need exception
-- handlers themselves so connections don't leak.

initConnectionState
  :: (MonadBase IO m, MonadMask m)
  => InternalConnectionSource m cdata
  -> ConnectionAcquisitionMode
  -> m (ConnectionState cdata)
initConnectionState :: forall (m :: * -> *) cdata.
(MonadBase IO m, MonadMask m) =>
InternalConnectionSource m cdata
-> ConnectionAcquisitionMode -> m (ConnectionState cdata)
initConnectionState InternalConnectionSource m cdata
ics = \case
  ConnectionAcquisitionMode
AcquireOnDemand -> ConnectionState cdata -> m (ConnectionState cdata)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ConnectionState cdata
forall cdata. ConnectionState cdata
OnDemand
  AcquireAndHold IsolationLevel
tsIsolationLevel Permissions
tsPermissions -> do
    let initSql :: SQL
initSql =
          [SQL] -> SQL
forall m. (IsString m, Monoid m) => [m] -> m
smconcat
            [ SQL
"BEGIN"
            , case IsolationLevel
tsIsolationLevel of
                IsolationLevel
DefaultLevel -> SQL
""
                IsolationLevel
ReadCommitted -> SQL
"ISOLATION LEVEL READ COMMITTED"
                IsolationLevel
RepeatableRead -> SQL
"ISOLATION LEVEL REPEATABLE READ"
                IsolationLevel
Serializable -> SQL
"ISOLATION LEVEL SERIALIZABLE"
            , case Permissions
tsPermissions of
                Permissions
DefaultPermissions -> SQL
""
                Permissions
ReadOnly -> SQL
"READ ONLY"
                Permissions
ReadWrite -> SQL
"READ WRITE"
            ]
    (conn, cdata) <- InternalConnectionSource m cdata -> m (Connection, cdata)
forall (m :: * -> *) cdata.
InternalConnectionSource m cdata -> m (Connection, cdata)
takeConnection InternalConnectionSource m cdata
ics
    _ <- uninterruptibleMask_ $ do
      liftBase (runQueryIO @SQL conn initSql) `catch` \SomeException
e -> do
        InternalConnectionSource m cdata
-> forall r. (Connection, cdata) -> ExitCase r -> m ()
forall (m :: * -> *) cdata.
InternalConnectionSource m cdata
-> forall r. (Connection, cdata) -> ExitCase r -> m ()
putConnection InternalConnectionSource m cdata
ics (Connection
conn, cdata
cdata) (SomeException -> ExitCase (ZonkAny 1)
forall a. SomeException -> ExitCase a
ExitCaseException SomeException
e)
        SomeException
-> m (Int, ForeignPtr PGresult, ConnectionStats -> ConnectionStats)
forall e a. (HasCallStack, Exception e) => e -> m a
forall (m :: * -> *) e a.
(MonadThrow m, HasCallStack, Exception e) =>
e -> m a
throwM SomeException
e
    pure $ Acquired tsIsolationLevel tsPermissions conn cdata

finalizeConnectionState
  :: (HasCallStack, MonadBase IO m, MonadMask m)
  => InternalConnectionSource m cdata
  -> ExitCase r
  -> ConnectionState cdata
  -> m ()
finalizeConnectionState :: forall (m :: * -> *) cdata r.
(HasCallStack, MonadBase IO m, MonadMask m) =>
InternalConnectionSource m cdata
-> ExitCase r -> ConnectionState cdata -> m ()
finalizeConnectionState InternalConnectionSource m cdata
ics ExitCase r
ec = \case
  ConnectionState cdata
OnDemand -> () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  Acquired IsolationLevel
_ Permissions
_ Connection
conn cdata
cdata -> do
    let finalizeSql :: SQL
finalizeSql = case ExitCase r
ec of
          ExitCaseSuccess r
_ -> SQL
"COMMIT"
          ExitCase r
_ -> SQL
"ROLLBACK"
    _ <- m (Int, ForeignPtr PGresult, ConnectionStats -> ConnectionStats)
-> m (Int, ForeignPtr PGresult, ConnectionStats -> ConnectionStats)
forall (m :: * -> *) a. (HasCallStack, MonadMask m) => m a -> m a
uninterruptibleMask_ (m (Int, ForeignPtr PGresult, ConnectionStats -> ConnectionStats)
 -> m (Int, ForeignPtr PGresult,
       ConnectionStats -> ConnectionStats))
-> m (Int, ForeignPtr PGresult, ConnectionStats -> ConnectionStats)
-> m (Int, ForeignPtr PGresult, ConnectionStats -> ConnectionStats)
forall a b. (a -> b) -> a -> b
$ do
      IO (Int, ForeignPtr PGresult, ConnectionStats -> ConnectionStats)
-> m (Int, ForeignPtr PGresult, ConnectionStats -> ConnectionStats)
forall α. IO α -> m α
forall (b :: * -> *) (m :: * -> *) α. MonadBase b m => b α -> m α
liftBase (forall sql.
(HasCallStack, IsSQL sql) =>
Connection
-> sql
-> IO
     (Int, ForeignPtr PGresult, ConnectionStats -> ConnectionStats)
runQueryIO @SQL Connection
conn SQL
finalizeSql) m (Int, ForeignPtr PGresult, ConnectionStats -> ConnectionStats)
-> (SomeException
    -> m (Int, ForeignPtr PGresult,
          ConnectionStats -> ConnectionStats))
-> m (Int, ForeignPtr PGresult, ConnectionStats -> ConnectionStats)
forall e a. (HasCallStack, Exception e) => m a -> (e -> m a) -> m a
forall (m :: * -> *) e a.
(MonadCatch m, HasCallStack, Exception e) =>
m a -> (e -> m a) -> m a
`catch` \SomeException
e -> do
        InternalConnectionSource m cdata
-> forall r. (Connection, cdata) -> ExitCase r -> m ()
forall (m :: * -> *) cdata.
InternalConnectionSource m cdata
-> forall r. (Connection, cdata) -> ExitCase r -> m ()
putConnection InternalConnectionSource m cdata
ics (Connection
conn, cdata
cdata) (SomeException -> ExitCase (ZonkAny 0)
forall a. SomeException -> ExitCase a
ExitCaseException SomeException
e)
        SomeException
-> m (Int, ForeignPtr PGresult, ConnectionStats -> ConnectionStats)
forall e a. (HasCallStack, Exception e) => e -> m a
forall (m :: * -> *) e a.
(MonadThrow m, HasCallStack, Exception e) =>
e -> m a
throwM SomeException
e
    putConnection ics (conn, cdata) ec
  ConnectionState cdata
Finalized -> String -> m ()
forall a. HasCallStack => String -> a
error String
"finalized connection"

----------------------------------------

data ConnectionData m = forall cdata. ConnectionData
  { ()
cdConnectionSource :: !(InternalConnectionSource m cdata)
  , ()
cdConnectionState :: !(MVar (ConnectionState cdata))
  }

getConnectionSource :: ConnectionData m -> ConnectionSourceM m
getConnectionSource :: forall (m :: * -> *). ConnectionData m -> ConnectionSourceM m
getConnectionSource ConnectionData {MVar (ConnectionState cdata)
InternalConnectionSource m cdata
cdConnectionSource :: ()
cdConnectionState :: ()
cdConnectionSource :: InternalConnectionSource m cdata
cdConnectionState :: MVar (ConnectionState cdata)
..} = InternalConnectionSource m cdata -> ConnectionSourceM m
forall (m :: * -> *) cdata.
InternalConnectionSource m cdata -> ConnectionSourceM m
ConnectionSourceM InternalConnectionSource m cdata
cdConnectionSource

getConnectionAcquisitionModeIO
  :: HasCallStack
  => ConnectionData m
  -> IO ConnectionAcquisitionMode
getConnectionAcquisitionModeIO :: forall (m :: * -> *).
HasCallStack =>
ConnectionData m -> IO ConnectionAcquisitionMode
getConnectionAcquisitionModeIO ConnectionData {MVar (ConnectionState cdata)
InternalConnectionSource m cdata
cdConnectionSource :: ()
cdConnectionState :: ()
cdConnectionSource :: InternalConnectionSource m cdata
cdConnectionState :: MVar (ConnectionState cdata)
..} = do
  MVar (ConnectionState cdata) -> IO (ConnectionState cdata)
forall (m :: * -> *) a. MonadBase IO m => MVar a -> m a
readMVar MVar (ConnectionState cdata)
cdConnectionState IO (ConnectionState cdata)
-> (ConnectionState cdata -> IO ConnectionAcquisitionMode)
-> IO ConnectionAcquisitionMode
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    ConnectionState cdata
OnDemand -> ConnectionAcquisitionMode -> IO ConnectionAcquisitionMode
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ConnectionAcquisitionMode
AcquireOnDemand
    Acquired IsolationLevel
isolationLevel Permissions
permissions Connection
_ cdata
_ -> do
      ConnectionAcquisitionMode -> IO ConnectionAcquisitionMode
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ConnectionAcquisitionMode -> IO ConnectionAcquisitionMode)
-> ConnectionAcquisitionMode -> IO ConnectionAcquisitionMode
forall a b. (a -> b) -> a -> b
$ IsolationLevel -> Permissions -> ConnectionAcquisitionMode
AcquireAndHold IsolationLevel
isolationLevel Permissions
permissions
    ConnectionState cdata
Finalized -> String -> IO ConnectionAcquisitionMode
forall a. HasCallStack => String -> a
error String
"finalized connection"

withConnectionData
  :: (HasCallStack, MonadBase IO m, MonadMask m)
  => ConnectionSourceM m
  -> TransactionSettings
  -> (ConnectionData m -> m r)
  -> m r
withConnectionData :: forall (m :: * -> *) r.
(HasCallStack, MonadBase IO m, MonadMask m) =>
ConnectionSourceM m
-> TransactionSettings -> (ConnectionData m -> m r) -> m r
withConnectionData ConnectionSourceM m
cs TransactionSettings
ts ConnectionData m -> m r
action = (((Integer -> m r) -> Integer -> m r) -> Integer -> m r
forall a. (a -> a) -> a
`fix` Integer
1) (((Integer -> m r) -> Integer -> m r) -> m r)
-> ((Integer -> m r) -> Integer -> m r) -> m r
forall a b. (a -> b) -> a -> b
$ \Integer -> m r
loop Integer
n -> do
  let maybeRestart :: m r -> m r
maybeRestart = case TransactionSettings -> Maybe RestartPredicate
tsRestartPredicate TransactionSettings
ts of
        Just RestartPredicate
_ -> (SomeException -> Maybe ()) -> (() -> m r) -> m r -> m r
forall (m :: * -> *) e b a.
(HasCallStack, MonadCatch m, Exception e) =>
(e -> Maybe b) -> (b -> m a) -> m a -> m a
handleJust (Integer -> SomeException -> Maybe ()
expred Integer
n) ((() -> m r) -> m r -> m r) -> (() -> m r) -> m r -> m r
forall a b. (a -> b) -> a -> b
$ \()
_ -> Integer -> m r
loop (Integer -> m r) -> Integer -> m r
forall a b. (a -> b) -> a -> b
$ Integer
n Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1
        Maybe RestartPredicate
Nothing -> m r -> m r
forall a. a -> a
id
  m r -> m r
maybeRestart
    (m r -> m r)
-> ((ConnectionData m -> m r) -> m r)
-> (ConnectionData m -> m r)
-> m r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((r, ()) -> r) -> m (r, ()) -> m r
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (r, ()) -> r
forall a b. (a, b) -> a
fst
    (m (r, ()) -> m r)
-> ((ConnectionData m -> m r) -> m (r, ()))
-> (ConnectionData m -> m r)
-> m r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m (ConnectionData m)
-> (ConnectionData m -> ExitCase r -> m ())
-> (ConnectionData m -> m r)
-> m (r, ())
forall a b c.
HasCallStack =>
m a -> (a -> ExitCase b -> m c) -> (a -> m b) -> m (b, c)
forall (m :: * -> *) a b c.
(MonadMask m, HasCallStack) =>
m a -> (a -> ExitCase b -> m c) -> (a -> m b) -> m (b, c)
generalBracket (ConnectionSourceM m
-> ConnectionAcquisitionMode -> m (ConnectionData m)
forall (m :: * -> *).
(MonadBase IO m, MonadMask m) =>
ConnectionSourceM m
-> ConnectionAcquisitionMode -> m (ConnectionData m)
initConnectionData ConnectionSourceM m
cs ConnectionAcquisitionMode
cam) ConnectionData m -> ExitCase r -> m ()
forall (m :: * -> *) r.
(HasCallStack, MonadBase IO m, MonadMask m) =>
ConnectionData m -> ExitCase r -> m ()
finalizeConnectionData
    ((ConnectionData m -> m r) -> m r)
-> (ConnectionData m -> m r) -> m r
forall a b. (a -> b) -> a -> b
$ ConnectionData m -> m r
action
  where
    cam :: ConnectionAcquisitionMode
cam = TransactionSettings -> ConnectionAcquisitionMode
tsConnectionAcquisitionMode TransactionSettings
ts

    expred :: Integer -> SomeException -> Maybe ()
    expred :: Integer -> SomeException -> Maybe ()
expred Integer
n SomeException
e = do
      -- check if the predicate exists
      RestartPredicate f <- TransactionSettings -> Maybe RestartPredicate
tsRestartPredicate TransactionSettings
ts
      -- cast exception to the type expected by the predicate
      err <-
        msum
          [ -- either cast the exception itself...
            fromException e
          , -- ...or extract it from DBException
            fromException e >>= \DBException {e
sql
CallStack
BackendPid
dbeQueryContext :: sql
dbeBackendPid :: BackendPid
dbeError :: e
dbeCallStack :: CallStack
dbeBackendPid :: DBException -> BackendPid
dbeCallStack :: DBException -> CallStack
dbeError :: ()
dbeQueryContext :: ()
..} -> e -> Maybe e
forall a b. (Typeable a, Typeable b) => a -> Maybe b
cast e
dbeError
          ]
      -- check if the predicate allows for the restart
      guard $ f err n

changeAcquisitionModeTo
  :: (HasCallStack, MonadBase IO m, MonadMask m)
  => ConnectionAcquisitionMode
  -> ConnectionData m
  -> m ()
changeAcquisitionModeTo :: forall (m :: * -> *).
(HasCallStack, MonadBase IO m, MonadMask m) =>
ConnectionAcquisitionMode -> ConnectionData m -> m ()
changeAcquisitionModeTo ConnectionAcquisitionMode
cam ConnectionData {MVar (ConnectionState cdata)
InternalConnectionSource m cdata
cdConnectionSource :: ()
cdConnectionState :: ()
cdConnectionSource :: InternalConnectionSource m cdata
cdConnectionState :: MVar (ConnectionState cdata)
..} = do
  m (ConnectionState cdata)
-> (ConnectionState cdata -> m ())
-> (ConnectionState cdata -> m ())
-> m ()
forall (m :: * -> *) a c b.
(HasCallStack, MonadMask m) =>
m a -> (a -> m c) -> (a -> m b) -> m b
bracketOnError (MVar (ConnectionState cdata) -> m (ConnectionState cdata)
forall (m :: * -> *) a. MonadBase IO m => MVar a -> m a
takeMVar MVar (ConnectionState cdata)
cdConnectionState) (MVar (ConnectionState cdata) -> ConnectionState cdata -> m ()
forall (m :: * -> *) a. MonadBase IO m => MVar a -> a -> m ()
putMVar MVar (ConnectionState cdata)
cdConnectionState) ((ConnectionState cdata -> m ()) -> m ())
-> (ConnectionState cdata -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \case
    ConnectionState cdata
OnDemand -> case ConnectionAcquisitionMode
cam of
      ConnectionAcquisitionMode
AcquireOnDemand -> MVar (ConnectionState cdata) -> ConnectionState cdata -> m ()
forall (m :: * -> *) a. MonadBase IO m => MVar a -> a -> m ()
putMVar MVar (ConnectionState cdata)
cdConnectionState ConnectionState cdata
forall cdata. ConnectionState cdata
OnDemand
      ConnectionAcquisitionMode
_ -> m () -> m ()
forall (m :: * -> *) a. (HasCallStack, MonadMask m) => m a -> m a
mask_ (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
        -- Need to mask, if asynchronous exception arrives between
        -- initConnectionState and putMVar, the connection leaks.
        newConnState <- InternalConnectionSource m cdata
-> ConnectionAcquisitionMode -> m (ConnectionState cdata)
forall (m :: * -> *) cdata.
(MonadBase IO m, MonadMask m) =>
InternalConnectionSource m cdata
-> ConnectionAcquisitionMode -> m (ConnectionState cdata)
initConnectionState InternalConnectionSource m cdata
cdConnectionSource ConnectionAcquisitionMode
cam
        putMVar cdConnectionState newConnState
    connState :: ConnectionState cdata
connState@(Acquired IsolationLevel
isolationLevel Permissions
permissions Connection
_ cdata
_) -> case ConnectionAcquisitionMode
cam of
      ConnectionAcquisitionMode
AcquireOnDemand -> m () -> m ()
forall (m :: * -> *) a. (HasCallStack, MonadMask m) => m a -> m a
mask_ (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
        -- Need to mask, if asynchronous exception arrives between
        -- finalizeConnectionState and putMVar, we end up with an invalid
        -- (finalized) connection state.
        InternalConnectionSource m cdata
-> ExitCase () -> ConnectionState cdata -> m ()
forall (m :: * -> *) cdata r.
(HasCallStack, MonadBase IO m, MonadMask m) =>
InternalConnectionSource m cdata
-> ExitCase r -> ConnectionState cdata -> m ()
finalizeConnectionState InternalConnectionSource m cdata
cdConnectionSource (() -> ExitCase ()
forall a. a -> ExitCase a
ExitCaseSuccess ()) ConnectionState cdata
connState
        MVar (ConnectionState cdata) -> ConnectionState cdata -> m ()
forall (m :: * -> *) a. MonadBase IO m => MVar a -> a -> m ()
putMVar MVar (ConnectionState cdata)
cdConnectionState ConnectionState cdata
forall cdata. ConnectionState cdata
OnDemand
      AcquireAndHold IsolationLevel
newIsolationLevel Permissions
newPermissions -> do
        Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (IsolationLevel
isolationLevel IsolationLevel -> IsolationLevel -> Bool
forall a. Eq a => a -> a -> Bool
/= IsolationLevel
newIsolationLevel) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
          String -> m ()
forall a. HasCallStack => String -> a
error (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$
            String
"isolation level mismatch (current: "
              String -> String -> String
forall a. [a] -> [a] -> [a]
++ IsolationLevel -> String
forall a. Show a => a -> String
show IsolationLevel
isolationLevel
              String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
", new: "
              String -> String -> String
forall a. [a] -> [a] -> [a]
++ IsolationLevel -> String
forall a. Show a => a -> String
show IsolationLevel
newIsolationLevel
              String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
")"
        Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Permissions
permissions Permissions -> Permissions -> Bool
forall a. Eq a => a -> a -> Bool
/= Permissions
newPermissions) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
          String -> m ()
forall a. HasCallStack => String -> a
error (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$
            String
"permissions mismatch (current: "
              String -> String -> String
forall a. [a] -> [a] -> [a]
++ Permissions -> String
forall a. Show a => a -> String
show Permissions
permissions
              String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
", new: "
              String -> String -> String
forall a. [a] -> [a] -> [a]
++ Permissions -> String
forall a. Show a => a -> String
show Permissions
newPermissions
              String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
")"
        MVar (ConnectionState cdata) -> ConnectionState cdata -> m ()
forall (m :: * -> *) a. MonadBase IO m => MVar a -> a -> m ()
putMVar MVar (ConnectionState cdata)
cdConnectionState ConnectionState cdata
connState
    ConnectionState cdata
Finalized -> String -> m ()
forall a. HasCallStack => String -> a
error String
"finalized connection"

withConnection
  :: (HasCallStack, MonadBase IO m, MonadMask m)
  => ConnectionData m
  -> (Connection -> m r)
  -> m r
withConnection :: forall (m :: * -> *) r.
(HasCallStack, MonadBase IO m, MonadMask m) =>
ConnectionData m -> (Connection -> m r) -> m r
withConnection ConnectionData {MVar (ConnectionState cdata)
InternalConnectionSource m cdata
cdConnectionSource :: ()
cdConnectionState :: ()
cdConnectionSource :: InternalConnectionSource m cdata
cdConnectionState :: MVar (ConnectionState cdata)
..} Connection -> m r
action = do
  m (ConnectionState cdata)
-> (ConnectionState cdata -> m ())
-> (ConnectionState cdata -> m r)
-> m r
forall (m :: * -> *) a c b.
(HasCallStack, MonadMask m) =>
m a -> (a -> m c) -> (a -> m b) -> m b
bracket (MVar (ConnectionState cdata) -> m (ConnectionState cdata)
forall (m :: * -> *) a. MonadBase IO m => MVar a -> m a
takeMVar MVar (ConnectionState cdata)
cdConnectionState) (MVar (ConnectionState cdata) -> ConnectionState cdata -> m ()
forall (m :: * -> *) a. MonadBase IO m => MVar a -> a -> m ()
putMVar MVar (ConnectionState cdata)
cdConnectionState) ((ConnectionState cdata -> m r) -> m r)
-> (ConnectionState cdata -> m r) -> m r
forall a b. (a -> b) -> a -> b
$ \case
    ConnectionState cdata
OnDemand ->
      (r, ()) -> r
forall a b. (a, b) -> a
fst
        ((r, ()) -> r) -> m (r, ()) -> m r
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (Connection, cdata)
-> ((Connection, cdata) -> ExitCase r -> m ())
-> ((Connection, cdata) -> m r)
-> m (r, ())
forall a b c.
HasCallStack =>
m a -> (a -> ExitCase b -> m c) -> (a -> m b) -> m (b, c)
forall (m :: * -> *) a b c.
(MonadMask m, HasCallStack) =>
m a -> (a -> ExitCase b -> m c) -> (a -> m b) -> m (b, c)
generalBracket
          (InternalConnectionSource m cdata -> m (Connection, cdata)
forall (m :: * -> *) cdata.
InternalConnectionSource m cdata -> m (Connection, cdata)
takeConnection InternalConnectionSource m cdata
cdConnectionSource)
          (InternalConnectionSource m cdata
-> forall r. (Connection, cdata) -> ExitCase r -> m ()
forall (m :: * -> *) cdata.
InternalConnectionSource m cdata
-> forall r. (Connection, cdata) -> ExitCase r -> m ()
putConnection InternalConnectionSource m cdata
cdConnectionSource)
          ( \(Connection
conn, cdata
_cdata) ->
              m (Int, ForeignPtr PGresult, ConnectionStats -> ConnectionStats)
-> m (Int, ForeignPtr PGresult, ConnectionStats -> ConnectionStats)
-> m r
-> m r
forall (m :: * -> *) a c b.
(HasCallStack, MonadMask m) =>
m a -> m c -> m b -> m b
bracket_
                (IO (Int, ForeignPtr PGresult, ConnectionStats -> ConnectionStats)
-> m (Int, ForeignPtr PGresult, ConnectionStats -> ConnectionStats)
forall α. IO α -> m α
forall (b :: * -> *) (m :: * -> *) α. MonadBase b m => b α -> m α
liftBase (IO (Int, ForeignPtr PGresult, ConnectionStats -> ConnectionStats)
 -> m (Int, ForeignPtr PGresult,
       ConnectionStats -> ConnectionStats))
-> (IO
      (Int, ForeignPtr PGresult, ConnectionStats -> ConnectionStats)
    -> IO
         (Int, ForeignPtr PGresult, ConnectionStats -> ConnectionStats))
-> IO
     (Int, ForeignPtr PGresult, ConnectionStats -> ConnectionStats)
-> m (Int, ForeignPtr PGresult, ConnectionStats -> ConnectionStats)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO (Int, ForeignPtr PGresult, ConnectionStats -> ConnectionStats)
-> IO
     (Int, ForeignPtr PGresult, ConnectionStats -> ConnectionStats)
forall (m :: * -> *) a. (HasCallStack, MonadMask m) => m a -> m a
uninterruptibleMask_ (IO (Int, ForeignPtr PGresult, ConnectionStats -> ConnectionStats)
 -> m (Int, ForeignPtr PGresult,
       ConnectionStats -> ConnectionStats))
-> IO
     (Int, ForeignPtr PGresult, ConnectionStats -> ConnectionStats)
-> m (Int, ForeignPtr PGresult, ConnectionStats -> ConnectionStats)
forall a b. (a -> b) -> a -> b
$ forall sql.
(HasCallStack, IsSQL sql) =>
Connection
-> sql
-> IO
     (Int, ForeignPtr PGresult, ConnectionStats -> ConnectionStats)
runQueryIO @SQL Connection
conn SQL
"BEGIN READ ONLY")
                (IO (Int, ForeignPtr PGresult, ConnectionStats -> ConnectionStats)
-> m (Int, ForeignPtr PGresult, ConnectionStats -> ConnectionStats)
forall α. IO α -> m α
forall (b :: * -> *) (m :: * -> *) α. MonadBase b m => b α -> m α
liftBase (IO (Int, ForeignPtr PGresult, ConnectionStats -> ConnectionStats)
 -> m (Int, ForeignPtr PGresult,
       ConnectionStats -> ConnectionStats))
-> (IO
      (Int, ForeignPtr PGresult, ConnectionStats -> ConnectionStats)
    -> IO
         (Int, ForeignPtr PGresult, ConnectionStats -> ConnectionStats))
-> IO
     (Int, ForeignPtr PGresult, ConnectionStats -> ConnectionStats)
-> m (Int, ForeignPtr PGresult, ConnectionStats -> ConnectionStats)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO (Int, ForeignPtr PGresult, ConnectionStats -> ConnectionStats)
-> IO
     (Int, ForeignPtr PGresult, ConnectionStats -> ConnectionStats)
forall (m :: * -> *) a. (HasCallStack, MonadMask m) => m a -> m a
uninterruptibleMask_ (IO (Int, ForeignPtr PGresult, ConnectionStats -> ConnectionStats)
 -> m (Int, ForeignPtr PGresult,
       ConnectionStats -> ConnectionStats))
-> IO
     (Int, ForeignPtr PGresult, ConnectionStats -> ConnectionStats)
-> m (Int, ForeignPtr PGresult, ConnectionStats -> ConnectionStats)
forall a b. (a -> b) -> a -> b
$ forall sql.
(HasCallStack, IsSQL sql) =>
Connection
-> sql
-> IO
     (Int, ForeignPtr PGresult, ConnectionStats -> ConnectionStats)
runQueryIO @SQL Connection
conn SQL
"ROLLBACK")
                (Connection -> m r
action Connection
conn)
          )
    Acquired IsolationLevel
_ Permissions
_ Connection
conn cdata
_ -> Connection -> m r
action Connection
conn
    ConnectionState cdata
Finalized -> String -> m r
forall a. HasCallStack => String -> a
error String
"finalized connection"

initConnectionData
  :: (MonadBase IO m, MonadMask m)
  => ConnectionSourceM m
  -> ConnectionAcquisitionMode
  -> m (ConnectionData m)
initConnectionData :: forall (m :: * -> *).
(MonadBase IO m, MonadMask m) =>
ConnectionSourceM m
-> ConnectionAcquisitionMode -> m (ConnectionData m)
initConnectionData (ConnectionSourceM InternalConnectionSource m cdata
ics) ConnectionAcquisitionMode
cam = do
  connState <- ConnectionState cdata -> m (MVar (ConnectionState cdata))
forall (m :: * -> *) a. MonadBase IO m => a -> m (MVar a)
newMVar (ConnectionState cdata -> m (MVar (ConnectionState cdata)))
-> m (ConnectionState cdata) -> m (MVar (ConnectionState cdata))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< InternalConnectionSource m cdata
-> ConnectionAcquisitionMode -> m (ConnectionState cdata)
forall (m :: * -> *) cdata.
(MonadBase IO m, MonadMask m) =>
InternalConnectionSource m cdata
-> ConnectionAcquisitionMode -> m (ConnectionState cdata)
initConnectionState InternalConnectionSource m cdata
ics ConnectionAcquisitionMode
cam
  pure $
    ConnectionData
      { cdConnectionSource = ics
      , cdConnectionState = connState
      }

finalizeConnectionData
  :: (HasCallStack, MonadBase IO m, MonadMask m)
  => ConnectionData m
  -> ExitCase r
  -> m ()
finalizeConnectionData :: forall (m :: * -> *) r.
(HasCallStack, MonadBase IO m, MonadMask m) =>
ConnectionData m -> ExitCase r -> m ()
finalizeConnectionData ConnectionData {MVar (ConnectionState cdata)
InternalConnectionSource m cdata
cdConnectionSource :: ()
cdConnectionState :: ()
cdConnectionSource :: InternalConnectionSource m cdata
cdConnectionState :: MVar (ConnectionState cdata)
..} ExitCase r
ec = do
  (m () -> m () -> m ()
forall (m :: * -> *) a b.
(HasCallStack, MonadMask m) =>
m a -> m b -> m a
`finally` MVar (ConnectionState cdata) -> ConnectionState cdata -> m ()
forall (m :: * -> *) a. MonadBase IO m => MVar a -> a -> m ()
putMVar MVar (ConnectionState cdata)
cdConnectionState ConnectionState cdata
forall cdata. ConnectionState cdata
Finalized) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    connState <- MVar (ConnectionState cdata) -> m (ConnectionState cdata)
forall (m :: * -> *) a. MonadBase IO m => MVar a -> m a
takeMVar MVar (ConnectionState cdata)
cdConnectionState
    finalizeConnectionState cdConnectionSource ec connState

----------------------------------------

-- | Internal DB state.
data DBState m = DBState
  { forall (m :: * -> *). DBState m -> ConnectionData m
dbConnectionData :: !(ConnectionData m)
  -- ^ Active connection.
  , forall (m :: * -> *). DBState m -> ConnectionStats
dbConnectionStats :: !ConnectionStats
  -- ^ Statistics associated with the session.
  , forall (m :: * -> *). DBState m -> Maybe RestartPredicate
dbRestartPredicate :: !(Maybe RestartPredicate)
  -- ^ Restart predicate from initial 'TransactionSettings'.
  , forall (m :: * -> *). DBState m -> (BackendPid, SomeSQL)
dbLastQuery :: !(BackendPid, SomeSQL)
  -- ^ Last SQL query that was executed along with ID of the server process
  -- attached to the session that executed it.
  , forall (m :: * -> *). DBState m -> Bool
dbRecordLastQuery :: !Bool
  -- ^ Whether running query should override 'dbLastQuery'.
  , forall (m :: * -> *).
DBState m -> forall row. FromRow row => Maybe (QueryResult row)
dbQueryResult :: !(forall row. FromRow row => Maybe (QueryResult row))
  -- ^ Current query result.
  }

mkDBState
  :: ConnectionData m
  -> TransactionSettings
  -> DBState m
mkDBState :: forall (m :: * -> *).
ConnectionData m -> TransactionSettings -> DBState m
mkDBState ConnectionData m
cd TransactionSettings
ts =
  DBState
    { dbConnectionData :: ConnectionData m
dbConnectionData = ConnectionData m
cd
    , dbConnectionStats :: ConnectionStats
dbConnectionStats = ConnectionStats
initialConnectionStats
    , dbRestartPredicate :: Maybe RestartPredicate
dbRestartPredicate = TransactionSettings -> Maybe RestartPredicate
tsRestartPredicate TransactionSettings
ts
    , dbLastQuery :: (BackendPid, SomeSQL)
dbLastQuery = (BackendPid
noBackendPid, SQL -> SomeSQL
forall sql. IsSQL sql => sql -> SomeSQL
SomeSQL (SQL
forall a. Monoid a => a
mempty :: SQL))
    , dbRecordLastQuery :: Bool
dbRecordLastQuery = Bool
True
    , dbQueryResult :: forall row. FromRow row => Maybe (QueryResult row)
dbQueryResult = Maybe (QueryResult row)
forall a. Maybe a
forall row. FromRow row => Maybe (QueryResult row)
Nothing
    }

updateStateWith
  :: IsSQL sql
  => Connection
  -> DBState m
  -> sql
  -> (r, ForeignPtr PGresult, ConnectionStats -> ConnectionStats)
  -> IO (r, DBState m)
updateStateWith :: forall sql (m :: * -> *) r.
IsSQL sql =>
Connection
-> DBState m
-> sql
-> (r, ForeignPtr PGresult, ConnectionStats -> ConnectionStats)
-> IO (r, DBState m)
updateStateWith Connection
conn DBState m
st sql
sql (r
r, ForeignPtr PGresult
res, ConnectionStats -> ConnectionStats
updateStats) = do
  (r, DBState m) -> IO (r, DBState m)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( r
r
    , DBState m
st
        { dbConnectionStats = updateStats $ dbConnectionStats st
        , dbLastQuery =
            if dbRecordLastQuery st
              then (connBackendPid conn, SomeSQL sql)
              else dbLastQuery st
        , dbQueryResult = Just $ mkQueryResult sql (connBackendPid conn) res
        }
    )