module PrimitiveExtras.Folds
  ( indexCounts,
    unliftedArray,
    primMultiArray,
  )
where

import Control.Foldl
import PrimitiveExtras.Prelude hiding (fold, foldM)
import PrimitiveExtras.Types
import qualified PrimitiveExtras.UnliftedArray as UA

unsafeIO :: (state -> input -> IO state) -> IO state -> (state -> IO output) -> Fold input output
unsafeIO :: forall state input output.
(state -> input -> IO state)
-> IO state -> (state -> IO output) -> Fold input output
unsafeIO state -> input -> IO state
stepInIO IO state
initInIO state -> IO output
extractInIO =
  (state -> input -> state)
-> state -> (state -> output) -> Fold input output
forall a b x. (x -> a -> x) -> x -> (x -> b) -> Fold a b
Fold
    (\ !state
state input
input -> IO state -> state
forall a. IO a -> a
unsafeDupablePerformIO (state -> input -> IO state
stepInIO state
state input
input))
    (IO state -> state
forall a. IO a -> a
unsafeDupablePerformIO IO state
initInIO)
    (\state
state -> let !output :: output
output = IO output -> output
forall a. IO a -> a
unsafePerformIO (state -> IO output
extractInIO state
state) in output
output)

-- |
-- Given a size of the array,
-- construct a fold, which produces an array of index counts.
indexCounts ::
  (Integral count, Prim count) =>
  -- | Array size
  Int ->
  Fold Int (PrimArray count)
indexCounts :: forall count.
(Integral count, Prim count) =>
Int -> Fold Int (PrimArray count)
indexCounts Int
size = (MutablePrimArray RealWorld count
 -> Int -> IO (MutablePrimArray RealWorld count))
-> IO (MutablePrimArray RealWorld count)
-> (MutablePrimArray RealWorld count -> IO (PrimArray count))
-> Fold Int (PrimArray count)
forall state input output.
(state -> input -> IO state)
-> IO state -> (state -> IO output) -> Fold input output
unsafeIO MutablePrimArray RealWorld count
-> Int -> IO (MutablePrimArray RealWorld count)
MutablePrimArray (PrimState IO) count
-> Int -> IO (MutablePrimArray (PrimState IO) count)
forall {m :: * -> *} {a}.
(Prim a, PrimMonad m, Enum a) =>
MutablePrimArray (PrimState m) a
-> Int -> m (MutablePrimArray (PrimState m) a)
step IO (MutablePrimArray RealWorld count)
IO (MutablePrimArray (PrimState IO) count)
init MutablePrimArray RealWorld count -> IO (PrimArray count)
MutablePrimArray (PrimState IO) count -> IO (PrimArray count)
forall {m :: * -> *} {a}.
PrimMonad m =>
MutablePrimArray (PrimState m) a -> m (PrimArray a)
extract
  where
    init :: IO (MutablePrimArray (PrimState IO) count)
init = PrimArray count -> IO (MutablePrimArray (PrimState IO) count)
forall (m :: * -> *) a.
PrimMonad m =>
PrimArray a -> m (MutablePrimArray (PrimState m) a)
unsafeThawPrimArray (Int -> count -> PrimArray count
forall a. Prim a => Int -> a -> PrimArray a
replicatePrimArray Int
size count
0)
    step :: MutablePrimArray (PrimState m) a
-> Int -> m (MutablePrimArray (PrimState m) a)
step MutablePrimArray (PrimState m) a
mutable Int
i = do
      a
count <- MutablePrimArray (PrimState m) a -> Int -> m a
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> m a
readPrimArray MutablePrimArray (PrimState m) a
mutable Int
i
      MutablePrimArray (PrimState m) a -> Int -> a -> m ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
writePrimArray MutablePrimArray (PrimState m) a
mutable Int
i (a -> a
forall a. Enum a => a -> a
succ a
count)
      MutablePrimArray (PrimState m) a
-> m (MutablePrimArray (PrimState m) a)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return MutablePrimArray (PrimState m) a
mutable
    extract :: MutablePrimArray (PrimState m) a -> m (PrimArray a)
extract = MutablePrimArray (PrimState m) a -> m (PrimArray a)
forall {m :: * -> *} {a}.
PrimMonad m =>
MutablePrimArray (PrimState m) a -> m (PrimArray a)
unsafeFreezePrimArray

-- |
-- This function is partial in the sense that it expects the
-- index vector of produced elements to be within the specified amount.
unliftedArray ::
  (PrimUnlifted element) =>
  -- | Size of the array
  Int ->
  Fold (Int, element) (UnliftedArray element)
