module Database.Persist.Sql.Lifted.Savepoint
  ( rollbackWhen
  ) where

import Prelude ((-))

import Control.Applicative (pure)
import Control.Monad (replicateM)
import Control.Monad.IO.Class (MonadIO)
import Data.Bool (Bool)
import Data.Char (Char)
import Data.Function (($))
import Data.Functor ((<$>))
import Data.Semigroup ((<>))
import Data.Text (Text)
import Data.Text qualified as T
import Data.Vector (Vector)
import Data.Vector qualified as V
import Database.Persist.Sql.Lifted.MonadSqlBackend (MonadSqlBackend)
import Database.Persist.Sql.Lifted.Persistent (rawExecute)
import GHC.Stack (HasCallStack)
import System.Random (randomRIO)

--  | Create a new transaction @SAVEPOINT@, returning its name
newSavepoint :: forall m. (HasCallStack, MonadSqlBackend m) => m Text
newSavepoint :: forall (m :: * -> *). (HasCallStack, MonadSqlBackend m) => m Text
newSavepoint = do
  [Char]
r <- Int -> m Char -> m [Char]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
8 m Char
forall (m :: * -> *). MonadIO m => m Char
randomCharacter
  let savepoint :: Text
savepoint = Text
"savepoint_" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Char] -> Text
T.pack [Char]
r
  Text -> [PersistValue] -> m ()
forall (m :: * -> *).
(HasCallStack, MonadSqlBackend m) =>
Text -> [PersistValue] -> m ()
rawExecute (Text
"SAVEPOINT " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
savepoint) []
  Text -> m Text
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Text
savepoint

randomCharacter :: forall m. MonadIO m => m Char
randomCharacter :: forall (m :: * -> *). MonadIO m => m Char
randomCharacter = (Vector Char
characterSet V.!) (Int -> Char) -> m Int -> m Char
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Int, Int) -> m Int
forall a (m :: * -> *). (Random a, MonadIO m) => (a, a) -> m a
randomRIO (Int
0, Vector Char -> Int
forall a. Vector a -> Int
V.length Vector Char
characterSet Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)

characterSet :: Vector Char
characterSet :: Vector Char
characterSet = [Char] -> Vector Char
forall a. [a] -> Vector a
V.fromList ([Char] -> Vector Char) -> [Char] -> Vector Char
forall a b. (a -> b) -> a -> b
$ [Char
'a' .. Char
'z'] [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char
'1' .. Char
'9']

--  | Release a @SAVEPOINT@
releaseSavepoint :: forall m. (HasCallStack, MonadSqlBackend m) => Text -> m ()
releaseSavepoint :: forall (m :: * -> *).
(HasCallStack, MonadSqlBackend m) =>
Text -> m ()
releaseSavepoint Text
name = Text -> [PersistValue] -> m ()
forall (m :: * -> *).
(HasCallStack, MonadSqlBackend m) =>
Text -> [PersistValue] -> m ()
rawExecute (Text
"RELEASE SAVEPOINT " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
name) []

--  | Rollback to a @SAVEPOINT@
rollbackToSavepoint
  :: forall m. (HasCallStack, MonadSqlBackend m) => Text -> m ()
rollbackToSavepoint :: forall (m :: * -> *).
(HasCallStack, MonadSqlBackend m) =>
Text -> m ()
rollbackToSavepoint Text
name = Text -> [PersistValue] -> m ()
forall (m :: * -> *).
(HasCallStack, MonadSqlBackend m) =>
Text -> [PersistValue] -> m ()
rawExecute (Text
"ROLLBACK TO SAVEPOINT " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
name) []

-- | Runs a SQL action with SAVEPOINT, rolling back when specified
rollbackWhen
  :: forall m a
   . (HasCallStack, MonadSqlBackend m)
  => (a -> Bool)
  -- ^ When to ROLLBACK based on the result of the action
  -> m a
  -- ^ The action to be run
  -> m a
rollbackWhen :: forall (m :: * -> *) a.
(HasCallStack, MonadSqlBackend m) =>
(a -> Bool) -> m a -> m a
rollbackWhen a -> Bool
shouldRollback m a
act = do
  Text
savepoint <- m Text
forall (m :: * -> *). (HasCallStack, MonadSqlBackend m) => m Text
newSavepoint
  a
a <- m a
act
  if a -> Bool
shouldRollback a
a
    then Text -> m ()
forall (m :: * -> *).
(HasCallStack, MonadSqlBackend m) =>
Text -> m ()
rollbackToSavepoint Text
savepoint
    else Text -> m ()
forall (m :: * -> *).
(HasCallStack, MonadSqlBackend m) =>
Text -> m ()
releaseSavepoint Text
savepoint
  a -> m a
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
a