module Hpgsql.Transaction (withTransaction, withTransactionMode, begin, beginMode, commit, rollback, transactionStatus, IsolationLevel (..), ReadWriteMode (..), TransactionStatus (..)) where

import qualified Control.Concurrent.STM as STM
import Control.Exception.Safe (Exception (..), bracketWithError, throw, tryAny)
import Control.Monad (unless)
import Hpgsql.Internal (execute_, fullTransactionStatus, transactionStatus)
import Hpgsql.InternalTypes (HPgConnection (..), InternalConnectionState (..), IrrecoverableHpgsqlError)
import Hpgsql.Query (Query)
import Hpgsql.TransactionStatusInternal (TransactionStatus (..))

-- The types and constructors here have matching names to postgresql-simple where
-- I thought sameness would be convenient. The implementation is of course our
-- own, and I did choose to keep some things different out of my preferences.

data IsolationLevel
  = DefaultIsolationLevel
  | ReadUncommitted
  | ReadCommitted
  | RepeatableRead
  | Serializable
  deriving (Int -> IsolationLevel -> ShowS
[IsolationLevel] -> ShowS
IsolationLevel -> String
(Int -> IsolationLevel -> ShowS)
-> (IsolationLevel -> String)
-> ([IsolationLevel] -> ShowS)
-> Show IsolationLevel
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> IsolationLevel -> ShowS
showsPrec :: Int -> IsolationLevel -> ShowS
$cshow :: IsolationLevel -> String
show :: IsolationLevel -> String
$cshowList :: [IsolationLevel] -> ShowS
showList :: [IsolationLevel] -> ShowS
Show, IsolationLevel -> IsolationLevel -> Bool
(IsolationLevel -> IsolationLevel -> Bool)
-> (IsolationLevel -> IsolationLevel -> Bool) -> Eq IsolationLevel
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: IsolationLevel -> IsolationLevel -> Bool
== :: IsolationLevel -> IsolationLevel -> Bool
$c/= :: IsolationLevel -> IsolationLevel -> Bool
/= :: IsolationLevel -> IsolationLevel -> Bool
Eq, Eq IsolationLevel
Eq IsolationLevel =>
(IsolationLevel -> IsolationLevel -> Ordering)
-> (IsolationLevel -> IsolationLevel -> Bool)
-> (IsolationLevel -> IsolationLevel -> Bool)
-> (IsolationLevel -> IsolationLevel -> Bool)
-> (IsolationLevel -> IsolationLevel -> Bool)
-> (IsolationLevel -> IsolationLevel -> IsolationLevel)
-> (IsolationLevel -> IsolationLevel -> IsolationLevel)
-> Ord IsolationLevel
IsolationLevel -> IsolationLevel -> Bool
IsolationLevel -> IsolationLevel -> Ordering
IsolationLevel -> IsolationLevel -> IsolationLevel
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: IsolationLevel -> IsolationLevel -> Ordering
compare :: IsolationLevel -> IsolationLevel -> Ordering
$c< :: IsolationLevel -> IsolationLevel -> Bool
< :: IsolationLevel -> IsolationLevel -> Bool
$c<= :: IsolationLevel -> IsolationLevel -> Bool
<= :: IsolationLevel -> IsolationLevel -> Bool
$c> :: IsolationLevel -> IsolationLevel -> Bool
> :: IsolationLevel -> IsolationLevel -> Bool
$c>= :: IsolationLevel -> IsolationLevel -> Bool
>= :: IsolationLevel -> IsolationLevel -> Bool
$cmax :: IsolationLevel -> IsolationLevel -> IsolationLevel
max :: IsolationLevel -> IsolationLevel -> IsolationLevel
$cmin :: IsolationLevel -> IsolationLevel -> IsolationLevel
min :: IsolationLevel -> IsolationLevel -> IsolationLevel
Ord, Int -> IsolationLevel
IsolationLevel -> Int
IsolationLevel -> [IsolationLevel]
IsolationLevel -> IsolationLevel
IsolationLevel -> IsolationLevel -> [IsolationLevel]
IsolationLevel
-> IsolationLevel -> IsolationLevel -> [IsolationLevel]
(IsolationLevel -> IsolationLevel)
-> (IsolationLevel -> IsolationLevel)
-> (Int -> IsolationLevel)
-> (IsolationLevel -> Int)
-> (IsolationLevel -> [IsolationLevel])
-> (IsolationLevel -> IsolationLevel -> [IsolationLevel])
-> (IsolationLevel -> IsolationLevel -> [IsolationLevel])
-> (IsolationLevel
    -> IsolationLevel -> IsolationLevel -> [IsolationLevel])
-> Enum IsolationLevel
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
$csucc :: IsolationLevel -> IsolationLevel
succ :: IsolationLevel -> IsolationLevel
$cpred :: IsolationLevel -> IsolationLevel
pred :: IsolationLevel -> IsolationLevel
$ctoEnum :: Int -> IsolationLevel
toEnum :: Int -> IsolationLevel
$cfromEnum :: IsolationLevel -> Int
fromEnum :: IsolationLevel -> Int
$cenumFrom :: IsolationLevel -> [IsolationLevel]
enumFrom :: IsolationLevel -> [IsolationLevel]
$cenumFromThen :: IsolationLevel -> IsolationLevel -> [IsolationLevel]
enumFromThen :: IsolationLevel -> IsolationLevel -> [IsolationLevel]
$cenumFromTo :: IsolationLevel -> IsolationLevel -> [IsolationLevel]
enumFromTo :: IsolationLevel -> IsolationLevel -> [IsolationLevel]
$cenumFromThenTo :: IsolationLevel
-> IsolationLevel -> IsolationLevel -> [IsolationLevel]
enumFromThenTo :: IsolationLevel
-> IsolationLevel -> IsolationLevel -> [IsolationLevel]
Enum, IsolationLevel
IsolationLevel -> IsolationLevel -> Bounded IsolationLevel
forall a. a -> a -> Bounded a
$cminBound :: IsolationLevel
minBound :: IsolationLevel
$cmaxBound :: IsolationLevel
maxBound :: IsolationLevel
Bounded)

