{-# LANGUAGE TupleSections #-}
{-# LANGUAGE BangPatterns #-}
module Algorithm.EqSat.Build where
import System.Random (Random (randomR), StdGen)
import Control.Lens ( over )
import Control.Monad ( forM_, when, foldM, forM )
import Data.Maybe ( fromMaybe, catMaybes )
import Data.SRTree
import Algorithm.EqSat.Egraph
import Algorithm.EqSat.DB
import qualified Data.IntMap.Strict as IntMap
import Data.Map.Strict ( Map )
import qualified Data.Map.Strict as Map
import qualified Data.HashSet as Set
import Control.Monad.State.Strict
import Control.Monad.Identity
import Data.SRTree.Recursion (cataM)
import Algorithm.EqSat.Info
import qualified Data.IntSet as IntSet
import Data.Maybe
import Data.Sequence (Seq(..), (><))
import Data.List ( nub )
import Debug.Trace (trace, traceShow)
add :: Monad m => CostFun -> ENode -> EGraphST m EClassId
add :: forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
costFun ENode
enode =
do ENode
enode'' <- ENode -> EGraphST m ENode
forall (m :: * -> *). Monad m => ENode -> EGraphST m ENode
canonize ENode
enode
Consts
constEnode <- ENode -> EGraphST m Consts
forall (m :: * -> *). Monad m => ENode -> EGraphST m Consts
calculateConsts ENode
enode''
ENode
enode' <- case Consts
constEnode of
ConstVal Double
x -> ENode -> EGraphST m ENode
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ENode -> EGraphST m ENode) -> ENode -> EGraphST m ENode
forall a b. (a -> b) -> a -> b
$ Double -> ENode
forall val. Double -> SRTree val
Const Double
x
ParamIx Int
x -> ENode -> EGraphST m ENode
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ENode -> EGraphST m ENode) -> ENode -> EGraphST m ENode
forall a b. (a -> b) -> a -> b
$ Int -> ENode
forall val. Int -> SRTree val
Param Int
x
Consts
_ -> case ENode
enode'' of
Bin Op
Sub Int
c1 Int
c2 -> do Consts
constType <- (EGraph -> Consts) -> EGraphST m Consts
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EClassData -> Consts
_consts (EClassData -> Consts)
-> (EGraph -> EClassData) -> EGraph -> Consts
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info (EClass -> EClassData)
-> (EGraph -> EClass) -> EGraph -> EClassData
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap EClass -> Int -> EClass
forall a. IntMap a -> Int -> a
IntMap.! Int
c2) (IntMap EClass -> EClass)
-> (EGraph -> IntMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> IntMap EClass
_eClass)
ENode -> EGraphST m ENode
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ENode -> EGraphST m ENode) -> ENode -> EGraphST m ENode
forall a b. (a -> b) -> a -> b
$ case Consts
constType of
ParamIx Int
x -> Op -> Int -> Int -> ENode
forall val. Op -> val -> val -> SRTree val
Bin Op
Add Int
c1 Int
c2
Consts
_ -> ENode
enode''
Bin Op
Div Int
c1 Int
c2 -> do Consts
constType <- (EGraph -> Consts) -> EGraphST m Consts
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EClassData -> Consts
_consts (EClassData -> Consts)
-> (EGraph -> EClassData) -> EGraph -> Consts
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info (EClass -> EClassData)
-> (EGraph -> EClass) -> EGraph -> EClassData
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap EClass -> Int -> EClass
forall a. IntMap a -> Int -> a
IntMap.! Int
c2) (IntMap EClass -> EClass)
-> (EGraph -> IntMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> IntMap EClass
_eClass)
ENode -> EGraphST m ENode
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ENode -> EGraphST m ENode) -> ENode -> EGraphST m ENode
forall a b. (a -> b) -> a -> b
$ case Consts
constType of
ParamIx Int
x -> Op -> Int -> Int -> ENode
forall val. Op -> val -> val -> SRTree val
Bin Op
Mul Int
c1 Int
c2
Consts
_ -> ENode
enode''
ENode
_ -> ENode -> EGraphST m ENode
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ENode -> EGraphST m ENode) -> ENode -> EGraphST m ENode
forall a b. (a -> b) -> a -> b
$ ENode
enode''
Maybe Int
maybeEid <- (EGraph -> Maybe Int) -> StateT EGraph m (Maybe Int)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((Map ENode Int -> ENode -> Maybe Int
forall k a. Ord k => Map k a -> k -> Maybe a
Map.!? ENode
enode') (Map ENode Int -> Maybe Int)
-> (EGraph -> Map ENode Int) -> EGraph -> Maybe Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> Map ENode Int
_eNodeToEClass)
case Maybe Int
maybeEid of
Just Int
eid -> Int -> EGraphST m Int
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
eid
Maybe Int
Nothing -> do
Int
curId <- (EGraph -> Int) -> EGraphST m Int
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EGraphDB -> Int
_nextId (EGraphDB -> Int) -> (EGraph -> EGraphDB) -> EGraph -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> EGraphDB
_eDB)
(EGraph -> EGraph) -> StateT EGraph m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> StateT EGraph m ())
-> (EGraph -> EGraph) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (ClassIdMap Int) (ClassIdMap Int)
-> (ClassIdMap Int -> ClassIdMap Int) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter EGraph EGraph (ClassIdMap Int) (ClassIdMap Int)
Lens' EGraph (ClassIdMap Int)
canonicalMap (Int -> Int -> ClassIdMap Int -> ClassIdMap Int
forall a. Int -> a -> IntMap a -> IntMap a
IntMap.insert Int
curId Int
curId)
(EGraph -> EGraph) -> (EGraph -> EGraph) -> EGraph -> EGraph
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ASetter EGraph EGraph (Map ENode Int) (Map ENode Int)
-> (Map ENode Int -> Map ENode Int) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter EGraph EGraph (Map ENode Int) (Map ENode Int)
Lens' EGraph (Map ENode Int)
eNodeToEClass (ENode -> Int -> Map ENode Int -> Map ENode Int
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert ENode
enode' Int
curId)
(EGraph -> EGraph) -> (EGraph -> EGraph) -> EGraph -> EGraph
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ASetter EGraph EGraph Int Int -> (Int -> Int) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((Int -> Identity Int) -> EGraphDB -> Identity EGraphDB)
-> ASetter EGraph EGraph Int Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Identity Int) -> EGraphDB -> Identity EGraphDB
Lens' EGraphDB Int
nextId) (Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
(EGraph -> EGraph) -> (EGraph -> EGraph) -> EGraph -> EGraph
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ASetter EGraph EGraph (HashSet (Int, ENode)) (HashSet (Int, ENode))
-> (HashSet (Int, ENode) -> HashSet (Int, ENode))
-> EGraph
-> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((HashSet (Int, ENode) -> Identity (HashSet (Int, ENode)))
-> EGraphDB -> Identity EGraphDB)
-> ASetter
EGraph EGraph (HashSet (Int, ENode)) (HashSet (Int, ENode))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HashSet (Int, ENode) -> Identity (HashSet (Int, ENode)))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (HashSet (Int, ENode))
worklist) ((Int, ENode) -> HashSet (Int, ENode) -> HashSet (Int, ENode)
forall a. Hashable a => a -> HashSet a -> HashSet a
Set.insert (Int
curId, ENode
enode'))
[Int] -> (Int -> StateT EGraph m ()) -> StateT EGraph m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (ENode -> [Int]
forall a. SRTree a -> [a]
childrenOf ENode
enode') (Int -> ENode -> Int -> StateT EGraph m ()
forall (m :: * -> *).
Monad m =>
Int -> ENode -> Int -> EGraphST m ()
addParents Int
curId ENode
enode')
EClassData
info <- CostFun -> ENode -> EGraphST m EClassData
forall (m :: * -> *).
Monad m =>
CostFun -> ENode -> EGraphST m EClassData
makeAnalysis CostFun
costFun ENode
enode'
Int
h <- ENode -> EGraphST m Int
forall (m :: * -> *). Monad m => ENode -> EGraphST m Int
getChildrenMinHeight ENode
enode'
let newClass :: EClass
newClass = Int -> ENode -> EClassData -> Int -> EClass
createEClass Int
curId ENode
enode' EClassData
info Int
h
(EGraph -> EGraph) -> StateT EGraph m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> StateT EGraph m ())
-> (EGraph -> EGraph) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (IntMap EClass) (IntMap EClass)
-> (IntMap EClass -> IntMap EClass) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter EGraph EGraph (IntMap EClass) (IntMap EClass)
Lens' EGraph (IntMap EClass)
eClass (Int -> EClass -> IntMap EClass -> IntMap EClass
forall a. Int -> a -> IntMap a -> IntMap a
IntMap.insert Int
curId EClass
newClass)
ENode -> Int -> StateT EGraph m ()
forall (m :: * -> *). Monad m => ENode -> Int -> EGraphST m ()
addToDB ENode
enode' Int
curId
(EGraph -> EGraph) -> StateT EGraph m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> StateT EGraph m ())
-> (EGraph -> EGraph) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (IntMap IntSet) (IntMap IntSet)
-> (IntMap IntSet -> IntMap IntSet) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((IntMap IntSet -> Identity (IntMap IntSet))
-> EGraphDB -> Identity EGraphDB)
-> ASetter EGraph EGraph (IntMap IntSet) (IntMap IntSet)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap IntSet -> Identity (IntMap IntSet))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (IntMap IntSet)
sizeDB)
((IntMap IntSet -> IntMap IntSet) -> EGraph -> EGraph)
-> (IntMap IntSet -> IntMap IntSet) -> EGraph -> EGraph
forall a b. (a -> b) -> a -> b
$ (IntSet -> IntSet -> IntSet)
-> Int -> IntSet -> IntMap IntSet -> IntMap IntSet
forall a. (a -> a -> a) -> Int -> a -> IntMap a -> IntMap a
IntMap.insertWith (IntSet -> IntSet -> IntSet
IntSet.union) (EClassData -> Int
_size EClassData
info) (Int -> IntSet
IntSet.singleton Int
curId)
(EGraph -> EGraph) -> StateT EGraph m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> StateT EGraph m ())
-> (EGraph -> EGraph) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph IntSet IntSet
-> (IntSet -> IntSet) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((IntSet -> Identity IntSet) -> EGraphDB -> Identity EGraphDB)
-> ASetter EGraph EGraph IntSet IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntSet -> Identity IntSet) -> EGraphDB -> Identity EGraphDB
Lens' EGraphDB IntSet
unevaluated) (Int -> IntSet -> IntSet
IntSet.insert Int
curId)
Int -> EGraphST m Int
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
curId
where
addParents :: Monad m => EClassId -> ENode -> EClassId -> EGraphST m ()
addParents :: forall (m :: * -> *).
Monad m =>
Int -> ENode -> Int -> EGraphST m ()
addParents Int
cId ENode
node Int
c =
do EClass
ec <- Int -> EGraphST m EClass
forall (m :: * -> *). Monad m => Int -> EGraphST m EClass
getEClass Int
c
let ec' :: EClass
ec' = EClass
ec{ _parents = Set.insert (cId, node) (_parents ec) }
(EGraph -> EGraph) -> EGraphST m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> EGraphST m ())
-> (EGraph -> EGraph) -> EGraphST m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (IntMap EClass) (IntMap EClass)
-> (IntMap EClass -> IntMap EClass) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter EGraph EGraph (IntMap EClass) (IntMap EClass)
Lens' EGraph (IntMap EClass)
eClass (Int -> EClass -> IntMap EClass -> IntMap EClass
forall a. Int -> a -> IntMap a -> IntMap a
IntMap.insert Int
c EClass
ec')
rebuild :: Monad m => CostFun -> EGraphST m ()
rebuild :: forall (m :: * -> *). Monad m => CostFun -> EGraphST m ()
rebuild CostFun
costFun =
do HashSet (Int, ENode)
wl <- (EGraph -> HashSet (Int, ENode))
-> StateT EGraph m (HashSet (Int, ENode))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EGraphDB -> HashSet (Int, ENode)
_worklist (EGraphDB -> HashSet (Int, ENode))
-> (EGraph -> EGraphDB) -> EGraph -> HashSet (Int, ENode)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> EGraphDB
_eDB)
HashSet (Int, ENode)
al <- (EGraph -> HashSet (Int, ENode))
-> StateT EGraph m (HashSet (Int, ENode))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EGraphDB -> HashSet (Int, ENode)
_analysis (EGraphDB -> HashSet (Int, ENode))
-> (EGraph -> EGraphDB) -> EGraph -> HashSet (Int, ENode)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> EGraphDB
_eDB)
(EGraph -> EGraph) -> EGraphST m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> EGraphST m ())
-> (EGraph -> EGraph) -> EGraphST m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (HashSet (Int, ENode)) (HashSet (Int, ENode))
-> (HashSet (Int, ENode) -> HashSet (Int, ENode))
-> EGraph
-> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((HashSet (Int, ENode) -> Identity (HashSet (Int, ENode)))
-> EGraphDB -> Identity EGraphDB)
-> ASetter
EGraph EGraph (HashSet (Int, ENode)) (HashSet (Int, ENode))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HashSet (Int, ENode) -> Identity (HashSet (Int, ENode)))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (HashSet (Int, ENode))
worklist) (HashSet (Int, ENode)
-> HashSet (Int, ENode) -> HashSet (Int, ENode)
forall a b. a -> b -> a
const HashSet (Int, ENode)
forall a. HashSet a
Set.empty)
(EGraph -> EGraph) -> (EGraph -> EGraph) -> EGraph -> EGraph
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ASetter EGraph EGraph (HashSet (Int, ENode)) (HashSet (Int, ENode))
-> (HashSet (Int, ENode) -> HashSet (Int, ENode))
-> EGraph
-> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((HashSet (Int, ENode) -> Identity (HashSet (Int, ENode)))
-> EGraphDB -> Identity EGraphDB)
-> ASetter
EGraph EGraph (HashSet (Int, ENode)) (HashSet (Int, ENode))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HashSet (Int, ENode) -> Identity (HashSet (Int, ENode)))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (HashSet (Int, ENode))
analysis) (HashSet (Int, ENode)
-> HashSet (Int, ENode) -> HashSet (Int, ENode)
forall a b. a -> b -> a
const HashSet (Int, ENode)
forall a. HashSet a
Set.empty)
HashSet (Int, ENode)
-> ((Int, ENode) -> EGraphST m ()) -> EGraphST m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ HashSet (Int, ENode)
wl ((Int -> ENode -> EGraphST m ()) -> (Int, ENode) -> EGraphST m ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (CostFun -> Int -> ENode -> EGraphST m ()
forall (m :: * -> *).
Monad m =>
CostFun -> Int -> ENode -> EGraphST m ()
repair CostFun
costFun))
HashSet (Int, ENode)
-> ((Int, ENode) -> EGraphST m ()) -> EGraphST m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ HashSet (Int, ENode)
al ((Int -> ENode -> EGraphST m ()) -> (Int, ENode) -> EGraphST m ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (CostFun -> Int -> ENode -> EGraphST m ()
forall (m :: * -> *).
Monad m =>
CostFun -> Int -> ENode -> EGraphST m ()
repairAnalysis CostFun
costFun))
{-# INLINE rebuild #-}
repair :: Monad m => CostFun -> EClassId -> ENode -> EGraphST m ()
repair :: forall (m :: * -> *).
Monad m =>
CostFun -> Int -> ENode -> EGraphST m ()
repair CostFun
costFun Int
ecId ENode
enode =
do (EGraph -> EGraph) -> EGraphST m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> EGraphST m ())
-> (EGraph -> EGraph) -> EGraphST m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (Map ENode Int) (Map ENode Int)
-> (Map ENode Int -> Map ENode Int) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter EGraph EGraph (Map ENode Int) (Map ENode Int)
Lens' EGraph (Map ENode Int)
eNodeToEClass (ENode -> Map ENode Int -> Map ENode Int
forall k a. Ord k => k -> Map k a -> Map k a
Map.delete ENode
enode)
ENode
enode' <- ENode -> EGraphST m ENode
forall (m :: * -> *). Monad m => ENode -> EGraphST m ENode
canonize ENode
enode
Int
ecId' <- Int -> EGraphST m Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical Int
ecId
Maybe Int
doExist <- (EGraph -> Maybe Int) -> StateT EGraph m (Maybe Int)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((Map ENode Int -> ENode -> Maybe Int
forall k a. Ord k => Map k a -> k -> Maybe a
Map.!? ENode
enode') (Map ENode Int -> Maybe Int)
-> (EGraph -> Map ENode Int) -> EGraph -> Maybe Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> Map ENode Int
_eNodeToEClass)
case Maybe Int
doExist of
Just Int
ecIdCanon -> do Int
mergedId <- CostFun -> Int -> Int -> EGraphST m Int
forall (m :: * -> *).
Monad m =>
CostFun -> Int -> Int -> EGraphST m Int
merge CostFun
costFun Int
ecIdCanon Int
ecId'
(EGraph -> EGraph) -> EGraphST m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> EGraphST m ())
-> (EGraph -> EGraph) -> EGraphST m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (Map ENode Int) (Map ENode Int)
-> (Map ENode Int -> Map ENode Int) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter EGraph EGraph (Map ENode Int) (Map ENode Int)
Lens' EGraph (Map ENode Int)
eNodeToEClass (ENode -> Int -> Map ENode Int -> Map ENode Int
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert ENode
enode' Int
mergedId)
Maybe Int
Nothing -> (EGraph -> EGraph) -> EGraphST m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> EGraphST m ())
-> (EGraph -> EGraph) -> EGraphST m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (Map ENode Int) (Map ENode Int)
-> (Map ENode Int -> Map ENode Int) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter EGraph EGraph (Map ENode Int) (Map ENode Int)
Lens' EGraph (Map ENode Int)
eNodeToEClass (ENode -> Int -> Map ENode Int -> Map ENode Int
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert ENode
enode' Int
ecId')
{-# INLINE repair #-}
repairAnalysis :: Monad m => CostFun -> EClassId -> ENode -> EGraphST m ()
repairAnalysis :: forall (m :: * -> *).
Monad m =>
CostFun -> Int -> ENode -> EGraphST m ()
repairAnalysis CostFun
costFun Int
ecId ENode
enode =
do Int
ecId' <- Int -> EGraphST m Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical Int
ecId
ENode
enode' <- ENode -> EGraphST m ENode
forall (m :: * -> *). Monad m => ENode -> EGraphST m ENode
canonize ENode
enode
EClass
eclass <- Int -> EGraphST m EClass
forall (m :: * -> *). Monad m => Int -> EGraphST m EClass
getEClass Int
ecId'
EClassData
info <- CostFun -> ENode -> EGraphST m EClassData
forall (m :: * -> *).
Monad m =>
CostFun -> ENode -> EGraphST m EClassData
makeAnalysis CostFun
costFun ENode
enode'
let newData :: EClassData
newData = EClassData -> EClassData -> EClassData
joinData (EClass -> EClassData
_info EClass
eclass) EClassData
info
eclass' :: EClass
eclass' = EClass
eclass { _info = newData }
Bool -> EGraphST m () -> EGraphST m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (EClass -> EClassData
_info EClass
eclass EClassData -> EClassData -> Bool
forall a. Eq a => a -> a -> Bool
/= EClassData
newData) (EGraphST m () -> EGraphST m ()) -> EGraphST m () -> EGraphST m ()
forall a b. (a -> b) -> a -> b
$
do (EGraph -> EGraph) -> EGraphST m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> EGraphST m ())
-> (EGraph -> EGraph) -> EGraphST m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (HashSet (Int, ENode)) (HashSet (Int, ENode))
-> (HashSet (Int, ENode) -> HashSet (Int, ENode))
-> EGraph
-> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((HashSet (Int, ENode) -> Identity (HashSet (Int, ENode)))
-> EGraphDB -> Identity EGraphDB)
-> ASetter
EGraph EGraph (HashSet (Int, ENode)) (HashSet (Int, ENode))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HashSet (Int, ENode) -> Identity (HashSet (Int, ENode)))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (HashSet (Int, ENode))
analysis) (EClass -> HashSet (Int, ENode)
_parents EClass
eclass HashSet (Int, ENode)
-> HashSet (Int, ENode) -> HashSet (Int, ENode)
forall a. Semigroup a => a -> a -> a
<>)
(EGraph -> EGraph) -> (EGraph -> EGraph) -> EGraph -> EGraph
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ASetter EGraph EGraph (IntMap EClass) (IntMap EClass)
-> (IntMap EClass -> IntMap EClass) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter EGraph EGraph (IntMap EClass) (IntMap EClass)
Lens' EGraph (IntMap EClass)
eClass (Int -> EClass -> IntMap EClass -> IntMap EClass
forall a. Int -> a -> IntMap a -> IntMap a
IntMap.insert Int
ecId' EClass
eclass')
(EGraph -> EGraph) -> (EGraph -> EGraph) -> EGraph -> EGraph
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ASetter EGraph EGraph (HashSet Int) (HashSet Int)
-> (HashSet Int -> HashSet Int) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((HashSet Int -> Identity (HashSet Int))
-> EGraphDB -> Identity EGraphDB)
-> ASetter EGraph EGraph (HashSet Int) (HashSet Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HashSet Int -> Identity (HashSet Int))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (HashSet Int)
refits) (Int -> HashSet Int -> HashSet Int
forall a. Hashable a => a -> HashSet a -> HashSet a
Set.insert Int
ecId')
Int
_ <- CostFun -> Int -> EGraphST m Int
forall (m :: * -> *). Monad m => CostFun -> Int -> EGraphST m Int
modifyEClass CostFun
costFun Int
ecId'
() -> EGraphST m ()
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
{-# INLINE repairAnalysis #-}
merge :: Monad m => CostFun -> EClassId -> EClassId -> EGraphST m EClassId
merge :: forall (m :: * -> *).
Monad m =>
CostFun -> Int -> Int -> EGraphST m Int
merge CostFun
costFun Int
c1 Int
c2 =
do Int
c1' <- Int -> EGraphST m Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical Int
c1
Int
c2' <- Int -> EGraphST m Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical Int
c2
if Int
c1' Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
c2'
then Int -> EGraphST m Int
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
c1'
else do (Int
led, EClass
ledC, Int
ledOrig, Int
sub, EClass
subC, Int
subOrig) <- Int
-> Int
-> Int
-> Int
-> StateT EGraph m (Int, EClass, Int, Int, EClass, Int)
forall {m :: * -> *} {c}.
Monad m =>
Int
-> c
-> Int
-> c
-> StateT EGraph m (Int, EClass, c, Int, EClass, c)
getLeaderSub Int
c1' Int
c1 Int
c2' Int
c2
Int -> EClass -> Int -> Int -> EClass -> Int -> EGraphST m Int
forall (m :: * -> *).
Monad m =>
Int -> EClass -> Int -> Int -> EClass -> Int -> EGraphST m Int
mergeClasses Int
led EClass
ledC Int
ledOrig Int
sub EClass
subC Int
subOrig
where
mergeClasses :: Monad m => EClassId -> EClass -> EClassId -> EClassId -> EClass -> EClassId -> EGraphST m EClassId
mergeClasses :: forall (m :: * -> *).
Monad m =>
Int -> EClass -> Int -> Int -> EClass -> Int -> EGraphST m Int
mergeClasses Int
led EClass
ledC Int
ledO Int
sub EClass
subC Int
subO =
do (EGraph -> EGraph) -> StateT EGraph m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> StateT EGraph m ())
-> (EGraph -> EGraph) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (ClassIdMap Int) (ClassIdMap Int)
-> (ClassIdMap Int -> ClassIdMap Int) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter EGraph EGraph (ClassIdMap Int) (ClassIdMap Int)
Lens' EGraph (ClassIdMap Int)
canonicalMap (Int -> Int -> ClassIdMap Int -> ClassIdMap Int
forall a. Int -> a -> IntMap a -> IntMap a
IntMap.insert Int
sub Int
led (ClassIdMap Int -> ClassIdMap Int)
-> (ClassIdMap Int -> ClassIdMap Int)
-> ClassIdMap Int
-> ClassIdMap Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Int -> ClassIdMap Int -> ClassIdMap Int
forall a. Int -> a -> IntMap a -> IntMap a
IntMap.insert Int
subO Int
led)
let
newC :: EClass
newC = Int
-> HashSet ENodeEnc
-> HashSet (Int, ENode)
-> Int
-> EClassData
-> EClass
EClass Int
led
(EClass -> HashSet ENodeEnc
_eNodes EClass
ledC HashSet ENodeEnc -> HashSet ENodeEnc -> HashSet ENodeEnc
forall a. Eq a => HashSet a -> HashSet a -> HashSet a
`Set.union` EClass -> HashSet ENodeEnc
_eNodes EClass
subC)
(EClass -> HashSet (Int, ENode)
_parents EClass
ledC HashSet (Int, ENode)
-> HashSet (Int, ENode) -> HashSet (Int, ENode)
forall a. Semigroup a => a -> a -> a
<> EClass -> HashSet (Int, ENode)
_parents EClass
subC)
(Int -> Int -> Int
forall a. Ord a => a -> a -> a
min (EClass -> Int
_height EClass
ledC) (EClass -> Int
_height EClass
subC))
(EClassData -> EClassData -> EClassData
joinData (EClass -> EClassData
_info EClass
ledC) (EClass -> EClassData
_info EClass
subC))
(EGraph -> EGraph) -> StateT EGraph m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> StateT EGraph m ())
-> (EGraph -> EGraph) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (IntMap EClass) (IntMap EClass)
-> (IntMap EClass -> IntMap EClass) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter EGraph EGraph (IntMap EClass) (IntMap EClass)
Lens' EGraph (IntMap EClass)
eClass (Int -> EClass -> IntMap EClass -> IntMap EClass
forall a. Int -> a -> IntMap a -> IntMap a
IntMap.insert Int
led EClass
newC (IntMap EClass -> IntMap EClass)
-> (IntMap EClass -> IntMap EClass)
-> IntMap EClass
-> IntMap EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> IntMap EClass -> IntMap EClass
forall a. Int -> IntMap a -> IntMap a
IntMap.delete Int
sub)
(EGraph -> EGraph) -> (EGraph -> EGraph) -> EGraph -> EGraph
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ASetter EGraph EGraph (HashSet (Int, ENode)) (HashSet (Int, ENode))
-> (HashSet (Int, ENode) -> HashSet (Int, ENode))
-> EGraph
-> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((HashSet (Int, ENode) -> Identity (HashSet (Int, ENode)))
-> EGraphDB -> Identity EGraphDB)
-> ASetter
EGraph EGraph (HashSet (Int, ENode)) (HashSet (Int, ENode))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HashSet (Int, ENode) -> Identity (HashSet (Int, ENode)))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (HashSet (Int, ENode))
worklist) (EClass -> HashSet (Int, ENode)
_parents EClass
subC HashSet (Int, ENode)
-> HashSet (Int, ENode) -> HashSet (Int, ENode)
forall a. Semigroup a => a -> a -> a
<>)
Bool -> StateT EGraph m () -> StateT EGraph m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (EClass -> EClassData
_info EClass
newC EClassData -> EClassData -> Bool
forall a. Eq a => a -> a -> Bool
/= EClass -> EClassData
_info EClass
ledC)
(StateT EGraph m () -> StateT EGraph m ())
-> StateT EGraph m () -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ (EGraph -> EGraph) -> StateT EGraph m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> StateT EGraph m ())
-> (EGraph -> EGraph) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (HashSet (Int, ENode)) (HashSet (Int, ENode))
-> (HashSet (Int, ENode) -> HashSet (Int, ENode))
-> EGraph
-> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((HashSet (Int, ENode) -> Identity (HashSet (Int, ENode)))
-> EGraphDB -> Identity EGraphDB)
-> ASetter
EGraph EGraph (HashSet (Int, ENode)) (HashSet (Int, ENode))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HashSet (Int, ENode) -> Identity (HashSet (Int, ENode)))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (HashSet (Int, ENode))
analysis) (EClass -> HashSet (Int, ENode)
_parents EClass
ledC HashSet (Int, ENode)
-> HashSet (Int, ENode) -> HashSet (Int, ENode)
forall a. Semigroup a => a -> a -> a
<>)
(EGraph -> EGraph) -> (EGraph -> EGraph) -> EGraph -> EGraph
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ASetter EGraph EGraph (HashSet Int) (HashSet Int)
-> (HashSet Int -> HashSet Int) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((HashSet Int -> Identity (HashSet Int))
-> EGraphDB -> Identity EGraphDB)
-> ASetter EGraph EGraph (HashSet Int) (HashSet Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HashSet Int -> Identity (HashSet Int))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (HashSet Int)
refits) (Int -> HashSet Int -> HashSet Int
forall a. Hashable a => a -> HashSet a -> HashSet a
Set.insert Int
led)
Bool -> StateT EGraph m () -> StateT EGraph m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (EClass -> EClassData
_info EClass
newC EClassData -> EClassData -> Bool
forall a. Eq a => a -> a -> Bool
/= EClass -> EClassData
_info EClass
subC)
(StateT EGraph m () -> StateT EGraph m ())
-> StateT EGraph m () -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ (EGraph -> EGraph) -> StateT EGraph m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> StateT EGraph m ())
-> (EGraph -> EGraph) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (HashSet (Int, ENode)) (HashSet (Int, ENode))
-> (HashSet (Int, ENode) -> HashSet (Int, ENode))
-> EGraph
-> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((HashSet (Int, ENode) -> Identity (HashSet (Int, ENode)))
-> EGraphDB -> Identity EGraphDB)
-> ASetter
EGraph EGraph (HashSet (Int, ENode)) (HashSet (Int, ENode))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HashSet (Int, ENode) -> Identity (HashSet (Int, ENode)))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (HashSet (Int, ENode))
analysis) (EClass -> HashSet (Int, ENode)
_parents EClass
subC HashSet (Int, ENode)
-> HashSet (Int, ENode) -> HashSet (Int, ENode)
forall a. Semigroup a => a -> a -> a
<>)
EClass
-> Int
-> EClass
-> Int
-> Int
-> EClass
-> Int
-> StateT EGraph m ()
forall (m :: * -> *).
Monad m =>
EClass
-> Int -> EClass -> Int -> Int -> EClass -> Int -> EGraphST m ()
updateDBs EClass
newC Int
led EClass
ledC Int
ledO Int
sub EClass
subC Int
subO
CostFun -> Int -> EGraphST m Int
forall (m :: * -> *). Monad m => CostFun -> Int -> EGraphST m Int
modifyEClass CostFun
costFun Int
led
Int -> EGraphST m Int
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
led
getLeaderSub :: Int
-> c
-> Int
-> c
-> StateT EGraph m (Int, EClass, c, Int, EClass, c)
getLeaderSub Int
c1 c
c1O Int
c2 c
c2O =
do EClass
ec1 <- Int -> EGraphST m EClass
forall (m :: * -> *). Monad m => Int -> EGraphST m EClass
getEClass Int
c1
EClass
ec2 <- Int -> EGraphST m EClass
forall (m :: * -> *). Monad m => Int -> EGraphST m EClass
getEClass Int
c2
let n1 :: Int
n1 = HashSet (Int, ENode) -> Int
forall a. HashSet a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (EClass -> HashSet (Int, ENode)
_parents EClass
ec1)
n2 :: Int
n2 = HashSet (Int, ENode) -> Int
forall a. HashSet a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (EClass -> HashSet (Int, ENode)
_parents EClass
ec2)
(Int, EClass, c, Int, EClass, c)
-> StateT EGraph m (Int, EClass, c, Int, EClass, c)
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Int, EClass, c, Int, EClass, c)
-> StateT EGraph m (Int, EClass, c, Int, EClass, c))
-> (Int, EClass, c, Int, EClass, c)
-> StateT EGraph m (Int, EClass, c, Int, EClass, c)
forall a b. (a -> b) -> a -> b
$ if Int
n1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
n2
then (Int
c1, EClass
ec1, c
c1O, Int
c2, EClass
ec2, c
c2O)
else (Int
c2, EClass
ec2, c
c2O, Int
c1, EClass
ec1, c
c1O)
updateDBs :: Monad m => EClass -> EClassId -> EClass -> EClassId -> EClassId -> EClass -> EClassId -> EGraphST m ()
updateDBs :: forall (m :: * -> *).
Monad m =>
EClass
-> Int -> EClass -> Int -> Int -> EClass -> Int -> EGraphST m ()
updateDBs EClass
newC Int
led EClass
ledC Int
ledO Int
sub EClass
subC Int
subO = do
EClass
-> Int -> EClass -> Int -> Int -> EClass -> Int -> EGraphST m ()
forall (m :: * -> *).
Monad m =>
EClass
-> Int -> EClass -> Int -> Int -> EClass -> Int -> EGraphST m ()
updateFitnessDB EClass
newC Int
led EClass
ledC Int
ledO Int
sub EClass
subC Int
subO
EClass
-> Int -> EClass -> Int -> Int -> EClass -> Int -> EGraphST m ()
forall (m :: * -> *).
Monad m =>
EClass
-> Int -> EClass -> Int -> Int -> EClass -> Int -> EGraphST m ()
updateSizeDB EClass
newC Int
led EClass
ledC Int
ledO Int
sub EClass
subC Int
subO
updateSizeDB :: Monad m => EClass -> EClassId -> EClass -> EClassId -> EClassId -> EClass -> EClassId -> EGraphST m ()
updateSizeDB :: forall (m :: * -> *).
Monad m =>
EClass
-> Int -> EClass -> Int -> Int -> EClass -> Int -> EGraphST m ()
updateSizeDB EClass
newC Int
led EClass
ledC Int
ledO Int
sub EClass
subC Int
subO = do
let sz :: Int
sz = (EClassData -> Int
_size (EClassData -> Int) -> (EClass -> EClassData) -> EClass -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info) EClass
newC
szL :: Int
szL = (EClassData -> Int
_size (EClassData -> Int) -> (EClass -> EClassData) -> EClass -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info) EClass
ledC
szS :: Int
szS = (EClassData -> Int
_size (EClassData -> Int) -> (EClass -> EClassData) -> EClass -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info) EClass
subC
fun :: IntMap IntSet -> IntMap IntSet
fun = (IntSet -> IntSet) -> Int -> IntMap IntSet -> IntMap IntSet
forall a. (a -> a) -> Int -> IntMap a -> IntMap a
IntMap.adjust (Int -> IntSet -> IntSet
IntSet.insert Int
led) Int
sz (IntMap IntSet -> IntMap IntSet)
-> (IntMap IntSet -> IntMap IntSet)
-> IntMap IntSet
-> IntMap IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntSet -> IntSet) -> Int -> IntMap IntSet -> IntMap IntSet
forall a. (a -> a) -> Int -> IntMap a -> IntMap a
IntMap.adjust (Int -> IntSet -> IntSet
IntSet.delete Int
led (IntSet -> IntSet) -> (IntSet -> IntSet) -> IntSet -> IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> IntSet -> IntSet
IntSet.delete Int
ledO) Int
szL (IntMap IntSet -> IntMap IntSet)
-> (IntMap IntSet -> IntMap IntSet)
-> IntMap IntSet
-> IntMap IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntSet -> IntSet) -> Int -> IntMap IntSet -> IntMap IntSet
forall a. (a -> a) -> Int -> IntMap a -> IntMap a
IntMap.adjust (Int -> IntSet -> IntSet
IntSet.delete Int
sub (IntSet -> IntSet) -> (IntSet -> IntSet) -> IntSet -> IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> IntSet -> IntSet
IntSet.delete Int
subO) Int
szS
(EGraph -> EGraph) -> EGraphST m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> EGraphST m ())
-> (EGraph -> EGraph) -> EGraphST m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (IntMap IntSet) (IntMap IntSet)
-> (IntMap IntSet -> IntMap IntSet) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((IntMap IntSet -> Identity (IntMap IntSet))
-> EGraphDB -> Identity EGraphDB)
-> ASetter EGraph EGraph (IntMap IntSet) (IntMap IntSet)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap IntSet -> Identity (IntMap IntSet))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (IntMap IntSet)
sizeDB) IntMap IntSet -> IntMap IntSet
fun
updateFitnessDB :: Monad m => EClass -> EClassId -> EClass -> EClassId -> EClassId -> EClass -> EClassId -> EGraphST m ()
updateFitnessDB :: forall (m :: * -> *).
Monad m =>
EClass
-> Int -> EClass -> Int -> Int -> EClass -> Int -> EGraphST m ()
updateFitnessDB EClass
newC Int
led EClass
ledC Int
ledO Int
sub EClass
subC Int
subO =
if (Maybe Double -> Bool
forall a. Maybe a -> Bool
isJust Maybe Double
fitNew)
then do
Bool -> EGraphST m () -> EGraphST m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Maybe Double
fitNew Maybe Double -> Maybe Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Maybe Double
fitLed) (EGraphST m () -> EGraphST m ()) -> EGraphST m () -> EGraphST m ()
forall a b. (a -> b) -> a -> b
$ do
if Maybe Double -> Bool
forall a. Maybe a -> Bool
isNothing Maybe Double
fitLed
then (EGraph -> EGraph) -> EGraphST m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> EGraphST m ())
-> (EGraph -> EGraph) -> EGraphST m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph IntSet IntSet
-> (IntSet -> IntSet) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((IntSet -> Identity IntSet) -> EGraphDB -> Identity EGraphDB)
-> ASetter EGraph EGraph IntSet IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntSet -> Identity IntSet) -> EGraphDB -> Identity EGraphDB
Lens' EGraphDB IntSet
unevaluated) (Int -> IntSet -> IntSet
IntSet.delete Int
led (IntSet -> IntSet) -> (IntSet -> IntSet) -> IntSet -> IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> IntSet -> IntSet
IntSet.delete Int
ledO)
else (EGraph -> EGraph) -> EGraphST m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> EGraphST m ())
-> (EGraph -> EGraph) -> EGraphST m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (RangeTree Double) (RangeTree Double)
-> (RangeTree Double -> RangeTree Double) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((RangeTree Double -> Identity (RangeTree Double))
-> EGraphDB -> Identity EGraphDB)
-> ASetter EGraph EGraph (RangeTree Double) (RangeTree Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (RangeTree Double -> Identity (RangeTree Double))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (RangeTree Double)
fitRangeDB) (Int -> Double -> RangeTree Double -> RangeTree Double
forall a. (Ord a, Show a) => Int -> a -> RangeTree a -> RangeTree a
removeRange Int
led (Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust Maybe Double
fitLed) (RangeTree Double -> RangeTree Double)
-> (RangeTree Double -> RangeTree Double)
-> RangeTree Double
-> RangeTree Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Double -> RangeTree Double -> RangeTree Double
forall a. (Ord a, Show a) => Int -> a -> RangeTree a -> RangeTree a
removeRange Int
ledO (Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust Maybe Double
fitLed))
(EGraph -> EGraph) -> (EGraph -> EGraph) -> EGraph -> EGraph
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ASetter
EGraph
EGraph
(IntMap (RangeTree Double))
(IntMap (RangeTree Double))
-> (IntMap (RangeTree Double) -> IntMap (RangeTree Double))
-> EGraph
-> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((IntMap (RangeTree Double)
-> Identity (IntMap (RangeTree Double)))
-> EGraphDB -> Identity EGraphDB)
-> ASetter
EGraph
EGraph
(IntMap (RangeTree Double))
(IntMap (RangeTree Double))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap (RangeTree Double) -> Identity (IntMap (RangeTree Double)))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (IntMap (RangeTree Double))
sizeFitDB) ((RangeTree Double -> RangeTree Double)
-> Int -> IntMap (RangeTree Double) -> IntMap (RangeTree Double)
forall a. (a -> a) -> Int -> IntMap a -> IntMap a
IntMap.adjust (Int -> Double -> RangeTree Double -> RangeTree Double
forall a. (Ord a, Show a) => Int -> a -> RangeTree a -> RangeTree a
removeRange Int
ledO (Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust Maybe Double
fitLed) (RangeTree Double -> RangeTree Double)
-> (RangeTree Double -> RangeTree Double)
-> RangeTree Double
-> RangeTree Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Double -> RangeTree Double -> RangeTree Double
forall a. (Ord a, Show a) => Int -> a -> RangeTree a -> RangeTree a
removeRange Int
led (Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust Maybe Double
fitLed)) Int
szLed)
(EGraph -> EGraph) -> EGraphST m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> EGraphST m ())
-> (EGraph -> EGraph) -> EGraphST m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (RangeTree Double) (RangeTree Double)
-> (RangeTree Double -> RangeTree Double) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((RangeTree Double -> Identity (RangeTree Double))
-> EGraphDB -> Identity EGraphDB)
-> ASetter EGraph EGraph (RangeTree Double) (RangeTree Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (RangeTree Double -> Identity (RangeTree Double))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (RangeTree Double)
fitRangeDB) (Int -> Double -> RangeTree Double -> RangeTree Double
forall a. (Ord a, Show a) => Int -> a -> RangeTree a -> RangeTree a
insertRange Int
led (Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust Maybe Double
fitNew))
(EGraph -> EGraph) -> (EGraph -> EGraph) -> EGraph -> EGraph
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ASetter
EGraph
EGraph
(IntMap (RangeTree Double))
(IntMap (RangeTree Double))
-> (IntMap (RangeTree Double) -> IntMap (RangeTree Double))
-> EGraph
-> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((IntMap (RangeTree Double)
-> Identity (IntMap (RangeTree Double)))
-> EGraphDB -> Identity EGraphDB)
-> ASetter
EGraph
EGraph
(IntMap (RangeTree Double))
(IntMap (RangeTree Double))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap (RangeTree Double) -> Identity (IntMap (RangeTree Double)))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (IntMap (RangeTree Double))
sizeFitDB) ((RangeTree Double -> RangeTree Double)
-> Int -> IntMap (RangeTree Double) -> IntMap (RangeTree Double)
forall a. (a -> a) -> Int -> IntMap a -> IntMap a
IntMap.adjust (Int -> Double -> RangeTree Double -> RangeTree Double
forall a. (Ord a, Show a) => Int -> a -> RangeTree a -> RangeTree a
insertRange Int
led (Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust Maybe Double
fitNew)) Int
szNew (IntMap (RangeTree Double) -> IntMap (RangeTree Double))
-> (IntMap (RangeTree Double) -> IntMap (RangeTree Double))
-> IntMap (RangeTree Double)
-> IntMap (RangeTree Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (RangeTree Double -> RangeTree Double -> RangeTree Double)
-> Int
-> RangeTree Double
-> IntMap (RangeTree Double)
-> IntMap (RangeTree Double)
forall a. (a -> a -> a) -> Int -> a -> IntMap a -> IntMap a
IntMap.insertWith RangeTree Double -> RangeTree Double -> RangeTree Double
forall a. Seq a -> Seq a -> Seq a
(><) Int
szNew RangeTree Double
forall a. Seq a
Empty)
if Maybe Double -> Bool
forall a. Maybe a -> Bool
isNothing Maybe Double
fitSub
then (EGraph -> EGraph) -> EGraphST m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> EGraphST m ())
-> (EGraph -> EGraph) -> EGraphST m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph IntSet IntSet
-> (IntSet -> IntSet) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((IntSet -> Identity IntSet) -> EGraphDB -> Identity EGraphDB)
-> ASetter EGraph EGraph IntSet IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntSet -> Identity IntSet) -> EGraphDB -> Identity EGraphDB
Lens' EGraphDB IntSet
unevaluated) (Int -> IntSet -> IntSet
IntSet.delete Int
sub (IntSet -> IntSet) -> (IntSet -> IntSet) -> IntSet -> IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> IntSet -> IntSet
IntSet.delete Int
subO)
else (EGraph -> EGraph) -> EGraphST m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> EGraphST m ())
-> (EGraph -> EGraph) -> EGraphST m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (RangeTree Double) (RangeTree Double)
-> (RangeTree Double -> RangeTree Double) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((RangeTree Double -> Identity (RangeTree Double))
-> EGraphDB -> Identity EGraphDB)
-> ASetter EGraph EGraph (RangeTree Double) (RangeTree Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (RangeTree Double -> Identity (RangeTree Double))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (RangeTree Double)
fitRangeDB) (Int -> Double -> RangeTree Double -> RangeTree Double
forall a. (Ord a, Show a) => Int -> a -> RangeTree a -> RangeTree a
removeRange Int
sub (Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust Maybe Double
fitSub) (RangeTree Double -> RangeTree Double)
-> (RangeTree Double -> RangeTree Double)
-> RangeTree Double
-> RangeTree Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Double -> RangeTree Double -> RangeTree Double
forall a. (Ord a, Show a) => Int -> a -> RangeTree a -> RangeTree a
removeRange Int
subO (Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust Maybe Double
fitSub))
(EGraph -> EGraph) -> (EGraph -> EGraph) -> EGraph -> EGraph
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ASetter
EGraph
EGraph
(IntMap (RangeTree Double))
(IntMap (RangeTree Double))
-> (IntMap (RangeTree Double) -> IntMap (RangeTree Double))
-> EGraph
-> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((IntMap (RangeTree Double)
-> Identity (IntMap (RangeTree Double)))
-> EGraphDB -> Identity EGraphDB)
-> ASetter
EGraph
EGraph
(IntMap (RangeTree Double))
(IntMap (RangeTree Double))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap (RangeTree Double) -> Identity (IntMap (RangeTree Double)))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (IntMap (RangeTree Double))
sizeFitDB) ((RangeTree Double -> RangeTree Double)
-> Int -> IntMap (RangeTree Double) -> IntMap (RangeTree Double)
forall a. (a -> a) -> Int -> IntMap a -> IntMap a
IntMap.adjust (Int -> Double -> RangeTree Double -> RangeTree Double
forall a. (Ord a, Show a) => Int -> a -> RangeTree a -> RangeTree a
removeRange Int
subO (Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust Maybe Double
fitSub) (RangeTree Double -> RangeTree Double)
-> (RangeTree Double -> RangeTree Double)
-> RangeTree Double
-> RangeTree Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Double -> RangeTree Double -> RangeTree Double
forall a. (Ord a, Show a) => Int -> a -> RangeTree a -> RangeTree a
removeRange Int
sub (Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust Maybe Double
fitSub)) Int
szSub)
else (EGraph -> EGraph) -> EGraphST m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> EGraphST m ())
-> (EGraph -> EGraph) -> EGraphST m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph IntSet IntSet
-> (IntSet -> IntSet) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((IntSet -> Identity IntSet) -> EGraphDB -> Identity EGraphDB)
-> ASetter EGraph EGraph IntSet IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntSet -> Identity IntSet) -> EGraphDB -> Identity EGraphDB
Lens' EGraphDB IntSet
unevaluated) (Int -> IntSet -> IntSet
IntSet.insert Int
led (IntSet -> IntSet) -> (IntSet -> IntSet) -> IntSet -> IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> IntSet -> IntSet
IntSet.delete Int
ledO (IntSet -> IntSet) -> (IntSet -> IntSet) -> IntSet -> IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> IntSet -> IntSet
IntSet.delete Int
sub (IntSet -> IntSet) -> (IntSet -> IntSet) -> IntSet -> IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> IntSet -> IntSet
IntSet.delete Int
subO)
where
fitNew :: Maybe Double
fitNew = (EClassData -> Maybe Double
_fitness (EClassData -> Maybe Double)
-> (EClass -> EClassData) -> EClass -> Maybe Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info) EClass
newC
fitLed :: Maybe Double
fitLed = (EClassData -> Maybe Double
_fitness (EClassData -> Maybe Double)
-> (EClass -> EClassData) -> EClass -> Maybe Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info) EClass
ledC
fitSub :: Maybe Double
fitSub = (EClassData -> Maybe Double
_fitness (EClassData -> Maybe Double)
-> (EClass -> EClassData) -> EClass -> Maybe Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info) EClass
subC
szNew :: Int
szNew = (EClassData -> Int
_size (EClassData -> Int) -> (EClass -> EClassData) -> EClass -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info) EClass
newC
szLed :: Int
szLed = (EClassData -> Int
_size (EClassData -> Int) -> (EClass -> EClassData) -> EClass -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info) EClass
ledC
szSub :: Int
szSub = (EClassData -> Int
_size (EClassData -> Int) -> (EClass -> EClassData) -> EClass -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info) EClass
subC
modifyEClass :: Monad m => CostFun -> EClassId -> EGraphST m EClassId
modifyEClass :: forall (m :: * -> *). Monad m => CostFun -> Int -> EGraphST m Int
modifyEClass CostFun
costFun Int
ecId =
do EClass
ec <- Int -> EGraphST m EClass
forall (m :: * -> *). Monad m => Int -> EGraphST m EClass
getEClass Int
ecId
case (EClassData -> Consts
_consts (EClassData -> Consts)
-> (EClass -> EClassData) -> EClass -> Consts
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info) EClass
ec of
ConstVal Double
x -> do
let en :: SRTree val
en = Double -> SRTree val
forall val. Double -> SRTree val
Const Double
x
Int
c <- CostFun -> ENode -> EGraphST m Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
calculateCost CostFun
costFun ENode
forall {val}. SRTree val
en
let infoEc :: EClassData
infoEc = (EClass -> EClassData
_info EClass
ec){ _cost = c, _best = en, _consts = toConst en }
Maybe Int
maybeEid <- (EGraph -> Maybe Int) -> StateT EGraph m (Maybe Int)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((Map ENode Int -> ENode -> Maybe Int
forall k a. Ord k => Map k a -> k -> Maybe a
Map.!? ENode
forall {val}. SRTree val
en) (Map ENode Int -> Maybe Int)
-> (EGraph -> Map ENode Int) -> EGraph -> Maybe Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> Map ENode Int
_eNodeToEClass)
(EGraph -> EGraph) -> StateT EGraph m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> StateT EGraph m ())
-> (EGraph -> EGraph) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (IntMap EClass) (IntMap EClass)
-> (IntMap EClass -> IntMap EClass) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter EGraph EGraph (IntMap EClass) (IntMap EClass)
Lens' EGraph (IntMap EClass)
eClass (Int -> EClass -> IntMap EClass -> IntMap EClass
forall a. Int -> a -> IntMap a -> IntMap a
IntMap.insert Int
ecId EClass
ec{_eNodes = Set.singleton (encodeEnode en) , _info = infoEc})
Bool -> StateT EGraph m () -> StateT EGraph m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Maybe Double -> Bool
forall a. Maybe a -> Bool
isJust (Maybe Double -> Bool) -> Maybe Double -> Bool
forall a b. (a -> b) -> a -> b
$ EClassData -> Maybe Double
_fitness (EClassData -> Maybe Double) -> EClassData -> Maybe Double
forall a b. (a -> b) -> a -> b
$ EClass -> EClassData
_info EClass
ec) (StateT EGraph m () -> StateT EGraph m ())
-> StateT EGraph m () -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ (EGraph -> EGraph) -> StateT EGraph m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> StateT EGraph m ())
-> (EGraph -> EGraph) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (HashSet Int) (HashSet Int)
-> (HashSet Int -> HashSet Int) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((HashSet Int -> Identity (HashSet Int))
-> EGraphDB -> Identity EGraphDB)
-> ASetter EGraph EGraph (HashSet Int) (HashSet Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HashSet Int -> Identity (HashSet Int))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (HashSet Int)
refits) (Int -> HashSet Int -> HashSet Int
forall a. Hashable a => a -> HashSet a -> HashSet a
Set.insert Int
ecId)
case Maybe Int
maybeEid of
Maybe Int
Nothing -> Int -> EGraphST m Int
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
ecId
Just Int
eid' -> CostFun -> Int -> Int -> EGraphST m Int
forall (m :: * -> *).
Monad m =>
CostFun -> Int -> Int -> EGraphST m Int
merge CostFun
costFun Int
eid' Int
ecId
ParamIx Int
x -> do
let en :: SRTree val
en = Int -> SRTree val
forall val. Int -> SRTree val
Param Int
x
Int
c <- CostFun -> ENode -> EGraphST m Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
calculateCost CostFun
costFun ENode
forall {val}. SRTree val
en
HashSet ENodeEnc
ens <- (EGraph -> HashSet ENodeEnc) -> StateT EGraph m (HashSet ENodeEnc)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EClass -> HashSet ENodeEnc
_eNodes (EClass -> HashSet ENodeEnc)
-> (EGraph -> EClass) -> EGraph -> HashSet ENodeEnc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap EClass -> Int -> EClass
forall a. IntMap a -> Int -> a
IntMap.! Int
ecId) (IntMap EClass -> EClass)
-> (EGraph -> IntMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> IntMap EClass
_eClass)
let infoEc :: EClassData
infoEc = (EClass -> EClassData
_info EClass
ec){ _cost = c, _best = en, _consts = toConst en }
Maybe Int
maybeEid <- (EGraph -> Maybe Int) -> StateT EGraph m (Maybe Int)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((Map ENode Int -> ENode -> Maybe Int
forall k a. Ord k => Map k a -> k -> Maybe a
Map.!? ENode
forall {val}. SRTree val
en) (Map ENode Int -> Maybe Int)
-> (EGraph -> Map ENode Int) -> EGraph -> Maybe Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> Map ENode Int
_eNodeToEClass)
(EGraph -> EGraph) -> StateT EGraph m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> StateT EGraph m ())
-> (EGraph -> EGraph) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (IntMap EClass) (IntMap EClass)
-> (IntMap EClass -> IntMap EClass) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter EGraph EGraph (IntMap EClass) (IntMap EClass)
Lens' EGraph (IntMap EClass)
eClass (Int -> EClass -> IntMap EClass -> IntMap EClass
forall a. Int -> a -> IntMap a -> IntMap a
IntMap.insert Int
ecId EClass
ec{_eNodes = Set.insert (encodeEnode en) (_eNodes ec), _info = infoEc})
Bool -> StateT EGraph m () -> StateT EGraph m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Maybe Double -> Bool
forall a. Maybe a -> Bool
isJust (Maybe Double -> Bool) -> Maybe Double -> Bool
forall a b. (a -> b) -> a -> b
$ EClassData -> Maybe Double
_fitness (EClassData -> Maybe Double) -> EClassData -> Maybe Double
forall a b. (a -> b) -> a -> b
$ EClass -> EClassData
_info EClass
ec) (StateT EGraph m () -> StateT EGraph m ())
-> StateT EGraph m () -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ (EGraph -> EGraph) -> StateT EGraph m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> StateT EGraph m ())
-> (EGraph -> EGraph) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (HashSet Int) (HashSet Int)
-> (HashSet Int -> HashSet Int) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((HashSet Int -> Identity (HashSet Int))
-> EGraphDB -> Identity EGraphDB)
-> ASetter EGraph EGraph (HashSet Int) (HashSet Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HashSet Int -> Identity (HashSet Int))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (HashSet Int)
refits) (Int -> HashSet Int -> HashSet Int
forall a. Hashable a => a -> HashSet a -> HashSet a
Set.insert Int
ecId)
case Maybe Int
maybeEid of
Maybe Int
Nothing -> Int -> EGraphST m Int
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
ecId
Just Int
eid' -> CostFun -> Int -> Int -> EGraphST m Int
forall (m :: * -> *).
Monad m =>
CostFun -> Int -> Int -> EGraphST m Int
merge CostFun
costFun Int
eid' Int
ecId
Consts
_ -> Int -> EGraphST m Int
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
ecId
where
isTerm :: SRTree val -> Bool
isTerm (Var Int
_) = Bool
True
isTerm (Const Double
_) = Bool
True
isTerm (Param Int
_) = Bool
True
isTerm SRTree val
_ = Bool
False
toConst :: SRTree val -> Consts
toConst (Param Int
ix) = Int -> Consts
ParamIx Int
ix
toConst (Const Double
x) = Double -> Consts
ConstVal Double
x
toConst SRTree val
_ = Consts
NotConst
createDB :: Monad m => EGraphST m DB
createDB :: forall (m :: * -> *). Monad m => EGraphST m DB
createDB = do (EGraph -> EGraph) -> StateT EGraph m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> StateT EGraph m ())
-> (EGraph -> EGraph) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph DB DB -> (DB -> DB) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((DB -> Identity DB) -> EGraphDB -> Identity EGraphDB)
-> ASetter EGraph EGraph DB DB
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (DB -> Identity DB) -> EGraphDB -> Identity EGraphDB
Lens' EGraphDB DB
patDB) (DB -> DB -> DB
forall a b. a -> b -> a
const DB
forall k a. Map k a
Map.empty)
[(ENode, Int)]
ecls <- (EGraph -> [(ENode, Int)]) -> StateT EGraph m [(ENode, Int)]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (Map ENode Int -> [(ENode, Int)]
forall k a. Map k a -> [(k, a)]
Map.toList (Map ENode Int -> [(ENode, Int)])
-> (EGraph -> Map ENode Int) -> EGraph -> [(ENode, Int)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> Map ENode Int
_eNodeToEClass)
((ENode, Int) -> StateT EGraph m ())
-> [(ENode, Int)] -> StateT EGraph m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((ENode -> Int -> StateT EGraph m ())
-> (ENode, Int) -> StateT EGraph m ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ENode -> Int -> StateT EGraph m ()
forall (m :: * -> *). Monad m => ENode -> Int -> EGraphST m ()
addToDB) [(ENode, Int)]
ecls
(EGraph -> DB) -> EGraphST m DB
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EGraphDB -> DB
_patDB (EGraphDB -> DB) -> (EGraph -> EGraphDB) -> EGraph -> DB
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> EGraphDB
_eDB)
{-# INLINE createDB #-}
createDBBest :: Monad m => EGraphST m DB
createDBBest :: forall (m :: * -> *). Monad m => EGraphST m DB
createDBBest = do (EGraph -> EGraph) -> StateT EGraph m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> StateT EGraph m ())
-> (EGraph -> EGraph) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph DB DB -> (DB -> DB) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((DB -> Identity DB) -> EGraphDB -> Identity EGraphDB)
-> ASetter EGraph EGraph DB DB
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (DB -> Identity DB) -> EGraphDB -> Identity EGraphDB
Lens' EGraphDB DB
patDB) (DB -> DB -> DB
forall a b. a -> b -> a
const DB
forall k a. Map k a
Map.empty)
[(ENode, Int)]
ecls <- (EGraph -> [(ENode, Int)]) -> StateT EGraph m [(ENode, Int)]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (((Int, EClass) -> (ENode, Int))
-> [(Int, EClass)] -> [(ENode, Int)]
forall a b. (a -> b) -> [a] -> [b]
Prelude.map (\(Int
eId, EClass
ec) -> (EClassData -> ENode
_best (EClass -> EClassData
_info EClass
ec), Int
eId)) ([(Int, EClass)] -> [(ENode, Int)])
-> (EGraph -> [(Int, EClass)]) -> EGraph -> [(ENode, Int)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IntMap EClass -> [(Int, EClass)]
forall a. IntMap a -> [(Int, a)]
IntMap.toList (IntMap EClass -> [(Int, EClass)])
-> (EGraph -> IntMap EClass) -> EGraph -> [(Int, EClass)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> IntMap EClass
_eClass)
((ENode, Int) -> StateT EGraph m ())
-> [(ENode, Int)] -> StateT EGraph m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((ENode -> Int -> StateT EGraph m ())
-> (ENode, Int) -> StateT EGraph m ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ENode -> Int -> StateT EGraph m ()
forall (m :: * -> *). Monad m => ENode -> Int -> EGraphST m ()
addToDB) [(ENode, Int)]
ecls
(EGraph -> DB) -> EGraphST m DB
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EGraphDB -> DB
_patDB (EGraphDB -> DB) -> (EGraph -> EGraphDB) -> EGraph -> DB
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> EGraphDB
_eDB)
addToDB :: Monad m => ENode -> EClassId -> EGraphST m ()
addToDB :: forall (m :: * -> *). Monad m => ENode -> Int -> EGraphST m ()
addToDB ENode
enode' Int
eid = do
Int
eid' <- Int -> EGraphST m Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical Int
eid
Consts
isConst <- (EGraph -> Consts) -> StateT EGraph m Consts
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EClassData -> Consts
_consts (EClassData -> Consts)
-> (EGraph -> EClassData) -> EGraph -> Consts
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info (EClass -> EClassData)
-> (EGraph -> EClass) -> EGraph -> EClassData
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap EClass -> Int -> EClass
forall a. IntMap a -> Int -> a
IntMap.! Int
eid') (IntMap EClass -> EClass)
-> (EGraph -> IntMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> IntMap EClass
_eClass)
let enode :: ENode
enode = case Consts
isConst of
ConstVal Double
x -> Double -> ENode
forall val. Double -> SRTree val
Const Double
x
ParamIx Int
x -> Int -> ENode
forall val. Int -> SRTree val
Param Int
x
Consts
_ -> ENode
enode'
let ids :: [Int]
ids = Int
eid Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: ENode -> [Int]
forall a. SRTree a -> [a]
childrenOf ENode
enode
op :: SRTree ()
op = ENode -> SRTree ()
forall a. SRTree a -> SRTree ()
getOperator ENode
enode
Maybe IntTrie
trie <- (EGraph -> Maybe IntTrie) -> StateT EGraph m (Maybe IntTrie)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((DB -> SRTree () -> Maybe IntTrie
forall k a. Ord k => Map k a -> k -> Maybe a
Map.!? SRTree ()
op) (DB -> Maybe IntTrie) -> (EGraph -> DB) -> EGraph -> Maybe IntTrie
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraphDB -> DB
_patDB (EGraphDB -> DB) -> (EGraph -> EGraphDB) -> EGraph -> DB
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> EGraphDB
_eDB)
case Maybe IntTrie -> [Int] -> Maybe IntTrie
populate Maybe IntTrie
trie [Int]
ids of
Maybe IntTrie
Nothing -> () -> EGraphST m ()
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
Just IntTrie
t -> (EGraph -> EGraph) -> EGraphST m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> EGraphST m ())
-> (EGraph -> EGraph) -> EGraphST m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph DB DB -> (DB -> DB) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((DB -> Identity DB) -> EGraphDB -> Identity EGraphDB)
-> ASetter EGraph EGraph DB DB
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (DB -> Identity DB) -> EGraphDB -> Identity EGraphDB
Lens' EGraphDB DB
patDB) (SRTree () -> IntTrie -> DB -> DB
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert SRTree ()
op IntTrie
t)
{-# INLINE addToDB #-}
populate :: Maybe IntTrie -> [EClassId] -> Maybe IntTrie
populate :: Maybe IntTrie -> [Int] -> Maybe IntTrie
populate Maybe IntTrie
_ [] = Maybe IntTrie
forall a. Maybe a
Nothing
populate Maybe IntTrie
Nothing [Int]
eids = (Int -> Maybe IntTrie -> Maybe IntTrie)
-> Maybe IntTrie -> [Int] -> Maybe IntTrie
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Int -> Maybe IntTrie -> Maybe IntTrie
f Maybe IntTrie
forall a. Maybe a
Nothing [Int]
eids
where
f :: EClassId -> Maybe IntTrie -> Maybe IntTrie
f :: Int -> Maybe IntTrie -> Maybe IntTrie
f Int
eid (Just IntTrie
t) = IntTrie -> Maybe IntTrie
forall a. a -> Maybe a
Just (IntTrie -> Maybe IntTrie) -> IntTrie -> Maybe IntTrie
forall a b. (a -> b) -> a -> b
$ Int -> IntMap IntTrie -> IntTrie
trie Int
eid (Int -> IntTrie -> IntMap IntTrie
forall a. Int -> a -> IntMap a
IntMap.singleton Int
eid IntTrie
t)
f Int
eid Maybe IntTrie
Nothing = IntTrie -> Maybe IntTrie
forall a. a -> Maybe a
Just (IntTrie -> Maybe IntTrie) -> IntTrie -> Maybe IntTrie
forall a b. (a -> b) -> a -> b
$ Int -> IntMap IntTrie -> IntTrie
trie Int
eid IntMap IntTrie
forall a. IntMap a
IntMap.empty
populate (Just IntTrie
tId) (Int
eid:[Int]
eids) = let keys :: HashSet Int
keys = Int -> HashSet Int -> HashSet Int
forall a. Hashable a => a -> HashSet a -> HashSet a
Set.insert Int
eid (IntTrie -> HashSet Int
_keys IntTrie
tId)
nextTrie :: Maybe IntTrie
nextTrie = IntTrie -> IntMap IntTrie
_trie IntTrie
tId IntMap IntTrie -> Int -> Maybe IntTrie
forall a. IntMap a -> Int -> Maybe a
IntMap.!? Int
eid
val :: IntTrie
val = IntTrie -> Maybe IntTrie -> IntTrie
forall a. a -> Maybe a -> a
fromMaybe (Int -> IntMap IntTrie -> IntTrie
trie Int
eid IntMap IntTrie
forall a. IntMap a
IntMap.empty) (Maybe IntTrie -> IntTrie) -> Maybe IntTrie -> IntTrie
forall a b. (a -> b) -> a -> b
$ Maybe IntTrie -> [Int] -> Maybe IntTrie
populate Maybe IntTrie
nextTrie [Int]
eids
in IntTrie -> Maybe IntTrie
forall a. a -> Maybe a
Just (IntTrie -> Maybe IntTrie) -> IntTrie -> Maybe IntTrie
forall a b. (a -> b) -> a -> b
$ HashSet Int -> IntMap IntTrie -> IntTrie
IntTrie HashSet Int
keys (Int -> IntTrie -> IntMap IntTrie -> IntMap IntTrie
forall a. Int -> a -> IntMap a -> IntMap a
IntMap.insert Int
eid IntTrie
val (IntTrie -> IntMap IntTrie
_trie IntTrie
tId))
{-# INLINE populate #-}
canonizeMap :: Monad m => (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m (Map ClassOrVar ClassOrVar, ClassOrVar)
canonizeMap :: forall (m :: * -> *).
Monad m =>
(Map ClassOrVar ClassOrVar, ClassOrVar)
-> EGraphST m (Map ClassOrVar ClassOrVar, ClassOrVar)
canonizeMap (Map ClassOrVar ClassOrVar
subst, ClassOrVar
cv) = (,ClassOrVar
cv) (Map ClassOrVar ClassOrVar
-> (Map ClassOrVar ClassOrVar, ClassOrVar))
-> StateT EGraph m (Map ClassOrVar ClassOrVar)
-> StateT EGraph m (Map ClassOrVar ClassOrVar, ClassOrVar)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ClassOrVar -> StateT EGraph m ClassOrVar)
-> Map ClassOrVar ClassOrVar
-> StateT EGraph m (Map ClassOrVar ClassOrVar)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Map ClassOrVar a -> f (Map ClassOrVar b)
traverse ClassOrVar -> StateT EGraph m ClassOrVar
forall (m :: * -> *).
Monad m =>
ClassOrVar -> EGraphST m ClassOrVar
g Map ClassOrVar ClassOrVar
subst
where
g :: Monad m => ClassOrVar -> EGraphST m ClassOrVar
g :: forall (m :: * -> *).
Monad m =>
ClassOrVar -> EGraphST m ClassOrVar
g (Left Int
e2) = Int -> ClassOrVar
forall a b. a -> Either a b
Left (Int -> ClassOrVar)
-> StateT EGraph m Int -> StateT EGraph m ClassOrVar
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> StateT EGraph m Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical Int
e2
g ClassOrVar
e2 = ClassOrVar -> StateT EGraph m ClassOrVar
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ClassOrVar
e2
f :: Monad m => (ClassOrVar, ClassOrVar) -> EGraphST m (ClassOrVar, ClassOrVar)
f :: forall (m :: * -> *).
Monad m =>
(ClassOrVar, ClassOrVar) -> EGraphST m (ClassOrVar, ClassOrVar)
f (ClassOrVar
e1, Left Int
e2) = do Int
e2' <- Int -> EGraphST m Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical Int
e2
(ClassOrVar, ClassOrVar) -> EGraphST m (ClassOrVar, ClassOrVar)
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ClassOrVar
e1, Int -> ClassOrVar
forall a b. a -> Either a b
Left Int
e2')
f (ClassOrVar
e1, ClassOrVar
e2) = (ClassOrVar, ClassOrVar) -> EGraphST m (ClassOrVar, ClassOrVar)
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ClassOrVar
e1, ClassOrVar
e2)
{-# INLINE canonizeMap #-}
applyMatch :: Monad m => CostFun -> Rule -> (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m ()
applyMatch :: forall (m :: * -> *).
Monad m =>
CostFun
-> Rule -> (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m ()
applyMatch CostFun
costFun Rule
rule (Map ClassOrVar ClassOrVar, ClassOrVar)
match' =
do let conds :: [Condition]
conds = Rule -> [Condition]
getConditions Rule
rule
(Map ClassOrVar ClassOrVar, ClassOrVar)
match <- (Map ClassOrVar ClassOrVar, ClassOrVar)
-> EGraphST m (Map ClassOrVar ClassOrVar, ClassOrVar)
forall (m :: * -> *).
Monad m =>
(Map ClassOrVar ClassOrVar, ClassOrVar)
-> EGraphST m (Map ClassOrVar ClassOrVar, ClassOrVar)
canonizeMap (Map ClassOrVar ClassOrVar, ClassOrVar)
match'
Bool
validHeight <- (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m Bool
forall (m :: * -> *).
Monad m =>
(Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m Bool
isValidHeight (Map ClassOrVar ClassOrVar, ClassOrVar)
match
[Bool]
validConds <- (Condition -> EGraphST m Bool)
-> [Condition] -> StateT EGraph m [Bool]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Condition
-> (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m Bool
forall (m :: * -> *).
Monad m =>
Condition
-> (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m Bool
`isValidConditions` (Map ClassOrVar ClassOrVar, ClassOrVar)
match) [Condition]
conds
Bool -> EGraphST m () -> EGraphST m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool
validHeight Bool -> Bool -> Bool
&& [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and [Bool]
validConds) (EGraphST m () -> EGraphST m ()) -> EGraphST m () -> EGraphST m ()
forall a b. (a -> b) -> a -> b
$
do Int
new_eclass <- CostFun -> Map ClassOrVar ClassOrVar -> Pattern -> EGraphST m Int
forall (m :: * -> *).
Monad m =>
CostFun -> Map ClassOrVar ClassOrVar -> Pattern -> EGraphST m Int
reprPrat CostFun
costFun ((Map ClassOrVar ClassOrVar, ClassOrVar)
-> Map ClassOrVar ClassOrVar
forall a b. (a, b) -> a
fst (Map ClassOrVar ClassOrVar, ClassOrVar)
match) (Rule -> Pattern
target Rule
rule)
CostFun -> Int -> Int -> EGraphST m Int
forall (m :: * -> *).
Monad m =>
CostFun -> Int -> Int -> EGraphST m Int
merge CostFun
costFun (ClassOrVar -> Int
getInt ((Map ClassOrVar ClassOrVar, ClassOrVar) -> ClassOrVar
forall a b. (a, b) -> b
snd (Map ClassOrVar ClassOrVar, ClassOrVar)
match)) Int
new_eclass
() -> EGraphST m ()
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
{-# INLINE applyMatch #-}
applyMergeOnlyMatch :: Monad m => CostFun -> Rule -> (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m ()
applyMergeOnlyMatch :: forall (m :: * -> *).
Monad m =>
CostFun
-> Rule -> (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m ()
applyMergeOnlyMatch CostFun
costFun Rule
rule (Map ClassOrVar ClassOrVar, ClassOrVar)
match' =
do let conds :: [Condition]
conds = Rule -> [Condition]
getConditions Rule
rule
(Map ClassOrVar ClassOrVar, ClassOrVar)
match <- (Map ClassOrVar ClassOrVar, ClassOrVar)
-> EGraphST m (Map ClassOrVar ClassOrVar, ClassOrVar)
forall (m :: * -> *).
Monad m =>
(Map ClassOrVar ClassOrVar, ClassOrVar)
-> EGraphST m (Map ClassOrVar ClassOrVar, ClassOrVar)
canonizeMap (Map ClassOrVar ClassOrVar, ClassOrVar)
match'
Bool
validHeight <- (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m Bool
forall (m :: * -> *).
Monad m =>
(Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m Bool
isValidHeight (Map ClassOrVar ClassOrVar, ClassOrVar)
match
[Bool]
validConds <- (Condition -> EGraphST m Bool)
-> [Condition] -> StateT EGraph m [Bool]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Condition
-> (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m Bool
forall (m :: * -> *).
Monad m =>
Condition
-> (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m Bool
`isValidConditions` (Map ClassOrVar ClassOrVar, ClassOrVar)
match) [Condition]
conds
Bool -> EGraphST m () -> EGraphST m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool
validHeight Bool -> Bool -> Bool
&& [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and [Bool]
validConds) (EGraphST m () -> EGraphST m ()) -> EGraphST m () -> EGraphST m ()
forall a b. (a -> b) -> a -> b
$
do Maybe Int
maybe_eid <- CostFun
-> Map ClassOrVar ClassOrVar -> Pattern -> EGraphST m (Maybe Int)
forall (m :: * -> *).
Monad m =>
CostFun
-> Map ClassOrVar ClassOrVar -> Pattern -> EGraphST m (Maybe Int)
classOfENode CostFun
costFun ((Map ClassOrVar ClassOrVar, ClassOrVar)
-> Map ClassOrVar ClassOrVar
forall a b. (a, b) -> a
fst (Map ClassOrVar ClassOrVar, ClassOrVar)
match) (Rule -> Pattern
target Rule
rule)
case Maybe Int
maybe_eid of
Maybe Int
Nothing -> () -> EGraphST m ()
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
Just Int
eid -> do CostFun -> Int -> Int -> EGraphST m Int
forall (m :: * -> *).
Monad m =>
CostFun -> Int -> Int -> EGraphST m Int
merge CostFun
costFun (ClassOrVar -> Int
getInt ((Map ClassOrVar ClassOrVar, ClassOrVar) -> ClassOrVar
forall a b. (a, b) -> b
snd (Map ClassOrVar ClassOrVar, ClassOrVar)
match)) Int
eid
() -> EGraphST m ()
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
{-# INLINE applyMergeOnlyMatch #-}
classOfENode :: Monad m => CostFun -> Map ClassOrVar ClassOrVar -> Pattern -> EGraphST m (Maybe EClassId)
classOfENode :: forall (m :: * -> *).
Monad m =>
CostFun
-> Map ClassOrVar ClassOrVar -> Pattern -> EGraphST m (Maybe Int)
classOfENode CostFun
costFun Map ClassOrVar ClassOrVar
subst (VarPat Char
c) = do let maybeEid :: Maybe Int
maybeEid = ClassOrVar -> Int
getInt (ClassOrVar -> Int) -> Maybe ClassOrVar -> Maybe Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Map ClassOrVar ClassOrVar
subst Map ClassOrVar ClassOrVar -> ClassOrVar -> Maybe ClassOrVar
forall k a. Ord k => Map k a -> k -> Maybe a
Map.!? Int -> ClassOrVar
forall a b. b -> Either a b
Right (Char -> Int
forall a. Enum a => a -> Int
fromEnum Char
c)
case Maybe Int
maybeEid of
Maybe Int
Nothing -> Maybe Int -> EGraphST m (Maybe Int)
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe Int
forall a. Maybe a
Nothing
Just Int
eid -> Int -> Maybe Int
forall a. a -> Maybe a
Just (Int -> Maybe Int) -> StateT EGraph m Int -> EGraphST m (Maybe Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> StateT EGraph m Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical Int
eid
classOfENode CostFun
costFun Map ClassOrVar ClassOrVar
subst (Fixed (Const Double
x)) = Int -> Maybe Int
forall a. a -> Maybe a
Just (Int -> Maybe Int) -> StateT EGraph m Int -> EGraphST m (Maybe Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CostFun -> ENode -> StateT EGraph m Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
costFun (Double -> ENode
forall val. Double -> SRTree val
Const Double
x)
classOfENode CostFun
costFun Map ClassOrVar ClassOrVar
subst (Fixed SRTree Pattern
target) = do [Maybe Int]
newChildren <- (Pattern -> EGraphST m (Maybe Int))
-> [Pattern] -> StateT EGraph m [Maybe Int]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (CostFun
-> Map ClassOrVar ClassOrVar -> Pattern -> EGraphST m (Maybe Int)
forall (m :: * -> *).
Monad m =>
CostFun
-> Map ClassOrVar ClassOrVar -> Pattern -> EGraphST m (Maybe Int)
classOfENode CostFun
costFun Map ClassOrVar ClassOrVar
subst) (SRTree Pattern -> [Pattern]
forall a. SRTree a -> [a]
getElems SRTree Pattern
target)
case [Maybe Int] -> Maybe [Int]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence [Maybe Int]
newChildren of
Maybe [Int]
Nothing -> Maybe Int -> EGraphST m (Maybe Int)
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe Int
forall a. Maybe a
Nothing
Just [Int]
cs -> do let new_enode :: ENode
new_enode = [Int] -> SRTree Pattern -> ENode
forall a b. [a] -> SRTree b -> SRTree a
replaceChildren [Int]
cs SRTree Pattern
target
[Int]
cs' <- (Int -> StateT EGraph m Int) -> [Int] -> StateT EGraph m [Int]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Int -> StateT EGraph m Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical [Int]
cs
[Bool]
areConsts <- (Int -> StateT EGraph m Bool) -> [Int] -> StateT EGraph m [Bool]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Int -> StateT EGraph m Bool
forall (m :: * -> *). Monad m => Int -> EGraphST m Bool
isConst [Int]
cs'
if [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and [Bool]
areConsts
then do Int
eid <- CostFun -> ENode -> StateT EGraph m Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
costFun ENode
new_enode
CostFun -> EGraphST m ()
forall (m :: * -> *). Monad m => CostFun -> EGraphST m ()
rebuild CostFun
costFun
Maybe Int -> EGraphST m (Maybe Int)
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> Maybe Int
forall a. a -> Maybe a
Just Int
eid)
else (EGraph -> Maybe Int) -> EGraphST m (Maybe Int)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((Map ENode Int -> ENode -> Maybe Int
forall k a. Ord k => Map k a -> k -> Maybe a
Map.!? ENode
new_enode) (Map ENode Int -> Maybe Int)
-> (EGraph -> Map ENode Int) -> EGraph -> Maybe Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> Map ENode Int
_eNodeToEClass)
{-# INLINE classOfENode #-}
reprPrat :: Monad m => CostFun -> Map ClassOrVar ClassOrVar -> Pattern -> EGraphST m EClassId
reprPrat :: forall (m :: * -> *).
Monad m =>
CostFun -> Map ClassOrVar ClassOrVar -> Pattern -> EGraphST m Int
reprPrat CostFun
costFun Map ClassOrVar ClassOrVar
subst (VarPat Char
c) = Int -> EGraphST m Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical (Int -> EGraphST m Int) -> Int -> EGraphST m Int
forall a b. (a -> b) -> a -> b
$ ClassOrVar -> Int
getInt (ClassOrVar -> Int) -> ClassOrVar -> Int
forall a b. (a -> b) -> a -> b
$ Map ClassOrVar ClassOrVar
subst Map ClassOrVar ClassOrVar -> ClassOrVar -> ClassOrVar
forall k a. Ord k => Map k a -> k -> a
Map.! Int -> ClassOrVar
forall a b. b -> Either a b
Right (Char -> Int
forall a. Enum a => a -> Int
fromEnum Char
c)
reprPrat CostFun
costFun Map ClassOrVar ClassOrVar
subst (Fixed SRTree Pattern
target) = do [Int]
newChildren <- (Pattern -> EGraphST m Int) -> [Pattern] -> StateT EGraph m [Int]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (CostFun -> Map ClassOrVar ClassOrVar -> Pattern -> EGraphST m Int
forall (m :: * -> *).
Monad m =>
CostFun -> Map ClassOrVar ClassOrVar -> Pattern -> EGraphST m Int
reprPrat CostFun
costFun Map ClassOrVar ClassOrVar
subst) (SRTree Pattern -> [Pattern]
forall a. SRTree a -> [a]
getElems SRTree Pattern
target)
CostFun -> ENode -> EGraphST m Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
costFun ([Int] -> SRTree Pattern -> ENode
forall a b. [a] -> SRTree b -> SRTree a
replaceChildren [Int]
newChildren SRTree Pattern
target)
{-# INLINE reprPrat #-}
isValidHeight :: Monad m => (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m Bool
isValidHeight :: forall (m :: * -> *).
Monad m =>
(Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m Bool
isValidHeight (Map ClassOrVar ClassOrVar, ClassOrVar)
match = do
Int
h <- case (Map ClassOrVar ClassOrVar, ClassOrVar) -> ClassOrVar
forall a b. (a, b) -> b
snd (Map ClassOrVar ClassOrVar, ClassOrVar)
match of
Left Int
ec -> do Int
ec' <- Int -> StateT EGraph m Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical Int
ec
(EGraph -> Int) -> StateT EGraph m Int
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EClass -> Int
_height (EClass -> Int) -> (EGraph -> EClass) -> EGraph -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap EClass -> Int -> EClass
forall a. IntMap a -> Int -> a
IntMap.! Int
ec') (IntMap EClass -> EClass)
-> (EGraph -> IntMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> IntMap EClass
_eClass)
Right Int
_ -> Int -> StateT EGraph m Int
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
0
Bool -> EGraphST m Bool
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool -> EGraphST m Bool) -> Bool -> EGraphST m Bool
forall a b. (a -> b) -> a -> b
$ Int
h Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
15
{-# INLINE isValidHeight #-}
isValidConditions :: Monad m => Condition -> (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m Bool
isValidConditions :: forall (m :: * -> *).
Monad m =>
Condition
-> (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m Bool
isValidConditions Condition
cond (Map ClassOrVar ClassOrVar, ClassOrVar)
match = (EGraph -> Bool) -> StateT EGraph m Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((EGraph -> Bool) -> StateT EGraph m Bool)
-> (EGraph -> Bool) -> StateT EGraph m Bool
forall a b. (a -> b) -> a -> b
$ Condition
cond ((Map ClassOrVar ClassOrVar, ClassOrVar)
-> Map ClassOrVar ClassOrVar
forall a b. (a, b) -> a
fst (Map ClassOrVar ClassOrVar, ClassOrVar)
match)
{-# INLINE isValidConditions #-}
fromTree :: Monad m => CostFun -> Fix SRTree -> EGraphST m EClassId
fromTree :: forall (m :: * -> *).
Monad m =>
CostFun -> Fix SRTree -> EGraphST m Int
fromTree CostFun
costFun = (forall x.
SRTree (StateT EGraph m x) -> StateT EGraph m (SRTree x))
-> (ENode -> StateT EGraph m Int)
-> Fix SRTree
-> StateT EGraph m Int
forall (f :: * -> *) (m :: * -> *) a.
(Functor f, Monad m) =>
(forall x. f (m x) -> m (f x)) -> (f a -> m a) -> Fix f -> m a
cataM SRTree (StateT EGraph m x) -> StateT EGraph m (SRTree x)
forall x. SRTree (StateT EGraph m x) -> StateT EGraph m (SRTree x)
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => SRTree (m a) -> m (SRTree a)
sequence (CostFun -> ENode -> StateT EGraph m Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
costFun)
{-# INLINE fromTree #-}
fromTrees :: Monad m => CostFun -> [Fix SRTree] -> EGraphST m [EClassId]
fromTrees :: forall (m :: * -> *).
Monad m =>
CostFun -> [Fix SRTree] -> EGraphST m [Int]
fromTrees CostFun
costFun = ([Int] -> Fix SRTree -> StateT EGraph m [Int])
-> [Int] -> [Fix SRTree] -> StateT EGraph m [Int]
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (\[Int]
rs Fix SRTree
t -> do Int
eid <- CostFun -> Fix SRTree -> EGraphST m Int
forall (m :: * -> *).
Monad m =>
CostFun -> Fix SRTree -> EGraphST m Int
fromTree CostFun
costFun Fix SRTree
t; [Int] -> StateT EGraph m [Int]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
eidInt -> [Int] -> [Int]
forall a. a -> [a] -> [a]
:[Int]
rs)) []
{-# INLINE fromTrees #-}
countParamsEg :: EGraph -> EClassId -> Int
countParamsEg :: EGraph -> Int -> Int
countParamsEg EGraph
eg Int
rt = Fix SRTree -> Int
forall a. Num a => Fix SRTree -> a
countParams (Fix SRTree -> Int)
-> (Identity (Fix SRTree) -> Fix SRTree)
-> Identity (Fix SRTree)
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Identity (Fix SRTree) -> Fix SRTree
forall a. Identity a -> a
runIdentity (Identity (Fix SRTree) -> Int) -> Identity (Fix SRTree) -> Int
forall a b. (a -> b) -> a -> b
$ Int -> EGraphST Identity (Fix SRTree)
forall (m :: * -> *). Monad m => Int -> EGraphST m (Fix SRTree)
getBestExpr Int
rt EGraphST Identity (Fix SRTree) -> EGraph -> Identity (Fix SRTree)
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
`evalStateT` EGraph
eg
countParamsUniqEg :: EGraph -> EClassId -> Int
countParamsUniqEg :: EGraph -> Int -> Int
countParamsUniqEg EGraph
eg Int
rt = Fix SRTree -> Int
countParamsUniq (Fix SRTree -> Int)
-> (Identity (Fix SRTree) -> Fix SRTree)
-> Identity (Fix SRTree)
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Identity (Fix SRTree) -> Fix SRTree
forall a. Identity a -> a
runIdentity (Identity (Fix SRTree) -> Int) -> Identity (Fix SRTree) -> Int
forall a b. (a -> b) -> a -> b
$ Int -> EGraphST Identity (Fix SRTree)
forall (m :: * -> *). Monad m => Int -> EGraphST m (Fix SRTree)
getBestExpr Int
rt EGraphST Identity (Fix SRTree) -> EGraph -> Identity (Fix SRTree)
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
`evalStateT` EGraph
eg
getBestExpr :: Monad m => EClassId -> EGraphST m (Fix SRTree)
getBestExpr :: forall (m :: * -> *). Monad m => Int -> EGraphST m (Fix SRTree)
getBestExpr Int
eid = do Int
eid' <- Int -> EGraphST m Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical Int
eid
ENode
best <- (EGraph -> ENode) -> StateT EGraph m ENode
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EClassData -> ENode
_best (EClassData -> ENode) -> (EGraph -> EClassData) -> EGraph -> ENode
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info (EClass -> EClassData)
-> (EGraph -> EClass) -> EGraph -> EClassData
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap EClass -> Int -> EClass
forall a. IntMap a -> Int -> a
IntMap.! Int
eid') (IntMap EClass -> EClass)
-> (EGraph -> IntMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> IntMap EClass
_eClass)
[Fix SRTree]
childs <- (Int -> EGraphST m (Fix SRTree))
-> [Int] -> StateT EGraph m [Fix SRTree]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Int -> EGraphST m (Fix SRTree)
forall (m :: * -> *). Monad m => Int -> EGraphST m (Fix SRTree)
getBestExpr ([Int] -> StateT EGraph m [Fix SRTree])
-> [Int] -> StateT EGraph m [Fix SRTree]
forall a b. (a -> b) -> a -> b
$ ENode -> [Int]
forall a. SRTree a -> [a]
childrenOf ENode
best
Fix SRTree -> EGraphST m (Fix SRTree)
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Fix SRTree -> EGraphST m (Fix SRTree))
-> (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree)
-> EGraphST m (Fix SRTree)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> EGraphST m (Fix SRTree))
-> SRTree (Fix SRTree) -> EGraphST m (Fix SRTree)
forall a b. (a -> b) -> a -> b
$ [Fix SRTree] -> ENode -> SRTree (Fix SRTree)
forall a b. [a] -> SRTree b -> SRTree a
replaceChildren [Fix SRTree]
childs ENode
best
{-# INLINE getBestExpr #-}
getBestENode :: Int -> StateT EGraph m ENode
getBestENode Int
eid = do Int
eid' <- Int -> EGraphST m Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical Int
eid
(EGraph -> ENode) -> StateT EGraph m ENode
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EClassData -> ENode
_best (EClassData -> ENode) -> (EGraph -> EClassData) -> EGraph -> ENode
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info (EClass -> EClassData)
-> (EGraph -> EClass) -> EGraph -> EClassData
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap EClass -> Int -> EClass
forall a. IntMap a -> Int -> a
IntMap.! Int
eid') (IntMap EClass -> EClass)
-> (EGraph -> IntMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> IntMap EClass
_eClass)
{-# INLINE getBestENode #-}
getExpressionFrom :: Monad m => EClassId -> EGraphST m (Fix SRTree)
getExpressionFrom :: forall (m :: * -> *). Monad m => Int -> EGraphST m (Fix SRTree)
getExpressionFrom Int
eId' = do
Int
eId <- Int -> EGraphST m Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical Int
eId'
HashSet ENode
nodes <- (EGraph -> HashSet ENode) -> StateT EGraph m (HashSet ENode)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ENodeEnc -> ENode) -> HashSet ENodeEnc -> HashSet ENode
forall b a. Hashable b => (a -> b) -> HashSet a -> HashSet b
Set.map ENodeEnc -> ENode
decodeEnode (HashSet ENodeEnc -> HashSet ENode)
-> (EGraph -> HashSet ENodeEnc) -> EGraph -> HashSet ENode
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> HashSet ENodeEnc
_eNodes (EClass -> HashSet ENodeEnc)
-> (EGraph -> EClass) -> EGraph -> HashSet ENodeEnc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap EClass -> Int -> EClass
forall a. IntMap a -> Int -> a
IntMap.! Int
eId) (IntMap EClass -> EClass)
-> (EGraph -> IntMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> IntMap EClass
_eClass)
let hasTerm :: Bool
hasTerm = (ENode -> Bool) -> HashSet ENode -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ENode -> Bool
forall {val}. SRTree val -> Bool
isTerm HashSet ENode
nodes
cands :: [ENode]
cands = if Bool
hasTerm then (ENode -> Bool) -> [ENode] -> [ENode]
forall a. (a -> Bool) -> [a] -> [a]
filter ENode -> Bool
forall {val}. SRTree val -> Bool
isTerm (HashSet ENode -> [ENode]
forall a. HashSet a -> [a]
Set.toList HashSet ENode
nodes) else HashSet ENode -> [ENode]
forall a. HashSet a -> [a]
Set.toList HashSet ENode
nodes
SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> StateT EGraph m (SRTree (Fix SRTree)) -> EGraphST m (Fix SRTree)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> case [ENode] -> ENode
forall a. HasCallStack => [a] -> a
head ([ENode] -> ENode) -> [ENode] -> ENode
forall a b. (a -> b) -> a -> b
$ HashSet ENode -> [ENode]
forall a. HashSet a -> [a]
Set.toList HashSet ENode
nodes of
Bin Op
op Int
l Int
r -> Op -> Fix SRTree -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Op -> val -> val -> SRTree val
Bin Op
op (Fix SRTree -> Fix SRTree -> SRTree (Fix SRTree))
-> EGraphST m (Fix SRTree)
-> StateT EGraph m (Fix SRTree -> SRTree (Fix SRTree))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> EGraphST m (Fix SRTree)
forall (m :: * -> *). Monad m => Int -> EGraphST m (Fix SRTree)
getExpressionFrom Int
l StateT EGraph m (Fix SRTree -> SRTree (Fix SRTree))
-> EGraphST m (Fix SRTree) -> StateT EGraph m (SRTree (Fix SRTree))
forall a b.
StateT EGraph m (a -> b) -> StateT EGraph m a -> StateT EGraph m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> EGraphST m (Fix SRTree)
forall (m :: * -> *). Monad m => Int -> EGraphST m (Fix SRTree)
getExpressionFrom Int
r
Uni Function
f Int
t -> Function -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Function -> val -> SRTree val
Uni Function
f (Fix SRTree -> SRTree (Fix SRTree))
-> EGraphST m (Fix SRTree) -> StateT EGraph m (SRTree (Fix SRTree))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> EGraphST m (Fix SRTree)
forall (m :: * -> *). Monad m => Int -> EGraphST m (Fix SRTree)
getExpressionFrom Int
t
Var Int
ix -> SRTree (Fix SRTree) -> StateT EGraph m (SRTree (Fix SRTree))
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SRTree (Fix SRTree) -> StateT EGraph m (SRTree (Fix SRTree)))
-> SRTree (Fix SRTree) -> StateT EGraph m (SRTree (Fix SRTree))
forall a b. (a -> b) -> a -> b
$ Int -> SRTree (Fix SRTree)
forall val. Int -> SRTree val
Var Int
ix
Const Double
x -> SRTree (Fix SRTree) -> StateT EGraph m (SRTree (Fix SRTree))
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SRTree (Fix SRTree) -> StateT EGraph m (SRTree (Fix SRTree)))
-> SRTree (Fix SRTree) -> StateT EGraph m (SRTree (Fix SRTree))
forall a b. (a -> b) -> a -> b
$ Double -> SRTree (Fix SRTree)
forall val. Double -> SRTree val
Const Double
x
Param Int
ix -> SRTree (Fix SRTree) -> StateT EGraph m (SRTree (Fix SRTree))
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SRTree (Fix SRTree) -> StateT EGraph m (SRTree (Fix SRTree)))
-> SRTree (Fix SRTree) -> StateT EGraph m (SRTree (Fix SRTree))
forall a b. (a -> b) -> a -> b
$ Int -> SRTree (Fix SRTree)
forall val. Int -> SRTree val
Param Int
ix
where
isTerm :: SRTree val -> Bool
isTerm (Var Int
_) = Bool
True
isTerm (Const Double
_) = Bool
True
isTerm (Param Int
_) = Bool
True
isTerm SRTree val
_ = Bool
False
{-# INLINE getExpressionFrom #-}
getAllExpressionsFrom :: Monad m => EClassId -> EGraphST m [Fix SRTree]
getAllExpressionsFrom :: forall (m :: * -> *). Monad m => Int -> EGraphST m [Fix SRTree]
getAllExpressionsFrom Int
eId' = do
Int
eId <- Int -> EGraphST m Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical Int
eId'
[ENode]
nodes <- (EGraph -> [ENode]) -> StateT EGraph m [ENode]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ENodeEnc -> ENode) -> [ENodeEnc] -> [ENode]
forall a b. (a -> b) -> [a] -> [b]
map ENodeEnc -> ENode
decodeEnode ([ENodeEnc] -> [ENode])
-> (EGraph -> [ENodeEnc]) -> EGraph -> [ENode]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HashSet ENodeEnc -> [ENodeEnc]
forall a. HashSet a -> [a]
Set.toList (HashSet ENodeEnc -> [ENodeEnc])
-> (EGraph -> HashSet ENodeEnc) -> EGraph -> [ENodeEnc]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> HashSet ENodeEnc
_eNodes (EClass -> HashSet ENodeEnc)
-> (EGraph -> EClass) -> EGraph -> HashSet ENodeEnc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap EClass -> Int -> EClass
forall a. IntMap a -> Int -> a
IntMap.! Int
eId) (IntMap EClass -> EClass)
-> (EGraph -> IntMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> IntMap EClass
_eClass)
let cands :: [ENode]
cands = (ENode -> Bool) -> [ENode] -> [ENode]
forall a. (a -> Bool) -> [a] -> [a]
filter ENode -> Bool
forall {val}. SRTree val -> Bool
isTerm [ENode]
nodes
[[Fix SRTree]] -> [Fix SRTree]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Fix SRTree]] -> [Fix SRTree])
-> StateT EGraph m [[Fix SRTree]] -> EGraphST m [Fix SRTree]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [ENode] -> StateT EGraph m [[Fix SRTree]]
forall {m :: * -> *}.
Monad m =>
[ENode] -> StateT EGraph m [[Fix SRTree]]
go [ENode]
nodes
where
isTerm :: SRTree val -> Bool
isTerm (Var Int
_) = Bool
True
isTerm (Const Double
_) = Bool
True
isTerm (Param Int
_) = Bool
True
isTerm SRTree val
_ = Bool
False
toTree :: SRTree val -> Fix SRTree
toTree (Var Int
ix) = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree) -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Int -> SRTree (Fix SRTree)
forall val. Int -> SRTree val
Var Int
ix
toTree (Const Double
x) = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree) -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Double -> SRTree (Fix SRTree)
forall val. Double -> SRTree val
Const Double
x
toTree (Param Int
ix) = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree) -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Int -> SRTree (Fix SRTree)
forall val. Int -> SRTree val
Param Int
ix
toTree SRTree val
_ = Fix SRTree
forall a. HasCallStack => a
undefined
go :: [ENode] -> StateT EGraph m [[Fix SRTree]]
go [] = [[Fix SRTree]] -> StateT EGraph m [[Fix SRTree]]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
go (ENode
n:[ENode]
ns) = do
[Fix SRTree]
t <- (SRTree (Fix SRTree) -> Fix SRTree)
-> [SRTree (Fix SRTree)] -> [Fix SRTree]
forall a b. (a -> b) -> [a] -> [b]
Prelude.map SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix ([SRTree (Fix SRTree)] -> [Fix SRTree])
-> StateT EGraph m [SRTree (Fix SRTree)]
-> StateT EGraph m [Fix SRTree]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> case ENode
n of
Bin Op
op Int
l Int
r -> do [Fix SRTree]
l' <- Int -> StateT EGraph m [Fix SRTree]
forall (m :: * -> *). Monad m => Int -> EGraphST m [Fix SRTree]
getAllExpressionsFrom Int
l
[Fix SRTree]
r' <- Int -> StateT EGraph m [Fix SRTree]
forall (m :: * -> *). Monad m => Int -> EGraphST m [Fix SRTree]
getAllExpressionsFrom Int
r
[SRTree (Fix SRTree)] -> StateT EGraph m [SRTree (Fix SRTree)]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([SRTree (Fix SRTree)] -> StateT EGraph m [SRTree (Fix SRTree)])
-> [SRTree (Fix SRTree)] -> StateT EGraph m [SRTree (Fix SRTree)]
forall a b. (a -> b) -> a -> b
$ [Op -> Fix SRTree -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Op -> val -> val -> SRTree val
Bin Op
op Fix SRTree
li Fix SRTree
ri | Fix SRTree
li <- [Fix SRTree]
l', Fix SRTree
ri <- [Fix SRTree]
r']
Uni Function
f Int
t -> (Fix SRTree -> SRTree (Fix SRTree))
-> [Fix SRTree] -> [SRTree (Fix SRTree)]
forall a b. (a -> b) -> [a] -> [b]
Prelude.map (Function -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Function -> val -> SRTree val
Uni Function
f) ([Fix SRTree] -> [SRTree (Fix SRTree)])
-> StateT EGraph m [Fix SRTree]
-> StateT EGraph m [SRTree (Fix SRTree)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> StateT EGraph m [Fix SRTree]
forall (m :: * -> *). Monad m => Int -> EGraphST m [Fix SRTree]
getAllExpressionsFrom Int
t
Var Int
ix -> [SRTree (Fix SRTree)] -> StateT EGraph m [SRTree (Fix SRTree)]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Int -> SRTree (Fix SRTree)
forall val. Int -> SRTree val
Var Int
ix]
Const Double
x -> [SRTree (Fix SRTree)] -> StateT EGraph m [SRTree (Fix SRTree)]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Double -> SRTree (Fix SRTree)
forall val. Double -> SRTree val
Const Double
x]
Param Int
ix -> [SRTree (Fix SRTree)] -> StateT EGraph m [SRTree (Fix SRTree)]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Int -> SRTree (Fix SRTree)
forall val. Int -> SRTree val
Param Int
ix]
[[Fix SRTree]]
ts <- [ENode] -> StateT EGraph m [[Fix SRTree]]
go [ENode]
ns
[[Fix SRTree]] -> StateT EGraph m [[Fix SRTree]]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Fix SRTree]
t[Fix SRTree] -> [[Fix SRTree]] -> [[Fix SRTree]]
forall a. a -> [a] -> [a]
:[[Fix SRTree]]
ts)
{-# INLINE getAllExpressionsFrom #-}
getNExpressionsFrom :: Monad m => Int -> EClassId -> EGraphST m [Fix SRTree]
getNExpressionsFrom :: forall (m :: * -> *).
Monad m =>
Int -> Int -> EGraphST m [Fix SRTree]
getNExpressionsFrom Int
n Int
eId' = Int -> Int -> Int -> EGraphST m [Fix SRTree]
forall (m :: * -> *).
Monad m =>
Int -> Int -> Int -> EGraphST m [Fix SRTree]
getNExpressionsFrom' Int
n Int
15 Int
eId'
getNExpressionsFrom' :: Monad m => Int -> Int -> EClassId -> EGraphST m [Fix SRTree]
getNExpressionsFrom' :: forall (m :: * -> *).
Monad m =>
Int -> Int -> Int -> EGraphST m [Fix SRTree]
getNExpressionsFrom' Int
_ Int
0 Int
_ = [Fix SRTree] -> StateT EGraph m [Fix SRTree]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
getNExpressionsFrom' Int
n Int
d Int
eId' = do
Int
eId <- Int -> EGraphST m Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical Int
eId'
[ENode]
nodes <- (EGraph -> [ENode]) -> StateT EGraph m [ENode]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ENodeEnc -> ENode) -> [ENodeEnc] -> [ENode]
forall a b. (a -> b) -> [a] -> [b]
map ENodeEnc -> ENode
decodeEnode ([ENodeEnc] -> [ENode])
-> (EGraph -> [ENodeEnc]) -> EGraph -> [ENode]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HashSet ENodeEnc -> [ENodeEnc]
forall a. HashSet a -> [a]
Set.toList (HashSet ENodeEnc -> [ENodeEnc])
-> (EGraph -> HashSet ENodeEnc) -> EGraph -> [ENodeEnc]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> HashSet ENodeEnc
_eNodes (EClass -> HashSet ENodeEnc)
-> (EGraph -> EClass) -> EGraph -> HashSet ENodeEnc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap EClass -> Int -> EClass
forall a. IntMap a -> Int -> a
IntMap.! Int
eId) (IntMap EClass -> EClass)
-> (EGraph -> IntMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> IntMap EClass
_eClass)
([[Fix SRTree]] -> [Fix SRTree]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Fix SRTree]] -> [Fix SRTree])
-> StateT EGraph m [[Fix SRTree]] -> StateT EGraph m [Fix SRTree]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Int -> [ENode] -> StateT EGraph m [[Fix SRTree]]
forall {m :: * -> *}.
Monad m =>
Int -> Int -> [ENode] -> StateT EGraph m [[Fix SRTree]]
go Int
n Int
d [ENode]
nodes)
where
isTerm :: SRTree val -> Bool
isTerm (Var Int
_) = Bool
True
isTerm (Const Double
_) = Bool
True
isTerm (Param Int
_) = Bool
True
isTerm SRTree val
_ = Bool
False
toTree :: SRTree val -> Fix SRTree
toTree (Var Int
ix) = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree) -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Int -> SRTree (Fix SRTree)
forall val. Int -> SRTree val
Var Int
ix
toTree (Const Double
x) = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree) -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Double -> SRTree (Fix SRTree)
forall val. Double -> SRTree val
Const Double
x
toTree (Param Int
ix) = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree) -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Int -> SRTree (Fix SRTree)
forall val. Int -> SRTree val
Param Int
ix
toTree SRTree val
_ = Fix SRTree
forall a. HasCallStack => a
undefined
go :: Int -> Int -> [ENode] -> StateT EGraph m [[Fix SRTree]]
go Int
n' Int
_ [] = [[Fix SRTree]] -> StateT EGraph m [[Fix SRTree]]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
go Int
n' Int
0 [ENode]
ts = [[Fix SRTree]] -> StateT EGraph m [[Fix SRTree]]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
go Int
n' Int
d (ENode
node:[ENode]
ns) = do
[Fix SRTree]
tt <- (SRTree (Fix SRTree) -> Fix SRTree)
-> [SRTree (Fix SRTree)] -> [Fix SRTree]
forall a b. (a -> b) -> [a] -> [b]
Prelude.map SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix ([SRTree (Fix SRTree)] -> [Fix SRTree])
-> StateT EGraph m [SRTree (Fix SRTree)]
-> StateT EGraph m [Fix SRTree]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> case ENode
node of
Bin Op
op Int
l Int
r -> do [Fix SRTree]
l' <- Int -> Int -> Int -> StateT EGraph m [Fix SRTree]
forall (m :: * -> *).
Monad m =>
Int -> Int -> Int -> EGraphST m [Fix SRTree]
getNExpressionsFrom' Int
n' (Int
dInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) Int
l
[Fix SRTree]
r' <- Int -> Int -> Int -> StateT EGraph m [Fix SRTree]
forall (m :: * -> *).
Monad m =>
Int -> Int -> Int -> EGraphST m [Fix SRTree]
getNExpressionsFrom' Int
n' (Int
dInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) Int
r
[SRTree (Fix SRTree)] -> StateT EGraph m [SRTree (Fix SRTree)]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([SRTree (Fix SRTree)] -> StateT EGraph m [SRTree (Fix SRTree)])
-> [SRTree (Fix SRTree)] -> StateT EGraph m [SRTree (Fix SRTree)]
forall a b. (a -> b) -> a -> b
$ Int -> [SRTree (Fix SRTree)] -> [SRTree (Fix SRTree)]
forall a. Int -> [a] -> [a]
Prelude.take Int
n [Op -> Fix SRTree -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Op -> val -> val -> SRTree val
Bin Op
op Fix SRTree
li Fix SRTree
ri | Fix SRTree
li <- [Fix SRTree]
l', Fix SRTree
ri <- [Fix SRTree]
r']
Uni Function
f Int
t -> (Fix SRTree -> SRTree (Fix SRTree))
-> [Fix SRTree] -> [SRTree (Fix SRTree)]
forall a b. (a -> b) -> [a] -> [b]
Prelude.map (Function -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Function -> val -> SRTree val
Uni Function
f) ([Fix SRTree] -> [SRTree (Fix SRTree)])
-> StateT EGraph m [Fix SRTree]
-> StateT EGraph m [SRTree (Fix SRTree)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Int -> Int -> StateT EGraph m [Fix SRTree]
forall (m :: * -> *).
Monad m =>
Int -> Int -> Int -> EGraphST m [Fix SRTree]
getNExpressionsFrom' Int
n' (Int
dInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) Int
t
Var Int
ix -> [SRTree (Fix SRTree)] -> StateT EGraph m [SRTree (Fix SRTree)]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Int -> SRTree (Fix SRTree)
forall val. Int -> SRTree val
Var Int
ix]
Const Double
x -> [SRTree (Fix SRTree)] -> StateT EGraph m [SRTree (Fix SRTree)]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Double -> SRTree (Fix SRTree)
forall val. Double -> SRTree val
Const Double
x]
Param Int
ix -> [SRTree (Fix SRTree)] -> StateT EGraph m [SRTree (Fix SRTree)]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Int -> SRTree (Fix SRTree)
forall val. Int -> SRTree val
Param Int
ix]
let n'' :: Int
n'' = Int
n' Int -> Int -> Int
forall a. Num a => a -> a -> a
- [Fix SRTree] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Fix SRTree]
tt
if Int
n'' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0
then [[Fix SRTree]] -> StateT EGraph m [[Fix SRTree]]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [[Fix SRTree]
tt]
else do [[Fix SRTree]]
ts <- Int -> Int -> [ENode] -> StateT EGraph m [[Fix SRTree]]
go Int
n'' (Int
dInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) [ENode]
ns
[[Fix SRTree]] -> StateT EGraph m [[Fix SRTree]]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Fix SRTree]
tt[Fix SRTree] -> [[Fix SRTree]] -> [[Fix SRTree]]
forall a. a -> [a] -> [a]
:[[Fix SRTree]]
ts)
getAllChildEClasses :: Monad m => EClassId -> EGraphST m [EClassId]
getAllChildEClasses :: forall (m :: * -> *). Monad m => Int -> EGraphST m [Int]
getAllChildEClasses Int
eId' = do
Int
eId <- Int -> EGraphST m Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical Int
eId'
IntSet -> [Int]
IntSet.toList (IntSet -> [Int]) -> StateT EGraph m IntSet -> EGraphST m [Int]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int] -> IntSet -> StateT EGraph m IntSet
forall (m :: * -> *).
Monad m =>
[Int] -> IntSet -> EGraphST m IntSet
go [Int
eId] IntSet
IntSet.empty
where
hasNoTerminal :: [ENode] -> Bool
hasNoTerminal :: [ENode] -> Bool
hasNoTerminal = (ENode -> Bool) -> [ENode] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Bool -> Bool
not (Bool -> Bool) -> (ENode -> Bool) -> ENode -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([Int] -> Bool) -> (ENode -> [Int]) -> ENode -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ENode -> [Int]
forall a. SRTree a -> [a]
childrenOf)
getNodes :: Monad m => EClassId -> EGraphST m [ENode]
getNodes :: forall (m :: * -> *). Monad m => Int -> EGraphST m [ENode]
getNodes Int
n = (EGraph -> [ENode]) -> StateT EGraph m [ENode]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ENodeEnc -> ENode) -> [ENodeEnc] -> [ENode]
forall a b. (a -> b) -> [a] -> [b]
map ENodeEnc -> ENode
decodeEnode ([ENodeEnc] -> [ENode])
-> (EGraph -> [ENodeEnc]) -> EGraph -> [ENode]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HashSet ENodeEnc -> [ENodeEnc]
forall a. HashSet a -> [a]
Set.toList (HashSet ENodeEnc -> [ENodeEnc])
-> (EGraph -> HashSet ENodeEnc) -> EGraph -> [ENodeEnc]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> HashSet ENodeEnc
_eNodes (EClass -> HashSet ENodeEnc)
-> (EGraph -> EClass) -> EGraph -> HashSet ENodeEnc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap EClass -> Int -> EClass
forall a. IntMap a -> Int -> a
IntMap.! Int
n) (IntMap EClass -> EClass)
-> (EGraph -> IntMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> IntMap EClass
_eClass)
go :: Monad m => [Int] -> IntSet.IntSet -> EGraphST m IntSet.IntSet
go :: forall (m :: * -> *).
Monad m =>
[Int] -> IntSet -> EGraphST m IntSet
go [] IntSet
visited = IntSet -> StateT EGraph m IntSet
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure IntSet
visited
go [Int]
queue IntSet
visited = do
[Int]
nodes <- (ENode -> [Int]) -> [ENode] -> [Int]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ENode -> [Int]
forall a. SRTree a -> [a]
childrenOf ([ENode] -> [Int]) -> ([[ENode]] -> [ENode]) -> [[ENode]] -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[ENode]] -> [ENode]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[ENode]] -> [ENode])
-> ([[ENode]] -> [[ENode]]) -> [[ENode]] -> [ENode]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([ENode] -> Bool) -> [[ENode]] -> [[ENode]]
forall a. (a -> Bool) -> [a] -> [a]
filter [ENode] -> Bool
hasNoTerminal ([[ENode]] -> [Int])
-> StateT EGraph m [[ENode]] -> StateT EGraph m [Int]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Int -> StateT EGraph m [ENode])
-> [Int] -> StateT EGraph m [[ENode]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Int -> StateT EGraph m [ENode]
forall (m :: * -> *). Monad m => Int -> EGraphST m [ENode]
getNodes [Int]
queue
[Int]
eids <- (Int -> Bool) -> [Int] -> [Int]
forall a. (a -> Bool) -> [a] -> [a]
filter (\Int
e -> Int
e Int -> IntSet -> Bool
`IntSet.notMember` IntSet
visited) ([Int] -> [Int]) -> StateT EGraph m [Int] -> StateT EGraph m [Int]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Int -> StateT EGraph m Int) -> [Int] -> StateT EGraph m [Int]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Int -> StateT EGraph m Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical [Int]
nodes)
[Int] -> IntSet -> StateT EGraph m IntSet
forall (m :: * -> *).
Monad m =>
[Int] -> IntSet -> EGraphST m IntSet
go [Int]
eids (IntSet
visited IntSet -> IntSet -> IntSet
`IntSet.union` [Int] -> IntSet
IntSet.fromList [Int]
queue)
{-# INLINE getAllChildEClasses #-}
getAllChildBestEClasses :: Monad m => EClassId -> EGraphST m [EClassId]
getAllChildBestEClasses :: forall (m :: * -> *). Monad m => Int -> EGraphST m [Int]
getAllChildBestEClasses Int
eId' = do
Int
eId <- Int -> EGraphST m Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical Int
eId'
[Int] -> [Int]
forall a. Eq a => [a] -> [a]
nub ([Int] -> [Int]) -> EGraphST m [Int] -> EGraphST m [Int]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> EGraphST m [Int]
forall (m :: * -> *). Monad m => Int -> EGraphST m [Int]
go Int
eId
where
go :: Monad m => Int -> EGraphST m [Int]
go :: forall (m :: * -> *). Monad m => Int -> EGraphST m [Int]
go Int
n = do ENode
node <- (EGraph -> ENode) -> StateT EGraph m ENode
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EClassData -> ENode
_best (EClassData -> ENode) -> (EGraph -> EClassData) -> EGraph -> ENode
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info (EClass -> EClassData)
-> (EGraph -> EClass) -> EGraph -> EClassData
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap EClass -> Int -> EClass
forall a. IntMap a -> Int -> a
IntMap.! Int
n) (IntMap EClass -> EClass)
-> (EGraph -> IntMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> IntMap EClass
_eClass)
let hasTerminal :: Bool
hasTerminal = ([Int] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([Int] -> Bool) -> (ENode -> [Int]) -> ENode -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ENode -> [Int]
forall a. SRTree a -> [a]
childrenOf) ENode
node
[Int]
eids <- (Int -> StateT EGraph m Int) -> [Int] -> EGraphST m [Int]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Int -> StateT EGraph m Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical ([Int] -> EGraphST m [Int]) -> [Int] -> EGraphST m [Int]
forall a b. (a -> b) -> a -> b
$ ENode -> [Int]
forall a. SRTree a -> [a]
childrenOf ENode
node
if Bool
hasTerminal
then [Int] -> EGraphST m [Int]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Int
n]
else do [[Int]]
eids' <- (Int -> EGraphST m [Int]) -> [Int] -> StateT EGraph m [[Int]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Int -> EGraphST m [Int]
forall (m :: * -> *). Monad m => Int -> EGraphST m [Int]
go [Int]
eids
[Int] -> EGraphST m [Int]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Int
n Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
eids) [Int] -> [Int] -> [Int]
forall a. Semigroup a => a -> a -> a
<> [[Int]] -> [Int]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Int]]
eids')
getRndExpressionFrom :: EClassId -> EGraphST (State StdGen) (Fix SRTree)
getRndExpressionFrom :: Int -> EGraphST (State StdGen) (Fix SRTree)
getRndExpressionFrom Int
eId' = do
Int
eId <- Int -> EGraphST (State StdGen) Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical Int
eId'
[ENodeEnc]
nodes <- (EGraph -> [ENodeEnc]) -> StateT EGraph (State StdGen) [ENodeEnc]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (HashSet ENodeEnc -> [ENodeEnc]
forall a. HashSet a -> [a]
Set.toList (HashSet ENodeEnc -> [ENodeEnc])
-> (EGraph -> HashSet ENodeEnc) -> EGraph -> [ENodeEnc]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> HashSet ENodeEnc
_eNodes (EClass -> HashSet ENodeEnc)
-> (EGraph -> EClass) -> EGraph -> HashSet ENodeEnc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap EClass -> Int -> EClass
forall a. IntMap a -> Int -> a
IntMap.! Int
eId) (IntMap EClass -> EClass)
-> (EGraph -> IntMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> IntMap EClass
_eClass)
ENodeEnc
n <- State StdGen ENodeEnc -> StateT EGraph (State StdGen) ENodeEnc
forall (m :: * -> *) a. Monad m => m a -> StateT EGraph m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (State StdGen ENodeEnc -> StateT EGraph (State StdGen) ENodeEnc)
-> State StdGen ENodeEnc -> StateT EGraph (State StdGen) ENodeEnc
forall a b. (a -> b) -> a -> b
$ [ENodeEnc] -> State StdGen ENodeEnc
forall {m :: * -> *} {s} {b}.
(MonadState s m, RandomGen s) =>
[b] -> m b
randomFrom [ENodeEnc]
nodes
SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> StateT EGraph (State StdGen) (SRTree (Fix SRTree))
-> EGraphST (State StdGen) (Fix SRTree)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> case ENodeEnc -> ENode
decodeEnode ENodeEnc
n of
Bin Op
op Int
l Int
r -> Op -> Fix SRTree -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Op -> val -> val -> SRTree val
Bin Op
op (Fix SRTree -> Fix SRTree -> SRTree (Fix SRTree))
-> EGraphST (State StdGen) (Fix SRTree)
-> StateT EGraph (State StdGen) (Fix SRTree -> SRTree (Fix SRTree))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> EGraphST (State StdGen) (Fix SRTree)
getRndExpressionFrom Int
l StateT EGraph (State StdGen) (Fix SRTree -> SRTree (Fix SRTree))
-> EGraphST (State StdGen) (Fix SRTree)
-> StateT EGraph (State StdGen) (SRTree (Fix SRTree))
forall a b.
StateT EGraph (State StdGen) (a -> b)
-> StateT EGraph (State StdGen) a -> StateT EGraph (State StdGen) b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> EGraphST (State StdGen) (Fix SRTree)
getRndExpressionFrom Int
r
Uni Function
f Int
t -> Function -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Function -> val -> SRTree val
Uni Function
f (Fix SRTree -> SRTree (Fix SRTree))
-> EGraphST (State StdGen) (Fix SRTree)
-> StateT EGraph (State StdGen) (SRTree (Fix SRTree))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> EGraphST (State StdGen) (Fix SRTree)
getRndExpressionFrom Int
t
Var Int
ix -> SRTree (Fix SRTree)
-> StateT EGraph (State StdGen) (SRTree (Fix SRTree))
forall a. a -> StateT EGraph (State StdGen) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SRTree (Fix SRTree)
-> StateT EGraph (State StdGen) (SRTree (Fix SRTree)))
-> SRTree (Fix SRTree)
-> StateT EGraph (State StdGen) (SRTree (Fix SRTree))
forall a b. (a -> b) -> a -> b
$ Int -> SRTree (Fix SRTree)
forall val. Int -> SRTree val
Var Int
ix
Const Double
x -> SRTree (Fix SRTree)
-> StateT EGraph (State StdGen) (SRTree (Fix SRTree))
forall a. a -> StateT EGraph (State StdGen) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SRTree (Fix SRTree)
-> StateT EGraph (State StdGen) (SRTree (Fix SRTree)))
-> SRTree (Fix SRTree)
-> StateT EGraph (State StdGen) (SRTree (Fix SRTree))
forall a b. (a -> b) -> a -> b
$ Double -> SRTree (Fix SRTree)
forall val. Double -> SRTree val
Const Double
x
Param Int
ix -> SRTree (Fix SRTree)
-> StateT EGraph (State StdGen) (SRTree (Fix SRTree))
forall a. a -> StateT EGraph (State StdGen) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SRTree (Fix SRTree)
-> StateT EGraph (State StdGen) (SRTree (Fix SRTree)))
-> SRTree (Fix SRTree)
-> StateT EGraph (State StdGen) (SRTree (Fix SRTree))
forall a b. (a -> b) -> a -> b
$ Int -> SRTree (Fix SRTree)
forall val. Int -> SRTree val
Param Int
ix
where
randomRange :: (a, a) -> m a
randomRange (a, a)
rng = (s -> (a, s)) -> m a
forall a. (s -> (a, s)) -> m a
forall s (m :: * -> *) a. MonadState s m => (s -> (a, s)) -> m a
state ((a, a) -> s -> (a, s)
forall g. RandomGen g => (a, a) -> g -> (a, g)
forall a g. (Random a, RandomGen g) => (a, a) -> g -> (a, g)
randomR (a, a)
rng)
randomFrom :: [b] -> m b
randomFrom [b]
xs = do Int
n <- (Int, Int) -> m Int
forall {s} {m :: * -> *} {a}.
(MonadState s m, Random a, RandomGen s) =>
(a, a) -> m a
randomRange (Int
0, [b] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [b]
xs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
b -> m b
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (b -> m b) -> b -> m b
forall a b. (a -> b) -> a -> b
$ [b]
xs [b] -> Int -> b
forall a. HasCallStack => [a] -> Int -> a
!! Int
n
{-# INLINE getRndExpressionFrom #-}
cleanMaps :: Monad m => EGraphST m ()
cleanMaps :: forall (m :: * -> *). Monad m => EGraphST m ()
cleanMaps = do
Map ENode Int
enode2eclass <- (EGraph -> Map ENode Int) -> StateT EGraph m (Map ENode Int)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets EGraph -> Map ENode Int
_eNodeToEClass
[(ENode, Int)]
entries <- [(ENode, Int)]
-> ((ENode, Int) -> StateT EGraph m (ENode, Int))
-> StateT EGraph m [(ENode, Int)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (Map ENode Int -> [(ENode, Int)]
forall k a. Map k a -> [(k, a)]
Map.toList Map ENode Int
enode2eclass) (((ENode, Int) -> StateT EGraph m (ENode, Int))
-> StateT EGraph m [(ENode, Int)])
-> ((ENode, Int) -> StateT EGraph m (ENode, Int))
-> StateT EGraph m [(ENode, Int)]
forall a b. (a -> b) -> a -> b
$ \(ENode
k,Int
v) -> do
ENode
k' <- ENode -> EGraphST m ENode
forall (m :: * -> *). Monad m => ENode -> EGraphST m ENode
canonize ENode
k
Int
v' <- Int -> EGraphST m Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical Int
v
(ENode, Int) -> StateT EGraph m (ENode, Int)
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ENode
k',Int
v')
let enode2eclass' :: Map ENode Int
enode2eclass' = [(ENode, Int)] -> Map ENode Int
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [(ENode, Int)]
entries
IntMap EClass
eclassMap <- (EGraph -> IntMap EClass) -> StateT EGraph m (IntMap EClass)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets EGraph -> IntMap EClass
_eClass
[Maybe (Int, EClass)]
entries' <- [(Int, EClass)]
-> ((Int, EClass) -> StateT EGraph m (Maybe (Int, EClass)))
-> StateT EGraph m [Maybe (Int, EClass)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (IntMap EClass -> [(Int, EClass)]
forall a. IntMap a -> [(Int, a)]
IntMap.toList IntMap EClass
eclassMap) (((Int, EClass) -> StateT EGraph m (Maybe (Int, EClass)))
-> StateT EGraph m [Maybe (Int, EClass)])
-> ((Int, EClass) -> StateT EGraph m (Maybe (Int, EClass)))
-> StateT EGraph m [Maybe (Int, EClass)]
forall a b. (a -> b) -> a -> b
$ \(Int
k,EClass
v) -> do
Int
k' <- Int -> EGraphST m Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical Int
k
Maybe (Int, EClass) -> StateT EGraph m (Maybe (Int, EClass))
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (Int, EClass) -> StateT EGraph m (Maybe (Int, EClass)))
-> Maybe (Int, EClass) -> StateT EGraph m (Maybe (Int, EClass))
forall a b. (a -> b) -> a -> b
$ if Int
kInt -> Int -> Bool
forall a. Eq a => a -> a -> Bool
==Int
k' then ((Int, EClass) -> Maybe (Int, EClass)
forall a. a -> Maybe a
Just (Int
k,EClass
v)) else Maybe (Int, EClass)
forall a. Maybe a
Nothing
let eclassMap' :: IntMap EClass
eclassMap' = [(Int, EClass)] -> IntMap EClass
forall a. [(Int, a)] -> IntMap a
IntMap.fromList ([Maybe (Int, EClass)] -> [(Int, EClass)]
forall a. [Maybe a] -> [a]
catMaybes [Maybe (Int, EClass)]
entries')
ClassIdMap Int
canon <- (EGraph -> ClassIdMap Int) -> StateT EGraph m (ClassIdMap Int)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets EGraph -> ClassIdMap Int
_canonicalMap
[Maybe (Int, Int)]
entries'' <- [(Int, Int)]
-> ((Int, Int) -> StateT EGraph m (Maybe (Int, Int)))
-> StateT EGraph m [Maybe (Int, Int)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (ClassIdMap Int -> [(Int, Int)]
forall a. IntMap a -> [(Int, a)]
IntMap.toList ClassIdMap Int
canon) (((Int, Int) -> StateT EGraph m (Maybe (Int, Int)))
-> StateT EGraph m [Maybe (Int, Int)])
-> ((Int, Int) -> StateT EGraph m (Maybe (Int, Int)))
-> StateT EGraph m [Maybe (Int, Int)]
forall a b. (a -> b) -> a -> b
$ \(Int
k,Int
v) -> do
Maybe (Int, Int) -> StateT EGraph m (Maybe (Int, Int))
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (Int, Int) -> StateT EGraph m (Maybe (Int, Int)))
-> Maybe (Int, Int) -> StateT EGraph m (Maybe (Int, Int))
forall a b. (a -> b) -> a -> b
$ if Int
kInt -> Int -> Bool
forall a. Eq a => a -> a -> Bool
==Int
v then (Int, Int) -> Maybe (Int, Int)
forall a. a -> Maybe a
Just (Int
k,Int
v) else Maybe (Int, Int)
forall a. Maybe a
Nothing
let canon' :: ClassIdMap Int
canon' = [(Int, Int)] -> ClassIdMap Int
forall a. [(Int, a)] -> IntMap a
IntMap.fromList ([Maybe (Int, Int)] -> [(Int, Int)]
forall a. [Maybe a] -> [a]
catMaybes [Maybe (Int, Int)]
entries'')
EGraphDB
eDB' <- (EGraph -> EGraphDB) -> StateT EGraph m EGraphDB
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets EGraph -> EGraphDB
_eDB
EGraph -> EGraphST m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (EGraph -> EGraphST m ()) -> EGraph -> EGraphST m ()
forall a b. (a -> b) -> a -> b
$ ClassIdMap Int
-> Map ENode Int -> IntMap EClass -> EGraphDB -> EGraph
EGraph ClassIdMap Int
canon Map ENode Int
enode2eclass' IntMap EClass
eclassMap' EGraphDB
eDB'
EGraphST m ()
forall (m :: * -> *) s. Monad m => StateT s m ()
forceState
{-# INLINE cleanMaps #-}
forceState :: Monad m => StateT s m ()
forceState :: forall (m :: * -> *) s. Monad m => StateT s m ()
forceState = StateT s m s
forall s (m :: * -> *). MonadState s m => m s
get StateT s m s -> (s -> StateT s m ()) -> StateT s m ()
forall a b. StateT s m a -> (a -> StateT s m b) -> StateT s m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \ !s
_ -> () -> StateT s m ()
forall a. a -> StateT s m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
{-# INLINE forceState #-}