module Hasql.Engine.Statement
  ( Statement (..),
    preparable,
    unpreparable,
    refineResult,
    toSql,
    compilePreparedStatementData,
    compileUnpreparedStatementData,
  )
where

import Data.Text.Encoding qualified as TextEncoding
import Data.Vector qualified as Vector
import Hasql.Codecs.Encoders.Params qualified as Params
import Hasql.Codecs.RequestingOid qualified as RequestingOid
import Hasql.Codecs.Vocab qualified as Vocab
import Hasql.Codecs.Vocab.OidCache qualified as Vocab.OidCache
import Hasql.Codecs.Vocab.ParamMeta (ParamMeta (..))
import Hasql.Codecs.Vocab.TypeRef qualified as Vocab.TypeRef
import Hasql.Comms.ResultDecoder qualified as ResultDecoder
import Hasql.Decoders qualified as Decoders
import Hasql.Encoders qualified as Encoders
import Hasql.Engine.Decoders.Result qualified as Decoders.Result
import Hasql.Platform.Prelude

-- |
-- Specification of a strictly single-statement query, which can be parameterized and prepared.
-- It encapsulates the mapping of parameters and results in association with an SQL template.
--
-- Following is an example of a declaration of a prepared statement with its associated codecs.
--
-- @
-- selectSum :: 'Statement' (Int64, Int64) Int64
-- selectSum =
--   'preparable' sql encoder decoder
--   where
--     sql =
--       \"select ($1 + $2)\"
--     encoder =
--       ('fst' '>$<' Encoders.'Hasql.Encoders.param' (Encoders.'Hasql.Encoders.nonNullable' Encoders.'Hasql.Encoders.int8')) '<>'
--       ('snd' '>$<' Encoders.'Hasql.Encoders.param' (Encoders.'Hasql.Encoders.nonNullable' Encoders.'Hasql.Encoders.int8'))
--     decoder =
--       Decoders.'Hasql.Decoders.singleRow' (Decoders.'Hasql.Decoders.column' (Decoders.'Hasql.Decoders.nonNullable' Decoders.'Hasql.Decoders.int8'))
-- @
--
-- The statement above accepts a product of two parameters of type 'Int64'
-- and produces a single result of type 'Int64'.
data Statement params result
  = Statement
  { -- | SQL template pre-encoded as UTF-8 for execution.
    forall params result. Statement params result -> ByteString
sql :: ByteString,
    -- | Frozen per-parameter metadata: type reference, dimensionality, text-format flag.
    -- Produced once at construction from the Params DList and reused across executions.
    forall params result. Statement params result -> Vector ParamMeta
columnsMetadata :: Vector ParamMeta,
    -- | Serialise params to encoded wire values given a resolved OID cache.
    forall params result.
Statement params result -> OidCache -> params -> [Maybe ByteString]
serializer :: Vocab.OidCache -> params -> [Maybe ByteString],
    -- | Render params in human-readable form (for error reporting).
    forall params result. Statement params result -> params -> [Text]
printer :: params -> [Text],
    -- | Union of encoder and decoder unknown types, resolved once at construction.
    forall params result.
Statement params result -> HashSet QualifiedTypeName
unknownTypes :: HashSet Vocab.QualifiedTypeName,
    -- | Unwrapped result decoder (RequestingOid layer already peeled from Result).
    forall params result.
Statement params result -> RequestingOid (ResultDecoder result)
decoder :: RequestingOid.RequestingOid (ResultDecoder.ResultDecoder result),
    -- | Whether this statement may be prepared on the server.
    forall params result. Statement params result -> Bool
isPrepared :: Bool
  }

-- |
-- Construct a preparable statement.
--
-- Use this for statements that will be executed multiple times with different parameters.
-- Preparable statements are cached by PostgreSQL, which avoids reconstructing the execution plan each time.
--
-- Suitable for applications with a limited amount of queries that don't generate SQL dynamically.
preparable ::
  -- | SQL template with parameters in positional notation (@$1@, @$2@, etc.)
  Text ->
  -- | Parameters encoder
  Encoders.Params params ->
  -- | Result decoder
  Decoders.Result result ->
  Statement params result
