{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE OverloadedStrings #-}

module DataFrame.Operations.Permutation where

import qualified Data.List as L
import qualified Data.Text as T
import qualified Data.Vector as V
import qualified Data.Vector.Unboxed as VU
import qualified Data.Vector.Unboxed.Mutable as VUM

import Control.Exception (throw)
import Control.Monad.ST (runST)
import Data.Vector.Internal.Check (HasCallStack)
import DataFrame.Errors (DataFrameException (..))
import DataFrame.Internal.Column (Columnable, atIndicesStable)
import DataFrame.Internal.DataFrame (DataFrame (..))
import DataFrame.Internal.Expression (Expr (Col))
import DataFrame.Internal.Row (sortedIndexes', toRowVector)
import DataFrame.Operations.Core (columnNames, dimensions)
import System.Random (Random (randomR), RandomGen)

-- | Sort order taken as a parameter by the 'sortBy' function.
data SortOrder where
    Asc :: (Columnable a) => Expr a -> SortOrder
    Desc :: (Columnable a) => Expr a -> SortOrder

instance Eq SortOrder where
    (==) :: SortOrder -> SortOrder -> Bool
    == :: SortOrder -> SortOrder -> Bool
(==) (Asc Expr a
_) (Asc Expr a
_) = Bool
True
    (==) (Desc Expr a
_) (Desc Expr a
_) = Bool
True
    (==) SortOrder
_ SortOrder
_ = Bool
False

getSortColumnName :: SortOrder -> T.Text
getSortColumnName :: SortOrder -> Text
getSortColumnName (Asc (Col Text
n)) = Text
n
getSortColumnName (Desc (Col Text
n)) = Text
n
getSortColumnName SortOrder
_ = [Char] -> Text
forall a. HasCallStack => [Char] -> a
error [Char]
"Sorting on compound column"

mustFlipCompare :: SortOrder -> Bool
mustFlipCompare :: SortOrder -> Bool
mustFlipCompare (Asc Expr a
_) = Bool
True
mustFlipCompare (Desc Expr a
_) = Bool
False

{- | O(k log n) Sorts the dataframe by a given row.

> sortBy Ascending ["Age"] df
-}
sortBy ::
    [SortOrder] ->
    DataFrame ->
    DataFrame
sortBy :: [SortOrder] -> DataFrame -> DataFrame
sortBy [SortOrder]
sortOrds DataFrame
df
    | (Text -> Bool) -> [Text] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Text -> [Text] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` DataFrame -> [Text]
columnNames DataFrame
df) [Text]
names =
        DataFrameException -> DataFrame
forall a e. Exception e => e -> a
throw (DataFrameException -> DataFrame)
-> DataFrameException -> DataFrame
forall a b. (a -> b) -> a -> b
$
            Text -> Text -> [Text] -> DataFrameException
ColumnNotFoundException
                ([Char] -> Text
T.pack ([Char] -> Text) -> [Char] -> Text
forall a b. (a -> b) -> a -> b
$ [Text] -> [Char]
forall a. Show a => a -> [Char]
show ([Text] -> [Char]) -> [Text] -> [Char]
forall a b. (a -> b) -> a -> b
$ [Text]
names [Text] -> [Text] -> [Text]
forall a. Eq a => [a] -> [a] -> [a]
L.\\ DataFrame -> [Text]
columnNames DataFrame
df)
                Text
"sortBy"
                (DataFrame -> [Text]
columnNames DataFrame
df)
    | Bool
otherwise =
        let
            indexes :: Vector Int
indexes = [Bool] -> Vector Row -> Vector Int
sortedIndexes' [Bool]
mustFlips ([Text] -> DataFrame -> Vector Row
toRowVector [Text]
names DataFrame
df)
         in
            DataFrame
df{columns = V.map (atIndicesStable indexes) (columns df)}
  where
    names :: [Text]
names = (SortOrder -> Text) -> [SortOrder] -> [Text]
forall a b. (a -> b) -> [a] -> [b]
map SortOrder -> Text
getSortColumnName [SortOrder]
sortOrds
    mustFlips :: [Bool]
mustFlips = (SortOrder -> Bool) -> [SortOrder] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map SortOrder -> Bool
mustFlipCompare [SortOrder]
sortOrds

shuffle ::
    (RandomGen g) =>
    g ->
    DataFrame ->
    DataFrame
shuffle :: forall g. RandomGen g => g -> DataFrame -> DataFrame
shuffle g
pureGen DataFrame
df =
    let
        indexes :: Vector Int
indexes = g -> Int -> Vector Int
forall g. (HasCallStack, RandomGen g) => g -> Int -> Vector Int
shuffledIndices g
pureGen ((Int, Int) -> Int
forall a b. (a, b) -> a
fst (DataFrame -> (Int, Int)
dimensions DataFrame
df))
     in
        DataFrame
df{columns = V.map (atIndicesStable indexes) (columns df)}

shuffledIndices :: (HasCallStack, RandomGen g) => g -> Int -> VU.Vector Int
shuffledIndices :: forall g. (HasCallStack, RandomGen g) => g -> Int -> Vector Int
shuffledIndices g
pureGen Int
k
    | Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 = [Char] -> Vector Int
forall a. HasCallStack => [Char] -> a
error ([Char] -> Vector Int) -> [Char] -> Vector Int
forall a b. (a -> b) -> a -> b
$ [Char]
"Vector index may not be a neative number: " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Int -> [Char]
forall a. Show a => a -> [Char]
show Int
k
    | Int
k Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Vector Int
forall a. Unbox a => Vector a
VU.empty
    | Bool
otherwise = g -> Vector Int
forall g. RandomGen g => g -> Vector Int
shuffleVec g
pureGen
  where
    shuffleVec :: (RandomGen g) => g -> VU.Vector Int
    shuffleVec :: forall g. RandomGen g => g -> Vector Int
shuffleVec g
g = (forall s. ST s (Vector Int)) -> Vector Int
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Vector Int)) -> Vector Int)
-> (forall s. ST s (Vector Int)) -> Vector Int
forall a b. (a -> b) -> a -> b
$ do
        MVector s Int
vm <- Int -> (Int -> Int) -> ST s (MVector (PrimState (ST s)) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> (Int -> a) -> m (MVector (PrimState m) a)
VUM.generate Int
k Int -> Int
forall a. a -> a
id
        let (Int
n, g
nGen) = (Int, Int) -> g -> (Int, g)
forall g. RandomGen g => (Int, Int) -> g -> (Int, g)
forall a g. (Random a, RandomGen g) => (a, a) -> g -> (a, g)
randomR (Int
1, Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) g
g
        MVector (PrimState (ST s)) Int -> Int -> g -> ST s ()
forall {f :: * -> *} {t} {a}.
(RandomGen t, PrimMonad f, Unbox a) =>
MVector (PrimState f) a -> Int -> t -> f ()
go MVector s Int
MVector (PrimState (ST s)) Int
vm Int
n g
nGen
        MVector (PrimState (ST s)) Int -> ST s (Vector Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VU.unsafeFreeze MVector s Int
MVector (PrimState (ST s)) Int
vm

    go :: MVector (PrimState f) a -> Int -> t -> f ()
go MVector (PrimState f) a
v (-1) t
_ = () -> f ()
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    go MVector (PrimState f) a
v Int
0 t
_ = () -> f ()
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    go MVector (PrimState f) a
v Int
maxInd t
gen =
        let
            (Int
n, t
nextGen) = (Int, Int) -> t -> (Int, t)
forall g. RandomGen g => (Int, Int) -> g -> (Int, g)
forall a g. (Random a, RandomGen g) => (a, a) -> g -> (a, g)
randomR (Int
1, Int
maxInd) t
gen
         in
            MVector (PrimState f) a -> Int -> Int -> f ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> Int -> m ()
VUM.swap MVector (PrimState f) a
v Int
0 Int
n f () -> f () -> f ()
forall a b. f a -> f b -> f b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> MVector (PrimState f) a -> Int -> t -> f ()
go (MVector (PrimState f) a -> MVector (PrimState f) a
forall a s. Unbox a => MVector s a -> MVector s a
VUM.tail MVector (PrimState f) a
v) (Int
maxInd Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) t
nextGen