-----------------------------------------------------------------------------
-- |
-- Module      :  Algorithm.EqSat.Search
-- Copyright   :  (c) Fabricio Olivetti 2021 - 2024
-- License     :  BSD3
-- Maintainer  :  fabricio.olivetti@gmail.com
-- Stability   :  experimental
-- Portability :
--
-- Support functions for search symbolic expressions with e-graphs
--
-----------------------------------------------------------------------------

module Algorithm.EqSat.SearchSR where

import Data.SRTree
import Data.SRTree.Datasets
import System.Random
import Control.Monad.State.Strict
import Algorithm.EqSat.Egraph
import Algorithm.SRTree.Likelihoods
import qualified Data.IntMap as IM
import qualified Data.IntSet as IntSet
import qualified Data.SRTree.Random as Random
import Data.Function ( on )
import Algorithm.SRTree.Likelihoods
import Algorithm.SRTree.NonlinearOpt
import Control.Monad ( when, replicateM, forM, forM_ )
import Algorithm.EqSat.Egraph
import Algorithm.SRTree.Opt
import Algorithm.EqSat.Info
import Algorithm.EqSat.Build
import Data.Maybe ( fromJust )
import Data.SRTree.Random
import Algorithm.EqSat.Queries
import Data.List ( maximumBy )
import qualified Data.Map.Strict as Map

-- Environment of an e-graph with support to random generator and IO
type RndEGraph a = EGraphST (StateT StdGen IO) a

