module Database.PostgreSQL.PQTypes.Internal.Utils
  ( MkConstraint
  , mread
  , safePeekCString
  , safePeekCString'
  , cStringLenToBytea
  , byteaToCStringLen
  , textToCString
  , verifyPQTRes
  , withPGparam
  , throwLibPQError
  , throwLibPQTypesError
  , rethrowWithArrayError
  , hpqTypesError
  , unexpectedNULL
  ) where

import Control.Exception qualified as E
import Control.Monad
import Data.ByteString.Unsafe
import Data.Kind (Type)
import Data.Maybe
import Data.Text qualified as T
import Data.Text.Encoding qualified as T
import Foreign.C
import Foreign.ForeignPtr
import Foreign.Marshal.Alloc
import Foreign.Marshal.Utils
import Foreign.Ptr
import Foreign.Storable
import GHC.Exts
import GHC.Stack

import Database.PostgreSQL.PQTypes.Internal.C.Interface
import Database.PostgreSQL.PQTypes.Internal.C.Types
import Database.PostgreSQL.PQTypes.Internal.Error

type family
  MkConstraint
    (m :: Type -> Type)
    (cs :: [(Type -> Type) -> Constraint])
    :: Constraint
  where
  MkConstraint m '[] = ()
  MkConstraint m (c ': cs) = (c m, MkConstraint m cs)

-- Safely read value.
mread :: Read a => String -> Maybe a
mread :: forall a. Read a => String -> Maybe a
mread String
s = do
  [(a, "")] <- [(a, String)] -> Maybe [(a, String)]
forall a. a -> Maybe a
Just (ReadS a
forall a. Read a => ReadS a
reads String
s)
  Just a

-- | Safely peek C string.
safePeekCString :: CString -> IO (Maybe String)
safePeekCString :: CString -> IO (Maybe String)
safePeekCString CString
cs
  | CString
cs CString -> CString -> Bool
forall a. Eq a => a -> a -> Bool
== CString
forall a. Ptr a
nullPtr = Maybe String -> IO (Maybe String)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe String
forall a. Maybe a
Nothing
  | Bool
otherwise = String -> Maybe String
forall a. a -> Maybe a
Just (String -> Maybe String) -> IO String -> IO (Maybe String)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CString -> IO String
peekCString CString
cs

-- | Safely peek C string and return "" if NULL.
safePeekCString' :: CString -> IO String
safePeekCString' :: CString -> IO String
safePeekCString' CString
cs = String -> Maybe String -> String
forall a. a -> Maybe a -> a
fromMaybe String
"" (Maybe String -> String) -> IO (Maybe String) -> IO String
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CString -> IO (Maybe String)
safePeekCString CString
cs

-- | Convert C string to 'PGbytea'.
cStringLenToBytea :: CStringLen -> PGbytea
cStringLenToBytea :: CStringLen -> PGbytea
cStringLenToBytea (CString
cs, Int
len) =
  PGbytea
    { pgByteaLen :: CInt
pgByteaLen = Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len
    , pgByteaData :: CString
pgByteaData = CString
cs
    }

-- | Convert 'PGbytea' to C string.
byteaToCStringLen :: PGbytea -> CStringLen
byteaToCStringLen :: PGbytea -> CStringLen
byteaToCStringLen PGbytea {CString
CInt
pgByteaLen :: PGbytea -> CInt
pgByteaData :: PGbytea -> CString
pgByteaLen :: CInt
pgByteaData :: CString
..} = (CString
pgByteaData, CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CInt
pgByteaLen)

-- | Convert 'Text' to UTF-8 encoded C string wrapped by foreign pointer.
textToCString :: T.Text -> IO (ForeignPtr CChar)
textToCString :: Text -> IO (ForeignPtr CChar)
textToCString Text
bs = ByteString
-> (CStringLen -> IO (ForeignPtr CChar)) -> IO (ForeignPtr CChar)
forall a. ByteString -> (CStringLen -> IO a) -> IO a
unsafeUseAsCStringLen (Text -> ByteString
T.encodeUtf8 Text
bs) ((CStringLen -> IO (ForeignPtr CChar)) -> IO (ForeignPtr CChar))
-> (CStringLen -> IO (ForeignPtr CChar)) -> IO (ForeignPtr CChar)
forall a b. (a -> b) -> a -> b
$ \(CString
cs, Int
len) -> do
  fptr <- Int -> IO (ForeignPtr CChar)
forall a. Int -> IO (ForeignPtr a)
mallocForeignPtrBytes (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
  withForeignPtr fptr $ \CString
ptr -> do
    CString -> CString -> Int -> IO ()
forall a. Ptr a -> Ptr a -> Int -> IO ()
copyBytes CString
ptr CString
cs Int
len
    CString -> Int -> CChar -> IO ()
forall b. Ptr b -> Int -> CChar -> IO ()
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff CString
ptr Int
len (CChar
0 :: CChar)
  pure fptr

-- | Check return value of a function from libpqtypes
-- and if it indicates an error, throw appropriate exception.
verifyPQTRes :: HasCallStack => Ptr PGerror -> String -> CInt -> IO ()
verifyPQTRes :: HasCallStack => Ptr PGerror -> String -> CInt -> IO ()
verifyPQTRes Ptr PGerror
err String
ctx CInt
0 = Ptr PGerror -> String -> IO ()
forall a. HasCallStack => Ptr PGerror -> String -> IO a
throwLibPQTypesError Ptr PGerror
err String
ctx
verifyPQTRes Ptr PGerror
_ String
_ CInt
_ = () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- 'alloca'-like function for managing usage of 'PGparam' object.
withPGparam :: HasCallStack => Ptr PGconn -> (Ptr PGparam -> IO r) -> IO r
withPGparam :: forall r.
HasCallStack =>
Ptr PGconn -> (Ptr PGparam -> IO r) -> IO r
withPGparam Ptr PGconn
conn = IO (Ptr PGparam)
-> (Ptr PGparam -> IO ()) -> (Ptr PGparam -> IO r) -> IO r
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracket IO (Ptr PGparam)
create Ptr PGparam -> IO ()
c_PQparamClear
  where
    create :: IO (Ptr PGparam)
create = (Ptr PGerror -> IO (Ptr PGparam)) -> IO (Ptr PGparam)
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr PGerror -> IO (Ptr PGparam)) -> IO (Ptr PGparam))
-> (Ptr PGerror -> IO (Ptr PGparam)) -> IO (Ptr PGparam)
forall a b. (a -> b) -> a -> b
$ \Ptr PGerror
err -> do
      param <- Ptr PGconn -> Ptr PGerror -> IO (Ptr PGparam)
