{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}

module DataFrame.DecisionTree where

import qualified DataFrame.Functions as F
import DataFrame.Internal.Column
import DataFrame.Internal.DataFrame (DataFrame (..), unsafeGetColumn)
import DataFrame.Internal.Expression (Expr (..), eSize, getColumns)
import DataFrame.Internal.Interpreter (interpret)
import DataFrame.Internal.Statistics (percentile', percentileOrd')
import DataFrame.Internal.Types
import DataFrame.Operations.Core (columnNames, nRows)
import DataFrame.Operations.Subset (exclude, filterWhere)

import Control.Exception (throw)
import Control.Monad (guard)
import Data.Containers.ListUtils (nubOrd)
import Data.Function (on)
import Data.List (foldl', maximumBy, minimumBy, sort, sortBy)
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Data.Text as T
import Data.Type.Equality
import qualified Data.Vector as V
import qualified Data.Vector.Unboxed as VU
import Type.Reflection (typeRep)

import DataFrame.Operators

data TreeConfig = TreeConfig
    { TreeConfig -> Int
maxTreeDepth :: Int
    , TreeConfig -> Int
minSamplesSplit :: Int
    , TreeConfig -> Int
minLeafSize :: Int
    , TreeConfig -> [Int]
percentiles :: [Int]
    , TreeConfig -> Int
expressionPairs :: Int
    , TreeConfig -> SynthConfig
synthConfig :: SynthConfig
    , TreeConfig -> Int
taoIterations :: Int
    , TreeConfig -> Double
taoConvergenceTol :: Double
    }
    deriving (TreeConfig -> TreeConfig -> Bool
(TreeConfig -> TreeConfig -> Bool)
-> (TreeConfig -> TreeConfig -> Bool) -> Eq TreeConfig
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: TreeConfig -> TreeConfig -> Bool
== :: TreeConfig -> TreeConfig -> Bool
$c/= :: TreeConfig -> TreeConfig -> Bool
/= :: TreeConfig -> TreeConfig -> Bool
Eq, Int -> TreeConfig -> ShowS
[TreeConfig] -> ShowS
TreeConfig -> String
(Int -> TreeConfig -> ShowS)
-> (TreeConfig -> String)
-> ([TreeConfig] -> ShowS)
-> Show TreeConfig
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> TreeConfig -> ShowS
showsPrec :: Int -> TreeConfig -> ShowS
$cshow :: TreeConfig -> String
show :: TreeConfig -> String
$cshowList :: [TreeConfig] -> ShowS
showList :: [TreeConfig] -> ShowS
Show)

data SynthConfig = SynthConfig
    { SynthConfig -> Int
maxExprDepth :: Int
    , SynthConfig -> Int
boolExpansion :: Int
    , SynthConfig -> [(Text, Text)]
disallowedCombinations :: [(T.Text, T.Text)]
    , SynthConfig -> Double
complexityPenalty :: Double
    , SynthConfig -> Bool
enableStringOps :: Bool
    , SynthConfig -> Bool
enableCrossCols :: Bool
    , SynthConfig -> Bool
enableArithOps :: Bool
    }
    deriving (SynthConfig -> SynthConfig -> Bool
(SynthConfig -> SynthConfig -> Bool)
-> (SynthConfig -> SynthConfig -> Bool) -> Eq SynthConfig
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: SynthConfig -> SynthConfig -> Bool
== :: SynthConfig -> SynthConfig -> Bool
$c/= :: SynthConfig -> SynthConfig -> Bool
/= :: SynthConfig -> SynthConfig -> Bool
Eq, Int -> SynthConfig -> ShowS
[SynthConfig] -> ShowS
SynthConfig -> String
(Int -> SynthConfig -> ShowS)
-> (SynthConfig -> String)
-> ([SynthConfig] -> ShowS)
-> Show SynthConfig
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> SynthConfig -> ShowS
showsPrec :: Int -> SynthConfig -> ShowS
$cshow :: SynthConfig -> String
show :: SynthConfig -> String
$cshowList :: [SynthConfig] -> ShowS
showList :: [SynthConfig] -> ShowS
Show)

defaultSynthConfig :: SynthConfig
defaultSynthConfig :: SynthConfig
defaultSynthConfig =
    SynthConfig
        { maxExprDepth :: Int
maxExprDepth = Int
2
        , boolExpansion :: Int
boolExpansion = Int
2
        , disallowedCombinations :: [(Text, Text)]
disallowedCombinations = []
        , complexityPenalty :: Double
complexityPenalty = Double
0.05
        , enableStringOps :: Bool
enableStringOps = Bool
True
        , enableCrossCols :: Bool
enableCrossCols = Bool
True
        , enableArithOps :: Bool
enableArithOps = Bool
True
        }

defaultTreeConfig :: TreeConfig
defaultTreeConfig :: TreeConfig
defaultTreeConfig =
    TreeConfig
        { maxTreeDepth :: Int
maxTreeDepth = Int
4
        , minSamplesSplit :: Int
minSamplesSplit = Int
5
        , minLeafSize :: Int
minLeafSize = Int
1
        , percentiles :: [Int]
percentiles = [Int
0, Int
10 .. Int
100]
        , expressionPairs :: Int
expressionPairs = Int
10
        , synthConfig :: SynthConfig
synthConfig = SynthConfig
defaultSynthConfig
        , taoIterations :: Int
taoIterations = Int
10
        , taoConvergenceTol :: Double
taoConvergenceTol = Double
1e-6
        }

data Tree a
    = Leaf !a
    | Branch !(Expr Bool) !(Tree a) !(Tree a)
    deriving (Tree a -> Tree a -> Bool
(Tree a -> Tree a -> Bool)
-> (Tree a -> Tree a -> Bool) -> Eq (Tree a)
forall a. Eq a => Tree a -> Tree a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall a. Eq a => Tree a -> Tree a -> Bool
== :: Tree a -> Tree a -> Bool
$c/= :: forall a. Eq a => Tree a -> Tree a -> Bool
/= :: Tree a -> Tree a -> Bool
Eq, Int -> Tree a -> ShowS
[Tree a] -> ShowS
Tree a -> String
(Int -> Tree a -> ShowS)
-> (Tree a -> String) -> ([Tree a] -> ShowS) -> Show (Tree a)
forall a. Show a => Int -> Tree a -> ShowS
forall a. Show a => [Tree a] -> ShowS
forall a. Show a => Tree a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall a. Show a => Int -> Tree a -> ShowS
showsPrec :: Int -> Tree a -> ShowS
$cshow :: forall a. Show a => Tree a -> String
show :: Tree a -> String
$cshowList :: forall a. Show a => [Tree a] -> ShowS
showList :: [Tree a] -> ShowS
Show)

treeDepth :: Tree a -> Int
treeDepth :: forall a. Tree a -> Int
treeDepth (Leaf a
_) = Int
0
treeDepth (Branch Expr Bool
_ Tree a
l Tree a
r) = Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int -> Int -> Int
forall a. Ord a => a -> a -> a
max (Tree a -> Int
forall a. Tree a -> Int
treeDepth Tree a
l) (Tree a -> Int
forall a. Tree a -> Int
treeDepth Tree a
r)

treeToExpr :: (Columnable a) => Tree a -> Expr a
treeToExpr :: forall a. Columnable a => Tree a -> Expr a
treeToExpr (Leaf a
v) = a -> Expr a
forall a. Columnable a => a -> Expr a
Lit a
v
treeToExpr (Branch Expr Bool
cond Tree a
left Tree a
right) =
    Expr Bool -> Expr a -> Expr a -> Expr a
forall a. Columnable a => Expr Bool -> Expr a -> Expr a -> Expr a
F.ifThenElse Expr Bool
cond (Tree a -> Expr a
forall a. Columnable a => Tree a -> Expr a
treeToExpr Tree a
left) (Tree a -> Expr a
forall a. Columnable a => Tree a -> Expr a
treeToExpr Tree a
right)

-- | Fit a TAO decision tree
fitDecisionTree ::
    forall a.
    (Columnable a) =>
    TreeConfig ->
    Expr a ->
    DataFrame ->
    Expr a
fitDecisionTree :: forall a.
Columnable a =>
TreeConfig -> Expr a -> DataFrame -> Expr a
fitDecisionTree TreeConfig
cfg (Col Text
target) DataFrame
df =
    let
        conds :: [Expr Bool]
conds =
            [Expr Bool] -> [Expr Bool]
forall a. Ord a => [a] -> [a]
nubOrd ([Expr Bool] -> [Expr Bool]) -> [Expr Bool] -> [Expr Bool]
forall a b. (a -> b) -> a -> b
$
                TreeConfig -> DataFrame -> [Expr Bool]
numericConditions TreeConfig
cfg ([Text] -> DataFrame -> DataFrame
exclude [Text
target] DataFrame
df)
                    [Expr Bool] -> [Expr Bool] -> [Expr Bool]
forall a. [a] -> [a] -> [a]
++ TreeConfig -> DataFrame -> [Expr Bool]
generateConditionsOld TreeConfig
cfg ([Text] -> DataFrame -> DataFrame
exclude [Text
target] DataFrame
df)

        initialTree :: Tree a
initialTree = forall a.
Columnable a =>
TreeConfig -> Int -> Text -> [Expr Bool] -> DataFrame -> Tree a
buildGreedyTree @a TreeConfig
cfg (TreeConfig -> Int
maxTreeDepth TreeConfig
cfg) Text
target [Expr Bool]
conds DataFrame
df

        indices :: Vector Int
indices = Int -> Int -> Vector Int
forall a. Num a => a -> Int -> Vector a
V.enumFromN Int
0 (DataFrame -> Int
nRows DataFrame
df)

        optimizedTree :: Tree a
optimizedTree = forall a.
Columnable a =>
TreeConfig
-> Text
-> [Expr Bool]
-> DataFrame
-> Vector Int
-> Tree a
-> Tree a
taoOptimize @a TreeConfig
cfg Text
target [Expr Bool]
conds DataFrame
df Vector Int
indices Tree a
initialTree
     in
        Expr a -> Expr a
forall a. (Columnable a, Eq a) => Expr a -> Expr a
pruneExpr (Tree a -> Expr a
forall a. Columnable a => Tree a -> Expr a
treeToExpr Tree a
optimizedTree)
fitDecisionTree TreeConfig
_ Expr a
expr DataFrame
_ = String -> Expr a
forall a. HasCallStack => String -> a
error (String -> Expr a) -> String -> Expr a
forall a b. (a -> b) -> a -> b
$ String
"Cannot create tree for compound expression: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Expr a -> String
forall a. Show a => a -> String
show Expr a
expr

taoOptimize ::
    forall a.
    (Columnable a) =>
    TreeConfig ->
    T.Text -> -- Target column name
    [Expr Bool] -> -- Candidate conditions
    DataFrame -> -- Full dataset
    V.Vector Int -> -- Indices of points reaching the root
    Tree a -> -- Current tree
    Tree a
taoOptimize :: forall a.
Columnable a =>
TreeConfig
-> Text
-> [Expr Bool]
-> DataFrame
-> Vector Int
-> Tree a
-> Tree a
taoOptimize TreeConfig
cfg Text
target [Expr Bool]
conds DataFrame
df Vector Int
rootIndices Tree a
initialTree =
    Int -> Tree a -> Double -> Tree a
go Int
0 Tree a
initialTree (forall a.
Columnable a =>
Text -> DataFrame -> Vector Int -> Tree a -> Double
computeTreeLoss @a Text
target DataFrame
df Vector Int
rootIndices Tree a
initialTree)
  where
    go :: Int -> Tree a -> Double -> Tree a
    go :: Int -> Tree a -> Double -> Tree a
go Int
iter Tree a
tree Double
prevLoss
        | Int
iter Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= TreeConfig -> Int
taoIterations TreeConfig
cfg = Tree a -> Tree a
forall a. Tree a -> Tree a
pruneDead Tree a
tree
        | Bool
otherwise =
            let
                tree' :: Tree a
tree' = forall a.
Columnable a =>
TreeConfig
-> Text
-> [Expr Bool]
-> DataFrame
-> Vector Int
-> Tree a
-> Tree a
taoIteration @a TreeConfig
cfg Text
target [Expr Bool]
conds DataFrame
df Vector Int
rootIndices Tree a
tree

                newLoss :: Double
newLoss = forall a.
Columnable a =>
Text -> DataFrame -> Vector Int -> Tree a -> Double
computeTreeLoss @a Text
target DataFrame
df Vector Int
rootIndices Tree a
tree'
                improvement :: Double
improvement = Double
prevLoss Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
newLoss
             in
                if Double
improvement Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< TreeConfig -> Double
taoConvergenceTol TreeConfig
cfg
                    then Tree a -> Tree a
forall a. Tree a -> Tree a
pruneDead Tree a
tree'
                    else Int -> Tree a -> Double -> Tree a
go (Int
iter Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Tree a
tree' Double
newLoss

taoIteration ::
    forall a.
    (Columnable a) =>
    TreeConfig ->
    T.Text ->
    [Expr Bool] ->
    DataFrame ->
    V.Vector Int ->
    Tree a ->
    Tree a
taoIteration :: forall a.
Columnable a =>
TreeConfig
-> Text
-> [Expr Bool]
-> DataFrame
-> Vector Int
-> Tree a
-> Tree a
taoIteration TreeConfig
cfg Text
target [Expr Bool]
conds DataFrame
df Vector Int
rootIndices Tree a
tree =
    let depth :: Int
depth = Tree a -> Int
forall a. Tree a -> Int
treeDepth Tree a
tree
     in (Tree a -> Int -> Tree a) -> Tree a -> [Int] -> Tree a
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl'
            (forall a.
Columnable a =>
TreeConfig
-> Text
-> [Expr Bool]
-> DataFrame
-> Vector Int
-> Tree a
-> Int
-> Tree a
optimizeDepthLevel @a TreeConfig
cfg Text
target [Expr Bool]
conds DataFrame
df Vector Int
rootIndices)
            Tree a
tree
            [Int
depth, Int
depth Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 .. Int
0] -- Bottom to top

optimizeDepthLevel ::
    forall a.
    (Columnable a) =>
    TreeConfig ->
    T.Text ->
    [Expr Bool] ->
    DataFrame ->
    V.Vector Int ->
    Tree a ->
    Int -> -- Target depth
    Tree a
optimizeDepthLevel :: forall a.
Columnable a =>
TreeConfig
-> Text
-> [Expr Bool]
-> DataFrame
-> Vector Int
-> Tree a
-> Int
-> Tree a
optimizeDepthLevel TreeConfig
cfg Text
target [Expr Bool]
conds DataFrame
df Vector Int
rootIndices Tree a
tree = forall a.
Columnable a =>
TreeConfig
-> Text
-> [Expr Bool]
-> DataFrame
-> Vector Int
-> Tree a
-> Int
-> Int
-> Tree a
optimizeAtDepth @a TreeConfig
cfg Text
target [Expr Bool]
conds DataFrame
df Vector Int
rootIndices Tree a
tree Int
0

optimizeAtDepth ::
    forall a.
    (Columnable a) =>
    TreeConfig ->
    T.Text ->
    [Expr Bool] ->
    DataFrame ->
    V.Vector Int ->
    Tree a ->
    Int ->
    Int ->
    Tree a
optimizeAtDepth :: forall a.
Columnable a =>
TreeConfig
-> Text
-> [Expr Bool]
-> DataFrame
-> Vector Int
-> Tree a
-> Int
-> Int
-> Tree a
optimizeAtDepth TreeConfig
cfg Text
target [Expr Bool]
conds DataFrame
df Vector Int
indices Tree a
tree Int
currentDepth Int
targetDepth
    | Int
currentDepth Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
targetDepth =
        forall a.
Columnable a =>
TreeConfig
-> Text
-> [Expr Bool]
-> DataFrame
-> Vector Int
-> Tree a
-> Tree a
optimizeNode @a TreeConfig
cfg Text
target [Expr Bool]
conds DataFrame
df Vector Int
indices Tree a
tree
    | Bool
otherwise = case Tree a
tree of
        Leaf a
v -> a -> Tree a
forall a. a -> Tree a
Leaf a
v
        Branch Expr Bool
cond Tree a
left Tree a
right ->
            let
                (Vector Int
indicesL, Vector Int
indicesR) = Expr Bool -> DataFrame -> Vector Int -> (Vector Int, Vector Int)
partitionIndices Expr Bool
cond DataFrame
df Vector Int
indices
                left' :: Tree a
left' =
                    forall a.
Columnable a =>
TreeConfig
-> Text
-> [Expr Bool]
-> DataFrame
-> Vector Int
-> Tree a
-> Int
-> Int
-> Tree a
optimizeAtDepth @a
                        TreeConfig
cfg
                        Text
target
                        [Expr Bool]
conds
                        DataFrame
df
                        Vector Int
indicesL
                        Tree a
left
                        (Int
currentDepth Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
                        Int
targetDepth
                right' :: Tree a
right' =
                    forall a.
Columnable a =>
TreeConfig
-> Text
-> [Expr Bool]
-> DataFrame
-> Vector Int
-> Tree a
-> Int
-> Int
-> Tree a
optimizeAtDepth @a
                        TreeConfig
cfg
                        Text
target
                        [Expr Bool]
conds
                        DataFrame
df
                        Vector Int
indicesR
                        Tree a
right
                        (Int
currentDepth Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
                        Int
targetDepth
             in
                Expr Bool -> Tree a -> Tree a -> Tree a
forall a. Expr Bool -> Tree a -> Tree a -> Tree a
Branch Expr Bool
cond Tree a
left' Tree a
right'

optimizeNode ::
    forall a.
    (Columnable a) =>
    TreeConfig ->
    T.Text ->
    [Expr Bool] ->
    DataFrame ->
    V.Vector Int ->
    Tree a ->
    Tree a
optimizeNode :: forall a.
Columnable a =>
TreeConfig
-> Text
-> [Expr Bool]
-> DataFrame
-> Vector Int
-> Tree a
-> Tree a
optimizeNode TreeConfig
cfg Text
target [Expr Bool]
conds DataFrame
df Vector Int
indices Tree a
tree
    | Vector Int -> Bool
forall a. Vector a -> Bool
V.null Vector Int
indices = Tree a
tree
    | Bool
otherwise = case Tree a
tree of
        Leaf a
_ -> a -> Tree a
forall a. a -> Tree a
Leaf (forall a. Columnable a => Text -> DataFrame -> Vector Int -> a
majorityValueFromIndices @a Text
target DataFrame
df Vector Int
indices)
        Branch Expr Bool
oldCond Tree a
left Tree a
right ->
            let
                newCond :: Expr Bool
newCond = forall a.
Columnable a =>
TreeConfig
-> Text
-> [Expr Bool]
-> DataFrame
-> Vector Int
-> Tree a
-> Tree a
-> Expr Bool
-> Expr Bool
findBestSplitTAO @a TreeConfig
cfg Text
target [Expr Bool]
conds DataFrame
df Vector Int
indices Tree a
left Tree a
right Expr Bool
oldCond

                (Vector Int
newIndicesL, Vector Int
newIndicesR) = Expr Bool -> DataFrame -> Vector Int -> (Vector Int, Vector Int)
partitionIndices Expr Bool
newCond DataFrame
df Vector Int
indices
             in
                if Vector Int -> Int
forall a. Vector a -> Int
V.length Vector Int
newIndicesL Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< TreeConfig -> Int
minLeafSize TreeConfig
cfg
                    Bool -> Bool -> Bool
|| Vector Int -> Int
forall a. Vector a -> Int
V.length Vector Int
newIndicesR Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< TreeConfig -> Int
minLeafSize TreeConfig
cfg
                    then a -> Tree a
forall a. a -> Tree a
Leaf (forall a. Columnable a => Text -> DataFrame -> Vector Int -> a
majorityValueFromIndices @a Text
target DataFrame
df Vector Int
indices)
                    else Expr Bool -> Tree a -> Tree a -> Tree a
forall a. Expr Bool -> Tree a -> Tree a -> Tree a
Branch Expr Bool
newCond Tree a
left Tree a
right

findBestSplitTAO ::
    forall a.
    (Columnable a) =>
    TreeConfig ->
    T.Text ->
    [Expr Bool] ->
    DataFrame ->
    V.Vector Int ->
    Tree a -> -- Left subtree (FIXED)
    Tree a -> -- Right subtree (FIXED)
    Expr Bool -> -- Current condition (fallback)
    Expr Bool
findBestSplitTAO :: forall a.
Columnable a =>
TreeConfig
-> Text
-> [Expr Bool]
-> DataFrame
-> Vector Int
-> Tree a
-> Tree a
-> Expr Bool
-> Expr Bool
findBestSplitTAO TreeConfig
cfg Text
target [Expr Bool]
conds DataFrame
df Vector Int
indices Tree a
leftTree Tree a
rightTree Expr Bool
currentCond
    | Vector Int -> Bool
forall a. Vector a -> Bool
V.null Vector Int
indices = Expr Bool
currentCond
    | [Expr Bool] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Expr Bool]
validConds = Expr Bool
currentCond
    | Bool
otherwise =
        let
            carePoints :: [CarePoint]
carePoints = forall a.
Columnable a =>
Text -> DataFrame -> Vector Int -> Tree a -> Tree a -> [CarePoint]
identifyCarePoints @a Text
target DataFrame
df Vector Int
indices Tree a
leftTree Tree a
rightTree
         in
            if [CarePoint] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [CarePoint]
carePoints
                then Expr Bool
currentCond
                else
                    let
                        evalSplit :: Expr Bool -> Int
                        evalSplit :: Expr Bool -> Int
evalSplit Expr Bool
cond = Expr Bool -> DataFrame -> [CarePoint] -> Int
countCarePointErrors Expr Bool
cond DataFrame
df [CarePoint]
carePoints

                        evalWithPenalty :: Expr Bool -> Int
evalWithPenalty Expr Bool
c =
                            let errors :: Int
errors = Expr Bool -> Int
evalSplit Expr Bool
c
                                penalty :: Int
penalty =
                                    Double -> Int
forall b. Integral b => Double -> b
forall a b. (RealFrac a, Integral b) => a -> b
floor
                                        ( SynthConfig -> Double
complexityPenalty (TreeConfig -> SynthConfig
synthConfig TreeConfig
cfg)
                                            Double -> Double -> Double
forall a. Num a => a -> a -> a
* Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Expr Bool -> Int
forall a. Expr a -> Int
eSize Expr Bool
c)
                                        )
                             in Int
errors Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
penalty

                        sortedConds :: [Expr Bool]
sortedConds =
                            Int -> [Expr Bool] -> [Expr Bool]
forall a. Int -> [a] -> [a]
take (TreeConfig -> Int
expressionPairs TreeConfig
cfg) ([Expr Bool] -> [Expr Bool]) -> [Expr Bool] -> [Expr Bool]
forall a b. (a -> b) -> a -> b
$
                                (Expr Bool -> Expr Bool -> Ordering) -> [Expr Bool] -> [Expr Bool]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Int -> Int -> Ordering)
-> (Expr Bool -> Int) -> Expr Bool -> Expr Bool -> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` Expr Bool -> Int
evalWithPenalty) [Expr Bool]
validConds

                        expandedConds :: [Expr Bool]
expandedConds =
                            DataFrame
-> [Expr Bool] -> [Expr Bool] -> Int -> Int -> [Expr Bool]
boolExprs
                                DataFrame
df
                                [Expr Bool]
sortedConds
                                [Expr Bool]
sortedConds
                                Int
0
                                (SynthConfig -> Int
boolExpansion (TreeConfig -> SynthConfig
synthConfig TreeConfig
cfg))
                     in
                        if [Expr Bool] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Expr Bool]
expandedConds
                            then Expr Bool
currentCond
                            else (Expr Bool -> Expr Bool -> Ordering) -> [Expr Bool] -> Expr Bool
forall (t :: * -> *) a.
Foldable t =>
(a -> a -> Ordering) -> t a -> a
minimumBy (Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Int -> Int -> Ordering)
-> (Expr Bool -> Int) -> Expr Bool -> Expr Bool -> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` Expr Bool -> Int
evalWithPenalty) [Expr Bool]
expandedConds
  where
    validConds :: [Expr Bool]