io :: IO a -> RndEGraph a
io :: forall a. IO a -> RndEGraph a
io = StateT StdGen IO a -> StateT EGraph (StateT StdGen IO) a
forall (m :: * -> *) a. Monad m => m a -> StateT EGraph m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (StateT StdGen IO a -> StateT EGraph (StateT StdGen IO) a)
-> (IO a -> StateT StdGen IO a)
-> IO a
-> StateT EGraph (StateT StdGen IO) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO a -> StateT StdGen IO a
forall (m :: * -> *) a. Monad m => m a -> StateT StdGen m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift
{-# INLINE io #-}
rnd :: StateT StdGen IO a -> RndEGraph a
rnd :: forall a. StateT StdGen IO a -> RndEGraph a
rnd = StateT StdGen IO a -> StateT EGraph (StateT StdGen IO) a
forall (m :: * -> *) a. Monad m => m a -> StateT EGraph m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift
{-# INLINE rnd #-}

myCost :: SRTree Int -> Int
myCost :: ENode -> Int
myCost (Var Int
_)     = Int
1
myCost (Const Double
_)   = Int
1
myCost (Param Int
_)   = Int
1
myCost (Bin Op
_ Int
l Int
r) = Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
r
myCost (Uni Function
_ Int
t)   = Int
3 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
t

while :: Monad f => (t -> Bool) -> t -> (t -> f t) -> f t
while :: forall (f :: * -> *) t.
Monad f =>
(t -> Bool) -> t -> (t -> f t) -> f t
while t -> Bool
p t
arg t -> f t
prog = do if (t -> Bool
p t
arg)
                      then do t
arg' <- t -> f t
prog t
arg
                              (t -> Bool) -> t -> (t -> f t) -> f t
forall (f :: * -> *) t.
Monad f =>
(t -> Bool) -> t -> (t -> f t) -> f t
while t -> Bool
p t
arg' t -> f t
prog
                      else t -> f t
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure t
arg

fitnessFun :: Int -> Distribution -> DataSet -> DataSet -> Fix SRTree -> PVector -> (Double, PVector)
fitnessFun :: Int
-> Distribution
-> DataSet
-> DataSet
-> Fix SRTree
-> PVector
-> (Double, PVector)
fitnessFun Int
nIter Distribution
distribution (SRMatrix
x, PVector
y, Maybe PVector
mYErr) (SRMatrix
x_val, PVector
y_val, Maybe PVector
mYErr_val) Fix SRTree
tree PVector
thetaOrig =
  if Double -> Bool
forall a. RealFloat a => a -> Bool
isNaN Double
val -- || isNaN tr
    then (-(Double
1Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
0), PVector
theta) -- infinity
    else (Double
val, PVector
theta)
  where
    --tree          = relabelParams _tree
    nParams :: Int
nParams       = Fix SRTree -> Int
countParamsUniq Fix SRTree
tree Int -> Int -> Int
forall a. Num a => a -> a -> a
+ if Distribution
distribution Distribution -> Distribution -> Bool
forall a. Eq a => a -> a -> Bool
== Distribution
ROXY then Int
3 else if Distribution
distribution Distribution -> Distribution -> Bool
forall a. Eq a => a -> a -> Bool
== Distribution
Gaussian then Int
1 else Int
0
    (PVector
theta, Double
_, Int
_) = (ObjectiveD -> Maybe VectorStorage -> LocalAlgorithm)
-> Distribution
-> Maybe PVector
-> Int
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (PVector, Double, Int)
minimizeNLL' ObjectiveD -> Maybe VectorStorage -> LocalAlgorithm
VAR1 Distribution
distribution Maybe PVector
mYErr Int
nIter SRMatrix
x PVector
y Fix SRTree
tree PVector
thetaOrig
    evalF :: SRMatrix -> PVector -> Maybe PVector -> Double
evalF SRMatrix
a PVector
b Maybe PVector
c   = Double -> Double
forall a. Num a => a -> a
negate (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> Double
nll Distribution
distribution Maybe PVector
c SRMatrix
a PVector
b Fix SRTree
tree (PVector -> Double) -> PVector -> Double
forall a b. (a -> b) -> a -> b
$ if Int
nParams Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 then PVector
thetaOrig else PVector
theta
    --tr            = evalF x y mYErr
    val :: Double
val           = SRMatrix -> PVector -> Maybe PVector -> Double
evalF SRMatrix
x_val PVector
y_val Maybe PVector
mYErr_val

--{-# INLINE fitnessFun #-}

fitnessFunRep :: Int -> Int -> Distribution -> DataSet -> DataSet -> Fix SRTree -> RndEGraph (Double, PVector)
fitnessFunRep :: Int
-> Int
-> Distribution
-> DataSet
-> DataSet
-> Fix SRTree
-> RndEGraph (Double, PVector)
fitnessFunRep Int
nRep Int
nIter Distribution
distribution DataSet
dataTrain DataSet
dataVal Fix SRTree
tree = do
    let nParams :: Int
nParams = Fix SRTree -> Int
countParamsUniq Fix SRTree
tree Int -> Int -> Int
forall a. Num a => a -> a -> a
+ if Distribution
distribution Distribution -> Distribution -> Bool
forall a. Eq a => a -> a -> Bool
== Distribution
ROXY then Int
3 else if Distribution
distribution Distribution -> Distribution -> Bool
forall a. Eq a => a -> a -> Bool
== Distribution
Gaussian then Int
1 else Int
0
    [PVector]
thetaOrigs <- Int
-> StateT EGraph (StateT StdGen IO) PVector
-> StateT EGraph (StateT StdGen IO) [PVector]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
nRep (StateT StdGen IO PVector
-> StateT EGraph (StateT StdGen IO) PVector
forall a. StateT StdGen IO a -> RndEGraph a
rnd (StateT StdGen IO PVector
 -> StateT EGraph (StateT StdGen IO) PVector)
-> StateT StdGen IO PVector
-> StateT EGraph (StateT StdGen IO) PVector
forall a b. (a -> b) -> a -> b
$ Int -> StateT StdGen IO PVector
forall (m :: * -> *). Monad m => Int -> Rng m PVector
randomVec Int
nParams)
    let fits :: (Double, PVector)
fits = ((Double, PVector) -> (Double, PVector) -> Ordering)
-> [(Double, PVector)] -> (Double, PVector)
forall (t :: * -> *) a.
Foldable t =>
(a -> a -> Ordering) -> t a -> a
maximumBy (Double -> Double -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Double -> Double -> Ordering)
-> ((Double, PVector) -> Double)
-> (Double, PVector)
-> (Double, PVector)
-> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (Double, PVector) -> Double
forall a b. (a, b) -> a
fst) ([(Double, PVector)] -> (Double, PVector))
-> [(Double, PVector)] -> (Double, PVector)
forall a b. (a -> b) -> a -> b
$ (PVector -> (Double, PVector)) -> [PVector] -> [(Double, PVector)]
forall a b. (a -> b) -> [a] -> [b]
Prelude.map (Int
-> Distribution
-> DataSet
-> DataSet
-> Fix SRTree
-> PVector
-> (Double, PVector)
fitnessFun Int
nIter Distribution
distribution DataSet
dataTrain DataSet
dataVal Fix SRTree
tree) [PVector]
thetaOrigs
    (Double, PVector) -> RndEGraph (Double, PVector)
forall a. a -> StateT EGraph (StateT StdGen IO) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Double, PVector)
fits
--{-# INLINE fitnessFunRep #-}


fitnessMV :: Bool -> Int -> Int -> Distribution -> [(DataSet, DataSet)] -> Fix SRTree -> RndEGraph (Double, [PVector])
fitnessMV :: Bool
-> Int
-> Int
-> Distribution
-> [(DataSet, DataSet)]
-> Fix SRTree
-> RndEGraph (Double, [PVector])
fitnessMV Bool
shouldReparam Int
nRep Int
nIter Distribution
distribution [(DataSet, DataSet)]
dataTrainsVals Fix SRTree
_tree = do
  let tree :: Fix SRTree
tree = if Bool
shouldReparam then Fix SRTree -> Fix SRTree
relabelParams Fix SRTree
_tree else Fix SRTree -> Fix SRTree
relabelParamsOrder Fix SRTree
_tree
  [(Double, PVector)]
response <- [(DataSet, DataSet)]
-> ((DataSet, DataSet) -> RndEGraph (Double, PVector))
-> StateT EGraph (StateT StdGen IO) [(Double, PVector)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(DataSet, DataSet)]
dataTrainsVals (((DataSet, DataSet) -> RndEGraph (Double, PVector))
 -> StateT EGraph (StateT StdGen IO) [(Double, PVector)])
-> ((DataSet, DataSet) -> RndEGraph (Double, PVector))
-> StateT EGraph (StateT StdGen IO) [(Double, PVector)]
forall a b. (a -> b) -> a -> b
$ \(DataSet
dt, DataSet
dv) -> Int
-> Int
-> Distribution
-> DataSet
-> DataSet
-> Fix SRTree
-> RndEGraph (Double, PVector)
fitnessFunRep Int
nRep Int
nIter Distribution
distribution DataSet
dt DataSet
dv Fix SRTree
tree
  (Double, [PVector]) -> RndEGraph (Double, [PVector])
forall a. a -> StateT EGraph (StateT StdGen IO) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Double] -> Double
forall a. Ord a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
minimum (((Double, PVector) -> Double) -> [(Double, PVector)] -> [Double]
forall a b. (a -> b) -> [a] -> [b]
Prelude.map (Double, PVector) -> Double
forall a b. (a, b) -> a
fst [(Double, PVector)]
response), ((Double, PVector) -> PVector) -> [(Double, PVector)] -> [PVector]
forall a b. (a -> b) -> [a] -> [b]
Prelude.map (Double, PVector) -> PVector
forall a b. (a, b) -> b
snd [(Double, PVector)]
response)





-- RndEGraph utils
-- fitFun fitnessFunRep rep iter distribution x y mYErr x_val y_val mYErr_val
insertExpr :: Fix SRTree -> (Fix SRTree -> RndEGraph (Double, [PVector])) -> RndEGraph EClassId
insertExpr :: Fix SRTree
-> (Fix SRTree -> RndEGraph (Double, [PVector])) -> RndEGraph Int
insertExpr Fix SRTree
t Fix SRTree -> RndEGraph (Double, [PVector])
fitFun = do
    Int
ecId <- (ENode -> Int) -> Fix SRTree -> RndEGraph Int
forall (m :: * -> *).
Monad m =>
(ENode -> Int) -> Fix SRTree -> EGraphST m Int
fromTree ENode -> Int
myCost Fix SRTree
t RndEGraph Int -> (Int -> RndEGraph Int) -> RndEGraph Int
forall a b.
StateT EGraph (StateT StdGen IO) a
-> (a -> StateT EGraph (StateT StdGen IO) b)
-> StateT EGraph (StateT StdGen IO) b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int -> RndEGraph Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical
    (Double
f, [PVector]
p) <- Fix SRTree -> RndEGraph (Double, [PVector])
fitFun Fix SRTree
t
    Int -> Double -> [PVector] -> EGraphST (StateT StdGen IO) ()
forall (m :: * -> *).
Monad m =>
Int -> Double -> [PVector] -> EGraphST m ()
insertFitness Int
ecId Double
f [PVector]
p
    Int -> RndEGraph Int
forall a. a -> StateT EGraph (StateT StdGen IO) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
ecId
  where powabs :: Fix SRTree -> Fix SRTree -> Fix SRTree
powabs Fix SRTree
l Fix SRTree
r  = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (Op -> Fix SRTree -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Op -> val -> val -> SRTree val
Bin Op
PowerAbs Fix SRTree
l Fix SRTree
r)

updateIfNothing :: (Fix SRTree -> StateT EGraph m (Double, [PVector]))
-> Int -> StateT EGraph m Bool
updateIfNothing Fix SRTree -> StateT EGraph m (Double, [PVector])
fitFun Int
ec = do
      Maybe Double
mf <- Int -> EGraphST m (Maybe Double)
forall (m :: * -> *). Monad m => Int -> EGraphST m (Maybe Double)
getFitness Int
ec
      case Maybe Double
mf of
        Maybe Double
Nothing -> do
          Fix SRTree
t <- Int -> EGraphST m (Fix SRTree)
forall (m :: * -> *). Monad m => Int -> EGraphST m (Fix SRTree)
getBestExpr Int
ec
          (Double
f, [PVector]
p) <- Fix SRTree -> StateT EGraph m (Double, [PVector])
fitFun Fix SRTree
t
          Int -> Double -> [PVector] -> EGraphST m ()
forall (m :: * -> *).
Monad m =>
Int -> Double -> [PVector] -> EGraphST m ()
insertFitness Int
ec Double
f [PVector]
p
          Bool -> StateT EGraph m Bool
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
        Just Double
_ -> Bool -> StateT EGraph m Bool
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False

pickRndSubTree :: RndEGraph (Maybe EClassId)
pickRndSubTree :: RndEGraph (Maybe Int)
pickRndSubTree = do [Int]
ecIds <- (EGraph -> [Int]) -> StateT EGraph (StateT StdGen IO) [Int]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (IntSet -> [Int]
IntSet.toList (IntSet -> [Int]) -> (EGraph -> IntSet) -> EGraph -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraphDB -> IntSet
_unevaluated (EGraphDB -> IntSet) -> (EGraph -> EGraphDB) -> EGraph -> IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> EGraphDB
_eDB)
                    if Bool -> Bool
not ([Int] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
ecIds)
                          then do Int
rndId' <- StateT StdGen IO Int -> RndEGraph Int
forall a. StateT StdGen IO a -> RndEGraph a
rnd (StateT StdGen IO Int -> RndEGraph Int)
-> StateT StdGen IO Int -> RndEGraph Int
forall a b. (a -> b) -> a -> b
$ [Int] -> StateT StdGen IO Int
forall (m :: * -> *) a. Monad m => [a] -> Rng m a
randomFrom [Int]
ecIds
                                  Int
rndId  <- Int -> RndEGraph Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical Int
rndId'
                                  Consts
constType <- (EGraph -> Consts) -> StateT EGraph (StateT StdGen IO) Consts
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EClassData -> Consts
_consts (EClassData -> Consts)
-> (EGraph -> EClassData) -> EGraph -> Consts
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info (EClass -> EClassData)
-> (EGraph -> EClass) -> EGraph -> EClassData
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap EClass -> Int -> EClass
forall a. IntMap a -> Int -> a
IM.! Int
rndId) (IntMap EClass -> EClass)
-> (EGraph -> IntMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> IntMap EClass
_eClass)
                                  case Consts
