module Database.Persist.Sql.Lifted.MonadSqlBackend
  ( MonadSqlBackend (..)
  , liftSql
  ) where

import Prelude

import Control.Exception.Annotated.UnliftIO (checkpointCallStack)
import Control.Monad.IO.Unlift (MonadUnliftIO)
import Control.Monad.Reader (ReaderT (..), asks)
import Database.Persist.Sql (SqlBackend)
import Database.Persist.Sql.Lifted.HasSqlBackend (HasSqlBackend, getSqlBackend)
import GHC.Stack (HasCallStack)

-- | A monadic context in which a SQL backend is available
--   for running database queries
class MonadUnliftIO m => MonadSqlBackend m where
  getSqlBackendM :: m SqlBackend

instance (HasSqlBackend r, MonadUnliftIO m) => MonadSqlBackend (ReaderT r m) where
  getSqlBackendM :: ReaderT r m SqlBackend
getSqlBackendM = (r -> SqlBackend) -> ReaderT r m SqlBackend
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks r -> SqlBackend
forall a. HasSqlBackend a => a -> SqlBackend
getSqlBackend

-- | Generalize from 'SqlPersistT' to 'MonadSqlBackend'
liftSql
  :: forall m a. (HasCallStack, MonadSqlBackend m) => ReaderT SqlBackend m a -> m a
liftSql :: forall (m :: * -> *) a.
(HasCallStack, MonadSqlBackend m) =>
ReaderT SqlBackend m a -> m a
liftSql (ReaderT SqlBackend -> m a
f) = m a -> m a
forall (m :: * -> *) a.
(MonadUnliftIO m, HasCallStack) =>
m a -> m a
checkpointCallStack (m a -> m a) -> m a -> m a
forall a b. (a -> b) -> a -> b
$ m SqlBackend
forall (m :: * -> *). MonadSqlBackend m => m SqlBackend
getSqlBackendM m SqlBackend -> (SqlBackend -> m a) -> m a
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= SqlBackend -> m a
f