validConds = (Expr Bool -> Bool) -> [Expr Bool] -> [Expr Bool]
forall a. (a -> Bool) -> [a] -> [a]
filter Expr Bool -> Bool
isValidSplit [Expr Bool]
conds
    isValidSplit :: Expr Bool -> Bool
isValidSplit Expr Bool
c =
        let (Vector Int
t, Vector Int
f) = Expr Bool -> DataFrame -> Vector Int -> (Vector Int, Vector Int)
partitionIndices Expr Bool
c DataFrame
df Vector Int
indices
         in Vector Int -> Int
forall a. Vector a -> Int
V.length Vector Int
t Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= TreeConfig -> Int
minLeafSize TreeConfig
cfg Bool -> Bool -> Bool
&& Vector Int -> Int
forall a. Vector a -> Int
V.length Vector Int
f Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= TreeConfig -> Int
minLeafSize TreeConfig
cfg

-- | A care point with its index and which direction leads to correct classification
data CarePoint = CarePoint
    { CarePoint -> Int
cpIndex :: !Int
    , CarePoint -> Direction
cpCorrectDir :: !Direction -- Which child classifies this point correctly
    }
    deriving (CarePoint -> CarePoint -> Bool
(CarePoint -> CarePoint -> Bool)
-> (CarePoint -> CarePoint -> Bool) -> Eq CarePoint
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: CarePoint -> CarePoint -> Bool
== :: CarePoint -> CarePoint -> Bool
$c/= :: CarePoint -> CarePoint -> Bool
/= :: CarePoint -> CarePoint -> Bool
Eq, Int -> CarePoint -> ShowS
[CarePoint] -> ShowS
CarePoint -> String
(Int -> CarePoint -> ShowS)
-> (CarePoint -> String)
-> ([CarePoint] -> ShowS)
-> Show CarePoint
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> CarePoint -> ShowS
showsPrec :: Int -> CarePoint -> ShowS
$cshow :: CarePoint -> String
show :: CarePoint -> String
$cshowList :: [CarePoint] -> ShowS
showList :: [CarePoint] -> ShowS
Show)

