{-# language BangPatterns #-}
{-# language MagicHash #-}
{-# language RankNTypes #-}
{-# language ScopedTypeVariables #-}
{-# language TypeFamilies #-}
{-# language UnboxedTuples #-}
module Data.Primitive.Unlifted.Array
  ( 
    UnliftedArray(..)
  , MutableUnliftedArray(..)
    
  , newUnliftedArray
  , unsafeNewUnliftedArray
  , sizeofUnliftedArray
  , sizeofMutableUnliftedArray
  , sameMutableUnliftedArray
  , writeUnliftedArray
  , readUnliftedArray
  , indexUnliftedArray
  , unsafeFreezeUnliftedArray
  , freezeUnliftedArray
  , thawUnliftedArray
  , setUnliftedArray
  , copyUnliftedArray
  , copyMutableUnliftedArray
  , cloneUnliftedArray
  , cloneMutableUnliftedArray
  , emptyUnliftedArray
  , runUnliftedArray
    
  , unliftedArrayToList
  , unliftedArrayFromList
  , unliftedArrayFromListN
    
  , foldrUnliftedArray
  , foldrUnliftedArray'
  , foldlUnliftedArray
  , foldlUnliftedArray'
  , foldlUnliftedArrayM'
    
  , traverseUnliftedArray_
  , itraverseUnliftedArray_
    
  , mapUnliftedArray
  ) where
import Control.Monad.Primitive (PrimMonad,PrimState,primitive,primitive_)
import Control.Monad.ST (ST)
import Data.Primitive.Unlifted.Class (PrimUnlifted)
import GHC.Exts (Int(I#),MutableArrayArray#,ArrayArray#,State#)
import qualified Data.List as L
import qualified Data.Primitive.Unlifted.Class as C
import qualified GHC.Exts as Exts
import qualified GHC.ST as ST
data MutableUnliftedArray s a
  = MutableUnliftedArray (MutableArrayArray# s)
data UnliftedArray a
  = UnliftedArray ArrayArray#
unsafeNewUnliftedArray
  :: (PrimMonad m)
  => Int 
  -> m (MutableUnliftedArray (PrimState m) a)
{-# inline unsafeNewUnliftedArray #-}
unsafeNewUnliftedArray (I# i#) = primitive $ \s -> case Exts.newArrayArray# i# s of
  (# s', maa# #) -> (# s', MutableUnliftedArray maa# #)
newUnliftedArray
  :: (PrimMonad m, PrimUnlifted a)
  => Int 
  -> a 
  -> m (MutableUnliftedArray (PrimState m) a)
newUnliftedArray len v = do
  mua <- unsafeNewUnliftedArray len
  setUnliftedArray mua v 0 len
  pure mua
{-# inline newUnliftedArray #-}
setUnliftedArray
  :: (PrimMonad m, PrimUnlifted a)
  => MutableUnliftedArray (PrimState m) a 
  -> a 
  -> Int 
  -> Int 
  -> m ()
{-# inline setUnliftedArray #-}
setUnliftedArray mua v off len = loop (len + off - 1)
 where
 loop i
   | i < off = pure ()
   | otherwise = writeUnliftedArray mua i v *> loop (i-1)
sizeofUnliftedArray :: UnliftedArray e -> Int
{-# inline sizeofUnliftedArray #-}
sizeofUnliftedArray (UnliftedArray aa#) = I# (Exts.sizeofArrayArray# aa#)
sizeofMutableUnliftedArray :: MutableUnliftedArray s e -> Int
{-# inline sizeofMutableUnliftedArray #-}
sizeofMutableUnliftedArray (MutableUnliftedArray maa#)
  = I# (Exts.sizeofMutableArrayArray# maa#)
writeUnliftedArray :: (PrimMonad m, PrimUnlifted a)
  => MutableUnliftedArray (PrimState m) a
  -> Int
  -> a
  -> m ()
{-# inline writeUnliftedArray #-}
writeUnliftedArray (MutableUnliftedArray arr) (I# ix) a =
  primitive_ (C.writeUnliftedArray# arr ix a)
readUnliftedArray :: (PrimMonad m, PrimUnlifted a)
  => MutableUnliftedArray (PrimState m) a
  -> Int
  -> m a
{-# inline readUnliftedArray #-}
readUnliftedArray (MutableUnliftedArray arr) (I# ix) =
  primitive (C.readUnliftedArray# arr ix)
indexUnliftedArray :: PrimUnlifted a
  => UnliftedArray a
  -> Int
  -> a
{-# inline indexUnliftedArray #-}
indexUnliftedArray (UnliftedArray arr) (I# ix) =
  C.indexUnliftedArray# arr ix
unsafeFreezeUnliftedArray
  :: (PrimMonad m)
  => MutableUnliftedArray (PrimState m) a
  -> m (UnliftedArray a)
unsafeFreezeUnliftedArray (MutableUnliftedArray maa#)
  = primitive $ \s -> case Exts.unsafeFreezeArrayArray# maa# s of
      (# s', aa# #) -> (# s', UnliftedArray aa# #)
{-# inline unsafeFreezeUnliftedArray #-}
sameMutableUnliftedArray
  :: MutableUnliftedArray s a
  -> MutableUnliftedArray s a
  -> Bool
sameMutableUnliftedArray (MutableUnliftedArray maa1#) (MutableUnliftedArray maa2#)
  = Exts.isTrue# (Exts.sameMutableArrayArray# maa1# maa2#)
{-# inline sameMutableUnliftedArray #-}
copyUnliftedArray
  :: (PrimMonad m)
  => MutableUnliftedArray (PrimState m) a 
  -> Int 
  -> UnliftedArray a 
  -> Int 
  -> Int 
  -> m ()
{-# inline copyUnliftedArray #-}
copyUnliftedArray
  (MutableUnliftedArray dst) (I# doff)
  (UnliftedArray src) (I# soff) (I# ln) =
    primitive_ $ Exts.copyArrayArray# src soff dst doff ln
copyMutableUnliftedArray
  :: (PrimMonad m)
  => MutableUnliftedArray (PrimState m) a 
  -> Int 
  -> MutableUnliftedArray (PrimState m) a 
  -> Int 
  -> Int 
  -> m ()
{-# inline copyMutableUnliftedArray #-}
copyMutableUnliftedArray
  (MutableUnliftedArray dst) (I# doff)
  (MutableUnliftedArray src) (I# soff) (I# ln) =
    primitive_ $ Exts.copyMutableArrayArray# src soff dst doff ln
freezeUnliftedArray
  :: (PrimMonad m)
  => MutableUnliftedArray (PrimState m) a 
  -> Int 
  -> Int 
  -> m (UnliftedArray a)
freezeUnliftedArray src off len = do
  dst <- unsafeNewUnliftedArray len
  copyMutableUnliftedArray dst 0 src off len
  unsafeFreezeUnliftedArray dst
{-# inline freezeUnliftedArray #-}
thawUnliftedArray
  :: (PrimMonad m)
  => UnliftedArray a 
  -> Int 
  -> Int 
  -> m (MutableUnliftedArray (PrimState m) a)
{-# inline thawUnliftedArray #-}
thawUnliftedArray src off len = do
  dst <- unsafeNewUnliftedArray len
  copyUnliftedArray dst 0 src off len
  return dst
unsafeCreateUnliftedArray
  :: Int
  -> (forall s. MutableUnliftedArray s a -> ST s ())
  -> UnliftedArray a
unsafeCreateUnliftedArray !n f = runUnliftedArray $ do
  mary <- unsafeNewUnliftedArray n
  f mary
  pure mary
runUnliftedArray
  :: (forall s. ST s (MutableUnliftedArray s a))
  -> UnliftedArray a
{-# INLINE runUnliftedArray #-}
runUnliftedArray m = UnliftedArray (runUnliftedArray# m)
runUnliftedArray#
  :: (forall s. ST s (MutableUnliftedArray s a))
  -> ArrayArray#
runUnliftedArray# m = case Exts.runRW# $ \s ->
  case unST m s of { (# s', MutableUnliftedArray mary# #) ->
  Exts.unsafeFreezeArrayArray# mary# s'} of (# _, ary# #) -> ary#
unST :: ST s a -> State# s -> (# State# s, a #)
unST (ST.ST f) = f
cloneUnliftedArray
  :: UnliftedArray a 
  -> Int 
  -> Int 
  -> UnliftedArray a
{-# inline cloneUnliftedArray #-}
cloneUnliftedArray src off len =
  runUnliftedArray (thawUnliftedArray src off len)
cloneMutableUnliftedArray
  :: (PrimMonad m)
  => MutableUnliftedArray (PrimState m) a 
  -> Int 
  -> Int 
  -> m (MutableUnliftedArray (PrimState m) a)
{-# inline cloneMutableUnliftedArray #-}
cloneMutableUnliftedArray src off len = do
  dst <- unsafeNewUnliftedArray len
  copyMutableUnliftedArray dst 0 src off len
  return dst
emptyUnliftedArray :: UnliftedArray a
emptyUnliftedArray = runUnliftedArray (unsafeNewUnliftedArray 0)
{-# NOINLINE emptyUnliftedArray #-}
concatUnliftedArray :: UnliftedArray a -> UnliftedArray a -> UnliftedArray a
{-# INLINE concatUnliftedArray #-}
concatUnliftedArray x y = unsafeCreateUnliftedArray (sizeofUnliftedArray x + sizeofUnliftedArray y) $ \m -> do
  copyUnliftedArray m 0 x 0 (sizeofUnliftedArray x)
  copyUnliftedArray m (sizeofUnliftedArray x) y 0 (sizeofUnliftedArray y)
foldrUnliftedArray :: forall a b. PrimUnlifted a => (a -> b -> b) -> b -> UnliftedArray a -> b
{-# INLINE foldrUnliftedArray #-}
foldrUnliftedArray f z arr = go 0
  where
    !sz = sizeofUnliftedArray arr
    go !i
      | sz > i = f (indexUnliftedArray arr i) (go (i+1))
      | otherwise = z
{-# INLINE foldrUnliftedArray' #-}
foldrUnliftedArray' :: forall a b. PrimUnlifted a => (a -> b -> b) -> b -> UnliftedArray a -> b
foldrUnliftedArray' f z0 arr = go (sizeofUnliftedArray arr - 1) z0
  where
    go !i !acc
      | i < 0 = acc
      | otherwise = go (i - 1) (f (indexUnliftedArray arr i) acc)
{-# INLINE foldlUnliftedArray #-}
foldlUnliftedArray :: forall a b. PrimUnlifted a => (b -> a -> b) -> b -> UnliftedArray a -> b
foldlUnliftedArray f z arr = go (sizeofUnliftedArray arr - 1)
  where
    go !i
      | i < 0 = z
      | otherwise = f (go (i - 1)) (indexUnliftedArray arr i)
{-# INLINE foldlUnliftedArray' #-}
foldlUnliftedArray' :: forall a b. PrimUnlifted a => (b -> a -> b) -> b -> UnliftedArray a -> b
foldlUnliftedArray' f z0 arr = go 0 z0
  where
    !sz = sizeofUnliftedArray arr
    go !i !acc
      | i < sz = go (i + 1) (f acc (indexUnliftedArray arr i))
      | otherwise = acc
{-# INLINE foldlUnliftedArrayM' #-}
foldlUnliftedArrayM' :: (PrimUnlifted a, Monad m)
  => (b -> a -> m b) -> b -> UnliftedArray a -> m b
foldlUnliftedArrayM' f z0 arr = go 0 z0
  where
    !sz = sizeofUnliftedArray arr
    go !i !acc
      | i < sz = f acc (indexUnliftedArray arr i) >>= go (i + 1)
      | otherwise = pure acc
{-# INLINE traverseUnliftedArray_ #-}
traverseUnliftedArray_ :: (PrimUnlifted a, Applicative m)
  => (a -> m b) -> UnliftedArray a -> m ()
traverseUnliftedArray_ f arr = go 0
  where
    !sz = sizeofUnliftedArray arr
    go !i
      | i < sz = f (indexUnliftedArray arr i) *> go (i + 1)
      | otherwise = pure ()
{-# INLINE itraverseUnliftedArray_ #-}
itraverseUnliftedArray_ :: (PrimUnlifted a, Applicative m)
  => (Int -> a -> m b) -> UnliftedArray a -> m ()
itraverseUnliftedArray_ f arr = go 0
  where
    !sz = sizeofUnliftedArray arr
    go !i
      | i < sz = f i (indexUnliftedArray arr i) *> go (i + 1)
      | otherwise = pure ()
{-# INLINE mapUnliftedArray #-}
mapUnliftedArray :: (PrimUnlifted a, PrimUnlifted b)
  => (a -> b)
  -> UnliftedArray a
  -> UnliftedArray b
mapUnliftedArray f arr = unsafeCreateUnliftedArray sz $ \marr -> do
  let go !ix = if ix < sz
        then do
          let b = f (indexUnliftedArray arr ix)
          writeUnliftedArray marr ix b
          go (ix + 1)
        else return ()
  go 0
  where
  !sz = sizeofUnliftedArray arr
{-# INLINE unliftedArrayToList #-}
unliftedArrayToList :: PrimUnlifted a => UnliftedArray a -> [a]
unliftedArrayToList xs = Exts.build (\c n -> foldrUnliftedArray c n xs)
unliftedArrayFromList :: PrimUnlifted a => [a] -> UnliftedArray a
unliftedArrayFromList xs = unliftedArrayFromListN (L.length xs) xs
unliftedArrayFromListN :: forall a. PrimUnlifted a => Int -> [a] -> UnliftedArray a
unliftedArrayFromListN len vs = unsafeCreateUnliftedArray len run where
  run :: forall s. MutableUnliftedArray s a -> ST s ()
  run arr = do
    let go :: [a] -> Int -> ST s ()
        go [] !ix = if ix == len
          
          
          
          then return ()
          else die "unliftedArrayFromListN" "list length less than specified size"
        go (a : as) !ix = if ix < len
          then do
            writeUnliftedArray arr ix a
            go as (ix + 1)
          else die "unliftedArrayFromListN" "list length greater than specified size"
    go vs 0
instance PrimUnlifted a => Exts.IsList (UnliftedArray a) where
  type Item (UnliftedArray a) = a
  fromList = unliftedArrayFromList
  fromListN = unliftedArrayFromListN
  toList = unliftedArrayToList
instance PrimUnlifted a => Semigroup (UnliftedArray a) where
  (<>) = concatUnliftedArray
instance PrimUnlifted a => Monoid (UnliftedArray a) where
  mempty = emptyUnliftedArray
instance (Show a, PrimUnlifted a) => Show (UnliftedArray a) where
  showsPrec p a = showParen (p > 10) $
    showString "fromListN " . shows (sizeofUnliftedArray a) . showString " "
      . shows (unliftedArrayToList a)
instance Eq (MutableUnliftedArray s a) where
  (==) = sameMutableUnliftedArray
instance (Eq a, PrimUnlifted a) => Eq (UnliftedArray a) where
  aa1 == aa2 = sizeofUnliftedArray aa1 == sizeofUnliftedArray aa2
            && loop (sizeofUnliftedArray aa1 - 1)
   where
   loop i
     | i < 0 = True
     | otherwise = indexUnliftedArray aa1 i == indexUnliftedArray aa2 i && loop (i-1)
die :: String -> String -> a
die fun problem = error $ "Data.Primitive.UnliftedArray." ++ fun ++ ": " ++ problem