unliftedArray :: forall element.
PrimUnlifted element =>
Int -> Fold (Int, element) (UnliftedArray element)
unliftedArray Int
size =
  (MutableUnliftedArray_ (Unlifted element) RealWorld element
 -> (Int, element)
 -> IO (MutableUnliftedArray_ (Unlifted element) RealWorld element))
-> IO (MutableUnliftedArray_ (Unlifted element) RealWorld element)
-> (MutableUnliftedArray_ (Unlifted element) RealWorld element
    -> IO (UnliftedArray_ (Unlifted element) element))
-> Fold (Int, element) (UnliftedArray_ (Unlifted element) element)
forall state input output.
(state -> input -> IO state)
-> IO state -> (state -> IO output) -> Fold input output
unsafeIO MutableUnliftedArray_ (Unlifted element) RealWorld element
-> (Int, element)
-> IO (MutableUnliftedArray_ (Unlifted element) RealWorld element)
MutableUnliftedArray_ (Unlifted element) (PrimState IO) element
-> (Int, element)
-> IO
     (MutableUnliftedArray_ (Unlifted element) (PrimState IO) element)
forall {f :: * -> *} {a}.
(PrimMonad f, PrimUnlifted a) =>
MutableUnliftedArray_ (Unlifted a) (PrimState f) a
-> (Int, a)
-> f (MutableUnliftedArray_ (Unlifted a) (PrimState f) a)
step IO (MutableUnliftedArray_ (Unlifted element) RealWorld element)
IO
  (MutableUnliftedArray_ (Unlifted element) (PrimState IO) element)
init MutableUnliftedArray_ (Unlifted element) RealWorld element
-> IO (UnliftedArray_ (Unlifted element) element)
MutableUnliftedArray_ (Unlifted element) (PrimState IO) element
-> IO (UnliftedArray_ (Unlifted element) element)
forall {m :: * -> *} {a}.
PrimMonad m =>
MutableUnliftedArray_ (Unlifted a) (PrimState m) a
-> m (UnliftedArray_ (Unlifted a) a)
extract
  where
    step :: MutableUnliftedArray_ (Unlifted a) (PrimState f) a
-> (Int, a)
-> f (MutableUnliftedArray_ (Unlifted a) (PrimState f) a)
step MutableUnliftedArray_ (Unlifted a) (PrimState f) a
mutable (Int
index, a
element) =
      MutableUnliftedArray_ (Unlifted a) (PrimState f) a
-> Int -> a -> f ()
forall (m :: * -> *) a.
(PrimMonad m, PrimUnlifted a) =>
MutableUnliftedArray (PrimState m) a -> Int -> a -> m ()
writeUnliftedArray MutableUnliftedArray_ (Unlifted a) (PrimState f) a
mutable Int
index a
element f ()
-> MutableUnliftedArray_ (Unlifted a) (PrimState f) a
-> f (MutableUnliftedArray_ (Unlifted a) (PrimState f) a)
forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> MutableUnliftedArray_ (Unlifted a) (PrimState f) a
mutable
    init :: IO
  (MutableUnliftedArray_ (Unlifted element) (PrimState IO) element)
init =
      Int
-> IO
     (MutableUnliftedArray_ (Unlifted element) (PrimState IO) element)
forall (m :: * -> *) a.
PrimMonad m =>
Int -> m (MutableUnliftedArray (PrimState m) a)
unsafeNewUnliftedArray Int
size
    extract :: MutableUnliftedArray_ (Unlifted a) (PrimState m) a
-> m (UnliftedArray_ (Unlifted a) a)
extract =
      MutableUnliftedArray_ (Unlifted a) (PrimState m) a
-> m (UnliftedArray_ (Unlifted a) a)
forall {m :: * -> *} {a}.
PrimMonad m =>
MutableUnliftedArray_ (Unlifted a) (PrimState m) a
-> m (UnliftedArray_ (Unlifted a) a)
unsafeFreezeUnliftedArray

