module Basement.Block.Base
    ( Block(..)
    , MutableBlock(..)
    
    , unsafeNew
    , unsafeThaw
    , unsafeFreeze
    , unsafeShrink
    , unsafeCopyElements
    , unsafeCopyElementsRO
    , unsafeCopyBytes
    , unsafeCopyBytesRO
    , unsafeCopyBytesPtr
    , unsafeRead
    , unsafeWrite
    , unsafeIndex
    
    , length
    , lengthBytes
    , isPinned
    , isMutablePinned
    , mutableLength
    , mutableLengthBytes
    
    , mutableEmpty
    , new
    , newPinned
    , withPtr
    , withMutablePtr
    , withMutablePtrHint
    , mutableWithPtr
    , unsafeRecast
    ) where
import           GHC.Prim
import           GHC.Types
import           GHC.ST
import           GHC.IO
import qualified Data.List
import           Basement.Compat.Base
import           Data.Proxy
import           Basement.Compat.Primitive
import           Basement.Compat.Semigroup
import           Basement.Bindings.Memory (sysHsMemcmpBaBa)
import           Basement.Types.OffsetSize
import           Basement.Monad
import           Basement.NormalForm
import           Basement.Numerical.Additive
import           Basement.PrimType
data Block ty = Block ByteArray#
    deriving (Typeable)
instance Data ty => Data (Block ty) where
    dataTypeOf _ = blockType
    toConstr _   = error "toConstr"
    gunfold _ _  = error "gunfold"
blockType :: DataType
blockType = mkNoRepType "Foundation.Block"
instance NormalForm (Block ty) where
    toNormalForm (Block !_) = ()
instance (PrimType ty, Show ty) => Show (Block ty) where
    show v = show (toList v)
instance (PrimType ty, Eq ty) => Eq (Block ty) where
    
    (==) = equal
instance (PrimType ty, Ord ty) => Ord (Block ty) where
    compare = internalCompare
instance PrimType ty => Semigroup (Block ty) where
    (<>) = append
instance PrimType ty => Monoid (Block ty) where
    mempty  = empty
    mappend = append
    mconcat = concat
instance PrimType ty => IsList (Block ty) where
    type Item (Block ty) = ty
    fromList = internalFromList
    toList = internalToList
