{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ExplicitNamespaces #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}

module DataFrame.DecisionTree where

import qualified DataFrame.Functions as F
import DataFrame.Internal.Column
import DataFrame.Internal.DataFrame (DataFrame (..), unsafeGetColumn)
import DataFrame.Internal.Expression (Expr (..), eSize)
import DataFrame.Internal.Interpreter (interpret)
import DataFrame.Internal.Statistics (percentile', percentileOrd')
import DataFrame.Internal.Types
import DataFrame.Operations.Core (columnNames, nRows)
import DataFrame.Operations.Statistics (percentile)
import DataFrame.Operations.Subset (exclude, filterWhere)

import Control.Exception (throw)
import Control.Monad (guard)
import Data.Containers.ListUtils (nubOrd)
import Data.Function (on)
import Data.List (foldl', maximumBy, sortBy)
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Data.Text as T
import Data.Type.Equality
import qualified Data.Vector as V
import qualified Data.Vector.Unboxed as VU
import Type.Reflection (typeRep)

import DataFrame.Functions ((.<), (.<=), (.==), (.>), (.>=))

data TreeConfig
    = TreeConfig
    { TreeConfig -> Int
maxTreeDepth :: Int
    , TreeConfig -> Int
minSamplesSplit :: Int
    , TreeConfig -> Int
minLeafSize :: Int
    , TreeConfig -> SynthConfig
synthConfig :: SynthConfig
    }
    deriving (TreeConfig -> TreeConfig -> Bool
(TreeConfig -> TreeConfig -> Bool)
-> (TreeConfig -> TreeConfig -> Bool) -> Eq TreeConfig
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: TreeConfig -> TreeConfig -> Bool
== :: TreeConfig -> TreeConfig -> Bool
$c/= :: TreeConfig -> TreeConfig -> Bool
/= :: TreeConfig -> TreeConfig -> Bool
Eq, Int -> TreeConfig -> ShowS
[TreeConfig] -> ShowS
TreeConfig -> String
(Int -> TreeConfig -> ShowS)
-> (TreeConfig -> String)
-> ([TreeConfig] -> ShowS)
-> Show TreeConfig
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> TreeConfig -> ShowS
showsPrec :: Int -> TreeConfig -> ShowS
$cshow :: TreeConfig -> String
show :: TreeConfig -> String
$cshowList :: [TreeConfig] -> ShowS
showList :: [TreeConfig] -> ShowS
Show)

data SynthConfig = SynthConfig
    { SynthConfig -> Int
maxExprDepth :: Int
    , SynthConfig -> Int
boolExpansion :: Int
    , SynthConfig -> [Int]
percentiles :: [Int]
    , SynthConfig -> Double
complexityPenalty :: Double
    , SynthConfig -> Bool
enableStringOps :: Bool
    , SynthConfig -> Bool
enableCrossCols :: Bool
    , SynthConfig -> Bool
enableArithOps :: Bool
    }
    deriving (SynthConfig -> SynthConfig -> Bool
(SynthConfig -> SynthConfig -> Bool)
-> (SynthConfig -> SynthConfig -> Bool) -> Eq SynthConfig
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: SynthConfig -> SynthConfig -> Bool
== :: SynthConfig -> SynthConfig -> Bool
$c/= :: SynthConfig -> SynthConfig -> Bool
/= :: SynthConfig -> SynthConfig -> Bool
Eq, Int -> SynthConfig -> ShowS
[SynthConfig] -> ShowS
SynthConfig -> String
(Int -> SynthConfig -> ShowS)
-> (SynthConfig -> String)
-> ([SynthConfig] -> ShowS)
-> Show SynthConfig
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> SynthConfig -> ShowS
showsPrec :: Int -> SynthConfig -> ShowS
$cshow :: SynthConfig -> String
show :: SynthConfig -> String
$cshowList :: [SynthConfig] -> ShowS
showList :: [SynthConfig] -> ShowS
Show)

defaultSynthConfig :: SynthConfig
defaultSynthConfig :: SynthConfig
defaultSynthConfig =
    SynthConfig
        { maxExprDepth :: Int
maxExprDepth = Int
2
        , boolExpansion :: Int
boolExpansion = Int
2
        , percentiles :: [Int]
percentiles = [Int
0, Int
10 .. Int
100]
        , complexityPenalty :: Double
complexityPenalty = Double
0.05
        , enableStringOps :: Bool
enableStringOps = Bool
True
        , enableCrossCols :: Bool
enableCrossCols = Bool
True
        , enableArithOps :: Bool
enableArithOps = Bool
True
        }

defaultTreeConfig :: TreeConfig
defaultTreeConfig :: TreeConfig
defaultTreeConfig =
    TreeConfig
        { maxTreeDepth :: Int
maxTreeDepth = Int
4
        , minSamplesSplit :: Int
minSamplesSplit = Int
5
        , minLeafSize :: Int
minLeafSize = Int
1
        , synthConfig :: SynthConfig
synthConfig = SynthConfig
defaultSynthConfig
        }

fitDecisionTree ::
    forall a.
    (Columnable a) =>
    TreeConfig ->
    Expr a ->
    DataFrame ->
    Expr a
fitDecisionTree :: forall a.
Columnable a =>
TreeConfig -> Expr a -> DataFrame -> Expr a
fitDecisionTree TreeConfig
cfg (Col Text
target) DataFrame
df =
    forall a.
Columnable a =>
TreeConfig -> Int -> Text -> [Expr Bool] -> DataFrame -> Expr a
buildTree @a
        TreeConfig
cfg
        (TreeConfig -> Int
maxTreeDepth TreeConfig
cfg)
        Text
target
        ( CondGen
numericConditions (TreeConfig -> SynthConfig
synthConfig TreeConfig
cfg) ([Text] -> DataFrame -> DataFrame
exclude [Text
target] DataFrame
df)
            [Expr Bool] -> [Expr Bool] -> [Expr Bool]
forall a. [a] -> [a] -> [a]
++ CondGen
generateConditionsOld (TreeConfig -> SynthConfig
synthConfig TreeConfig
cfg) ([Text] -> DataFrame -> DataFrame
exclude [Text
target] DataFrame
df)
        )
        DataFrame
df
fitDecisionTree TreeConfig
_ Expr a
expr DataFrame
_ = String -> Expr a
forall a. HasCallStack => String -> a
error (String -> Expr a) -> String -> Expr a
forall a b. (a -> b) -> a -> b
$ String
"Cannot create tree for compound expression: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Expr a -> String
forall a. Show a => a -> String
show Expr a
expr

buildTree ::
    forall a.
    (Columnable a) =>
    TreeConfig ->
    Int ->
    T.Text ->
    [Expr Bool] ->
    DataFrame ->
    Expr a
buildTree :: forall a.
Columnable a =>
TreeConfig -> Int -> Text -> [Expr Bool] -> DataFrame -> Expr a
buildTree TreeConfig
cfg Int
depth Text
target [Expr Bool]
conds DataFrame
df
    | Int
depth Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0 Bool -> Bool -> Bool
|| DataFrame -> Int
nRows DataFrame
df Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= TreeConfig -> Int
minSamplesSplit TreeConfig
cfg =
        a -> Expr a
forall a. Columnable a => a -> Expr a
Lit (forall a. Columnable a => Text -> DataFrame -> a
majorityValue @a Text
target DataFrame
df)
    | Bool
otherwise =
        case forall a.
Columnable a =>
TreeConfig -> Text -> [Expr Bool] -> DataFrame -> Maybe (Expr Bool)
findBestSplit @a TreeConfig
cfg Text
target [Expr Bool]
conds DataFrame
df of
            Maybe (Expr Bool)
Nothing -> a -> Expr a
forall a. Columnable a => a -> Expr a
Lit (forall a. Columnable a => Text -> DataFrame -> a
majorityValue @a Text
target DataFrame
df)
            Just Expr Bool
bestCond ->
                let (DataFrame
dfTrue, DataFrame
dfFalse) = Expr Bool -> DataFrame -> (DataFrame, DataFrame)
partitionDataFrame Expr Bool
bestCond DataFrame
df
                 in if DataFrame -> Int
nRows DataFrame
dfTrue Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 Bool -> Bool -> Bool
|| DataFrame -> Int
nRows DataFrame
dfFalse Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
                        then a -> Expr a
forall a. Columnable a => a -> Expr a
Lit (forall a. Columnable a => Text -> DataFrame -> a
majorityValue @a Text
target DataFrame
df)
                        else
                            Expr a -> Expr a
forall a. (Columnable a, Eq a) => Expr a -> Expr a
pruneTree
                                ( Expr Bool -> Expr a -> Expr a -> Expr a
forall a. Columnable a => Expr Bool -> Expr a -> Expr a -> Expr a
F.ifThenElse
                                    Expr Bool
bestCond
                                    (forall a.
Columnable a =>
TreeConfig -> Int -> Text -> [Expr Bool] -> DataFrame -> Expr a
buildTree @a TreeConfig
cfg (Int
depth Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Text
target [Expr Bool]
conds DataFrame
dfTrue)
                                    (forall a.
Columnable a =>
TreeConfig -> Int -> Text -> [Expr Bool] -> DataFrame -> Expr a
buildTree @a TreeConfig
cfg (Int
depth Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Text
target [Expr Bool]
conds DataFrame
dfFalse)
                                )

pruneTree :: forall a. (Columnable a, Eq a) => Expr a -> Expr a
pruneTree :: forall a. (Columnable a, Eq a) => Expr a -> Expr a
pruneTree (If Expr Bool
cond Expr a
trueBranch Expr a
falseBranch) =
    let
        t :: Expr a
t = Expr a -> Expr a
forall a. (Columnable a, Eq a) => Expr a -> Expr a
pruneTree Expr a
trueBranch
        f :: Expr a
f = Expr a -> Expr a
forall a. (Columnable a, Eq a) => Expr a -> Expr a
pruneTree Expr a
falseBranch
     in
        if Expr a
t Expr a -> Expr a -> Bool
forall a. Eq a => a -> a -> Bool
== Expr a
f
            then Expr a
t
            else case (Expr a
t, Expr a
f) of
                -- Nested simplification: `if C1 then (if C1 then X else Y) else Z`
                -- becomes:     if C1 then X else Z`
                -- Generalize this with hegg later.
                (If Expr Bool
condInner Expr a
tInner Expr a
fInner, Expr a
_) | Expr Bool
cond Expr Bool -> Expr Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Expr Bool
condInner -> Expr Bool -> Expr a -> Expr a -> Expr a
forall a. Columnable a => Expr Bool -> Expr a -> Expr a -> Expr a
If Expr Bool
cond Expr a
tInner Expr a
f
                (Expr a
_, If Expr Bool
condInner Expr a
tInner Expr a
fInner) | Expr Bool
cond Expr Bool -> Expr Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Expr Bool
condInner -> Expr Bool -> Expr a -> Expr a -> Expr a
forall a. Columnable a => Expr Bool -> Expr a -> Expr a -> Expr a
If Expr Bool
cond Expr a
t Expr a
fInner
                (Expr a, Expr a)
_ -> Expr Bool -> Expr a -> Expr a -> Expr a
forall a. Columnable a => Expr Bool -> Expr a -> Expr a -> Expr a
If Expr Bool
cond Expr a
t Expr a
f
pruneTree (UnaryOp Text
name b -> a
op Expr b
e) = Text -> (b -> a) -> Expr b -> Expr a
forall a b.
(Columnable a, Columnable b) =>
Text -> (b -> a) -> Expr b -> Expr a
UnaryOp Text
name b -> a
op (Expr b -> Expr b
forall a. (Columnable a, Eq a) => Expr a -> Expr a
pruneTree Expr b
e)
pruneTree (BinaryOp Text
name c -> b -> a
op Expr c
l Expr b
r) = Text -> (c -> b -> a) -> Expr c -> Expr b -> Expr a
forall c b a.
(Columnable c, Columnable b, Columnable a) =>
Text -> (c -> b -> a) -> Expr c -> Expr b -> Expr a
BinaryOp Text
name c -> b -> a
op (Expr c -> Expr c
forall a. (Columnable a, Eq a) => Expr a -> Expr a
pruneTree Expr c
l) (Expr b -> Expr b
forall a. (Columnable a, Eq a) => Expr a -> Expr a
pruneTree Expr b
r)
pruneTree Expr a
e = Expr a
e

type CondGen = SynthConfig -> DataFrame -> [Expr Bool]

numericConditions :: CondGen
numericConditions :: CondGen
numericConditions = CondGen
generateNumericConds

generateNumericConds ::
    SynthConfig -> DataFrame -> [Expr Bool]
generateNumericConds :: CondGen
generateNumericConds SynthConfig
cfg DataFrame
df = do
    Expr Double
expr <- SynthConfig -> DataFrame -> [Expr Double]
numericExprsWithTerms SynthConfig
cfg DataFrame
df
    let thresholds :: [Double]
thresholds = (Int -> Double) -> [Int] -> [Double]
forall a b. (a -> b) -> [a] -> [b]
map (\Int
p -> Int -> Expr Double -> DataFrame -> Double
forall a.
(Columnable a, Real a, Unbox a) =>
Int -> Expr a -> DataFrame -> Double
percentile Int
p Expr Double
expr DataFrame
df) (SynthConfig -> [Int]
percentiles SynthConfig
cfg)
    Double
threshold <- [Double]
thresholds
    [ Expr Double
expr Expr Double -> Expr Double -> Expr Bool
forall a.
(Columnable a, Ord a, Eq a) =>
Expr a -> Expr a -> Expr Bool
.<= Double -> Expr Double
forall a. Columnable a => a -> Expr a
F.lit Double
threshold
        , Expr Double
expr Expr Double -> Expr Double -> Expr Bool
forall a.
(Columnable a, Ord a, Eq a) =>
Expr a -> Expr a -> Expr Bool
.>= Double -> Expr Double
forall a. Columnable a => a -> Expr a
F.lit Double
threshold
        , Expr Double
expr Expr Double -> Expr Double -> Expr Bool
forall a. (Columnable a, Ord a) => Expr a -> Expr a -> Expr Bool
.< Double -> Expr Double
forall a. Columnable a => a -> Expr a
F.lit Double
threshold
        , Expr Double
expr Expr Double -> Expr Double -> Expr Bool
forall a. (Columnable a, Ord a) => Expr a -> Expr a -> Expr Bool
.> Double -> Expr Double
forall a. Columnable a => a -> Expr a
F.lit Double
threshold
        ]

numericExprsWithTerms ::
    SynthConfig -> DataFrame -> [Expr Double]
numericExprsWithTerms :: SynthConfig -> DataFrame -> [Expr Double]
numericExprsWithTerms SynthConfig
cfg DataFrame
df =
    (Int -> [Expr Double]) -> [Int] -> [Expr Double]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (SynthConfig
-> DataFrame -> [Expr Double] -> Int -> Int -> [Expr Double]
numericExprs SynthConfig
cfg DataFrame
df [] Int
0) [Int
0 .. SynthConfig -> Int
maxExprDepth SynthConfig
cfg]

numericCols :: DataFrame -> [Expr Double]
numericCols :: DataFrame -> [Expr Double]
numericCols DataFrame
df = (Text -> [Expr Double]) -> [Text] -> [Expr Double]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Text -> [Expr Double]
extract (DataFrame -> [Text]
columnNames DataFrame
df)
  where
    extract :: Text -> [Expr Double]
extract Text
col = case Text -> DataFrame -> Column
unsafeGetColumn Text
col DataFrame
df of
        UnboxedColumn (Vector a
_ :: VU.Vector b) ->
            case TypeRep a -> TypeRep Double -> Maybe (a :~: Double)
forall a b. TypeRep a -> TypeRep b -> Maybe (a :~: b)
forall {k} (f :: k -> *) (a :: k) (b :: k).
TestEquality f =>
f a -> f b -> Maybe (a :~: b)
testEquality (forall a. Typeable a => TypeRep a
forall {k} (a :: k). Typeable a => TypeRep a
typeRep @b) (forall a. Typeable a => TypeRep a
forall {k} (a :: k). Typeable a => TypeRep a
typeRep @Double) of
                Just a :~: Double
Refl -> [Text -> Expr Double
forall a. Columnable a => Text -> Expr a
Col Text
col]
                Maybe (a :~: Double)
Nothing -> case forall a. SBoolI (IntegralTypes a) => SBool (IntegralTypes a)
sIntegral @b of
                    SBool (IntegralTypes a)
STrue -> [Expr a -> Expr Double
forall a. (Columnable a, Real a) => Expr a -> Expr Double
F.toDouble (forall a. Columnable a => Text -> Expr a
Col @b Text
col)]
                    SBool (IntegralTypes a)
SFalse -> []
        Column
_ -> []

numericExprs ::
    SynthConfig -> DataFrame -> [Expr Double] -> Int -> Int -> [Expr Double]
numericExprs :: SynthConfig
-> DataFrame -> [Expr Double] -> Int -> Int -> [Expr Double]
numericExprs SynthConfig
cfg DataFrame
df [Expr Double]
prevExprs Int
depth Int
maxDepth
    | Int
depth Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = [Expr Double]
baseExprs [Expr Double] -> [Expr Double] -> [Expr Double]
forall a. [a] -> [a] -> [a]
++ SynthConfig
-> DataFrame -> [Expr Double] -> Int -> Int -> [Expr Double]
numericExprs SynthConfig
cfg DataFrame
df [Expr Double]
baseExprs (Int
depth Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
maxDepth
    | Int
depth Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
maxDepth = []
    | Bool
otherwise =
        [Expr Double]
combinedExprs [Expr Double] -> [Expr Double] -> [Expr Double]
forall a. [a] -> [a] -> [a]
++ SynthConfig
-> DataFrame -> [Expr Double] -> Int -> Int -> [Expr Double]
numericExprs SynthConfig
cfg DataFrame
df [Expr Double]
combinedExprs (Int
depth Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
maxDepth
  where
    baseExprs :: [Expr Double]
baseExprs = DataFrame -> [Expr Double]
numericCols DataFrame
df

    combinedExprs :: [Expr Double]
combinedExprs
        | Bool -> Bool
not (SynthConfig -> Bool
enableArithOps SynthConfig
cfg) = []
        | Bool
otherwise = do
            Expr Double
e1 <- [Expr Double]
prevExprs
            Expr Double
e2 <- [Expr Double]
baseExprs
            Bool -> [()]
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Expr Double
e1 Expr Double -> Expr Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Expr Double
e2)
            [Expr Double
e1 Expr Double -> Expr Double -> Expr Double
forall a. Num a => a -> a -> a
+ Expr Double
e2, Expr Double
e1 Expr Double -> Expr Double -> Expr Double
forall a. Num a => a -> a -> a
- Expr Double
e2, Expr Double
e1 Expr Double -> Expr Double -> Expr Double
forall a. Num a => a -> a -> a
* Expr Double
e2, Expr Bool -> Expr Double -> Expr Double -> Expr Double
forall a. Columnable a => Expr Bool -> Expr a -> Expr a -> Expr a
F.ifThenElse (Expr Double
e2 Expr Double -> Expr Double -> Expr Bool
forall a.
(Columnable a, Ord a, Eq a) =>
Expr a -> Expr a -> Expr Bool
.>= Expr Double
0) (Expr Double
e1 Expr Double -> Expr Double -> Expr Double
forall a. Fractional a => a -> a -> a
/ Expr Double
e2) Expr Double
0]

boolExprs ::
    DataFrame -> [Expr Bool] -> [Expr Bool] -> Int -> Int -> [Expr Bool]
boolExprs :: DataFrame
-> [Expr Bool] -> [Expr Bool] -> Int -> Int -> [Expr Bool]
boolExprs DataFrame
df [Expr Bool]
baseExprs [Expr Bool]
prevExprs Int
depth Int
maxDepth
    | Int
depth Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 =
        [Expr Bool]
baseExprs [Expr Bool] -> [Expr Bool] -> [Expr Bool]
forall a. [a] -> [a] -> [a]
++ DataFrame
-> [Expr Bool] -> [Expr Bool] -> Int -> Int -> [Expr Bool]
boolExprs DataFrame
df [Expr Bool]
baseExprs [Expr Bool]
prevExprs (Int
depth Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
maxDepth
    | Int
depth Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
maxDepth = []
    | Bool
otherwise =
        [Expr Bool]
combinedExprs [Expr Bool] -> [Expr Bool] -> [Expr Bool]
forall a. [a] -> [a] -> [a]
++ DataFrame
-> [Expr Bool] -> [Expr Bool] -> Int -> Int -> [Expr Bool]
boolExprs DataFrame
df [Expr Bool]
baseExprs [Expr Bool]
combinedExprs (Int
depth Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
maxDepth
  where
    combinedExprs :: [Expr Bool]
combinedExprs = do
        Expr Bool
e1 <- [Expr Bool]
prevExprs
        Expr Bool
e2 <- [Expr Bool]
baseExprs
        Bool -> [()]
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Expr Bool
e1 Expr Bool -> Expr Bool -> Bool
forall a. Eq a => a -> a -> Bool
/= Expr Bool
e2)
        [Expr Bool -> Expr Bool -> Expr Bool
F.and Expr Bool
e1 Expr Bool
e2, Expr Bool -> Expr Bool -> Expr Bool
F.or Expr Bool
e1 Expr Bool
e2]

generateConditionsOld :: SynthConfig -> DataFrame -> [Expr Bool]
generateConditionsOld :: CondGen
generateConditionsOld SynthConfig
cfg DataFrame
df =
    let
        genConds :: T.Text -> [Expr Bool]
        genConds :: Text -> [Expr Bool]
genConds Text
colName = case Text -> DataFrame -> Column
unsafeGetColumn Text
colName DataFrame
df of
            (BoxedColumn (Vector a
col :: V.Vector a)) ->
                let
                    percentiles :: [Expr a]
percentiles = (Int -> Expr a) -> [Int] -> [Expr a]
forall a b. (a -> b) -> [a] -> [b]
map (a -> Expr a
forall a. Columnable a => a -> Expr a
Lit (a -> Expr a) -> (Int -> a) -> Int -> Expr a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Vector a -> a
forall a. (Ord a, Eq a) => Int -> Vector a -> a
`percentileOrd'` Vector a
col)) [Int
1, Int
25, Int
75, Int
99]
                 in
                    (Expr a -> Expr Bool) -> [Expr a] -> [Expr Bool]
forall a b. (a -> b) -> [a] -> [b]
map (forall a. Columnable a => Text -> Expr a
Col @a Text
colName Expr a -> Expr a -> Expr Bool
forall a. (Columnable a, Eq a) => Expr a -> Expr a -> Expr Bool
.==) [Expr a]
percentiles
            (OptionalColumn (Vector (Maybe a)
col :: V.Vector (Maybe a))) -> case forall a. SBoolI (FloatingTypes a) => SBool (FloatingTypes a)
sFloating @a of
                SBool (FloatingTypes a)
STrue ->
                    let
                        doubleCol :: Vector Double
doubleCol =
                            Vector Double -> Vector Double
forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
VU.convert
                                ((Maybe Double -> Double) -> Vector (Maybe Double) -> Vector Double
forall a b. (a -> b) -> Vector a -> Vector b
V.map Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust ((Maybe Double -> Bool)
-> Vector (Maybe Double) -> Vector (Maybe Double)
forall a. (a -> Bool) -> Vector a -> Vector a
V.filter Maybe Double -> Bool
forall a. Maybe a -> Bool
isJust ((Maybe a -> Maybe Double)
-> Vector (Maybe a) -> Vector (Maybe Double)
forall a b. (a -> b) -> Vector a -> Vector b
V.map ((a -> Double) -> Maybe a -> Maybe Double
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a b. (Real a, Fractional b) => a -> b
realToFrac @a @Double)) Vector (Maybe a)
col)))
                     in
                        ((Expr (Maybe a) -> Expr Bool) -> Expr (Maybe a) -> Expr Bool)
-> [Expr (Maybe a) -> Expr Bool] -> [Expr (Maybe a)] -> [Expr Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
                            (Expr (Maybe a) -> Expr Bool) -> Expr (Maybe a) -> Expr Bool
forall a b. (a -> b) -> a -> b
($)
                            [ (forall a. Columnable a => Text -> Expr a
Col @(Maybe a) Text
colName Expr (Maybe a) -> Expr (Maybe a) -> Expr Bool
forall a. (Columnable a, Eq a) => Expr a -> Expr a -> Expr Bool
.==)
                            , (forall a. Columnable a => Text -> Expr a
Col @(Maybe a) Text
colName Expr (Maybe a) -> Expr (Maybe a) -> Expr Bool
forall a.
(Columnable a, Ord a, Eq a) =>
Expr a -> Expr a -> Expr Bool
.<=)
                            , (forall a. Columnable a => Text -> Expr a
Col @(Maybe a) Text
colName Expr (Maybe a) -> Expr (Maybe a) -> Expr Bool
forall a.
(Columnable a, Ord a, Eq a) =>
Expr a -> Expr a -> Expr Bool
.>=)
                            ]
                            ( Maybe a -> Expr (Maybe a)
forall a. Columnable a => a -> Expr a
Lit Maybe a
forall a. Maybe a
Nothing
                                Expr (Maybe a) -> [Expr (Maybe a)] -> [Expr (Maybe a)]
forall a. a -> [a] -> [a]
: (Int -> Expr (Maybe a)) -> [Int] -> [Expr (Maybe a)]
forall a b. (a -> b) -> [a] -> [b]
map
                                    ( Maybe a -> Expr (Maybe a)
forall a. Columnable a => a -> Expr a
Lit
                                        (Maybe a -> Expr (Maybe a))
-> (Int -> Maybe a) -> Int -> Expr (Maybe a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Maybe a
forall a. a -> Maybe a
Just
                                        (a -> Maybe a) -> (Int -> a) -> Int -> Maybe a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> a
forall a b. (Real a, Fractional b) => a -> b
realToFrac
                                        (Double -> a) -> (Int -> Double) -> Int -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Vector Double -> Double
forall a. (Unbox a, Num a, Real a) => Int -> Vector a -> Double
`percentile'` Vector Double
doubleCol)
                                    )
                                    (SynthConfig -> [Int]
percentiles SynthConfig
cfg)
                            )
                SBool (FloatingTypes a)
SFalse -> case forall a. SBoolI (IntegralTypes a) => SBool (IntegralTypes a)
sIntegral @a of
                    SBool (IntegralTypes a)
STrue ->
                        let
                            doubleCol :: Vector Double
doubleCol =
                                Vector Double -> Vector Double
forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
VU.convert
                                    ((Maybe Double -> Double) -> Vector (Maybe Double) -> Vector Double
forall a b. (a -> b) -> Vector a -> Vector b
V.map Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust ((Maybe Double -> Bool)
-> Vector (Maybe Double) -> Vector (Maybe Double)
forall a. (a -> Bool) -> Vector a -> Vector a
V.filter Maybe Double -> Bool
forall a. Maybe a -> Bool
isJust ((Maybe a -> Maybe Double)
-> Vector (Maybe a) -> Vector (Maybe Double)
forall a b. (a -> b) -> Vector a -> Vector b
V.map ((a -> Double) -> Maybe a -> Maybe Double
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a b. (Integral a, Num b) => a -> b
fromIntegral @a @Double)) Vector (Maybe a)
col)))
                         in
                            ((Expr (Maybe a) -> Expr Bool) -> Expr (Maybe a) -> Expr Bool)
-> [Expr (Maybe a) -> Expr Bool] -> [Expr (Maybe a)] -> [Expr Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
                                (Expr (Maybe a) -> Expr Bool) -> Expr (Maybe a) -> Expr Bool
forall a b. (a -> b) -> a -> b
($)
                                [ (forall a. Columnable a => Text -> Expr a
Col @(Maybe a) Text
colName Expr (Maybe a) -> Expr (Maybe a) -> Expr Bool
forall a. (Columnable a, Eq a) => Expr a -> Expr a -> Expr Bool
.==)
                                , (forall a. Columnable a => Text -> Expr a
Col @(Maybe a) Text
colName Expr (Maybe a) -> Expr (Maybe a) -> Expr Bool
forall a.
(Columnable a, Ord a, Eq a) =>
Expr a -> Expr a -> Expr Bool
.<=)
                                , (forall a. Columnable a => Text -> Expr a
Col @(Maybe a) Text
colName Expr (Maybe a) -> Expr (Maybe a) -> Expr Bool
forall a.
(Columnable a, Ord a, Eq a) =>
Expr a -> Expr a -> Expr Bool
.>=)
                                ]
                                ( Maybe a -> Expr (Maybe a)
forall a. Columnable a => a -> Expr a
Lit Maybe a
forall a. Maybe a
Nothing
                                    Expr (Maybe a) -> [Expr (Maybe a)] -> [Expr (Maybe a)]
forall a. a -> [a] -> [a]
: (Int -> Expr (Maybe a)) -> [Int] -> [Expr (Maybe a)]
forall a b. (a -> b) -> [a] -> [b]
map
                                        ( Maybe a -> Expr (Maybe a)
forall a. Columnable a => a -> Expr a
Lit
                                            (Maybe a -> Expr (Maybe a))
-> (Int -> Maybe a) -> Int -> Expr (Maybe a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Maybe a
forall a. a -> Maybe a
Just
                                            (a -> Maybe a) -> (Int -> a) -> Int -> Maybe a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> a
forall b. Integral b => Double -> b
forall a b. (RealFrac a, Integral b) => a -> b
round
                                            (Double -> a) -> (Int -> Double) -> Int -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Vector Double -> Double
forall a. (Unbox a, Num a, Real a) => Int -> Vector a -> Double
`percentile'` Vector Double
doubleCol)
                                        )
                                        (SynthConfig -> [Int]
percentiles SynthConfig
cfg)
                                )
                    SBool (IntegralTypes a)
SFalse ->
                        (Int -> Expr Bool) -> [Int] -> [Expr Bool]
forall a b. (a -> b) -> [a] -> [b]
map
                            ((forall a. Columnable a => Text -> Expr a
Col @(Maybe a) Text
colName Expr (Maybe a) -> Expr (Maybe a) -> Expr Bool
forall a. (Columnable a, Eq a) => Expr a -> Expr a -> Expr Bool
.==) (Expr (Maybe a) -> Expr Bool)
-> (Int -> Expr (Maybe a)) -> Int -> Expr Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe a -> Expr (Maybe a)
forall a. Columnable a => a -> Expr a
Lit (Maybe a -> Expr (Maybe a))
-> (Int -> Maybe a) -> Int -> Expr (Maybe a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Vector (Maybe a) -> Maybe a
forall a. (Ord a, Eq a) => Int -> Vector a -> a
`percentileOrd'` Vector (Maybe a)
col))
                            [Int
1, Int
25, Int
75, Int
99]
            (UnboxedColumn (Vector a
col :: VU.Vector a)) -> []
        columnConds :: [Expr Bool]
columnConds = ((Text, Text) -> [Expr Bool]) -> [(Text, Text)] -> [Expr Bool]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Text, Text) -> [Expr Bool]
colConds [(Text
l, Text
r) | Text
l <- DataFrame -> [Text]
columnNames DataFrame
df, Text
r <- DataFrame -> [Text]
columnNames DataFrame
df]
          where
            colConds :: (Text, Text) -> [Expr Bool]
colConds (!Text
l, !Text
r) = case (Text -> DataFrame -> Column
unsafeGetColumn Text
l DataFrame
df, Text -> DataFrame -> Column
unsafeGetColumn Text
r DataFrame
df) of
                (BoxedColumn (Vector a
col1 :: V.Vector a), BoxedColumn (Vector a
col2 :: V.Vector b)) -> case TypeRep a -> TypeRep a -> Maybe (a :~: a)
forall a b. TypeRep a -> TypeRep b -> Maybe (a :~: b)
forall {k} (f :: k -> *) (a :: k) (b :: k).
TestEquality f =>
f a -> f b -> Maybe (a :~: b)
testEquality (forall a. Typeable a => TypeRep a
forall {k} (a :: k). Typeable a => TypeRep a
typeRep @a) (forall a. Typeable a => TypeRep a
forall {k} (a :: k). Typeable a => TypeRep a
typeRep @b) of
                    Maybe (a :~: a)
Nothing -> []
                    Just a :~: a
Refl -> [forall a. Columnable a => Text -> Expr a
Col @a Text
l Expr a -> Expr a -> Expr Bool
forall a. (Columnable a, Eq a) => Expr a -> Expr a -> Expr Bool
.== forall a. Columnable a => Text -> Expr a
Col @a Text
r]
                (UnboxedColumn (Vector a
col1 :: VU.Vector a), UnboxedColumn (Vector a
col2 :: VU.Vector b)) -> []
                ( OptionalColumn (Vector (Maybe a)
col1 :: V.Vector (Maybe a))
                    , OptionalColumn (Vector (Maybe a)
col2 :: V.Vector (Maybe b))
                    ) -> case TypeRep a -> TypeRep a -> Maybe (a :~: a)
forall a b. TypeRep a -> TypeRep b -> Maybe (a :~: b)
forall {k} (f :: k -> *) (a :: k) (b :: k).
TestEquality f =>
f a -> f b -> Maybe (a :~: b)
testEquality (forall a. Typeable a => TypeRep a
forall {k} (a :: k). Typeable a => TypeRep a
typeRep @a) (forall a. Typeable a => TypeRep a
forall {k} (a :: k). Typeable a => TypeRep a
typeRep @b) of
                        Maybe (a :~: a)
Nothing -> []
                        Just a :~: a
Refl -> case TypeRep a -> TypeRep Text -> Maybe (a :~: Text)
forall a b. TypeRep a -> TypeRep b -> Maybe (a :~: b)
forall {k} (f :: k -> *) (a :: k) (b :: k).
TestEquality f =>
f a -> f b -> Maybe (a :~: b)
testEquality (forall a. Typeable a => TypeRep a
forall {k} (a :: k). Typeable a => TypeRep a
typeRep @a) (forall a. Typeable a => TypeRep a
forall {k} (a :: k). Typeable a => TypeRep a
typeRep @T.Text) of
                            Maybe (a :~: Text)
Nothing -> [forall a. Columnable a => Text -> Expr a
Col @(Maybe a) Text
l Expr (Maybe a) -> Expr (Maybe a) -> Expr Bool
forall a.
(Columnable a, Ord a, Eq a) =>
Expr a -> Expr a -> Expr Bool
.<= Text -> Expr (Maybe a)
forall a. Columnable a => Text -> Expr a
Col Text
r, forall a. Columnable a => Text -> Expr a
Col @(Maybe a) Text
l Expr (Maybe a) -> Expr (Maybe a) -> Expr Bool
forall a. (Columnable a, Eq a) => Expr a -> Expr a -> Expr Bool
.== Text -> Expr (Maybe a)
forall a. Columnable a => Text -> Expr a
Col Text
r]
                            Just a :~: Text
Refl -> [forall a. Columnable a => Text -> Expr a
Col @(Maybe a) Text
l Expr (Maybe a) -> Expr (Maybe a) -> Expr Bool
forall a. (Columnable a, Eq a) => Expr a -> Expr a -> Expr Bool
.== Text -> Expr (Maybe a)
forall a. Columnable a => Text -> Expr a
Col Text
r]
                (Column, Column)
_ -> []
     in
        (Text -> [Expr Bool]) -> [Text] -> [Expr Bool]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Text -> [Expr Bool]
genConds (DataFrame -> [Text]
columnNames DataFrame
df) [Expr Bool] -> [Expr Bool] -> [Expr Bool]
forall a. [a] -> [a] -> [a]
++ [Expr Bool]
columnConds

partitionDataFrame :: Expr Bool -> DataFrame -> (DataFrame, DataFrame)
partitionDataFrame :: Expr Bool -> DataFrame -> (DataFrame, DataFrame)
partitionDataFrame Expr Bool
cond DataFrame
df = (Expr Bool -> DataFrame -> DataFrame
filterWhere Expr Bool
cond DataFrame
df, Expr Bool -> DataFrame -> DataFrame
filterWhere (Expr Bool -> Expr Bool
F.not Expr Bool
cond) DataFrame
df)

findBestSplit ::
    forall a.
    (Columnable a) =>
    TreeConfig -> T.Text -> [Expr Bool] -> DataFrame -> Maybe (Expr Bool)
findBestSplit :: forall a.
Columnable a =>
TreeConfig -> Text -> [Expr Bool] -> DataFrame -> Maybe (Expr Bool)
findBestSplit TreeConfig
cfg Text
target [Expr Bool]
conds DataFrame
df =
    let
        initialImpurity :: Double
initialImpurity = forall a. Columnable a => Text -> DataFrame -> Double
calculateGini @a Text
target DataFrame
df
        evalGain :: Expr Bool -> (Double, Int)
evalGain Expr Bool
cond =
            let (DataFrame
t, DataFrame
f) = Expr Bool -> DataFrame -> (DataFrame, DataFrame)
partitionDataFrame Expr Bool
cond DataFrame
df
                n :: Double
n = forall a b. (Integral a, Num b) => a -> b
fromIntegral @Int @Double (DataFrame -> Int
nRows DataFrame
df)
                weightT :: Double
weightT = forall a b. (Integral a, Num b) => a -> b
fromIntegral @Int @Double (DataFrame -> Int
nRows DataFrame
t) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
n
                weightF :: Double
weightF = forall a b. (Integral a, Num b) => a -> b
fromIntegral @Int @Double (DataFrame -> Int
nRows DataFrame
f) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
n
                newImpurity :: Double
newImpurity =
                    (Double
weightT Double -> Double -> Double
forall a. Num a => a -> a -> a
* forall a. Columnable a => Text -> DataFrame -> Double
calculateGini @a Text
target DataFrame
t)
                        Double -> Double -> Double
forall a. Num a => a -> a -> a
+ (Double
weightF Double -> Double -> Double
forall a. Num a => a -> a -> a
* forall a. Columnable a => Text -> DataFrame -> Double
calculateGini @a Text
target DataFrame
f)
             in ( (Double
initialImpurity Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
newImpurity)
                    Double -> Double -> Double
forall a. Num a => a -> a -> a
- SynthConfig -> Double
complexityPenalty (TreeConfig -> SynthConfig
synthConfig TreeConfig
cfg) Double -> Double -> Double
forall a. Num a => a -> a -> a
* Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Expr Bool -> Int
forall a. Expr a -> Int
eSize Expr Bool
cond)
                , Int -> Int
forall a. Num a => a -> a
negate (Expr Bool -> Int
forall a. Expr a -> Int
eSize Expr Bool
cond)
                )

        validConds :: [Expr Bool]
