module Hasql.Engine.PqProcedures.SelectTypeInfo
  ( SelectTypeInfo (..),
    SelectTypeInfoResult,
    run,
  )
where

import Data.HashMap.Strict qualified as HashMap
import Data.HashSet qualified as HashSet
import Hasql.Codecs.Decoders.Value qualified as Decoders.Value
import Hasql.Codecs.Vocab qualified as Vocab
import Hasql.Codecs.Vocab.QualifiedTypeName qualified as Vocab.QualifiedTypeName
import Hasql.Codecs.Vocab.TypeInfo qualified as Vocab.TypeInfo
import Hasql.Comms.ResultDecoder qualified
import Hasql.Comms.Roundtrip qualified
import Hasql.Comms.RowDecoder qualified
import Hasql.Engine.Errors qualified as Errors
import Hasql.Platform.Prelude
import Hasql.Pq qualified as Pq
import PostgreSQL.Binary.Encoding qualified as Binary

newtype SelectTypeInfo = SelectTypeInfo
  { -- | Set of (schema name, type name) pairs to look up.
    SelectTypeInfo -> HashSet QualifiedTypeName
keys :: HashSet Vocab.QualifiedTypeName
  }

-- | Result maps (schema name, type name) pairs to TypeInfo (scalar OID, array OID).
type SelectTypeInfoResult =
  HashMap Vocab.QualifiedTypeName Vocab.TypeInfo.TypeInfo

run :: Pq.Connection -> SelectTypeInfo -> IO (Either Errors.SessionError SelectTypeInfoResult)
run :: Connection
-> SelectTypeInfo -> IO (Either SessionError SelectTypeInfoResult)
run Connection
connection (SelectTypeInfo HashSet QualifiedTypeName
keys) =
  if HashSet QualifiedTypeName -> Bool
forall a. HashSet a -> Bool
HashSet.null HashSet QualifiedTypeName
keys
    then Either SessionError SelectTypeInfoResult
-> IO (Either SessionError SelectTypeInfoResult)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SelectTypeInfoResult -> Either SessionError SelectTypeInfoResult
forall a b. b -> Either a b
Right SelectTypeInfoResult
forall k v. HashMap k v
HashMap.empty)
    else
      (Error () -> SessionError)
-> Either (Error ()) SelectTypeInfoResult
-> Either SessionError SelectTypeInfoResult
forall a b c. (a -> b) -> Either a c -> Either b c
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first Error () -> SessionError
forall context. Error context -> SessionError
Errors.fromRoundtripError
        (Either (Error ()) SelectTypeInfoResult
 -> Either SessionError SelectTypeInfoResult)
-> IO (Either (Error ()) SelectTypeInfoResult)
-> IO (Either SessionError SelectTypeInfoResult)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Roundtrip () SelectTypeInfoResult
-> Connection -> IO (Either (Error ()) SelectTypeInfoResult)
forall context a.
Roundtrip context a -> Connection -> IO (Either (Error context) a)
Hasql.Comms.Roundtrip.toSerialIO (SelectTypeInfo -> Roundtrip () SelectTypeInfoResult
roundtrip (HashSet QualifiedTypeName -> SelectTypeInfo
SelectTypeInfo HashSet QualifiedTypeName
keys)) Connection
connection

sql :: ByteString
sql :: ByteString
sql =
  ByteString
"with\n\
  \  inputs as (\n\
  \    select *\n\
  \    from unnest($1, $2) as x(schema_name, type_name)\n\
  \  ),\n\
  \  unnamespaced_results as (\n\
  \    select\n\
  \      null :: text as schema_name,\n\
  \      pg_type.typname :: text as type_name,\n\
  \      pg_type.oid :: int4 as type_oid,\n\
  \      pg_type.typarray :: int4 as array_oid\n\
  \    from inputs\n\
  \    join pg_type on pg_type.oid = to_regtype(inputs.type_name)\n\
  \    where inputs.schema_name is null\n\
  \  ),\n\
  \  namespaced_results as (\n\
  \    select\n\
  \      pg_namespace.nspname :: text as schema_name,\n\
  \      pg_type.typname :: text as type_name,\n\
  \      pg_type.oid :: int4 as type_oid,\n\
  \      pg_type.typarray :: int4 as array_oid\n\
  \    from inputs\n\
  \    join pg_namespace on pg_namespace.nspname = inputs.schema_name\n\
  \    join pg_type\n\
  \      on pg_type.typname = inputs.type_name\n\
  \      and pg_type.typnamespace = pg_namespace.oid\n\
  \    where inputs.schema_name is not null\n\
  \  )\n\
  \select * from unnamespaced_results\n\
  \union\n\
  \select * from namespaced_results"

roundtrip :: SelectTypeInfo -> Hasql.Comms.Roundtrip.Roundtrip () SelectTypeInfoResult
roundtrip :: SelectTypeInfo -> Roundtrip () SelectTypeInfoResult
roundtrip SelectTypeInfo
params =
  ()