data ReadWriteMode
  = DefaultReadWriteMode
  | ReadWrite
  | ReadOnly
  deriving (Int -> ReadWriteMode -> ShowS
[ReadWriteMode] -> ShowS
ReadWriteMode -> String
(Int -> ReadWriteMode -> ShowS)
-> (ReadWriteMode -> String)
-> ([ReadWriteMode] -> ShowS)
-> Show ReadWriteMode
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ReadWriteMode -> ShowS
showsPrec :: Int -> ReadWriteMode -> ShowS
$cshow :: ReadWriteMode -> String
show :: ReadWriteMode -> String
$cshowList :: [ReadWriteMode] -> ShowS
showList :: [ReadWriteMode] -> ShowS
Show, ReadWriteMode -> ReadWriteMode -> Bool
(ReadWriteMode -> ReadWriteMode -> Bool)
-> (ReadWriteMode -> ReadWriteMode -> Bool) -> Eq ReadWriteMode
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ReadWriteMode -> ReadWriteMode -> Bool
== :: ReadWriteMode -> ReadWriteMode -> Bool
$c/= :: ReadWriteMode -> ReadWriteMode -> Bool
/= :: ReadWriteMode -> ReadWriteMode -> Bool
Eq, Eq ReadWriteMode
Eq ReadWriteMode =>
(ReadWriteMode -> ReadWriteMode -> Ordering)
-> (ReadWriteMode -> ReadWriteMode -> Bool)
-> (ReadWriteMode -> ReadWriteMode -> Bool)
-> (ReadWriteMode -> ReadWriteMode -> Bool)
-> (ReadWriteMode -> ReadWriteMode -> Bool)
-> (ReadWriteMode -> ReadWriteMode -> ReadWriteMode)
-> (ReadWriteMode -> ReadWriteMode -> ReadWriteMode)
-> Ord ReadWriteMode
ReadWriteMode -> ReadWriteMode -> Bool
ReadWriteMode -> ReadWriteMode -> Ordering
ReadWriteMode -> ReadWriteMode -> ReadWriteMode
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: ReadWriteMode -> ReadWriteMode -> Ordering
compare :: ReadWriteMode -> ReadWriteMode -> Ordering
$c< :: ReadWriteMode -> ReadWriteMode -> Bool
< :: ReadWriteMode -> ReadWriteMode -> Bool
$c<= :: ReadWriteMode -> ReadWriteMode -> Bool
<= :: ReadWriteMode -> ReadWriteMode -> Bool
$c> :: ReadWriteMode -> ReadWriteMode -> Bool
> :: ReadWriteMode -> ReadWriteMode -> Bool
$c>= :: ReadWriteMode -> ReadWriteMode -> Bool
>= :: ReadWriteMode -> ReadWriteMode -> Bool
$cmax :: ReadWriteMode -> ReadWriteMode -> ReadWriteMode
max :: ReadWriteMode -> ReadWriteMode -> ReadWriteMode
$cmin :: ReadWriteMode -> ReadWriteMode -> ReadWriteMode
min :: ReadWriteMode -> ReadWriteMode -> ReadWriteMode
Ord, Int -> ReadWriteMode
ReadWriteMode -> Int
ReadWriteMode -> [ReadWriteMode]
ReadWriteMode -> ReadWriteMode
ReadWriteMode -> ReadWriteMode -> [ReadWriteMode]
ReadWriteMode -> ReadWriteMode -> ReadWriteMode -> [ReadWriteMode]
(ReadWriteMode -> ReadWriteMode)
-> (ReadWriteMode -> ReadWriteMode)
-> (Int -> ReadWriteMode)
-> (ReadWriteMode -> Int)
-> (ReadWriteMode -> [ReadWriteMode])
-> (ReadWriteMode -> ReadWriteMode -> [ReadWriteMode])
-> (ReadWriteMode -> ReadWriteMode -> [ReadWriteMode])
-> (ReadWriteMode
    -> ReadWriteMode -> ReadWriteMode -> [ReadWriteMode])
-> Enum ReadWriteMode
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
$csucc :: ReadWriteMode -> ReadWriteMode
succ :: ReadWriteMode -> ReadWriteMode
$cpred :: ReadWriteMode -> ReadWriteMode
pred :: ReadWriteMode -> ReadWriteMode
$ctoEnum :: Int -> ReadWriteMode
toEnum :: Int -> ReadWriteMode
$cfromEnum :: ReadWriteMode -> Int
fromEnum :: ReadWriteMode -> Int
$cenumFrom :: ReadWriteMode -> [ReadWriteMode]
enumFrom :: ReadWriteMode -> [ReadWriteMode]
$cenumFromThen :: ReadWriteMode -> ReadWriteMode -> [ReadWriteMode]
enumFromThen :: ReadWriteMode -> ReadWriteMode -> [ReadWriteMode]
$cenumFromTo :: ReadWriteMode -> ReadWriteMode -> [ReadWriteMode]
enumFromTo :: ReadWriteMode -> ReadWriteMode -> [ReadWriteMode]
$cenumFromThenTo :: ReadWriteMode -> ReadWriteMode -> ReadWriteMode -> [ReadWriteMode]
enumFromThenTo :: ReadWriteMode -> ReadWriteMode -> ReadWriteMode -> [ReadWriteMode]
Enum, ReadWriteMode
ReadWriteMode -> ReadWriteMode -> Bounded ReadWriteMode
forall a. a -> a -> Bounded a
$cminBound :: ReadWriteMode
minBound :: ReadWriteMode
$cmaxBound :: ReadWriteMode
maxBound :: ReadWriteMode
Bounded)