data MutableBlock ty st = MutableBlock (MutableByteArray# st)
isPinned :: Block ty -> PinnedStatus
isPinned (Block ba) = toPinnedStatus# (compatIsByteArrayPinned# ba)
isMutablePinned :: MutableBlock s ty -> PinnedStatus
isMutablePinned (MutableBlock mba) = toPinnedStatus# (compatIsMutableByteArrayPinned# mba)
length :: forall ty . PrimType ty => Block ty -> CountOf ty
length (Block ba) =
    case primShiftToBytes (Proxy :: Proxy ty) of
        0           -> CountOf (I# (sizeofByteArray# ba))
        (I# szBits) -> CountOf (I# (uncheckedIShiftRL# (sizeofByteArray# ba) szBits))
lengthBytes :: Block ty -> CountOf Word8
lengthBytes (Block ba) = CountOf (I# (sizeofByteArray# ba))
mutableLength :: forall ty st . PrimType ty => MutableBlock ty st -> CountOf ty
mutableLength mb = sizeRecast $ mutableLengthBytes mb
mutableLengthBytes :: MutableBlock ty st -> CountOf Word8
mutableLengthBytes (MutableBlock mba) = CountOf (I# (sizeofMutableByteArray# mba))
empty :: Block ty
empty = Block ba where !(Block ba) = empty_
empty_ :: Block ()
empty_ = runST $ primitive $ \s1 ->
    case newByteArray# 0# s1           of { (# s2, mba #) ->
    case unsafeFreezeByteArray# mba s2 of { (# s3, ba  #) ->
        (# s3, Block ba #) }}
mutableEmpty :: PrimMonad prim => prim (MutableBlock ty (PrimState prim))
mutableEmpty = primitive $ \s1 ->
    case newByteArray# 0# s1 of { (# s2, mba #) ->
        (# s2, MutableBlock mba #) }
unsafeIndex :: forall ty . PrimType ty => Block ty -> Offset ty -> ty
unsafeIndex (Block ba) n = primBaIndex ba n
internalFromList :: PrimType ty => [ty] -> Block ty
internalFromList l = runST $ do
    ma <- new (CountOf len)
    iter azero l $ \i x -> unsafeWrite ma i x
    unsafeFreeze ma
  where len = Data.List.length l
        iter _  []     _ = return ()
        iter !i (x:xs) z = z i x >> iter (i+1) xs z
internalToList :: forall ty . PrimType ty => Block ty -> [ty]
internalToList blk@(Block ba)
    | len == azero = []
    | otherwise    = loop azero
  where
    !len = length blk
    loop !i | i .==# len = []
            | otherwise  = primBaIndex ba i : loop (i+1)
equal :: (PrimType ty, Eq ty) => Block ty -> Block ty -> Bool
equal a b
    | la /= lb  = False
    | otherwise = loop azero
  where
    !la = lengthBytes a
    !lb = lengthBytes b
    lat = length a
    loop !n | n .==# lat = True
            | otherwise  = (unsafeIndex a n == unsafeIndex b n) && loop (n+o1)
    o1 = Offset (I# 1#)
equalMemcmp :: PrimMemoryComparable ty => Block ty -> Block ty -> Bool
equalMemcmp b1@(Block a) b2@(Block b)
    | la /= lb  = False
    | otherwise = unsafeDupablePerformIO (sysHsMemcmpBaBa a 0 b 0 la) == 0
  where
    la = lengthBytes b1
    lb = lengthBytes b2
internalCompare :: (Ord ty, PrimType ty) => Block ty -> Block ty -> Ordering
internalCompare a b = loop azero
  where
    !la = length a
    !lb = length b
    !end = sizeAsOffset (min la lb)
    loop !n
        | n == end  = la `compare` lb
        | v1 == v2  = loop (n + Offset (I# 1#))
        | otherwise = v1 `compare` v2
      where
        v1 = unsafeIndex a n
        v2 = unsafeIndex b n
compareMemcmp :: PrimMemoryComparable ty => Block ty -> Block ty -> Ordering
compareMemcmp b1@(Block a) b2@(Block b) =
    case unsafeDupablePerformIO (sysHsMemcmpBaBa a 0 b 0 sz) of
        0             -> la `compare` lb
        n | n > 0     -> GT
          | otherwise -> LT
  where
    la = lengthBytes b1
    lb = lengthBytes b2
    sz = min la lb
append :: Block ty -> Block ty -> Block ty
append a b
    | la == azero = b
    | lb == azero = a
    | otherwise = runST $ do
        r  <- unsafeNew Unpinned (la+lb)
        unsafeCopyBytesRO r 0                 a 0 la
        unsafeCopyBytesRO r (sizeAsOffset la) b 0 lb
        unsafeFreeze r
  where
    !la = lengthBytes a
    !lb = lengthBytes b
concat :: forall ty . [Block ty] -> Block ty
concat original = runST $ do
    r <- unsafeNew Unpinned total
    goCopy r zero original
    unsafeFreeze r
  where
    !total = size 0 original
    
    size !sz []     = sz
    size !sz (x:xs) = size (lengthBytes x + sz) xs
    zero = Offset 0
    goCopy r = loop
      where
        loop _  []      = pure ()
        loop !i (x:xs) = do
            unsafeCopyBytesRO r i x zero lx
            loop (i `offsetPlusE` lx) xs
          where !lx = lengthBytes x
unsafeFreeze :: PrimMonad prim => MutableBlock ty (PrimState prim) -> prim (Block ty)
unsafeFreeze (MutableBlock mba) = primitive $ \s1 ->
    case unsafeFreezeByteArray# mba s1 of
        (# s2, ba #) -> (# s2, Block ba #)
unsafeShrink :: PrimMonad prim => MutableBlock ty (PrimState prim) -> CountOf ty -> prim (MutableBlock ty (PrimState prim))
unsafeShrink (MutableBlock mba) (CountOf (I# nsz)) = primitive $ \s ->
    case compatShrinkMutableByteArray# mba nsz s of
        (# s, mba' #) -> (# s, MutableBlock mba' #)
unsafeThaw :: (PrimType ty, PrimMonad prim) => Block ty -> prim (MutableBlock ty (PrimState prim))
unsafeThaw (Block ba) = primitive $ \st -> (# st, MutableBlock (unsafeCoerce# ba) #)
unsafeNew :: PrimMonad prim
          => PinnedStatus
          -> CountOf Word8
          -> prim (MutableBlock ty (PrimState prim))
unsafeNew pinSt (CountOf (I# bytes)) = case pinSt of
    Unpinned -> primitive $ \s1 -> case newByteArray# bytes s1 of { (# s2, mba #) -> (# s2, MutableBlock mba #) }
    _        -> primitive $ \s1 -> case newAlignedPinnedByteArray# bytes 8# s1 of { (# s2, mba #) -> (# s2, MutableBlock mba #) }
new :: forall prim ty . (PrimMonad prim, PrimType ty) => CountOf ty -> prim (MutableBlock ty (PrimState prim))
new n = unsafeNew Unpinned (sizeOfE (primSizeInBytes (Proxy :: Proxy ty)) n)
newPinned :: forall prim ty . (PrimMonad prim, PrimType ty) => CountOf ty -> prim (MutableBlock ty (PrimState prim))
newPinned n = unsafeNew Pinned (sizeOfE (primSizeInBytes (Proxy :: Proxy ty)) n)
unsafeCopyElements :: forall prim ty . (PrimMonad prim, PrimType ty)
                   => MutableBlock ty (PrimState prim) 
                   -> Offset ty                        
                   -> MutableBlock ty (PrimState prim) 
                   -> Offset ty                        
                   -> CountOf ty                          
                   -> prim ()
unsafeCopyElements dstMb destOffset srcMb srcOffset n = 
    unsafeCopyBytes dstMb (offsetOfE sz destOffset)
                    srcMb (offsetOfE sz srcOffset)
                    (sizeOfE sz n)
  where
    !sz = primSizeInBytes (Proxy :: Proxy ty)
unsafeCopyElementsRO :: forall prim ty . (PrimMonad prim, PrimType ty)
                     => MutableBlock ty (PrimState prim) 
                     -> Offset ty                        
                     -> Block ty                         
                     -> Offset ty                        
                     -> CountOf ty                          
                     -> prim ()
unsafeCopyElementsRO dstMb destOffset srcMb srcOffset n =
    unsafeCopyBytesRO dstMb (offsetOfE sz destOffset)
                      srcMb (offsetOfE sz srcOffset)
                      (sizeOfE sz n)
  where
    !sz = primSizeInBytes (Proxy :: Proxy ty)
unsafeCopyBytes :: forall prim ty . PrimMonad prim
                => MutableBlock ty (PrimState prim) 
                -> Offset Word8                     
                -> MutableBlock ty (PrimState prim) 
                -> Offset Word8                     
                -> CountOf Word8                       
                -> prim ()
unsafeCopyBytes (MutableBlock dstMba) (Offset (I# d)) (MutableBlock srcBa) (Offset (I# s)) (CountOf (I# n)) =
    primitive $ \st -> (# copyMutableByteArray# srcBa s dstMba d n st, () #)
unsafeCopyBytesRO :: forall prim ty . PrimMonad prim
                  => MutableBlock ty (PrimState prim) 
                  -> Offset Word8                     
                  -> Block ty                         
                  -> Offset Word8                     
                  -> CountOf Word8                       
                  -> prim ()
unsafeCopyBytesRO (MutableBlock dstMba) (Offset (I# d)) (Block srcBa) (Offset (I# s)) (CountOf (I# n)) =
    primitive $ \st -> (# copyByteArray# srcBa s dstMba d n st, () #)
unsafeCopyBytesPtr :: forall prim ty . PrimMonad prim
                   => MutableBlock ty (PrimState prim) 
                   -> Offset Word8                     
                   -> Ptr ty                           
                   -> CountOf Word8                    
                   -> prim ()
unsafeCopyBytesPtr (MutableBlock dstMba) (Offset (I# d)) (Ptr srcBa) (CountOf (I# n)) =
    primitive $ \st -> (# copyAddrToByteArray# srcBa dstMba d n st, () #)
unsafeRead :: (PrimMonad prim, PrimType ty) => MutableBlock ty (PrimState prim) -> Offset ty -> prim ty
unsafeRead (MutableBlock mba) i = primMbaRead mba i
unsafeWrite :: (PrimMonad prim, PrimType ty) => MutableBlock ty (PrimState prim) -> Offset ty -> ty -> prim ()
unsafeWrite (MutableBlock mba) i v = primMbaWrite mba i v
withPtr :: PrimMonad prim
        => Block ty
        -> (Ptr ty -> prim a)
        -> prim a
withPtr x@(Block ba) f
    | isPinned x == Pinned = f (Ptr (byteArrayContents# ba)) <* touch x
    | otherwise            = do
        arr@(Block arrBa) <- makeTrampoline
        f (Ptr (byteArrayContents# arrBa)) <* touch arr
  where
    makeTrampoline = do
        trampoline <- unsafeNew Pinned (lengthBytes x)
        unsafeCopyBytesRO trampoline 0 x 0 (lengthBytes x)
        unsafeFreeze trampoline
touch :: PrimMonad prim => Block ty -> prim ()
touch (Block ba) =
    unsafePrimFromIO $ primitive $ \s -> case touch# ba s of { s2 -> (# s2, () #) }
unsafeRecast :: (PrimType t1, PrimType t2)
             => MutableBlock t1 st
             -> MutableBlock t2 st
unsafeRecast (MutableBlock mba) = MutableBlock mba
mutableWithPtr :: PrimMonad prim
                => MutableBlock ty (PrimState prim)
                -> (Ptr ty -> prim a)
                -> prim a
mutableWithPtr = withMutablePtr
withMutablePtr :: PrimMonad prim
               => MutableBlock ty (PrimState prim)
               -> (Ptr ty -> prim a)
               -> prim a
withMutablePtr = withMutablePtrHint False False
withMutablePtrHint :: forall ty prim a . PrimMonad prim
                   => Bool 
                   -> Bool 
                   -> MutableBlock ty (PrimState prim)
                   -> (Ptr ty -> prim a)
                   -> prim a
withMutablePtrHint skipCopy skipCopyBack mb f
    | isMutablePinned mb == Pinned = callWithPtr mb
    | otherwise                    = do
        trampoline <- unsafeNew Pinned vecSz
        if not skipCopy
            then unsafeCopyBytes trampoline 0 mb 0 vecSz
            else pure ()
        r <- callWithPtr trampoline
        if not skipCopyBack
            then unsafeCopyBytes mb 0 trampoline 0 vecSz
            else pure ()
        pure r
  where
    vecSz = mutableLengthBytes mb
    callWithPtr pinnedMb = do
        b@(Block ba) <- unsafeFreeze pinnedMb
        f (Ptr (byteArrayContents# ba)) <* touch b