-> ByteString
-> [Maybe (Oid, ByteString, Format)]
-> Format
-> ResultDecoder SelectTypeInfoResult
-> Roundtrip () SelectTypeInfoResult
forall context a.
context
-> ByteString
-> [Maybe (Oid, ByteString, Format)]
-> Format
-> ResultDecoder a
-> Roundtrip context a
Hasql.Comms.Roundtrip.queryParams () ByteString
sql (SelectTypeInfo -> [Maybe (Oid, ByteString, Format)]
encodeParams SelectTypeInfo
params) Format
Pq.Binary ResultDecoder SelectTypeInfoResult
decoder

-- | Encode the two text-array parameters directly.
-- Text OID is 25; text-array OID is 1009.
encodeParams :: SelectTypeInfo -> [Maybe (Pq.Oid, ByteString, Pq.Format)]
encodeParams :: SelectTypeInfo -> [Maybe (Oid, ByteString, Format)]
encodeParams (SelectTypeInfo HashSet QualifiedTypeName
keys) =
  let ([Maybe Text]
schemaNames, [Text]
typeNames) = [(Maybe Text, Text)] -> ([Maybe Text], [Text])
forall a b. [(a, b)] -> ([a], [b])
unzip ((QualifiedTypeName -> (Maybe Text, Text))
-> [QualifiedTypeName] -> [(Maybe Text, Text)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap QualifiedTypeName -> (Maybe Text, Text)
Vocab.QualifiedTypeName.toNameTuple (HashSet QualifiedTypeName -> [QualifiedTypeName]
forall a. HashSet a -> [a]
HashSet.toList HashSet QualifiedTypeName
keys))
      schemaArray :: ByteString
schemaArray = Encoding -> ByteString
Binary.encodingBytes (Word32 -> Array -> Encoding
Binary.array Word32
25 ([Array] -> Array
forall {t :: * -> *}. Foldable t => t Array -> Array
encodeTextArray ([Maybe Text] -> [Array]
forall {f :: * -> *}. Functor f => f (Maybe Text) -> f Array
encodeMaybeText [Maybe Text]
schemaNames)))
      typeArray :: ByteString
typeArray = Encoding -> ByteString
Binary.encodingBytes (Word32 -> Array -> Encoding
Binary.array Word32
25 ([Array] -> Array
forall {t :: * -> *}. Foldable t => t Array -> Array
encodeTextArray ((Text -> Array) -> [Text] -> [Array]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Encoding -> Array
Binary.encodingArray (Encoding -> Array) -> (Text -> Encoding) -> Text -> Array
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Text -> Encoding
Binary.text_strict) [Text]
typeNames)))
   in [ (Oid, ByteString, Format) -> Maybe (Oid, ByteString, Format)
forall a. a -> Maybe a
Just (CUInt -> Oid
Pq.Oid CUInt
1009, ByteString
schemaArray, Format
Pq.Binary),
        (Oid, ByteString, Format) -> Maybe (Oid, ByteString, Format)
forall a. a -> Maybe a
Just (CUInt -> Oid
Pq.Oid CUInt
1009, ByteString
typeArray, Format
Pq.Binary)
      ]
  where
    encodeTextArray :: t Array -> Array
encodeTextArray t Array
elements =
      (forall b. (b -> Array -> b) -> b -> t Array -> b)
-> (Array -> Array) -> t Array -> Array
forall a c.
(forall b. (b -> a -> b) -> b -> c -> b)
-> (a -> Array) -> c -> Array
Binary.dimensionArray (b -> Array -> b) -> b -> t Array -> b
forall b. (b -> Array -> b) -> b -> t Array -> b
forall b a. (b -> a -> b) -> b -> t a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Array -> Array
forall a. a -> a
forall {k} (cat :: k -> k -> *) (a :: k). Category cat => cat a a
id t Array
elements
    encodeMaybeText :: f (Maybe Text) -> f Array
encodeMaybeText =
      (Maybe Text -> Array) -> f (Maybe Text) -> f Array
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap \case
        Maybe Text
Nothing -> Array
Binary.nullArray
        Just Text
text -> Encoding -> Array
Binary.encodingArray (Text -> Encoding
Binary.text_strict Text
text)

decoder :: Hasql.Comms.ResultDecoder.ResultDecoder SelectTypeInfoResult
decoder :: ResultDecoder SelectTypeInfoResult
decoder =
  (SelectTypeInfoResult
 -> (Maybe Text, Text, Word32, Word32) -> SelectTypeInfoResult)
