{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RoleAnnotations #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE NoMonomorphismRestriction #-}

-- |
-- Description : Zero-initialized allocation for GPU-aligned types
-- Copyright   : (c) Jeremy Nuttall, 2025
-- License     : BSD-3-Clause
-- Maintainer  : jeremy@jeremy-nuttall.com
-- Stability   : experimental
--
-- Allocation helpers with guaranteed zero-initialized padding.
-- Use these instead of 'alloca' or 'malloc' to avoid garbage in padding bytes.
module Foreign.GPU.Marshal.Aligned (
  -- * Utilities for 'Packed' values
  PackedPtr,
  withPacked,
  allocaPacked,

  -- * Utilities for 'Strided' values
  StridedPtr,
  withStrided,
  allocaStrided,

  -- * Utilities for runtime length arrays
  alignedCopyVector,
) where

import qualified Data.Vector.Storable as SV
import UnliftIO (MonadIO, MonadUnliftIO, liftIO)
import UnliftIO.Foreign

import Foreign.GPU.Storable.Aligned

-- | Convenience type for 'Packed' 'AlignedPtr'
type PackedPtr layout a = AlignedPtr layout (Packed layout a)

-- | Temporarily allocates a zero-initialized block of memory and pokes a
-- 'Packed' value into it, providing a pointer to the result.
-- The storage is freed automatically. The pointer is only valid within the continuation.
withPacked
  :: forall layout a m b
   . (MonadUnliftIO m, AlignedStorable layout a)
  => a
  -> (PackedPtr layout a -> m b)
  -> m b
withPacked :: forall (layout :: MemoryLayout) a (m :: * -> *) b.
(MonadUnliftIO m, AlignedStorable layout a) =>
a -> (PackedPtr layout a -> m b) -> m b
withPacked a
a PackedPtr layout a -> m b
f = Packed layout a -> (Ptr (Packed layout a) -> m b) -> m b
forall (m :: * -> *) a b.
(MonadUnliftIO m, Storable a) =>
a -> (Ptr a -> m b) -> m b
withZeroed (a -> Packed layout a
forall (layout :: MemoryLayout) a. a -> Packed layout a
Packed a
a) (PackedPtr layout a -> m b
f (PackedPtr layout a -> m b)
-> (Ptr (Packed layout a) -> PackedPtr layout a)
-> Ptr (Packed layout a)
-> m b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr (Packed layout a) -> PackedPtr layout a
forall (layout :: MemoryLayout) a. Ptr a -> AlignedPtr layout a
AlignedPtr)
{-# INLINEABLE withPacked #-}

-- | Allocates temporary, zero-initialized storage for a 'Packed' value on the stack.
-- The pointer is only valid within the continuation.
allocaPacked
  :: forall layout a b m
   . (MonadUnliftIO m, AlignedStorable layout a)
  => (PackedPtr layout a -> m b)
  -> m b
allocaPacked :: forall (layout :: MemoryLayout) a b (m :: * -> *).
(MonadUnliftIO m, AlignedStorable layout a) =>
(PackedPtr layout a -> m b) -> m b
allocaPacked PackedPtr layout a -> m b
f = (Ptr (Packed layout a) -> m b) -> m b
forall (m :: * -> *) a b.
(MonadUnliftIO m, Storable a) =>
(Ptr a -> m b) -> m b
allocaZeroed (PackedPtr layout a -> m b
f (PackedPtr layout a -> m b)
-> (Ptr (Packed layout a) -> PackedPtr layout a)
-> Ptr (Packed layout a)
-> m b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr (Packed layout a) -> PackedPtr layout a
forall (layout :: MemoryLayout) a. Ptr a -> AlignedPtr layout a
AlignedPtr)
{-# INLINEABLE allocaPacked #-}

-- | Convenience type for 'Strided' 'AlignedPtr'
type StridedPtr layout a = AlignedPtr layout (Strided layout a)

-- | Temporarily allocates a zero-initialized block of memory and pokes a
-- 'Strided' value into it, providing a pointer to the result.
-- The storage is freed automatically. The pointer is only valid within the continuation.
withStrided
  :: (MonadUnliftIO m, AlignedStorable layout a) => a -> (StridedPtr layout a -> m b) -> m b
withStrided :: forall (m :: * -> *) (layout :: MemoryLayout) a b.
(MonadUnliftIO m, AlignedStorable layout a) =>
a -> (StridedPtr layout a -> m b) -> m b
withStrided a
a StridedPtr layout a -> m b
f = Strided layout a -> (Ptr (Strided layout a) -> m b) -> m b
forall (m :: * -> *) a b.
(MonadUnliftIO m, Storable a) =>
a -> (Ptr a -> m b) -> m b
withZeroed (a -> Strided layout a
forall (layout :: MemoryLayout) a. a -> Strided layout a
Strided a
a) (StridedPtr layout a -> m b
f (StridedPtr layout a -> m b)
-> (Ptr (Strided layout a) -> StridedPtr layout a)
-> Ptr (Strided layout a)
-> m b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr (Strided layout a) -> StridedPtr layout a
forall (layout :: MemoryLayout) a. Ptr a -> AlignedPtr layout a
AlignedPtr)
{-# INLINEABLE withStrided #-}

-- | Allocates temporary, zero-initialized storage for a 'Strided' value on the stack.
-- The pointer is only valid within the continuation.
allocaStrided :: (MonadUnliftIO m, AlignedStorable layout a) => (StridedPtr layout a -> m b) -> m b
allocaStrided :: forall (m :: * -> *) (layout :: MemoryLayout) a b.
(MonadUnliftIO m, AlignedStorable layout a) =>
(StridedPtr layout a -> m b) -> m b
allocaStrided StridedPtr layout a -> m b
f = (Ptr (Strided layout a) -> m b) -> m b
forall (m :: * -> *) a b.
(MonadUnliftIO m, Storable a) =>
(Ptr a -> m b) -> m b
allocaZeroed (StridedPtr layout a -> m b
f (StridedPtr layout a -> m b)
-> (Ptr (Strided layout a) -> StridedPtr layout a)
-> Ptr (Strided layout a)
-> m b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr (Strided layout a) -> StridedPtr layout a
forall (layout :: MemoryLayout) a. Ptr a -> AlignedPtr layout a
AlignedPtr)
{-# INLINEABLE allocaStrided #-}

-- | Performs a straight 'copyBytes' on the underlying pointer of 'SV.Vector'
alignedCopyVector
  :: forall layout a m
   . (MonadIO m, AlignedStorable layout a)
  => AlignedPtr layout a
  -> SV.Vector (Strided layout a)
  -> m ()
alignedCopyVector :: forall (layout :: MemoryLayout) a (m :: * -> *).
(MonadIO m, AlignedStorable layout a) =>
AlignedPtr layout a -> Vector (Strided layout a) -> m ()
alignedCopyVector (AlignedPtr Ptr a
dest) Vector (Strided layout a)
v = IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Vector (Strided layout a)
-> (Ptr (Strided layout a) -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
SV.unsafeWith Vector (Strided layout a)
v \Ptr (Strided layout a)
src ->
  Ptr a -> Ptr a -> Int -> IO ()
forall (m :: * -> *) a. MonadIO m => Ptr a -> Ptr a -> Int -> m ()
copyBytes
    Ptr a
dest
    (Ptr (Strided layout a) -> Ptr a
forall a b. Ptr a -> Ptr b
castPtr Ptr (Strided layout a)
src)
    (forall a. Storable a => a -> Int
sizeOf @(Strided layout a) Strided layout a
forall a. HasCallStack => a
undefined Int -> Int -> Int
forall a. Num a => a -> a -> a
* Vector (Strided layout a) -> Int
forall a. Storable a => Vector a -> Int
SV.length Vector (Strided layout a)
v)
{-# INLINE alignedCopyVector #-}

--------------------------------------------------------------------------------
-- Utilities
--------------------------------------------------------------------------------

zeroPtr :: forall a m. (MonadIO m, Storable a) => Ptr a -> m ()
zeroPtr :: forall a (m :: * -> *). (MonadIO m, Storable a) => Ptr a -> m ()
zeroPtr Ptr a
ptr = Ptr a -> Word8 -> Int -> m ()
forall (m :: * -> *) a. MonadIO m => Ptr a -> Word8 -> Int -> m ()
fillBytes Ptr a
ptr Word8
0 (forall a. Storable a => a -> Int
sizeOf @a a
forall a. HasCallStack => a
undefined)
{-# INLINE zeroPtr #-}

allocaZeroed :: (MonadUnliftIO m, Storable a) => (Ptr a -> m b) -> m b
allocaZeroed :: forall (m :: * -> *) a b.
(MonadUnliftIO m, Storable a) =>
(Ptr a -> m b) -> m b
allocaZeroed Ptr a -> m b
f = (Ptr a -> m b) -> m b
forall (m :: * -> *) a b.
(MonadUnliftIO m, Storable a) =>
(Ptr a -> m b) -> m b
alloca \Ptr a
ptr -> Ptr a -> m ()
forall a (m :: * -> *). (MonadIO m, Storable a) => Ptr a -> m ()
zeroPtr Ptr a
ptr m () -> m b -> m b
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Ptr a -> m b
f Ptr a
ptr
{-# INLINE allocaZeroed #-}

withZeroed :: (MonadUnliftIO m, Storable a) => a -> (Ptr a -> m b) -> m b
withZeroed :: forall (m :: * -> *) a b.
(MonadUnliftIO m, Storable a) =>
a -> (Ptr a -> m b) -> m b
withZeroed a
a Ptr a -> m b
f = (Ptr a -> m b) -> m b
forall (m :: * -> *) a b.
(MonadUnliftIO m, Storable a) =>
(Ptr a -> m b) -> m b
allocaZeroed \Ptr a
ptr -> do
  IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Ptr a -> a -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr a
ptr a
a
  Ptr a -> m b
f Ptr a
ptr
{-# INLINE withZeroed #-}