-----------------------------------------------------------------------------
-- |
-- Module      :  Algorithm.EqSat.Search
-- Copyright   :  (c) Fabricio Olivetti 2021 - 2024
-- License     :  BSD3
-- Maintainer  :  fabricio.olivetti@gmail.com
-- Stability   :  experimental
-- Portability :
--
-- Support functions for search symbolic expressions with e-graphs
--
-----------------------------------------------------------------------------

module Algorithm.EqSat.SearchSRCache where

import Data.SRTree
import Data.SRTree.Datasets
import System.Random
import Control.Monad.State.Strict
import Algorithm.EqSat.Egraph
import Algorithm.SRTree.Likelihoods
import qualified Data.IntMap as IM
import qualified Data.IntSet as IntSet
import qualified Data.SRTree.Random as Random
import Data.Function ( on )
import Algorithm.SRTree.Likelihoods
import Algorithm.SRTree.NonlinearOpt
import Control.Monad ( when, replicateM, forM, forM_ )
import Algorithm.EqSat.Egraph
import Algorithm.SRTree.Opt
import Algorithm.EqSat.Info
import Algorithm.EqSat.Build
import Data.Maybe ( fromJust )
import Data.SRTree.Random
import Algorithm.EqSat.Queries
import Data.List ( maximumBy )
import qualified Data.Map.Strict as Map
import Control.Monad.Identity

import Debug.Trace

-- Environment of an e-graph with support to random generator and IO
type RndEGraph a = EGraphST (StateT StdGen (StateT [ECache] IO)) a

io :: IO a -> RndEGraph a
io :: forall a. IO a -> RndEGraph a
io = StateT StdGen (StateT [ECache] IO) a
-> StateT EGraph (StateT StdGen (StateT [ECache] IO)) a
forall (m :: * -> *) a. Monad m => m a -> StateT EGraph m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (StateT StdGen (StateT [ECache] IO) a
 -> StateT EGraph (StateT StdGen (StateT [ECache] IO)) a)