validConds =
            (Expr Bool -> Bool) -> [Expr Bool] -> [Expr Bool]
forall a. (a -> Bool) -> [a] -> [a]
filter
                ( \Expr Bool
c ->
                    let
                        (DataFrame
t, DataFrame
f) = Expr Bool -> DataFrame -> (DataFrame, DataFrame)
partitionDataFrame Expr Bool
c DataFrame
df
                     in
                        DataFrame -> Int
nRows DataFrame
t Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= TreeConfig -> Int
minLeafSize TreeConfig
cfg Bool -> Bool -> Bool
&& DataFrame -> Int
nRows DataFrame
f Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= TreeConfig -> Int
minLeafSize TreeConfig
cfg
                )
                ([Expr Bool] -> [Expr Bool]
forall a. Ord a => [a] -> [a]
nubOrd [Expr Bool]
conds)
        sortedConditions :: [Expr Bool]
sortedConditions = Int -> [Expr Bool] -> [Expr Bool]
forall a. Int -> [a] -> [a]
take Int
10 ((Expr Bool -> Expr Bool -> Ordering) -> [Expr Bool] -> [Expr Bool]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (((Double, Int) -> (Double, Int) -> Ordering)
-> (Double, Int) -> (Double, Int) -> Ordering
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Double, Int) -> (Double, Int) -> Ordering
forall a. Ord a => a -> a -> Ordering
compare ((Double, Int) -> (Double, Int) -> Ordering)
-> (Expr Bool -> (Double, Int))
-> Expr Bool
-> Expr Bool
-> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` Expr Bool -> (Double, Int)
evalGain) [Expr Bool]
validConds)
     in
        if [Expr Bool] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Expr Bool]
validConds
            then Maybe (Expr Bool)
forall a. Maybe a
Nothing
            else
                Expr Bool -> Maybe (Expr Bool)
forall a. a -> Maybe a
Just (Expr Bool -> Maybe (Expr Bool)) -> Expr Bool -> Maybe (Expr Bool)
forall a b. (a -> b) -> a -> b
$
                    (Expr Bool -> Expr Bool -> Ordering) -> [Expr Bool] -> Expr Bool
forall (t :: * -> *) a.
Foldable t =>
(a -> a -> Ordering) -> t a -> a
maximumBy
                        ((Double, Int) -> (Double, Int) -> Ordering
forall a. Ord a => a -> a -> Ordering
compare ((Double, Int) -> (Double, Int) -> Ordering)
-> (Expr Bool -> (Double, Int))
-> Expr Bool
-> Expr Bool
-> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` Expr Bool -> (Double, Int)
evalGain)
                        ( DataFrame
-> [Expr Bool] -> [Expr Bool] -> Int -> Int -> [Expr Bool]
boolExprs
                            DataFrame
df
                            [Expr Bool]
sortedConditions
                            [Expr Bool]
sortedConditions
                            Int
0
                            (SynthConfig -> Int
boolExpansion (TreeConfig -> SynthConfig
synthConfig TreeConfig
cfg))
                        )