preparable :: forall params result.
Text -> Params params -> Result result -> Statement params result
preparable Text
sqlText Params params
encoder Result result
resultDecoder =
  Statement
    { sql :: ByteString
sql = Text -> ByteString
TextEncoding.encodeUtf8 Text
sqlText,
      columnsMetadata :: Vector ParamMeta
columnsMetadata = Params params -> Vector ParamMeta
forall a. Params a -> Vector ParamMeta
Params.toColumnsMetadata Params params
encoder,
      serializer :: OidCache -> params -> [Maybe ByteString]
serializer = Params params -> OidCache -> params -> [Maybe ByteString]
forall a. Params a -> OidCache -> a -> [Maybe ByteString]
Params.toSerializer Params params
encoder,
      printer :: params -> [Text]
printer = Params params -> params -> [Text]
forall a. Params a -> a -> [Text]
Params.toPrinter Params params
encoder,
      unknownTypes :: HashSet QualifiedTypeName
unknownTypes = Params params -> HashSet QualifiedTypeName
forall a. Params a -> HashSet QualifiedTypeName
Params.toUnknownTypes Params params
encoder HashSet QualifiedTypeName
-> HashSet QualifiedTypeName -> HashSet QualifiedTypeName
forall a. Semigroup a => a -> a -> a
<> RequestingOid (ResultDecoder result) -> HashSet QualifiedTypeName
forall a. RequestingOid a -> HashSet QualifiedTypeName
RequestingOid.toUnknownTypes RequestingOid (ResultDecoder result)
rawDecoder,
      decoder :: RequestingOid (ResultDecoder result)
decoder = RequestingOid (ResultDecoder result)
rawDecoder,
      isPrepared :: Bool
isPrepared = Bool
True
    }
  where
    rawDecoder :: RequestingOid (ResultDecoder result)
rawDecoder = Result result -> RequestingOid (ResultDecoder result)
forall a. Result a -> RequestingOid (ResultDecoder a)
Decoders.Result.unwrap Result result
resultDecoder

-- |
-- Construct an unpreparable statement.
--
-- Use this for statements that are dynamically generated or executed only once.
-- Unpreparable statements are not cached by PostgreSQL.
--
-- Suitable for dynamic SQL or one-off queries.
unpreparable ::
  -- | SQL template with parameters in positional notation (@$1@, @$2@, etc.)
  Text ->
  -- | Parameters encoder
  Encoders.Params params ->
  -- | Result decoder
  Decoders.Result result ->
  Statement params result
unpreparable :: forall params result.
Text -> Params params -> Result result -> Statement params result
unpreparable Text
sqlText Params params
encoder Result result
resultDecoder =
  Statement
    { sql :: ByteString
sql = Text -> ByteString
TextEncoding.encodeUtf8 Text
sqlText,
      columnsMetadata :: Vector ParamMeta
columnsMetadata = Params params -> Vector ParamMeta
forall a. Params a -> Vector ParamMeta
Params.toColumnsMetadata Params params
encoder,
      serializer :: OidCache -> params -> [Maybe ByteString]
serializer = Params params -> OidCache -> params -> [Maybe ByteString]
forall a. Params a -> OidCache -> a -> [Maybe ByteString]
Params.toSerializer Params params
encoder,
      printer :: params -> [Text]
printer = Params params -> params -> [Text]
forall a. Params a -> a -> [Text]
Params.toPrinter Params params
encoder,
      unknownTypes :: HashSet QualifiedTypeName
unknownTypes = Params params -> HashSet QualifiedTypeName
forall a. Params a -> HashSet QualifiedTypeName
Params.toUnknownTypes Params params
encoder HashSet QualifiedTypeName
-> HashSet QualifiedTypeName -> HashSet QualifiedTypeName
forall a. Semigroup a => a -> a -> a
<> RequestingOid (ResultDecoder result) -> HashSet QualifiedTypeName
forall a. RequestingOid a -> HashSet QualifiedTypeName
RequestingOid.toUnknownTypes RequestingOid (ResultDecoder result)
rawDecoder,
      decoder :: RequestingOid (ResultDecoder result)
decoder = RequestingOid (ResultDecoder result)
rawDecoder,
      isPrepared :: Bool
isPrepared = Bool
False
    }
  where
    rawDecoder :: RequestingOid (ResultDecoder result)