-> (IO a -> StateT StdGen (StateT [ECache] IO) a)
-> IO a
-> StateT EGraph (StateT StdGen (StateT [ECache] IO)) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StateT [ECache] IO a -> StateT StdGen (StateT [ECache] IO) a
forall (m :: * -> *) a. Monad m => m a -> StateT StdGen m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (StateT [ECache] IO a -> StateT StdGen (StateT [ECache] IO) a)
-> (IO a -> StateT [ECache] IO a)
-> IO a
-> StateT StdGen (StateT [ECache] IO) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO a -> StateT [ECache] IO a
forall (m :: * -> *) a. Monad m => m a -> StateT [ECache] m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift
{-# INLINE io #-}
getCache :: StateT [ECache] IO a -> RndEGraph a
getCache :: forall a. StateT [ECache] IO a -> RndEGraph a
getCache = StateT StdGen (StateT [ECache] IO) a
-> StateT EGraph (StateT StdGen (StateT [ECache] IO)) a
forall (m :: * -> *) a. Monad m => m a -> StateT EGraph m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (StateT StdGen (StateT [ECache] IO) a
 -> StateT EGraph (StateT StdGen (StateT [ECache] IO)) a)
-> (StateT [ECache] IO a -> StateT StdGen (StateT [ECache] IO) a)
-> StateT [ECache] IO a
-> StateT EGraph (StateT StdGen (StateT [ECache] IO)) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StateT [ECache] IO a -> StateT StdGen (StateT [ECache] IO) a
forall (m :: * -> *) a. Monad m => m a -> StateT StdGen m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift
rnd :: StateT StdGen (StateT [ECache] IO)  a -> RndEGraph a
rnd :: forall a. StateT StdGen (StateT [ECache] IO) a -> RndEGraph a
rnd = StateT StdGen (StateT [ECache] IO) a
-> StateT EGraph (StateT StdGen (StateT [ECache] IO)) a
forall (m :: * -> *) a. Monad m => m a -> StateT EGraph m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift
{-# INLINE rnd #-}

myCost :: SRTree Int -> Int
myCost :: ENode -> Int
myCost (Var Int
_)     = Int
1
myCost (Const Double
_)   = Int
1
myCost (Param Int
_)   = Int
1
myCost (Bin Op
_ Int
l Int
r) = Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
r
myCost (Uni Function
_ Int
t)   = Int
3 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
t

while :: Monad f => (t -> Bool) -> t -> (t -> f t) -> f t
while :: forall (f :: * -> *) t.
Monad f =>
(t -> Bool) -> t -> (t -> f t) -> f t
while t -> Bool
p t
arg t -> f t
prog = do if (t -> Bool
p t
arg)
                      then do t
arg' <- t -> f t
prog t
arg
                              (t -> Bool) -> t -> (t -> f t) -> f t
forall (f :: * -> *) t.
Monad f =>
(t -> Bool) -> t -> (t -> f t) -> f t
while t -> Bool
p t
arg' t -> f t
prog
                      else t -> f t
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure t
arg

fitnessFun :: Int -> Distribution -> DataSet -> DataSet -> EGraph -> EClassId -> ECache -> PVector -> (Double, PVector, ECache)
fitnessFun :: Int
-> Distribution
-> DataSet
-> DataSet
-> EGraph
-> Int
-> ECache
-> PVector
-> (Double, PVector, ECache)
fitnessFun Int
nIter Distribution
distribution (SRMatrix
x, PVector
y, Maybe PVector
mYErr) (SRMatrix
x_val, PVector
y_val, Maybe PVector
mYErr_val) EGraph
egraph Int
root ECache
cache PVector
thetaOrig =
  if Double -> Bool
forall a. RealFloat a => a -> Bool
isNaN Double
val -- || isNaN tr
    then (-(Double
1Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
0), PVector
theta,ECache
cache') -- infinity
    else (Double
val, PVector
theta, ECache
cache')
  where
    tree :: Fix SRTree
tree          = Identity (Fix SRTree) -> Fix SRTree
forall a. Identity a -> a
runIdentity (Identity (Fix SRTree) -> Fix SRTree)
-> Identity (Fix SRTree) -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Int -> EGraphST Identity (Fix SRTree)
forall (m :: * -> *). Monad m => Int -> EGraphST m (Fix SRTree)
getBestExpr Int
root EGraphST Identity (Fix SRTree) -> EGraph -> Identity (Fix SRTree)
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
`evalStateT` EGraph
egraph
    nParams :: Int
nParams       = EGraph -> Int -> Int
countParamsUniqEg EGraph
egraph Int
root Int -> Int -> Int
forall a. Num a => a -> a -> a
+ if Distribution
distribution Distribution -> Distribution -> Bool
forall a. Eq a => a -> a -> Bool
== Distribution
ROXY then Int
3 else if Distribution
distribution Distribution -> Distribution -> Bool
forall a. Eq a => a -> a -> Bool
== Distribution
Gaussian then Int
1 else Int
0
    (PVector
theta, Double
val, Int
_, ECache
cache') = (ObjectiveD -> Maybe VectorStorage -> LocalAlgorithm)
-> Distribution
-> Maybe PVector
-> Int
-> SRMatrix
-> PVector
-> EGraph
-> Int
-> ECache
-> PVector
-> (PVector, Double, Int, ECache)
minimizeNLLEGraph ObjectiveD -> Maybe VectorStorage -> LocalAlgorithm
VAR1 Distribution
distribution Maybe PVector
mYErr Int
nIter SRMatrix
x PVector
y EGraph
egraph Int
root ECache
cache PVector
thetaOrig
    evalF :: SRMatrix -> PVector -> Maybe PVector -> Double
evalF SRMatrix
a PVector
b Maybe PVector
c   = Double -> Double
forall a. Num a => a -> a
negate (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> Double
nll Distribution
distribution Maybe PVector
c SRMatrix
a PVector
b Fix SRTree
tree (PVector -> Double) -> PVector -> Double
forall a b. (a -> b) -> a -> b
$ if Int
nParams Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 then PVector
thetaOrig else PVector
theta
    -- val           = evalF x_val y_val mYErr_val

--{-# INLINE fitnessFun #-}

fitnessFunRep :: Int -> Int -> Distribution -> DataSet -> DataSet -> EClassId -> ECache -> RndEGraph (Double, PVector, ECache)
fitnessFunRep :: Int
-> Int
-> Distribution
-> DataSet
-> DataSet
-> Int
-> ECache
-> RndEGraph (Double, PVector, ECache)
fitnessFunRep Int
nRep Int
nIter Distribution
distribution DataSet
dataTrain DataSet
dataVal Int
root ECache
cache = do
    EGraph
egraph <- StateT EGraph (StateT StdGen (StateT [ECache] IO)) EGraph
forall s (m :: * -> *). MonadState s m => m s
get
    let nParams :: Int
nParams = EGraph -> Int -> Int
countParamsUniqEg EGraph
egraph Int
root Int -> Int -> Int
forall a. Num a => a -> a -> a
+ if Distribution
distribution Distribution -> Distribution -> Bool
forall a. Eq a => a -> a -> Bool
== Distribution
ROXY then Int
3 else if Distribution
distribution Distribution -> Distribution -> Bool
forall a. Eq a => a -> a -> Bool
== Distribution
Gaussian then Int
1 else Int
0
        fst' :: (a, b, c) -> a
fst' (a
a, b
_, c
_) = a
a
    [PVector]
thetaOrigs <- Int
-> StateT EGraph (StateT StdGen (StateT [ECache] IO)) PVector
-> StateT EGraph (StateT StdGen (StateT [ECache] IO)) [PVector]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
nRep (StateT StdGen (StateT [ECache] IO) PVector
-> StateT EGraph (StateT StdGen (StateT [ECache] IO)) PVector
forall a. StateT StdGen (StateT [ECache] IO) a -> RndEGraph a
rnd (StateT StdGen (StateT [ECache] IO) PVector
 -> StateT EGraph (StateT StdGen (StateT [ECache] IO)) PVector)
-> StateT StdGen (StateT [ECache] IO) PVector
-> StateT EGraph (StateT StdGen (StateT [ECache] IO)) PVector
forall a b. (a -> b) -> a -> b
$ Int -> StateT StdGen (StateT [ECache] IO) PVector
forall (m :: * -> *). Monad m => Int -> Rng m PVector
randomVec Int
nParams)
    let fits :: (Double, PVector, ECache)
fits = ((Double, PVector, ECache)
 -> (Double, PVector, ECache) -> Ordering)
-> [(Double, PVector, ECache)] -> (Double, PVector, ECache)
forall (t :: * -> *) a.
Foldable t =>
(a -> a -> Ordering) -> t a -> a
maximumBy (Double -> Double -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Double -> Double -> Ordering)
-> ((Double, PVector, ECache) -> Double)
-> (Double, PVector, ECache)
-> (Double, PVector, ECache)
-> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (Double, PVector, ECache) -> Double
forall {a} {b} {c}. (a, b, c) -> a
fst') ([(Double, PVector, ECache)] -> (Double, PVector, ECache))
-> [(Double, PVector, ECache)] -> (Double, PVector, ECache)
forall a b. (a -> b) -> a -> b
$ (PVector -> (Double, PVector, ECache))
-> [PVector] -> [(Double, PVector, ECache)]
forall a b. (a -> b) -> [a] -> [b]
Prelude.map (Int
-> Distribution
-> DataSet
-> DataSet
-> EGraph
-> Int
-> ECache
-> PVector
-> (Double, PVector, ECache)
fitnessFun Int
nIter Distribution
distribution DataSet
dataTrain DataSet
dataVal EGraph
egraph Int
root ECache
cache) [PVector]
thetaOrigs
    (Double, PVector, ECache) -> RndEGraph (Double, PVector, ECache)
forall a. a -> StateT EGraph (StateT StdGen (StateT [ECache] IO)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Double, PVector, ECache)
fits
--{-# INLINE fitnessFunRep #-}


fitnessMV :: Bool -> Int -> Int -> Distribution -> [(DataSet, DataSet)] -> EClassId -> RndEGraph (Double, [PVector])
fitnessMV :: Bool
-> Int
-> Int
-> Distribution
-> [(DataSet, DataSet)]
-> Int
-> RndEGraph (Double, [PVector])
fitnessMV Bool
shouldReparam Int
nRep Int
nIter Distribution
distribution [(DataSet, DataSet)]
dataTrainsVals Int
root = do
  -- let tree = if shouldReparam then relabelParams _tree else relabelParamsOrder _tree
  -- WARNING: this should be done BEFORE inserting into egraph, so it's up to the algorithm'
  [ECache]
caches <- StateT [ECache] IO [ECache] -> RndEGraph [ECache]
forall a. StateT [ECache] IO a -> RndEGraph a
getCache StateT [ECache] IO [ECache]
forall s (m :: * -> *). MonadState s m => m s
get
  [(Double, PVector, ECache)]
response <- [((DataSet, DataSet), ECache)]
-> (((DataSet, DataSet), ECache)
    -> RndEGraph (Double, PVector, ECache))
-> StateT
     EGraph
     (StateT StdGen (StateT [ECache] IO))
     [(Double, PVector, ECache)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([(DataSet, DataSet)] -> [ECache] -> [((DataSet, DataSet), ECache)]
forall a b. [a] -> [b] -> [(a, b)]
Prelude.zip [(DataSet, DataSet)]
dataTrainsVals [ECache]
caches) ((((DataSet, DataSet), ECache)
  -> RndEGraph (Double, PVector, ECache))
 -> StateT
      EGraph
      (StateT StdGen (StateT [ECache] IO))
      [(Double, PVector, ECache)])
-> (((DataSet, DataSet), ECache)
    -> RndEGraph (Double, PVector, ECache))
-> StateT
     EGraph
     (StateT StdGen (StateT [ECache] IO))
     [(Double, PVector, ECache)]
forall a b. (a -> b) -> a -> b
$ \((DataSet
dt, DataSet
dv), ECache
cache) -> Int
-> Int
-> Distribution
-> DataSet
-> DataSet
-> Int
-> ECache
-> RndEGraph (Double, PVector, ECache)
fitnessFunRep Int
nRep Int
nIter Distribution
distribution DataSet
dt DataSet
dv Int
root ECache
cache
  StateT [ECache] IO () -> RndEGraph ()
forall a. StateT [ECache] IO a -> RndEGraph a
getCache (StateT [ECache] IO () -> RndEGraph ())
-> StateT [ECache] IO () -> RndEGraph ()
forall a b. (a -> b) -> a -> b
$ [ECache] -> StateT [ECache] IO ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (((Double, PVector, ECache) -> ECache)
-> [(Double, PVector, ECache)] -> [ECache]
forall a b. (a -> b) -> [a] -> [b]
Prelude.map (Double, PVector, ECache) -> ECache
forall {a} {b} {c}. (a, b, c) -> c
trd [(Double, PVector, ECache)]
response)
  (Double, [PVector]) -> RndEGraph (Double, [PVector])
forall a. a -> StateT EGraph (StateT StdGen (StateT [ECache] IO)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Double] -> Double
forall a. Ord a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
minimum (((Double, PVector, ECache) -> Double)
-> [(Double, PVector, ECache)] -> [Double]
forall a b. (a -> b) -> [a] -> [b]
Prelude.map (Double, PVector, ECache) -> Double
forall {a} {b} {c}. (a, b, c) -> a
fst' [(Double, PVector, ECache)]
response), ((Double, PVector, ECache) -> PVector)
-> [(Double, PVector, ECache)] -> [PVector]
forall a b. (a -> b) -> [a] -> [b]
Prelude.map (Double, PVector, ECache) -> PVector
forall {a} {b} {c}. (a, b, c) -> b
snd' [(Double, PVector, ECache)]
response)
  where fst' :: (a, b, c) -> a
fst' (a
a, b
_, c
_) = a
a
        snd' :: (a, b, c) -> b
snd' (a
_, b
a, c
_) = b
a
        trd :: (a, b, c) -> c
trd  (a
_, b
_, c
a) = c
a

fitnessMVNoCache :: Bool -> Int -> Int -> Distribution -> [(DataSet, DataSet)] -> EClassId -> RndEGraph (Double, [PVector])
fitnessMVNoCache :: Bool
-> Int
-> Int
-> Distribution
-> [(DataSet, DataSet)]
-> Int
-> RndEGraph (Double, [PVector])
fitnessMVNoCache Bool
shouldReparam Int
nRep Int
nIter Distribution
distribution [(DataSet, DataSet)]
dataTrainsVals Int
root = do
  -- let tree = if shouldReparam then relabelParams _tree else relabelParamsOrder _tree
  -- WARNING: this should be done BEFORE inserting into egraph, so it's up to the algorithm'
  [ECache]
caches <- StateT [ECache] IO [ECache] -> RndEGraph [ECache]
forall a. StateT [ECache] IO a -> RndEGraph a
getCache StateT [ECache] IO [ECache]
forall s (m :: * -> *). MonadState s m => m s
get
  [(Double, PVector, ECache)]
response <- [((DataSet, DataSet), ECache)]
-> (((DataSet, DataSet), ECache)
    -> RndEGraph (Double, PVector, ECache))
-> StateT
     EGraph
     (StateT StdGen (StateT [ECache] IO))
     [(Double, PVector, ECache)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([(DataSet, DataSet)] -> [ECache] -> [((DataSet, DataSet), ECache)]
forall a b. [a] -> [b] -> [(a, b)]
Prelude.zip [(DataSet, DataSet)]
dataTrainsVals [ECache]
caches) ((((DataSet, DataSet), ECache)
  -> RndEGraph (Double, PVector, ECache))
 -> StateT
      EGraph
      (StateT StdGen (StateT [ECache] IO))
      [(Double, PVector, ECache)])
-> (((DataSet, DataSet), ECache)
    -> RndEGraph (Double, PVector, ECache))
-> StateT
     EGraph
     (StateT StdGen (StateT [ECache] IO))
     [(Double, PVector, ECache)]
forall a b. (a -> b) -> a -> b
$ \((DataSet
dt, DataSet
dv), ECache
cache) -> Int
-> Int
-> Distribution
-> DataSet
-> DataSet
-> Int
-> ECache
-> RndEGraph (Double, PVector, ECache)
fitnessFunRep Int
nRep Int
nIter Distribution
distribution DataSet
dt DataSet
dv Int
root ECache
cache
  (Double, [PVector]) -> RndEGraph (Double, [PVector])
forall a. a -> StateT EGraph (StateT StdGen (StateT [ECache] IO)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Double] -> Double
forall a. Ord a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
minimum (((Double, PVector, ECache) -> Double)
-> [(Double, PVector, ECache)] -> [Double]
forall a b. (a -> b) -> [a] -> [b]
Prelude.map (Double, PVector, ECache) -> Double
forall {a} {b} {c}. (a, b, c) -> a
fst' [(Double, PVector, ECache)]
response), ((Double, PVector, ECache) -> PVector)
-> [(Double, PVector, ECache)] -> [PVector]
forall a b. (a -> b) -> [a] -> [b]
Prelude.map (Double, PVector, ECache) -> PVector
forall {a} {b} {c}. (a, b, c) -> b
snd' [(Double, PVector, ECache)]
response)
  where fst' :: (a, b, c) -> a
fst' (a
a, b
_, c
_) = a
a
        snd' :: (a, b, c) -> b
snd' (a
_, b
a, c
_) = b
a
        trd :: (a, b, c) -> c
trd  (a
_, b
_, c
a) = c
a



-- RndEGraph utils
-- fitFun fitnessFunRep rep iter distribution x y mYErr x_val y_val mYErr_val
insertExpr :: Fix SRTree -> (Fix SRTree -> RndEGraph (Double, [PVector])) -> RndEGraph EClassId
insertExpr :: Fix SRTree
-> (Fix SRTree -> RndEGraph (Double, [PVector])) -> RndEGraph Int
insertExpr Fix SRTree
t Fix SRTree -> RndEGraph (Double, [PVector])
fitFun = do
    Int
ecId <- (ENode -> Int) -> Fix SRTree -> RndEGraph Int
forall (m :: * -> *).
Monad m =>
(ENode -> Int) -> Fix SRTree -> EGraphST m Int
fromTree ENode -> Int
myCost Fix SRTree
t RndEGraph Int -> (Int -> RndEGraph Int) -> RndEGraph Int
forall a b.
StateT EGraph (StateT StdGen (StateT [ECache] IO)) a
-> (a -> StateT EGraph (StateT StdGen (StateT [ECache] IO)) b)
-> StateT EGraph (StateT StdGen (StateT [ECache] IO)) b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int -> RndEGraph Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical
    (Double
f, [PVector]
p) <- Fix SRTree -> RndEGraph (Double, [PVector])
fitFun Fix SRTree
t
    Int -> Double -> [PVector] -> RndEGraph ()
forall (m :: * -> *).
Monad m =>
Int -> Double -> [PVector] -> EGraphST m ()
insertFitness Int
ecId Double
f [PVector]
p
    Int -> RndEGraph Int
forall a. a -> StateT EGraph (StateT StdGen (StateT [ECache] IO)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
ecId
  where powabs :: Fix SRTree -> Fix SRTree -> Fix SRTree
powabs Fix SRTree
l Fix SRTree
r  = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (Op -> Fix SRTree -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Op -> val -> val -> SRTree val
Bin Op
PowerAbs Fix SRTree
l Fix SRTree
r)

updateIfNothing :: (Int -> StateT EGraph m (Double, [PVector]))
-> Int -> StateT EGraph m Bool
updateIfNothing Int -> StateT EGraph m (Double, [PVector])
fitFun Int
ec = do
      Maybe Double
mf <- Int -> EGraphST m (Maybe Double)
forall (m :: * -> *). Monad m => Int -> EGraphST m (Maybe Double)
getFitness Int
ec
      case Maybe Double
mf of
        Maybe Double
Nothing -> do
          --t <- getBestExpr ec
          (Double
f, [PVector]
p) <- Int -> StateT EGraph m (Double, [PVector])
fitFun Int
ec
          Int -> Double -> [PVector] -> EGraphST m ()
forall (m :: * -> *).
Monad m =>
Int -> Double -> [PVector] -> EGraphST m ()
insertFitness Int
ec Double
f [PVector]
p
          Bool -> StateT EGraph m Bool
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
        Just Double
_ -> Bool -> StateT EGraph m Bool
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False

pickRndSubTree :: RndEGraph (Maybe EClassId)
pickRndSubTree :: RndEGraph (Maybe Int)
pickRndSubTree = do [Int]
ecIds <- (EGraph -> [Int])
-> StateT EGraph (StateT StdGen (StateT [ECache] IO)) [Int]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (IntSet -> [Int]
IntSet.toList (IntSet -> [Int]) -> (EGraph -> IntSet) -> EGraph -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraphDB -> IntSet
_unevaluated (EGraphDB -> IntSet) -> (EGraph -> EGraphDB) -> EGraph -> IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> EGraphDB
_eDB)
                    if Bool -> Bool
not ([Int] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
ecIds)
                          then do Int
rndId' <- StateT StdGen (StateT [ECache] IO) Int -> RndEGraph Int
forall a. StateT StdGen (StateT [ECache] IO) a -> RndEGraph a
rnd (StateT StdGen (StateT [ECache] IO) Int -> RndEGraph Int)
-> StateT StdGen (StateT [ECache] IO) Int -> RndEGraph Int
forall a b. (a -> b) -> a -> b
$ [Int] -> StateT StdGen (StateT [ECache] IO) Int
forall (m :: * -> *) a. Monad m => [a] -> Rng m a
randomFrom [Int]
ecIds
                                  Int
rndId  <- Int -> RndEGraph Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical Int
rndId'
                                  Consts
constType <- (EGraph -> Consts)
-> StateT EGraph (StateT StdGen (StateT [ECache] IO)) Consts
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EClassData -> Consts
_consts (EClassData -> Consts)
-> (EGraph -> EClassData) -> EGraph -> Consts
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info (EClass -> EClassData)
-> (EGraph -> EClass) -> EGraph -> EClassData
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap EClass -> Int -> EClass
forall a. IntMap a -> Int -> a
IM.! Int
rndId) (IntMap EClass -> EClass)
-> (EGraph -> IntMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> IntMap EClass
_eClass)
                                  case Consts
constType of
                                    Consts
NotConst -> Maybe Int -> RndEGraph (Maybe Int)
forall a. a -> StateT EGraph (StateT StdGen (StateT [ECache] IO)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Int -> RndEGraph (Maybe Int))
-> Maybe Int -> RndEGraph (Maybe Int)
forall a b. (a -> b) -> a -> b
$ Int -> Maybe Int
forall a. a -> Maybe a
Just Int
rndId
                                    Consts
_        -> Maybe Int -> RndEGraph (Maybe Int)
forall a. a -> StateT EGraph (StateT StdGen (StateT [ECache] IO)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe Int
forall a. Maybe a
Nothing
                          else Maybe Int -> RndEGraph (Maybe Int)
forall a. a -> StateT EGraph (StateT StdGen (StateT [ECache] IO)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe Int
forall a. Maybe a
Nothing

getParetoEcsUpTo :: Int -> Int -> StateT EGraph m [Int]
getParetoEcsUpTo Int
n Int
maxSize = [[Int]] -> [Int]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Int]] -> [Int])
-> StateT EGraph m [[Int]] -> StateT EGraph m [Int]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int] -> (Int -> StateT EGraph m [Int]) -> StateT EGraph m [[Int]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int
1..Int
maxSize] (\Int
i -> Int -> Int -> StateT EGraph m [Int]
forall (m :: * -> *). Monad m => Int -> Int -> EGraphST m [Int]
getTopFitEClassWithSize Int
i Int
n)
getParetoDLEcsUpTo :: Int -> Int -> StateT EGraph m [Int]
getParetoDLEcsUpTo Int
n Int
maxSize = [[Int]] -> [Int]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Int]] -> [Int])
-> StateT EGraph m [[Int]] -> StateT EGraph m [Int]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int] -> (Int -> StateT EGraph m [Int]) -> StateT EGraph m [[Int]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int
1..Int
maxSize] (\Int
i -> Int -> Int -> StateT EGraph m [Int]
forall (m :: * -> *). Monad m => Int -> Int -> EGraphST m [Int]
getTopDLEClassWithSize Int
i Int
n)

