module Algorithm.EqSat.SearchSRCache where
import Data.SRTree
import Data.SRTree.Datasets
import System.Random
import Control.Monad.State.Strict
import Algorithm.EqSat.Egraph
import Algorithm.SRTree.Likelihoods
import qualified Data.IntMap as IM
import qualified Data.IntSet as IntSet
import qualified Data.SRTree.Random as Random
import Data.Function ( on )
import Algorithm.SRTree.Likelihoods
import Algorithm.SRTree.NonlinearOpt
import Control.Monad ( when, replicateM, forM, forM_ )
import Algorithm.EqSat.Egraph
import Algorithm.SRTree.Opt
import Algorithm.EqSat.Info
import Algorithm.EqSat.Build
import Data.Maybe ( fromJust )
import Data.SRTree.Random
import Algorithm.EqSat.Queries
import Data.List ( maximumBy )
import qualified Data.Map.Strict as Map
import Control.Monad.Identity
import Debug.Trace
type RndEGraph a = EGraphST (StateT StdGen (StateT [ECache] IO)) a
io :: IO a -> RndEGraph a
io :: forall a. IO a -> RndEGraph a
io = StateT StdGen (StateT [ECache] IO) a
-> StateT EGraph (StateT StdGen (StateT [ECache] IO)) a
forall (m :: * -> *) a. Monad m => m a -> StateT EGraph m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (StateT StdGen (StateT [ECache] IO) a
-> StateT EGraph (StateT StdGen (StateT [ECache] IO)) a)
-> (IO a -> StateT StdGen (StateT [ECache] IO) a)
-> IO a
-> StateT EGraph (StateT StdGen (StateT [ECache] IO)) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StateT [ECache] IO a -> StateT StdGen (StateT [ECache] IO) a
forall (m :: * -> *) a. Monad m => m a -> StateT StdGen m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (StateT [ECache] IO a -> StateT StdGen (StateT [ECache] IO) a)
-> (IO a -> StateT [ECache] IO a)
-> IO a
-> StateT StdGen (StateT [ECache] IO) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO a -> StateT [ECache] IO a
forall (m :: * -> *) a. Monad m => m a -> StateT [ECache] m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift
{-# INLINE io #-}
getCache :: StateT [ECache] IO a -> RndEGraph a
getCache :: forall a. StateT [ECache] IO a -> RndEGraph a
getCache = StateT StdGen (StateT [ECache] IO) a
-> StateT EGraph (StateT StdGen (StateT [ECache] IO)) a
forall (m :: * -> *) a. Monad m => m a -> StateT EGraph m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (StateT StdGen (StateT [ECache] IO) a
-> StateT EGraph (StateT StdGen (StateT [ECache] IO)) a)
-> (StateT [ECache] IO a -> StateT StdGen (StateT [ECache] IO) a)
-> StateT [ECache] IO a
-> StateT EGraph (StateT StdGen (StateT [ECache] IO)) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StateT [ECache] IO a -> StateT StdGen (StateT [ECache] IO) a
forall (m :: * -> *) a. Monad m => m a -> StateT StdGen m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift
rnd :: StateT StdGen (StateT [ECache] IO) a -> RndEGraph a
rnd :: forall a. StateT StdGen (StateT [ECache] IO) a -> RndEGraph a
rnd = StateT StdGen (StateT [ECache] IO) a
-> StateT EGraph (StateT StdGen (StateT [ECache] IO)) a
forall (m :: * -> *) a. Monad m => m a -> StateT EGraph m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift
{-# INLINE rnd #-}
myCost :: SRTree Int -> Int
myCost :: ENode -> Int
myCost (Var Int
_) = Int
1
myCost (Const Double
_) = Int
1
myCost (Param Int
_) = Int
1
myCost (Bin Op
_ Int
l Int
r) = Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
r
myCost (Uni Function
_ Int
t) = Int
3 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
t
while :: Monad f => (t -> Bool) -> t -> (t -> f t) -> f t
while :: forall (f :: * -> *) t.
Monad f =>
(t -> Bool) -> t -> (t -> f t) -> f t
while t -> Bool
p t
arg t -> f t
prog = do if (t -> Bool
p t
arg)
then do t
arg' <- t -> f t
prog t
arg
(t -> Bool) -> t -> (t -> f t) -> f t
forall (f :: * -> *) t.
Monad f =>
(t -> Bool) -> t -> (t -> f t) -> f t
while t -> Bool
p t
arg' t -> f t
prog
else t -> f t
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure t
arg
fitnessFun :: Int -> Distribution -> DataSet -> DataSet -> EGraph -> EClassId -> ECache -> PVector -> (Double, PVector, ECache)
fitnessFun :: Int
-> Distribution
-> DataSet
-> DataSet
-> EGraph
-> Int
-> ECache
-> PVector
-> (Double, PVector, ECache)
fitnessFun Int
nIter Distribution
distribution (SRMatrix
x, PVector
y, Maybe PVector
mYErr) (SRMatrix
x_val, PVector
y_val, Maybe PVector
mYErr_val) EGraph
egraph Int
root ECache
cache PVector
thetaOrig =
if Double -> Bool
forall a. RealFloat a => a -> Bool
isNaN Double
val
then (-(Double
1Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
0), PVector
theta,ECache
cache')
else (Double
val, PVector
theta, ECache
cache')
where
tree :: Fix SRTree
tree = Identity (Fix SRTree) -> Fix SRTree
forall a. Identity a -> a
runIdentity (Identity (Fix SRTree) -> Fix SRTree)
-> Identity (Fix SRTree) -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Int -> EGraphST Identity (Fix SRTree)
forall (m :: * -> *). Monad m => Int -> EGraphST m (Fix SRTree)
getBestExpr Int
root EGraphST Identity (Fix SRTree) -> EGraph -> Identity (Fix SRTree)
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
`evalStateT` EGraph
egraph
nParams :: Int
nParams = EGraph -> Int -> Int
countParamsUniqEg EGraph
egraph Int
root Int -> Int -> Int
forall a. Num a => a -> a -> a
+ if Distribution
distribution Distribution -> Distribution -> Bool
forall a. Eq a => a -> a -> Bool
== Distribution
ROXY then Int
3 else if Distribution
distribution Distribution -> Distribution -> Bool
forall a. Eq a => a -> a -> Bool
== Distribution
Gaussian then Int
1 else Int
0
(PVector
theta, Double
val, Int
_, ECache
cache') = (ObjectiveD -> Maybe VectorStorage -> LocalAlgorithm)
-> Distribution
-> Maybe PVector
-> Int
-> SRMatrix
-> PVector
-> EGraph
-> Int
-> ECache
-> PVector
-> (PVector, Double, Int, ECache)
minimizeNLLEGraph ObjectiveD -> Maybe VectorStorage -> LocalAlgorithm
VAR1 Distribution
distribution Maybe PVector
mYErr Int
nIter SRMatrix
x PVector
y EGraph
egraph Int
root ECache
cache PVector
thetaOrig
evalF :: SRMatrix -> PVector -> Maybe PVector -> Double
evalF SRMatrix
a PVector
b Maybe PVector
c = Double -> Double
forall a. Num a => a -> a
negate (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> Double
nll Distribution
distribution Maybe PVector
c SRMatrix
a PVector
b Fix SRTree
tree (PVector -> Double) -> PVector -> Double
forall a b. (a -> b) -> a -> b
$ if Int
nParams Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 then PVector
thetaOrig else PVector
theta
fitnessFunRep :: Int -> Int -> Distribution -> DataSet -> DataSet -> EClassId -> ECache -> RndEGraph (Double, PVector, ECache)
fitnessFunRep :: Int
-> Int
-> Distribution
-> DataSet
-> DataSet
-> Int
-> ECache
-> RndEGraph (Double, PVector, ECache)
fitnessFunRep Int
nRep Int
nIter Distribution
distribution DataSet
dataTrain DataSet
dataVal Int
root ECache
cache = do
EGraph
egraph <- StateT EGraph (StateT StdGen (StateT [ECache] IO)) EGraph
forall s (m :: * -> *). MonadState s m => m s
get
let nParams :: Int
nParams = EGraph -> Int -> Int
countParamsUniqEg EGraph
egraph Int
root Int -> Int -> Int
forall a. Num a => a -> a -> a
+ if Distribution
distribution Distribution -> Distribution -> Bool
forall a. Eq a => a -> a -> Bool
== Distribution
ROXY then Int
3 else if Distribution
distribution Distribution -> Distribution -> Bool
forall a. Eq a => a -> a -> Bool
== Distribution
Gaussian then Int
1 else Int
0
fst' :: (a, b, c) -> a
fst' (a
a, b
_, c
_) = a
a
[PVector]
thetaOrigs <- Int
-> StateT EGraph (StateT StdGen (StateT [ECache] IO)) PVector
-> StateT EGraph (StateT StdGen (StateT [ECache] IO)) [PVector]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
nRep (StateT StdGen (StateT [ECache] IO) PVector
-> StateT EGraph (StateT StdGen (StateT [ECache] IO)) PVector
forall a. StateT StdGen (StateT [ECache] IO) a -> RndEGraph a
rnd (StateT StdGen (StateT [ECache] IO) PVector
-> StateT EGraph (StateT StdGen (StateT [ECache] IO)) PVector)
-> StateT StdGen (StateT [ECache] IO) PVector
-> StateT EGraph (StateT StdGen (StateT [ECache] IO)) PVector
forall a b. (a -> b) -> a -> b
$ Int -> StateT StdGen (StateT [ECache] IO) PVector
forall (m :: * -> *). Monad m => Int -> Rng m PVector
randomVec Int
nParams)
let fits :: (Double, PVector, ECache)
fits = ((Double, PVector, ECache)
-> (Double, PVector, ECache) -> Ordering)
-> [(Double, PVector, ECache)] -> (Double, PVector, ECache)
forall (t :: * -> *) a.
Foldable t =>
(a -> a -> Ordering) -> t a -> a
maximumBy (Double -> Double -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Double -> Double -> Ordering)
-> ((Double, PVector, ECache) -> Double)
-> (Double, PVector, ECache)
-> (Double, PVector, ECache)
-> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (Double, PVector, ECache) -> Double
forall {a} {b} {c}. (a, b, c) -> a
fst') ([(Double, PVector, ECache)] -> (Double, PVector, ECache))
-> [(Double, PVector, ECache)] -> (Double, PVector, ECache)
forall a b. (a -> b) -> a -> b
$ (PVector -> (Double, PVector, ECache))
-> [PVector] -> [(Double, PVector, ECache)]
forall a b. (a -> b) -> [a] -> [b]
Prelude.map (Int
-> Distribution
-> DataSet
-> DataSet
-> EGraph
-> Int
-> ECache
-> PVector
-> (Double, PVector, ECache)
fitnessFun Int
nIter Distribution
distribution DataSet
dataTrain DataSet
dataVal EGraph
egraph Int
root ECache
cache) [PVector]
thetaOrigs
(Double, PVector, ECache) -> RndEGraph (Double, PVector, ECache)
forall a. a -> StateT EGraph (StateT StdGen (StateT [ECache] IO)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Double, PVector, ECache)
fits
fitnessMV :: Bool -> Int -> Int -> Distribution -> [(DataSet, DataSet)] -> EClassId -> RndEGraph (Double, [PVector])
fitnessMV :: Bool
-> Int
-> Int
-> Distribution
-> [(DataSet, DataSet)]
-> Int
-> RndEGraph (Double, [PVector])
fitnessMV Bool
shouldReparam Int
nRep Int
nIter Distribution
distribution [(DataSet, DataSet)]
dataTrainsVals Int
root = do
[ECache]
caches <- StateT [ECache] IO [ECache] -> RndEGraph [ECache]
forall a. StateT [ECache] IO a -> RndEGraph a
getCache StateT [ECache] IO [ECache]
forall s (m :: * -> *). MonadState s m => m s
get
[(Double, PVector, ECache)]
response <- [((DataSet, DataSet), ECache)]
-> (((DataSet, DataSet), ECache)
-> RndEGraph (Double, PVector, ECache))
-> StateT
EGraph
(StateT StdGen (StateT [ECache] IO))
[(Double, PVector, ECache)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([(DataSet, DataSet)] -> [ECache] -> [((DataSet, DataSet), ECache)]
forall a b. [a] -> [b] -> [(a, b)]
Prelude.zip [(DataSet, DataSet)]
dataTrainsVals [ECache]
caches) ((((DataSet, DataSet), ECache)
-> RndEGraph (Double, PVector, ECache))
-> StateT
EGraph
(StateT StdGen (StateT [ECache] IO))
[(Double, PVector, ECache)])
-> (((DataSet, DataSet), ECache)
-> RndEGraph (Double, PVector, ECache))
-> StateT
EGraph
(StateT StdGen (StateT [ECache] IO))
[(Double, PVector, ECache)]
forall a b. (a -> b) -> a -> b
$ \((DataSet
dt, DataSet
dv), ECache
cache) -> Int
-> Int
-> Distribution
-> DataSet
-> DataSet
-> Int
-> ECache
-> RndEGraph (Double, PVector, ECache)
fitnessFunRep Int
nRep Int
nIter Distribution
distribution DataSet
dt DataSet
dv Int
root ECache
cache
StateT [ECache] IO () -> RndEGraph ()
forall a. StateT [ECache] IO a -> RndEGraph a
getCache (StateT [ECache] IO () -> RndEGraph ())
-> StateT [ECache] IO () -> RndEGraph ()
forall a b. (a -> b) -> a -> b
$ [ECache] -> StateT [ECache] IO ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (((Double, PVector, ECache) -> ECache)
-> [(Double, PVector, ECache)] -> [ECache]
forall a b. (a -> b) -> [a] -> [b]
Prelude.map (Double, PVector, ECache) -> ECache
forall {a} {b} {c}. (a, b, c) -> c
trd [(Double, PVector, ECache)]
response)
(Double, [PVector]) -> RndEGraph (Double, [PVector])
forall a. a -> StateT EGraph (StateT StdGen (StateT [ECache] IO)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Double] -> Double
forall a. Ord a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
minimum (((Double, PVector, ECache) -> Double)
-> [(Double, PVector, ECache)] -> [Double]
forall a b. (a -> b) -> [a] -> [b]
Prelude.map (Double, PVector, ECache) -> Double
forall {a} {b} {c}. (a, b, c) -> a
fst' [(Double, PVector, ECache)]
response), ((Double, PVector, ECache) -> PVector)
-> [(Double, PVector, ECache)] -> [PVector]
forall a b. (a -> b) -> [a] -> [b]
Prelude.map (Double, PVector, ECache) -> PVector
forall {a} {b} {c}. (a, b, c) -> b
snd' [(Double, PVector, ECache)]
response)
where fst' :: (a, b, c) -> a
fst' (a
a, b
_, c
_) = a
a
snd' :: (a, b, c) -> b
snd' (a
_, b
a, c
_) = b
a
trd :: (a, b, c) -> c
trd (a
_, b
_, c
a) = c
a
fitnessMVNoCache :: Bool -> Int -> Int -> Distribution -> [(DataSet, DataSet)] -> EClassId -> RndEGraph (Double, [PVector])
fitnessMVNoCache :: Bool
-> Int
-> Int
-> Distribution
-> [(DataSet, DataSet)]
-> Int
-> RndEGraph (Double, [PVector])
fitnessMVNoCache Bool
shouldReparam Int
nRep Int
nIter Distribution
distribution [(DataSet, DataSet)]
dataTrainsVals Int
root = do
[ECache]
caches <- StateT [ECache] IO [ECache] -> RndEGraph [ECache]
forall a. StateT [ECache] IO a -> RndEGraph a
getCache StateT [ECache] IO [ECache]
forall s (m :: * -> *). MonadState s m => m s
get
[(Double, PVector, ECache)]
response <- [((DataSet, DataSet), ECache)]
-> (((DataSet, DataSet), ECache)
-> RndEGraph (Double, PVector, ECache))
-> StateT
EGraph
(StateT StdGen (StateT [ECache] IO))
[(Double, PVector, ECache)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([(DataSet, DataSet)] -> [ECache] -> [((DataSet, DataSet), ECache)]
forall a b. [a] -> [b] -> [(a, b)]
Prelude.zip [(DataSet, DataSet)]
dataTrainsVals [ECache]
caches) ((((DataSet, DataSet), ECache)
-> RndEGraph (Double, PVector, ECache))
-> StateT
EGraph
(StateT StdGen (StateT [ECache] IO))
[(Double, PVector, ECache)])
-> (((DataSet, DataSet), ECache)
-> RndEGraph (Double, PVector, ECache))
-> StateT
EGraph
(StateT StdGen (StateT [ECache] IO))
[(Double, PVector, ECache)]
forall a b. (a -> b) -> a -> b
$ \((DataSet
dt, DataSet
dv), ECache
cache) -> Int
-> Int
-> Distribution
-> DataSet
-> DataSet
-> Int
-> ECache
-> RndEGraph (Double, PVector, ECache)
fitnessFunRep Int
nRep Int
nIter Distribution
distribution DataSet
dt DataSet
dv Int
root ECache
cache
(Double, [PVector]) -> RndEGraph (Double, [PVector])
forall a. a -> StateT EGraph (StateT StdGen (StateT [ECache] IO)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Double] -> Double
forall a. Ord a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
minimum (((Double, PVector, ECache) -> Double)
-> [(Double, PVector, ECache)] -> [Double]
forall a b. (a -> b) -> [a] -> [b]
Prelude.map (Double, PVector, ECache) -> Double
forall {a} {b} {c}. (a, b, c) -> a
fst' [(Double, PVector, ECache)]
response), ((Double, PVector, ECache) -> PVector)
-> [(Double, PVector, ECache)] -> [PVector]
forall a b. (a -> b) -> [a] -> [b]
Prelude.map (Double, PVector, ECache) -> PVector
forall {a} {b} {c}. (a, b, c) -> b
snd' [(Double, PVector, ECache)]
response)
where fst' :: (a, b, c) -> a
fst' (a
a, b
_, c
_) = a
a
snd' :: (a, b, c) -> b
snd' (a
_, b
a, c
_) = b
a
trd :: (a, b, c) -> c
trd (a
_, b
_, c
a) = c
a
insertExpr :: Fix SRTree -> (Fix SRTree -> RndEGraph (Double, [PVector])) -> RndEGraph EClassId
insertExpr :: Fix SRTree
-> (Fix SRTree -> RndEGraph (Double, [PVector])) -> RndEGraph Int
insertExpr Fix SRTree
t Fix SRTree -> RndEGraph (Double, [PVector])
fitFun = do
Int
ecId <- (ENode -> Int) -> Fix SRTree -> RndEGraph Int
forall (m :: * -> *).
Monad m =>
(ENode -> Int) -> Fix SRTree -> EGraphST m Int
fromTree ENode -> Int
myCost Fix SRTree
t RndEGraph Int -> (Int -> RndEGraph Int) -> RndEGraph Int
forall a b.
StateT EGraph (StateT StdGen (StateT [ECache] IO)) a
-> (a -> StateT EGraph (StateT StdGen (StateT [ECache] IO)) b)
-> StateT EGraph (StateT StdGen (StateT [ECache] IO)) b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int -> RndEGraph Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical
(Double
f, [PVector]
p) <- Fix SRTree -> RndEGraph (Double, [PVector])
fitFun Fix SRTree
t
Int -> Double -> [PVector] -> RndEGraph ()
forall (m :: * -> *).
Monad m =>
Int -> Double -> [PVector] -> EGraphST m ()
insertFitness Int
ecId Double
f [PVector]
p
Int -> RndEGraph Int
forall a. a -> StateT EGraph (StateT StdGen (StateT [ECache] IO)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
ecId
where powabs :: Fix SRTree -> Fix SRTree -> Fix SRTree
powabs Fix SRTree
l Fix SRTree
r = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (Op -> Fix SRTree -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Op -> val -> val -> SRTree val
Bin Op
PowerAbs Fix SRTree
l Fix SRTree
r)
updateIfNothing :: (Int -> StateT EGraph m (Double, [PVector]))
-> Int -> StateT EGraph m Bool
updateIfNothing Int -> StateT EGraph m (Double, [PVector])
fitFun Int
ec = do
Maybe Double
mf <- Int -> EGraphST m (Maybe Double)
forall (m :: * -> *). Monad m => Int -> EGraphST m (Maybe Double)
getFitness Int
ec
case Maybe Double
mf of
Maybe Double
Nothing -> do
(Double
f, [PVector]
p) <- Int -> StateT EGraph m (Double, [PVector])
fitFun Int
ec
Int -> Double -> [PVector] -> EGraphST m ()
forall (m :: * -> *).
Monad m =>
Int -> Double -> [PVector] -> EGraphST m ()
insertFitness Int
ec Double
f [PVector]
p
Bool -> StateT EGraph m Bool
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
Just Double
_ -> Bool -> StateT EGraph m Bool
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
pickRndSubTree :: RndEGraph (Maybe EClassId)
pickRndSubTree :: RndEGraph (Maybe Int)
pickRndSubTree = do [Int]
ecIds <- (EGraph -> [Int])
-> StateT EGraph (StateT StdGen (StateT [ECache] IO)) [Int]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (IntSet -> [Int]
IntSet.toList (IntSet -> [Int]) -> (EGraph -> IntSet) -> EGraph -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraphDB -> IntSet
_unevaluated (EGraphDB -> IntSet) -> (EGraph -> EGraphDB) -> EGraph -> IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> EGraphDB
_eDB)
if Bool -> Bool
not ([Int] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
ecIds)
then do Int
rndId' <- StateT StdGen (StateT [ECache] IO) Int -> RndEGraph Int
forall a. StateT StdGen (StateT [ECache] IO) a -> RndEGraph a
rnd (StateT StdGen (StateT [ECache] IO) Int -> RndEGraph Int)
-> StateT StdGen (StateT [ECache] IO) Int -> RndEGraph Int
forall a b. (a -> b) -> a -> b
$ [Int] -> StateT StdGen (StateT [ECache] IO) Int
forall (m :: * -> *) a. Monad m => [a] -> Rng m a
randomFrom [Int]
ecIds
Int
rndId <- Int -> RndEGraph Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical Int
rndId'
Consts
constType <- (EGraph -> Consts)
-> StateT EGraph (StateT StdGen (StateT [ECache] IO)) Consts
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EClassData -> Consts
_consts (EClassData -> Consts)
-> (EGraph -> EClassData) -> EGraph -> Consts
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info (EClass -> EClassData)
-> (EGraph -> EClass) -> EGraph -> EClassData
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap EClass -> Int -> EClass
forall a. IntMap a -> Int -> a
IM.! Int
rndId) (IntMap EClass -> EClass)
-> (EGraph -> IntMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> IntMap EClass
_eClass)
case Consts
constType of
Consts
NotConst -> Maybe Int -> RndEGraph (Maybe Int)
forall a. a -> StateT EGraph (StateT StdGen (StateT [ECache] IO)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Int -> RndEGraph (Maybe Int))
-> Maybe Int -> RndEGraph (Maybe Int)
forall a b. (a -> b) -> a -> b
$ Int -> Maybe Int
forall a. a -> Maybe a
Just Int
rndId
Consts
_ -> Maybe Int -> RndEGraph (Maybe Int)
forall a. a -> StateT EGraph (StateT StdGen (StateT [ECache] IO)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe Int
forall a. Maybe a
Nothing
else Maybe Int -> RndEGraph (Maybe Int)
forall a. a -> StateT EGraph (StateT StdGen (StateT [ECache] IO)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe Int
forall a. Maybe a
Nothing
getParetoEcsUpTo :: Int -> Int -> StateT EGraph m [Int]
getParetoEcsUpTo Int
n Int
maxSize = [[Int]] -> [Int]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Int]] -> [Int])
-> StateT EGraph m [[Int]] -> StateT EGraph m [Int]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int] -> (Int -> StateT EGraph m [Int]) -> StateT EGraph m [[Int]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int
1..Int
maxSize] (\Int
i -> Int -> Int -> StateT EGraph m [Int]
forall (m :: * -> *). Monad m => Int -> Int -> EGraphST m [Int]
getTopFitEClassWithSize Int
i Int
n)
getParetoDLEcsUpTo :: Int -> Int -> StateT EGraph m [Int]
getParetoDLEcsUpTo Int
n Int
maxSize = [[Int]] -> [Int]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Int]] -> [Int])
-> StateT EGraph m [[Int]] -> StateT EGraph m [Int]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int] -> (Int -> StateT EGraph m [Int]) -> StateT EGraph m [[Int]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int
1..Int
maxSize] (\Int
i -> Int -> Int -> StateT EGraph m [Int]
forall (m :: * -> *). Monad m => Int -> Int -> EGraphST m [Int]
getTopDLEClassWithSize Int
i Int
n)
getBestExprWithSize :: Int -> StateT EGraph m [(Int, Maybe Double)]
getBestExprWithSize Int
n =
do [Int]
ec <- Int -> Int -> EGraphST m [Int]
forall (m :: * -> *). Monad m => Int -> Int -> EGraphST m [Int]
getTopFitEClassWithSize Int
n Int
1 EGraphST m [Int] -> ([Int] -> EGraphST m [Int]) -> EGraphST m [Int]
forall a b.
StateT EGraph m a -> (a -> StateT EGraph m b) -> StateT EGraph m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Int -> StateT EGraph m Int) -> [Int] -> EGraphST m [Int]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse Int -> StateT EGraph m Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical
if (Bool -> Bool
not ([Int] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
ec))
then do
Maybe Double
bestFit <- Int -> EGraphST m (Maybe Double)
forall (m :: * -> *). Monad m => Int -> EGraphST m (Maybe Double)
getFitness (Int -> EGraphST m (Maybe Double))
-> Int -> EGraphST m (Maybe Double)
forall a b. (a -> b) -> a -> b
$ [Int] -> Int
forall a. HasCallStack => [a] -> a
head [Int]
ec
[PVector]
bestP <- (EGraph -> [PVector]) -> StateT EGraph m [PVector]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EClassData -> [PVector]
_theta (EClassData -> [PVector])
-> (EGraph -> EClassData) -> EGraph -> [PVector]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info (EClass -> EClassData)
-> (EGraph -> EClass) -> EGraph -> EClassData
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap EClass -> Int -> EClass
forall a. IntMap a -> Int -> a
IM.! ([Int] -> Int
forall a. HasCallStack => [a] -> a
head [Int]
ec)) (IntMap EClass -> EClass)
-> (EGraph -> IntMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> IntMap EClass
_eClass)
[(Int, Maybe Double)] -> StateT EGraph m [(Int, Maybe Double)]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [([Int] -> Int
forall a. HasCallStack => [a] -> a
head [Int]
ec, Maybe Double
bestFit)]
else [(Int, Maybe Double)] -> StateT EGraph m [(Int, Maybe Double)]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
insertRndExpr :: Int
-> Rng (StateT [ECache] IO) (Fix SRTree)
-> Rng (StateT [ECache] IO) (SRTree ())
-> RndEGraph Int
insertRndExpr Int
maxSize Rng (StateT [ECache] IO) (Fix SRTree)
rndTerm Rng (StateT [ECache] IO) (SRTree ())
rndNonTerm =
do Bool
grow <- StateT StdGen (StateT [ECache] IO) Bool -> RndEGraph Bool
forall a. StateT StdGen (StateT [ECache] IO) a -> RndEGraph a
rnd StateT StdGen (StateT [ECache] IO) Bool
forall (m :: * -> *). Monad m => Rng m Bool
toss
Int
n <- StateT StdGen (StateT [ECache] IO) Int -> RndEGraph Int
forall a. StateT StdGen (StateT [ECache] IO) a -> RndEGraph a
rnd ([Int] -> StateT StdGen (StateT [ECache] IO) Int
forall (m :: * -> *) a. Monad m => [a] -> Rng m a
randomFrom [if Int
maxSize Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
4 then Int
4 else Int
1 .. Int
maxSize])
Fix SRTree
t <- Rng (StateT [ECache] IO) (Fix SRTree) -> RndEGraph (Fix SRTree)
forall a. StateT StdGen (StateT [ECache] IO) a -> RndEGraph a
rnd (Rng (StateT [ECache] IO) (Fix SRTree) -> RndEGraph (Fix SRTree))
-> Rng (StateT [ECache] IO) (Fix SRTree) -> RndEGraph (Fix SRTree)
forall a b. (a -> b) -> a -> b
$ Int
-> Int
-> Int
-> Rng (StateT [ECache] IO) (Fix SRTree)
-> Rng (StateT [ECache] IO) (SRTree ())
-> Bool
-> Rng (StateT [ECache] IO) (Fix SRTree)
forall (m :: * -> *).
Monad m =>
Int
-> Int
-> Int
-> Rng m (Fix SRTree)
-> Rng m (SRTree ())
-> Bool
-> Rng m (Fix SRTree)
Random.randomTree Int
3 Int
8 Int
n Rng (StateT [ECache] IO) (Fix SRTree)
rndTerm Rng (StateT [ECache] IO) (SRTree ())
rndNonTerm Bool
grow
(ENode -> Int) -> Fix SRTree -> RndEGraph Int
forall (m :: * -> *).
Monad m =>
(ENode -> Int) -> Fix SRTree -> EGraphST m Int
fromTree ENode -> Int
myCost Fix SRTree
t RndEGraph Int -> (Int -> RndEGraph Int) -> RndEGraph Int
forall a b.
StateT EGraph (StateT StdGen (StateT [ECache] IO)) a
-> (a -> StateT EGraph (StateT StdGen (StateT [ECache] IO)) b)
-> StateT EGraph (StateT StdGen (StateT [ECache] IO)) b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int -> RndEGraph Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical
refit :: (Int -> StateT EGraph m (Double, [PVector]))
-> Int -> StateT EGraph m ()
refit Int -> StateT EGraph m (Double, [PVector])
fitFun Int
ec = do
(Double
f, [PVector]
p) <- Int -> StateT EGraph m (Double, [PVector])
fitFun Int
ec
Maybe Double
mf <- Int -> EGraphST m (Maybe Double)
forall (m :: * -> *). Monad m => Int -> EGraphST m (Maybe Double)
getFitness Int
ec
case Maybe Double
mf of
Maybe Double
Nothing -> Int -> Double -> [PVector] -> StateT EGraph m ()
forall (m :: * -> *).
Monad m =>
Int -> Double -> [PVector] -> EGraphST m ()
insertFitness Int
ec Double
f [PVector]
p
Just Double
f' -> Bool -> StateT EGraph m () -> StateT EGraph m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Double
f Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
> Double
f') (StateT EGraph m () -> StateT EGraph m ())
-> StateT EGraph m () -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ Int -> Double -> [PVector] -> StateT EGraph m ()
forall (m :: * -> *).
Monad m =>
Int -> Double -> [PVector] -> EGraphST m ()
insertFitness Int
ec Double
f [PVector]
p
printBest :: p -> (t -> Int -> StateT EGraph m b) -> StateT EGraph m b
printBest p
fitFun t -> Int -> StateT EGraph m b
printExprFun = do
Int
bec <- (EGraph -> Int) -> StateT EGraph m Int
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((Double, Int) -> Int
forall a b. (a, b) -> b
snd ((Double, Int) -> Int)
-> (EGraph -> (Double, Int)) -> EGraph -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RangeTree Double -> (Double, Int)
forall a. Ord a => RangeTree a -> (a, Int)
getGreatest (RangeTree Double -> (Double, Int))
-> (EGraph -> RangeTree Double) -> EGraph -> (Double, Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraphDB -> RangeTree Double
_fitRangeDB (EGraphDB -> RangeTree Double)
-> (EGraph -> EGraphDB) -> EGraph -> RangeTree Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> EGraphDB
_eDB) StateT EGraph m Int
-> (Int -> StateT EGraph m Int) -> StateT EGraph m Int
forall a b.
StateT EGraph m a -> (a -> StateT EGraph m b) -> StateT EGraph m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int -> StateT EGraph m Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical
Maybe Double
bestFit <- (EGraph -> Maybe Double) -> StateT EGraph m (Maybe Double)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EClassData -> Maybe Double
_fitness(EClassData -> Maybe Double)
-> (EGraph -> EClassData) -> EGraph -> Maybe Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info (EClass -> EClassData)
-> (EGraph -> EClass) -> EGraph -> EClassData
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap EClass -> Int -> EClass
forall a. IntMap a -> Int -> a
IM.! Int
bec) (IntMap EClass -> EClass)
-> (EGraph -> IntMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> IntMap EClass
_eClass)
t -> Int -> StateT EGraph m b
printExprFun t
0 Int
bec
paretoFront :: (Int -> RndEGraph (Double, [PVector]))
-> Int
-> (Int
-> Int
-> StateT EGraph (StateT StdGen (StateT [ECache] IO)) [String])
-> RndEGraph [[String]]
paretoFront Int -> RndEGraph (Double, [PVector])
fitFun Int
maxSize Int
-> Int
-> StateT EGraph (StateT StdGen (StateT [ECache] IO)) [String]
printExprFun = Int -> Int -> Double -> RndEGraph [[String]]
go Int
1 Int
0 (-(Double
1.0Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
0.0))
where
go :: Int -> Int -> Double -> RndEGraph [[String]]
go :: Int -> Int -> Double -> RndEGraph [[String]]
go Int
n Int
ix Double
f
| Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
maxSize = [[String]] -> RndEGraph [[String]]
forall a. a -> StateT EGraph (StateT StdGen (StateT [ECache] IO)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
| Bool
otherwise = do
[(Int, Maybe Double)]
ecList <- Int
-> StateT
EGraph (StateT StdGen (StateT [ECache] IO)) [(Int, Maybe Double)]
forall {m :: * -> *}.
Monad m =>
Int -> StateT EGraph m [(Int, Maybe Double)]
getBestExprWithSize Int
n
if Bool -> Bool
not ([(Int, Maybe Double)] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Int, Maybe Double)]
ecList)
then do let (Int
ec, Maybe Double
mf) = [(Int, Maybe Double)] -> (Int, Maybe Double)
forall a. HasCallStack => [a] -> a
head [(Int, Maybe Double)]
ecList
f' :: Double
f' = Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust Maybe Double
mf
improved :: Bool
improved = Double
f' Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
>= Double
f Bool -> Bool -> Bool
&& (Bool -> Bool
not (Bool -> Bool) -> (Double -> Bool) -> Double -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> Bool
forall a. RealFloat a => a -> Bool
isNaN) Double
f' Bool -> Bool -> Bool
&& (Bool -> Bool
not (Bool -> Bool) -> (Double -> Bool) -> Double -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> Bool
forall a. RealFloat a => a -> Bool
isInfinite) Double
f'
Int
ec' <- Int -> RndEGraph Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical Int
ec
if Bool
improved
then do (Int -> RndEGraph (Double, [PVector])) -> Int -> RndEGraph ()
forall {m :: * -> *}.
Monad m =>
(Int -> StateT EGraph m (Double, [PVector]))
-> Int -> StateT EGraph m ()
refit Int -> RndEGraph (Double, [PVector])
fitFun Int
ec'
[String]
t <- Int
-> Int
-> StateT EGraph (StateT StdGen (StateT [ECache] IO)) [String]
printExprFun Int
ix Int
ec'
[[String]]
ts <- Int -> Int -> Double -> RndEGraph [[String]]
go (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) (Int
ix Int -> Int -> Int
forall a. Num a => a -> a -> a
+ if Bool
improved then Int
1 else Int
0) (Double -> Double -> Double
forall a. Ord a => a -> a -> a
max Double
f Double
f')
[[String]] -> RndEGraph [[String]]
forall a. a -> StateT EGraph (StateT StdGen (StateT [ECache] IO)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([String]
t[String] -> [[String]] -> [[String]]
forall a. a -> [a] -> [a]
:[[String]]
ts)
else Int -> Int -> Double -> RndEGraph [[String]]
go (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) (Int
ix Int -> Int -> Int
forall a. Num a => a -> a -> a
+ if Bool
improved then Int
1 else Int
0) (Double -> Double -> Double
forall a. Ord a => a -> a -> a
max Double
f Double
f')
else Int -> Int -> Double -> RndEGraph [[String]]
go (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Int
ix Double
f
evaluateUnevaluated :: (Int -> StateT EGraph m (Double, [PVector])) -> StateT EGraph m ()
evaluateUnevaluated Int -> StateT EGraph m (Double, [PVector])
fitFun = do
[Int]
ec <- (EGraph -> [Int]) -> StateT EGraph m [Int]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (IntSet -> [Int]
IntSet.toList (IntSet -> [Int]) -> (EGraph -> IntSet) -> EGraph -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraphDB -> IntSet
_unevaluated (EGraphDB -> IntSet) -> (EGraph -> EGraphDB) -> EGraph -> IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> EGraphDB
_eDB)
[Int] -> (Int -> StateT EGraph m ()) -> StateT EGraph m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int]
ec ((Int -> StateT EGraph m ()) -> StateT EGraph m ())
-> (Int -> StateT EGraph m ()) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ \Int
c -> do
(Double
f, [PVector]
p) <- Int -> StateT EGraph m (Double, [PVector])
fitFun Int
c
Int -> Double -> [PVector] -> StateT EGraph m ()
forall (m :: * -> *).
Monad m =>
Int -> Double -> [PVector] -> EGraphST m ()
insertFitness Int
c Double
f [PVector]
p
evaluateRndUnevaluated :: (Int -> RndEGraph (Double, [PVector])) -> RndEGraph Int
evaluateRndUnevaluated Int -> RndEGraph (Double, [PVector])
fitFun = do
[Int]
ec <- (EGraph -> [Int])
-> StateT EGraph (StateT StdGen (StateT [ECache] IO)) [Int]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (IntSet -> [Int]
IntSet.toList (IntSet -> [Int]) -> (EGraph -> IntSet) -> EGraph -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraphDB -> IntSet
_unevaluated (EGraphDB -> IntSet) -> (EGraph -> EGraphDB) -> EGraph -> IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> EGraphDB
_eDB)
Int
c <- StateT StdGen (StateT [ECache] IO) Int -> RndEGraph Int
forall a. StateT StdGen (StateT [ECache] IO) a -> RndEGraph a
rnd (StateT StdGen (StateT [ECache] IO) Int -> RndEGraph Int)
-> ([Int] -> StateT StdGen (StateT [ECache] IO) Int)
-> [Int]
-> RndEGraph Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> StateT StdGen (StateT [ECache] IO) Int
forall (m :: * -> *) a. Monad m => [a] -> Rng m a
randomFrom ([Int] -> RndEGraph Int) -> [Int] -> RndEGraph Int
forall a b. (a -> b) -> a -> b
$ [Int]
ec
(Double
f, [PVector]
p) <- Int -> RndEGraph (Double, [PVector])
fitFun Int
c
Int -> Double -> [PVector] -> RndEGraph ()
forall (m :: * -> *).
Monad m =>
Int -> Double -> [PVector] -> EGraphST m ()
insertFitness Int
c Double
f [PVector]
p
Int -> RndEGraph Int
forall a. a -> StateT EGraph (StateT StdGen (StateT [ECache] IO)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
c
doesExist, doesNotExist :: ENode -> RndEGraph Bool
doesExist :: ENode -> RndEGraph Bool
doesExist ENode
en = (EGraph -> Bool) -> RndEGraph Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ENode -> Map ENode Int -> Bool
forall k a. Ord k => k -> Map k a -> Bool
Map.member ENode
en) (Map ENode Int -> Bool)
-> (EGraph -> Map ENode Int) -> EGraph -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> Map ENode Int
_eNodeToEClass)
doesNotExist :: ENode -> RndEGraph Bool
doesNotExist ENode
en = (EGraph -> Bool) -> RndEGraph Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ENode -> Map ENode Int -> Bool
forall k a. Ord k => k -> Map k a -> Bool
Map.notMember ENode
en) (Map ENode Int -> Bool)
-> (EGraph -> Map ENode Int) -> EGraph -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> Map ENode Int
_eNodeToEClass)
doesNotExistGens :: [Maybe (EClassId -> ENode)] -> ENode -> RndEGraph Bool
doesNotExistGens :: [Maybe (Int -> ENode)] -> ENode -> RndEGraph Bool
doesNotExistGens [] ENode
en = (EGraph -> Bool) -> RndEGraph Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ENode -> Map ENode Int -> Bool
forall k a. Ord k => k -> Map k a -> Bool
Map.notMember ENode
en) (Map ENode Int -> Bool)
-> (EGraph -> Map ENode Int) -> EGraph -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> Map ENode Int
_eNodeToEClass)
doesNotExistGens (Maybe (Int -> ENode)
mGrand:[Maybe (Int -> ENode)]
grands) ENode
en = do Bool
b <- (EGraph -> Bool) -> RndEGraph Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ENode -> Map ENode Int -> Bool
forall k a. Ord k => k -> Map k a -> Bool
Map.notMember ENode
en) (Map ENode Int -> Bool)
-> (EGraph -> Map ENode Int) -> EGraph -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> Map ENode Int
_eNodeToEClass)
if Bool
b
then Bool -> RndEGraph Bool
forall a. a -> StateT EGraph (StateT StdGen (StateT [ECache] IO)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
else case Maybe (Int -> ENode)
mGrand of
Maybe (Int -> ENode)
Nothing -> Bool -> RndEGraph Bool
forall a. a -> StateT EGraph (StateT StdGen (StateT [ECache] IO)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
Just Int -> ENode
gf -> do Int
ec <- (EGraph -> Int) -> RndEGraph Int
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((Map ENode Int -> ENode -> Int
forall k a. Ord k => Map k a -> k -> a
Map.! ENode
en) (Map ENode Int -> Int)
-> (EGraph -> Map ENode Int) -> EGraph -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> Map ENode Int
_eNodeToEClass)
ENode
en' <- ENode -> EGraphST (StateT StdGen (StateT [ECache] IO)) ENode
forall (m :: * -> *). Monad m => ENode -> EGraphST m ENode
canonize (Int -> ENode
gf Int
ec)
[Maybe (Int -> ENode)] -> ENode -> RndEGraph Bool
doesNotExistGens [Maybe (Int -> ENode)]
grands ENode
en'
checkToken :: (Int -> ENode) -> ENode -> RndEGraph Bool
checkToken Int -> ENode
parent ENode
en' = do ENode
en <- ENode -> EGraphST (StateT StdGen (StateT [ECache] IO)) ENode
forall (m :: * -> *). Monad m => ENode -> EGraphST m ENode
canonize ENode
en'
Maybe Int
mEc <- (EGraph -> Maybe Int) -> RndEGraph (Maybe Int)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((Map ENode Int -> ENode -> Maybe Int
forall k a. Ord k => Map k a -> k -> Maybe a
Map.!? ENode
en) (Map ENode Int -> Maybe Int)
-> (EGraph -> Map ENode Int) -> EGraph -> Maybe Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> Map ENode Int
_eNodeToEClass)
case Maybe Int
mEc of
Maybe Int
Nothing -> Bool -> RndEGraph Bool
forall a. a -> StateT EGraph (StateT StdGen (StateT [ECache] IO)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
Just Int
ec -> do Int
ec' <- Int -> RndEGraph Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical Int
ec
ENode
ec'' <- ENode -> EGraphST (StateT StdGen (StateT [ECache] IO)) ENode
forall (m :: * -> *). Monad m => ENode -> EGraphST m ENode
canonize (Int -> ENode
parent Int
ec')
Bool -> Bool
not (Bool -> Bool) -> RndEGraph Bool -> RndEGraph Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ENode -> RndEGraph Bool
doesExist ENode
ec''