c_PQparamCreate Ptr PGconn
conn Ptr PGerror
err
      when (param == nullPtr) $
        throwLibPQTypesError err "withPGparam.create"
      pure param

----------------------------------------

-- | Throw libpq specific error.
throwLibPQError :: HasCallStack => Ptr PGconn -> String -> IO a
throwLibPQError :: forall a. HasCallStack => Ptr PGconn -> String -> IO a
throwLibPQError Ptr PGconn
conn String
ctx = do
  msg <- CString -> IO String
safePeekCString' (CString -> IO String) -> IO CString -> IO String
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr PGconn -> IO CString
c_PQerrorMessage Ptr PGconn
conn
  E.throwIO . LibPQError $
    if null ctx then msg else ctx ++ ": " ++ msg

-- | Throw libpqtypes specific error.
throwLibPQTypesError :: HasCallStack => Ptr PGerror -> String -> IO a
throwLibPQTypesError :: forall a. HasCallStack => Ptr PGerror -> String -> IO a
throwLibPQTypesError Ptr PGerror
err String
ctx = do
  msg <- PGerror -> String
pgErrorMsg (PGerror -> String) -> IO PGerror -> IO String
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr PGerror -> IO PGerror
forall a. Storable a => Ptr a -> IO a
peek Ptr PGerror
err
  E.throwIO . LibPQError $
    if null ctx then msg else ctx ++ ": " ++ msg

-- | Rethrow supplied exception enriched with array index.
rethrowWithArrayError :: HasCallStack => CInt -> E.SomeException -> IO a
rethrowWithArrayError :: forall a. HasCallStack => CInt -> SomeException -> IO a
rethrowWithArrayError CInt
i (E.SomeException e
e) =
  ArrayItemError -> IO a
forall e a. (HasCallStack, Exception e) => e -> IO a
E.throwIO
    ArrayItemError
      { arrItemIndex :: Int
arrItemIndex = CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CInt
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
      , arrItemError :: e
arrItemError = e
e
      }

-- | Throw 'HPQTypesError exception.
hpqTypesError :: HasCallStack => String -> IO a
hpqTypesError :: forall a. HasCallStack => String -> IO a
hpqTypesError = HPQTypesError -> IO a
forall e a. (HasCallStack, Exception e) => e -> IO a
E.throwIO (HPQTypesError -> IO a)
-> (String -> HPQTypesError) -> String -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> HPQTypesError
HPQTypesError

-- | Throw 'unexpected NULL' exception.
unexpectedNULL :: HasCallStack => IO a
unexpectedNULL :: forall a. HasCallStack => IO a
unexpectedNULL = String -> IO a
forall a. HasCallStack => String -> IO a
hpqTypesError String
"unexpected NULL"