getBestExprWithSize :: Int -> StateT EGraph m [(Int, Maybe Double)]
getBestExprWithSize Int
n =
        do [Int]
ec <- Int -> Int -> EGraphST m [Int]
forall (m :: * -> *). Monad m => Int -> Int -> EGraphST m [Int]
getTopFitEClassWithSize Int
n Int
1 EGraphST m [Int] -> ([Int] -> EGraphST m [Int]) -> EGraphST m [Int]
forall a b.
StateT EGraph m a -> (a -> StateT EGraph m b) -> StateT EGraph m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Int -> StateT EGraph m Int) -> [Int] -> EGraphST m [Int]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse Int -> StateT EGraph m Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical
           if (Bool -> Bool
not ([Int] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
ec))
            then do
              Maybe Double
bestFit <- Int -> EGraphST m (Maybe Double)
forall (m :: * -> *). Monad m => Int -> EGraphST m (Maybe Double)
getFitness (Int -> EGraphST m (Maybe Double))
-> Int -> EGraphST m (Maybe Double)
forall a b. (a -> b) -> a -> b
$ [Int] -> Int
forall a. HasCallStack => [a] -> a
head [Int]
ec
              [PVector]
bestP   <- (EGraph -> [PVector]) -> StateT EGraph m [PVector]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EClassData -> [PVector]
_theta (EClassData -> [PVector])
-> (EGraph -> EClassData) -> EGraph -> [PVector]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info (EClass -> EClassData)
-> (EGraph -> EClass) -> EGraph -> EClassData
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap EClass -> Int -> EClass
forall a. IntMap a -> Int -> a
IM.! ([Int] -> Int
forall a. HasCallStack => [a] -> a
head [Int]
ec)) (IntMap EClass -> EClass)
-> (EGraph -> IntMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> IntMap EClass
_eClass)
              [(Int, Maybe Double)] -> StateT EGraph m [(Int, Maybe Double)]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [([Int] -> Int
forall a. HasCallStack => [a] -> a
head [Int]
ec, Maybe Double
bestFit)]
            else [(Int, Maybe Double)] -> StateT EGraph m [(Int, Maybe Double)]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []

insertRndExpr :: Int
-> Rng (StateT [ECache] IO) (Fix SRTree)
-> Rng (StateT [ECache] IO) (SRTree ())
-> RndEGraph Int
insertRndExpr Int
maxSize Rng (StateT [ECache] IO) (Fix SRTree)
rndTerm Rng (StateT [ECache] IO) (SRTree ())
rndNonTerm =
      do Bool
grow <- StateT StdGen (StateT [ECache] IO) Bool -> RndEGraph Bool
forall a. StateT StdGen (StateT [ECache] IO) a -> RndEGraph a
rnd StateT StdGen (StateT [ECache] IO) Bool
forall (m :: * -> *). Monad m => Rng m Bool
toss
         Int
n <- StateT StdGen (StateT [ECache] IO) Int -> RndEGraph Int
forall a. StateT StdGen (StateT [ECache] IO) a -> RndEGraph a
rnd ([Int] -> StateT StdGen (StateT [ECache] IO) Int
forall (m :: * -> *) a. Monad m => [a] -> Rng m a
randomFrom [if Int
maxSize Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
4 then Int
4 else Int
1 .. Int
maxSize])
         Fix SRTree
t <- Rng (StateT [ECache] IO) (Fix SRTree) -> RndEGraph (Fix SRTree)
forall a. StateT StdGen (StateT [ECache] IO) a -> RndEGraph a
rnd (Rng (StateT [ECache] IO) (Fix SRTree) -> RndEGraph (Fix SRTree))
-> Rng (StateT [ECache] IO) (Fix SRTree) -> RndEGraph (Fix SRTree)
forall a b. (a -> b) -> a -> b
$ Int
-> Int
-> Int
-> Rng (StateT [ECache] IO) (Fix SRTree)
-> Rng (StateT [ECache] IO) (SRTree ())
-> Bool
-> Rng (StateT [ECache] IO) (Fix SRTree)
forall (m :: * -> *).
Monad m =>
Int
-> Int
-> Int
-> Rng m (Fix SRTree)
-> Rng m (SRTree ())
-> Bool
-> Rng m (Fix SRTree)
Random.randomTree Int
3 Int
8 Int
n Rng (StateT [ECache] IO) (Fix SRTree)
rndTerm Rng (StateT [ECache] IO) (SRTree ())
rndNonTerm Bool
grow
         (ENode -> Int) -> Fix SRTree -> RndEGraph Int
