{-# LANGUAGE MagicHash #-}

-- We capitalize module names, but use camelCase/PascalCase in code:
--
-- - in types names:    FlamFoo, FooFlamBar
-- - in variable names: flamFoo, fooFlamBar

-- | Intended for qualified import.
--
-- @
-- import HsBindgen.Runtime.FLAM (WithFlam)
-- import HsBindgen.Runtime.FLAM qualified as FLAM
-- @
module HsBindgen.Runtime.FLAM (
    -- * Definitions
    Offset (..),
    NumElems (..),
    WithFlam (..),
    -- * Exceptions
    FlamLengthMismatch (..),
) where

import Control.Exception (Exception, throwIO)
import Data.Kind (Type)
import Data.Vector.Storable qualified as VS
import Data.Vector.Storable.Mutable qualified as VSM
import Foreign (Ptr, Storable)
import Foreign qualified
import GHC.Exts (Proxy#, proxy#)

import HsBindgen.Runtime.Marshal

{-------------------------------------------------------------------------------
  Definitions
-------------------------------------------------------------------------------}

class Offset elem aux | aux -> elem where
  offset :: Proxy# aux -> Int

class Offset elem aux => NumElems elem aux | aux -> elem where
  numElems :: aux -> Int

-- | Data structure with flexible array member
data WithFlam elem aux = WithFlam
    { -- Underlying data structure without FLAM
      forall elem aux. WithFlam elem aux -> aux
aux  :: !aux
      -- We use the word "flam" for the flexible array member of the struct.
      -- We use the word "vector" to refer to its Haskell representation (as a
      -- vector).
    , forall elem aux. WithFlam elem aux -> Vector elem
flam :: {-# UNPACK #-} !(VS.Vector elem)
    }
  deriving stock Int -> WithFlam elem aux -> ShowS
[WithFlam elem aux] -> ShowS
WithFlam elem aux -> String
(Int -> WithFlam elem aux -> ShowS)
-> (WithFlam elem aux -> String)
-> ([WithFlam elem aux] -> ShowS)
-> Show (WithFlam elem aux)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall elem aux.
(Show aux, Show elem, Storable elem) =>
Int -> WithFlam elem aux -> ShowS
forall elem aux.
(Show aux, Show elem, Storable elem) =>
[WithFlam elem aux] -> ShowS
forall elem aux.
(Show aux, Show elem, Storable elem) =>
WithFlam elem aux -> String
$cshowsPrec :: forall elem aux.
(Show aux, Show elem, Storable elem) =>
Int -> WithFlam elem aux -> ShowS
showsPrec :: Int -> WithFlam elem aux -> ShowS
$cshow :: forall elem aux.
(Show aux, Show elem, Storable elem) =>
WithFlam elem aux -> String
show :: WithFlam elem aux -> String
$cshowList :: forall elem aux.
(Show aux, Show elem, Storable elem) =>
[WithFlam elem aux] -> ShowS
showList :: [WithFlam elem aux] -> ShowS
Show

instance
       (Storable aux, Storable elem, NumElems elem aux)
    => ReadRaw (WithFlam elem aux) where
  readRaw :: Ptr (WithFlam elem aux) -> IO (WithFlam elem aux)
readRaw = Ptr (WithFlam elem aux) -> IO (WithFlam elem aux)
forall aux elem.
(Storable aux, Storable elem, NumElems elem aux) =>
Ptr (WithFlam elem aux) -> IO (WithFlam elem aux)
peek

instance
       (Storable aux, Storable elem, NumElems elem aux )
    => WriteRaw (WithFlam elem aux) where
  writeRaw :: Ptr (WithFlam elem aux) -> WithFlam elem aux -> IO ()
writeRaw = Ptr (WithFlam elem aux) -> WithFlam elem aux -> IO ()
forall aux elem.
(Storable aux, Storable elem, NumElems elem aux) =>
Ptr (WithFlam elem aux) -> WithFlam elem aux -> IO ()
poke

{-------------------------------------------------------------------------------
  Peek and poke
-------------------------------------------------------------------------------}

-- | Peek structure with flexible array member.
peek :: forall aux elem.
     (Storable aux , Storable elem, NumElems elem aux)
  => Ptr (WithFlam elem aux) -> IO (WithFlam elem aux)
peek :: forall aux elem.
(Storable aux, Storable elem, NumElems elem aux) =>
Ptr (WithFlam elem aux) -> IO (WithFlam elem aux)
peek Ptr (WithFlam elem aux)
ptrStruct = do
    aux
aux <- Ptr aux -> IO aux
forall a. Storable a => Ptr a -> IO a
Foreign.peek (Ptr (WithFlam elem aux) -> Ptr aux
forall elem aux. Ptr (WithFlam elem aux) -> Ptr aux
ptrToAux Ptr (WithFlam elem aux)
ptrStruct)
    let Size{Int
sizeNumElems :: Int
sizeNumElems :: Size -> Int
sizeNumElems, Int
sizeNumBytes :: Int
sizeNumBytes :: Size -> Int
sizeNumBytes} = aux -> Size
forall elem aux. (NumElems elem aux, Storable elem) => aux -> Size
flamSize aux
aux
    MVector RealWorld elem
vector <- Int -> IO (MVector (PrimState IO) elem)
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
VSM.unsafeNew Int
sizeNumElems
    ForeignPtr elem -> (Ptr elem -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
Foreign.withForeignPtr ((ForeignPtr elem, Int) -> ForeignPtr elem
forall a b. (a, b) -> a
fst (MVector RealWorld elem -> (ForeignPtr elem, Int)
forall s a. MVector s a -> (ForeignPtr a, Int)
VSM.unsafeToForeignPtr0 MVector RealWorld elem
vector)) ((Ptr elem -> IO ()) -> IO ()) -> (Ptr elem -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr elem
ptrVectorElems -> do
        Ptr elem -> Ptr elem -> Int -> IO ()
forall a. Ptr a -> Ptr a -> Int -> IO ()
Foreign.copyBytes Ptr elem
ptrVectorElems (Ptr (WithFlam elem aux) -> Ptr elem
forall elem aux.
Offset elem aux =>
Ptr (WithFlam elem aux) -> Ptr elem
ptrToFlam Ptr (WithFlam elem aux)
ptrStruct) Int
sizeNumBytes
    Vector elem
vector' <- MVector (PrimState IO) elem -> IO (Vector elem)
forall a (m :: * -> *).
(Storable a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VS.unsafeFreeze MVector RealWorld elem
MVector (PrimState IO) elem
vector
    WithFlam elem aux -> IO (WithFlam elem aux)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (aux -> Vector elem -> WithFlam elem aux
forall elem aux. aux -> Vector elem -> WithFlam elem aux
WithFlam aux
aux Vector elem
vector')

-- | Poke structure with flexible array member.
poke :: forall aux elem.
     (Storable aux, Storable elem, NumElems elem aux)
  => Ptr (WithFlam elem aux) -> WithFlam elem aux -> IO ()
poke :: forall aux elem.
(Storable aux, Storable elem, NumElems elem aux) =>
Ptr (WithFlam elem aux) -> WithFlam elem aux -> IO ()
poke Ptr (WithFlam elem aux)
ptrStruct (WithFlam aux
aux Vector elem
vector)
  | Int
sizeNumElems Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Vector elem -> Int
forall a. Storable a => Vector a -> Int
VS.length Vector elem
vector =
      FlamLengthMismatch -> IO ()
forall e a. Exception e => e -> IO a
throwIO (FlamLengthMismatch -> IO ()) -> FlamLengthMismatch -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> Int -> FlamLengthMismatch
FlamLengthMismatch Int
sizeNumElems (Vector elem -> Int
forall a. Storable a => Vector a -> Int
VS.length Vector elem
vector)
  | Bool
otherwise = do
      Ptr aux -> aux -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
Foreign.poke (Ptr (WithFlam elem aux) -> Ptr aux
forall elem aux. Ptr (WithFlam elem aux) -> Ptr aux
ptrToAux Ptr (WithFlam elem aux)
ptrStruct) aux
aux
      Vector elem -> (Ptr elem -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith Vector elem
vector ((Ptr elem -> IO ()) -> IO ()) -> (Ptr elem -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr elem
ptrVectorElems -> do
        Ptr elem -> Ptr elem -> Int -> IO ()
forall a. Ptr a -> Ptr a -> Int -> IO ()
Foreign.copyBytes (Ptr (WithFlam elem aux) -> Ptr elem
forall elem aux.
Offset elem aux =>
Ptr (WithFlam elem aux) -> Ptr elem
ptrToFlam Ptr (WithFlam elem aux)
ptrStruct) Ptr elem
ptrVectorElems Int
sizeNumBytes
  where
    Size{Int
sizeNumElems :: Size -> Int
sizeNumElems :: Int
sizeNumElems, Int
sizeNumBytes :: Size -> Int
sizeNumBytes :: Int
sizeNumBytes} = aux -> Size
forall elem aux. (NumElems elem aux, Storable elem) => aux -> Size
flamSize aux
aux

{-------------------------------------------------------------------------------
  Exceptions
-------------------------------------------------------------------------------}

data FlamLengthMismatch = FlamLengthMismatch {
      FlamLengthMismatch -> Int
flamLengthStruct   :: Int
    , FlamLengthMismatch -> Int
flamLengthProvided :: Int
    }
  deriving stock (Int -> FlamLengthMismatch -> ShowS
[FlamLengthMismatch] -> ShowS
FlamLengthMismatch -> String
(Int -> FlamLengthMismatch -> ShowS)
-> (FlamLengthMismatch -> String)
-> ([FlamLengthMismatch] -> ShowS)
-> Show FlamLengthMismatch
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> FlamLengthMismatch -> ShowS
showsPrec :: Int -> FlamLengthMismatch -> ShowS
$cshow :: FlamLengthMismatch -> String
show :: FlamLengthMismatch -> String
$cshowList :: [FlamLengthMismatch] -> ShowS
showList :: [FlamLengthMismatch] -> ShowS
Show)

instance Exception FlamLengthMismatch

{-------------------------------------------------------------------------------
  Internal helpers
-------------------------------------------------------------------------------}

ptrToAux :: Ptr (WithFlam elem aux) -> Ptr aux
ptrToAux :: forall elem aux. Ptr (WithFlam elem aux) -> Ptr aux
ptrToAux = Ptr (WithFlam elem aux) -> Ptr aux
forall a b. Ptr a -> Ptr b
Foreign.castPtr

ptrToFlam :: forall elem aux.
     Offset elem aux
  => Ptr (WithFlam elem aux) -> Ptr elem
ptrToFlam :: forall elem aux.
Offset elem aux =>
Ptr (WithFlam elem aux) -> Ptr elem
ptrToFlam Ptr (WithFlam elem aux)
ptrStruct = Ptr (WithFlam elem aux) -> Int -> Ptr elem
forall a b. Ptr a -> Int -> Ptr b
Foreign.plusPtr Ptr (WithFlam elem aux)
ptrStruct (Proxy# aux -> Int
forall {k} {k} (elem :: k) (aux :: k).
Offset elem aux =>
Proxy# aux -> Int
offset (forall a. Proxy# a
forall {k} (a :: k). Proxy# a
proxy# @aux))

-- Internal.
data Size = Size{
      Size -> Int
sizeNumElems :: Int
    , Size -> Int
sizeNumBytes    :: Int
    }

flamSize :: forall (elem :: Type) aux.
     (NumElems elem aux, Storable elem)
  => aux -> Size
flamSize :: forall elem aux. (NumElems elem aux, Storable elem) => aux -> Size
flamSize aux
aux = Size{
      Int
sizeNumElems :: Int
sizeNumElems :: Int
sizeNumElems
    , sizeNumBytes :: Int
sizeNumBytes = Int
sizeNumElems Int -> Int -> Int
forall a. Num a => a -> a -> a
* elem -> Int
forall a. Storable a => a -> Int
Foreign.sizeOf (elem
forall a. HasCallStack => a
undefined :: elem)
    }
  where sizeNumElems :: Int
sizeNumElems = aux -> Int
forall {k} (elem :: k) aux. NumElems elem aux => aux -> Int
numElems aux
aux