constType of
                                    Consts
NotConst -> Maybe Int -> RndEGraph (Maybe Int)
forall a. a -> StateT EGraph (StateT StdGen IO) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Int -> RndEGraph (Maybe Int))
-> Maybe Int -> RndEGraph (Maybe Int)
forall a b. (a -> b) -> a -> b
$ Int -> Maybe Int
forall a. a -> Maybe a
Just Int
rndId
                                    Consts
_        -> Maybe Int -> RndEGraph (Maybe Int)
forall a. a -> StateT EGraph (StateT StdGen IO) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe Int
forall a. Maybe a
Nothing
                          else Maybe Int -> RndEGraph (Maybe Int)
forall a. a -> StateT EGraph (StateT StdGen IO) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe Int
forall a. Maybe a
Nothing

getParetoEcsUpTo :: Int -> Int -> StateT EGraph m [Int]
getParetoEcsUpTo Int
n Int
maxSize = [[Int]] -> [Int]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Int]] -> [Int])
-> StateT EGraph m [[Int]] -> StateT EGraph m [Int]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int] -> (Int -> StateT EGraph m [Int]) -> StateT EGraph m [[Int]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int
1..Int
maxSize] (\Int
i -> Int -> Int -> StateT EGraph m [Int]
forall (m :: * -> *). Monad m => Int -> Int -> EGraphST m [Int]
getTopFitEClassWithSize Int
i Int
n)
getParetoDLEcsUpTo :: Int -> Int -> StateT EGraph m [Int]
getParetoDLEcsUpTo Int
n Int
maxSize = [[Int]] -> [Int]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Int]] -> [Int])
-> StateT EGraph m [[Int]] -> StateT EGraph m [Int]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int] -> (Int -> StateT EGraph m [Int]) -> StateT EGraph m [[Int]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int
1..Int
maxSize] (\Int
i -> Int -> Int -> StateT EGraph m [Int]
forall (m :: * -> *). Monad m => Int -> Int -> EGraphST m [Int]
getTopDLEClassWithSize Int
i Int
n)