forall (m :: * -> *).
Monad m =>
(ENode -> Int) -> Fix SRTree -> EGraphST m Int
fromTree ENode -> Int
myCost Fix SRTree
t RndEGraph Int -> (Int -> RndEGraph Int) -> RndEGraph Int
forall a b.
StateT EGraph (StateT StdGen (StateT [ECache] IO)) a
-> (a -> StateT EGraph (StateT StdGen (StateT [ECache] IO)) b)
-> StateT EGraph (StateT StdGen (StateT [ECache] IO)) b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int -> RndEGraph Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical

refit :: (Int -> StateT EGraph m (Double, [PVector]))
-> Int -> StateT EGraph m ()
refit Int -> StateT EGraph m (Double, [PVector])
fitFun Int
ec = do
  --t <- getBestExpr ec
  (Double
f, [PVector]
p) <- Int -> StateT EGraph m (Double, [PVector])
fitFun Int
ec
  Maybe Double
mf <- Int -> EGraphST m (Maybe Double)
forall (m :: * -> *). Monad m => Int -> EGraphST m (Maybe Double)
getFitness Int
ec
  case Maybe Double
mf of
    Maybe Double
Nothing -> Int -> Double -> [PVector] -> StateT EGraph m ()
forall (m :: * -> *).
Monad m =>
Int -> Double -> [PVector] -> EGraphST m ()
insertFitness Int
ec Double
f [PVector]
p
    Just Double