data Direction = GoLeft | GoRight
    deriving (Direction -> Direction -> Bool
(Direction -> Direction -> Bool)
-> (Direction -> Direction -> Bool) -> Eq Direction
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Direction -> Direction -> Bool
== :: Direction -> Direction -> Bool
$c/= :: Direction -> Direction -> Bool
/= :: Direction -> Direction -> Bool
Eq, Int -> Direction -> ShowS
[Direction] -> ShowS
Direction -> String
(Int -> Direction -> ShowS)
-> (Direction -> String)
-> ([Direction] -> ShowS)
-> Show Direction
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Direction -> ShowS
showsPrec :: Int -> Direction -> ShowS
$cshow :: Direction -> String
show :: Direction -> String
$cshowList :: [Direction] -> ShowS
showList :: [Direction] -> ShowS
Show)

{- | Identify care points: points where exactly one subtree classifies correctly

   For each point reaching the node:
   1. Compute what label the left subtree would predict
   2. Compute what label the right subtree would predict
   3. If exactly one matches the true label, it's a care point
   4. Record which direction leads to correct classification
-}
identifyCarePoints ::
    forall a.
    (Columnable a) =>
    T.Text ->
    DataFrame ->
    V.Vector Int ->
    Tree a -> -- Left subtree
    Tree a -> -- Right subtree
    [CarePoint]
identifyCarePoints :: forall a.
Columnable a =>
Text -> DataFrame -> Vector Int -> Tree a -> Tree a -> [CarePoint]
identifyCarePoints Text
target DataFrame
df Vector Int
indices Tree a
leftTree Tree a
rightTree =
    case forall a.
Columnable a =>
DataFrame -> Expr a -> Either DataFrameException (TypedColumn a)
interpret @a DataFrame
df (Text -> Expr a
forall a. Columnable a => Text -> Expr a
Col Text
target) of
        Left DataFrameException
_ -> []
        Right (TColumn Column
col) ->
            case forall a (v :: * -> *).
(Vector v a, Columnable a) =>
Column -> Either DataFrameException (v a)
toVector @a Column
col of
                Left DataFrameException
_ -> []
                Right Vector a
targetVals ->
                    Vector CarePoint -> [CarePoint]
forall a. Vector a -> [a]
V.toList (Vector CarePoint -> [CarePoint])
-> Vector CarePoint -> [CarePoint]
forall a b. (a -> b) -> a -> b
$ (Int -> Maybe CarePoint) -> Vector Int -> Vector CarePoint
forall a b. (a -> Maybe b) -> Vector a -> Vector b
V.mapMaybe (Vector a -> Int -> Maybe CarePoint
checkPoint Vector a
targetVals) Vector Int
indices
  where
    checkPoint :: V.Vector a -> Int -> Maybe CarePoint
    checkPoint :: Vector a -> Int -> Maybe CarePoint
checkPoint Vector a
targetVals Int
idx =
        let
            trueLabel :: a
trueLabel = Vector a
targetVals Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.! Int
idx
            leftPred :: a
leftPred = forall a. Columnable a => Text -> DataFrame -> Int -> Tree a -> a
predictWithTree @a Text
target DataFrame
df Int
idx Tree a
leftTree
            rightPred :: a
rightPred = forall a. Columnable a => Text -> DataFrame -> Int -> Tree a -> a
predictWithTree @a Text
target DataFrame
df Int
idx Tree a
rightTree
            leftCorrect :: Bool
leftCorrect = a
leftPred a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
trueLabel
            rightCorrect :: Bool
rightCorrect = a
rightPred a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
trueLabel
         in
            case (Bool
leftCorrect, Bool
rightCorrect) of
                (Bool
True, Bool
False) -> CarePoint -> Maybe CarePoint
forall a. a -> Maybe a
Just (CarePoint -> Maybe CarePoint) -> CarePoint -> Maybe CarePoint
forall a b. (a -> b) -> a -> b
$ Int -> Direction -> CarePoint
CarePoint Int
idx Direction
GoLeft
                (Bool
False, Bool
True) -> CarePoint -> Maybe CarePoint
forall a. a -> Maybe a
Just (CarePoint -> Maybe CarePoint) -> CarePoint -> Maybe CarePoint
forall a b. (a -> b) -> a -> b
$ Int -> Direction -> CarePoint
CarePoint Int
idx Direction
GoRight
                (Bool, Bool)
_ -> Maybe CarePoint
forall a. Maybe a
Nothing -- Don't-care point (both correct or both wrong)

-- | Predict the label for a single point using a fixed tree
predictWithTree ::
    forall a.
    (Columnable a) =>
    T.Text ->
    DataFrame ->
    Int -> -- Row index
    Tree a ->
    a
predictWithTree :: forall a. Columnable a => Text -> DataFrame -> Int -> Tree a -> a
predictWithTree Text
target DataFrame
df Int
idx (Leaf a
v) = a
v
predictWithTree Text
target DataFrame
df Int
idx (Branch Expr Bool
cond Tree a
left Tree a
right) =
    case forall a.
Columnable a =>
DataFrame -> Expr a -> Either DataFrameException (TypedColumn a)
interpret @Bool DataFrame
df Expr Bool
cond of
        Left DataFrameException
_ -> forall a. Columnable a => Text -> DataFrame -> Int -> Tree a -> a
predictWithTree @a Text
target DataFrame
df Int
idx Tree a
left -- Default to left on error
        Right (TColumn Column
col) ->
            case forall a (v :: * -> *).
(Vector v a, Columnable a) =>
Column -> Either DataFrameException (v a)
toVector @Bool Column
col of
                Left DataFrameException
_ -> forall a. Columnable a => Text -> DataFrame -> Int -> Tree a -> a
predictWithTree @a Text
target DataFrame
df Int
idx Tree a
left
                Right Vector Bool
boolVals ->
                    if Vector Bool
boolVals Vector Bool -> Int -> Bool
forall a. Vector a -> Int -> a
V.! Int
idx
                        then forall a. Columnable a => Text -> DataFrame -> Int -> Tree a -> a
predictWithTree @a Text
target DataFrame
df Int
idx Tree a
left
                        else forall a. Columnable a => Text -> DataFrame -> Int -> Tree a -> a
predictWithTree @a Text
target DataFrame
df Int
idx Tree a
right

