{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
module DataFrame.DecisionTree where
import qualified DataFrame.Functions as F
import DataFrame.Internal.Column
import DataFrame.Internal.DataFrame (DataFrame (..), unsafeGetColumn)
import DataFrame.Internal.Expression (Expr (..), eSize, getColumns)
import DataFrame.Internal.Interpreter (interpret)
import DataFrame.Internal.Statistics (percentile', percentileOrd')
import DataFrame.Internal.Types
import DataFrame.Operations.Core (columnNames, nRows)
import DataFrame.Operations.Subset (exclude, filterWhere)
import Control.Exception (throw)
import Control.Monad (guard)
import Data.Containers.ListUtils (nubOrd)
import Data.Function (on)
import Data.List (foldl', maximumBy, minimumBy, sort, sortBy)
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Data.Text as T
import Data.Type.Equality
import qualified Data.Vector as V
import qualified Data.Vector.Unboxed as VU
import Type.Reflection (typeRep)
import DataFrame.Operators
data TreeConfig = TreeConfig
{ TreeConfig -> Int
maxTreeDepth :: Int
, TreeConfig -> Int
minSamplesSplit :: Int
, TreeConfig -> Int
minLeafSize :: Int
, TreeConfig -> [Int]
percentiles :: [Int]
, TreeConfig -> Int
expressionPairs :: Int
, TreeConfig -> SynthConfig
synthConfig :: SynthConfig
, TreeConfig -> Int
taoIterations :: Int
, TreeConfig -> Double
taoConvergenceTol :: Double
}
deriving (TreeConfig -> TreeConfig -> Bool
(TreeConfig -> TreeConfig -> Bool)
-> (TreeConfig -> TreeConfig -> Bool) -> Eq TreeConfig
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: TreeConfig -> TreeConfig -> Bool
== :: TreeConfig -> TreeConfig -> Bool
$c/= :: TreeConfig -> TreeConfig -> Bool
/= :: TreeConfig -> TreeConfig -> Bool
Eq, Int -> TreeConfig -> ShowS
[TreeConfig] -> ShowS
TreeConfig -> String
(Int -> TreeConfig -> ShowS)
-> (TreeConfig -> String)
-> ([TreeConfig] -> ShowS)
-> Show TreeConfig
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> TreeConfig -> ShowS
showsPrec :: Int -> TreeConfig -> ShowS
$cshow :: TreeConfig -> String
show :: TreeConfig -> String
$cshowList :: [TreeConfig] -> ShowS
showList :: [TreeConfig] -> ShowS
Show)
data SynthConfig = SynthConfig
{ SynthConfig -> Int
maxExprDepth :: Int
, SynthConfig -> Int
boolExpansion :: Int
, SynthConfig -> [(Text, Text)]
disallowedCombinations :: [(T.Text, T.Text)]
, SynthConfig -> Double
complexityPenalty :: Double
, SynthConfig -> Bool
enableStringOps :: Bool
, SynthConfig -> Bool
enableCrossCols :: Bool
, SynthConfig -> Bool
enableArithOps :: Bool
}
deriving (SynthConfig -> SynthConfig -> Bool
(SynthConfig -> SynthConfig -> Bool)
-> (SynthConfig -> SynthConfig -> Bool) -> Eq SynthConfig
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: SynthConfig -> SynthConfig -> Bool
== :: SynthConfig -> SynthConfig -> Bool
$c/= :: SynthConfig -> SynthConfig -> Bool
/= :: SynthConfig -> SynthConfig -> Bool
Eq, Int -> SynthConfig -> ShowS
[SynthConfig] -> ShowS
SynthConfig -> String
(Int -> SynthConfig -> ShowS)
-> (SynthConfig -> String)
-> ([SynthConfig] -> ShowS)
-> Show SynthConfig
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> SynthConfig -> ShowS
showsPrec :: Int -> SynthConfig -> ShowS
$cshow :: SynthConfig -> String
show :: SynthConfig -> String
$cshowList :: [SynthConfig] -> ShowS
showList :: [SynthConfig] -> ShowS
Show)
defaultSynthConfig :: SynthConfig
defaultSynthConfig :: SynthConfig
defaultSynthConfig =
SynthConfig
{ maxExprDepth :: Int
maxExprDepth = Int
2
, boolExpansion :: Int
boolExpansion = Int
2
, disallowedCombinations :: [(Text, Text)]
disallowedCombinations = []
, complexityPenalty :: Double
complexityPenalty = Double
0.05
, enableStringOps :: Bool
enableStringOps = Bool
True
, enableCrossCols :: Bool
enableCrossCols = Bool
True
, enableArithOps :: Bool
enableArithOps = Bool
True
}
defaultTreeConfig :: TreeConfig
defaultTreeConfig :: TreeConfig
defaultTreeConfig =
TreeConfig
{ maxTreeDepth :: Int
maxTreeDepth = Int
4
, minSamplesSplit :: Int
minSamplesSplit = Int
5
, minLeafSize :: Int
minLeafSize = Int
1
, percentiles :: [Int]
percentiles = [Int
0, Int
10 .. Int
100]
, expressionPairs :: Int
expressionPairs = Int
10
, synthConfig :: SynthConfig
synthConfig = SynthConfig
defaultSynthConfig
, taoIterations :: Int
taoIterations = Int
10
, taoConvergenceTol :: Double
taoConvergenceTol = Double
1e-6
}
data Tree a
= Leaf !a
| Branch !(Expr Bool) !(Tree a) !(Tree a)
deriving (Tree a -> Tree a -> Bool
(Tree a -> Tree a -> Bool)
-> (Tree a -> Tree a -> Bool) -> Eq (Tree a)
forall a. Eq a => Tree a -> Tree a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall a. Eq a => Tree a -> Tree a -> Bool
== :: Tree a -> Tree a -> Bool
$c/= :: forall a. Eq a => Tree a -> Tree a -> Bool
/= :: Tree a -> Tree a -> Bool
Eq, Int -> Tree a -> ShowS
[Tree a] -> ShowS
Tree a -> String
(Int -> Tree a -> ShowS)
-> (Tree a -> String) -> ([Tree a] -> ShowS) -> Show (Tree a)
forall a. Show a => Int -> Tree a -> ShowS
forall a. Show a => [Tree a] -> ShowS
forall a. Show a => Tree a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall a. Show a => Int -> Tree a -> ShowS
showsPrec :: Int -> Tree a -> ShowS
$cshow :: forall a. Show a => Tree a -> String
show :: Tree a -> String
$cshowList :: forall a. Show a => [Tree a] -> ShowS
showList :: [Tree a] -> ShowS
Show)
treeDepth :: Tree a -> Int
treeDepth :: forall a. Tree a -> Int
treeDepth (Leaf a
_) = Int
0
treeDepth (Branch Expr Bool
_ Tree a
l Tree a
r) = Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int -> Int -> Int
forall a. Ord a => a -> a -> a
max (Tree a -> Int
forall a. Tree a -> Int
treeDepth Tree a
l) (Tree a -> Int
forall a. Tree a -> Int
treeDepth Tree a
r)
treeToExpr :: (Columnable a) => Tree a -> Expr a
treeToExpr :: forall a. Columnable a => Tree a -> Expr a
treeToExpr (Leaf a
v) = a -> Expr a
forall a. Columnable a => a -> Expr a
Lit a
v
treeToExpr (Branch Expr Bool
cond Tree a
left Tree a
right) =
Expr Bool -> Expr a -> Expr a -> Expr a
forall a. Columnable a => Expr Bool -> Expr a -> Expr a -> Expr a
F.ifThenElse Expr Bool
cond (Tree a -> Expr a
forall a. Columnable a => Tree a -> Expr a
treeToExpr Tree a
left) (Tree a -> Expr a
forall a. Columnable a => Tree a -> Expr a
treeToExpr Tree a
right)
fitDecisionTree ::
forall a.
(Columnable a) =>
TreeConfig ->
Expr a ->
DataFrame ->
Expr a
fitDecisionTree :: forall a.
Columnable a =>
TreeConfig -> Expr a -> DataFrame -> Expr a
fitDecisionTree TreeConfig
cfg (Col Text
target) DataFrame
df =
let
conds :: [Expr Bool]
conds =
[Expr Bool] -> [Expr Bool]
forall a. Ord a => [a] -> [a]
nubOrd ([Expr Bool] -> [Expr Bool]) -> [Expr Bool] -> [Expr Bool]
forall a b. (a -> b) -> a -> b
$
TreeConfig -> DataFrame -> [Expr Bool]
numericConditions TreeConfig
cfg ([Text] -> DataFrame -> DataFrame
exclude [Text
target] DataFrame
df)
[Expr Bool] -> [Expr Bool] -> [Expr Bool]
forall a. [a] -> [a] -> [a]
++ TreeConfig -> DataFrame -> [Expr Bool]
generateConditionsOld TreeConfig
cfg ([Text] -> DataFrame -> DataFrame
exclude [Text
target] DataFrame
df)
initialTree :: Tree a
initialTree = forall a.
Columnable a =>
TreeConfig -> Int -> Text -> [Expr Bool] -> DataFrame -> Tree a
buildGreedyTree @a TreeConfig
cfg (TreeConfig -> Int
maxTreeDepth TreeConfig
cfg) Text
target [Expr Bool]
conds DataFrame
df
indices :: Vector Int
indices = Int -> Int -> Vector Int
forall a. Num a => a -> Int -> Vector a
V.enumFromN Int
0 (DataFrame -> Int
nRows DataFrame
df)
optimizedTree :: Tree a
optimizedTree = forall a.
Columnable a =>
TreeConfig
-> Text
-> [Expr Bool]
-> DataFrame
-> Vector Int
-> Tree a
-> Tree a
taoOptimize @a TreeConfig
cfg Text
target [Expr Bool]
conds DataFrame
df Vector Int
indices Tree a
initialTree
in
Expr a -> Expr a
forall a. (Columnable a, Eq a) => Expr a -> Expr a
pruneExpr (Tree a -> Expr a
forall a. Columnable a => Tree a -> Expr a
treeToExpr Tree a
optimizedTree)
fitDecisionTree TreeConfig
_ Expr a
expr DataFrame
_ = String -> Expr a
forall a. HasCallStack => String -> a
error (String -> Expr a) -> String -> Expr a
forall a b. (a -> b) -> a -> b
$ String
"Cannot create tree for compound expression: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Expr a -> String
forall a. Show a => a -> String
show Expr a
expr
taoOptimize ::
forall a.
(Columnable a) =>
TreeConfig ->
T.Text ->
[Expr Bool] ->
DataFrame ->
V.Vector Int ->
Tree a ->
Tree a
taoOptimize :: forall a.
Columnable a =>
TreeConfig
-> Text
-> [Expr Bool]
-> DataFrame
-> Vector Int
-> Tree a
-> Tree a
taoOptimize TreeConfig
cfg Text
target [Expr Bool]
conds DataFrame
df Vector Int
rootIndices Tree a
initialTree =
Int -> Tree a -> Double -> Tree a
go Int
0 Tree a
initialTree (forall a.
Columnable a =>
Text -> DataFrame -> Vector Int -> Tree a -> Double
computeTreeLoss @a Text
target DataFrame
df Vector Int
rootIndices Tree a
initialTree)
where
go :: Int -> Tree a -> Double -> Tree a
go :: Int -> Tree a -> Double -> Tree a
go Int
iter Tree a
tree Double
prevLoss
| Int
iter Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= TreeConfig -> Int
taoIterations TreeConfig
cfg = Tree a -> Tree a
forall a. Tree a -> Tree a
pruneDead Tree a
tree
| Bool
otherwise =
let
tree' :: Tree a
tree' = forall a.
Columnable a =>
TreeConfig
-> Text
-> [Expr Bool]
-> DataFrame
-> Vector Int
-> Tree a
-> Tree a
taoIteration @a TreeConfig
cfg Text
target [Expr Bool]
conds DataFrame
df Vector Int
rootIndices Tree a
tree
newLoss :: Double
newLoss = forall a.
Columnable a =>
Text -> DataFrame -> Vector Int -> Tree a -> Double
computeTreeLoss @a Text
target DataFrame
df Vector Int
rootIndices Tree a
tree'
improvement :: Double
improvement = Double
prevLoss Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
newLoss
in
if Double
improvement Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< TreeConfig -> Double
taoConvergenceTol TreeConfig
cfg
then Tree a -> Tree a
forall a. Tree a -> Tree a
pruneDead Tree a
tree'
else Int -> Tree a -> Double -> Tree a
go (Int
iter Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Tree a
tree' Double
newLoss
taoIteration ::
forall a.
(Columnable a) =>
TreeConfig ->
T.Text ->
[Expr Bool] ->
DataFrame ->
V.Vector Int ->
Tree a ->
Tree a
taoIteration :: forall a.
Columnable a =>
TreeConfig
-> Text
-> [Expr Bool]
-> DataFrame
-> Vector Int
-> Tree a
-> Tree a
taoIteration TreeConfig
cfg Text
target [Expr Bool]
conds DataFrame
df Vector Int
rootIndices Tree a
tree =
let depth :: Int
depth = Tree a -> Int
forall a. Tree a -> Int
treeDepth Tree a
tree
in (Tree a -> Int -> Tree a) -> Tree a -> [Int] -> Tree a
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl'
(forall a.
Columnable a =>
TreeConfig
-> Text
-> [Expr Bool]
-> DataFrame
-> Vector Int
-> Tree a
-> Int
-> Tree a
optimizeDepthLevel @a TreeConfig
cfg Text
target [Expr Bool]
conds DataFrame
df Vector Int
rootIndices)
Tree a
tree
[Int
depth, Int
depth Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 .. Int
0]
optimizeDepthLevel ::
forall a.
(Columnable a) =>
TreeConfig ->
T.Text ->
[Expr Bool] ->
DataFrame ->
V.Vector Int ->
Tree a ->
Int ->
Tree a
optimizeDepthLevel :: forall a.
Columnable a =>
TreeConfig
-> Text
-> [Expr Bool]
-> DataFrame
-> Vector Int
-> Tree a
-> Int
-> Tree a
optimizeDepthLevel TreeConfig
cfg Text
target [Expr Bool]
conds DataFrame
df Vector Int
rootIndices Tree a
tree = forall a.
Columnable a =>
TreeConfig
-> Text
-> [Expr Bool]
-> DataFrame
-> Vector Int
-> Tree a
-> Int
-> Int
-> Tree a
optimizeAtDepth @a TreeConfig
cfg Text
target [Expr Bool]
conds DataFrame
df Vector Int
rootIndices Tree a
tree Int
0
optimizeAtDepth ::
forall a.
(Columnable a) =>
TreeConfig ->
T.Text ->
[Expr Bool] ->
DataFrame ->
V.Vector Int ->
Tree a ->
Int ->
Int ->
Tree a
optimizeAtDepth :: forall a.
Columnable a =>
TreeConfig
-> Text
-> [Expr Bool]
-> DataFrame
-> Vector Int
-> Tree a
-> Int
-> Int
-> Tree a
optimizeAtDepth TreeConfig
cfg Text
target [Expr Bool]
conds DataFrame
df Vector Int
indices Tree a
tree Int
currentDepth Int
targetDepth
| Int
currentDepth Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
targetDepth =
forall a.
Columnable a =>
TreeConfig
-> Text
-> [Expr Bool]
-> DataFrame
-> Vector Int
-> Tree a
-> Tree a
optimizeNode @a TreeConfig
cfg Text
target [Expr Bool]
conds DataFrame
df Vector Int
indices Tree a
tree
| Bool
otherwise = case Tree a
tree of
Leaf a
v -> a -> Tree a
forall a. a -> Tree a
Leaf a
v
Branch Expr Bool
cond Tree a
left Tree a
right ->
let
(Vector Int
indicesL, Vector Int
indicesR) = Expr Bool -> DataFrame -> Vector Int -> (Vector Int, Vector Int)
partitionIndices Expr Bool
cond DataFrame
df Vector Int
indices
left' :: Tree a
left' =
forall a.
Columnable a =>
TreeConfig
-> Text
-> [Expr Bool]
-> DataFrame
-> Vector Int
-> Tree a
-> Int
-> Int
-> Tree a
optimizeAtDepth @a
TreeConfig
cfg
Text
target
[Expr Bool]
conds
DataFrame
df
Vector Int
indicesL
Tree a
left
(Int
currentDepth Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
Int
targetDepth
right' :: Tree a
right' =
forall a.
Columnable a =>
TreeConfig
-> Text
-> [Expr Bool]
-> DataFrame
-> Vector Int
-> Tree a
-> Int
-> Int
-> Tree a
optimizeAtDepth @a
TreeConfig
cfg
Text
target
[Expr Bool]
conds
DataFrame
df
Vector Int
indicesR
Tree a
right
(Int
currentDepth Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
Int
targetDepth
in
Expr Bool -> Tree a -> Tree a -> Tree a
forall a. Expr Bool -> Tree a -> Tree a -> Tree a
Branch Expr Bool
cond Tree a
left' Tree a
right'
optimizeNode ::
forall a.
(Columnable a) =>
TreeConfig ->
T.Text ->
[Expr Bool] ->
DataFrame ->
V.Vector Int ->
Tree a ->
Tree a
optimizeNode :: forall a.
Columnable a =>
TreeConfig
-> Text
-> [Expr Bool]
-> DataFrame
-> Vector Int
-> Tree a
-> Tree a
optimizeNode TreeConfig
cfg Text
target [Expr Bool]
conds DataFrame
df Vector Int
indices Tree a
tree
| Vector Int -> Bool
forall a. Vector a -> Bool
V.null Vector Int
indices = Tree a
tree
| Bool
otherwise = case Tree a
tree of
Leaf a
_ -> a -> Tree a
forall a. a -> Tree a
Leaf (forall a. Columnable a => Text -> DataFrame -> Vector Int -> a
majorityValueFromIndices @a Text
target DataFrame
df Vector Int
indices)
Branch Expr Bool
oldCond Tree a
left Tree a
right ->
let
newCond :: Expr Bool
newCond = forall a.
Columnable a =>
TreeConfig
-> Text
-> [Expr Bool]
-> DataFrame
-> Vector Int
-> Tree a
-> Tree a
-> Expr Bool
-> Expr Bool
findBestSplitTAO @a TreeConfig
cfg Text
target [Expr Bool]
conds DataFrame
df Vector Int
indices Tree a
left Tree a
right Expr Bool
oldCond
(Vector Int
newIndicesL, Vector Int
newIndicesR) = Expr Bool -> DataFrame -> Vector Int -> (Vector Int, Vector Int)
partitionIndices Expr Bool
newCond DataFrame
df Vector Int
indices
in
if Vector Int -> Int
forall a. Vector a -> Int
V.length Vector Int
newIndicesL Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< TreeConfig -> Int
minLeafSize TreeConfig
cfg
Bool -> Bool -> Bool
|| Vector Int -> Int
forall a. Vector a -> Int
V.length Vector Int
newIndicesR Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< TreeConfig -> Int
minLeafSize TreeConfig
cfg
then a -> Tree a
forall a. a -> Tree a
Leaf (forall a. Columnable a => Text -> DataFrame -> Vector Int -> a
majorityValueFromIndices @a Text
target DataFrame
df Vector Int
indices)
else Expr Bool -> Tree a -> Tree a -> Tree a
forall a. Expr Bool -> Tree a -> Tree a -> Tree a
Branch Expr Bool
newCond Tree a
left Tree a
right
findBestSplitTAO ::
forall a.
(Columnable a) =>
TreeConfig ->
T.Text ->
[Expr Bool] ->
DataFrame ->
V.Vector Int ->
Tree a ->
Tree a ->
Expr Bool ->
Expr Bool
findBestSplitTAO :: forall a.
Columnable a =>
TreeConfig
-> Text
-> [Expr Bool]
-> DataFrame
-> Vector Int
-> Tree a
-> Tree a
-> Expr Bool
-> Expr Bool
findBestSplitTAO TreeConfig
cfg Text
target [Expr Bool]
conds DataFrame
df Vector Int
indices Tree a
leftTree Tree a
rightTree Expr Bool
currentCond
| Vector Int -> Bool
forall a. Vector a -> Bool
V.null Vector Int
indices = Expr Bool
currentCond
| [Expr Bool] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Expr Bool]
validConds = Expr Bool
currentCond
| Bool
otherwise =
let
carePoints :: [CarePoint]
carePoints = forall a.
Columnable a =>
Text -> DataFrame -> Vector Int -> Tree a -> Tree a -> [CarePoint]
identifyCarePoints @a Text
target DataFrame
df Vector Int
indices Tree a
leftTree Tree a
rightTree
in
if [CarePoint] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [CarePoint]
carePoints
then Expr Bool
currentCond
else
let
evalSplit :: Expr Bool -> Int
evalSplit :: Expr Bool -> Int
evalSplit Expr Bool
cond = Expr Bool -> DataFrame -> [CarePoint] -> Int
countCarePointErrors Expr Bool
cond DataFrame
df [CarePoint]
carePoints
evalWithPenalty :: Expr Bool -> Int
evalWithPenalty Expr Bool
c =
let errors :: Int
errors = Expr Bool -> Int
evalSplit Expr Bool
c
penalty :: Int
penalty =
Double -> Int
forall b. Integral b => Double -> b
forall a b. (RealFrac a, Integral b) => a -> b
floor
( SynthConfig -> Double
complexityPenalty (TreeConfig -> SynthConfig
synthConfig TreeConfig
cfg)
Double -> Double -> Double
forall a. Num a => a -> a -> a
* Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Expr Bool -> Int
forall a. Expr a -> Int
eSize Expr Bool
c)
)
in Int
errors Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
penalty
sortedConds :: [Expr Bool]
sortedConds =
Int -> [Expr Bool] -> [Expr Bool]
forall a. Int -> [a] -> [a]
take (TreeConfig -> Int
expressionPairs TreeConfig
cfg) ([Expr Bool] -> [Expr Bool]) -> [Expr Bool] -> [Expr Bool]
forall a b. (a -> b) -> a -> b
$
(Expr Bool -> Expr Bool -> Ordering) -> [Expr Bool] -> [Expr Bool]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Int -> Int -> Ordering)
-> (Expr Bool -> Int) -> Expr Bool -> Expr Bool -> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` Expr Bool -> Int
evalWithPenalty) [Expr Bool]
validConds
expandedConds :: [Expr Bool]
expandedConds =
DataFrame
-> [Expr Bool] -> [Expr Bool] -> Int -> Int -> [Expr Bool]
boolExprs
DataFrame
df
[Expr Bool]
sortedConds
[Expr Bool]
sortedConds
Int
0
(SynthConfig -> Int
boolExpansion (TreeConfig -> SynthConfig
synthConfig TreeConfig
cfg))
in
if [Expr Bool] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Expr Bool]
expandedConds
then Expr Bool
currentCond
else (Expr Bool -> Expr Bool -> Ordering) -> [Expr Bool] -> Expr Bool
forall (t :: * -> *) a.
Foldable t =>
(a -> a -> Ordering) -> t a -> a
minimumBy (Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Int -> Int -> Ordering)
-> (Expr Bool -> Int) -> Expr Bool -> Expr Bool -> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` Expr Bool -> Int
evalWithPenalty) [Expr Bool]
expandedConds
where
validConds :: [Expr Bool]
validConds = (Expr Bool -> Bool) -> [Expr Bool] -> [Expr Bool]
forall a. (a -> Bool) -> [a] -> [a]
filter Expr Bool -> Bool
isValidSplit [Expr Bool]
conds
isValidSplit :: Expr Bool -> Bool
isValidSplit Expr Bool
c =
let (Vector Int
t, Vector Int
f) = Expr Bool -> DataFrame -> Vector Int -> (Vector Int, Vector Int)
partitionIndices Expr Bool
c DataFrame
df Vector Int
indices
in Vector Int -> Int
forall a. Vector a -> Int
V.length Vector Int
t Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= TreeConfig -> Int
minLeafSize TreeConfig
cfg Bool -> Bool -> Bool
&& Vector Int -> Int
forall a. Vector a -> Int
V.length Vector Int
f Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= TreeConfig -> Int
minLeafSize TreeConfig
cfg
data CarePoint = CarePoint
{ CarePoint -> Int
cpIndex :: !Int
, CarePoint -> Direction
cpCorrectDir :: !Direction
}
deriving (CarePoint -> CarePoint -> Bool
(CarePoint -> CarePoint -> Bool)
-> (CarePoint -> CarePoint -> Bool) -> Eq CarePoint
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: CarePoint -> CarePoint -> Bool
== :: CarePoint -> CarePoint -> Bool
$c/= :: CarePoint -> CarePoint -> Bool
/= :: CarePoint -> CarePoint -> Bool
Eq, Int -> CarePoint -> ShowS
[CarePoint] -> ShowS
CarePoint -> String
(Int -> CarePoint -> ShowS)
-> (CarePoint -> String)
-> ([CarePoint] -> ShowS)
-> Show CarePoint
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> CarePoint -> ShowS
showsPrec :: Int -> CarePoint -> ShowS
$cshow :: CarePoint -> String
show :: CarePoint -> String
$cshowList :: [CarePoint] -> ShowS
showList :: [CarePoint] -> ShowS
Show)
data Direction = GoLeft | GoRight
deriving (Direction -> Direction -> Bool
(Direction -> Direction -> Bool)
-> (Direction -> Direction -> Bool) -> Eq Direction
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Direction -> Direction -> Bool
== :: Direction -> Direction -> Bool
$c/= :: Direction -> Direction -> Bool
/= :: Direction -> Direction -> Bool
Eq, Int -> Direction -> ShowS
[Direction] -> ShowS
Direction -> String
(Int -> Direction -> ShowS)
-> (Direction -> String)
-> ([Direction] -> ShowS)
-> Show Direction
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Direction -> ShowS
showsPrec :: Int -> Direction -> ShowS
$cshow :: Direction -> String
show :: Direction -> String
$cshowList :: [Direction] -> ShowS
showList :: [Direction] -> ShowS
Show)
identifyCarePoints ::
forall a.
(Columnable a) =>
T.Text ->
DataFrame ->
V.Vector Int ->
Tree a ->
Tree a ->
[CarePoint]
identifyCarePoints :: forall a.
Columnable a =>
Text -> DataFrame -> Vector Int -> Tree a -> Tree a -> [CarePoint]
identifyCarePoints Text
target DataFrame
df Vector Int
indices Tree a
leftTree Tree a
rightTree =
case forall a.
Columnable a =>
DataFrame -> Expr a -> Either DataFrameException (TypedColumn a)
interpret @a DataFrame
df (Text -> Expr a
forall a. Columnable a => Text -> Expr a
Col Text
target) of
Left DataFrameException
_ -> []
Right (TColumn Column
col) ->
case forall a (v :: * -> *).
(Vector v a, Columnable a) =>
Column -> Either DataFrameException (v a)
toVector @a Column
col of
Left DataFrameException
_ -> []
Right Vector a
targetVals ->
Vector CarePoint -> [CarePoint]
forall a. Vector a -> [a]
V.toList (Vector CarePoint -> [CarePoint])
-> Vector CarePoint -> [CarePoint]
forall a b. (a -> b) -> a -> b
$ (Int -> Maybe CarePoint) -> Vector Int -> Vector CarePoint
forall a b. (a -> Maybe b) -> Vector a -> Vector b
V.mapMaybe (Vector a -> Int -> Maybe CarePoint
checkPoint Vector a
targetVals) Vector Int
indices
where
checkPoint :: V.Vector a -> Int -> Maybe CarePoint
checkPoint :: Vector a -> Int -> Maybe CarePoint
checkPoint Vector a
targetVals Int
idx =
let
trueLabel :: a
trueLabel = Vector a
targetVals Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.! Int
idx
leftPred :: a
leftPred = forall a. Columnable a => Text -> DataFrame -> Int -> Tree a -> a
predictWithTree @a Text
target DataFrame
df Int
idx Tree a
leftTree
rightPred :: a
rightPred = forall a. Columnable a => Text -> DataFrame -> Int -> Tree a -> a
predictWithTree @a Text
target DataFrame
df Int
idx Tree a
rightTree
leftCorrect :: Bool
leftCorrect = a
leftPred a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
trueLabel
rightCorrect :: Bool
rightCorrect = a
rightPred a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
trueLabel
in
case (Bool
leftCorrect, Bool
rightCorrect) of
(Bool
True, Bool
False) -> CarePoint -> Maybe CarePoint
forall a. a -> Maybe a
Just (CarePoint -> Maybe CarePoint) -> CarePoint -> Maybe CarePoint
forall a b. (a -> b) -> a -> b
$ Int -> Direction -> CarePoint
CarePoint Int
idx Direction
GoLeft
(Bool
False, Bool
True) -> CarePoint -> Maybe CarePoint
forall a. a -> Maybe a
Just (CarePoint -> Maybe CarePoint) -> CarePoint -> Maybe CarePoint
forall a b. (a -> b) -> a -> b
$ Int -> Direction -> CarePoint
CarePoint Int
idx Direction
GoRight
(Bool, Bool)
_ -> Maybe CarePoint
forall a. Maybe a
Nothing
predictWithTree ::
forall a.
(Columnable a) =>
T.Text ->
DataFrame ->
Int ->
Tree a ->
a
predictWithTree :: forall a. Columnable a => Text -> DataFrame -> Int -> Tree a -> a
predictWithTree Text
target DataFrame
df Int
idx (Leaf a
v) = a
v
predictWithTree Text
target DataFrame
df Int
idx (Branch Expr Bool
cond Tree a
left Tree a
right) =
case forall a.
Columnable a =>
DataFrame -> Expr a -> Either DataFrameException (TypedColumn a)
interpret @Bool DataFrame
df Expr Bool
cond of
Left DataFrameException
_ -> forall a. Columnable a => Text -> DataFrame -> Int -> Tree a -> a
predictWithTree @a Text
target DataFrame
df Int
idx Tree a
left
Right (TColumn Column
col) ->
case forall a (v :: * -> *).
(Vector v a, Columnable a) =>
Column -> Either DataFrameException (v a)
toVector @Bool Column
col of
Left DataFrameException
_ -> forall a. Columnable a => Text -> DataFrame -> Int -> Tree a -> a
predictWithTree @a Text
target DataFrame
df Int
idx Tree a
left
Right Vector Bool
boolVals ->
if Vector Bool
boolVals Vector Bool -> Int -> Bool
forall a. Vector a -> Int -> a
V.! Int
idx
then forall a. Columnable a => Text -> DataFrame -> Int -> Tree a -> a
predictWithTree @a Text
target DataFrame
df Int
idx Tree a
left
else forall a. Columnable a => Text -> DataFrame -> Int -> Tree a -> a
predictWithTree @a Text
target DataFrame
df Int
idx Tree a
right
countCarePointErrors :: Expr Bool -> DataFrame -> [CarePoint] -> Int
countCarePointErrors :: Expr Bool -> DataFrame -> [CarePoint] -> Int
countCarePointErrors Expr Bool
cond DataFrame
df [CarePoint]
carePoints =
case forall a.
Columnable a =>
DataFrame -> Expr a -> Either DataFrameException (TypedColumn a)
interpret @Bool DataFrame
df Expr Bool
cond of
Left DataFrameException
_ -> [CarePoint] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [CarePoint]
carePoints
Right (TColumn Column
col) ->
case forall a (v :: * -> *).
(Vector v a, Columnable a) =>
Column -> Either DataFrameException (v a)
toVector @Bool Column
col of
Left DataFrameException
_ -> [CarePoint] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [CarePoint]
carePoints
Right Vector Bool
boolVals ->
[CarePoint] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([CarePoint] -> Int) -> [CarePoint] -> Int
forall a b. (a -> b) -> a -> b
$ (CarePoint -> Bool) -> [CarePoint] -> [CarePoint]
forall a. (a -> Bool) -> [a] -> [a]
filter (Vector Bool -> CarePoint -> Bool
isMisclassified Vector Bool
boolVals) [CarePoint]
carePoints
where
isMisclassified :: V.Vector Bool -> CarePoint -> Bool
isMisclassified :: Vector Bool -> CarePoint -> Bool
isMisclassified Vector Bool
boolVals CarePoint
cp =
let goesLeft :: Bool
goesLeft = Vector Bool
boolVals Vector Bool -> Int -> Bool
forall a. Vector a -> Int -> a
V.! CarePoint -> Int
cpIndex CarePoint
cp
shouldGoLeft :: Bool
shouldGoLeft = CarePoint -> Direction
cpCorrectDir CarePoint
cp Direction -> Direction -> Bool
forall a. Eq a => a -> a -> Bool
== Direction
GoLeft
in Bool
goesLeft Bool -> Bool -> Bool
forall a. Eq a => a -> a -> Bool
/= Bool
shouldGoLeft
partitionIndices ::
Expr Bool -> DataFrame -> V.Vector Int -> (V.Vector Int, V.Vector Int)
partitionIndices :: Expr Bool -> DataFrame -> Vector Int -> (Vector Int, Vector Int)
partitionIndices Expr Bool
cond DataFrame
df Vector Int
indices =
case forall a.
Columnable a =>
DataFrame -> Expr a -> Either DataFrameException (TypedColumn a)
interpret @Bool DataFrame
df Expr Bool
cond of
Left DataFrameException
_ -> (Vector Int
indices, Vector Int
forall a. Vector a
V.empty)
Right (TColumn Column
col) ->
case forall a (v :: * -> *).
(Vector v a, Columnable a) =>
Column -> Either DataFrameException (v a)
toVector @Bool Column
col of
Left DataFrameException
_ -> (Vector Int
indices, Vector Int
forall a. Vector a
V.empty)
Right Vector Bool
boolVals ->
(Int -> Bool) -> Vector Int -> (Vector Int, Vector Int)
forall a. (a -> Bool) -> Vector a -> (Vector a, Vector a)
V.partition (Vector Bool
boolVals Vector Bool -> Int -> Bool
forall a. Vector a -> Int -> a
V.!) Vector Int
indices
majorityValueFromIndices ::
forall a.
(Columnable a) =>
T.Text ->
DataFrame ->
V.Vector Int ->
a
majorityValueFromIndices :: forall a. Columnable a => Text -> DataFrame -> Vector Int -> a
majorityValueFromIndices Text
target DataFrame
df Vector Int
indices =
case forall a.
Columnable a =>
DataFrame -> Expr a -> Either DataFrameException (TypedColumn a)
interpret @a DataFrame
df (Text -> Expr a
forall a. Columnable a => Text -> Expr a
Col Text
target) of
Left DataFrameException
e -> DataFrameException -> a
forall a e. Exception e => e -> a
throw DataFrameException
e
Right (TColumn Column
col) ->
case forall a (v :: * -> *).
(Vector v a, Columnable a) =>
Column -> Either DataFrameException (v a)
toVector @a Column
col of
Left DataFrameException
e -> DataFrameException -> a
forall a e. Exception e => e -> a
throw DataFrameException
e
Right Vector a
vals ->
let counts :: Map a Integer
counts =
(Map a Integer -> Int -> Map a Integer)
-> Map a Integer -> Vector Int -> Map a Integer
forall a b. (a -> b -> a) -> a -> Vector b -> a
V.foldl'
(\Map a Integer
acc Int
i -> (Integer -> Integer -> Integer)
-> a -> Integer -> Map a Integer -> Map a Integer
forall k a. Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
M.insertWith Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
(+) (Vector a
vals Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.! Int
i) Integer
1 Map a Integer
acc)
Map a Integer
forall k a. Map k a
M.empty
Vector Int
indices
in if Map a Integer -> Bool
forall k a. Map k a -> Bool
M.null Map a Integer
counts
then String -> a
forall a. HasCallStack => String -> a
error String
"Empty indices in majorityValueFromIndices"
else (a, Integer) -> a
forall a b. (a, b) -> a
fst ((a, Integer) -> a) -> (a, Integer) -> a
forall a b. (a -> b) -> a -> b
$ ((a, Integer) -> (a, Integer) -> Ordering)
-> [(a, Integer)] -> (a, Integer)
forall (t :: * -> *) a.
Foldable t =>
(a -> a -> Ordering) -> t a -> a
maximumBy (Integer -> Integer -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Integer -> Integer -> Ordering)
-> ((a, Integer) -> Integer)
-> (a, Integer)
-> (a, Integer)
-> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (a, Integer) -> Integer
forall a b. (a, b) -> b
snd) (Map a Integer -> [(a, Integer)]
forall k a. Map k a -> [(k, a)]
M.toList Map a Integer
counts)
computeTreeLoss ::
forall a.
(Columnable a) =>
T.Text ->
DataFrame ->
V.Vector Int ->
Tree a ->
Double
computeTreeLoss :: forall a.
Columnable a =>
Text -> DataFrame -> Vector Int -> Tree a -> Double
computeTreeLoss Text
target DataFrame
df Vector Int
indices Tree a
tree
| Vector Int -> Bool
forall a. Vector a -> Bool
V.null Vector Int
indices = Double
0
| Bool
otherwise =
case forall a.
Columnable a =>
DataFrame -> Expr a -> Either DataFrameException (TypedColumn a)
interpret @a DataFrame
df (Text -> Expr a
forall a. Columnable a => Text -> Expr a
Col Text
target) of
Left DataFrameException
_ -> Double
1.0
Right (TColumn Column
col) ->
case forall a (v :: * -> *).
(Vector v a, Columnable a) =>
Column -> Either DataFrameException (v a)
toVector @a Column
col of
Left DataFrameException
_ -> Double
1.0
Right Vector a
targetVals ->
let
n :: Int
n = Vector Int -> Int
forall a. Vector a -> Int
V.length Vector Int
indices
errors :: Int
errors =
Vector Int -> Int
forall a. Vector a -> Int
V.length (Vector Int -> Int) -> Vector Int -> Int
forall a b. (a -> b) -> a -> b
$
(Int -> Bool) -> Vector Int -> Vector Int
forall a. (a -> Bool) -> Vector a -> Vector a
V.filter
(\Int
i -> Vector a
targetVals Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.! Int
i a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= forall a. Columnable a => Text -> DataFrame -> Int -> Tree a -> a
predictWithTree @a Text
target DataFrame
df Int
i Tree a
tree)
Vector Int
indices
in
Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
errors Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n
pruneDead :: Tree a -> Tree a
pruneDead :: forall a. Tree a -> Tree a
pruneDead (Leaf a
v) = a -> Tree a
forall a. a -> Tree a
Leaf a
v
pruneDead (Branch Expr Bool
cond Tree a
left Tree a
right) =
let
left' :: Tree a
left' = Tree a -> Tree a
forall a. Tree a -> Tree a
pruneDead Tree a
left
right' :: Tree a
right' = Tree a -> Tree a
forall a. Tree a -> Tree a
pruneDead Tree a
right
in
Expr Bool -> Tree a -> Tree a -> Tree a
forall a. Expr Bool -> Tree a -> Tree a -> Tree a
Branch Expr Bool
cond Tree a
left' Tree a
right'
pruneExpr :: forall a. (Columnable a, Eq a) => Expr a -> Expr a
pruneExpr :: forall a. (Columnable a, Eq a) => Expr a -> Expr a
pruneExpr (If Expr Bool
cond Expr a
trueBranch Expr a
falseBranch) =
let t :: Expr a
t = Expr a -> Expr a
forall a. (Columnable a, Eq a) => Expr a -> Expr a
pruneExpr Expr a
trueBranch
f :: Expr a
f = Expr a -> Expr a
forall a. (Columnable a, Eq a) => Expr a -> Expr a
pruneExpr Expr a
falseBranch
in if Expr a
t Expr a -> Expr a -> Bool
forall a. Eq a => a -> a -> Bool
== Expr a
f
then Expr a
t
else case (Expr a
t, Expr a
f) of
(If Expr Bool
condInner Expr a
tInner Expr a
_, Expr a
_) | Expr Bool
cond Expr Bool -> Expr Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Expr Bool
condInner -> Expr Bool -> Expr a -> Expr a -> Expr a
forall a. Columnable a => Expr Bool -> Expr a -> Expr a -> Expr a
If Expr Bool
cond Expr a
tInner Expr a
f
(Expr a
_, If Expr Bool
condInner Expr a
_ Expr a
fInner) | Expr Bool
cond Expr Bool -> Expr Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Expr Bool
condInner -> Expr Bool -> Expr a -> Expr a -> Expr a
forall a. Columnable a => Expr Bool -> Expr a -> Expr a -> Expr a
If Expr Bool
cond Expr a
t Expr a
fInner
(Expr a, Expr a)
_ -> Expr Bool -> Expr a -> Expr a -> Expr a
forall a. Columnable a => Expr Bool -> Expr a -> Expr a -> Expr a
If Expr Bool
cond Expr a
t Expr a
f
pruneExpr (Unary UnaryOp b a
op Expr b
e) = UnaryOp b a -> Expr b -> Expr a
forall a b.
(Columnable a, Columnable b) =>
UnaryOp b a -> Expr b -> Expr a
Unary UnaryOp b a
op (Expr b -> Expr b
forall a. (Columnable a, Eq a) => Expr a -> Expr a
pruneExpr Expr b
e)
pruneExpr (Binary BinaryOp c b a
op Expr c
l Expr b
r) = BinaryOp c b a -> Expr c -> Expr b -> Expr a
forall c b a.
(Columnable c, Columnable b, Columnable a) =>
BinaryOp c b a -> Expr c -> Expr b -> Expr a
Binary BinaryOp c b a
op (Expr c -> Expr c
forall a. (Columnable a, Eq a) => Expr a -> Expr a
pruneExpr Expr c
l) (Expr b -> Expr b
forall a. (Columnable a, Eq a) => Expr a -> Expr a
pruneExpr Expr b
r)
pruneExpr Expr a
e = Expr a
e
buildGreedyTree ::
forall a.
(Columnable a) =>
TreeConfig ->
Int ->
T.Text ->
[Expr Bool] ->
DataFrame ->
Tree a
buildGreedyTree :: forall a.
Columnable a =>
TreeConfig -> Int -> Text -> [Expr Bool] -> DataFrame -> Tree a
buildGreedyTree TreeConfig
cfg Int
depth Text
target [Expr Bool]
conds DataFrame
df
| Int
depth Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0 Bool -> Bool -> Bool
|| DataFrame -> Int
nRows DataFrame
df Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= TreeConfig -> Int
minSamplesSplit TreeConfig
cfg =
a -> Tree a
forall a. a -> Tree a
Leaf (forall a. Columnable a => Text -> DataFrame -> a
majorityValue @a Text
target DataFrame
df)
| Bool
otherwise =
case forall a.
Columnable a =>
TreeConfig -> Text -> [Expr Bool] -> DataFrame -> Maybe (Expr Bool)
findBestGreedySplit @a TreeConfig
cfg Text
target [Expr Bool]
conds DataFrame
df of
Maybe (Expr Bool)
Nothing -> a -> Tree a
forall a. a -> Tree a
Leaf (forall a. Columnable a => Text -> DataFrame -> a
majorityValue @a Text
target DataFrame
df)
Just Expr Bool
bestCond ->
let (DataFrame
dfTrue, DataFrame
dfFalse) = Expr Bool -> DataFrame -> (DataFrame, DataFrame)
partitionDataFrame Expr Bool
bestCond DataFrame
df
in if DataFrame -> Int
nRows DataFrame
dfTrue Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< TreeConfig -> Int
minLeafSize TreeConfig
cfg Bool -> Bool -> Bool
|| DataFrame -> Int
nRows DataFrame
dfFalse Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< TreeConfig -> Int
minLeafSize TreeConfig
cfg
then a -> Tree a
forall a. a -> Tree a
Leaf (forall a. Columnable a => Text -> DataFrame -> a
majorityValue @a Text
target DataFrame
df)
else
Expr Bool -> Tree a -> Tree a -> Tree a
forall a. Expr Bool -> Tree a -> Tree a -> Tree a
Branch
Expr Bool
bestCond
(forall a.
Columnable a =>
TreeConfig -> Int -> Text -> [Expr Bool] -> DataFrame -> Tree a
buildGreedyTree @a TreeConfig
cfg (Int
depth Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Text
target [Expr Bool]
conds DataFrame
dfTrue)
(forall a.
Columnable a =>
TreeConfig -> Int -> Text -> [Expr Bool] -> DataFrame -> Tree a
buildGreedyTree @a TreeConfig
cfg (Int
depth Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Text
target [Expr Bool]
conds DataFrame
dfFalse)
findBestGreedySplit ::
forall a.
(Columnable a) =>
TreeConfig -> T.Text -> [Expr Bool] -> DataFrame -> Maybe (Expr Bool)
findBestGreedySplit :: forall a.
Columnable a =>
TreeConfig -> Text -> [Expr Bool] -> DataFrame -> Maybe (Expr Bool)
findBestGreedySplit TreeConfig
cfg Text
target [Expr Bool]
conds DataFrame
df =
let
initialImpurity :: Double
initialImpurity = forall a. Columnable a => Text -> DataFrame -> Double
calculateGini @a Text
target DataFrame
df
calculateComplexity :: Expr Bool -> Double
calculateComplexity Expr Bool
c = SynthConfig -> Double
complexityPenalty (TreeConfig -> SynthConfig
synthConfig TreeConfig
cfg) Double -> Double -> Double
forall a. Num a => a -> a -> a
* Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Expr Bool -> Int
forall a. Expr a -> Int
eSize Expr Bool
c)
evalGain :: Expr Bool -> (Double, Int)
evalGain :: Expr Bool -> (Double, Int)
evalGain Expr Bool
cond =
let (DataFrame
t, DataFrame
f) = Expr Bool -> DataFrame -> (DataFrame, DataFrame)
partitionDataFrame Expr Bool
cond DataFrame
df
n :: Double
n = forall a b. (Integral a, Num b) => a -> b
fromIntegral @Int @Double (DataFrame -> Int
nRows DataFrame
df)
weightT :: Double
weightT = forall a b. (Integral a, Num b) => a -> b
fromIntegral @Int @Double (DataFrame -> Int
nRows DataFrame
t) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
n
weightF :: Double
weightF = forall a b. (Integral a, Num b) => a -> b
fromIntegral @Int @Double (DataFrame -> Int
nRows DataFrame
f) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
n
newImpurity :: Double
newImpurity =
Double
weightT Double -> Double -> Double
forall a. Num a => a -> a -> a
* forall a. Columnable a => Text -> DataFrame -> Double
calculateGini @a Text
target DataFrame
t
Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
weightF Double -> Double -> Double
forall a. Num a => a -> a -> a
* forall a. Columnable a => Text -> DataFrame -> Double
calculateGini @a Text
target DataFrame
f
in ( (Double
initialImpurity Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
newImpurity) Double -> Double -> Double
forall a. Num a => a -> a -> a
- Expr Bool -> Double
calculateComplexity Expr Bool
cond
, Int -> Int
forall a. Num a => a -> a
negate (Expr Bool -> Int
forall a. Expr a -> Int
eSize Expr Bool
cond)
)
validConds :: [Expr Bool]
validConds =
(Expr Bool -> Bool) -> [Expr Bool] -> [Expr Bool]
forall a. (a -> Bool) -> [a] -> [a]
filter
( \Expr Bool
c ->
let (DataFrame
t, DataFrame
f) = Expr Bool -> DataFrame -> (DataFrame, DataFrame)
partitionDataFrame Expr Bool
c DataFrame
df
in DataFrame -> Int
nRows DataFrame
t Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= TreeConfig -> Int
minLeafSize TreeConfig
cfg Bool -> Bool -> Bool
&& DataFrame -> Int
nRows DataFrame
f Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= TreeConfig -> Int
minLeafSize TreeConfig
cfg
)
[Expr Bool]
conds
sortedConditions :: [Expr Bool]
sortedConditions =
((Expr Bool, (Double, Int)) -> Expr Bool)
-> [(Expr Bool, (Double, Int))] -> [Expr Bool]
forall a b. (a -> b) -> [a] -> [b]
map (Expr Bool, (Double, Int)) -> Expr Bool
forall a b. (a, b) -> a
fst ([(Expr Bool, (Double, Int))] -> [Expr Bool])
-> [(Expr Bool, (Double, Int))] -> [Expr Bool]
forall a b. (a -> b) -> a -> b
$
Int -> [(Expr Bool, (Double, Int))] -> [(Expr Bool, (Double, Int))]
forall a. Int -> [a] -> [a]
take
(TreeConfig -> Int
expressionPairs TreeConfig
cfg)
( ((Expr Bool, (Double, Int)) -> Bool)
-> [(Expr Bool, (Double, Int))] -> [(Expr Bool, (Double, Int))]
forall a. (a -> Bool) -> [a] -> [a]
filter
(\(Expr Bool
c, (Double, Int)
v) -> ((Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
> Double -> Double
forall a. Num a => a -> a
negate (Expr Bool -> Double
calculateComplexity Expr Bool
c)) (Double -> Bool)
-> ((Double, Int) -> Double) -> (Double, Int) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double, Int) -> Double
forall a b. (a, b) -> a
fst) (Double, Int)
v)
(((Expr Bool, (Double, Int))
-> (Expr Bool, (Double, Int)) -> Ordering)
-> [(Expr Bool, (Double, Int))] -> [(Expr Bool, (Double, Int))]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (((Double, Int) -> (Double, Int) -> Ordering)
-> (Double, Int) -> (Double, Int) -> Ordering
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Double, Int) -> (Double, Int) -> Ordering
forall a. Ord a => a -> a -> Ordering
compare ((Double, Int) -> (Double, Int) -> Ordering)
-> ((Expr Bool, (Double, Int)) -> (Double, Int))
-> (Expr Bool, (Double, Int))
-> (Expr Bool, (Double, Int))
-> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (Expr Bool, (Double, Int)) -> (Double, Int)
forall a b. (a, b) -> b
snd) ((Expr Bool -> (Expr Bool, (Double, Int)))
-> [Expr Bool] -> [(Expr Bool, (Double, Int))]
forall a b. (a -> b) -> [a] -> [b]
map (\Expr Bool
c -> (Expr Bool
c, Expr Bool -> (Double, Int)
evalGain Expr Bool
c)) [Expr Bool]
validConds))
)
in
if [Expr Bool] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Expr Bool]
sortedConditions
then Maybe (Expr Bool)
forall a. Maybe a
Nothing
else
Expr Bool -> Maybe (Expr Bool)
forall a. a -> Maybe a
Just (Expr Bool -> Maybe (Expr Bool)) -> Expr Bool -> Maybe (Expr Bool)
forall a b. (a -> b) -> a -> b
$
(Expr Bool -> Expr Bool -> Ordering) -> [Expr Bool] -> Expr Bool
forall (t :: * -> *) a.
Foldable t =>
(a -> a -> Ordering) -> t a -> a
maximumBy
((Double, Int) -> (Double, Int) -> Ordering
forall a. Ord a => a -> a -> Ordering
compare ((Double, Int) -> (Double, Int) -> Ordering)
-> (Expr Bool -> (Double, Int))
-> Expr Bool
-> Expr Bool
-> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` Expr Bool -> (Double, Int)
evalGain)
( DataFrame
-> [Expr Bool] -> [Expr Bool] -> Int -> Int -> [Expr Bool]
boolExprs
DataFrame
df
[Expr Bool]
sortedConditions
[Expr Bool]
sortedConditions
Int
0
(SynthConfig -> Int
boolExpansion (TreeConfig -> SynthConfig
synthConfig TreeConfig
cfg))
)
numericConditions :: TreeConfig -> DataFrame -> [Expr Bool]
numericConditions :: TreeConfig -> DataFrame -> [Expr Bool]
numericConditions = TreeConfig -> DataFrame -> [Expr Bool]
generateNumericConds
generateNumericConds :: TreeConfig -> DataFrame -> [Expr Bool]
generateNumericConds :: TreeConfig -> DataFrame -> [Expr Bool]
generateNumericConds TreeConfig
cfg DataFrame
df = do
Expr Double
expr <- SynthConfig -> DataFrame -> [Expr Double]
numericExprsWithTerms (TreeConfig -> SynthConfig
synthConfig TreeConfig
cfg) DataFrame
df
let thresholds :: [Double]
thresholds = (Int -> Double) -> [Int] -> [Double]
forall a b. (a -> b) -> [a] -> [b]
map (\Int
p -> Int -> Expr Double -> DataFrame -> Double
percentile Int
p Expr Double
expr DataFrame
df) (TreeConfig -> [Int]
percentiles TreeConfig
cfg)
Double
threshold <- [Double]
thresholds
[ Expr Double
expr Expr Double -> Expr Double -> Expr Bool
forall a.
(Columnable a, Ord a, Eq a) =>
Expr a -> Expr a -> Expr Bool
.<= Double -> Expr Double
forall a. Columnable a => a -> Expr a
F.lit Double
threshold
, Expr Double
expr Expr Double -> Expr Double -> Expr Bool
forall a.
(Columnable a, Ord a, Eq a) =>
Expr a -> Expr a -> Expr Bool
.>= Double -> Expr Double
forall a. Columnable a => a -> Expr a
F.lit Double
threshold
, Expr Double
expr Expr Double -> Expr Double -> Expr Bool
forall a. (Columnable a, Ord a) => Expr a -> Expr a -> Expr Bool
.< Double -> Expr Double
forall a. Columnable a => a -> Expr a
F.lit Double
threshold
, Expr Double
expr Expr Double -> Expr Double -> Expr Bool
forall a. (Columnable a, Ord a) => Expr a -> Expr a -> Expr Bool
.> Double -> Expr Double
forall a. Columnable a => a -> Expr a
F.lit Double
threshold
]
numericExprsWithTerms :: SynthConfig -> DataFrame -> [Expr Double]
numericExprsWithTerms :: SynthConfig -> DataFrame -> [Expr Double]
numericExprsWithTerms SynthConfig
cfg DataFrame
df =
(Int -> [Expr Double]) -> [Int] -> [Expr Double]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (SynthConfig
-> DataFrame -> [Expr Double] -> Int -> Int -> [Expr Double]
numericExprs SynthConfig
cfg DataFrame
df [] Int
0) [Int
0 .. SynthConfig -> Int
maxExprDepth SynthConfig
cfg]
numericCols :: DataFrame -> [Expr Double]
numericCols :: DataFrame -> [Expr Double]
numericCols DataFrame
df = (Text -> [Expr Double]) -> [Text] -> [Expr Double]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Text -> [Expr Double]
extract (DataFrame -> [Text]
columnNames DataFrame
df)
where
extract :: Text -> [Expr Double]
extract Text
col = case Text -> DataFrame -> Column
unsafeGetColumn Text
col DataFrame
df of
UnboxedColumn (Vector a
_ :: VU.Vector b) ->
case TypeRep a -> TypeRep Double -> Maybe (a :~: Double)
forall a b. TypeRep a -> TypeRep b -> Maybe (a :~: b)
forall {k} (f :: k -> *) (a :: k) (b :: k).
TestEquality f =>
f a -> f b -> Maybe (a :~: b)
testEquality (forall a. Typeable a => TypeRep a
forall {k} (a :: k). Typeable a => TypeRep a
typeRep @b) (forall a. Typeable a => TypeRep a
forall {k} (a :: k). Typeable a => TypeRep a
typeRep @Double) of
Just a :~: Double
Refl -> [Text -> Expr Double
forall a. Columnable a => Text -> Expr a
Col Text
col]
Maybe (a :~: Double)
Nothing -> case forall a. SBoolI (IntegralTypes a) => SBool (IntegralTypes a)
sIntegral @b of
SBool (IntegralTypes a)
STrue -> [Expr a -> Expr Double
forall a. (Columnable a, Real a) => Expr a -> Expr Double
F.toDouble (forall a. Columnable a => Text -> Expr a
Col @b Text
col)]
SBool (IntegralTypes a)
SFalse -> []
Column
_ -> []
numericExprs ::
SynthConfig -> DataFrame -> [Expr Double] -> Int -> Int -> [Expr Double]
numericExprs :: SynthConfig
-> DataFrame -> [Expr Double] -> Int -> Int -> [Expr Double]
numericExprs SynthConfig
cfg DataFrame
df [Expr Double]
prevExprs Int
depth Int
maxDepth
| Int
depth Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = [Expr Double]
baseExprs [Expr Double] -> [Expr Double] -> [Expr Double]
forall a. [a] -> [a] -> [a]
++ SynthConfig
-> DataFrame -> [Expr Double] -> Int -> Int -> [Expr Double]
numericExprs SynthConfig
cfg DataFrame
df [Expr Double]
baseExprs (Int
depth Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
maxDepth
| Int
depth Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
maxDepth = []
| Bool
otherwise =
[Expr Double]
combinedExprs [Expr Double] -> [Expr Double] -> [Expr Double]
forall a. [a] -> [a] -> [a]
++ SynthConfig
-> DataFrame -> [Expr Double] -> Int -> Int -> [Expr Double]
numericExprs SynthConfig
cfg DataFrame
df [Expr Double]
combinedExprs (Int
depth Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
maxDepth
where
baseExprs :: [Expr Double]
baseExprs = DataFrame -> [Expr Double]
numericCols DataFrame
df
combinedExprs :: [Expr Double]
combinedExprs
| Bool -> Bool
not (SynthConfig -> Bool
enableArithOps SynthConfig
cfg) = []
| Bool
otherwise = do
Expr Double
e1 <- [Expr Double]
prevExprs
Expr Double
e2 <- [Expr Double]
baseExprs
let cols :: [Text]
cols = Expr Double -> [Text]
forall a. Expr a -> [Text]
getColumns Expr Double
e1 [Text] -> [Text] -> [Text]
forall a. Semigroup a => a -> a -> a
<> Expr Double -> [Text]
forall a. Expr a -> [Text]
getColumns Expr Double
e2
Bool -> [()]
forall (f :: * -> *). Alternative f => Bool -> f ()
guard
( Expr Double
e1 Expr Double -> Expr Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Expr Double
e2
Bool -> Bool -> Bool
&& Bool -> Bool
not
( ((Text, Text) -> Bool) -> [(Text, Text)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any
(\(Text
l, Text
r) -> Text
l Text -> [Text] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Text]
cols Bool -> Bool -> Bool
&& Text
r Text -> [Text] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Text]
cols)
(SynthConfig -> [(Text, Text)]
disallowedCombinations SynthConfig
cfg)
)
)
[Expr Double
e1 Expr Double -> Expr Double -> Expr Double
forall a. Num a => a -> a -> a
+ Expr Double
e2, Expr Double
e1 Expr Double -> Expr Double -> Expr Double
forall a. Num a => a -> a -> a
- Expr Double
e2, Expr Double
e1 Expr Double -> Expr Double -> Expr Double
forall a. Num a => a -> a -> a
* Expr Double
e2, Expr Bool -> Expr Double -> Expr Double -> Expr Double
forall a. Columnable a => Expr Bool -> Expr a -> Expr a -> Expr a
F.ifThenElse (Expr Double
e2 Expr Double -> Expr Double -> Expr Bool
forall a. (Columnable a, Eq a) => Expr a -> Expr a -> Expr Bool
./= Expr Double
0) (Expr Double
e1 Expr Double -> Expr Double -> Expr Double
forall a. Fractional a => a -> a -> a
/ Expr Double
e2) Expr Double
0]
boolExprs ::
DataFrame -> [Expr Bool] -> [Expr Bool] -> Int -> Int -> [Expr Bool]
boolExprs :: DataFrame
-> [Expr Bool] -> [Expr Bool] -> Int -> Int -> [Expr Bool]
boolExprs DataFrame
df [Expr Bool]
baseExprs [Expr Bool]
prevExprs Int
depth Int
maxDepth
| Int
depth Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 =
[Expr Bool]
baseExprs [Expr Bool] -> [Expr Bool] -> [Expr Bool]
forall a. [a] -> [a] -> [a]
++ DataFrame
-> [Expr Bool] -> [Expr Bool] -> Int -> Int -> [Expr Bool]
boolExprs DataFrame
df [Expr Bool]
baseExprs [Expr Bool]
prevExprs (Int
depth Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
maxDepth
| Int
depth Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
maxDepth = []
| Bool
otherwise =
[Expr Bool]
combinedExprs [Expr Bool] -> [Expr Bool] -> [Expr Bool]
forall a. [a] -> [a] -> [a]
++ DataFrame
-> [Expr Bool] -> [Expr Bool] -> Int -> Int -> [Expr Bool]
boolExprs DataFrame
df [Expr Bool]
baseExprs [Expr Bool]
combinedExprs (Int
depth Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
maxDepth
where
combinedExprs :: [Expr Bool]
combinedExprs = do
Expr Bool
e1 <- [Expr Bool]
prevExprs
Expr Bool
e2 <- [Expr Bool]
baseExprs
Bool -> [()]
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Expr Bool
e1 Expr Bool -> Expr Bool -> Bool
forall a. Eq a => a -> a -> Bool
/= Expr Bool
e2)
[Expr Bool -> Expr Bool -> Expr Bool
F.and Expr Bool
e1 Expr Bool
e2, Expr Bool -> Expr Bool -> Expr Bool
F.or Expr Bool
e1 Expr Bool
e2]
generateConditionsOld :: TreeConfig -> DataFrame -> [Expr Bool]
generateConditionsOld :: TreeConfig -> DataFrame -> [Expr Bool]
generateConditionsOld TreeConfig
cfg DataFrame
df =
let
genConds :: T.Text -> [Expr Bool]
genConds :: Text -> [Expr Bool]
genConds Text
colName = case Text -> DataFrame -> Column
unsafeGetColumn Text
colName DataFrame
df of
(BoxedColumn (Vector a
col :: V.Vector a)) ->
let ps :: [Expr a]
ps = (Int -> Expr a) -> [Int] -> [Expr a]
forall a b. (a -> b) -> [a] -> [b]
map (a -> Expr a
forall a. Columnable a => a -> Expr a
Lit (a -> Expr a) -> (Int -> a) -> Int -> Expr a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Vector a -> a
forall a. (Ord a, Eq a) => Int -> Vector a -> a
`percentileOrd'` Vector a
col)) [Int
1, Int
25, Int
75, Int
99]
in (Expr a -> Expr Bool) -> [Expr a] -> [Expr Bool]
forall a b. (a -> b) -> [a] -> [b]
map (forall a. Columnable a => Text -> Expr a
Col @a Text
colName Expr a -> Expr a -> Expr Bool
forall a. (Columnable a, Eq a) => Expr a -> Expr a -> Expr Bool
.==) [Expr a]
ps
(OptionalColumn (Vector (Maybe a)
col :: V.Vector (Maybe a))) -> case forall a. SBoolI (FloatingTypes a) => SBool (FloatingTypes a)
sFloating @a of
SBool (FloatingTypes a)
STrue ->
let doubleCol :: Vector Double
doubleCol =
Vector Double -> Vector Double
forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
VU.convert
((Maybe Double -> Double) -> Vector (Maybe Double) -> Vector Double
forall a b. (a -> b) -> Vector a -> Vector b
V.map Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust ((Maybe Double -> Bool)
-> Vector (Maybe Double) -> Vector (Maybe Double)
forall a. (a -> Bool) -> Vector a -> Vector a
V.filter Maybe Double -> Bool
forall a. Maybe a -> Bool
isJust ((Maybe a -> Maybe Double)
-> Vector (Maybe a) -> Vector (Maybe Double)
forall a b. (a -> b) -> Vector a -> Vector b
V.map ((a -> Double) -> Maybe a -> Maybe Double
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a b. (Real a, Fractional b) => a -> b
realToFrac @a @Double)) Vector (Maybe a)
col)))
in ((Expr (Maybe a) -> Expr Bool) -> Expr (Maybe a) -> Expr Bool)
-> [Expr (Maybe a) -> Expr Bool] -> [Expr (Maybe a)] -> [Expr Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
(Expr (Maybe a) -> Expr Bool) -> Expr (Maybe a) -> Expr Bool
forall a b. (a -> b) -> a -> b
($)
[ (forall a. Columnable a => Text -> Expr a
Col @(Maybe a) Text
colName Expr (Maybe a) -> Expr (Maybe a) -> Expr Bool
forall a. (Columnable a, Eq a) => Expr a -> Expr a -> Expr Bool
.==)
, (forall a. Columnable a => Text -> Expr a
Col @(Maybe a) Text
colName Expr (Maybe a) -> Expr (Maybe a) -> Expr Bool
forall a.
(Columnable a, Ord a, Eq a) =>
Expr a -> Expr a -> Expr Bool
.<=)
, (forall a. Columnable a => Text -> Expr a
Col @(Maybe a) Text
colName Expr (Maybe a) -> Expr (Maybe a) -> Expr Bool
forall a.
(Columnable a, Ord a, Eq a) =>
Expr a -> Expr a -> Expr Bool
.>=)
]
( Maybe a -> Expr (Maybe a)
forall a. Columnable a => a -> Expr a
Lit Maybe a
forall a. Maybe a
Nothing
Expr (Maybe a) -> [Expr (Maybe a)] -> [Expr (Maybe a)]
forall a. a -> [a] -> [a]
: (Int -> Expr (Maybe a)) -> [Int] -> [Expr (Maybe a)]
forall a b. (a -> b) -> [a] -> [b]
map
(Maybe a -> Expr (Maybe a)
forall a. Columnable a => a -> Expr a
Lit (Maybe a -> Expr (Maybe a))
-> (Int -> Maybe a) -> Int -> Expr (Maybe a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> (Int -> a) -> Int -> Maybe a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> a
forall a b. (Real a, Fractional b) => a -> b
realToFrac (Double -> a) -> (Int -> Double) -> Int -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Vector Double -> Double
forall a. (Unbox a, Num a, Real a) => Int -> Vector a -> Double
`percentile'` Vector Double
doubleCol))
(TreeConfig -> [Int]
percentiles TreeConfig
cfg)
)
SBool (FloatingTypes a)
SFalse -> case forall a. SBoolI (IntegralTypes a) => SBool (IntegralTypes a)
sIntegral @a of
SBool (IntegralTypes a)
STrue ->
let doubleCol :: Vector Double
doubleCol =
Vector Double -> Vector Double
forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
VU.convert
((Maybe Double -> Double) -> Vector (Maybe Double) -> Vector Double
forall a b. (a -> b) -> Vector a -> Vector b
V.map Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust ((Maybe Double -> Bool)
-> Vector (Maybe Double) -> Vector (Maybe Double)
forall a. (a -> Bool) -> Vector a -> Vector a
V.filter Maybe Double -> Bool
forall a. Maybe a -> Bool
isJust ((Maybe a -> Maybe Double)
-> Vector (Maybe a) -> Vector (Maybe Double)
forall a b. (a -> b) -> Vector a -> Vector b
V.map ((a -> Double) -> Maybe a -> Maybe Double
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a b. (Integral a, Num b) => a -> b
fromIntegral @a @Double)) Vector (Maybe a)
col)))
in ((Expr (Maybe a) -> Expr Bool) -> Expr (Maybe a) -> Expr Bool)
-> [Expr (Maybe a) -> Expr Bool] -> [Expr (Maybe a)] -> [Expr Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
(Expr (Maybe a) -> Expr Bool) -> Expr (Maybe a) -> Expr Bool
forall a b. (a -> b) -> a -> b
($)
[ (forall a. Columnable a => Text -> Expr a
Col @(Maybe a) Text
colName Expr (Maybe a) -> Expr (Maybe a) -> Expr Bool
forall a. (Columnable a, Eq a) => Expr a -> Expr a -> Expr Bool
.==)
, (forall a. Columnable a => Text -> Expr a
Col @(Maybe a) Text
colName Expr (Maybe a) -> Expr (Maybe a) -> Expr Bool
forall a.
(Columnable a, Ord a, Eq a) =>
Expr a -> Expr a -> Expr Bool
.<=)
, (forall a. Columnable a => Text -> Expr a
Col @(Maybe a) Text
colName Expr (Maybe a) -> Expr (Maybe a) -> Expr Bool
forall a.
(Columnable a, Ord a, Eq a) =>
Expr a -> Expr a -> Expr Bool
.>=)
]
( Maybe a -> Expr (Maybe a)
forall a. Columnable a => a -> Expr a
Lit Maybe a
forall a. Maybe a
Nothing
Expr (Maybe a) -> [Expr (Maybe a)] -> [Expr (Maybe a)]
forall a. a -> [a] -> [a]
: (Int -> Expr (Maybe a)) -> [Int] -> [Expr (Maybe a)]
forall a b. (a -> b) -> [a] -> [b]
map
(Maybe a -> Expr (Maybe a)
forall a. Columnable a => a -> Expr a
Lit (Maybe a -> Expr (Maybe a))
-> (Int -> Maybe a) -> Int -> Expr (Maybe a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> (Int -> a) -> Int -> Maybe a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> a
forall b. Integral b => Double -> b
forall a b. (RealFrac a, Integral b) => a -> b
round (Double -> a) -> (Int -> Double) -> Int -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Vector Double -> Double
forall a. (Unbox a, Num a, Real a) => Int -> Vector a -> Double
`percentile'` Vector Double
doubleCol))
(TreeConfig -> [Int]
percentiles TreeConfig
cfg)
)
SBool (IntegralTypes a)
SFalse ->
(Int -> Expr Bool) -> [Int] -> [Expr Bool]
forall a b. (a -> b) -> [a] -> [b]
map
((forall a. Columnable a => Text -> Expr a
Col @(Maybe a) Text
colName Expr (Maybe a) -> Expr (Maybe a) -> Expr Bool
forall a. (Columnable a, Eq a) => Expr a -> Expr a -> Expr Bool
.==) (Expr (Maybe a) -> Expr Bool)
-> (Int -> Expr (Maybe a)) -> Int -> Expr Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe a -> Expr (Maybe a)
forall a. Columnable a => a -> Expr a
Lit (Maybe a -> Expr (Maybe a))
-> (Int -> Maybe a) -> Int -> Expr (Maybe a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Vector (Maybe a) -> Maybe a
forall a. (Ord a, Eq a) => Int -> Vector a -> a
`percentileOrd'` Vector (Maybe a)
col))
[Int
1, Int
25, Int
75, Int
99]
(UnboxedColumn (Vector a
_ :: VU.Vector a)) -> []
columnConds :: [Expr Bool]
columnConds =
((Text, Text) -> [Expr Bool]) -> [(Text, Text)] -> [Expr Bool]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap
(Text, Text) -> [Expr Bool]
colConds
[ (Text
l, Text
r)
| Text
l <- DataFrame -> [Text]
columnNames DataFrame
df
, Text
r <- DataFrame -> [Text]
columnNames DataFrame
df
, Bool -> Bool
not
( ((Text, Text) -> Bool) -> [(Text, Text)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any
(\(Text
l', Text
r') -> [Text] -> [Text]
forall a. Ord a => [a] -> [a]
sort [Text
l', Text
r'] [Text] -> [Text] -> Bool
forall a. Eq a => a -> a -> Bool
== [Text] -> [Text]
forall a. Ord a => [a] -> [a]
sort [Text
l, Text
r])
(SynthConfig -> [(Text, Text)]
disallowedCombinations (TreeConfig -> SynthConfig
synthConfig TreeConfig
cfg))
)
]
where
colConds :: (Text, Text) -> [Expr Bool]
colConds (!Text
l, !Text
r) = case (Text -> DataFrame -> Column
unsafeGetColumn Text
l DataFrame
df, Text -> DataFrame -> Column
unsafeGetColumn Text
r DataFrame
df) of
(BoxedColumn (Vector a
col1 :: V.Vector a), BoxedColumn (Vector a
_ :: V.Vector b)) ->
case TypeRep a -> TypeRep a -> Maybe (a :~: a)
forall a b. TypeRep a -> TypeRep b -> Maybe (a :~: b)
forall {k} (f :: k -> *) (a :: k) (b :: k).
TestEquality f =>
f a -> f b -> Maybe (a :~: b)
testEquality (forall a. Typeable a => TypeRep a
forall {k} (a :: k). Typeable a => TypeRep a
typeRep @a) (forall a. Typeable a => TypeRep a
forall {k} (a :: k). Typeable a => TypeRep a
typeRep @b) of
Maybe (a :~: a)
Nothing -> []
Just a :~: a
Refl -> [forall a. Columnable a => Text -> Expr a
Col @a Text
l Expr a -> Expr a -> Expr Bool
forall a. (Columnable a, Eq a) => Expr a -> Expr a -> Expr Bool
.== forall a. Columnable a => Text -> Expr a
Col @a Text
r]
(UnboxedColumn (Vector a
_ :: VU.Vector a), UnboxedColumn (Vector a
_ :: VU.Vector b)) -> []
( OptionalColumn (Vector (Maybe a)
_ :: V.Vector (Maybe a))
, OptionalColumn (Vector (Maybe a)
_ :: V.Vector (Maybe b))
) -> case TypeRep a -> TypeRep a -> Maybe (a :~: a)
forall a b. TypeRep a -> TypeRep b -> Maybe (a :~: b)
forall {k} (f :: k -> *) (a :: k) (b :: k).
TestEquality f =>
f a -> f b -> Maybe (a :~: b)
testEquality (forall a. Typeable a => TypeRep a
forall {k} (a :: k). Typeable a => TypeRep a
typeRep @a) (forall a. Typeable a => TypeRep a
forall {k} (a :: k). Typeable a => TypeRep a
typeRep @b) of
Maybe (a :~: a)
Nothing -> []
Just a :~: a
Refl -> case TypeRep a -> TypeRep Text -> Maybe (a :~: Text)
forall a b. TypeRep a -> TypeRep b -> Maybe (a :~: b)
forall {k} (f :: k -> *) (a :: k) (b :: k).
TestEquality f =>
f a -> f b -> Maybe (a :~: b)
testEquality (forall a. Typeable a => TypeRep a
forall {k} (a :: k). Typeable a => TypeRep a
typeRep @a) (forall a. Typeable a => TypeRep a
forall {k} (a :: k). Typeable a => TypeRep a
typeRep @T.Text) of
Maybe (a :~: Text)
Nothing -> [forall a. Columnable a => Text -> Expr a
Col @(Maybe a) Text
l Expr (Maybe a) -> Expr (Maybe a) -> Expr Bool
forall a.
(Columnable a, Ord a, Eq a) =>
Expr a -> Expr a -> Expr Bool
.<= Text -> Expr (Maybe a)
forall a. Columnable a => Text -> Expr a
Col Text
r, forall a. Columnable a => Text -> Expr a
Col @(Maybe a) Text
l Expr (Maybe a) -> Expr (Maybe a) -> Expr Bool
forall a. (Columnable a, Eq a) => Expr a -> Expr a -> Expr Bool
.== Text -> Expr (Maybe a)
forall a. Columnable a => Text -> Expr a
Col Text
r]
Just a :~: Text
Refl -> [forall a. Columnable a => Text -> Expr a
Col @(Maybe a) Text
l Expr (Maybe a) -> Expr (Maybe a) -> Expr Bool
forall a. (Columnable a, Eq a) => Expr a -> Expr a -> Expr Bool
.== Text -> Expr (Maybe a)
forall a. Columnable a => Text -> Expr a
Col Text
r]
(Column, Column)
_ -> []
in
(Text -> [Expr Bool]) -> [Text] -> [Expr Bool]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Text -> [Expr Bool]
genConds (DataFrame -> [Text]
columnNames DataFrame
df) [Expr Bool] -> [Expr Bool] -> [Expr Bool]
forall a. [a] -> [a] -> [a]
++ [Expr Bool]
columnConds
partitionDataFrame :: Expr Bool -> DataFrame -> (DataFrame, DataFrame)
partitionDataFrame :: Expr Bool -> DataFrame -> (DataFrame, DataFrame)
partitionDataFrame Expr Bool
cond DataFrame
df = (Expr Bool -> DataFrame -> DataFrame
filterWhere Expr Bool
cond DataFrame
df, Expr Bool -> DataFrame -> DataFrame
filterWhere (Expr Bool -> Expr Bool
F.not Expr Bool
cond) DataFrame
df)
calculateGini :: forall a. (Columnable a) => T.Text -> DataFrame -> Double
calculateGini :: forall a. Columnable a => Text -> DataFrame -> Double
calculateGini Text
target DataFrame
df =
let n :: Double
n = Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Double) -> Int -> Double
forall a b. (a -> b) -> a -> b
$ DataFrame -> Int
nRows DataFrame
df
counts :: Map a Int
counts = forall a. Columnable a => Text -> DataFrame -> Map a Int
getCounts @a Text
target DataFrame
df
numClasses :: Double
numClasses = Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Double) -> Int -> Double
forall a b. (a -> b) -> a -> b
$ Map a Int -> Int
forall k a. Map k a -> Int
M.size Map a Int
counts
probs :: [Double]
probs = (Int -> Double) -> [Int] -> [Double]
forall a b. (a -> b) -> [a] -> [b]
map (\Int
c -> (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
c Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
1) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ (Double
n Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
numClasses)) (Map a Int -> [Int]
forall k a. Map k a -> [a]
M.elems Map a Int
counts)
in if Double
n Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
== Double
0 then Double
0 else Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- [Double] -> Double
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((Double -> Double) -> [Double] -> [Double]
forall a b. (a -> b) -> [a] -> [b]
map (Double -> Integer -> Double
forall a b. (Num a, Integral b) => a -> b -> a
^ Integer
2) [Double]
probs)
majorityValue :: forall a. (Columnable a) => T.Text -> DataFrame -> a
majorityValue :: forall a. Columnable a => Text -> DataFrame -> a
majorityValue Text
target DataFrame
df =
let counts :: Map a Int
counts = forall a. Columnable a => Text -> DataFrame -> Map a Int
getCounts @a Text
target DataFrame
df
in if Map a Int -> Bool
forall k a. Map k a -> Bool
M.null Map a Int
counts
then String -> a
forall a. HasCallStack => String -> a
error String
"Empty DataFrame in leaf"
else (a, Int) -> a
forall a b. (a, b) -> a
fst ((a, Int) -> a) -> (a, Int) -> a
forall a b. (a -> b) -> a -> b
$ ((a, Int) -> (a, Int) -> Ordering) -> [(a, Int)] -> (a, Int)
forall (t :: * -> *) a.
Foldable t =>
(a -> a -> Ordering) -> t a -> a
maximumBy (Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Int -> Int -> Ordering)
-> ((a, Int) -> Int) -> (a, Int) -> (a, Int) -> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (a, Int) -> Int
forall a b. (a, b) -> b
snd) (Map a Int -> [(a, Int)]
forall k a. Map k a -> [(k, a)]
M.toList Map a Int
counts)
getCounts :: forall a. (Columnable a) => T.Text -> DataFrame -> M.Map a Int
getCounts :: forall a. Columnable a => Text -> DataFrame -> Map a Int
getCounts Text
target DataFrame
df =
case forall a.
Columnable a =>
DataFrame -> Expr a -> Either DataFrameException (TypedColumn a)
interpret @a DataFrame
df (Text -> Expr a
forall a. Columnable a => Text -> Expr a
Col Text
target) of
Left DataFrameException
e -> DataFrameException -> Map a Int
forall a e. Exception e => e -> a
throw DataFrameException
e
Right (TColumn Column
col) ->
case forall a (v :: * -> *).
(Vector v a, Columnable a) =>
Column -> Either DataFrameException (v a)
toVector @a Column
col of
Left DataFrameException
e -> DataFrameException -> Map a Int
forall a e. Exception e => e -> a
throw DataFrameException
e
Right Vector a
vals -> (Map a Int -> a -> Map a Int) -> Map a Int -> [a] -> Map a Int
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\Map a Int
acc a
x -> (Int -> Int -> Int) -> a -> Int -> Map a Int -> Map a Int
forall k a. Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
M.insertWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) a
x Int
1 Map a Int
acc) Map a Int
forall k a. Map k a
M.empty (Vector a -> [a]
forall a. Vector a -> [a]
V.toList Vector a
vals)
percentile :: Int -> Expr Double -> DataFrame -> Double
percentile :: Int -> Expr Double -> DataFrame -> Double
percentile Int
p Expr Double
expr DataFrame
df =
case forall a.
Columnable a =>
DataFrame -> Expr a -> Either DataFrameException (TypedColumn a)
interpret @Double DataFrame
df Expr Double
expr of
Left DataFrameException
_ -> Double
0
Right (TColumn Column
col) ->
case forall a (v :: * -> *).
(Vector v a, Columnable a) =>
Column -> Either DataFrameException (v a)
toVector @Double Column
col of
Left DataFrameException
_ -> Double
0
Right Vector Double
vals ->
let sorted :: Vector Double
sorted = [Double] -> Vector Double
forall a. [a] -> Vector a
V.fromList ([Double] -> Vector Double) -> [Double] -> Vector Double
forall a b. (a -> b) -> a -> b
$ [Double] -> [Double]
forall a. Ord a => [a] -> [a]
sort ([Double] -> [Double]) -> [Double] -> [Double]
forall a b. (a -> b) -> a -> b
$ Vector Double -> [Double]
forall a. Vector a -> [a]
V.toList Vector Double
vals
n :: Int
n = Vector Double -> Int
forall a. Vector a -> Int
V.length Vector Double
sorted
idx :: Int
idx = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
0 (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ (Int
p Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
100
in if Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 then Double
0 else Vector Double
sorted Vector Double -> Int -> Double
forall a. Vector a -> Int -> a
V.! Int
idx
buildTree ::
forall a.
(Columnable a) =>
TreeConfig ->
Int ->
T.Text ->
[Expr Bool] ->
DataFrame ->
Expr a
buildTree :: forall a.
Columnable a =>
TreeConfig -> Int -> Text -> [Expr Bool] -> DataFrame -> Expr a
buildTree TreeConfig
cfg Int
depth Text
target [Expr Bool]
conds DataFrame
df =
let
tree :: Tree a
tree = forall a.
Columnable a =>
TreeConfig -> Int -> Text -> [Expr Bool] -> DataFrame -> Tree a
buildGreedyTree @a TreeConfig
cfg Int
depth Text
target [Expr Bool]
conds DataFrame
df
indices :: Vector Int
indices = Int -> Int -> Vector Int
forall a. Num a => a -> Int -> Vector a
V.enumFromN Int
0 (DataFrame -> Int
nRows DataFrame
df)
optimized :: Tree a
optimized = forall a.
Columnable a =>
TreeConfig
-> Text
-> [Expr Bool]
-> DataFrame
-> Vector Int
-> Tree a
-> Tree a
taoOptimize @a TreeConfig
cfg Text
target [Expr Bool]
conds DataFrame
df Vector Int
indices Tree a
tree
in
Expr a -> Expr a
forall a. (Columnable a, Eq a) => Expr a -> Expr a
pruneExpr (Tree a -> Expr a
forall a. Columnable a => Tree a -> Expr a
treeToExpr Tree a
optimized)
findBestSplit ::
forall a.
(Columnable a) =>
TreeConfig -> T.Text -> [Expr Bool] -> DataFrame -> Maybe (Expr Bool)
findBestSplit :: forall a.
Columnable a =>
TreeConfig -> Text -> [Expr Bool] -> DataFrame -> Maybe (Expr Bool)
findBestSplit = forall a.
Columnable a =>
TreeConfig -> Text -> [Expr Bool] -> DataFrame -> Maybe (Expr Bool)
findBestGreedySplit @a
pruneTree :: forall a. (Columnable a, Eq a) => Expr a -> Expr a
pruneTree :: forall a. (Columnable a, Eq a) => Expr a -> Expr a
pruneTree = Expr a -> Expr a
forall a. (Columnable a, Eq a) => Expr a -> Expr a
pruneExpr