{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE UnboxedTuples #-}
-- |
-- Module      :  System.Random.Array
-- Copyright   :  (c) Alexey Kuleshevich 2024
-- License     :  BSD-style (see the file LICENSE in the 'random' repository)
-- Maintainer  :  libraries@haskell.org
--
module System.Random.Array
  ( -- * Helper array functionality
    ioToST
  , wordSizeInBits
    -- ** MutableByteArray
  , newMutableByteArray
  , newPinnedMutableByteArray
  , freezeMutableByteArray
  , writeWord8
  , writeWord64LE
  , writeByteSliceWord64LE
  , indexWord8
  , indexWord64LE
  , indexByteSliceWord64LE
  , sizeOfByteArray
  , shortByteStringToByteArray
  , byteArrayToShortByteString
  , getSizeOfMutableByteArray
  , shortByteStringToByteString
  -- ** MutableArray
  , Array (..)
  , MutableArray (..)
  , newMutableArray
  , freezeMutableArray
  , writeArray
  , shuffleListM
  , shuffleListST
  ) where

import Control.Monad.Trans (lift, MonadTrans)
import Control.Monad (when)
import Control.Monad.ST
import Data.Array.Byte (ByteArray(..), MutableByteArray(..))
import Data.Bits
import Data.ByteString.Short.Internal (ShortByteString(SBS))
import qualified Data.ByteString.Short.Internal as SBS (fromShort)
import Data.Word
import GHC.Exts
import GHC.IO (IO(..))
import GHC.ST (ST(..))
import GHC.Word
#if __GLASGOW_HASKELL__ >= 802
import Data.ByteString.Internal (ByteString(PS))
import GHC.ForeignPtr
#else
import Data.ByteString (ByteString)
#endif

-- Needed for WORDS_BIGENDIAN
#include "MachDeps.h"

wordSizeInBits :: Int
wordSizeInBits :: Int
wordSizeInBits = Word -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize (Word
0 :: Word)

----------------
-- Byte Array --
----------------

-- Architecture independent helpers:

sizeOfByteArray :: ByteArray -> Int
sizeOfByteArray :: ByteArray -> Int
sizeOfByteArray (ByteArray ByteArray#
ba#) = Int# -> Int
I# (ByteArray# -> Int#
sizeofByteArray# ByteArray#
ba#)

st_ :: (State# s -> State# s) -> ST s ()
st_ :: forall s. (State# s -> State# s) -> ST s ()
st_ State# s -> State# s
m# = STRep s () -> ST s ()
forall s a. STRep s a -> ST s a
ST (STRep s () -> ST s ()) -> STRep s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ \State# s
s# -> (# State# s -> State# s
m# State# s
s#, () #)
{-# INLINE st_ #-}

ioToST :: IO a -> ST RealWorld a
ioToST :: forall a. IO a -> ST RealWorld a
ioToST (IO State# RealWorld -> (# State# RealWorld, a #)
m#) = (State# RealWorld -> (# State# RealWorld, a #)) -> ST RealWorld a
forall s a. STRep s a -> ST s a
ST State# RealWorld -> (# State# RealWorld, a #)
m#
{-# INLINE ioToST #-}

newMutableByteArray :: Int -> ST s (MutableByteArray s)
newMutableByteArray :: forall s. Int -> ST s (MutableByteArray s)
newMutableByteArray (I# Int#
n#) =
  STRep s (MutableByteArray s) -> ST s (MutableByteArray s)
forall s a. STRep s a -> ST s a
ST (STRep s (MutableByteArray s) -> ST s (MutableByteArray s))
-> STRep s (MutableByteArray s) -> ST s (MutableByteArray s)
forall a b. (a -> b) -> a -> b
$ \State# s
s# ->
    case Int# -> State# s -> (# State# s, MutableByteArray# s #)
forall d. Int# -> State# d -> (# State# d, MutableByteArray# d #)
newByteArray# Int#
n# State# s
s# of
      (# State# s
s'#, MutableByteArray# s
mba# #) -> (# State# s
s'#, MutableByteArray# s -> MutableByteArray s
forall s. MutableByteArray# s -> MutableByteArray s
MutableByteArray MutableByteArray# s
mba# #)
{-# INLINE newMutableByteArray #-}

newPinnedMutableByteArray :: Int -> ST s (MutableByteArray s)
newPinnedMutableByteArray :: forall s. Int -> ST s (MutableByteArray s)
newPinnedMutableByteArray (I# Int#
n#) =
  STRep s (MutableByteArray s) -> ST s (MutableByteArray s)
forall s a. STRep s a -> ST s a
ST (STRep s (MutableByteArray s) -> ST s (MutableByteArray s))
-> STRep s (MutableByteArray s) -> ST s (MutableByteArray s)
forall a b. (a -> b) -> a -> b
$ \State# s
s# ->
    case Int# -> State# s -> (# State# s, MutableByteArray# s #)
forall d. Int# -> State# d -> (# State# d, MutableByteArray# d #)
newPinnedByteArray# Int#
n# State# s
s# of
      (# State# s
s'#, MutableByteArray# s
mba# #) -> (# State# s
s'#, MutableByteArray# s -> MutableByteArray s
forall s. MutableByteArray# s -> MutableByteArray s
MutableByteArray MutableByteArray# s
mba# #)
{-# INLINE newPinnedMutableByteArray #-}

freezeMutableByteArray :: MutableByteArray s -> ST s ByteArray
freezeMutableByteArray :: forall s. MutableByteArray s -> ST s ByteArray
freezeMutableByteArray (MutableByteArray MutableByteArray# s
mba#) =
  STRep s ByteArray -> ST s ByteArray
forall s a. STRep s a -> ST s a
ST (STRep s ByteArray -> ST s ByteArray)
-> STRep s ByteArray -> ST s ByteArray
forall a b. (a -> b) -> a -> b
$ \State# s
s# ->
    case MutableByteArray# s -> State# s -> (# State# s, ByteArray# #)
forall d.
MutableByteArray# d -> State# d -> (# State# d, ByteArray# #)
unsafeFreezeByteArray# MutableByteArray# s
mba# State# s
s# of
      (# State# s
s'#, ByteArray#
ba# #) -> (# State# s
s'#, ByteArray# -> ByteArray
ByteArray ByteArray#
ba# #)

writeWord8 :: MutableByteArray s -> Int -> Word8 -> ST s ()
writeWord8 :: forall s. MutableByteArray s -> Int -> Word8 -> ST s ()
writeWord8 (MutableByteArray MutableByteArray# s
mba#) (I# Int#
i#) (W8# Word8#
w#) = (State# s -> State# s) -> ST s ()
forall s. (State# s -> State# s) -> ST s ()
st_ (MutableByteArray# s -> Int# -> Word8# -> State# s -> State# s
forall d.
MutableByteArray# d -> Int# -> Word8# -> State# d -> State# d
writeWord8Array# MutableByteArray# s
mba# Int#
i# Word8#
w#)
{-# INLINE writeWord8 #-}

writeByteSliceWord64LE :: MutableByteArray s -> Int -> Int -> Word64 -> ST s ()
writeByteSliceWord64LE :: forall s. MutableByteArray s -> Int -> Int -> Word64 -> ST s ()
writeByteSliceWord64LE MutableByteArray s
mba Int
fromByteIx Int
toByteIx = Int -> Word64 -> ST s ()
forall {t}. (Integral t, Bits t) => Int -> t -> ST s ()
go Int
fromByteIx
  where
    go :: Int -> t -> ST s ()
go !Int
i !t
z =
      Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
toByteIx) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
        MutableByteArray s -> Int -> Word8 -> ST s ()
forall s. MutableByteArray s -> Int -> Word8 -> ST s ()
writeWord8 MutableByteArray s
mba Int
i (t -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral t
z :: Word8)
        Int -> t -> ST s ()
go (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (t
z t -> Int -> t
forall a. Bits a => a -> Int -> a
`shiftR` Int
8)
{-# INLINE writeByteSliceWord64LE #-}

indexWord8 ::
     ByteArray
  -> Int -- ^ Offset into immutable byte array in number of bytes
  -> Word8
indexWord8 :: ByteArray -> Int -> Word8
indexWord8 (ByteArray ByteArray#
ba#) (I# Int#
i#) =
  Word8# -> Word8
W8# (ByteArray# -> Int# -> Word8#
indexWord8Array# ByteArray#
ba# Int#
i#)
{-# INLINE indexWord8 #-}

indexWord64LE ::
     ByteArray
  -> Int -- ^ Offset into immutable byte array in number of bytes
  -> Word64
#if defined WORDS_BIGENDIAN || !(__GLASGOW_HASKELL__ >= 806)
indexWord64LE ba i = indexByteSliceWord64LE ba i (i + 8)
#else
indexWord64LE :: ByteArray -> Int -> Word64
indexWord64LE (ByteArray ByteArray#
ba#) (I# Int#
i#)
  | Int
wordSizeInBits Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
64 = Word64# -> Word64
W64# (ByteArray# -> Int# -> Word64#
indexWord8ArrayAsWord64# ByteArray#
ba# Int#
i#)
  | Bool
otherwise =
    let !w32l :: Word32
w32l = Word32# -> Word32
W32# (ByteArray# -> Int# -> Word32#
indexWord8ArrayAsWord32# ByteArray#
ba# Int#
i#)
        !w32u :: Word32
w32u = Word32# -> Word32
W32# (ByteArray# -> Int# -> Word32#
indexWord8ArrayAsWord32# ByteArray#
ba# (Int#
i# Int# -> Int# -> Int#
+# Int#
4#))
    in (Word32 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
w32u Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`shiftL` Int
32) Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.|. Word32 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
w32l
#endif
{-# INLINE indexWord64LE #-}

indexByteSliceWord64LE ::
     ByteArray
  -> Int -- ^ Starting offset in number of bytes
  -> Int -- ^ Ending offset in number of bytes
  -> Word64
indexByteSliceWord64LE :: ByteArray -> Int -> Int -> Word64
indexByteSliceWord64LE ByteArray
ba Int
fromByteIx Int
toByteIx = Int -> Word64 -> Word64
goWord8 Int
fromByteIx Word64
0
  where
    r :: Int
r = (Int
toByteIx Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
fromByteIx) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`rem` Int
8
    nPadBits :: Int
nPadBits = if Int
r Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 then Int
0 else Int
8 Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Int
8 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
r)
    goWord8 :: Int -> Word64 -> Word64
goWord8 Int
i !Word64
w64
      | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
toByteIx = Int -> Word64 -> Word64
goWord8 (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
shiftL Word64
w64 Int
8 Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.|. Word8 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteArray -> Int -> Word8
indexWord8 ByteArray
ba Int
i))
      | Bool
otherwise = Word64 -> Word64
byteSwap64 (Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
shiftL Word64
w64 Int
nPadBits)
{-# INLINE indexByteSliceWord64LE #-}

-- On big endian machines we need to write one byte at a time for consistency with little
-- endian machines. Also for GHC versions prior to 8.6 we don't have primops that can
-- write with byte offset, eg. writeWord8ArrayAsWord64# and writeWord8ArrayAsWord32#, so we
-- also must fallback to writing one byte a time. Such fallback results in about 3 times
-- slow down, which is not the end of the world.
writeWord64LE ::
     MutableByteArray s
  -> Int -- ^ Offset into mutable byte array in number of bytes
  -> Word64 -- ^ 8 bytes that will be written into the supplied array
  -> ST s ()
#if defined WORDS_BIGENDIAN || !(__GLASGOW_HASKELL__ >= 806)
writeWord64LE mba i w64 =
  writeByteSliceWord64LE mba i (i + 8) w64
#else
writeWord64LE :: forall s. MutableByteArray s -> Int -> Word64 -> ST s ()
writeWord64LE (MutableByteArray MutableByteArray# s
mba#) (I# Int#
i#) w64 :: Word64
w64@(W64# Word64#
w64#)
  | Int
wordSizeInBits Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
64 = (State# s -> State# s) -> ST s ()
forall s. (State# s -> State# s) -> ST s ()
st_ (MutableByteArray# s -> Int# -> Word64# -> State# s -> State# s
forall d.
MutableByteArray# d -> Int# -> Word64# -> State# d -> State# d
writeWord8ArrayAsWord64# MutableByteArray# s
mba# Int#
i# Word64#
w64#)
  | Bool
otherwise = do
    let !(W32# Word32#
w32l#) = Word64 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
w64
        !(W32# Word32#
w32u#) = Word64 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word64
w64 Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`shiftR` Int
32)
    (State# s -> State# s) -> ST s ()
forall s. (State# s -> State# s) -> ST s ()
st_ (MutableByteArray# s -> Int# -> Word32# -> State# s -> State# s
forall d.
MutableByteArray# d -> Int# -> Word32# -> State# d -> State# d
writeWord8ArrayAsWord32# MutableByteArray# s
mba# Int#
i# Word32#
w32l#)
    (State# s -> State# s) -> ST s ()
forall s. (State# s -> State# s) -> ST s ()
st_ (MutableByteArray# s -> Int# -> Word32# -> State# s -> State# s
forall d.
MutableByteArray# d -> Int# -> Word32# -> State# d -> State# d
writeWord8ArrayAsWord32# MutableByteArray# s
mba# (Int#
i# Int# -> Int# -> Int#
+# Int#
4#) Word32#
w32u#)
#endif
{-# INLINE writeWord64LE #-}

getSizeOfMutableByteArray :: MutableByteArray s -> ST s Int
getSizeOfMutableByteArray :: forall s. MutableByteArray s -> ST s Int
getSizeOfMutableByteArray (MutableByteArray MutableByteArray# s
mba#) =
#if __GLASGOW_HASKELL__ >=802
  STRep s Int -> ST s Int
forall s a. STRep s a -> ST s a
ST (STRep s Int -> ST s Int) -> STRep s Int -> ST s Int
forall a b. (a -> b) -> a -> b
$ \State# s
s ->
    case MutableByteArray# s -> State# s -> (# State# s, Int# #)
forall d. MutableByteArray# d -> State# d -> (# State# d, Int# #)
getSizeofMutableByteArray# MutableByteArray# s
mba# State# s
s of
      (# State# s
s', Int#
n# #) -> (# State# s
s', Int# -> Int
I# Int#
n# #)
#else
  pure $! I# (sizeofMutableByteArray# mba#)
#endif
{-# INLINE getSizeOfMutableByteArray #-}

shortByteStringToByteArray :: ShortByteString -> ByteArray
shortByteStringToByteArray :: ShortByteString -> ByteArray
shortByteStringToByteArray (SBS ByteArray#
ba#) = ByteArray# -> ByteArray
ByteArray ByteArray#
ba#
{-# INLINE shortByteStringToByteArray #-}

byteArrayToShortByteString :: ByteArray -> ShortByteString
byteArrayToShortByteString :: ByteArray -> ShortByteString
byteArrayToShortByteString (ByteArray ByteArray#
ba#) = ByteArray# -> ShortByteString
SBS ByteArray#
ba#
{-# INLINE byteArrayToShortByteString #-}

-- | Convert a ShortByteString to ByteString by casting, whenever memory is pinned,
-- otherwise make a copy into a new pinned ByteString
shortByteStringToByteString :: ShortByteString -> ByteString
shortByteStringToByteString :: ShortByteString -> ByteString
shortByteStringToByteString ShortByteString
ba =
#if __GLASGOW_HASKELL__ < 802
  SBS.fromShort ba
#else
  let !(SBS ByteArray#
ba#) = ShortByteString
ba in
  if Int# -> Bool
isTrue# (ByteArray# -> Int#
isByteArrayPinned# ByteArray#
ba#)
    then ByteArray# -> ByteString
pinnedByteArrayToByteString ByteArray#
ba#
    else ShortByteString -> ByteString
SBS.fromShort ShortByteString
ba
{-# INLINE shortByteStringToByteString #-}

pinnedByteArrayToByteString :: ByteArray# -> ByteString
pinnedByteArrayToByteString :: ByteArray# -> ByteString
pinnedByteArrayToByteString ByteArray#
ba# =
  ForeignPtr Word8 -> Int -> Int -> ByteString
PS (ByteArray# -> ForeignPtr Word8
forall a. ByteArray# -> ForeignPtr a
pinnedByteArrayToForeignPtr ByteArray#
ba#) Int
0 (Int# -> Int
I# (ByteArray# -> Int#
sizeofByteArray# ByteArray#
ba#))
{-# INLINE pinnedByteArrayToByteString #-}

pinnedByteArrayToForeignPtr :: ByteArray# -> ForeignPtr a
pinnedByteArrayToForeignPtr :: forall a. ByteArray# -> ForeignPtr a
pinnedByteArrayToForeignPtr ByteArray#
ba# =
  Addr# -> ForeignPtrContents -> ForeignPtr a
forall a. Addr# -> ForeignPtrContents -> ForeignPtr a
ForeignPtr (ByteArray# -> Addr#
byteArrayContents# ByteArray#
ba#) (MutableByteArray# RealWorld -> ForeignPtrContents
PlainPtr (ByteArray# -> MutableByteArray# RealWorld
forall a b. a -> b
unsafeCoerce# ByteArray#
ba#))
{-# INLINE pinnedByteArrayToForeignPtr #-}
#endif

-----------------
-- Boxed Array --
-----------------

data Array a = Array (Array# a)

data MutableArray s a = MutableArray (MutableArray# s a)

newMutableArray :: Int -> a -> ST s (MutableArray s a)
newMutableArray :: forall a s. Int -> a -> ST s (MutableArray s a)
newMutableArray (I# Int#
n#) a
a =
  STRep s (MutableArray s a) -> ST s (MutableArray s a)
forall s a. STRep s a -> ST s a
ST (STRep s (MutableArray s a) -> ST s (MutableArray s a))
-> STRep s (MutableArray s a) -> ST s (MutableArray s a)
forall a b. (a -> b) -> a -> b
$ \State# s
s# ->
    case Int# -> a -> State# s -> (# State# s, MutableArray# s a #)
forall a d.
Int# -> a -> State# d -> (# State# d, MutableArray# d a #)
newArray# Int#
n# a
a State# s
s# of
      (# State# s
s'#, MutableArray# s a
ma# #) -> (# State# s
s'#, MutableArray# s a -> MutableArray s a
forall s a. MutableArray# s a -> MutableArray s a
MutableArray MutableArray# s a
ma# #)
{-# INLINE newMutableArray #-}

freezeMutableArray :: MutableArray s a -> ST s (Array a)
freezeMutableArray :: forall s a. MutableArray s a -> ST s (Array a)
freezeMutableArray (MutableArray MutableArray# s a
ma#) =
  STRep s (Array a) -> ST s (Array a)
forall s a. STRep s a -> ST s a
ST (STRep s (Array a) -> ST s (Array a))
-> STRep s (Array a) -> ST s (Array a)
forall a b. (a -> b) -> a -> b
$ \State# s
s# ->
    case MutableArray# s a -> State# s -> (# State# s, Array# a #)
forall d a.
MutableArray# d a -> State# d -> (# State# d, Array# a #)
unsafeFreezeArray# MutableArray# s a
ma# State# s
s# of
      (# State# s
s'#, Array# a
a# #) -> (# State# s
s'#, Array# a -> Array a
forall a. Array# a -> Array a
Array Array# a
a# #)
{-# INLINE freezeMutableArray #-}

sizeOfMutableArray :: MutableArray s a -> Int
sizeOfMutableArray :: forall s a. MutableArray s a -> Int
sizeOfMutableArray (MutableArray MutableArray# s a
ma#) = Int# -> Int
I# (MutableArray# s a -> Int#
forall d a. MutableArray# d a -> Int#
sizeofMutableArray# MutableArray# s a
ma#)
{-# INLINE sizeOfMutableArray #-}

readArray :: MutableArray s a -> Int -> ST s a
readArray :: forall s a. MutableArray s a -> Int -> ST s a
readArray (MutableArray MutableArray# s a
ma#) (I# Int#
i#) = STRep s a -> ST s a
forall s a. STRep s a -> ST s a
ST (MutableArray# s a -> Int# -> STRep s a
forall d a.
MutableArray# d a -> Int# -> State# d -> (# State# d, a #)
readArray# MutableArray# s a
ma# Int#
i#)
{-# INLINE readArray #-}

writeArray :: MutableArray s a -> Int -> a -> ST s ()
writeArray :: forall s a. MutableArray s a -> Int -> a -> ST s ()
writeArray (MutableArray MutableArray# s a
ma#) (I# Int#
i#) a
a = (State# s -> State# s) -> ST s ()
forall s. (State# s -> State# s) -> ST s ()
st_ (MutableArray# s a -> Int# -> a -> State# s -> State# s
forall d a. MutableArray# d a -> Int# -> a -> State# d -> State# d
writeArray# MutableArray# s a
ma# Int#
i# a
a)
{-# INLINE writeArray #-}

swapArray :: MutableArray s a -> Int -> Int -> ST s ()
swapArray :: forall s a. MutableArray s a -> Int -> Int -> ST s ()
swapArray MutableArray s a
ma Int
i Int
j = do
  a
x <- MutableArray s a -> Int -> ST s a
forall s a. MutableArray s a -> Int -> ST s a
readArray MutableArray s a
ma Int
i
  a
y <- MutableArray s a -> Int -> ST s a
forall s a. MutableArray s a -> Int -> ST s a
readArray MutableArray s a
ma Int
j
  MutableArray s a -> Int -> a -> ST s ()
forall s a. MutableArray s a -> Int -> a -> ST s ()
writeArray MutableArray s a
ma Int
j a
x
  MutableArray s a -> Int -> a -> ST s ()
forall s a. MutableArray s a -> Int -> a -> ST s ()
writeArray MutableArray s a
ma Int
i a
y
{-# INLINE swapArray #-}

-- | Write contents of the list into the mutable array. Make sure that array is big
-- enough or segfault will happen.
fillMutableArrayFromList :: MutableArray s a -> [a] -> ST s ()
fillMutableArrayFromList :: forall s a. MutableArray s a -> [a] -> ST s ()
fillMutableArrayFromList MutableArray s a
ma = Int -> [a] -> ST s ()
go Int
0
  where
    go :: Int -> [a] -> ST s ()
go Int
_ [] = () -> ST s ()
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    go Int
i (a
x:[a]
xs) = MutableArray s a -> Int -> a -> ST s ()
forall s a. MutableArray s a -> Int -> a -> ST s ()
writeArray MutableArray s a
ma Int
i a
x ST s () -> ST s () -> ST s ()
forall a b. ST s a -> ST s b -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Int -> [a] -> ST s ()
go (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [a]
xs
{-# INLINE fillMutableArrayFromList #-}

readListFromMutableArray :: MutableArray s a -> ST s [a]
readListFromMutableArray :: forall s a. MutableArray s a -> ST s [a]
readListFromMutableArray MutableArray s a
ma = Int -> [a] -> ST s [a]
go (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) []
  where
    len :: Int
len = MutableArray s a -> Int
forall s a. MutableArray s a -> Int
sizeOfMutableArray MutableArray s a
ma
    go :: Int -> [a] -> ST s [a]
go Int
i ![a]
acc
       | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0 = do
           a
x <- MutableArray s a -> Int -> ST s a
forall s a. MutableArray s a -> Int -> ST s a
readArray MutableArray s a
ma Int
i
           Int -> [a] -> ST s [a]
go (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (a
x a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a]
acc)
       | Bool
otherwise = [a] -> ST s [a]
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [a]
acc
{-# INLINE readListFromMutableArray #-}


-- | Generate a list of indices that will be used for swapping elements in uniform shuffling:
--
-- @
-- [ (0, n - 1)
-- , (0, n - 2)
-- , (0, n - 3)
-- , ...
-- , (0, 3)
-- , (0, 2)
-- , (0, 1)
-- ]
-- @
genSwapIndices
  :: Monad m
  => (Word -> m Word)
  -- ^ Action that generates a Word in the supplied range.
  -> Word
  -- ^ Number of index swaps to generate.
  -> m [Int]
genSwapIndices :: forall (m :: * -> *).
Monad m =>
(Word -> m Word) -> Word -> m [Int]
genSwapIndices Word -> m Word
genWordR Word
n = Word -> [Int] -> m [Int]
forall {a}. Num a => Word -> [a] -> m [a]
go Word
1 []
  where
    go :: Word -> [a] -> m [a]
go Word
i ![a]
acc
      | Word
i Word -> Word -> Bool
forall a. Ord a => a -> a -> Bool
>= Word
n = [a] -> m [a]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [a]
acc
      | Bool
otherwise = do
          Word
x <- Word -> m Word
genWordR Word
i
          let !xi :: a
xi = Word -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word
x
          Word -> [a] -> m [a]
go (Word
i Word -> Word -> Word
forall a. Num a => a -> a -> a
+ Word
1) (a
xi a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a]
acc)
{-# INLINE genSwapIndices #-}


-- | Implementation of mutable version of Fisher-Yates shuffle. Unfortunately, we cannot generally
-- interleave pseudo-random number generation and mutation of `ST` monad, therefore we have to
-- pre-generate all of the index swaps with `genSwapIndices` and store them in a list before we can
-- perform the actual swaps.
shuffleListM :: Monad m => (Word -> m Word) -> [a] -> m [a]
shuffleListM :: forall (m :: * -> *) a. Monad m => (Word -> m Word) -> [a] -> m [a]
shuffleListM Word -> m Word
genWordR [a]
ls
  | Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1 = [a] -> m [a]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [a]
ls
  | Bool
otherwise = do
    [Int]
swapIxs <- (Word -> m Word) -> Word -> m [Int]
forall (m :: * -> *).
Monad m =>
(Word -> m Word) -> Word -> m [Int]
genSwapIndices Word -> m Word
genWordR (Int -> Word
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)
    [a] -> m [a]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([a] -> m [a]) -> [a] -> m [a]
forall a b. (a -> b) -> a -> b
$ (forall s. ST s [a]) -> [a]
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s [a]) -> [a]) -> (forall s. ST s [a]) -> [a]
forall a b. (a -> b) -> a -> b
$ do
      MutableArray s a
ma <- Int -> a -> ST s (MutableArray s a)
forall a s. Int -> a -> ST s (MutableArray s a)
newMutableArray Int
len (a -> ST s (MutableArray s a)) -> a -> ST s (MutableArray s a)
forall a b. (a -> b) -> a -> b
$ [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"Impossible: shuffleListM"
      MutableArray s a -> [a] -> ST s ()
forall s a. MutableArray s a -> [a] -> ST s ()
fillMutableArrayFromList MutableArray s a
ma [a]
ls

      -- Shuffle elements of the mutable array according to the uniformly generated index swap list
      let goSwap :: Int -> [Int] -> ST s ()
goSwap Int
_ [] = () -> ST s ()
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
          goSwap Int
i (Int
j:[Int]
js) = MutableArray s a -> Int -> Int -> ST s ()
forall s a. MutableArray s a -> Int -> Int -> ST s ()
swapArray MutableArray s a
ma Int
i Int
j ST s () -> ST s () -> ST s ()
forall a b. ST s a -> ST s b -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Int -> [Int] -> ST s ()
goSwap (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) [Int]
js
      Int -> [Int] -> ST s ()
goSwap (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) [Int]
swapIxs

      MutableArray s a -> ST s [a]
forall s a. MutableArray s a -> ST s [a]
readListFromMutableArray MutableArray s a
ma
  where
    len :: Int
len = [a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
ls
{-# INLINE shuffleListM #-}

-- | This is a ~x2-x3 more efficient version of `shuffleListM`. It is more efficient because it does
-- not need to pregenerate a list of indices and instead generates them on demand. Because of this the
-- result that will be produced will differ for the same generator, since the order in which index
-- swaps are generated is reversed.
--
-- Unfortunately, most stateful generator monads can't handle `MonadTrans`, so this version is only
-- used for implementing the pure shuffle.
shuffleListST :: (Monad (t (ST s)), MonadTrans t) => (Word -> t (ST s) Word) -> [a] -> t (ST s) [a]
shuffleListST :: forall (t :: (* -> *) -> * -> *) s a.
(Monad (t (ST s)), MonadTrans t) =>
(Word -> t (ST s) Word) -> [a] -> t (ST s) [a]
shuffleListST Word -> t (ST s) Word
genWordR [a]
ls
  | Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1 = [a] -> t (ST s) [a]
forall a. a -> t (ST s) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [a]
ls
  | Bool
otherwise = do
     MutableArray s a
ma <- ST s (MutableArray s a) -> t (ST s) (MutableArray s a)
forall (m :: * -> *) a. Monad m => m a -> t m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s (MutableArray s a) -> t (ST s) (MutableArray s a))
-> ST s (MutableArray s a) -> t (ST s) (MutableArray s a)
forall a b. (a -> b) -> a -> b
$ Int -> a -> ST s (MutableArray s a)
forall a s. Int -> a -> ST s (MutableArray s a)
newMutableArray Int
len (a -> ST s (MutableArray s a)) -> a -> ST s (MutableArray s a)
forall a b. (a -> b) -> a -> b
$ [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"Impossible: shuffleListST"
     ST s () -> t (ST s) ()
forall (m :: * -> *) a. Monad m => m a -> t m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s () -> t (ST s) ()) -> ST s () -> t (ST s) ()
forall a b. (a -> b) -> a -> b
$ MutableArray s a -> [a] -> ST s ()
forall s a. MutableArray s a -> [a] -> ST s ()
fillMutableArrayFromList MutableArray s a
ma [a]
ls

     -- Shuffle elements of the mutable array according to the uniformly generated index swap
     let goSwap :: Int -> t (ST s) ()
goSwap Int
i =
           Bool -> t (ST s) () -> t (ST s) ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0) (t (ST s) () -> t (ST s) ()) -> t (ST s) () -> t (ST s) ()
forall a b. (a -> b) -> a -> b
$ do
             Word
j <- Word -> t (ST s) Word
genWordR (Word -> t (ST s) Word) -> Word -> t (ST s) Word
forall a b. (a -> b) -> a -> b
$ (Int -> Word
forall a b. (Integral a, Num b) => a -> b
fromIntegral :: Int -> Word) Int
i
             ST s () -> t (ST s) ()
forall (m :: * -> *) a. Monad m => m a -> t m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s () -> t (ST s) ()) -> ST s () -> t (ST s) ()
forall a b. (a -> b) -> a -> b
$ MutableArray s a -> Int -> Int -> ST s ()
forall s a. MutableArray s a -> Int -> Int -> ST s ()
swapArray MutableArray s a
ma Int
i ((Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral :: Word -> Int) Word
j)
             Int -> t (ST s) ()
goSwap (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
     Int -> t (ST s) ()
goSwap (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)

     ST s [a] -> t (ST s) [a]
forall (m :: * -> *) a. Monad m => m a -> t m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s [a] -> t (ST s) [a]) -> ST s [a] -> t (ST s) [a]
forall a b. (a -> b) -> a -> b
$ MutableArray s a -> ST s [a]
forall s a. MutableArray s a -> ST s [a]
readListFromMutableArray MutableArray s a
ma
  where
    len :: Int
len = [a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
ls
{-# INLINE shuffleListST #-}