calculateGini ::
    forall a.
    (Columnable a) =>
    T.Text -> DataFrame -> Double
calculateGini :: forall a. Columnable a => Text -> DataFrame -> Double
calculateGini Text
target DataFrame
df =
    let n :: Double
n = Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Double) -> Int -> Double
forall a b. (a -> b) -> a -> b
$ DataFrame -> Int
nRows DataFrame
df
        counts :: Map a Int
counts = forall a. Columnable a => Text -> DataFrame -> Map a Int
getCounts @a Text
target DataFrame
df
        numClasses :: Double
numClasses = Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Double) -> Int -> Double
forall a b. (a -> b) -> a -> b
$ Map a Int -> Int
forall k a. Map k a -> Int
M.size Map a Int
counts
        probs :: [Double]
probs = (Int -> Double) -> [Int] -> [Double]
forall a b. (a -> b) -> [a] -> [b]
map (\Int
c -> (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
c Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
1) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ (Double
n Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
numClasses)) (Map a Int -> [Int]
forall k a. Map k a -> [a]
M.elems Map a Int
counts)
     in if Double
n Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
== Double
0 then Double
0 else Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- [Double] -> Double
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((Double -> Double) -> [Double] -> [Double]
forall a b. (a -> b) -> [a] -> [b]
map (Double -> Integer -> Double
forall a b. (Num a, Integral b) => a -> b -> a
^ Integer
2) [Double]
probs)

