{-# LANGUAGE OverloadedStrings #-}

module DataFrame.Operations.Sorting where

import qualified Data.List as L
import qualified Data.Text as T
import qualified Data.Vector as V

import Control.Exception (throw)
import DataFrame.Errors (DataFrameException (..))
import DataFrame.Internal.Column
import DataFrame.Internal.DataFrame (DataFrame (..))
import DataFrame.Internal.Row
import DataFrame.Operations.Core

-- | Sort order taken as a parameter by the 'sortBy' function.
data SortOrder = Ascending | Descending deriving (SortOrder -> SortOrder -> Bool
(SortOrder -> SortOrder -> Bool)
-> (SortOrder -> SortOrder -> Bool) -> Eq SortOrder
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: SortOrder -> SortOrder -> Bool
== :: SortOrder -> SortOrder -> Bool
$c/= :: SortOrder -> SortOrder -> Bool
/= :: SortOrder -> SortOrder -> Bool
Eq)

{- | O(k log n) Sorts the dataframe by a given row.

> sortBy Ascending ["Age"] df
-}
sortBy ::
    SortOrder ->
    [T.Text] ->
    DataFrame ->
    DataFrame
sortBy :: SortOrder -> [Text] -> DataFrame -> DataFrame
sortBy SortOrder
order [Text]
names DataFrame
df
    | (Text -> Bool) -> [Text] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Text -> [Text] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` DataFrame -> [Text]
columnNames DataFrame
df) [Text]
names =
        DataFrameException -> DataFrame
forall a e. Exception e => e -> a
throw (DataFrameException -> DataFrame)
-> DataFrameException -> DataFrame
forall a b. (a -> b) -> a -> b
$
            Text -> Text -> [Text] -> DataFrameException
ColumnNotFoundException
                (String -> Text
T.pack (String -> Text) -> String -> Text
forall a b. (a -> b) -> a -> b
$ [Text] -> String
forall a. Show a => a -> String
show ([Text] -> String) -> [Text] -> String
forall a b. (a -> b) -> a -> b
$ [Text]
names [Text] -> [Text] -> [Text]
forall a. Eq a => [a] -> [a] -> [a]
L.\\ DataFrame -> [Text]
columnNames DataFrame
df)
                Text
"sortBy"
                (DataFrame -> [Text]
columnNames DataFrame
df)
    | Bool
otherwise =
        let
            indexes :: Vector Int
indexes = Bool -> Vector Row -> Vector Int
sortedIndexes' (SortOrder
order SortOrder -> SortOrder -> Bool
forall a. Eq a => a -> a -> Bool
== SortOrder
Ascending) ([Text] -> DataFrame -> Vector Row
toRowVector [Text]
names DataFrame
df)
         in
            DataFrame
df{columns = V.map (atIndicesStable indexes) (columns df)}