module Database.PostgreSQL.PQTypes.Transaction
  ( Savepoint (..)
  , withSavepoint
  , begin
  , commit
  , rollback
  , unsafeWithoutTransaction
  ) where

import Control.Monad.Catch
import Data.String
import GHC.Stack

import Data.Monoid.Utils
import Database.PostgreSQL.PQTypes.Class
import Database.PostgreSQL.PQTypes.Internal.Error
import Database.PostgreSQL.PQTypes.SQL.Raw
import Database.PostgreSQL.PQTypes.Transaction.Settings
import Database.PostgreSQL.PQTypes.Utils

-- | Wrapper that represents savepoint name.
newtype Savepoint = Savepoint (RawSQL ())

instance IsString Savepoint where
  fromString :: String -> Savepoint
fromString = RawSQL () -> Savepoint
Savepoint (RawSQL () -> Savepoint)
-> (String -> RawSQL ()) -> String -> Savepoint
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> RawSQL ()
forall a. IsString a => String -> a
fromString

-- | Create a savepoint and roll back to it if given monadic action throws.
-- This may only be used if a transaction is already active. Note that it
-- provides something like \"nested transaction\".
--
-- See <http://www.postgresql.org/docs/current/static/sql-savepoint.html>
withSavepoint :: (HasCallStack, MonadDB m, MonadMask m) => Savepoint -> m a -> m a
withSavepoint :: forall (m :: * -> *) a.
(HasCallStack, MonadDB m, MonadMask m) =>
Savepoint -> m a -> m a
withSavepoint (Savepoint RawSQL ()
savepoint) m a
m =
  (a, ()) -> a
forall a b. (a, b) -> a
fst
    ((a, ()) -> a) -> m (a, ()) -> m a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m () -> (() -> ExitCase a -> m ()) -> (() -> m a) -> m (a, ())
forall a b c.
HasCallStack =>
m a -> (a -> ExitCase b -> m c) -> (a -> m b) -> m (b, c)
forall (m :: * -> *) a b c.
(MonadMask m, HasCallStack) =>
m a -> (a -> ExitCase b -> m c) -> (a -> m b) -> m (b, c)
generalBracket
      (RawSQL () -> m ()
forall sql (m :: * -> *).
(HasCallStack, IsSQL sql, MonadDB m) =>
sql -> m ()
runQuery_ (RawSQL () -> m ()) -> RawSQL () -> m ()
forall a b. (a -> b) -> a -> b
$ RawSQL ()
"SAVEPOINT" RawSQL () -> RawSQL () -> RawSQL ()
forall m. (IsString m, Monoid m) => m -> m -> m
<+> RawSQL ()
savepoint)
      ( \() -> \case
          ExitCaseSuccess a
_ -> RawSQL () -> m ()
forall sql (m :: * -> *).
(HasCallStack, IsSQL sql, MonadDB m) =>
sql -> m ()
runQuery_ RawSQL ()
sqlReleaseSavepoint
          ExitCase a
_ -> m ()
rollbackAndReleaseSavepoint
      )
      (\() -> m a
m)
  where
    sqlReleaseSavepoint :: RawSQL ()
sqlReleaseSavepoint = RawSQL ()
"RELEASE SAVEPOINT" RawSQL () -> RawSQL () -> RawSQL ()
forall m. (IsString m, Monoid m) => m -> m -> m
<+> RawSQL ()
savepoint
    rollbackAndReleaseSavepoint :: m ()
rollbackAndReleaseSavepoint = do
      RawSQL () -> m ()
forall sql (m :: * -> *).
(HasCallStack, IsSQL sql, MonadDB m) =>
sql -> m ()
runQuery_ (RawSQL () -> m ()) -> RawSQL () -> m ()
forall a b. (a -> b) -> a -> b
$ RawSQL ()
"ROLLBACK TO SAVEPOINT" RawSQL () -> RawSQL () -> RawSQL ()
forall m. (IsString m, Monoid m) => m -> m -> m
<+> RawSQL ()
savepoint
      RawSQL () -> m ()
forall sql (m :: * -> *).
(HasCallStack, IsSQL sql, MonadDB m) =>
sql -> m ()
runQuery_ RawSQL ()
sqlReleaseSavepoint

----------------------------------------

-- Note: sql queries in below functions that modify transaction state need to
-- not be interruptible so we don't end up in unexpected transaction
-- state. However, getConnectionAcquisitionMode should be interruptible to not
-- lead to deadlocks if a connection ends up being used from multiple threads.

-- | Begin transaction using given transaction settings.
begin :: (HasCallStack, MonadDB m, MonadMask m) => m ()
begin :: forall (m :: * -> *).
(HasCallStack, MonadDB m, MonadMask m) =>
m ()
begin = do
  m ConnectionAcquisitionMode
forall (m :: * -> *).
(MonadDB m, HasCallStack) =>
m ConnectionAcquisitionMode
getConnectionAcquisitionMode m ConnectionAcquisitionMode
-> (ConnectionAcquisitionMode -> m ()) -> m ()
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    ConnectionAcquisitionMode
AcquireOnDemand -> do
      HPQTypesError -> m ()
forall e (m :: * -> *) a.
(HasCallStack, Exception e, MonadDB m, MonadThrow m) =>
e -> m a
throwDB (HPQTypesError -> m ()) -> HPQTypesError -> m ()
forall a b. (a -> b) -> a -> b
$ String -> HPQTypesError
HPQTypesError String
"Can't begin a transaction in OnDemand mode"
    AcquireAndHold IsolationLevel
