-- | Marshaling and serialization
--
-- Generalizes 'Storable'. For details, see
-- https://github.com/well-typed/hs-bindgen/issues/649.
--
-- This module is intended to be imported qualified.
--
-- > import HsBindgen.Runtime.Prelude
-- > import HsBindgen.Runtime.Marshal qualified as Marshal
module HsBindgen.Runtime.Marshal (
    -- * Type Classes
    StaticSize(..)
  , ReadRaw(..)
  , WriteRaw(..)
  , EquivStorable(..)

    -- * Utility Functions
  , readRawByteOff
  , writeRawByteOff
  , readRawElemOff
  , writeRawElemOff
  , maybeReadRaw
  , with
  , withZero
  , new
  , newZero
  ) where

import Data.Complex (Complex ((:+)))
import Data.Int (Int16, Int32, Int64, Int8)
import Data.Proxy (Proxy (Proxy))
import Data.Word (Word16, Word32, Word64, Word8)
import Foreign.C qualified as C
import Foreign.ForeignPtr (ForeignPtr, withForeignPtr)
import Foreign.Marshal.Alloc qualified as Alloc
import Foreign.Marshal.Utils qualified as Utils
import Foreign.Ptr (FunPtr, Ptr)
import Foreign.Ptr qualified as Ptr
import Foreign.StablePtr (StablePtr)
import Foreign.Storable (Storable)
import Foreign.Storable qualified as Storable
import GHC.ForeignPtr (mallocForeignPtrAlignedBytes)

import HsBindgen.Runtime.PtrConst (PtrConst)

{-------------------------------------------------------------------------------
  Type Classes
-------------------------------------------------------------------------------}

-- | Size and alignment for values that have a static size in memory
--
-- Types that are instances of 'Storable' can derive this instance.
class StaticSize a where

  -- | Storage requirements (bytes)
  staticSizeOf :: Proxy a -> Int

  default staticSizeOf :: Storable a => Proxy a -> Int
  staticSizeOf Proxy a
_proxy = forall a. Storable a => a -> Int
Storable.sizeOf @a a
forall a. HasCallStack => a
undefined

  -- | Alignment (bytes)
  staticAlignment :: Proxy a -> Int

  default staticAlignment :: Storable a => Proxy a -> Int
  staticAlignment Proxy a
_proxy = forall a. Storable a => a -> Int
Storable.alignment @a a
forall a. HasCallStack => a
undefined

-- | Values that can be read from memory
--
-- Types that are instances of 'Storable' can derive this instance.
class ReadRaw a where

  -- | Read a value from the given memory location
  --
  -- This function might require a properly aligned address to function
  -- correctly, depending on the architecture.
  readRaw :: Ptr a -> IO a

  default readRaw :: Storable a => Ptr a -> IO a
  readRaw = Ptr a -> IO a
forall a. Storable a => Ptr a -> IO a
Storable.peek

-- | Values that can be written to memory
--
-- Types that are instances of 'Storable' can derive this instance.
class WriteRaw a where

  -- | Write a value to the given memory location
  --
  -- This function might require a properly aligned address to function
  -- correctly, depending on the architecture.
  writeRaw :: Ptr a -> a -> IO ()

  default writeRaw :: Storable a => Ptr a -> a -> IO ()
  writeRaw = Ptr a -> a -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
Storable.poke

