{-# LANGUAGE UnboxedTuples #-}

-- may as well export everything the interface is highly unsafe
module Bytezap.Poke where

import GHC.Exts
import Raehik.Compat.GHC.Exts.GHC908MemcpyPrimops qualified as MemcpyPrimops

import GHC.Word ( Word8(W8#) )

import Data.ByteString qualified as BS
import Data.ByteString.Internal qualified as BS

import Control.Monad ( void )

import Raehik.Compat.Data.Primitive.Types

import GHC.ForeignPtr

import Control.Monad.Primitive

import Bytezap.Struct qualified as Struct

{- | Unboxed buffer write operation.

The next offset must be greater than or equal to the input buffer offset.
This is not checked.

Note that the only way to find out the length of a write is to perform it. But
you can't perform a length without providing a correctly-sized buffer. Thus, you
may only use a 'Poke#' when you have a buffer large enough to fit its maximum
write length-- which in turn means means you must track write lengths
separately. ('Bytezap.Write.Write' does this.)

I provide this highly unsafe, seemingly unhelpful type because it's a
requirement for 'Bytezap.Write.Write', and here I can guarantee performance
better because I don't need to worry about laziness.

We cannot be polymorphic on the pointer type unless we box the pointer.
We thus limit ourselves to writing to 'Addr#'s, and not 'MutableByteArray#'s.
(I figure we're most interested in @ByteString@s, which use 'Addr#'.)

Note that if we did provide write length, then the next offset might appear
superfluous. But that next offset is usually already calculated, and may be
passed directly to sequenced writes, unlike if we returned a write length which
would need to be added to the original offset.
-}
type Poke# s =
     Addr#                {- ^ buffer pointer -}
  -> Int#                 {- ^ buffer offset -}
  -> State# s             {- ^ state token -}
  -> (# State# s, Int# #) {- ^ (state token, next offset) -}

-- | Poke newtype wrapper.
newtype Poke s = Poke { forall s. Poke s -> Poke# s
unPoke :: Poke# s }

-- | Sequence two buffer writes left-to-right.
instance Semigroup (Poke s) where
    Poke Poke# s
l <> :: Poke s -> Poke s -> Poke s
<> Poke Poke# s
r = Poke# s -> Poke s
forall s. Poke# s -> Poke s
Poke (Poke# s -> Poke s) -> Poke# s -> Poke s
forall a b. (a -> b) -> a -> b
$ \Addr#
base# Int#
os0# State# s
s0 ->
        case Poke# s
l Addr#
base# Int#
os0# State# s
s0 of (# State# s
s1, Int#
os1# #) -> Poke# s
r Addr#
base# Int#
os1# State# s
s1

-- | The empty buffer write simply returns its state token and offset.
instance Monoid (Poke s) where
    mempty :: Poke s
mempty = Poke# s -> Poke s
forall s. Poke# s -> Poke s
Poke (Poke# s -> Poke s) -> Poke# s -> Poke s
forall a b. (a -> b) -> a -> b
$ \Addr#
_base# Int#
os# State# s
s -> (# State# s
s, Int#
os# #)

-- | Execute a 'Poke' at a fresh 'BS.ByteString' of the given length.
unsafeRunPokeBS :: Int -> Poke RealWorld -> BS.ByteString
unsafeRunPokeBS :: Int -> Poke RealWorld -> ByteString
unsafeRunPokeBS Int
len Poke RealWorld
p = Int -> (Ptr Word8 -> IO ()) -> ByteString
BS.unsafeCreate Int
len (IO Int -> IO ()
forall (f :: Type -> Type) a. Functor f => f a -> f ()
void (IO Int -> IO ()) -> (Ptr Word8 -> IO Int) -> Ptr Word8 -> IO ()
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Poke RealWorld -> Ptr Word8 -> IO Int
forall s (m :: Type -> Type).
MonadPrim s m =>
Poke s -> Ptr Word8 -> m Int
unsafeRunPoke Poke RealWorld
p)

-- | Execute a 'Poke' at a fresh 'BS.ByteString' of the given maximum length.
--   Does not reallocate if final size is less than estimated.
unsafeRunPokeBSUptoN :: Int -> Poke RealWorld -> BS.ByteString
unsafeRunPokeBSUptoN :: Int -> Poke RealWorld -> ByteString
unsafeRunPokeBSUptoN Int
len = Int -> (Ptr Word8 -> IO Int) -> ByteString
BS.unsafeCreateUptoN Int
len ((Ptr Word8 -> IO Int) -> ByteString)
-> (Poke RealWorld -> Ptr Word8 -> IO Int)
-> Poke RealWorld
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Poke RealWorld -> Ptr Word8 -> IO Int
forall s (m :: Type -> Type).
MonadPrim s m =>
Poke s -> Ptr Word8 -> m Int
unsafeRunPoke

-- | Execute a 'Poke' at a pointer. Returns the number of bytes written.
--
-- The pointer must be a mutable buffer with enough space to hold the poke.
-- Absolutely none of this is checked. Use with caution. Sensible uses:
--
-- * implementing pokes to ByteStrings and the like
-- * executing known-length (!!) pokes to known-length (!!) buffers e.g.
--   together with allocaBytes
unsafeRunPoke :: MonadPrim s m => Poke s -> Ptr Word8 -> m Int
unsafeRunPoke :: forall s (m :: Type -> Type).
MonadPrim s m =>
Poke s -> Ptr Word8 -> m Int
unsafeRunPoke (Poke Poke# s
p) (Ptr Addr#
base#) = (State# (PrimState m) -> (# State# (PrimState m), Int #)) -> m Int
forall a.
(State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
forall (m :: Type -> Type) a.
PrimMonad m =>
(State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
primitive ((State# (PrimState m) -> (# State# (PrimState m), Int #))
 -> m Int)
-> (State# (PrimState m) -> (# State# (PrimState m), Int #))
-> m Int
forall a b. (a -> b) -> a -> b
$ \State# (PrimState m)
s0 ->
    case Poke# s
p Addr#
base# Int#
0# State# s
State# (PrimState m)
s0 of (# State# s
s1, Int#
os# #) -> (# State# s
State# (PrimState m)
s1, Int# -> Int
I# Int#
os# #)

-- | Poke a type via its 'Prim'' instance.
prim :: forall a s. Prim' a => a -> Poke s
prim :: forall a s. Prim' a => a -> Poke s
prim a
a = Poke# s -> Poke s
forall s. Poke# s -> Poke s
Poke (Poke# s -> Poke s) -> Poke# s -> Poke s
forall a b. (a -> b) -> a -> b
$ \Addr#
base# Int#
os# State# s
s0 ->
    case Addr# -> Int# -> a -> State# s -> State# s
forall s. Addr# -> Int# -> a -> State# s -> State# s
forall a s. Prim' a => Addr# -> Int# -> a -> State# s -> State# s
writeWord8OffAddrAs# Addr#
base# Int#
os# a
a State# s
s0 of
      State# s
s1 -> (# State# s
s1, Int#
os# Int# -> Int# -> Int#
+# a -> Int#
forall a. Prim a => a -> Int#
sizeOf# (a
forall a. HasCallStack => a
undefined :: a) #)

-- we reimplement withForeignPtr because it's too high level.
-- keepAlive# has the wrong type before GHC 9.10, but it doesn't matter here
-- because copyAddrToAddrNonOverlapping# forces RealWorld.
byteString :: BS.ByteString -> Poke RealWorld
byteString :: ByteString -> Poke RealWorld
byteString (BS.BS (ForeignPtr Addr#
p# ForeignPtrContents
r) (I# Int#
len#)) = Poke# RealWorld -> Poke RealWorld
forall s. Poke# s -> Poke s
Poke (Poke# RealWorld -> Poke RealWorld)
-> Poke# RealWorld -> Poke RealWorld
forall a b. (a -> b) -> a -> b
$ \Addr#
base# Int#
os# State# RealWorld
s0 ->
    ForeignPtrContents
-> State# RealWorld
-> (State# RealWorld -> (# State# RealWorld, Int# #))
-> (# State# RealWorld, Int# #)
forall a d b. a -> State# d -> (State# d -> b) -> b
keepAlive# ForeignPtrContents
r State# RealWorld
s0 ((State# RealWorld -> (# State# RealWorld, Int# #))
 -> (# State# RealWorld, Int# #))
-> (State# RealWorld -> (# State# RealWorld, Int# #))
-> (# State# RealWorld, Int# #)
forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s1 ->
        case Addr# -> Addr# -> Int# -> State# RealWorld -> State# RealWorld
MemcpyPrimops.copyAddrToAddrNonOverlapping# Addr#
p# (Addr#
base# Addr# -> Int# -> Addr#
`plusAddr#` Int#
os#) Int#
len# State# RealWorld
s1 of
          State# RealWorld
s2 -> (# State# RealWorld
s2, Int#
os# Int# -> Int# -> Int#
+# Int#
len# #)

byteArray# :: ByteArray# -> Int# -> Int# -> Poke s
byteArray# :: forall s. ByteArray# -> Int# -> Int# -> Poke s
byteArray# ByteArray#
ba# Int#
baos# Int#
balen# = Poke# s -> Poke s
forall s. Poke# s -> Poke s
Poke (Poke# s -> Poke s) -> Poke# s -> Poke s
forall a b. (a -> b) -> a -> b
$ \Addr#
base# Int#
os# State# s
s0 ->
    case ByteArray# -> Int# -> Addr# -> Int# -> State# s -> State# s
forall d.
ByteArray# -> Int# -> Addr# -> Int# -> State# d -> State# d
copyByteArrayToAddr# ByteArray#
ba# Int#
baos# (Addr#
base# Addr# -> Int# -> Addr#
`plusAddr#` Int#
os#) Int#
balen# State# s
s0 of
      State# s
s1 -> (# State# s
s1, Int#
os# Int# -> Int# -> Int#
+# Int#
balen# #)

-- | essentially memset
replicateByte :: Int -> Word8 -> Poke RealWorld
replicateByte :: Int -> Word8 -> Poke RealWorld
replicateByte (I# Int#
len#) (W8# Word8#
byte#) = Poke# RealWorld -> Poke RealWorld
forall s. Poke# s -> Poke s
Poke (Poke# RealWorld -> Poke RealWorld)
-> Poke# RealWorld -> Poke RealWorld
forall a b. (a -> b) -> a -> b
$ \Addr#
base# Int#
os# State# RealWorld
s0 ->
    case Addr# -> Int# -> Int# -> State# RealWorld -> State# RealWorld
MemcpyPrimops.setAddrRange# (Addr#
base# Addr# -> Int# -> Addr#
`plusAddr#` Int#
os#) Int#
len# Int#
byteAsInt# State# RealWorld
s0 of
      State# RealWorld
s1 -> (# State# RealWorld
s1, Int#
os# Int# -> Int# -> Int#
+# Int#
len# #)
  where
    byteAsInt# :: Int#
byteAsInt# = Word# -> Int#
word2Int# (Word8# -> Word#
word8ToWord# Word8#
byte#)

-- | Use a struct poke as a regular poke.
--
-- To do this, we must associate a constant byte length with an existing poker.
-- Note that pokers don't expose the type of the data they are serializing,
-- so this is a very clumsy operation by itself. You should only be using this
-- when you have such types in scope, and the constant length should be obtained
-- in a sensible manner (e.g. 'Bytezap.Struct.Generic.KnownSizeOf' for generic
-- struct pokers, or your own constant size class if you're doing funky stuff).
fromStructPoke :: Int -> Struct.Poke s -> Poke s
fromStructPoke :: forall s. Int -> Poke s -> Poke s
fromStructPoke (I# Int#
len#) (Struct.Poke Poke# s
p) = Poke# s -> Poke s
forall s. Poke# s -> Poke s
Poke (Poke# s -> Poke s) -> Poke# s -> Poke s
forall a b. (a -> b) -> a -> b
$ \Addr#
base# Int#
os# State# s
s ->
    (# Poke# s
p Addr#
base# Int#
os# State# s
s, Int#
os# Int# -> Int# -> Int#
+# Int#
len# #)

-- | Use a struct poke as a regular poke by throwing away the return offset.
toStructPoke :: Poke s -> Struct.Poke s
toStructPoke :: forall s. Poke s -> Poke s
toStructPoke (Poke Poke# s
p) = Poke# s -> Poke s
forall s. Poke# s -> Poke s
Struct.Poke (Poke# s -> Poke s) -> Poke# s -> Poke s
forall a b. (a -> b) -> a -> b
$ \Addr#
base# Int#
os0# State# s
s0 ->
    case Poke# s
p Addr#
base# Int#
os0# State# s
s0 of (# State# s
s1, Int#
_os1# #) -> State# s
s1