isolationLevel Permissions
permissions -> m () -> m ()
forall (m :: * -> *) a. (HasCallStack, MonadMask m) => m a -> m a
uninterruptibleMask_ (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
      SQL -> m ()
forall (m :: * -> *). (HasCallStack, MonadDB m) => SQL -> m ()
runSQL_ (SQL -> m ()) -> SQL -> m ()
forall a b. (a -> b) -> a -> b
$
        [SQL] -> SQL
forall m. (IsString m, Monoid m) => [m] -> m
smconcat
          [ SQL
"BEGIN"
          , case IsolationLevel
isolationLevel of
              IsolationLevel
DefaultLevel -> SQL
""
              IsolationLevel
ReadCommitted -> SQL
"ISOLATION LEVEL READ COMMITTED"
              IsolationLevel
RepeatableRead -> SQL
"ISOLATION LEVEL REPEATABLE READ"
              IsolationLevel
Serializable -> SQL
"ISOLATION LEVEL SERIALIZABLE"
          , case Permissions
permissions of
              Permissions
DefaultPermissions -> SQL
""
              Permissions
ReadOnly -> SQL
"READ ONLY"
              Permissions
ReadWrite -> SQL
"READ WRITE"
          ]

-- | Commit active transaction using given transaction settings.
commit :: (HasCallStack, MonadDB m, MonadMask m) => m ()
commit :: forall (m :: * -> *).
(HasCallStack, MonadDB m, MonadMask m) =>
m ()
commit = do
  m ConnectionAcquisitionMode
forall (m :: * -> *).
(MonadDB m, HasCallStack) =>
m ConnectionAcquisitionMode
getConnectionAcquisitionMode m ConnectionAcquisitionMode
-> (ConnectionAcquisitionMode -> m ()) -> m ()
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    ConnectionAcquisitionMode
AcquireOnDemand -> do
      HPQTypesError -> m ()
forall e (m :: * -> *) a.
(HasCallStack, Exception e, MonadDB m, MonadThrow m) =>
e -> m a
throwDB (HPQTypesError -> m ()) -> HPQTypesError -> m ()
forall a b. (a -> b) -> a -> b
$ String -> HPQTypesError
HPQTypesError String
"Can't commit a transaction in OnDemand mode"
    AcquireAndHold {} -> m () -> m ()
forall (m :: * -> *) a. (HasCallStack, MonadMask m) => m a -> m a
uninterruptibleMask_ (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
      SQL -> m ()
forall (m :: * -> *). (HasCallStack, MonadDB m) => SQL -> m ()
runSQL_ SQL
"COMMIT"
      m ()
forall (m :: * -> *).
(HasCallStack, MonadDB m, MonadMask m) =>
m ()
begin

-- | Rollback active transaction using given transaction settings.
rollback :: (HasCallStack, MonadDB m, MonadMask m) => m ()
rollback :: forall (m :: * -> *).
(HasCallStack, MonadDB m, MonadMask m) =>
m ()
rollback = do
  m ConnectionAcquisitionMode
forall (m :: * -> *).
(MonadDB m, HasCallStack) =>
m ConnectionAcquisitionMode
getConnectionAcquisitionMode m ConnectionAcquisitionMode
-> (ConnectionAcquisitionMode -> m ()) -> m ()
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    ConnectionAcquisitionMode
AcquireOnDemand -> do
      HPQTypesError -> m ()
forall e (m :: * -> *) a.
(HasCallStack, Exception e, MonadDB m, MonadThrow m) =>
e -> m a
throwDB (HPQTypesError -> m ()) -> HPQTypesError -> m ()
forall a b. (a -> b) -> a -> b
$ String -> HPQTypesError
HPQTypesError String
"Can't rollback a transaction in OnDemand mode"
    AcquireAndHold {} -> m () -> m ()
forall (m :: * -> *) a. (HasCallStack, MonadMask m) => m a -> m a
uninterruptibleMask_ (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
      SQL -> m ()
forall (m :: * -> *). (HasCallStack, MonadDB m) => SQL -> m ()
runSQL_ SQL
"ROLLBACK"
      m ()
forall (m :: * -> *).
(HasCallStack, MonadDB m, MonadMask m) =>
m ()
begin

-- | Run a block of code without an open transaction.
--
-- This function is unsafe, because if there is a transaction in progress, it's
-- commited, so the atomicity guarantee is lost.
unsafeWithoutTransaction
  :: (HasCallStack, MonadDB m, MonadMask m)
  => m a
  -> m a
unsafeWithoutTransaction :: forall (m :: * -> *) a.
(HasCallStack, MonadDB m, MonadMask m) =>
m a -> m a
unsafeWithoutTransaction m a
action = do
  m ConnectionAcquisitionMode
forall (m :: * -> *).
(MonadDB m, HasCallStack) =>
m ConnectionAcquisitionMode
getConnectionAcquisitionMode m ConnectionAcquisitionMode
-> (ConnectionAcquisitionMode -> m a) -> m a
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    ConnectionAcquisitionMode
AcquireOnDemand -> m a
action
    AcquireAndHold {} ->
      m () -> m () -> m a -> m a
forall (m :: * -> *) a c b.
(HasCallStack, MonadMask m) =>
m a -> m c -> m b -> m b
bracket_
        (m () -> m ()
forall (m :: * -> *) a. (HasCallStack, MonadMask m) => m a -> m a
uninterruptibleMask_ (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ SQL -> m ()
forall (m :: * -> *). (HasCallStack, MonadDB m) => SQL -> m ()
runSQL_ SQL
"COMMIT")
        m ()
forall (m :: * -> *).
(HasCallStack, MonadDB m, MonadMask m) =>
m ()
begin
        m a
action