getBestExprWithSize :: Int -> StateT EGraph m [(Int, Maybe Double)]
getBestExprWithSize Int
n =
        do [Int]
ec <- Int -> Int -> EGraphST m [Int]
forall (m :: * -> *). Monad m => Int -> Int -> EGraphST m [Int]
getTopFitEClassWithSize Int
n Int
1 EGraphST m [Int] -> ([Int] -> EGraphST m [Int]) -> EGraphST m [Int]
forall a b.
StateT EGraph m a -> (a -> StateT EGraph m b) -> StateT EGraph m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Int -> StateT EGraph m Int) -> [Int] -> EGraphST m [Int]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse Int -> StateT EGraph m Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical
           if (Bool -> Bool
not ([Int] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
ec))
            then do
              Maybe Double
bestFit <- Int -> EGraphST m (Maybe Double)
forall (m :: * -> *). Monad m => Int -> EGraphST m (Maybe Double)
getFitness (Int -> EGraphST m (Maybe Double))
-> Int -> EGraphST m (Maybe Double)
forall a b. (a -> b) -> a -> b
$ [Int] -> Int
forall a. HasCallStack => [a] -> a
head [Int]
ec
              [PVector]
bestP   <- (EGraph -> [PVector]) -> StateT EGraph m [PVector]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EClassData -> [PVector]
_theta (EClassData -> [PVector])
-> (EGraph -> EClassData) -> EGraph -> [PVector]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info (EClass -> EClassData)
-> (EGraph -> EClass) -> EGraph -> EClassData
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap EClass -> Int -> EClass
forall a. IntMap a -> Int -> a
IM.! ([Int] -> Int
forall a. HasCallStack => [a] -> a
head [Int]
ec)) (IntMap EClass -> EClass)
-> (EGraph -> IntMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> IntMap EClass
_eClass)
              [(Int, Maybe Double)] -> StateT EGraph m [(Int, Maybe Double)]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [([Int] -> Int
forall a. HasCallStack => [a] -> a
head [Int]
ec, Maybe Double
bestFit)]
            else [(Int, Maybe Double)] -> StateT EGraph m [(Int, Maybe Double)]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []

insertRndExpr :: Int -> Rng IO (Fix SRTree) -> Rng IO (SRTree ()) -> RndEGraph Int
insertRndExpr Int
maxSize Rng IO (Fix SRTree)
rndTerm Rng IO (SRTree ())
rndNonTerm =
      do Bool
grow <- StateT StdGen IO Bool -> RndEGraph Bool
forall a. StateT StdGen IO a -> RndEGraph a
rnd StateT StdGen IO Bool
forall (m :: * -> *). Monad m => Rng m Bool
toss
         Int
n <- StateT StdGen IO Int -> RndEGraph Int
forall a. StateT StdGen IO a -> RndEGraph a
rnd ([Int] -> StateT StdGen IO Int
forall (m :: * -> *) a. Monad m => [a] -> Rng m a
randomFrom [if Int
maxSize Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
4 then Int
4 else Int
1 .. Int
maxSize])
         Fix SRTree
t <- Rng IO (Fix SRTree) -> RndEGraph (Fix SRTree)
forall a. StateT StdGen IO a -> RndEGraph a
rnd (Rng IO (Fix SRTree) -> RndEGraph (Fix SRTree))
-> Rng IO (Fix SRTree) -> RndEGraph (Fix SRTree)
forall a b. (a -> b) -> a -> b
$ Int
-> Int
-> Int
-> Rng IO (Fix SRTree)
-> Rng IO (SRTree ())
-> Bool
-> Rng IO (Fix SRTree)
forall (m :: * -> *).
Monad m =>
Int
-> Int
-> Int
-> Rng m (Fix SRTree)
-> Rng m (SRTree ())
-> Bool
-> Rng m (Fix SRTree)
Random.randomTree Int
3 Int
8 Int
n Rng IO (Fix SRTree)
rndTerm Rng IO (SRTree ())
rndNonTerm Bool
grow
         (ENode -> Int) -> Fix SRTree -> RndEGraph Int
forall (m :: * -> *).
Monad m =>
(ENode -> Int) -> Fix SRTree -> EGraphST m Int
fromTree ENode -> Int
myCost Fix SRTree
t RndEGraph Int -> (Int -> RndEGraph Int) -> RndEGraph Int
forall a b.
StateT EGraph (StateT StdGen IO) a
-> (a -> StateT EGraph (StateT StdGen IO) b)
-> StateT EGraph (StateT StdGen IO) b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int -> RndEGraph Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical

refit :: (Fix SRTree -> StateT EGraph m (Double, [PVector]))
-> Int -> StateT EGraph m ()
refit Fix SRTree -> StateT EGraph m (Double, [PVector])
fitFun Int
ec = do
  Fix SRTree
t <- Int -> EGraphST m (Fix SRTree)
forall (m :: * -> *). Monad m => Int -> EGraphST m (Fix SRTree)
getBestExpr Int
ec
  (Double
f, [PVector]
p) <- Fix SRTree -> StateT EGraph m (Double, [PVector])
fitFun Fix SRTree
t
  Maybe Double
mf <- Int -> EGraphST m (Maybe Double)
forall (m :: * -> *). Monad m => Int -> EGraphST m (Maybe Double)
getFitness Int
ec
  case Maybe Double
mf of
    Maybe Double
Nothing -> Int -> Double -> [PVector] -> StateT EGraph m ()
forall (m :: * -> *).
Monad m =>
Int -> Double -> [PVector] -> EGraphST m ()
insertFitness Int
ec Double
f [PVector]
p
    Just Double