countCarePointErrors :: Expr Bool -> DataFrame -> [CarePoint] -> Int
countCarePointErrors :: Expr Bool -> DataFrame -> [CarePoint] -> Int
countCarePointErrors Expr Bool
cond DataFrame
df [CarePoint]
carePoints =
    case forall a.
Columnable a =>
DataFrame -> Expr a -> Either DataFrameException (TypedColumn a)
interpret @Bool DataFrame
df Expr Bool
cond of
        Left DataFrameException
_ -> [CarePoint] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [CarePoint]
carePoints
        Right (TColumn Column
col) ->
            case forall a (v :: * -> *).
(Vector v a, Columnable a) =>
Column -> Either DataFrameException (v a)
toVector @Bool Column
col of
                Left DataFrameException
_ -> [CarePoint] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [CarePoint]
carePoints
                Right Vector Bool
boolVals ->
                    [CarePoint] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([CarePoint] -> Int) -> [CarePoint] -> Int
forall a b. (a -> b) -> a -> b
$ (CarePoint -> Bool) -> [CarePoint] -> [CarePoint]
forall a. (a -> Bool) -> [a] -> [a]
filter (Vector Bool -> CarePoint -> Bool
isMisclassified Vector Bool
boolVals) [CarePoint]
carePoints
  where
    isMisclassified :: V.Vector Bool -> CarePoint -> Bool
    isMisclassified :: Vector Bool -> CarePoint -> Bool
isMisclassified Vector Bool
boolVals CarePoint
cp =
        let goesLeft :: Bool
goesLeft = Vector Bool
boolVals Vector Bool -> Int -> Bool
forall a. Vector a -> Int -> a
V.! CarePoint -> Int
cpIndex CarePoint
cp
            shouldGoLeft :: Bool
shouldGoLeft = CarePoint -> Direction
cpCorrectDir CarePoint
cp Direction -> Direction -> Bool
forall a. Eq a => a -> a -> Bool
== Direction
GoLeft
         in Bool
goesLeft Bool -> Bool -> Bool
forall a. Eq a => a -> a -> Bool
/= Bool
shouldGoLeft

partitionIndices ::
    Expr Bool -> DataFrame -> V.Vector Int -> (V.Vector Int, V.Vector Int)
partitionIndices :: Expr Bool -> DataFrame -> Vector Int -> (Vector Int, Vector Int)
partitionIndices Expr Bool
cond DataFrame
df Vector Int
indices =
    case forall a.
Columnable a =>
DataFrame -> Expr a -> Either DataFrameException (TypedColumn a)
interpret @Bool DataFrame
df Expr Bool
cond of
        Left DataFrameException
_ -> (Vector Int
indices, Vector Int
forall a. Vector a
V.empty)
        Right (TColumn Column
col) ->
            case forall a (v :: * -> *).
(Vector v a, Columnable a) =>
Column -> Either DataFrameException (v a)
toVector @Bool Column
col of
                Left DataFrameException
_ -> (Vector Int
indices, Vector Int
forall a. Vector a
V.empty)
                Right Vector Bool
boolVals ->
                    (Int -> Bool) -> Vector Int -> (Vector Int, Vector Int)
forall a. (a -> Bool) -> Vector a -> (Vector a, Vector a)
V.partition (Vector Bool
boolVals Vector Bool -> Int -> Bool
forall a. Vector a -> Int -> a
V.!) Vector Int
indices

majorityValueFromIndices ::
    forall a.
    (Columnable a) =>
    T.Text ->
    DataFrame ->
    V.Vector Int ->
    a
majorityValueFromIndices :: forall a. Columnable a => Text -> DataFrame -> Vector Int -> a
majorityValueFromIndices Text
target DataFrame
df Vector Int
indices =
    case forall a.
Columnable a =>
DataFrame -> Expr a -> Either DataFrameException (TypedColumn a)
interpret @a DataFrame
df (Text -> Expr a
forall a. Columnable a => Text -> Expr a
Col Text
target) of
        Left DataFrameException
e -> DataFrameException -> a
forall a e. Exception e => e -> a
throw DataFrameException
e
        Right (TColumn Column
col) ->
            case forall a (v :: * -> *).
(Vector v a, Columnable a) =>
Column -> Either DataFrameException (v a)
toVector @a Column
col of
                Left DataFrameException
e -> DataFrameException -> a
forall a e. Exception e => e -> a
throw DataFrameException
e
                Right Vector a
vals ->
                    let counts :: Map a Integer
counts =
                            (Map a Integer -> Int -> Map a Integer)
-> Map a Integer -> Vector Int -> Map a Integer
forall a b. (a -> b -> a) -> a -> Vector b -> a
V.foldl'
                                (\Map a Integer
acc Int
i -> (Integer -> Integer -> Integer)
-> a -> Integer -> Map a Integer -> Map a Integer
forall k a. Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
M.insertWith Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
(+) (Vector a
vals Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.! Int
i) Integer
1 Map a Integer
acc)
                                Map a Integer
forall k a. Map k a
M.empty
                                Vector Int
indices
                     in if Map a Integer -> Bool
forall k a. Map k a -> Bool
M.null Map a Integer
counts
                            then String -> a
forall a. HasCallStack => String -> a
error String
"Empty indices in majorityValueFromIndices"
                            else (a, Integer) -> a
forall a b. (a, b) -> a
fst ((a, Integer) -> a) -> (a, Integer) -> a
forall a b. (a -> b) -> a -> b
$ ((a, Integer) -> (a, Integer) -> Ordering)
-> [(a, Integer)] -> (a, Integer)
forall (t :: * -> *) a.
Foldable t =>
(a -> a -> Ordering) -> t a -> a
maximumBy (Integer -> Integer -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Integer -> Integer -> Ordering)
-> ((a, Integer) -> Integer)
-> (a, Integer)
-> (a, Integer)
-> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (a, Integer) -> Integer
forall a b. (a, b) -> b
snd) (Map a Integer -> [(a, Integer)]
forall k a. Map k a -> [(k, a)]
M.toList Map a Integer
counts)

computeTreeLoss ::
    forall a.
    (Columnable a) =>
    T.Text ->
    DataFrame ->
    V.Vector Int ->
    Tree a ->
    Double
computeTreeLoss :: forall a.
Columnable a =>
Text -> DataFrame -> Vector Int -> Tree a -> Double
computeTreeLoss Text
target DataFrame
df Vector Int
indices Tree a
tree
    | Vector Int -> Bool
forall a. Vector a -> Bool
V.null Vector Int
indices = Double
0
    | Bool
otherwise =
        case forall a.
Columnable a =>
DataFrame -> Expr a -> Either DataFrameException (TypedColumn a)
interpret @a DataFrame
df (Text -> Expr a
forall a. Columnable a => Text -> Expr a
Col Text
target) of
            Left DataFrameException
_ -> Double
1.0
            Right (TColumn Column
col) ->
                case forall a (v :: * -> *).
(Vector v a, Columnable a) =>
Column -> Either DataFrameException (v a)
toVector @a Column
col of
                    Left DataFrameException
_ -> Double
1.0
                    Right Vector a
targetVals ->
                        let
                            n :: Int
n = Vector Int -> Int
forall a. Vector a -> Int
V.length Vector Int
indices
                            errors :: Int
errors =
                                Vector Int -> Int
forall a. Vector a -> Int
V.length (Vector Int -> Int) -> Vector Int -> Int
forall a b. (a -> b) -> a -> b
$
                                    (Int -> Bool) -> Vector Int -> Vector Int
forall a. (a -> Bool) -> Vector a -> Vector a
V.filter
                                        (\Int
i -> Vector a
targetVals Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.! Int
i a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= forall a. Columnable a => Text -> DataFrame -> Int -> Tree a -> a
predictWithTree @a Text
target DataFrame
df Int
i Tree a
tree)
                                        Vector Int
indices
                         in
                            Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
errors Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n

pruneDead :: Tree a -> Tree a
pruneDead :: forall a. Tree a -> Tree a
pruneDead (Leaf a
v) = a -> Tree a
forall a. a -> Tree a
Leaf a
v
pruneDead (Branch Expr Bool
cond Tree a
left Tree a
right) =
    let
        left' :: Tree a
left' = Tree a -> Tree a
forall a. Tree a -> Tree a
pruneDead Tree a
left
        right' :: Tree a
right' = Tree a -> Tree a
forall a. Tree a -> Tree a
pruneDead Tree a
right
     in
        Expr Bool -> Tree a -> Tree a -> Tree a
forall a. Expr Bool -> Tree a -> Tree a -> Tree a
Branch Expr Bool
cond Tree a
left' Tree a
right'

pruneExpr :: forall a. (Columnable a, Eq a) => Expr a -> Expr a
pruneExpr :: forall a. (Columnable a, Eq a) => Expr a -> Expr a
pruneExpr (If Expr Bool
cond Expr a
trueBranch Expr a
falseBranch) =
    let t :: Expr a
t = Expr a -> Expr a
forall a. (Columnable a, Eq a) => Expr a -> Expr a
pruneExpr Expr a
trueBranch
        f :: Expr a
f = Expr a -> Expr a
forall a. (Columnable a, Eq a) => Expr a -> Expr a
pruneExpr Expr a
falseBranch
     in if Expr a
t Expr a -> Expr a -> Bool
forall a. Eq a => a -> a -> Bool
== Expr a
f
            then Expr a
t
            else case (Expr a
t, Expr a
f) of
                (If Expr Bool
condInner Expr a
tInner Expr a
_, Expr a
_) | Expr Bool
cond Expr Bool -> Expr Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Expr Bool
condInner -> Expr Bool -> Expr a -> Expr a -> Expr a
forall a. Columnable a => Expr Bool -> Expr a -> Expr a -> Expr a
If Expr Bool
cond Expr a
tInner Expr a
f
                (Expr a
_, If Expr Bool
condInner Expr a
_ Expr a
fInner) | Expr Bool
cond Expr Bool -> Expr Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Expr Bool
condInner -> Expr Bool -> Expr a -> Expr a -> Expr a
forall a. Columnable a => Expr Bool -> Expr a -> Expr a -> Expr a
If Expr Bool
cond Expr a
t Expr a
fInner
                (Expr a, Expr a)
_ -> Expr Bool -> Expr a -> Expr a -> Expr a
forall a. Columnable a => Expr Bool -> Expr a -> Expr a -> Expr a
If Expr Bool
cond Expr a
t Expr a
f
pruneExpr (Unary UnaryOp b a
op Expr b
e) = UnaryOp b a -> Expr b -> Expr a
forall a b.
(Columnable a, Columnable b) =>
UnaryOp b a -> Expr b -> Expr a
Unary UnaryOp b a
op (Expr b -> Expr b
forall a. (Columnable a, Eq a) => Expr a -> Expr a
pruneExpr Expr b
e)
pruneExpr (Binary BinaryOp c b a
op Expr c
l Expr b
r) = BinaryOp c b a -> Expr c -> Expr b -> Expr a
forall c b a.
(Columnable c, Columnable b, Columnable a) =>
BinaryOp c b a -> Expr c -> Expr b -> Expr a
Binary BinaryOp c b a
op (Expr c -> Expr c
forall a. (Columnable a, Eq a) => Expr a -> Expr a
pruneExpr Expr c
l) (Expr b -> Expr b
forall a. (Columnable a, Eq a) => Expr a -> Expr a
pruneExpr Expr b
r)
pruneExpr Expr a
e = Expr a
e