majorityValue ::
    forall a.
    (Columnable a) =>
    T.Text -> DataFrame -> a
majorityValue :: forall a. Columnable a => Text -> DataFrame -> a
majorityValue Text
target DataFrame
df =
    let counts :: Map a Int
counts = forall a. Columnable a => Text -> DataFrame -> Map a Int
getCounts @a Text
target DataFrame
df
     in if Map a Int -> Bool
forall k a. Map k a -> Bool
M.null Map a Int
counts
            then String -> a
forall a. HasCallStack => String -> a
error String
"Empty DataFrame in leaf"
            else (a, Int) -> a
forall a b. (a, b) -> a
fst ((a, Int) -> a) -> (a, Int) -> a
forall a b. (a -> b) -> a -> b
$ ((a, Int) -> (a, Int) -> Ordering) -> [(a, Int)] -> (a, Int)
forall (t :: * -> *) a.
Foldable t =>
(a -> a -> Ordering) -> t a -> a
maximumBy (Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Int -> Int -> Ordering)
-> ((a, Int) -> Int) -> (a, Int) -> (a, Int) -> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (a, Int) -> Int
forall a b. (a, b) -> b
snd) (Map a Int -> [(a, Int)]
forall k a. Map k a -> [(k, a)]
M.toList Map a Int
counts)