f' -> Bool -> StateT EGraph m () -> StateT EGraph m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Double
f Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
> Double
f') (StateT EGraph m () -> StateT EGraph m ())
-> StateT EGraph m () -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ Int -> Double -> [PVector] -> StateT EGraph m ()
forall (m :: * -> *).
Monad m =>
Int -> Double -> [PVector] -> EGraphST m ()
insertFitness Int
ec Double
f [PVector]
p

--printBest :: (Int -> EClassId -> RndEGraph ()) -> RndEGraph ()
printBest :: p -> (t -> Int -> StateT EGraph m b) -> StateT EGraph m b
printBest p
fitFun t -> Int -> StateT EGraph m b
printExprFun = do
      Int
bec <- (EGraph -> Int) -> StateT EGraph m Int
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((Double, Int) -> Int
forall a b. (a, b) -> b
snd ((Double, Int) -> Int)
-> (EGraph -> (Double, Int)) -> EGraph -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RangeTree Double -> (Double, Int)
forall a. Ord a => RangeTree a -> (a, Int)
getGreatest (RangeTree Double -> (Double, Int))
-> (EGraph -> RangeTree Double) -> EGraph -> (Double, Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraphDB -> RangeTree Double
_fitRangeDB (EGraphDB -> RangeTree Double)
-> (EGraph -> EGraphDB) -> EGraph -> RangeTree Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> EGraphDB
_eDB) StateT EGraph m Int
-> (Int -> StateT EGraph m Int) -> StateT EGraph m Int
forall a b.
StateT EGraph m a -> (a -> StateT EGraph m b) -> StateT EGraph m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int -> StateT EGraph m Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical
      Maybe Double
bestFit <- (EGraph -> Maybe Double) -> StateT EGraph m (Maybe Double)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EClassData -> Maybe Double
_fitness(EClassData -> Maybe Double)
-> (EGraph -> EClassData) -> EGraph -> Maybe Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info (EClass -> EClassData)
-> (EGraph -> EClass) -> EGraph -> EClassData
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap EClass -> Int -> EClass
forall a. IntMap a -> Int -> a
IM.! Int
bec) (IntMap EClass -> EClass)
-> (EGraph -> IntMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> IntMap EClass
_eClass)
      --refit fitFun bec
      --io.print $ "should be " <> show bestFit
      t -> Int -> StateT EGraph m b
printExprFun t
0 Int
bec

--paretoFront :: Int -> (Int -> EClassId -> RndEGraph ()) -> RndEGraph ()
paretoFront :: (Fix SRTree -> RndEGraph (Double, [PVector]))
-> Int
-> (Int -> Int -> StateT EGraph (StateT StdGen IO) [String])
-> RndEGraph [[String]]
paretoFront Fix SRTree -> RndEGraph (Double, [PVector])
fitFun Int
maxSize Int -> Int -> StateT EGraph (StateT StdGen IO) [String]
printExprFun = Int -> Int -> Double -> RndEGraph [[String]]
go Int
1 Int
0 (-(Double
1.0Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
0.0))
    where
    go :: Int -> Int -> Double -> RndEGraph [[String]]
    go :: Int -> Int -> Double -> RndEGraph [[String]]
go Int
n Int
ix Double
f
        | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
maxSize = [[String]] -> RndEGraph [[String]]
forall a. a -> StateT EGraph (StateT StdGen IO) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
        | Bool
otherwise   = do
            [(Int, Maybe Double)]
ecList <- Int -> StateT EGraph (StateT StdGen IO) [(Int, Maybe Double)]
forall {m :: * -> *}.
Monad m =>
Int -> StateT EGraph m [(Int, Maybe Double)]
getBestExprWithSize Int
n
            if Bool -> Bool
not ([(Int, Maybe Double)] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Int, Maybe Double)]
ecList)
                then do let (Int
ec, Maybe Double
mf) = [(Int, Maybe Double)] -> (Int, Maybe Double)
forall a. HasCallStack => [a] -> a
head [(Int, Maybe Double)]
ecList
                            f' :: Double
f' = Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust Maybe Double
mf
                            improved :: Bool
improved = Double
f' Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
>= Double
f Bool -> Bool -> Bool
&& (Bool -> Bool
not (Bool -> Bool) -> (Double -> Bool) -> Double -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> Bool
forall a. RealFloat a => a -> Bool
isNaN) Double
f' Bool -> Bool -> Bool
&& (Bool -> Bool
not (Bool -> Bool) -> (Double -> Bool) -> Double -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> Bool
forall a. RealFloat a => a -> Bool
isInfinite) Double
f'
                        Int
ec' <- Int -> RndEGraph Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical Int
ec
                        if Bool
improved
                                then do (Fix SRTree -> RndEGraph (Double, [PVector]))