buildGreedyTree ::
    forall a.
    (Columnable a) =>
    TreeConfig ->
    Int ->
    T.Text ->
    [Expr Bool] ->
    DataFrame ->
    Tree a
buildGreedyTree :: forall a.
Columnable a =>
TreeConfig -> Int -> Text -> [Expr Bool] -> DataFrame -> Tree a
buildGreedyTree TreeConfig
cfg Int
depth Text
target [Expr Bool]
conds DataFrame
df
    | Int
depth Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0 Bool -> Bool -> Bool
|| DataFrame -> Int
nRows DataFrame
df Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= TreeConfig -> Int
minSamplesSplit TreeConfig
cfg =
        a -> Tree a
forall a. a -> Tree a
Leaf (forall a. Columnable a => Text -> DataFrame -> a
majorityValue @a Text
target DataFrame
df)
    | Bool
otherwise =
        case forall a.
Columnable a =>
TreeConfig -> Text -> [Expr Bool] -> DataFrame -> Maybe (Expr Bool)
findBestGreedySplit @a TreeConfig
cfg Text
target [Expr Bool]
conds DataFrame
df of
            Maybe (Expr Bool)
Nothing -> a -> Tree a
forall a. a -> Tree a
Leaf (forall a. Columnable a => Text -> DataFrame -> a
majorityValue @a Text
target DataFrame
df)
            Just Expr Bool
bestCond ->
                let (DataFrame
dfTrue, DataFrame
dfFalse) = Expr Bool -> DataFrame -> (DataFrame, DataFrame)
partitionDataFrame Expr Bool
bestCond DataFrame
df
                 in if DataFrame -> Int
nRows DataFrame
dfTrue Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< TreeConfig -> Int
minLeafSize TreeConfig
cfg Bool -> Bool -> Bool
|| DataFrame -> Int
nRows DataFrame
dfFalse Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< TreeConfig -> Int
minLeafSize TreeConfig
cfg
                        then a -> Tree a
forall a. a -> Tree a
Leaf (forall a. Columnable a => Text -> DataFrame -> a
majorityValue @a Text
target DataFrame
df)
                        else
                            Expr Bool -> Tree a -> Tree a -> Tree a
forall a. Expr Bool -> Tree a -> Tree a -> Tree a
Branch
                                Expr Bool
bestCond
                                (forall a.
Columnable a =>
TreeConfig -> Int -> Text -> [Expr Bool] -> DataFrame -> Tree a
buildGreedyTree @a TreeConfig
cfg (Int
depth Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Text
target [Expr Bool]
conds DataFrame
dfTrue)
                                (forall a.
Columnable a =>
TreeConfig -> Int -> Text -> [Expr Bool] -> DataFrame -> Tree a
buildGreedyTree @a TreeConfig
cfg (Int
depth Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Text
target [Expr Bool]
conds DataFrame
dfFalse)

findBestGreedySplit ::
    forall a.
    (Columnable a) =>
    TreeConfig -> T.Text -> [Expr Bool] -> DataFrame -> Maybe (Expr Bool)
findBestGreedySplit :: forall a.
Columnable a =>
TreeConfig -> Text -> [Expr Bool] -> DataFrame -> Maybe (Expr Bool)
findBestGreedySplit TreeConfig
cfg Text
target [Expr Bool]
conds DataFrame
df =
    let
        initialImpurity :: Double
initialImpurity = forall a. Columnable a => Text -> DataFrame -> Double
calculateGini @a Text
target DataFrame
df
        calculateComplexity :: Expr Bool -> Double
calculateComplexity Expr Bool
c = SynthConfig -> Double
complexityPenalty (TreeConfig -> SynthConfig
synthConfig TreeConfig
cfg) Double -> Double -> Double
forall a. Num a => a -> a -> a
* Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Expr Bool -> Int
forall a. Expr a -> Int
eSize Expr Bool
c)

        evalGain :: Expr Bool -> (Double, Int)
        evalGain :: Expr Bool -> (Double, Int)
evalGain Expr Bool
cond =
            let (DataFrame
t, DataFrame
f) = Expr Bool -> DataFrame -> (DataFrame, DataFrame)
partitionDataFrame Expr Bool
cond DataFrame
df
                n :: Double
n = forall a b. (Integral a, Num b) => a -> b
fromIntegral @Int @Double (DataFrame -> Int
nRows DataFrame
df)
                weightT :: Double
weightT = forall a b. (Integral a, Num b) => a -> b
fromIntegral @Int @Double (DataFrame -> Int
nRows DataFrame
t) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
n
                weightF :: Double
weightF = forall a b. (Integral a, Num b) => a -> b
fromIntegral @Int @Double (DataFrame -> Int
nRows DataFrame
f) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
n
                newImpurity :: Double
newImpurity =
                    Double
weightT Double -> Double -> Double
forall a. Num a => a -> a -> a
* forall a. Columnable a => Text -> DataFrame -> Double
calculateGini @a Text
target DataFrame
t
                        Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
weightF Double -> Double -> Double
forall a. Num a => a -> a -> a
* forall a. Columnable a => Text -> DataFrame -> Double
calculateGini @a Text
target DataFrame
f
             in ( (Double
initialImpurity Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
newImpurity) Double -> Double -> Double
forall a. Num a => a -> a -> a
- Expr Bool -> Double
calculateComplexity Expr Bool
cond
                , Int -> Int
forall a. Num a => a -> a
negate (Expr Bool -> Int
forall a. Expr a -> Int
eSize Expr Bool
cond)
                )

        validConds :: [Expr Bool]
validConds =
            (Expr Bool -> Bool) -> [Expr Bool] -> [Expr Bool]
forall a. (a -> Bool) -> [a] -> [a]
filter
                ( \Expr Bool
c ->
                    let (DataFrame
t, DataFrame
f) = Expr Bool -> DataFrame -> (DataFrame, DataFrame)
partitionDataFrame Expr Bool
c DataFrame
df
                     in DataFrame -> Int
nRows DataFrame
t Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= TreeConfig -> Int
minLeafSize TreeConfig
cfg Bool -> Bool -> Bool
&& DataFrame -> Int
nRows DataFrame
f Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= TreeConfig -> Int
minLeafSize TreeConfig
cfg
                )
                [Expr Bool]
conds

        sortedConditions :: [Expr Bool]
sortedConditions =
            ((Expr Bool, (Double, Int)) -> Expr Bool)
-> [(Expr Bool, (Double, Int))] -> [Expr Bool]
forall a b. (a -> b) -> [a] -> [b]
map (Expr Bool, (Double, Int)) -> Expr Bool
forall a b. (a, b) -> a
fst ([(Expr Bool, (Double, Int))] -> [Expr Bool])
-> [(Expr Bool, (Double, Int))] -> [Expr Bool]
forall a b. (a -> b) -> a -> b
$
                Int -> [(Expr Bool, (Double, Int))] -> [(Expr Bool, (Double, Int))]
forall a. Int -> [a] -> [a]
take
                    (TreeConfig -> Int
expressionPairs TreeConfig
cfg)
                    ( ((Expr Bool, (Double, Int)) -> Bool)
-> [(Expr Bool, (Double, Int))] -> [(Expr Bool, (Double, Int))]
forall a. (a -> Bool) -> [a] -> [a]
filter
                        (\(Expr Bool
c, (Double, Int)
v) -> ((Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
> Double -> Double
forall a. Num a => a -> a
negate (Expr Bool -> Double
calculateComplexity Expr Bool
c)) (Double -> Bool)
-> ((Double, Int) -> Double) -> (Double, Int) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double, Int) -> Double
forall a b. (a, b) -> a
fst) (Double, Int)
v)
                        (((Expr Bool, (Double, Int))
 -> (Expr Bool, (Double, Int)) -> Ordering)
-> [(Expr Bool, (Double, Int))] -> [(Expr Bool, (Double, Int))]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (((Double, Int) -> (Double, Int) -> Ordering)
-> (Double, Int) -> (Double, Int) -> Ordering
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Double, Int) -> (Double, Int) -> Ordering
forall a. Ord a => a -> a -> Ordering
compare ((Double, Int) -> (Double, Int) -> Ordering)
-> ((Expr Bool, (Double, Int)) -> (Double, Int))
-> (Expr Bool, (Double, Int))
-> (Expr Bool, (Double, Int))
-> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (Expr Bool, (Double, Int)) -> (Double, Int)
forall a b. (a, b) -> b
snd) ((Expr Bool -> (Expr Bool, (Double, Int)))
-> [Expr Bool] -> [(Expr Bool, (Double, Int))]
forall a b. (a -> b) -> [a] -> [b]
map (\Expr Bool
c -> (Expr Bool
c, Expr Bool -> (Double, Int)
evalGain Expr Bool
c)) [Expr Bool]
validConds))
                    )
     in
        if [Expr Bool] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Expr Bool]
sortedConditions
            then Maybe (Expr Bool)
forall a. Maybe a
Nothing
            else
                Expr Bool -> Maybe (Expr Bool)
forall a. a -> Maybe a
Just (Expr Bool -> Maybe (Expr Bool)) -> Expr Bool -> Maybe (Expr Bool)
forall a b. (a -> b) -> a -> b
$
                    (Expr Bool -> Expr Bool -> Ordering) -> [Expr Bool] -> Expr Bool
