module Basement.UArray.Mutable
    ( MUArray(..)
    
    , sizeInMutableBytesOfContent
    , mutableLength
    , mutableOffset
    , mutableSame
    , onMutableBackend
    
    , new
    , newPinned
    , newNative
    , mutableForeignMem
    , copyAt
    , copyFromPtr
    , copyToPtr
    , sub
    
    
    , unsafeWrite
    , unsafeRead
    , write
    , read
    , withMutablePtr
    ) where
import           GHC.Prim
import           GHC.Types
import           GHC.Ptr
import           Basement.Compat.Base
import           Basement.Compat.Primitive
import           Data.Proxy
import           Basement.Types.OffsetSize
import           Basement.Monad
import           Basement.PrimType
import           Basement.FinalPtr
import           Basement.Exception
import qualified Basement.Block         as BLK
import qualified Basement.Block.Mutable as MBLK
import           Basement.Block         (MutableBlock(..))
import           Basement.UArray.Base hiding (empty)
import           Basement.Numerical.Subtractive
import           Foreign.Marshal.Utils (copyBytes)
sizeInMutableBytesOfContent :: forall ty s . PrimType ty => MUArray ty s -> CountOf Word8
sizeInMutableBytesOfContent _ = primSizeInBytes (Proxy :: Proxy ty)
read :: (PrimMonad prim, PrimType ty) => MUArray ty (PrimState prim) -> Offset ty -> prim ty
read array n
    | isOutOfBound n len = primOutOfBound OOB_Read n len
    | otherwise          = unsafeRead array n
  where len = mutableLength array
write :: (PrimMonad prim, PrimType ty) => MUArray ty (PrimState prim) -> Offset ty -> ty -> prim ()
write array n val
    | isOutOfBound n len = primOutOfBound OOB_Write n len
    | otherwise          = unsafeWrite array n val
  where
    len = mutableLength array
empty :: (PrimType ty, PrimMonad prim) => prim (MUArray ty (PrimState prim))
empty = MUArray 0 0 . MUArrayMBA <$> MBLK.mutableEmpty
mutableSame :: MUArray ty st -> MUArray ty st -> Bool
mutableSame (MUArray sa ea (MUArrayMBA (MutableBlock ma))) (MUArray sb eb (MUArrayMBA (MutableBlock mb))) = (sa == sb) && (ea == eb) && bool# (sameMutableByteArray# ma mb)
mutableSame (MUArray s1 e1 (MUArrayAddr f1)) (MUArray s2 e2 (MUArrayAddr f2)) = (s1 == s2) && (e1 == e2) && finalPtrSameMemory f1 f2
mutableSame _ _ = False
mutableForeignMem :: (PrimMonad prim, PrimType ty)
                  => FinalPtr ty 
                  -> Int         
                  -> prim (MUArray ty (PrimState prim))
mutableForeignMem fptr nb = pure $ MUArray (Offset 0) (CountOf nb) (MUArrayAddr fptr)
sub :: (PrimMonad prim, PrimType ty)
    => MUArray ty (PrimState prim)
    -> Int 
    -> Int 
    -> prim (MUArray ty (PrimState prim))
sub (MUArray start sz back) dropElems' takeElems
    | takeElems <= 0 = empty
    | Just keepElems <- sz  dropElems, keepElems > 0 
                     = pure $ MUArray (start `offsetPlusE` dropElems) (min (CountOf takeElems) keepElems) back
    | otherwise      = empty
  where
    dropElems = max 0 (CountOf dropElems')
mutableLength :: PrimType ty => MUArray ty st -> CountOf ty
mutableLength (MUArray _ end _)   = end
withMutablePtrHint :: forall ty prim a . (PrimMonad prim, PrimType ty)
                   => Bool
                   -> Bool
                   -> MUArray ty (PrimState prim)
                   -> (Ptr ty -> prim a)
                   -> prim a
withMutablePtrHint _ _ (MUArray start _ (MUArrayAddr fptr))  f =
    withFinalPtr fptr (\ptr -> f (ptr `plusPtr` os))
  where
    sz           = primSizeInBytes (Proxy :: Proxy ty)
    !(Offset os) = offsetOfE sz start
withMutablePtrHint skipCopy skipCopyBack vec@(MUArray start vecSz (MUArrayMBA mb)) f
    | BLK.isMutablePinned mb == Pinned = MBLK.mutableWithAddr mb (\ptr -> f (ptr `plusPtr` os))
    | otherwise                        = do
        trampoline <- newPinned vecSz
        if not skipCopy
            then copyAt trampoline 0 vec 0 vecSz
            else pure ()
        r <- withMutablePtr trampoline f
        if not skipCopyBack
            then copyAt vec 0 trampoline 0 vecSz
            else pure ()
        pure r
  where
    !(Offset os) = offsetOfE sz start
    sz           = primSizeInBytes (Proxy :: Proxy ty)
withMutablePtr :: (PrimMonad prim, PrimType ty)
               => MUArray ty (PrimState prim)
               -> (Ptr ty -> prim a)
               -> prim a
withMutablePtr = withMutablePtrHint False False
copyFromPtr :: forall prim ty . (PrimMonad prim, PrimType ty)
            => Ptr ty -> CountOf ty -> MUArray ty (PrimState prim) -> prim ()
copyFromPtr src@(Ptr src#) count marr
    | count > arrSz = primOutOfBound OOB_MemCopy (sizeAsOffset count) arrSz
    | otherwise     = onMutableBackend copyNative copyPtr marr
  where
    arrSz = mutableLength marr
    ofs = mutableOffset marr
    sz = primSizeInBytes (Proxy :: Proxy ty)
    !(CountOf bytes@(I# bytes#)) = sizeOfE sz count
    !(Offset od@(I# od#)) = offsetOfE sz ofs
    copyNative mba = primitive $ \st -> (# copyAddrToByteArray# src# mba od# bytes# st, () #)
    copyPtr fptr = withFinalPtr fptr $ \dst ->
        unsafePrimFromIO $ copyBytes (dst `plusPtr` od) src bytes
copyToPtr :: forall ty prim . (PrimType ty, PrimMonad prim)
          => MUArray ty (PrimState prim) 
          -> Ptr ty                      
          -> prim ()
copyToPtr marr dst@(Ptr dst#) = onMutableBackend copyNative copyPtr marr
  where
    copyNative mba = primitive $ \s1 ->
        case unsafeFreezeByteArray# mba s1 of
            (# s2, ba #) -> (# compatCopyByteArrayToAddr# ba os# dst# szBytes# s2, () #)
    copyPtr fptr = unsafePrimFromIO $ withFinalPtr fptr $ \ptr ->
        copyBytes dst (ptr `plusPtr` os) szBytes
    !(Offset os@(I# os#)) = offsetInBytes $ mutableOffset marr
    !(CountOf szBytes@(I# szBytes#)) = sizeInBytes $ mutableLength marr
mutableOffset :: MUArray ty st -> Offset ty
mutableOffset (MUArray ofs _ _) = ofs