-- |
-- Having a priorly computed array of inner dimension sizes,
-- e.g., using the 'indexCounts' fold,
-- construct a fold over indexed elements into a multi-array of elements.
--
-- Thus it allows to construct it in two passes over the indexed elements.
primMultiArray :: forall size element. (Integral size, Prim size, Prim element) => PrimArray size -> Fold (Int, element) (PrimMultiArray element)
primMultiArray :: forall size element.
(Integral size, Prim size, Prim element) =>
PrimArray size -> Fold (Int, element) (PrimMultiArray element)
primMultiArray PrimArray size
sizeArray =
  (Product2
   (MutablePrimArray RealWorld size)
   (UnliftedArray_
      (MutableByteArray# RealWorld) (MutablePrimArray RealWorld element))
 -> (Int, element)
 -> IO
      (Product2
         (MutablePrimArray RealWorld size)
         (UnliftedArray_
            (MutableByteArray# RealWorld)
            (MutablePrimArray RealWorld element))))
-> IO
     (Product2
        (MutablePrimArray RealWorld size)
        (UnliftedArray_
           (MutableByteArray# RealWorld)
           (MutablePrimArray RealWorld element)))
-> (Product2
      (MutablePrimArray RealWorld size)
      (UnliftedArray_
         (MutableByteArray# RealWorld) (MutablePrimArray RealWorld element))
    -> IO (PrimMultiArray element))
-> Fold (Int, element) (PrimMultiArray element)
forall state input output.
(state -> input -> IO state)
-> IO state -> (state -> IO output) -> Fold input output
unsafeIO Product2
  (MutablePrimArray RealWorld size)
  (UnliftedArray_
     (MutableByteArray# RealWorld) (MutablePrimArray RealWorld element))
-> (Int, element)
-> IO
     (Product2
        (MutablePrimArray RealWorld size)
        (UnliftedArray_
           (MutableByteArray# RealWorld)
           (MutablePrimArray RealWorld element)))
Product2
  (MutablePrimArray (PrimState IO) size)
  (UnliftedArray_
     (MutableByteArray# (PrimState IO))
     (MutablePrimArray (PrimState IO) element))
-> (Int, element)
-> IO
     (Product2
        (MutablePrimArray (PrimState IO) size)
        (UnliftedArray_
           (MutableByteArray# (PrimState IO))
           (MutablePrimArray (PrimState IO) element)))
forall {m :: * -> *} {a} {a}.
(PrimMonad m, Prim a, Prim a, Integral a) =>
Product2
  (MutablePrimArray (PrimState m) a)
  (UnliftedArray_
     (MutableByteArray# (PrimState m))
     (MutablePrimArray (PrimState m) a))
-> (Int, a)
-> m (Product2
        (MutablePrimArray (PrimState m) a)
        (UnliftedArray_
           (MutableByteArray# (PrimState m))
           (MutablePrimArray (PrimState m) a)))
step IO
  (Product2
     (MutablePrimArray RealWorld size)
     (UnliftedArray_
        (MutableByteArray# RealWorld)
        (MutablePrimArray RealWorld element)))
init Product2
  (MutablePrimArray RealWorld size)
  (UnliftedArray_
     (MutableByteArray# RealWorld) (MutablePrimArray RealWorld element))
-> IO (PrimMultiArray element)
extract
  where
    outerLength :: Int
outerLength = PrimArray size -> Int
forall a. Prim a => PrimArray a -> Int
sizeofPrimArray PrimArray size
sizeArray
    init :: IO
  (Product2
     (MutablePrimArray RealWorld size)
     (UnliftedArray_
        (MutableByteArray# RealWorld)
        (MutablePrimArray RealWorld element)))
init =
      MutablePrimArray RealWorld size
-> UnliftedArray_
     (MutableByteArray# RealWorld) (MutablePrimArray RealWorld element)
-> Product2
     (MutablePrimArray RealWorld size)
     (UnliftedArray_
        (MutableByteArray# RealWorld) (MutablePrimArray RealWorld element))
forall a b. a -> b -> Product2 a b
Product2 (MutablePrimArray RealWorld size
 -> UnliftedArray_
      (MutableByteArray# RealWorld) (MutablePrimArray RealWorld element)
 -> Product2
      (MutablePrimArray RealWorld size)
      (UnliftedArray_
         (MutableByteArray# RealWorld)
         (MutablePrimArray RealWorld element)))
-> IO (MutablePrimArray RealWorld size)
-> IO
     (UnliftedArray_
        (MutableByteArray# RealWorld) (MutablePrimArray RealWorld element)
      -> Product2
           (MutablePrimArray RealWorld size)
           (UnliftedArray_
              (MutableByteArray# RealWorld)
              (MutablePrimArray RealWorld element)))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO (MutablePrimArray RealWorld size)
initIndexArray IO
  (UnliftedArray_
     (MutableByteArray# RealWorld) (MutablePrimArray RealWorld element)
   -> Product2
        (MutablePrimArray RealWorld size)
        (UnliftedArray_
           (MutableByteArray# RealWorld)
           (MutablePrimArray RealWorld element)))
-> IO
     (UnliftedArray_
        (MutableByteArray# RealWorld) (MutablePrimArray RealWorld element))
-> IO
     (Product2
        (MutablePrimArray RealWorld size)
        (UnliftedArray_
           (MutableByteArray# RealWorld)
           (MutablePrimArray RealWorld element)))
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IO
  (UnliftedArray_
     (MutableByteArray# RealWorld) (MutablePrimArray RealWorld element))
IO (UnliftedArray (MutablePrimArray RealWorld element))
initMultiArray
      where
        initIndexArray :: IO (MutablePrimArray RealWorld size)
        initIndexArray :: IO (MutablePrimArray RealWorld size)
initIndexArray =
          PrimArray size -> IO (MutablePrimArray (PrimState IO) size)
forall (m :: * -> *) a.
PrimMonad m =>
PrimArray a -> m (MutablePrimArray (PrimState m) a)
unsafeThawPrimArray (Int -> size -> PrimArray size
forall a. Prim a => Int -> a -> PrimArray a
replicatePrimArray Int
outerLength size
0)
        initMultiArray :: IO (UnliftedArray (MutablePrimArray RealWorld element))
        initMultiArray :: IO (UnliftedArray (MutablePrimArray RealWorld element))
initMultiArray =
          Int
-> (Int -> IO (MutablePrimArray RealWorld element))
-> IO (UnliftedArray (MutablePrimArray RealWorld element))
forall a.
PrimUnlifted a =>
Int -> (Int -> IO a) -> IO (UnliftedArray a)
UA.generate Int
outerLength ((Int -> IO (MutablePrimArray RealWorld element))
 -> IO (UnliftedArray (MutablePrimArray RealWorld element)))
-> (Int -> IO (MutablePrimArray RealWorld element))
-> IO (UnliftedArray (MutablePrimArray RealWorld element))
forall a b. (a -> b) -> a -> b
$ \Int
index -> do
            Int -> IO (MutablePrimArray (PrimState IO) element)
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
Int -> m (MutablePrimArray (PrimState m) a)
newPrimArray (size -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (PrimArray size -> Int -> size
forall a. Prim a => PrimArray a -> Int -> a
indexPrimArray PrimArray size
sizeArray Int
index))
    step :: Product2
  (MutablePrimArray (PrimState m) a)
  (UnliftedArray_
     (Unlifted (MutablePrimArray (PrimState m) a))
     (MutablePrimArray (PrimState m) a))
-> (Int, a)
-> m (Product2
        (MutablePrimArray (PrimState m) a)
        (UnliftedArray_
           (Unlifted (MutablePrimArray (PrimState m) a))
           (MutablePrimArray (PrimState m) a)))
step (Product2 MutablePrimArray (PrimState m) a
indexArray UnliftedArray_
  (Unlifted (MutablePrimArray (PrimState m) a))
  (MutablePrimArray (PrimState m) a)
multiArray) (Int
outerIndex, a
element) = do
      let innerArray :: MutablePrimArray (PrimState m) a
innerArray = UnliftedArray_
  (Unlifted (MutablePrimArray (PrimState m) a))
  (MutablePrimArray (PrimState m) a)
-> Int -> MutablePrimArray (PrimState m) a
forall a. PrimUnlifted a => UnliftedArray a -> Int -> a
indexUnliftedArray UnliftedArray_
  (Unlifted (MutablePrimArray (PrimState m) a))
  (MutablePrimArray (PrimState m) a)
multiArray Int
outerIndex
      a
innerIndex <- MutablePrimArray (PrimState m) a -> Int -> m a
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> m a
readPrimArray MutablePrimArray (PrimState m) a
indexArray Int
outerIndex
      MutablePrimArray (PrimState m) a -> Int -> a -> m ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
writePrimArray MutablePrimArray (PrimState m) a
indexArray Int
outerIndex (a -> a
forall a. Enum a => a -> a
succ a
innerIndex)
      MutablePrimArray (PrimState m) a -> Int -> a -> m ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
writePrimArray MutablePrimArray (PrimState m) a
innerArray (a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
innerIndex) a
element
      Product2
  (MutablePrimArray (PrimState m) a)
  (UnliftedArray_
     (Unlifted (MutablePrimArray (PrimState m) a))
     (MutablePrimArray (PrimState m) a))
-> m (Product2
        (MutablePrimArray (PrimState m) a)
        (UnliftedArray_
           (Unlifted (MutablePrimArray (PrimState m) a))
           (MutablePrimArray (PrimState m) a)))
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (MutablePrimArray (PrimState m) a
-> UnliftedArray_
     (Unlifted (MutablePrimArray (PrimState m) a))
     (MutablePrimArray (PrimState m) a)
-> Product2
     (MutablePrimArray (PrimState m) a)
     (UnliftedArray_
        (Unlifted (MutablePrimArray (PrimState m) a))
        (MutablePrimArray (PrimState m) a))
forall a b. a -> b -> Product2 a b
Product2 MutablePrimArray (PrimState m) a
indexArray UnliftedArray_
  (Unlifted (MutablePrimArray (PrimState m) a))
  (MutablePrimArray (PrimState m) a)
multiArray)
    extract :: Product2
  (MutablePrimArray RealWorld size)
  (UnliftedArray_
     (MutableByteArray# RealWorld) (MutablePrimArray RealWorld element))
-> IO (PrimMultiArray element)
extract (Product2 MutablePrimArray RealWorld size
_ UnliftedArray_
  (MutableByteArray# RealWorld) (MutablePrimArray RealWorld element)
multiArray) = do
      MutableUnliftedArray_
  (Unlifted (PrimArray element)) RealWorld (PrimArray element)
copied <- Int -> IO (MutableUnliftedArray (PrimState IO) (PrimArray element))
forall (m :: * -> *) a.
PrimMonad m =>
Int -> m (MutableUnliftedArray (PrimState m) a)
unsafeNewUnliftedArray Int
outerLength
      Int -> (Int -> IO ()) -> IO ()
forall (m :: * -> *) a.
Applicative m =>
Int -> (Int -> m a) -> m ()
forMFromZero_ Int
outerLength ((Int -> IO ()) -> IO ()) -> (Int -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Int
outerIndex -> do
        let mutableInnerArray :: MutablePrimArray RealWorld element
mutableInnerArray = UnliftedArray (MutablePrimArray RealWorld element)
-> Int -> MutablePrimArray RealWorld element
forall a. PrimUnlifted a => UnliftedArray a -> Int -> a
indexUnliftedArray UnliftedArray_
  (MutableByteArray# RealWorld) (MutablePrimArray RealWorld element)
UnliftedArray (MutablePrimArray RealWorld element)
multiArray Int
outerIndex
        PrimArray element
frozenInnerArray <- MutablePrimArray (PrimState IO) element -> IO (PrimArray element)
forall {m :: * -> *} {a}.
PrimMonad m =>
MutablePrimArray (PrimState m) a -> m (PrimArray a)
unsafeFreezePrimArray MutablePrimArray RealWorld element
MutablePrimArray (PrimState IO) element
mutableInnerArray
        MutableUnliftedArray (PrimState IO) (PrimArray element)
-> Int -> PrimArray element -> IO ()
forall (m :: * -> *) a.
(PrimMonad m, PrimUnlifted a) =>
MutableUnliftedArray (PrimState m) a -> Int -> a -> m ()
writeUnliftedArray MutableUnliftedArray_
  (Unlifted (PrimArray element)) RealWorld (PrimArray element)
MutableUnliftedArray (PrimState IO) (PrimArray element)
copied Int
outerIndex PrimArray element
frozenInnerArray
      UnliftedArray_ (Unlifted (PrimArray element)) (PrimArray element)
result <- MutableUnliftedArray (PrimState IO) (PrimArray element)
-> IO
     (UnliftedArray_ (Unlifted (PrimArray element)) (PrimArray element))
forall {m :: * -> *} {a}.
PrimMonad m =>
MutableUnliftedArray_ (Unlifted a) (PrimState m) a
-> m (UnliftedArray_ (Unlifted a) a)
unsafeFreezeUnliftedArray MutableUnliftedArray_
  (Unlifted (PrimArray element)) RealWorld (PrimArray element)
MutableUnliftedArray (PrimState IO) (PrimArray element)
copied
      PrimMultiArray element -> IO (PrimMultiArray element)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimMultiArray element -> IO (PrimMultiArray element))
-> PrimMultiArray element -> IO (PrimMultiArray element)
forall a b. (a -> b) -> a -> b
$ UnliftedArray_ (Unlifted (PrimArray element)) (PrimArray element)
-> PrimMultiArray element
forall a. UnliftedArray (PrimArray a) -> PrimMultiArray a
PrimMultiArray (UnliftedArray_ (Unlifted (PrimArray element)) (PrimArray element)
 -> PrimMultiArray element)
-> UnliftedArray_
     (Unlifted (PrimArray element)) (PrimArray element)
-> PrimMultiArray element
forall a b. (a -> b) -> a -> b
$ UnliftedArray_ (Unlifted (PrimArray element)) (PrimArray element)
result