forall (t :: * -> *) a.
Foldable t =>
(a -> a -> Ordering) -> t a -> a
maximumBy
                        ((Double, Int) -> (Double, Int) -> Ordering
forall a. Ord a => a -> a -> Ordering
compare ((Double, Int) -> (Double, Int) -> Ordering)
-> (Expr Bool -> (Double, Int))
-> Expr Bool
-> Expr Bool
-> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` Expr Bool -> (Double, Int)
evalGain)
                        ( DataFrame
-> [Expr Bool] -> [Expr Bool] -> Int -> Int -> [Expr Bool]
boolExprs
                            DataFrame
df
                            [Expr Bool]
sortedConditions
                            [Expr Bool]
sortedConditions
                            Int
0
                            (SynthConfig -> Int
boolExpansion (TreeConfig -> SynthConfig
synthConfig TreeConfig
cfg))
                        )

numericConditions :: TreeConfig -> DataFrame -> [Expr Bool]
numericConditions :: TreeConfig -> DataFrame -> [Expr Bool]
numericConditions = TreeConfig -> DataFrame -> [Expr Bool]
generateNumericConds

generateNumericConds :: TreeConfig -> DataFrame -> [Expr Bool]
generateNumericConds :: TreeConfig -> DataFrame -> [Expr Bool]
generateNumericConds TreeConfig
cfg DataFrame
df = do
    Expr Double
expr <- SynthConfig -> DataFrame -> [Expr Double]
numericExprsWithTerms (TreeConfig -> SynthConfig
synthConfig TreeConfig
cfg) DataFrame
df
    let thresholds :: [Double]
thresholds = (Int -> Double) -> [Int] -> [Double]
forall a b. (a -> b) -> [a] -> [b]
map (\Int
p -> Int -> Expr Double -> DataFrame -> Double
percentile Int
p Expr Double
expr DataFrame
df) (TreeConfig -> [Int]
percentiles TreeConfig
cfg)
    Double
threshold <- [Double]
thresholds
    [ Expr Double
expr Expr Double -> Expr Double -> Expr Bool
forall a.
(Columnable a, Ord a, Eq a) =>
Expr a -> Expr a -> Expr Bool
.<= Double -> Expr Double
forall a. Columnable a => a -> Expr a
F.lit Double
threshold
        , Expr Double
expr Expr Double -> Expr Double -> Expr Bool
forall a.
(Columnable a, Ord a, Eq a) =>
Expr a -> Expr a -> Expr Bool
.>= Double -> Expr Double
forall a. Columnable a => a -> Expr a
F.lit Double
threshold
        , Expr Double
expr Expr Double -> Expr Double -> Expr Bool
forall a. (Columnable a, Ord a) => Expr a -> Expr a -> Expr Bool
.< Double -> Expr Double
forall a. Columnable a => a -> Expr a
F.lit Double
threshold
        , Expr Double
expr Expr Double -> Expr Double -> Expr Bool
forall a. (Columnable a, Ord a) => Expr a -> Expr a -> Expr Bool
.> Double -> Expr Double
forall a. Columnable a => a -> Expr a
F.lit Double
threshold
        ]

numericExprsWithTerms :: SynthConfig -> DataFrame -> [Expr Double]
numericExprsWithTerms :: SynthConfig -> DataFrame -> [Expr Double]
numericExprsWithTerms SynthConfig
cfg DataFrame
df =
    (Int -> [Expr Double]) -> [Int] -> [Expr Double]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (SynthConfig
-> DataFrame -> [Expr Double] -> Int -> Int -> [Expr Double]
numericExprs SynthConfig
cfg DataFrame
df [] Int
0) [Int
0 .. SynthConfig -> Int
maxExprDepth SynthConfig
cfg]

numericCols :: DataFrame -> [Expr Double]
numericCols :: DataFrame -> [Expr Double]
numericCols DataFrame
df = (Text -> [Expr Double]) -> [Text] -> [Expr Double]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Text -> [Expr Double]
extract (DataFrame -> [Text]
columnNames DataFrame
df)
  where
    extract :: Text -> [Expr Double]
extract Text
col = case Text -> DataFrame -> Column
unsafeGetColumn Text
col DataFrame
df of
        UnboxedColumn (Vector a
_ :: VU.Vector b) ->
            case TypeRep a -> TypeRep Double -> Maybe (a :~: Double)
forall a b. TypeRep a -> TypeRep b -> Maybe (a :~: b)
forall {k} (f :: k -> *) (a :: k) (b :: k).
TestEquality f =>
f a -> f b -> Maybe (a :~: b)
testEquality (forall a. Typeable a => TypeRep a
forall {k} (a :: k). Typeable a => TypeRep a
typeRep @b) (forall a. Typeable a => TypeRep a
forall {k} (a :: k). Typeable a => TypeRep a
typeRep @Double) of
                Just a :~: Double
Refl -> [Text -> Expr Double
forall a. Columnable a => Text -> Expr a
Col Text
col]
                Maybe (a :~: Double)
Nothing -> case forall a. SBoolI (IntegralTypes a) => SBool (IntegralTypes a)
sIntegral @b of
                    SBool (IntegralTypes a)
STrue -> [Expr a -> Expr Double
forall a. (Columnable a, Real a) => Expr a -> Expr Double
F.toDouble (forall a. Columnable a => Text -> Expr a
Col @b Text
col)]
                    SBool (IntegralTypes a)
SFalse -> []
        Column
_ -> []

numericExprs ::
    SynthConfig -> DataFrame -> [Expr Double] -> Int -> Int -> [Expr Double]
numericExprs :: SynthConfig
-> DataFrame -> [Expr Double] -> Int -> Int -> [Expr Double]
numericExprs SynthConfig
cfg DataFrame
df [Expr Double]
prevExprs Int
depth Int
maxDepth
    | Int
depth Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = [Expr Double]
baseExprs [Expr Double] -> [Expr Double] -> [Expr Double]
forall a. [a] -> [a] -> [a]
++ SynthConfig
-> DataFrame -> [Expr Double] -> Int -> Int -> [Expr Double]
numericExprs SynthConfig
cfg DataFrame
df [Expr Double]
baseExprs (Int
depth Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
maxDepth
    | Int
depth Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
maxDepth = []
    | Bool
otherwise =
        [Expr Double]
combinedExprs [Expr Double] -> [Expr Double] -> [Expr Double]
forall a. [a] -> [a] -> [a]
++ SynthConfig
-> DataFrame -> [Expr Double] -> Int -> Int -> [Expr Double]
numericExprs SynthConfig
cfg DataFrame
df [Expr Double]
combinedExprs (Int
depth Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
maxDepth
  where
    baseExprs :: [Expr Double]
baseExprs = DataFrame -> [Expr Double]
numericCols DataFrame
df
    combinedExprs :: [Expr Double]
combinedExprs
        | Bool -> Bool
not (SynthConfig -> Bool
enableArithOps SynthConfig
cfg) = []
        | Bool
otherwise = do
            Expr Double
e1 <- [Expr Double]
prevExprs
            Expr Double
e2 <- [Expr Double]
baseExprs
            let cols :: [Text]
cols = Expr Double -> [Text]
forall a. Expr a -> [Text]
getColumns Expr Double
e1 [Text] -> [Text] -> [Text]
forall a. Semigroup a => a -> a -> a
<> Expr Double -> [Text]
forall a. Expr a -> [Text]
getColumns Expr Double
e2
            Bool -> [()]
forall (f :: * -> *). Alternative f => Bool -> f ()
guard
                ( Expr Double
e1 Expr Double -> Expr Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Expr Double
e2
                    Bool -> Bool -> Bool
&& Bool -> Bool
not
                        ( ((Text, Text) -> Bool) -> [(Text, Text)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any
                            (\(Text
l, Text
r) -> Text
l Text -> [Text] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Text]
cols Bool -> Bool -> Bool
&& Text
r Text -> [Text] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Text]
cols)
                            (SynthConfig -> [(Text, Text)]
disallowedCombinations SynthConfig
cfg)
                        )
                )
            [Expr Double
e1 Expr Double -> Expr Double -> Expr Double
forall a. Num a => a -> a -> a
+ Expr Double
e2, Expr Double
e1 Expr Double -> Expr Double -> Expr Double
forall a. Num a => a -> a -> a
- Expr Double
e2, Expr Double
e1 Expr Double -> Expr Double -> Expr Double
forall a. Num a => a -> a -> a
* Expr Double
e2, Expr Bool -> Expr Double -> Expr Double -> Expr Double
forall a. Columnable a => Expr Bool -> Expr a -> Expr a -> Expr a
F.ifThenElse (Expr Double
e2 Expr Double -> Expr Double -> Expr Bool
forall a. (Columnable a, Eq a) => Expr a -> Expr a -> Expr Bool
./= Expr Double
0) (Expr Double
e1 Expr Double -> Expr Double -> Expr Double
forall a. Fractional a => a -> a -> a
/ Expr Double
e2) Expr Double
0]

boolExprs ::
    DataFrame -> [Expr Bool] -> [Expr Bool] -> Int -> Int -> [Expr Bool]
boolExprs :: DataFrame
-> [Expr Bool] -> [Expr Bool] -> Int -> Int -> [Expr Bool]
boolExprs DataFrame
df [Expr Bool]
baseExprs [Expr Bool]
prevExprs Int
depth Int
maxDepth
    | Int
depth Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 =
        [Expr Bool]
baseExprs [Expr Bool] -> [Expr Bool] -> [Expr Bool]
forall a. [a] -> [a] -> [a]
++ DataFrame
-> [Expr Bool] -> [Expr Bool] -> Int -> Int -> [Expr Bool]
boolExprs DataFrame
df [Expr Bool]
baseExprs [Expr Bool]
prevExprs (Int
depth Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
maxDepth
    | Int
depth Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
maxDepth = []
    | Bool
otherwise =
        [Expr Bool]
combinedExprs [Expr Bool] -> [Expr Bool] -> [Expr Bool]
forall a. [a] -> [a] -> [a]
++ DataFrame
-> [Expr Bool] -> [Expr Bool] -> Int -> Int -> [Expr Bool]
boolExprs DataFrame
df [Expr Bool]
baseExprs [Expr Bool]
combinedExprs (Int
depth Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
maxDepth
  where
    combinedExprs :: [Expr Bool]
combinedExprs = do
        Expr Bool
e1 <- [Expr Bool]
prevExprs
        Expr Bool
e2 <- [Expr Bool]
baseExprs
        Bool -> [()]
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Expr Bool
e1 Expr Bool -> Expr Bool -> Bool
forall a. Eq a => a -> a -> Bool
/= Expr Bool
e2)
        [Expr Bool -> Expr Bool -> Expr Bool
F.and Expr Bool
e1 Expr Bool
e2, Expr Bool -> Expr Bool -> Expr Bool
F.or Expr Bool
e1 Expr Bool
e2]

generateConditionsOld :: TreeConfig -> DataFrame -> [Expr Bool]
generateConditionsOld :: TreeConfig -> DataFrame -> [Expr Bool]
generateConditionsOld TreeConfig
cfg DataFrame
df =
    let
        genConds :: T.Text -> [Expr Bool]
        genConds :: Text -> [Expr Bool]
genConds Text
colName = case Text -> DataFrame -> Column
unsafeGetColumn Text
colName DataFrame
df of
            (BoxedColumn (Vector a
col :: V.Vector a)) ->
                let ps :: [Expr a]
ps = (Int -> Expr a) -> [Int] -> [Expr a]
forall a b. (a -> b) -> [a] -> [b]
map (a -> Expr a
forall a. Columnable a => a -> Expr a
Lit (a -> Expr a) -> (Int -> a) -> Int -> Expr a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Vector a -> a
forall a. (Ord a, Eq a) => Int -> Vector a -> a
`percentileOrd'` Vector a
col)) [Int
1, Int
25, Int
75, Int
99]
                 in (Expr a -> Expr Bool) -> [Expr a] -> [Expr Bool]
forall a b. (a -> b) -> [a] -> [b]
map (forall a. Columnable a => Text -> Expr a
Col @a Text
colName Expr a -> Expr a -> Expr Bool
forall a. (Columnable a, Eq a) => Expr a -> Expr a -> Expr Bool
.==) [Expr a]
ps
            (OptionalColumn (Vector (Maybe a)
col :: V.Vector (Maybe a))) -> case forall a. SBoolI (FloatingTypes a) => SBool (FloatingTypes a)
sFloating @a of
                SBool (FloatingTypes a)