-> Int -> EGraphST (StateT StdGen IO) ()
forall {m :: * -> *}.
Monad m =>
(Fix SRTree -> StateT EGraph m (Double, [PVector]))
-> Int -> StateT EGraph m ()
refit Fix SRTree -> RndEGraph (Double, [PVector])
fitFun Int
ec'
                                        [String]
t <- Int -> Int -> StateT EGraph (StateT StdGen IO) [String]
printExprFun Int
ix Int
ec'
                                        [[String]]
ts <- Int -> Int -> Double -> RndEGraph [[String]]
go (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) (Int
ix Int -> Int -> Int
forall a. Num a => a -> a -> a
+ if Bool
improved then Int
1 else Int
0) (Double -> Double -> Double
forall a. Ord a => a -> a -> a
max Double
f Double
f')
                                        [[String]] -> RndEGraph [[String]]
forall a. a -> StateT EGraph (StateT StdGen IO) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([String]
t[String] -> [[String]] -> [[String]]
forall a. a -> [a] -> [a]
:[[String]]
ts)
                                else Int -> Int -> Double -> RndEGraph [[String]]
go (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) (Int
ix Int -> Int -> Int
forall a. Num a => a -> a -> a
+ if Bool
improved then Int
1 else Int
0) (Double -> Double -> Double
forall a. Ord a => a -> a -> a
max Double
f Double
f')
                else Int -> Int -> Double -> RndEGraph [[String]]
go (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Int
ix Double
f

evaluateUnevaluated :: (Fix SRTree -> StateT EGraph m (Double, [PVector]))
-> StateT EGraph m ()
evaluateUnevaluated Fix SRTree -> StateT EGraph m (Double, [PVector])
fitFun = do
          [Int]
ec <- (EGraph -> [Int]) -> StateT EGraph m [Int]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (IntSet -> [Int]
IntSet.toList (IntSet -> [Int]) -> (EGraph -> IntSet) -> EGraph -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraphDB -> IntSet
_unevaluated (EGraphDB -> IntSet) -> (EGraph -> EGraphDB) -> EGraph -> IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> EGraphDB
_eDB)
          [Int] -> (Int -> StateT EGraph m ()) -> StateT EGraph m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int]
ec ((Int -> StateT EGraph m ()) -> StateT EGraph m ())
-> (Int -> StateT EGraph m ()) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ \Int
c -> do
              Fix SRTree
t <- Int -> EGraphST m (Fix SRTree)
forall (m :: * -> *). Monad m => Int -> EGraphST m (Fix SRTree)
getBestExpr Int
c
              (Double
f, [PVector]
p) <- Fix SRTree -> StateT EGraph m (Double, [PVector])
fitFun Fix SRTree
t
              Int -> Double -> [PVector] -> StateT EGraph m ()
forall (m :: * -> *).
Monad m =>
Int -> Double -> [PVector] -> EGraphST m ()
insertFitness Int
c Double
f [PVector]
p

evaluateRndUnevaluated :: (Fix SRTree -> RndEGraph (Double, [PVector])) -> RndEGraph Int
evaluateRndUnevaluated Fix SRTree -> RndEGraph (Double, [PVector])
fitFun = do
          [Int]
ec <- (EGraph -> [Int]) -> StateT EGraph (StateT StdGen IO) [Int]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (IntSet -> [Int]
IntSet.toList (IntSet -> [Int]) -> (EGraph -> IntSet) -> EGraph -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraphDB -> IntSet
_unevaluated (EGraphDB -> IntSet) -> (EGraph -> EGraphDB) -> EGraph -> IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> EGraphDB
_eDB)
          Int
c <- StateT StdGen IO Int -> RndEGraph Int
forall a. StateT StdGen IO a -> RndEGraph a
rnd (StateT StdGen IO Int -> RndEGraph Int)
-> ([Int] -> StateT StdGen IO Int) -> [Int] -> RndEGraph Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> StateT StdGen IO Int
forall (m :: * -> *) a. Monad m => [a] -> Rng m a
randomFrom ([Int] -> RndEGraph Int) -> [Int] -> RndEGraph Int
forall a b. (a -> b) -> a -> b
$ [Int]
ec
          Fix SRTree
t <- Int -> RndEGraph (Fix SRTree)
forall (m :: * -> *). Monad m => Int -> EGraphST m (Fix SRTree)
getBestExpr Int
c
          (Double
f, [PVector]
p) <- Fix SRTree -> RndEGraph (Double, [PVector])
fitFun Fix SRTree
t
          Int -> Double -> [PVector] -> EGraphST (StateT StdGen IO) ()
forall (m :: * -> *).
Monad m =>
Int -> Double -> [PVector] -> EGraphST m ()
insertFitness Int
c Double
f [PVector]
p
          Int -> RndEGraph Int
forall a. a -> StateT EGraph (StateT StdGen IO) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
c

-- | check whether an e-node exists or does not exist in the e-graph
doesExist, doesNotExist :: ENode -> RndEGraph Bool
doesExist :: ENode -> RndEGraph Bool
doesExist ENode
en = (EGraph -> Bool) -> RndEGraph Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ENode -> Map ENode Int -> Bool
forall k a. Ord k => k -> Map k a -> Bool
Map.member ENode
en) (Map ENode Int -> Bool)
-> (EGraph -> Map ENode Int) -> EGraph -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> Map ENode Int
_eNodeToEClass)
doesNotExist :: ENode -> RndEGraph Bool
doesNotExist ENode
en = (EGraph -> Bool) -> RndEGraph Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ENode -> Map ENode Int -> Bool
forall k a. Ord k => k -> Map k a -> Bool
Map.notMember ENode
en) (Map ENode Int -> Bool)
-> (EGraph -> Map ENode Int) -> EGraph -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> Map ENode Int
_eNodeToEClass)