rawDecoder = Result result -> RequestingOid (ResultDecoder result)
forall a. Result a -> RequestingOid (ResultDecoder a)
Decoders.Result.unwrap Result result
resultDecoder

instance Functor (Statement params) where
  {-# INLINE fmap #-}
  fmap :: forall a b. (a -> b) -> Statement params a -> Statement params b
fmap a -> b
f Statement params a
stmt = Statement params a
stmt {decoder = fmap (fmap f) (decoder stmt)}

instance Filterable (Statement params) where
  {-# INLINE mapMaybe #-}
  mapMaybe :: forall a b.
(a -> Maybe b) -> Statement params a -> Statement params b
mapMaybe a -> Maybe b
filtrator Statement params a
stmt = Statement params a
stmt {decoder = fmap (mapMaybe filtrator) (decoder stmt)}

instance Profunctor Statement where
  {-# INLINE dimap #-}
  dimap :: forall a b c d.
(a -> b) -> (c -> d) -> Statement b c -> Statement a d
dimap a -> b
f1 c -> d
f2 Statement b c
stmt =
    Statement b c
stmt
      { serializer = \OidCache
oidCache -> Statement b c -> OidCache -> b -> [Maybe ByteString]
forall params result.
Statement params result -> OidCache -> params -> [Maybe ByteString]
serializer Statement b c
stmt OidCache
oidCache (b -> [Maybe ByteString]) -> (a -> b) -> a -> [Maybe ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. a -> b
f1,
        printer = printer stmt . f1,
        decoder = fmap (fmap f2) (decoder stmt)
      }

-- |
-- Refine the result of a statement,
-- causing the running session to fail with the 'Hasql.Errors.UnexpectedResultStatementError' error in case of a refinement failure.
--
-- This function is especially useful for refining the results of statements produced with
-- <http://hackage.haskell.org/package/hasql-th the \"hasql-th\" library>.
refineResult :: (a -> Either Text b) -> Statement params a -> Statement params b
refineResult :: forall a b params.
(a -> Either Text b) -> Statement params a -> Statement params b
refineResult a -> Either Text b
refiner Statement params a
stmt = Statement params a
stmt {decoder = fmap (ResultDecoder.refine refiner) (decoder stmt)}

-- | Extract the SQL template from a statement.
toSql :: Statement params result -> Text
toSql :: forall params result. Statement params result -> Text
toSql Statement params result
stmt = ByteString -> Text
TextEncoding.decodeUtf8Lenient (Statement params result -> ByteString
forall params result. Statement params result -> ByteString
sql Statement params result
stmt)

-- | Compile prepared-statement data: resolve OIDs and pair encoded values with their format flags.
compilePreparedStatementData ::
  Statement params result ->
  Vocab.OidCache ->
  params ->
  ([Word32], [Maybe (ByteString, Bool)])
compilePreparedStatementData :: forall params result.
Statement params result
-> OidCache -> params -> ([Word32], [Maybe (ByteString, Bool)])
compilePreparedStatementData Statement params result
stmt OidCache
oidCache params
params =
  [(Word32, Maybe (ByteString, Bool))]
-> ([Word32], [Maybe (ByteString, Bool)])
forall a b. [(a, b)] -> ([a], [b])
unzip
    ([(Word32, Maybe (ByteString, Bool))]
 -> ([Word32], [Maybe (ByteString, Bool)]))
-> [(Word32, Maybe (ByteString, Bool))]
-> ([Word32], [Maybe (ByteString, Bool)])
forall a b. (a -> b) -> a -> b
$ (ParamMeta
 -> Maybe ByteString -> (Word32, Maybe (ByteString, Bool)))
-> [ParamMeta]
-> [Maybe ByteString]
-> [(Word32, Maybe (ByteString, Bool))]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
      (\(ParamMeta TypeRef
typeRef Word
dim Bool
fmt) Maybe ByteString
encoding -> (TypeRef -> Word -> Word32
resolveOid TypeRef
typeRef Word
dim, (ByteString -> (ByteString, Bool))
-> Maybe ByteString -> Maybe (ByteString, Bool)
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (,Bool
fmt) Maybe ByteString
encoding))
      (Vector ParamMeta -> [ParamMeta]
forall a. Vector a -> [a]
Vector.toList (Statement params result -> Vector ParamMeta
forall params result. Statement params result -> Vector ParamMeta
columnsMetadata Statement params result
stmt))
      (Statement params result -> OidCache -> params -> [Maybe ByteString]
forall params result.
Statement params result -> OidCache -> params -> [Maybe ByteString]
serializer Statement params result
stmt OidCache
oidCache params
params)
  where
    resolveOid :: TypeRef -> Word -> Word32
resolveOid (Vocab.TypeRef.NamedType QualifiedTypeName
name) Word
dim =
      case QualifiedTypeName -> OidCache -> Maybe Word32
Vocab.OidCache.lookupTypeNameScalar QualifiedTypeName
name OidCache
oidCache of
        Just Word32
oid -> if Word
dim Word -> Word -> Bool
forall a. Eq a => a -> a -> Bool
== Word
0 then Word32
oid else Word32 -> Maybe Word32 -> Word32
forall a. a -> Maybe a -> a
fromMaybe Word32
0 (QualifiedTypeName -> OidCache -> Maybe Word32
Vocab.OidCache.lookupTypeNameArray QualifiedTypeName
name OidCache
oidCache)
        Maybe Word32
Nothing -> Word32
0
    resolveOid (Vocab.TypeRef.KnownOid Word32
oid) Word
_ = Word32
oid

-- | Compile unprepared-statement data: resolve OIDs inline with encoded values.
compileUnpreparedStatementData ::
  Statement params result ->
  Vocab.OidCache ->
  params ->
  [Maybe (Word32, ByteString, Bool)]
compileUnpreparedStatementData :: forall params result.
Statement params result
-> OidCache -> params -> [Maybe (Word32, ByteString, Bool)]
compileUnpreparedStatementData Statement params result
stmt OidCache
oidCache params
params =
  (ParamMeta -> Maybe ByteString -> Maybe (Word32, ByteString, Bool))
-> [ParamMeta]
-> [Maybe ByteString]
-> [Maybe (Word32, ByteString, Bool)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
    (\(ParamMeta TypeRef
typeRef Word
dim Bool
fmt) Maybe ByteString
encoding -> (,,) (Word32 -> ByteString -> Bool -> (Word32, ByteString, Bool))
-> Maybe Word32
-> Maybe (ByteString -> Bool -> (Word32, ByteString, Bool))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Word32 -> Maybe Word32
forall a. a -> Maybe a
Just (TypeRef -> Word -> Word32
resolveOid TypeRef
typeRef Word
dim) Maybe (ByteString -> Bool -> (Word32, ByteString, Bool))
-> Maybe ByteString -> Maybe (Bool -> (Word32, ByteString, Bool))
forall a b. Maybe (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Maybe ByteString
encoding Maybe (Bool -> (Word32, ByteString, Bool))
-> Maybe Bool -> Maybe (Word32, ByteString, Bool)
forall a b. Maybe (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
fmt)
    (Vector ParamMeta -> [ParamMeta]
forall a. Vector a -> [a]
Vector.toList (Statement params result -> Vector ParamMeta
forall params result. Statement params result -> Vector ParamMeta
columnsMetadata Statement params result
stmt))
    (Statement params result -> OidCache -> params -> [Maybe ByteString]
forall params result.
Statement params result -> OidCache -> params -> [Maybe ByteString]
serializer Statement params result
stmt OidCache
oidCache params
params)
  where
    resolveOid :: TypeRef -> Word -> Word32
resolveOid (Vocab.TypeRef.NamedType QualifiedTypeName
name) Word
dim =
      case QualifiedTypeName -> OidCache -> Maybe Word32
Vocab.OidCache.lookupTypeNameScalar QualifiedTypeName
name OidCache
oidCache of
        Just Word32
oid -> if Word
dim Word -> Word -> Bool
forall a. Eq a => a -> a -> Bool
== Word
0 then Word32
oid else Word32 -> Maybe Word32 -> Word32
forall a. a -> Maybe a -> a
fromMaybe Word32
0 (QualifiedTypeName -> OidCache -> Maybe Word32
Vocab.OidCache.lookupTypeNameArray QualifiedTypeName
name OidCache
oidCache)
        Maybe Word32
Nothing -> Word32
0
    resolveOid (Vocab.TypeRef.KnownOid Word32
oid) Word
_ = Word32
oid