f' -> Bool -> StateT EGraph m () -> StateT EGraph m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Double
f Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
> Double
f') (StateT EGraph m () -> StateT EGraph m ())
-> StateT EGraph m () -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ Int -> Double -> [PVector] -> StateT EGraph m ()
forall (m :: * -> *).
Monad m =>
Int -> Double -> [PVector] -> EGraphST m ()
insertFitness Int
ec Double
f [PVector]
p

--printBest :: (Int -> EClassId -> RndEGraph ()) -> RndEGraph ()
printBest :: p -> (t -> Int -> StateT EGraph m b) -> StateT EGraph m b
printBest p
fitFun t -> Int -> StateT EGraph m b
printExprFun = do
      Int
bec <- (EGraph -> Int) -> StateT EGraph m Int
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((Double, Int) -> Int
forall a b. (a, b) -> b
snd ((Double, Int) -> Int)
-> (EGraph -> (Double, Int)) -> EGraph -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RangeTree Double -> (Double, Int)
forall a. Ord a => RangeTree a -> (a, Int)
getGreatest (RangeTree Double -> (Double, Int))
-> (EGraph -> RangeTree Double) -> EGraph -> (Double, Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraphDB -> RangeTree Double
_fitRangeDB (EGraphDB -> RangeTree Double)
-> (EGraph -> EGraphDB) -> EGraph -> RangeTree Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> EGraphDB
_eDB) StateT EGraph m Int
-> (Int -> StateT EGraph m Int) -> StateT EGraph m Int
forall a b.
StateT EGraph m a -> (a -> StateT EGraph m b) -> StateT EGraph m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int -> StateT EGraph m Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical
      Maybe Double
bestFit <- (EGraph -> Maybe Double) -> StateT EGraph m (Maybe Double)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EClassData -> Maybe Double
_fitness(EClassData -> Maybe Double)
-> (EGraph -> EClassData) -> EGraph -> Maybe Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info (EClass -> EClassData)
-> (EGraph -> EClass) -> EGraph -> EClassData
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap EClass -> Int -> EClass
forall a. IntMap a -> Int -> a
IM.! Int
bec) (IntMap EClass -> EClass)
-> (EGraph -> IntMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> IntMap EClass
_eClass)
      --refit fitFun bec
      --io.print $ "should be " <> show bestFit
      t -> Int -> StateT EGraph m b
printExprFun t
0 Int
bec

--paretoFront :: Int -> (Int -> EClassId -> RndEGraph ()) -> RndEGraph ()
paretoFront :: (Int -> RndEGraph (Double, [PVector]))
-> Int
-> (Int
    -> Int
    -> StateT EGraph (StateT StdGen (StateT [ECache] IO)) [String])