STrue ->
                    let doubleCol :: Vector Double
doubleCol =
                            Vector Double -> Vector Double
forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
VU.convert
                                ((Maybe Double -> Double) -> Vector (Maybe Double) -> Vector Double
forall a b. (a -> b) -> Vector a -> Vector b
V.map Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust ((Maybe Double -> Bool)
-> Vector (Maybe Double) -> Vector (Maybe Double)
forall a. (a -> Bool) -> Vector a -> Vector a
V.filter Maybe Double -> Bool
forall a. Maybe a -> Bool
isJust ((Maybe a -> Maybe Double)
-> Vector (Maybe a) -> Vector (Maybe Double)
forall a b. (a -> b) -> Vector a -> Vector b
V.map ((a -> Double) -> Maybe a -> Maybe Double
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a b. (Real a, Fractional b) => a -> b
realToFrac @a @Double)) Vector (Maybe a)
col)))
                     in ((Expr (Maybe a) -> Expr Bool) -> Expr (Maybe a) -> Expr Bool)
-> [Expr (Maybe a) -> Expr Bool] -> [Expr (Maybe a)] -> [Expr Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
                            (Expr (Maybe a) -> Expr Bool) -> Expr (Maybe a) -> Expr Bool
forall a b. (a -> b) -> a -> b
($)
                            [ (forall a. Columnable a => Text -> Expr a
Col @(Maybe a) Text
colName Expr (Maybe a) -> Expr (Maybe a) -> Expr Bool
forall a. (Columnable a, Eq a) => Expr a -> Expr a -> Expr Bool
.==)
                            , (forall a. Columnable a => Text -> Expr a
Col @(Maybe a) Text
colName Expr (Maybe a) -> Expr (Maybe a) -> Expr Bool
forall a.
(Columnable a, Ord a, Eq a) =>
Expr a -> Expr a -> Expr Bool
.<=)
                            , (forall a. Columnable a => Text -> Expr a
Col @(Maybe a) Text
colName Expr (Maybe a) -> Expr (Maybe a) -> Expr Bool
forall a.
(Columnable a, Ord a, Eq a) =>
Expr a -> Expr a -> Expr Bool
.>=)
                            ]
                            ( Maybe a -> Expr (Maybe a)
forall a. Columnable a => a -> Expr a
Lit Maybe a
forall a. Maybe a
Nothing
                                Expr (Maybe a) -> [Expr (Maybe a)] -> [Expr (Maybe a)]
forall a. a -> [a] -> [a]
: (Int -> Expr (Maybe a)) -> [Int] -> [Expr (Maybe a)]
forall a b. (a -> b) -> [a] -> [b]
map
                                    (Maybe a -> Expr (Maybe a)
forall a. Columnable a => a -> Expr a
Lit (Maybe a -> Expr (Maybe a))
-> (Int -> Maybe a) -> Int -> Expr (Maybe a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> (Int -> a) -> Int -> Maybe a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> a
forall a b. (Real a, Fractional b) => a -> b
realToFrac (Double -> a) -> (Int -> Double) -> Int -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Vector Double -> Double
forall a. (Unbox a, Num a, Real a) => Int -> Vector a -> Double
`percentile'` Vector Double
doubleCol))
                                    (TreeConfig -> [Int]
percentiles TreeConfig
cfg)
                            )
                SBool (FloatingTypes a)
SFalse -> case forall a. SBoolI (IntegralTypes a) => SBool (IntegralTypes a)
sIntegral @a of
                    SBool (IntegralTypes a)
STrue ->
                        let doubleCol :: Vector Double
doubleCol =
                                Vector Double -> Vector Double
forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
VU.convert
                                    ((Maybe Double -> Double) -> Vector (Maybe Double) -> Vector Double
forall a b. (a -> b) -> Vector a -> Vector b
V.map Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust ((Maybe Double -> Bool)
-> Vector (Maybe Double) -> Vector (Maybe Double)
forall a. (a -> Bool) -> Vector a -> Vector a
V.filter Maybe Double -> Bool
forall a. Maybe a -> Bool
isJust ((Maybe a -> Maybe Double)
-> Vector (Maybe a) -> Vector (Maybe Double)
forall a b. (a -> b) -> Vector a -> Vector b
V.map ((a -> Double) -> Maybe a -> Maybe Double
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a b. (Integral a, Num b) => a -> b
fromIntegral @a @Double)) Vector (Maybe a)
col)))
                         in ((Expr (Maybe a) -> Expr Bool) -> Expr (Maybe a) -> Expr Bool)