begin :: HPgConnection -> IO ()
begin :: HPgConnection -> IO ()
begin HPgConnection
conn = HPgConnection -> IsolationLevel -> ReadWriteMode -> IO ()
beginMode HPgConnection
conn IsolationLevel
DefaultIsolationLevel ReadWriteMode
DefaultReadWriteMode

beginMode :: HPgConnection -> IsolationLevel -> ReadWriteMode -> IO ()
beginMode :: HPgConnection -> IsolationLevel -> ReadWriteMode -> IO ()
beginMode HPgConnection
conn IsolationLevel
il ReadWriteMode
rw = do
  let readWrite :: Maybe Query
readWrite = case ReadWriteMode
rw of
        ReadWriteMode
ReadWrite -> Query -> Maybe Query
forall a. a -> Maybe a
Just Query
"READ WRITE"
        ReadWriteMode
ReadOnly -> Query -> Maybe Query
forall a. a -> Maybe a
Just Query
"READ ONLY"
        ReadWriteMode
DefaultReadWriteMode -> Maybe Query
forall a. Maybe a
Nothing
      isolLvl :: Maybe Query
isolLvl = case IsolationLevel
il of
        IsolationLevel
DefaultIsolationLevel -> Maybe Query
forall a. Maybe a
Nothing
        IsolationLevel