-> RndEGraph [[String]]
paretoFront Int -> RndEGraph (Double, [PVector])
fitFun Int
maxSize Int
-> Int
-> StateT EGraph (StateT StdGen (StateT [ECache] IO)) [String]
printExprFun = Int -> Int -> Double -> RndEGraph [[String]]
go Int
1 Int
0 (-(Double
1.0Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
0.0))
    where
    go :: Int -> Int -> Double -> RndEGraph [[String]]
    go :: Int -> Int -> Double -> RndEGraph [[String]]
go Int
n Int
ix Double
f
        | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
maxSize = [[String]] -> RndEGraph [[String]]
forall a. a -> StateT EGraph (StateT StdGen (StateT [ECache] IO)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
        | Bool
otherwise   = do
            [(Int, Maybe Double)]
ecList <- Int
-> StateT
     EGraph (StateT StdGen (StateT [ECache] IO)) [(Int, Maybe Double)]
forall {m :: * -> *}.
Monad m =>
Int -> StateT EGraph m [(Int, Maybe Double)]
getBestExprWithSize Int
n
            if Bool -> Bool
not ([(Int, Maybe Double)] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Int, Maybe Double)]
ecList)
                then do let (Int
ec, Maybe Double
mf) = [(Int, Maybe Double)] -> (Int, Maybe Double)
forall a. HasCallStack => [a] -> a
head [(Int, Maybe Double)]
ecList
                            f' :: Double
f' = Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust Maybe Double
mf
                            improved :: Bool
improved = Double
f' Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
>= Double
f Bool -> Bool -> Bool
&& (Bool -> Bool
not (Bool -> Bool) -> (Double -> Bool) -> Double -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> Bool
forall a. RealFloat a => a -> Bool
isNaN) Double
f' Bool -> Bool -> Bool
&& (Bool -> Bool
not (Bool -> Bool) -> (Double -> Bool) -> Double -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> Bool
forall a. RealFloat a => a -> Bool
isInfinite) Double
f'
                        Int
ec' <- Int -> RndEGraph Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical Int
ec
                        if Bool
improved
                                then do (Int -> RndEGraph (Double, [PVector])) -> Int -> RndEGraph ()
forall {m :: * -> *}.
Monad m =>
(Int -> StateT EGraph m (Double, [PVector]))
-> Int -> StateT EGraph m ()
refit Int -> RndEGraph (Double, [PVector])
fitFun Int
ec'
                                        [String]
t <- Int
-> Int
-> StateT EGraph (StateT StdGen (StateT [ECache] IO)) [String]
printExprFun Int
ix Int
ec'
                                        [[String]]
ts <- Int -> Int -> Double -> RndEGraph [[String]]
go (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) (Int
ix Int -> Int -> Int
forall a. Num a => a -> a -> a
+ if Bool
improved then Int
1 else Int
0) (Double -> Double -> Double
forall a. Ord a => a -> a -> a
max Double
f Double
f')
                                        [[String]] -> RndEGraph [[String]]
forall a. a -> StateT EGraph (StateT StdGen (StateT [ECache] IO)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([String]
t[String] -> [[String]] -> [[String]]
forall a. a -> [a] -> [a]
:[[String]]
ts)
                                else Int -> Int -> Double -> RndEGraph [[String]]
go (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) (Int
ix Int -> Int -> Int
forall a. Num a => a -> a -> a
+ if Bool
improved then Int
1 else Int
0) (Double -> Double -> Double
forall a. Ord a => a -> a -> a
max Double
f Double
f')
                else Int -> Int -> Double -> RndEGraph [[String]]
go (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Int
ix Double
f

evaluateUnevaluated :: (Int -> StateT EGraph m (Double, [PVector])) -> StateT EGraph m ()
evaluateUnevaluated Int -> StateT EGraph m (Double, [PVector])
fitFun = do
          [Int]
ec <- (EGraph -> [Int]) -> StateT EGraph m [Int]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (IntSet -> [Int]
IntSet.toList (IntSet -> [Int]) -> (EGraph -> IntSet) -> EGraph -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraphDB -> IntSet
_unevaluated (EGraphDB -> IntSet) -> (EGraph -> EGraphDB) -> EGraph -> IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> EGraphDB
_eDB)
          [Int] -> (Int -> StateT EGraph m ()) -> StateT EGraph m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int]
ec ((Int -> StateT EGraph m ()) -> StateT EGraph m ())
-> (Int -> StateT EGraph m ()) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ \Int
c -> do
              --t <- getBestExpr c
              (Double
f, [PVector]
p) <- Int -> StateT EGraph m (Double, [PVector])
fitFun Int
c
              Int -> Double -> [PVector] -> StateT EGraph m ()
forall (m :: * -> *).
Monad m =>
Int -> Double -> [PVector] -> EGraphST m ()
insertFitness Int
c Double
f [PVector]
p

evaluateRndUnevaluated :: (Int -> RndEGraph (Double, [PVector])) -> RndEGraph Int
evaluateRndUnevaluated Int -> RndEGraph (Double, [PVector])
fitFun = do
          [Int]
ec <- (EGraph -> [Int])
-> StateT EGraph (StateT StdGen (StateT [ECache] IO)) [Int]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (IntSet -> [Int]
IntSet.toList (IntSet -> [Int]) -> (EGraph -> IntSet) -> EGraph -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraphDB -> IntSet
_unevaluated (EGraphDB -> IntSet) -> (EGraph -> EGraphDB) -> EGraph -> IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> EGraphDB
_eDB)
          Int
c <- StateT StdGen (StateT [ECache] IO) Int -> RndEGraph Int
forall a. StateT StdGen (StateT [ECache] IO) a -> RndEGraph a
rnd (StateT StdGen (StateT [ECache] IO) Int -> RndEGraph Int)
-> ([Int] -> StateT StdGen (StateT [ECache] IO) Int)
-> [Int]
-> RndEGraph Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> StateT StdGen (StateT [ECache] IO) Int
forall (m :: * -> *) a. Monad m => [a] -> Rng m a
randomFrom ([Int] -> RndEGraph Int) -> [Int] -> RndEGraph Int
forall a b. (a -> b) -> a -> b
$ [Int]
ec
          --t <- getBestExpr c
          (Double
f, [PVector]
p) <- Int -> RndEGraph (Double, [PVector])
fitFun Int
c
          Int -> Double -> [PVector] -> RndEGraph ()
forall (m :: * -> *).
Monad m =>
Int -> Double -> [PVector] -> EGraphST m ()
insertFitness Int
c Double
f [PVector]
p
          Int -> RndEGraph Int
forall a. a -> StateT EGraph (StateT StdGen (StateT [ECache] IO)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
c

-- | check whether an e-node exists or does not exist in the e-graph
doesExist, doesNotExist :: ENode -> RndEGraph Bool
doesExist :: ENode -> RndEGraph Bool
doesExist ENode
en = (EGraph -> Bool) -> RndEGraph Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ENode -> Map ENode Int -> Bool
forall k a. Ord k => k -> Map k a -> Bool
Map.member ENode
en) (Map ENode Int -> Bool)
-> (EGraph -> Map ENode Int) -> EGraph -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> Map ENode Int
_eNodeToEClass)
doesNotExist :: ENode -> RndEGraph Bool
doesNotExist ENode
en = (EGraph -> Bool) -> RndEGraph Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ENode -> Map ENode Int -> Bool
forall k a. Ord k => k -> Map k a -> Bool
Map.notMember ENode
en) (Map ENode Int -> Bool)
-> (EGraph -> Map ENode Int) -> EGraph -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> Map ENode Int
_eNodeToEClass)