getCounts ::
    forall a.
    (Columnable a) =>
    T.Text -> DataFrame -> M.Map a Int
getCounts :: forall a. Columnable a => Text -> DataFrame -> Map a Int
getCounts Text
target DataFrame
df =
    case forall a.
Columnable a =>
DataFrame -> Expr a -> Either DataFrameException (TypedColumn a)
interpret @a DataFrame
df (Text -> Expr a
forall a. Columnable a => Text -> Expr a
Col Text
target) of
        Left DataFrameException
e -> DataFrameException -> Map a Int
forall a e. Exception e => e -> a
throw DataFrameException
e
        Right (TColumn Column
col) ->
            case forall a (v :: * -> *).
(Vector v a, Columnable a) =>
Column -> Either DataFrameException (v a)
toVector @a Column
col of
                Left DataFrameException
e -> DataFrameException -> Map a Int
forall a e. Exception e => e -> a
throw DataFrameException
e
                Right Vector a
vals -> (Map a Int -> a -> Map a Int) -> Map a Int -> [a] -> Map a Int
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\Map a Int
acc a
x -> (Int -> Int -> Int) -> a -> Int -> Map a Int -> Map a Int
forall k a. Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
M.insertWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) a
x Int
1 Map a Int
acc) Map a Int
forall k a. Map k a
M.empty (Vector a -> [a]
forall a. Vector a -> [a]
V.toList Vector a
vals)