Serializable -> Query -> Maybe Query
forall a. a -> Maybe a
Just Query
"ISOLATION LEVEL SERIALIZABLE"
        IsolationLevel
RepeatableRead -> Query -> Maybe Query
forall a. a -> Maybe a
Just Query
"ISOLATION LEVEL REPEATABLE READ"
        IsolationLevel
ReadCommitted -> Query -> Maybe Query
forall a. a -> Maybe a
Just Query
"ISOLATION LEVEL READ COMMITTED"
        IsolationLevel
ReadUncommitted -> Query -> Maybe Query
forall a. a -> Maybe a
Just Query
"ISOLATION LEVEL READ UNCOMMITTED"
  HPgConnection -> Query -> IO ()
execute_ HPgConnection
conn (Query -> IO ()) -> Query -> IO ()
forall a b. (a -> b) -> a -> b
$ Query
"BEGIN " Query -> Query -> Query
forall a. Semigroup a => a -> a -> a
<> Maybe Query -> Maybe Query -> Query
withComma Maybe Query
readWrite Maybe Query
isolLvl
  where
    withComma :: Maybe Query -> Maybe Query -> Query
    withComma :: Maybe Query -> Maybe Query -> Query
withComma Maybe Query
mv1 Maybe Query
mv2 = case (Maybe Query
mv1, Maybe Query
mv2) of
      (Just Query
v1, Just Query
v2) -> Query
v1 Query -> Query -> Query
forall a. Semigroup a => a -> a -> a
<> Query
"," Query -> Query -> Query
forall a. Semigroup a => a -> a -> a
<> Query
v2
      (Just Query
v1, Maybe Query
Nothing) -> Query
v1
      (Maybe Query
Nothing, Just Query
v2) -> Query
v2
      (Maybe Query
Nothing, Maybe Query
Nothing) -> Query
""

commit :: HPgConnection -> IO ()
commit :: HPgConnection -> IO ()
commit HPgConnection
conn = HPgConnection -> Query -> IO ()
execute_ HPgConnection
conn Query
"COMMIT"

rollback :: HPgConnection -> IO ()
rollback :: HPgConnection -> IO ()
rollback HPgConnection
conn = HPgConnection -> Query -> IO ()
execute_ HPgConnection
conn Query
"ROLLBACK"

-- | Runs the supplied function inside a transaction with the database's
-- default isolation level and read-write mode, i.e. `BEGIN`s a transaction,
-- runs the supplied function and then `COMMIT`s if there are no exceptions.
-- In case the supplied function throws an exception, this runs `ROLLBACK`.
-- In case an asynchronous exception is thrown, Hpgsql ensures a `ROLLBACK`
-- will be issued before your next query on the same connection.
withTransaction :: HPgConnection -> IO a -> IO a
withTransaction :: forall a. HPgConnection -> IO a -> IO a
withTransaction HPgConnection
conn = HPgConnection -> IsolationLevel -> ReadWriteMode -> IO a -> IO a
forall a.
HPgConnection -> IsolationLevel -> ReadWriteMode -> IO a -> IO a
withTransactionMode HPgConnection
conn IsolationLevel
DefaultIsolationLevel ReadWriteMode
DefaultReadWriteMode