-- | check whether the partial tree defined by a list of ancestors will create
-- a non-existent expression when combined with a certain e-node.
doesNotExistGens :: [Maybe (EClassId -> ENode)] -> ENode -> RndEGraph Bool
doesNotExistGens :: [Maybe (Int -> ENode)] -> ENode -> RndEGraph Bool
doesNotExistGens []              ENode
en = (EGraph -> Bool) -> RndEGraph Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ENode -> Map ENode Int -> Bool
forall k a. Ord k => k -> Map k a -> Bool
Map.notMember ENode
en) (Map ENode Int -> Bool)
-> (EGraph -> Map ENode Int) -> EGraph -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> Map ENode Int
_eNodeToEClass)
doesNotExistGens (Maybe (Int -> ENode)
mGrand:[Maybe (Int -> ENode)]
grands) ENode
en = do  Bool
b <- (EGraph -> Bool) -> RndEGraph Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ENode -> Map ENode Int -> Bool
forall k a. Ord k => k -> Map k a -> Bool
Map.notMember ENode
en) (Map ENode Int -> Bool)
-> (EGraph -> Map ENode Int) -> EGraph -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> Map ENode Int
_eNodeToEClass)
                                          if Bool
b
                                            then Bool -> RndEGraph Bool
forall a. a -> StateT EGraph (StateT StdGen (StateT [ECache] IO)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
                                            else case Maybe (Int -> ENode)
mGrand of
                                                Maybe (Int -> ENode)
Nothing -> Bool -> RndEGraph Bool
forall a. a -> StateT EGraph (StateT StdGen (StateT [ECache] IO)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
                                                Just Int -> ENode
gf -> do Int
ec  <- (EGraph -> Int) -> RndEGraph Int
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((Map ENode Int -> ENode -> Int
forall k a. Ord k => Map k a -> k -> a
Map.! ENode
en) (Map ENode Int -> Int)
-> (EGraph -> Map ENode Int) -> EGraph -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> Map ENode Int
_eNodeToEClass)
                                                              ENode
en' <- ENode -> EGraphST (StateT StdGen (StateT [ECache] IO)) ENode
forall (m :: * -> *). Monad m => ENode -> EGraphST m ENode
canonize (Int -> ENode
gf Int
ec)
                                                              [Maybe (Int -> ENode)] -> ENode -> RndEGraph Bool
doesNotExistGens [Maybe (Int -> ENode)]
grands ENode
en'

-- | check whether combining a partial tree `parent` with the e-node `en'`
-- will create a new expression
checkToken :: (Int -> ENode) -> ENode -> RndEGraph Bool
checkToken Int -> ENode
parent ENode
en' = do  ENode
en <- ENode -> EGraphST (StateT StdGen (StateT [ECache] IO)) ENode
forall (m :: * -> *). Monad m => ENode -> EGraphST m ENode
canonize ENode
en'
                            Maybe Int
mEc <- (EGraph -> Maybe Int) -> RndEGraph (Maybe Int)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((Map ENode Int -> ENode -> Maybe Int
forall k a. Ord k => Map k a -> k -> Maybe a
Map.!? ENode
en) (Map ENode Int -> Maybe Int)
-> (EGraph -> Map ENode Int) -> EGraph -> Maybe Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> Map ENode Int
_eNodeToEClass)
                            case Maybe Int
mEc of
                                Maybe Int
Nothing -> Bool -> RndEGraph Bool
forall a. a -> StateT EGraph (StateT StdGen (StateT [ECache] IO)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
                                Just Int
ec -> do Int
ec' <- Int -> RndEGraph Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical Int
ec
                                              ENode
ec'' <- ENode -> EGraphST (StateT StdGen (StateT [ECache] IO)) ENode
forall (m :: * -> *). Monad m => ENode -> EGraphST m ENode
canonize (Int -> ENode
parent Int
ec')
                                              Bool -> Bool
not (Bool -> Bool) -> RndEGraph Bool -> RndEGraph Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ENode -> RndEGraph Bool
doesExist ENode
ec''