-> [Expr (Maybe a) -> Expr Bool] -> [Expr (Maybe a)] -> [Expr Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
                                (Expr (Maybe a) -> Expr Bool) -> Expr (Maybe a) -> Expr Bool
forall a b. (a -> b) -> a -> b
($)
                                [ (forall a. Columnable a => Text -> Expr a
Col @(Maybe a) Text
colName Expr (Maybe a) -> Expr (Maybe a) -> Expr Bool
forall a. (Columnable a, Eq a) => Expr a -> Expr a -> Expr Bool
.==)
                                , (forall a. Columnable a => Text -> Expr a
Col @(Maybe a) Text
colName Expr (Maybe a) -> Expr (Maybe a) -> Expr Bool
forall a.
(Columnable a, Ord a, Eq a) =>
Expr a -> Expr a -> Expr Bool
.<=)
                                , (forall a. Columnable a => Text -> Expr a
Col @(Maybe a) Text
colName Expr (Maybe a) -> Expr (Maybe a) -> Expr Bool
forall a.
(Columnable a, Ord a, Eq a) =>
Expr a -> Expr a -> Expr Bool
.>=)
                                ]
                                ( Maybe a -> Expr (Maybe a)
forall a. Columnable a => a -> Expr a
Lit Maybe a
forall a. Maybe a
Nothing
                                    Expr (Maybe a) -> [Expr (Maybe a)] -> [Expr (Maybe a)]
forall a. a -> [a] -> [a]
: (Int -> Expr (Maybe a)) -> [Int] -> [Expr (Maybe a)]
forall a b. (a -> b) -> [a] -> [b]
map
                                        (Maybe a -> Expr (Maybe a)
forall a. Columnable a => a -> Expr a
Lit (Maybe a -> Expr (Maybe a))
-> (Int -> Maybe a) -> Int -> Expr (Maybe a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> (Int -> a) -> Int -> Maybe a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> a
forall b. Integral b => Double -> b
forall a b. (RealFrac a, Integral b) => a -> b
round (Double -> a) -> (Int -> Double) -> Int -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Vector Double -> Double
forall a. (Unbox a, Num a, Real a) => Int -> Vector a -> Double
`percentile'` Vector Double
doubleCol))
                                        (TreeConfig -> [Int]
percentiles TreeConfig
cfg)
                                )
                    SBool (IntegralTypes a)
SFalse ->
                        (Int -> Expr Bool) -> [Int] -> [Expr Bool]
forall a b. (a -> b) -> [a] -> [b]
map
                            ((forall a. Columnable a => Text -> Expr a
Col @(Maybe a) Text
colName Expr (Maybe a) -> Expr (Maybe a) -> Expr Bool
forall a. (Columnable a, Eq a) => Expr a -> Expr a -> Expr Bool
.==) (Expr (Maybe a) -> Expr Bool)
-> (Int -> Expr (Maybe a)) -> Int -> Expr Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe a -> Expr (Maybe a)
forall a. Columnable a => a -> Expr a
Lit (Maybe a -> Expr (Maybe a))
-> (Int -> Maybe a) -> Int -> Expr (Maybe a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Vector (Maybe a) -> Maybe a
forall a. (Ord a, Eq a) => Int -> Vector a -> a
`percentileOrd'` Vector (Maybe a)
col))
                            [Int
1, Int
25, Int
75, Int
99]
            (UnboxedColumn (Vector a
_ :: VU.Vector a)) -> []

        columnConds :: [Expr Bool]
columnConds =
            ((Text, Text) -> [Expr Bool]) -> [(Text, Text)] -> [Expr Bool]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap
                (Text, Text) -> [Expr Bool]
colConds
                [ (Text
l, Text
r)
                | Text
l <- DataFrame -> [Text]
columnNames DataFrame
df
                , Text
r <- DataFrame -> [Text]
columnNames DataFrame
df
                , Bool -> Bool
not
                    ( ((Text, Text) -> Bool) -> [(Text, Text)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any
                        (\(Text
l', Text
r') -> [Text] -> [Text]
forall a. Ord a => [a] -> [a]
sort [Text
l', Text
r'] [Text] -> [Text] -> Bool
forall a. Eq a => a -> a -> Bool
== [Text] -> [Text]
forall a. Ord a => [a] -> [a]
sort [Text
l, Text
r])
                        (SynthConfig -> [(Text, Text)]
disallowedCombinations (TreeConfig -> SynthConfig
synthConfig TreeConfig
cfg))
                    )
                ]
          where
            colConds :: (Text, Text) -> [Expr Bool]
colConds (!Text
l, !Text
r) = case (Text -> DataFrame -> Column
unsafeGetColumn Text
l DataFrame
df, Text -> DataFrame -> Column
unsafeGetColumn Text
r DataFrame
df) of
                (BoxedColumn (Vector a
col1 :: V.Vector a), BoxedColumn (Vector a
_ :: V.Vector b)) ->
                    case TypeRep a -> TypeRep a -> Maybe (a :~: a)
forall a b. TypeRep a -> TypeRep b -> Maybe (a :~: b)
forall {k} (f :: k -> *) (a :: k) (b :: k).
TestEquality f =>
f a -> f b -> Maybe (a :~: b)
testEquality (forall a. Typeable a => TypeRep a
forall {k} (a :: k). Typeable a => TypeRep a
typeRep @a) (forall a. Typeable a => TypeRep a
forall {k} (a :: k). Typeable a => TypeRep a
typeRep @b) of
                        Maybe (a :~: a)
Nothing -> []
                        Just a :~: a
Refl -> [forall a. Columnable a => Text -> Expr a
Col @a Text
l Expr a -> Expr a -> Expr Bool
forall a. (Columnable a, Eq a) => Expr a -> Expr a -> Expr Bool
.== forall a. Columnable a => Text -> Expr a
Col @a Text
r]
                (UnboxedColumn (Vector a
_ :: VU.Vector a), UnboxedColumn (Vector a
_ :: VU.Vector b)) -> []
                ( OptionalColumn (Vector (Maybe a)
_ :: V.Vector (Maybe a))
                    , OptionalColumn (Vector (Maybe a)
_ :: V.Vector (Maybe b))
                    ) -> case TypeRep a -> TypeRep a -> Maybe (a :~: a)
forall a b. TypeRep a -> TypeRep b -> Maybe (a :~: b)
forall {k} (f :: k -> *) (a :: k) (b :: k).
TestEquality f =>
f a -> f b -> Maybe (a :~: b)
testEquality (forall a. Typeable a => TypeRep a
forall {k} (a :: k). Typeable a => TypeRep a
typeRep @a) (forall a. Typeable a => TypeRep a
forall {k} (a :: k). Typeable a => TypeRep a
typeRep @b) of
                        Maybe (a :~: a)
Nothing -> []
                        Just a :~: a
Refl -> case TypeRep a -> TypeRep Text -> Maybe (a :~: Text)
forall a b. TypeRep a -> TypeRep b -> Maybe (a :~: b)
forall {k} (f :: k -> *) (a :: k) (b :: k).
TestEquality f =>
f a -> f b -> Maybe (a :~: b)
testEquality (forall a. Typeable a => TypeRep a
forall {k} (a :: k). Typeable a => TypeRep a
typeRep @a) (forall a. Typeable a => TypeRep a
forall {k} (a :: k). Typeable a => TypeRep a
typeRep @T.Text) of
                            Maybe (a :~: Text)
Nothing -> [forall a. Columnable a => Text -> Expr a
Col @(Maybe a) Text
l Expr (Maybe a) -> Expr (Maybe a) -> Expr Bool
forall a.
(Columnable a, Ord a, Eq a) =>
Expr a -> Expr a -> Expr Bool
.<= Text -> Expr (Maybe a)
forall a. Columnable a => Text -> Expr a
Col Text
r, forall a. Columnable a => Text -> Expr a
Col @(Maybe a) Text
l Expr (Maybe a) -> Expr (Maybe a) -> Expr Bool
forall a. (Columnable a, Eq a) => Expr a -> Expr a -> Expr Bool
.== Text -> Expr (Maybe a)
forall a. Columnable a => Text -> Expr a
Col Text
r]
                            Just a :~: Text
Refl -> [forall a. Columnable a => Text -> Expr a
Col @(Maybe a) Text
l Expr (Maybe a) -> Expr (Maybe a) -> Expr Bool
forall a. (Columnable a, Eq a) => Expr a -> Expr a -> Expr Bool
.== Text -> Expr (Maybe a)
forall a. Columnable a => Text -> Expr a
Col Text
r]
                (Column, Column)
_ -> []
     in
        (Text -> [Expr Bool]) -> [Text] -> [Expr Bool]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Text -> [Expr Bool]
genConds (DataFrame -> [Text]
columnNames DataFrame
df) [Expr Bool] -> [Expr Bool] -> [Expr Bool]
forall a. [a] -> [a] -> [a]
++ [Expr Bool]
columnConds

partitionDataFrame :: Expr Bool -> DataFrame -> (DataFrame, DataFrame)
partitionDataFrame :: Expr Bool -> DataFrame -> (DataFrame, DataFrame)
partitionDataFrame Expr Bool
cond DataFrame
df = (Expr Bool -> DataFrame -> DataFrame
filterWhere Expr Bool
cond DataFrame
df, Expr Bool -> DataFrame -> DataFrame
filterWhere (Expr Bool -> Expr Bool
F.not Expr Bool
cond) DataFrame
df)

calculateGini :: forall a. (Columnable a) => T.Text -> DataFrame -> Double
calculateGini :: forall a. Columnable a => Text -> DataFrame -> Double
calculateGini Text
target DataFrame
df =
    let n :: Double
n = Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Double) -> Int -> Double
forall a b. (a -> b) -> a -> b
$ DataFrame -> Int
nRows DataFrame
df
        counts :: Map a Int
counts = forall a. Columnable a => Text -> DataFrame -> Map a Int
getCounts @a Text
target DataFrame
df
        numClasses :: Double
numClasses = Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Double) -> Int -> Double
forall a b. (a -> b) -> a -> b
$ Map a Int -> Int
forall k a. Map k a -> Int
M.size Map a Int
counts
        probs :: [Double]
probs = (Int -> Double) -> [Int] -> [Double]
forall a b. (a -> b) -> [a] -> [b]
map (\Int
c -> (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
c Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
1) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ (Double
n Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
numClasses)) (Map a Int -> [Int]
forall k a. Map k a -> [a]
M.elems Map a Int
counts)
     in if Double
n Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
== Double
0 then Double
0 else Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- [Double] -> Double
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((Double -> Double) -> [Double] -> [Double]
forall a b. (a -> b) -> [a] -> [b]
map (Double -> Integer -> Double
forall a b. (Num a, Integral b) => a -> b -> a
^ Integer
2) [Double]
probs)

majorityValue :: forall a. (Columnable a) => T.Text -> DataFrame -> a
majorityValue :: forall a. Columnable a => Text -> DataFrame -> a
majorityValue Text
target DataFrame
df =
    let counts :: Map a Int
counts = forall a. Columnable a => Text -> DataFrame -> Map a Int
getCounts @a Text
target DataFrame
df
     in if Map a Int -> Bool
forall k a. Map k a -> Bool
M.null Map a Int
counts
            then String -> a
forall a. HasCallStack => String -> a
error String
"Empty DataFrame in leaf"
            else (a, Int) -> a
forall a b. (a, b) -> a
fst ((a, Int) -> a) -> (a, Int) -> a
forall a b. (a -> b) -> a -> b
$ ((a, Int) -> (a, Int) -> Ordering) -> [(a, Int)] -> (a, Int)
forall (t :: * -> *) a.
Foldable t =>
(a -> a -> Ordering) -> t a -> a
maximumBy (Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Int -> Int -> Ordering)
-> ((a, Int) -> Int) -> (a, Int) -> (a, Int) -> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (a, Int) -> Int
forall a b. (a, b) -> b
snd) (Map a Int -> [(a, Int)]
forall k a. Map k a -> [(k, a)]
M.toList Map a Int
counts)

getCounts :: forall a. (Columnable a) => T.Text -> DataFrame -> M.Map a Int
getCounts :: forall a. Columnable a => Text -> DataFrame -> Map a Int
getCounts Text
target DataFrame
df =
    case forall a.
Columnable a =>
DataFrame -> Expr a -> Either DataFrameException (TypedColumn a)
interpret @a DataFrame
df (Text -> Expr a
forall a. Columnable a => Text -> Expr a
Col Text
target) of
        Left DataFrameException
e -> DataFrameException -> Map a Int
forall a e. Exception e => e -> a
throw DataFrameException
e
        Right (TColumn Column
col) ->
            case forall a (v :: * -> *).
(Vector v a, Columnable a) =>
Column -> Either DataFrameException (v a)
toVector @a Column
col of
                Left DataFrameException
e -> DataFrameException -> Map a Int
forall a e. Exception e => e -> a
throw DataFrameException
e
                Right Vector a
vals -> (Map a Int -> a -> Map a Int) -> Map a Int -> [a] -> Map a Int
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\Map a Int
acc a
x -> (Int -> Int -> Int) -> a -> Int -> Map a Int -> Map a Int
forall k a. Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
M.insertWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) a
x Int
1 Map a Int
acc) Map a Int
forall k a. Map k a
M.empty (Vector a -> [a]
forall a. Vector a -> [a]
V.toList Vector a
vals)

percentile :: Int -> Expr Double -> DataFrame -> Double
percentile :: Int -> Expr Double -> DataFrame -> Double
percentile Int
p Expr Double
expr DataFrame
df =
    case forall a.
Columnable a =>
DataFrame -> Expr a -> Either DataFrameException (TypedColumn a)
interpret @Double DataFrame
df Expr Double
expr of
        Left DataFrameException
_ -> Double
0
        Right (TColumn Column
col) ->
            case forall a (v :: * -> *).
(Vector v a, Columnable a) =>
Column -> Either DataFrameException (v a)
toVector @Double Column
col of
                Left DataFrameException
_ -> Double
0
                Right Vector Double
vals ->
                    let sorted :: Vector Double
sorted = [Double] -> Vector Double
forall a. [a] -> Vector a
V.fromList ([Double] -> Vector Double) -> [Double] -> Vector Double
forall a b. (a -> b) -> a -> b
$ [Double] -> [Double]
forall a. Ord a => [a] -> [a]
sort ([Double] -> [Double]) -> [Double] -> [Double]
forall a b. (a -> b) -> a -> b
$ Vector Double -> [Double]
forall a. Vector a -> [a]
V.toList Vector Double
vals
                        n :: Int
n = Vector Double -> Int
forall a. Vector a -> Int
V.length Vector Double
sorted
                        idx :: Int
idx = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
0 (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ (Int
p Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
100
                     in if Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 then Double
0 else Vector Double
sorted Vector Double -> Int -> Double
forall a. Vector a -> Int -> a
V.! Int
idx

buildTree ::
    forall a.
    (Columnable a) =>
    TreeConfig ->
    Int ->
    T.Text ->
    [Expr Bool] ->
    DataFrame ->
    Expr a
buildTree :: forall a.
Columnable a =>
TreeConfig -> Int -> Text -> [Expr Bool] -> DataFrame -> Expr a
buildTree TreeConfig
cfg Int
depth Text
target [Expr Bool]
conds DataFrame
df =
    let
        tree :: Tree a
tree = forall a.
Columnable a =>
TreeConfig -> Int -> Text -> [Expr Bool] -> DataFrame -> Tree a
buildGreedyTree @a TreeConfig
cfg Int
depth Text
target [Expr Bool]
conds DataFrame
df
        indices :: Vector Int
indices = Int -> Int -> Vector Int
forall a. Num a => a -> Int -> Vector a
V.enumFromN Int
0 (DataFrame -> Int
nRows DataFrame
df)
        optimized :: Tree a
optimized = forall a.
Columnable a =>
TreeConfig
-> Text
-> [Expr Bool]
-> DataFrame
-> Vector Int
-> Tree a
-> Tree a
taoOptimize @a TreeConfig
cfg Text
target [Expr Bool]
conds DataFrame
df Vector Int
indices Tree a
tree
     in
        Expr a -> Expr a
forall a. (Columnable a, Eq a) => Expr a -> Expr a
pruneExpr (Tree a -> Expr a
forall a. Columnable a => Tree a -> Expr a
treeToExpr Tree a
optimized)

findBestSplit ::
    forall a.
    (Columnable a) =>
    TreeConfig -> T.Text -> [Expr Bool] -> DataFrame -> Maybe (Expr Bool)
findBestSplit :: forall a.
Columnable a =>
TreeConfig -> Text -> [Expr Bool] -> DataFrame -> Maybe (Expr Bool)
findBestSplit = forall a.
Columnable a =>
TreeConfig -> Text -> [Expr Bool] -> DataFrame -> Maybe (Expr Bool)
findBestGreedySplit @a

pruneTree :: forall a. (Columnable a, Eq a) => Expr a -> Expr a
pruneTree :: forall a. (Columnable a, Eq a) => Expr a -> Expr a
pruneTree = Expr a -> Expr a
forall a. (Columnable a, Eq a) => Expr a -> Expr a
pruneExpr