-- | Runs the supplied function inside a transaction with the supplied
-- isolation level and read-write mode, i.e. `BEGIN`s a transaction,
-- runs the supplied function and then `COMMIT`s if there are no exceptions.
-- In case the supplied function throws an exception, this runs `ROLLBACK`.
-- In case an asynchronous exception is thrown, Hpgsql ensures a `ROLLBACK`
-- will be issued before your next query on this connection.
withTransactionMode :: HPgConnection -> IsolationLevel -> ReadWriteMode -> IO a -> IO a
withTransactionMode :: forall a.
HPgConnection -> IsolationLevel -> ReadWriteMode -> IO a -> IO a
withTransactionMode HPgConnection
conn IsolationLevel
il ReadWriteMode
rw IO a
f = IO ()
-> (Maybe SomeException -> () -> IO ()) -> (() -> IO a) -> IO a
forall (m :: * -> *) a b c.
(HasCallStack, MonadMask m) =>
m a -> (Maybe SomeException -> a -> m b) -> (a -> m c) -> m c
bracketWithError (HPgConnection -> IsolationLevel -> ReadWriteMode -> IO ()
beginMode HPgConnection
conn IsolationLevel
il ReadWriteMode
rw) Maybe SomeException -> () -> IO ()
cleanup ((() -> IO a) -> IO a) -> (() -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \() -> do
  res <- IO a -> IO (Either SomeException a)
forall (m :: * -> *) a.
(HasCallStack, MonadCatch m) =>
m a -> m (Either SomeException a)
tryAny IO a
f
  case res of
    Left SomeException
ex -> case SomeException -> Maybe IrrecoverableHpgsqlError
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
ex of
      Just (IrrecoverableHpgsqlError
_ :: IrrecoverableHpgsqlError) -> SomeException -> IO a
forall (m :: * -> *) e a.
(HasCallStack, MonadThrow m, Exception e) =>
e -> m a
throw SomeException
ex -- Rethrow internal errors
      Maybe IrrecoverableHpgsqlError
Nothing -> do
        -- In case of a synchronous exception, we rollback synchronously.
        -- If this is interrupted:
        -- - Before ROLLBACK is sent, `cleanup` will enqueue a "ROLLBACK".
        -- - After ROLLBACK is sent but before it finishes, a ROLLBACK will still be enqueued
        --   (check `cleanup`), which means:
        --      - If ROLLBACK has already completed by the time the new ROLLBACK is meant to be sent,
        --        the new ROLLBACK will produce a "WARNING: there is no transaction in progress". This
        --        isn't great, but is mostly harmless.
        --      - If ROLLBACK is cancelled by the future ROLLBACK, all is good as well.
        -- - After ROLLBACK is sent and completed, `cleanup` won't enqueue a new ROLLBACK, and all is well.
        --
        -- We rollback here, not in `cleanup`, because that runs with async exceptions masked
        HPgConnection -> IO ()
rollback HPgConnection
conn IO () -> IO a -> IO a
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> SomeException -> IO a
forall (m :: * -> *) e a.
(HasCallStack, MonadThrow m, Exception e) =>
e -> m a
throw SomeException
ex -- Reminder that this can be interrupted before/after "rollback" is sent
    Right a
v -> do
      -- If this is interrupted:
      -- - Before COMMIT is sent, it's equivalent to failing in the middle of the user-supplied action
      -- - After COMMIT is sent but before it's finished, a ROLLBACK will still be enqueued
      --   (check `cleanup`), which means:
      --      - If COMMIT has already completed by the time ROLLBACK is meant to be sent, ROLLBACK will
      --        produce a "WARNING: there is no transaction in progress". This isn't great, but
      --        is mostly harmless.
      --      - If COMMIT is cancelled by the future ROLLBACK, all is good as well.
      -- - After COMMIT is sent and completed, this is just interruption right after a successful
      --   operation, and `cleanup` won't enqueue a ROLLBACK.
      HPgConnection -> IO ()
commit HPgConnection
conn
      a -> IO a
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
v
  where
    cleanup :: Maybe SomeException -> () -> IO ()
cleanup Maybe SomeException
mEx () = case Maybe SomeException
mEx of
      Maybe SomeException
Nothing -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
      Just (SomeException -> Maybe IrrecoverableHpgsqlError
forall e. Exception e => SomeException -> Maybe e
fromException -> (Just (IrrecoverableHpgsqlError
_ :: IrrecoverableHpgsqlError))) -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure () -- Do nothing if an internal error was thrown, just let it propagate
      Just SomeException
_ex -> do
        -- print _ex
        -- Mark ROLLBACK to be sent before next command
        STM () -> IO ()
forall a. STM a -> IO a
STM.atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
          (_, txnStatus) <- TVar InternalConnectionState
-> STM (TransactionStatus, TransactionStatus)
fullTransactionStatus HPgConnection
conn.internalConnectionState
          -- The rollback in case of a synchronous exception may have run, so we
          -- don't need another one in that case.
          unless (txnStatus == TransIdle) $ do
            st <- STM.readTVar conn.internalConnectionState
            STM.writeTVar conn.internalConnectionState st {mustIssueRollbackBeforeNextCommand = True}