-- | Type used to derive a 'Storable' instance when the type has 'StaticSize',
-- 'ReadRaw', and 'WriteRaw' instances
--
-- Use the @DerivingVia@ GHC extension as follows:
--
-- @
-- {-# LANGUAGE DerivingVia #-}
--
-- data Foo = Foo { ... }
--   deriving Storable via EquivStorable Foo
-- @
newtype EquivStorable a = EquivStorable a

instance
     (ReadRaw a, StaticSize a, WriteRaw a)
  => Storable (EquivStorable a)
  where
    sizeOf :: EquivStorable a -> Int
sizeOf    EquivStorable a
_ = forall a. StaticSize a => Proxy a -> Int
staticSizeOf    @a Proxy a
forall a. HasCallStack => a
undefined
    alignment :: EquivStorable a -> Int
alignment EquivStorable a
_ = forall a. StaticSize a => Proxy a -> Int
staticAlignment @a Proxy a
forall a. HasCallStack => a
undefined

    peek :: Ptr (EquivStorable a) -> IO (EquivStorable a)
peek Ptr (EquivStorable a)
ptr = a -> EquivStorable a
forall a. a -> EquivStorable a
EquivStorable (a -> EquivStorable a) -> IO a -> IO (EquivStorable a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr a -> IO a
forall a. ReadRaw a => Ptr a -> IO a
readRaw (Ptr (EquivStorable a) -> Ptr a
forall a b. Ptr a -> Ptr b
Ptr.castPtr Ptr (EquivStorable a)
ptr)

    poke :: Ptr (EquivStorable a) -> EquivStorable a -> IO ()
poke Ptr (EquivStorable a)
ptr (EquivStorable a
x) = Ptr a -> a -> IO ()
forall a. WriteRaw a => Ptr a -> a -> IO ()
writeRaw (Ptr (EquivStorable a) -> Ptr a
forall a b. Ptr a -> Ptr b
Ptr.castPtr Ptr (EquivStorable a)
ptr) a
x

{-------------------------------------------------------------------------------
  Utility Functions
-------------------------------------------------------------------------------}

-- | Read a value from the given memory location, given by a base address and an
-- offset
readRawByteOff :: ReadRaw a => Ptr b -> Int -> IO a
readRawByteOff :: forall a b. ReadRaw a => Ptr b -> Int -> IO a
readRawByteOff Ptr b
ptr Int
off = Ptr a -> IO a
forall a. ReadRaw a => Ptr a -> IO a
readRaw (Ptr b
ptr Ptr b -> Int -> Ptr a
forall a b. Ptr a -> Int -> Ptr b
`Ptr.plusPtr` Int
off)

-- | Write a value to the given memory location, given by a base address and an
-- offset
writeRawByteOff :: WriteRaw a => Ptr b -> Int -> a -> IO ()
writeRawByteOff :: forall a b. WriteRaw a => Ptr b -> Int -> a -> IO ()
writeRawByteOff Ptr b
ptr Int
off = Ptr a -> a -> IO ()
forall a. WriteRaw a => Ptr a -> a -> IO ()
writeRaw (Ptr b
ptr Ptr b -> Int -> Ptr a
forall a b. Ptr a -> Int -> Ptr b
`Ptr.plusPtr` Int
off)

-- | Read a value from a memory area regarded as an array of values of the same
-- kind
--
-- The first argument specifies the start address of the array.  The second
-- specifies the (zero-based) index into the array.
readRawElemOff :: forall a. (ReadRaw a, StaticSize a) => Ptr a -> Int -> IO a
readRawElemOff :: forall a. (ReadRaw a, StaticSize a) => Ptr a -> Int -> IO a
readRawElemOff Ptr a
ptr Int
off = Ptr a -> Int -> IO a
forall a b. ReadRaw a => Ptr b -> Int -> IO a
readRawByteOff Ptr a
ptr (Int -> IO a) -> Int -> IO a
forall a b. (a -> b) -> a -> b
$ Int
off Int -> Int -> Int
forall a. Num a => a -> a -> a
* forall a. StaticSize a => Proxy a -> Int
staticSizeOf @a Proxy a
forall {k} (t :: k). Proxy t
Proxy

-- | Write a value to a memory area regarded as an array of values of the same
-- kind
--
-- The first argument specifies the start address of the array.  The second
-- specifies the (zero-based) index into the array.
writeRawElemOff :: forall a.
     (StaticSize a, WriteRaw a)
  => Ptr a
  -> Int
  -> a
  -> IO ()
writeRawElemOff :: forall a. (StaticSize a, WriteRaw a) => Ptr a -> Int -> a -> IO ()
writeRawElemOff Ptr a
ptr Int
off = Ptr a -> Int -> a -> IO ()
forall a b. WriteRaw a => Ptr b -> Int -> a -> IO ()
writeRawByteOff Ptr a
ptr (Int -> a -> IO ()) -> Int -> a -> IO ()
forall a b. (a -> b) -> a -> b
$ Int
off Int -> Int -> Int
forall a. Num a => a -> a -> a
* forall a. StaticSize a => Proxy a -> Int
staticSizeOf @a Proxy a
forall {k} (t :: k). Proxy t
Proxy

-- | Read a value from memory when passed a non-null pointer
maybeReadRaw :: ReadRaw a => Ptr a -> IO (Maybe a)
maybeReadRaw :: forall a. ReadRaw a => Ptr a -> IO (Maybe a)
maybeReadRaw Ptr a
ptr
    | Ptr a
ptr Ptr a -> Ptr a -> Bool
forall a. Eq a => a -> a -> Bool
== Ptr a
forall a. Ptr a
Ptr.nullPtr = Maybe a -> IO (Maybe a)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe a
forall a. Maybe a
Nothing
    | Bool
otherwise          = a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> IO a -> IO (Maybe a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr a -> IO a
forall a. ReadRaw a => Ptr a -> IO a
readRaw Ptr a
ptr

-- | Allocate local memory, write the specified value, and call a function with
-- the pointer
--
-- The allocated memory is aligned.
--
-- Memory that is not written to by 'poke' may contain arbitrary data.
--
-- The allocated memory is freed when the function terminates, either normally
-- or via an exception.  The passed pointer must therefore /not/ be used after
-- this.
with :: forall a b.
     (StaticSize a, WriteRaw a)
  => a
  -> (Ptr a -> IO b)
  -> IO b
with :: forall a b.
(StaticSize a, WriteRaw a) =>
a -> (Ptr a -> IO b) -> IO b
with a
x Ptr a -> IO b
f = Int -> Int -> (Ptr a -> IO b) -> IO b
forall a b. Int -> Int -> (Ptr a -> IO b) -> IO b
Alloc.allocaBytesAligned Int
size Int
align((Ptr a -> IO b) -> IO b) -> (Ptr a -> IO b) -> IO b
forall a b. (a -> b) -> a -> b
$ \Ptr a
ptr -> do
    Ptr a -> a -> IO ()
forall a. WriteRaw a => Ptr a -> a -> IO ()
writeRaw Ptr a
ptr a
x
    Ptr a -> IO b
f Ptr a
ptr
  where
    size, align :: Int
    size :: Int
size  = forall a. StaticSize a => Proxy a -> Int
staticSizeOf    @a Proxy a
forall {k} (t :: k). Proxy t
Proxy
    align :: Int
align = forall a. StaticSize a => Proxy a -> Int
staticAlignment @a Proxy a
forall {k} (t :: k). Proxy t
Proxy

-- | Allocate local memory, write the specified value, and call a function with
-- the pointer
--
-- The allocated memory is aligned.
--
-- The memory is filled with bytes of value zero before the value is written.
-- Memory that is not written to by 'poke' contains zeros, not arbitrary data.
--
-- The allocated memory is freed when the function terminates, either normally
-- or via an exception.  The passed pointer must therefore /not/ be used after
-- this.
withZero :: forall a b.
     (StaticSize a, WriteRaw a)
  => a
  -> (Ptr a -> IO b)
  -> IO b
withZero :: forall a b.
(StaticSize a, WriteRaw a) =>
a -> (Ptr a -> IO b) -> IO b
withZero a
x Ptr a -> IO b
f = Int -> Int -> (Ptr a -> IO b) -> IO b
forall a b. Int -> Int -> (Ptr a -> IO b) -> IO b
Alloc.allocaBytesAligned Int
size Int
align((Ptr a -> IO b) -> IO b) -> (Ptr a -> IO b) -> IO b
forall a b. (a -> b) -> a -> b
$ \Ptr a
ptr -> do
    Ptr a -> Word8 -> Int -> IO ()
forall a. Ptr a -> Word8 -> Int -> IO ()
Utils.fillBytes Ptr a
ptr Word8
0 Int
size
    Ptr a -> a -> IO ()
forall a. WriteRaw a => Ptr a -> a -> IO ()
writeRaw Ptr a
ptr a
x
    Ptr a -> IO b
f Ptr a
ptr
  where
    size, align :: Int
    size :: Int
size  = forall a. StaticSize a => Proxy a -> Int
staticSizeOf    @a Proxy a
forall {k} (t :: k). Proxy t
Proxy
    align :: Int
align = forall a. StaticSize a => Proxy a -> Int
staticAlignment @a Proxy a
forall {k} (t :: k). Proxy t
Proxy

-- | Allocate memory, write the specified value, and the 'ForeignPtr'
--
-- The allocated memory is aligned.
--
-- Memory that is not written to by 'writeRaw' may contain arbitrary data.
new :: forall a.
     (StaticSize a, WriteRaw a)
  => a
  -> IO (ForeignPtr a)
new :: forall a. (StaticSize a, WriteRaw a) => a -> IO (ForeignPtr a)
new a
x = do
    ForeignPtr a
fptr <- Int -> Int -> IO (ForeignPtr a)
forall a. Int -> Int -> IO (ForeignPtr a)
mallocForeignPtrAlignedBytes Int
size Int
align
    ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
fptr ((Ptr a -> IO ()) -> IO ()) -> (Ptr a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
ptr -> Ptr a -> a -> IO ()
forall a. WriteRaw a => Ptr a -> a -> IO ()
writeRaw Ptr a
ptr a
x
    ForeignPtr a -> IO (ForeignPtr a)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ForeignPtr a
fptr
  where
    size, align :: Int
    size :: Int
size  = forall a. StaticSize a => Proxy a -> Int
staticSizeOf    @a Proxy a
forall {k} (t :: k). Proxy t
Proxy
    align :: Int
align = forall a. StaticSize a => Proxy a -> Int
staticAlignment @a Proxy a
forall {k} (t :: k). Proxy t
Proxy

-- | Allocate memory, write the specified value, and the 'ForeignPtr'
--
-- The allocated memory is aligned.
--
-- The memory is filled with bytes of value zero before the value is written.
-- Memory that is not written to by 'writeRaw' contains zeros, not arbitrary
-- data.
newZero :: forall a.
     (StaticSize a, WriteRaw a)
  => a
  -> IO (ForeignPtr a)
newZero :: forall a. (StaticSize a, WriteRaw a) => a -> IO (ForeignPtr a)
newZero a
x = do
    ForeignPtr a
fptr <- Int -> Int -> IO (ForeignPtr a)
forall a. Int -> Int -> IO (ForeignPtr a)
mallocForeignPtrAlignedBytes Int
size Int
align
    ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
fptr ((Ptr a -> IO ()) -> IO ()) -> (Ptr a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
ptr -> do
      Ptr a -> Word8 -> Int -> IO ()
forall a. Ptr a -> Word8 -> Int -> IO ()
Utils.fillBytes Ptr a
ptr Word8
0 Int
size
      Ptr a -> a -> IO ()
forall a. WriteRaw a => Ptr a -> a -> IO ()
writeRaw Ptr a
ptr a
x
    ForeignPtr a -> IO (ForeignPtr a)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ForeignPtr a
fptr
  where
    size, align :: Int
    size :: Int
size  = forall a. StaticSize a => Proxy a -> Int
staticSizeOf    @a Proxy a
forall {k} (t :: k). Proxy t
Proxy
    align :: Int
align = forall a. StaticSize a => Proxy a -> Int
staticAlignment @a Proxy a
forall {k} (t :: k). Proxy t
Proxy

{-------------------------------------------------------------------------------
  Instances
-------------------------------------------------------------------------------}

instance StaticSize C.CChar
instance ReadRaw    C.CChar
instance WriteRaw   C.CChar

instance StaticSize C.CSChar
instance ReadRaw    C.CSChar
instance WriteRaw   C.CSChar

instance StaticSize C.CUChar
instance ReadRaw    C.CUChar
instance WriteRaw   C.CUChar

instance StaticSize C.CShort
instance ReadRaw    C.CShort
instance WriteRaw   C.CShort

instance StaticSize C.CUShort
instance ReadRaw    C.CUShort
instance WriteRaw   C.CUShort

instance StaticSize C.CInt
instance ReadRaw    C.CInt
instance WriteRaw   C.CInt

instance StaticSize C.CUInt
instance ReadRaw    C.CUInt
instance WriteRaw   C.CUInt

instance StaticSize C.CLong
instance ReadRaw    C.CLong
instance WriteRaw   C.CLong

instance StaticSize C.CULong
instance ReadRaw    C.CULong
instance WriteRaw   C.CULong

instance StaticSize C.CPtrdiff
instance ReadRaw    C.CPtrdiff
instance WriteRaw   C.CPtrdiff

instance StaticSize C.CSize
instance ReadRaw    C.CSize
instance WriteRaw   C.CSize

instance StaticSize C.CWchar
instance ReadRaw    C.CWchar
instance WriteRaw   C.CWchar

instance StaticSize C.CSigAtomic
instance ReadRaw    C.CSigAtomic
instance WriteRaw   C.CSigAtomic

instance StaticSize C.CLLong
instance ReadRaw    C.CLLong
instance WriteRaw   C.CLLong

instance StaticSize C.CULLong
instance ReadRaw    C.CULLong
instance WriteRaw   C.CULLong

instance StaticSize C.CBool
instance ReadRaw    C.CBool
instance WriteRaw   C.CBool

instance StaticSize C.CIntPtr
instance ReadRaw    C.CIntPtr
instance WriteRaw   C.CIntPtr

instance StaticSize C.CUIntPtr
instance ReadRaw    C.CUIntPtr
instance WriteRaw   C.CUIntPtr

instance StaticSize C.CIntMax
instance ReadRaw    C.CIntMax
instance WriteRaw   C.CIntMax

instance StaticSize C.CUIntMax
instance ReadRaw    C.CUIntMax
instance WriteRaw   C.CUIntMax

instance StaticSize C.CClock
instance ReadRaw    C.CClock
instance WriteRaw   C.CClock

instance StaticSize C.CTime
instance ReadRaw    C.CTime
instance WriteRaw   C.CTime

instance StaticSize C.CUSeconds
instance ReadRaw    C.CUSeconds
instance WriteRaw   C.CUSeconds

instance StaticSize C.CSUSeconds
instance ReadRaw    C.CSUSeconds
instance WriteRaw   C.CSUSeconds

instance StaticSize C.CFloat
instance ReadRaw    C.CFloat
instance WriteRaw   C.CFloat

instance StaticSize C.CDouble
instance ReadRaw    C.CDouble
instance WriteRaw   C.CDouble

instance StaticSize (Ptr a)
instance ReadRaw    (Ptr a)
instance WriteRaw   (Ptr a)

instance StaticSize (PtrConst a)
instance ReadRaw    (PtrConst a)
instance WriteRaw   (PtrConst a)

instance StaticSize (FunPtr a)
instance ReadRaw    (FunPtr a)
instance WriteRaw   (FunPtr a)

instance StaticSize (StablePtr a)
instance ReadRaw    (StablePtr a)
instance WriteRaw   (StablePtr a)

instance StaticSize Int8
instance ReadRaw    Int8
instance WriteRaw   Int8

instance StaticSize Int16
instance ReadRaw    Int16
instance WriteRaw   Int16

instance StaticSize Int32
instance ReadRaw    Int32
instance WriteRaw   Int32

instance StaticSize Int64
instance ReadRaw    Int64
instance WriteRaw   Int64

instance StaticSize Word8
instance ReadRaw    Word8
instance WriteRaw   Word8

instance StaticSize Word16
instance ReadRaw    Word16
instance WriteRaw   Word16

instance StaticSize Word32
instance ReadRaw    Word32
instance WriteRaw   Word32

instance StaticSize Word64
instance ReadRaw    Word64
instance WriteRaw   Word64

instance StaticSize Int
instance ReadRaw    Int
instance WriteRaw   Int

instance StaticSize Word
instance ReadRaw    Word
instance WriteRaw   Word

instance StaticSize Float
instance ReadRaw    Float
instance WriteRaw   Float

instance StaticSize Double
instance ReadRaw    Double
instance WriteRaw   Double

instance StaticSize Char
instance ReadRaw    Char
instance WriteRaw   Char

instance StaticSize Bool
instance ReadRaw    Bool
instance WriteRaw   Bool

instance StaticSize ()
instance ReadRaw    ()
instance WriteRaw   ()

--------------------------------------------------------------------------------

-- The instances for 'Complex' follow the 'Storable' instance.  It is rewritten
-- so that they only rely on the classes defined in this module, not 'Storable'.

instance StaticSize a => StaticSize (Complex a) where
  staticSizeOf :: Proxy (Complex a) -> Int
staticSizeOf    Proxy (Complex a)
_ = Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Proxy a -> Int
forall a. StaticSize a => Proxy a -> Int
staticSizeOf (forall t. Proxy t
forall {k} (t :: k). Proxy t
Proxy @a)
  staticAlignment :: Proxy (Complex a) -> Int
staticAlignment Proxy (Complex a)
_ = Proxy a -> Int
forall a. StaticSize a => Proxy a -> Int
staticAlignment (forall t. Proxy t
forall {k} (t :: k). Proxy t
Proxy @a)

instance (ReadRaw a, StaticSize a) => ReadRaw (Complex a) where
  readRaw :: Ptr (Complex a) -> IO (Complex a)
readRaw Ptr (Complex a)
ptrComplex =
    let ptrPart :: Ptr a
ptrPart = Ptr (Complex a) -> Ptr a
forall a b. Ptr a -> Ptr b
Ptr.castPtr Ptr (Complex a)
ptrComplex
    in  a -> a -> Complex a
forall a. a -> a -> Complex a
(:+) (a -> a -> Complex a) -> IO a -> IO (a -> Complex a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr a -> IO a
forall a. ReadRaw a => Ptr a -> IO a
readRaw Ptr a
ptrPart IO (a -> Complex a) -> IO a -> IO (Complex a)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Ptr a -> Int -> IO a
forall a. (ReadRaw a, StaticSize a) => Ptr a -> Int -> IO a
readRawElemOff Ptr a
ptrPart Int
1

instance (StaticSize a, WriteRaw a) => WriteRaw (Complex a) where
  writeRaw :: Ptr (Complex a) -> Complex a -> IO ()
writeRaw Ptr (Complex a)
ptrComplex  (a
r :+ a
i) = do
    let ptrPart :: Ptr a
ptrPart = Ptr (Complex a) -> Ptr a
forall a b. Ptr a -> Ptr b
Ptr.castPtr Ptr (Complex a)
ptrComplex
    Ptr a -> a -> IO ()
forall a. WriteRaw a => Ptr a -> a -> IO ()
writeRaw Ptr a
ptrPart a
r
    Ptr a -> Int -> a -> IO ()
forall a. (StaticSize a, WriteRaw a) => Ptr a -> Int -> a -> IO ()
writeRawElemOff Ptr a
ptrPart Int
1 a
i