-> SelectTypeInfoResult
-> RowDecoder (Maybe Text, Text, Word32, Word32)
-> ResultDecoder SelectTypeInfoResult
forall a b. (a -> b -> a) -> a -> RowDecoder b -> ResultDecoder a
Hasql.Comms.ResultDecoder.foldl SelectTypeInfoResult
-> (Maybe Text, Text, Word32, Word32) -> SelectTypeInfoResult
step SelectTypeInfoResult
forall k v. HashMap k v
HashMap.empty RowDecoder (Maybe Text, Text, Word32, Word32)
rowDecoder
  where
    step :: SelectTypeInfoResult
-> (Maybe Text, Text, Word32, Word32) -> SelectTypeInfoResult
step SelectTypeInfoResult
acc (Maybe Text
schemaName, Text
typeName, Word32
typeOid, Word32
arrayOid) =
      QualifiedTypeName
-> TypeInfo -> SelectTypeInfoResult -> SelectTypeInfoResult
forall k v. Hashable k => k -> v -> HashMap k v -> HashMap k v
HashMap.insert (Maybe Text -> Text -> QualifiedTypeName
Vocab.QualifiedTypeName.QualifiedTypeName Maybe Text
schemaName Text
typeName) (Word32 -> Word32 -> TypeInfo
Vocab.TypeInfo.TypeInfo Word32
typeOid Word32
arrayOid) SelectTypeInfoResult
acc

rowDecoder :: Hasql.Comms.RowDecoder.RowDecoder (Maybe Text, Text, Word32, Word32)
rowDecoder :: RowDecoder (Maybe Text, Text, Word32, Word32)
rowDecoder =
  (,,,)
    (Maybe Text
 -> Text -> Word32 -> Word32 -> (Maybe Text, Text, Word32, Word32))
-> RowDecoder (Maybe Text)
-> RowDecoder
     (Text -> Word32 -> Word32 -> (Maybe Text, Text, Word32, Word32))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Value Text -> RowDecoder (Maybe Text)
forall {a}. Value a -> RowDecoder (Maybe a)
nullableColumn Value Text
Decoders.Value.text
    RowDecoder
  (Text -> Word32 -> Word32 -> (Maybe Text, Text, Word32, Word32))
-> RowDecoder Text
-> RowDecoder
     (Word32 -> Word32 -> (Maybe Text, Text, Word32, Word32))
forall a b. RowDecoder (a -> b) -> RowDecoder a -> RowDecoder b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Value Text -> RowDecoder Text
forall {a}. Value a -> RowDecoder a
nonNullableColumn Value Text
Decoders.Value.text
    RowDecoder (Word32 -> Word32 -> (Maybe Text, Text, Word32, Word32))
-> RowDecoder Word32
-> RowDecoder (Word32 -> (Maybe Text, Text, Word32, Word32))
forall a b. RowDecoder (a -> b) -> RowDecoder a -> RowDecoder b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Value Word32 -> RowDecoder Word32
forall {a}. Value a -> RowDecoder a
nonNullableColumn (Int32 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32 -> Word32) -> Value Int32 -> Value Word32
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Value Int32
Decoders.Value.int4)
    RowDecoder (Word32 -> (Maybe Text, Text, Word32, Word32))
-> RowDecoder Word32
-> RowDecoder (Maybe Text, Text, Word32, Word32)
forall a b. RowDecoder (a -> b) -> RowDecoder a -> RowDecoder b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Value Word32 -> RowDecoder Word32
forall {a}. Value a -> RowDecoder a
nonNullableColumn (Int32 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32 -> Word32) -> Value Int32 -> Value Word32
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Value Int32
Decoders.Value.int4)
  where
    nullableColumn :: Value a -> RowDecoder (Maybe a)
nullableColumn Value a
valueDecoder =
      Maybe Word32
-> (ByteString -> Either Text a) -> RowDecoder (Maybe a)
forall a.
Maybe Word32
-> (ByteString -> Either Text a) -> RowDecoder (Maybe a)
Hasql.Comms.RowDecoder.nullableColumn
        (Value a -> Maybe Word32
forall a. Value a -> Maybe Word32
Decoders.Value.toBaseOid Value a
valueDecoder)
        (Value a -> OidCache -> ByteString -> Either Text a
forall a. Value a -> OidCache -> ByteString -> Either Text a
Decoders.Value.toByteStringParser Value a
valueDecoder OidCache
forall a. Monoid a => a
mempty)

    nonNullableColumn :: Value a -> RowDecoder a
nonNullableColumn Value a
valueDecoder =
      Maybe Word32 -> (ByteString -> Either Text a) -> RowDecoder a
forall a.
Maybe Word32 -> (ByteString -> Either Text a) -> RowDecoder a
Hasql.Comms.RowDecoder.nonNullableColumn
        (Value a -> Maybe Word32
forall a. Value a -> Maybe Word32
Decoders.Value.toBaseOid Value a
valueDecoder)
        (Value a -> OidCache -> ByteString -> Either Text a
forall a. Value a -> OidCache -> ByteString -> Either Text a
Decoders.Value.toByteStringParser Value a
valueDecoder OidCache
forall a. Monoid a => a
mempty)