-- | check whether the partial tree defined by a list of ancestors will create
-- a non-existent expression when combined with a certain e-node.
doesNotExistGens :: [Maybe (EClassId -> ENode)] -> ENode -> RndEGraph Bool
doesNotExistGens :: [Maybe (Int -> ENode)] -> ENode -> RndEGraph Bool
doesNotExistGens []              ENode
en = (EGraph -> Bool) -> RndEGraph Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ENode -> Map ENode Int -> Bool
forall k a. Ord k => k -> Map k a -> Bool
Map.notMember ENode
en) (Map ENode Int -> Bool)
-> (EGraph -> Map ENode Int) -> EGraph -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> Map ENode Int
_eNodeToEClass)
doesNotExistGens (Maybe (Int -> ENode)
mGrand:[Maybe (Int -> ENode)]
grands) ENode
en = do  Bool
b <- (EGraph -> Bool) -> RndEGraph Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ENode -> Map ENode Int -> Bool
forall k a. Ord k => k -> Map k a -> Bool
Map.notMember ENode
en) (Map ENode Int -> Bool)
-> (EGraph -> Map ENode Int) -> EGraph -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> Map ENode Int
_eNodeToEClass)
                                          if Bool
b
                                            then Bool -> RndEGraph Bool
forall a. a -> StateT EGraph (StateT StdGen IO) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
                                            else case Maybe (Int -> ENode)
mGrand of
                                                Maybe (Int -> ENode)
Nothing -> Bool -> RndEGraph Bool
forall a. a -> StateT EGraph (StateT StdGen IO) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
                                                Just Int -> ENode
gf -> do Int
ec  <- (EGraph -> Int) -> RndEGraph Int
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((Map ENode Int -> ENode -> Int
forall k a. Ord k => Map k a -> k -> a
Map.! ENode
en) (Map ENode Int -> Int)
-> (EGraph -> Map ENode Int) -> EGraph -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> Map ENode Int
_eNodeToEClass)
                                                              ENode
en' <- ENode -> EGraphST (StateT StdGen IO) ENode
forall (m :: * -> *). Monad m => ENode -> EGraphST m ENode
canonize (Int -> ENode
gf Int
ec)
                                                              [Maybe (Int -> ENode)] -> ENode -> RndEGraph Bool
doesNotExistGens [Maybe (Int -> ENode)]
grands ENode
en'

-- | check whether combining a partial tree `parent` with the e-node `en'`
-- will create a new expression
checkToken :: (Int -> ENode) -> ENode -> RndEGraph Bool
checkToken Int -> ENode
parent ENode
en' = do  ENode
en <- ENode -> EGraphST (StateT StdGen IO) ENode
forall (m :: * -> *). Monad m => ENode -> EGraphST m ENode
canonize ENode
en'
                            Maybe Int
mEc <- (EGraph -> Maybe Int) -> RndEGraph (Maybe Int)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((Map ENode Int -> ENode -> Maybe Int
forall k a. Ord k => Map k a -> k -> Maybe a
Map.!? ENode
en) (Map ENode Int -> Maybe Int)
-> (EGraph -> Map ENode Int) -> EGraph -> Maybe Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> Map ENode Int
_eNodeToEClass)
                            case Maybe Int
mEc of
                                Maybe Int
Nothing -> Bool -> RndEGraph Bool
forall a. a -> StateT EGraph (StateT StdGen IO) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
                                Just Int
ec -> do Int
ec' <- Int -> RndEGraph Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical Int
ec
                                              ENode
ec'' <- ENode -> EGraphST (StateT StdGen IO) ENode
forall (m :: * -> *). Monad m => ENode -> EGraphST m ENode
canonize (Int -> ENode
parent Int
ec')
                                              Bool -> Bool
not (Bool -> Bool) -> RndEGraph Bool -> RndEGraph Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ENode -> RndEGraph Bool
doesExist ENode
ec''