{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE TupleSections #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Algorithm.EqSat.Queries
-- Copyright   :  (c) Fabricio Olivetti 2021 - 2024
-- License     :  BSD3
-- Maintainer  :  fabricio.olivetti@gmail.com
-- Stability   :  experimental
-- Portability :
--
-- Query functions for e-graphs
-- Heavily based on hegg (https://github.com/alt-romes/hegg by alt-romes)
--
-----------------------------------------------------------------------------

module Algorithm.EqSat.Queries where

import Algorithm.EqSat.Egraph
import qualified Data.IntMap as IntMap
import qualified Data.Map as Map
import qualified Data.HashSet as Set
import Control.Monad.State ( gets, modify' )
import Control.Monad ( filterM )
import Control.Lens ( over )
import Data.Maybe
import Data.Sequence ( Seq(..) )
import qualified Data.Sequence as FingerTree
import qualified Data.Foldable as Foldable
import Data.SRTree (childrenOf)

import Debug.Trace

-- this is too slow for now, it needs a db of its own
-- basically a db for each query we need
getEClassesThat :: Monad m => (EClass -> Bool) -> EGraphST m [EClassId]
getEClassesThat :: forall (m :: * -> *).
Monad m =>
(EClass -> Bool) -> EGraphST m [EClassId]
getEClassesThat EClass -> Bool
p = do
    (EGraph -> [EClassId]) -> EGraphST m [EClassId]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (((EClassId, EClass) -> EClassId)
-> [(EClassId, EClass)] -> [EClassId]
forall a b. (a -> b) -> [a] -> [b]
map (EClassId, EClass) -> EClassId
forall a b. (a, b) -> a
fst ([(EClassId, EClass)] -> [EClassId])
-> (EGraph -> [(EClassId, EClass)]) -> EGraph -> [EClassId]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((EClassId, EClass) -> Bool)
-> [(EClassId, EClass)] -> [(EClassId, EClass)]
forall a. (a -> Bool) -> [a] -> [a]
filter (\(EClassId
ecId, EClass
ec) -> EClass -> Bool
p EClass
ec) ([(EClassId, EClass)] -> [(EClassId, EClass)])
-> (EGraph -> [(EClassId, EClass)])
-> EGraph
-> [(EClassId, EClass)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IntMap EClass -> [(EClassId, EClass)]
forall a. IntMap a -> [(EClassId, a)]
IntMap.toList (IntMap EClass -> [(EClassId, EClass)])
-> (EGraph -> IntMap EClass) -> EGraph -> [(EClassId, EClass)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> IntMap EClass
_eClass)
    --go ecs
        where
            go :: Monad m => [EClassId] -> EGraphST m [EClassId]
            go :: forall (m :: * -> *).
Monad m =>
[EClassId] -> EGraphST m [EClassId]
go [] = [EClassId] -> StateT EGraph m [EClassId]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
            go (EClassId
ecId:[EClassId]
ecs) = do Bool
ec <- (EGraph -> Bool) -> StateT EGraph m Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EClass -> Bool
p (EClass -> Bool) -> (EGraph -> EClass) -> EGraph -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap EClass -> EClassId -> EClass
forall a. IntMap a -> EClassId -> a
IntMap.! EClassId
ecId) (IntMap EClass -> EClass)
-> (EGraph -> IntMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> IntMap EClass
_eClass)
                               [EClassId]
ecs' <- [EClassId] -> StateT EGraph m [EClassId]
forall (m :: * -> *).
Monad m =>
[EClassId] -> EGraphST m [EClassId]
go [EClassId]
ecs
                               if Bool
ec
                                  then [EClassId] -> StateT EGraph m [EClassId]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (EClassId
ecIdEClassId -> [EClassId] -> [EClassId]
forall a. a -> [a] -> [a]
:[EClassId]
ecs')
                                  else [EClassId] -> StateT EGraph m [EClassId]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [EClassId]
ecs'

updateFitness :: Monad m => Double -> EClassId -> EGraphST m ()
updateFitness :: forall (m :: * -> *).
Monad m =>
Double -> EClassId -> EGraphST m ()
updateFitness Double
f EClassId
ecId = do
   EClass
ec   <- (EGraph -> EClass) -> StateT EGraph m EClass
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((IntMap EClass -> EClassId -> EClass
forall a. IntMap a -> EClassId -> a
IntMap.! EClassId
ecId) (IntMap EClass -> EClass)
-> (EGraph -> IntMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> IntMap EClass
_eClass)
   let info :: EClassData
info = EClass -> EClassData
_info EClass
ec
   (EGraph -> EGraph) -> EGraphST m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> EGraphST m ())
-> (EGraph -> EGraph) -> EGraphST m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (IntMap EClass) (IntMap EClass)
-> (IntMap EClass -> IntMap EClass) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter EGraph EGraph (IntMap EClass) (IntMap EClass)
Lens' EGraph (IntMap EClass)
eClass (EClassId -> EClass -> IntMap EClass -> IntMap EClass
forall a. EClassId -> a -> IntMap a -> IntMap a
IntMap.insert EClassId
ecId EClass
ec{_info=info{_fitness = Just f}})

-- | returns all the root e-classes (e-class without parents)
findRootClasses :: Monad m => EGraphST m [EClassId]
findRootClasses :: forall (m :: * -> *). Monad m => EGraphST m [EClassId]
findRootClasses = (EGraph -> [EClassId]) -> StateT EGraph m [EClassId]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (((EClassId, EClass) -> EClassId)
-> [(EClassId, EClass)] -> [EClassId]
forall a b. (a -> b) -> [a] -> [b]
Prelude.map (EClassId, EClass) -> EClassId
forall a b. (a, b) -> a
fst ([(EClassId, EClass)] -> [EClassId])
-> (EGraph -> [(EClassId, EClass)]) -> EGraph -> [EClassId]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((EClassId, EClass) -> Bool)
-> [(EClassId, EClass)] -> [(EClassId, EClass)]
forall a. (a -> Bool) -> [a] -> [a]
Prelude.filter (EClassId, EClass) -> Bool
isParent ([(EClassId, EClass)] -> [(EClassId, EClass)])
-> (EGraph -> [(EClassId, EClass)])
-> EGraph
-> [(EClassId, EClass)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IntMap EClass -> [(EClassId, EClass)]
forall a. IntMap a -> [(EClassId, a)]
IntMap.toList (IntMap EClass -> [(EClassId, EClass)])
-> (EGraph -> IntMap EClass) -> EGraph -> [(EClassId, EClass)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> IntMap EClass
_eClass)
  where
    isParent :: (EClassId, EClass) -> Bool
isParent (EClassId
k, EClass
v) = HashSet (EClassId, ENode) -> Bool
forall a. HashSet a -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
Prelude.null (EClass -> HashSet (EClassId, ENode)
_parents EClass
v) Bool -> Bool -> Bool
||  (EClassId
k EClassId -> HashSet EClassId -> Bool
forall a. Hashable a => a -> HashSet a -> Bool
`Set.member` (((EClassId, ENode) -> EClassId)
-> HashSet (EClassId, ENode) -> HashSet EClassId
forall b a. Hashable b => (a -> b) -> HashSet a -> HashSet b
Set.map (EClassId, ENode) -> EClassId
forall a b. (a, b) -> a
fst (EClass -> HashSet (EClassId, ENode)
_parents EClass
v)))

-- | returns the e-class id with the best fitness that
-- is true to a predicate
getTopECLassThat :: Monad m => Bool -> Int -> (EClass -> Bool) -> EGraphST m [EClassId]
getTopECLassThat :: forall (m :: * -> *).
Monad m =>
Bool -> EClassId -> (EClass -> Bool) -> EGraphST m [EClassId]
getTopECLassThat Bool
b EClassId
n EClass -> Bool
p = do
  let f :: EGraphDB -> RangeTree Double
f = if Bool
b then EGraphDB -> RangeTree Double
_fitRangeDB else EGraphDB -> RangeTree Double
_dlRangeDB
  (EGraph -> RangeTree Double) -> StateT EGraph m (RangeTree Double)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EGraphDB -> RangeTree Double
f (EGraphDB -> RangeTree Double)
-> (EGraph -> EGraphDB) -> EGraph -> RangeTree Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> EGraphDB
_eDB)
    StateT EGraph m (RangeTree Double)
-> (RangeTree Double -> EGraphST m [EClassId])
-> EGraphST m [EClassId]
forall a b.
StateT EGraph m a -> (a -> StateT EGraph m b) -> StateT EGraph m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= EClassId -> [EClassId] -> RangeTree Double -> EGraphST m [EClassId]
forall (m :: * -> *).
Monad m =>
EClassId -> [EClassId] -> RangeTree Double -> EGraphST m [EClassId]
go EClassId
n []
  where
    go :: Monad m => Int -> [EClassId] -> RangeTree Double -> EGraphST m [EClassId]
    go :: forall (m :: * -> *).
Monad m =>
EClassId -> [EClassId] -> RangeTree Double -> EGraphST m [EClassId]
go EClassId
0 [EClassId]
bests RangeTree Double
rt = [EClassId] -> StateT EGraph m [EClassId]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [EClassId]
bests
    go EClassId
m [EClassId]
bests RangeTree Double
rt = case RangeTree Double
rt of
                       RangeTree Double
Empty   -> [EClassId] -> StateT EGraph m [EClassId]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [EClassId]
bests
                       RangeTree Double
t :|> (Double, EClassId)
y -> do let x :: EClassId
x = (Double, EClassId) -> EClassId
forall a b. (a, b) -> b
snd (Double, EClassId)
y
                                     EClassId
ecId <- EClassId -> EGraphST m EClassId
forall (m :: * -> *). Monad m => EClassId -> EGraphST m EClassId
canonical EClassId
x
                                     EClass
ec <- (EGraph -> EClass) -> StateT EGraph m EClass
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((IntMap EClass -> EClassId -> EClass
forall a. IntMap a -> EClassId -> a
IntMap.! EClassId
ecId) (IntMap EClass -> EClass)
-> (EGraph -> IntMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> IntMap EClass
_eClass)
                                     if (Double -> Bool
forall a. RealFloat a => a -> Bool
isInfinite (Double -> Bool) -> (EClass -> Double) -> EClass -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe Double -> Double)
-> (EClass -> Maybe Double) -> EClass -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClassData -> Maybe Double
_fitness (EClassData -> Maybe Double)
-> (EClass -> EClassData) -> EClass -> Maybe Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info (EClass -> Bool) -> EClass -> Bool
forall a b. (a -> b) -> a -> b
$ EClass
ec)
                                       then EClassId
-> [EClassId] -> RangeTree Double -> StateT EGraph m [EClassId]
forall (m :: * -> *).
Monad m =>
EClassId -> [EClassId] -> RangeTree Double -> EGraphST m [EClassId]
go EClassId
m [EClassId]
bests RangeTree Double
t
                                       else if EClass -> Bool
p EClass
ec
                                              then EClassId
-> [EClassId] -> RangeTree Double -> StateT EGraph m [EClassId]
forall (m :: * -> *).
Monad m =>
EClassId -> [EClassId] -> RangeTree Double -> EGraphST m [EClassId]
go (EClassId
mEClassId -> EClassId -> EClassId
forall a. Num a => a -> a -> a
-EClassId
1) (EClassId
ecIdEClassId -> [EClassId] -> [EClassId]
forall a. a -> [a] -> [a]
:[EClassId]
bests) RangeTree Double
t
                                              else EClassId
-> [EClassId] -> RangeTree Double -> StateT EGraph m [EClassId]
forall (m :: * -> *).
Monad m =>
EClassId -> [EClassId] -> RangeTree Double -> EGraphST m [EClassId]
go EClassId
m [EClassId]
bests RangeTree Double
t

getTopECLassIn :: Monad m => Bool -> Int -> (EClass -> Bool) -> [EClassId] -> EGraphST m [EClassId]
getTopECLassIn :: forall (m :: * -> *).
Monad m =>
Bool
-> EClassId
-> (EClass -> Bool)
-> [EClassId]
-> EGraphST m [EClassId]
getTopECLassIn Bool
b EClassId
n EClass -> Bool
p [EClassId]
ecs' = do
  let f :: EGraphDB -> RangeTree Double
f = if Bool
b then EGraphDB -> RangeTree Double
_fitRangeDB else EGraphDB -> RangeTree Double
_dlRangeDB
  (EGraph -> RangeTree Double) -> StateT EGraph m (RangeTree Double)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EGraphDB -> RangeTree Double
f (EGraphDB -> RangeTree Double)
-> (EGraph -> EGraphDB) -> EGraph -> RangeTree Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> EGraphDB
_eDB)
    StateT EGraph m (RangeTree Double)
-> (RangeTree Double -> EGraphST m [EClassId])
-> EGraphST m [EClassId]
forall a b.
StateT EGraph m a -> (a -> StateT EGraph m b) -> StateT EGraph m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= EClassId -> [EClassId] -> RangeTree Double -> EGraphST m [EClassId]
forall (m :: * -> *).
Monad m =>
EClassId -> [EClassId] -> RangeTree Double -> EGraphST m [EClassId]
go EClassId
n []
  where
    ecs :: HashSet EClassId
ecs = [EClassId] -> HashSet EClassId
forall a. Hashable a => [a] -> HashSet a
Set.fromList [EClassId]
ecs'
    go :: Monad m => Int -> [EClassId] -> RangeTree Double -> EGraphST m [EClassId]
    go :: forall (m :: * -> *).
Monad m =>
EClassId -> [EClassId] -> RangeTree Double -> EGraphST m [EClassId]
go EClassId
0 [EClassId]
bests RangeTree Double
rt = [EClassId] -> StateT EGraph m [EClassId]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [EClassId]
bests
    go EClassId
m [EClassId]
bests RangeTree Double
rt = case RangeTree Double
rt of
                       RangeTree Double
Empty   -> [EClassId] -> StateT EGraph m [EClassId]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [EClassId]
bests
                       RangeTree Double
t :|> (Double, EClassId)
y -> do let x :: EClassId
x = (Double, EClassId) -> EClassId
forall a b. (a, b) -> b
snd (Double, EClassId)
y
                                     EClassId
ecId <- EClassId -> EGraphST m EClassId
forall (m :: * -> *). Monad m => EClassId -> EGraphST m EClassId
canonical EClassId
x
                                     EClass
ec <- (EGraph -> EClass) -> StateT EGraph m EClass
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((IntMap EClass -> EClassId -> EClass
forall a. IntMap a -> EClassId -> a
IntMap.! EClassId
ecId) (IntMap EClass -> EClass)
-> (EGraph -> IntMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> IntMap EClass
_eClass)
                                     if (Double -> Bool
forall a. RealFloat a => a -> Bool
isInfinite (Double -> Bool) -> (EClass -> Double) -> EClass -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe Double -> Double)
-> (EClass -> Maybe Double) -> EClass -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClassData -> Maybe Double
_fitness (EClassData -> Maybe Double)
-> (EClass -> EClassData) -> EClass -> Maybe Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info (EClass -> Bool) -> EClass -> Bool
forall a b. (a -> b) -> a -> b
$ EClass
ec)
                                       then EClassId
-> [EClassId] -> RangeTree Double -> StateT EGraph m [EClassId]
forall (m :: * -> *).
Monad m =>
EClassId -> [EClassId] -> RangeTree Double -> EGraphST m [EClassId]
go EClassId
m [EClassId]
bests RangeTree Double
t -- pure bests
                                       else if EClassId
ecId EClassId -> HashSet EClassId -> Bool
forall a. Hashable a => a -> HashSet a -> Bool
`Set.member` HashSet EClassId
ecs Bool -> Bool -> Bool
&& EClass -> Bool
p EClass
ec
                                              then EClassId
-> [EClassId] -> RangeTree Double -> StateT EGraph m [EClassId]
forall (m :: * -> *).
Monad m =>
EClassId -> [EClassId] -> RangeTree Double -> EGraphST m [EClassId]
go (EClassId
mEClassId -> EClassId -> EClassId
forall a. Num a => a -> a -> a
-EClassId
1) (EClassId
ecIdEClassId -> [EClassId] -> [EClassId]
forall a. a -> [a] -> [a]
:[EClassId]
bests) RangeTree Double
t
                                              else EClassId
-> [EClassId] -> RangeTree Double -> StateT EGraph m [EClassId]
forall (m :: * -> *).
Monad m =>
EClassId -> [EClassId] -> RangeTree Double -> EGraphST m [EClassId]
go EClassId
m [EClassId]
bests RangeTree Double
t

getTopECLassNotIn :: Monad m => Bool -> Int -> (EClass -> Bool) -> [EClassId] -> EGraphST m [EClassId]
getTopECLassNotIn :: forall (m :: * -> *).
Monad m =>
Bool
-> EClassId
-> (EClass -> Bool)
-> [EClassId]
-> EGraphST m [EClassId]
getTopECLassNotIn Bool
b EClassId
n EClass -> Bool
p [EClassId]
ecs' = do
  let f :: EGraphDB -> RangeTree Double
f = if Bool
b then EGraphDB -> RangeTree Double
_fitRangeDB else EGraphDB -> RangeTree Double
_dlRangeDB
  (EGraph -> RangeTree Double) -> StateT EGraph m (RangeTree Double)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EGraphDB -> RangeTree Double
f (EGraphDB -> RangeTree Double)
-> (EGraph -> EGraphDB) -> EGraph -> RangeTree Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> EGraphDB
_eDB)
    StateT EGraph m (RangeTree Double)
-> (RangeTree Double -> EGraphST m [EClassId])
-> EGraphST m [EClassId]
forall a b.
StateT EGraph m a -> (a -> StateT EGraph m b) -> StateT EGraph m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= EClassId -> [EClassId] -> RangeTree Double -> EGraphST m [EClassId]
forall (m :: * -> *).
Monad m =>
EClassId -> [EClassId] -> RangeTree Double -> EGraphST m [EClassId]
go EClassId
n []
  where
    ecs :: HashSet EClassId
ecs = [EClassId] -> HashSet EClassId
forall a. Hashable a => [a] -> HashSet a
Set.fromList [EClassId]
ecs'

    go :: Monad m => Int -> [EClassId] -> RangeTree Double -> EGraphST m [EClassId]
    go :: forall (m :: * -> *).
Monad m =>
EClassId -> [EClassId] -> RangeTree Double -> EGraphST m [EClassId]
go EClassId
0 [EClassId]
bests RangeTree Double
rt = [EClassId] -> StateT EGraph m [EClassId]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [EClassId]
bests
    go EClassId
m [EClassId]
bests RangeTree Double
rt = case RangeTree Double
rt of
                       RangeTree Double
Empty   -> [EClassId] -> StateT EGraph m [EClassId]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [EClassId]
bests
                       RangeTree Double
t :|> (Double, EClassId)
y -> do let x :: EClassId
x = (Double, EClassId) -> EClassId
forall a b. (a, b) -> b
snd (Double, EClassId)
y
                                     EClassId
ecId <- EClassId -> EGraphST m EClassId
forall (m :: * -> *). Monad m => EClassId -> EGraphST m EClassId
canonical EClassId
x
                                     EClass
ec <- (EGraph -> EClass) -> StateT EGraph m EClass
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((IntMap EClass -> EClassId -> EClass
forall a. IntMap a -> EClassId -> a
IntMap.! EClassId
ecId) (IntMap EClass -> EClass)
-> (EGraph -> IntMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> IntMap EClass
_eClass)
                                     if (Double -> Bool
forall a. RealFloat a => a -> Bool
isInfinite (Double -> Bool) -> (EClass -> Double) -> EClass -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe Double -> Double)
-> (EClass -> Maybe Double) -> EClass -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClassData -> Maybe Double
_fitness (EClassData -> Maybe Double)
-> (EClass -> EClassData) -> EClass -> Maybe Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info (EClass -> Bool) -> EClass -> Bool
forall a b. (a -> b) -> a -> b
$ EClass
ec)
                                       then EClassId
-> [EClassId] -> RangeTree Double -> StateT EGraph m [EClassId]
forall (m :: * -> *).
Monad m =>
EClassId -> [EClassId] -> RangeTree Double -> EGraphST m [EClassId]
go EClassId
m [EClassId]
bests RangeTree Double
t
                                       else if Bool -> Bool
not (EClassId
ecId EClassId -> HashSet EClassId -> Bool
forall a. Hashable a => a -> HashSet a -> Bool
`Set.member` HashSet EClassId
ecs) Bool -> Bool -> Bool
&& EClass -> Bool
p EClass
ec
                                              then EClassId
-> [EClassId] -> RangeTree Double -> StateT EGraph m [EClassId]
forall (m :: * -> *).
Monad m =>
EClassId -> [EClassId] -> RangeTree Double -> EGraphST m [EClassId]
go (EClassId
mEClassId -> EClassId -> EClassId
forall a. Num a => a -> a -> a
-EClassId
1) (EClassId
ecIdEClassId -> [EClassId] -> [EClassId]
forall a. a -> [a] -> [a]
:[EClassId]
bests) RangeTree Double
t
                                              else EClassId
-> [EClassId] -> RangeTree Double -> StateT EGraph m [EClassId]
forall (m :: * -> *).
Monad m =>
EClassId -> [EClassId] -> RangeTree Double -> EGraphST m [EClassId]
go EClassId
m [EClassId]
bests RangeTree Double
t

getAllEvaluatedEClasses :: Monad m => EGraphST m [EClassId]
getAllEvaluatedEClasses :: forall (m :: * -> *). Monad m => EGraphST m [EClassId]
getAllEvaluatedEClasses = do
  (EGraph -> RangeTree Double) -> StateT EGraph m (RangeTree Double)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EGraphDB -> RangeTree Double
_fitRangeDB (EGraphDB -> RangeTree Double)
-> (EGraph -> EGraphDB) -> EGraph -> RangeTree Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> EGraphDB
_eDB)
    StateT EGraph m (RangeTree Double)
-> (RangeTree Double -> EGraphST m [EClassId])
-> EGraphST m [EClassId]
forall a b.
StateT EGraph m a -> (a -> StateT EGraph m b) -> StateT EGraph m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= [EClassId] -> RangeTree Double -> EGraphST m [EClassId]
forall (m :: * -> *).
Monad m =>
[EClassId] -> RangeTree Double -> EGraphST m [EClassId]
go []
  where
    go :: Monad m => [EClassId] -> RangeTree Double -> EGraphST m [EClassId]
    go :: forall (m :: * -> *).
Monad m =>
[EClassId] -> RangeTree Double -> EGraphST m [EClassId]
go [EClassId]
bests RangeTree Double
rt = case RangeTree Double
rt of
                    RangeTree Double
Empty   -> [EClassId] -> EGraphST m [EClassId]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [EClassId]
bests
                    RangeTree Double
t :|> (Double, EClassId)
y -> do let x :: EClassId
x = (Double, EClassId) -> EClassId
forall a b. (a, b) -> b
snd (Double, EClassId)
y
                                  EClassId
ecId <- EClassId -> EGraphST m EClassId
forall (m :: * -> *). Monad m => EClassId -> EGraphST m EClassId
canonical EClassId
x
                                  EClass
ec <- (EGraph -> EClass) -> StateT EGraph m EClass
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((IntMap EClass -> EClassId -> EClass
forall a. IntMap a -> EClassId -> a
IntMap.! EClassId
ecId) (IntMap EClass -> EClass)
-> (EGraph -> IntMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> IntMap EClass
_eClass)
                                  if (Double -> Bool
forall a. RealFloat a => a -> Bool
isInfinite (Double -> Bool) -> (EClass -> Double) -> EClass -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe Double -> Double)
-> (EClass -> Maybe Double) -> EClass -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClassData -> Maybe Double
_fitness (EClassData -> Maybe Double)
-> (EClass -> EClassData) -> EClass -> Maybe Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info (EClass -> Bool) -> EClass -> Bool
forall a b. (a -> b) -> a -> b
$ EClass
ec)
                                    then [EClassId] -> RangeTree Double -> EGraphST m [EClassId]
forall (m :: * -> *).
Monad m =>
[EClassId] -> RangeTree Double -> EGraphST m [EClassId]
go [EClassId]
bests RangeTree Double
t
                                    else [EClassId] -> RangeTree Double -> EGraphST m [EClassId]
forall (m :: * -> *).
Monad m =>
[EClassId] -> RangeTree Double -> EGraphST m [EClassId]
go (EClassId
ecIdEClassId -> [EClassId] -> [EClassId]
forall a. a -> [a] -> [a]
:[EClassId]
bests) RangeTree Double
t

getTopEClassWithSize :: Monad m => Bool -> Int -> Int -> EGraphST m [EClassId]
getTopEClassWithSize :: forall (m :: * -> *).
Monad m =>
Bool -> EClassId -> EClassId -> EGraphST m [EClassId]
getTopEClassWithSize Bool
b EClassId
sz EClassId
n = do
   let fun :: EGraphDB -> IntMap (RangeTree Double)
fun = if Bool
b then EGraphDB -> IntMap (RangeTree Double)
_sizeFitDB else EGraphDB -> IntMap (RangeTree Double)
_sizeDLDB
   (EGraph -> [EClassId]) -> EGraphST m [EClassId]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EClassId -> [EClassId] -> Maybe (RangeTree Double) -> [EClassId]
forall {a} {t} {a}.
(RealFloat a, Eq t, Num t) =>
t -> [a] -> Maybe (Seq (a, a)) -> [a]
go EClassId
n [] (Maybe (RangeTree Double) -> [EClassId])
-> (EGraph -> Maybe (RangeTree Double)) -> EGraph -> [EClassId]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap (RangeTree Double) -> EClassId -> Maybe (RangeTree Double)
forall a. IntMap a -> EClassId -> Maybe a
IntMap.!? EClassId
sz) (IntMap (RangeTree Double) -> Maybe (RangeTree Double))
-> (EGraph -> IntMap (RangeTree Double))
-> EGraph
-> Maybe (RangeTree Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraphDB -> IntMap (RangeTree Double)
fun (EGraphDB -> IntMap (RangeTree Double))
-> (EGraph -> EGraphDB) -> EGraph -> IntMap (RangeTree Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> EGraphDB
_eDB)
    -- >>= mapM canonical
  where
    -- go :: Monad m => Int -> [EClassId] -> Maybe (RangeTree Double) -> EGraphST m [EClassId]
    go :: t -> [a] -> Maybe (Seq (a, a)) -> [a]
go t
_ [a]
bests Maybe (Seq (a, a))
Nothing   = []
    go t
0 [a]
bests (Just Seq (a, a)
rt) = [a]
bests
    go t
m [a]
bests (Just Seq (a, a)
rt) = case Seq (a, a)
rt of
                             Seq (a, a)
Empty   -> [a]
bests
                             Seq (a, a)
t :|> (a
f, a
x) -> if a -> Bool
forall a. RealFloat a => a -> Bool
isInfinite a
f Bool -> Bool -> Bool
|| a -> Bool
forall a. RealFloat a => a -> Bool
isNaN a
f then t -> [a] -> Maybe (Seq (a, a)) -> [a]
go t
m [a]
bests (Seq (a, a) -> Maybe (Seq (a, a))
forall a. a -> Maybe a
Just Seq (a, a)
t) else t -> [a] -> Maybe (Seq (a, a)) -> [a]
go (t
mt -> t -> t
forall a. Num a => a -> a -> a
-t
1) (a
xa -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
bests) (Seq (a, a) -> Maybe (Seq (a, a))
forall a. a -> Maybe a
Just Seq (a, a)
t)

getTopFitEClassThat :: Monad m => Int -> (EClass -> Bool) -> EGraphST m [EClassId]
getTopFitEClassThat :: forall (m :: * -> *).
Monad m =>
EClassId -> (EClass -> Bool) -> EGraphST m [EClassId]
getTopFitEClassThat  = Bool -> EClassId -> (EClass -> Bool) -> EGraphST m [EClassId]
forall (m :: * -> *).
Monad m =>
Bool -> EClassId -> (EClass -> Bool) -> EGraphST m [EClassId]
getTopECLassThat Bool
True
getTopDLEClassThat :: Monad m => Int -> (EClass -> Bool) -> EGraphST m [EClassId]
getTopDLEClassThat :: forall (m :: * -> *).
Monad m =>
EClassId -> (EClass -> Bool) -> EGraphST m [EClassId]
getTopDLEClassThat   = Bool -> EClassId -> (EClass -> Bool) -> EGraphST m [EClassId]
forall (m :: * -> *).
Monad m =>
Bool -> EClassId -> (EClass -> Bool) -> EGraphST m [EClassId]
getTopECLassThat Bool
False
getTopFitEClassIn :: Monad m =>  Int -> (EClass -> Bool) -> [EClassId] -> EGraphST m [EClassId]
getTopFitEClassIn :: forall (m :: * -> *).
Monad m =>
EClassId -> (EClass -> Bool) -> [EClassId] -> EGraphST m [EClassId]
getTopFitEClassIn    = Bool
-> EClassId
-> (EClass -> Bool)
-> [EClassId]
-> EGraphST m [EClassId]
forall (m :: * -> *).
Monad m =>
Bool
-> EClassId
-> (EClass -> Bool)
-> [EClassId]
-> EGraphST m [EClassId]
getTopECLassIn Bool
True
getTopDLEClassIn :: Monad m => Int -> (EClass -> Bool) -> [EClassId] -> EGraphST m [EClassId]
getTopDLEClassIn :: forall (m :: * -> *).
Monad m =>
EClassId -> (EClass -> Bool) -> [EClassId] -> EGraphST m [EClassId]
getTopDLEClassIn     = Bool
-> EClassId
-> (EClass -> Bool)
-> [EClassId]
-> EGraphST m [EClassId]
forall (m :: * -> *).
Monad m =>
Bool
-> EClassId
-> (EClass -> Bool)
-> [EClassId]
-> EGraphST m [EClassId]
getTopECLassIn Bool
False
getTopFitEClassNotIn :: Monad m => Int -> (EClass -> Bool) -> [EClassId] -> EGraphST m [EClassId]
getTopFitEClassNotIn :: forall (m :: * -> *).
Monad m =>
EClassId -> (EClass -> Bool) -> [EClassId] -> EGraphST m [EClassId]
getTopFitEClassNotIn = Bool
-> EClassId
-> (EClass -> Bool)
-> [EClassId]
-> EGraphST m [EClassId]
forall (m :: * -> *).
Monad m =>
Bool
-> EClassId
-> (EClass -> Bool)
-> [EClassId]
-> EGraphST m [EClassId]
getTopECLassNotIn Bool
True
getTopDLEClassNotIn :: Monad m => Int -> (EClass -> Bool) -> [EClassId] -> EGraphST m [EClassId]
getTopDLEClassNotIn :: forall (m :: * -> *).
Monad m =>
EClassId -> (EClass -> Bool) -> [EClassId] -> EGraphST m [EClassId]
getTopDLEClassNotIn  = Bool
-> EClassId
-> (EClass -> Bool)
-> [EClassId]
-> EGraphST m [EClassId]
forall (m :: * -> *).
Monad m =>
Bool
-> EClassId
-> (EClass -> Bool)
-> [EClassId]
-> EGraphST m [EClassId]
getTopECLassNotIn Bool
True
getTopFitEClassWithSize :: Monad m => Int -> Int -> EGraphST m [EClassId]
getTopFitEClassWithSize :: forall (m :: * -> *).
Monad m =>
EClassId -> EClassId -> EGraphST m [EClassId]
getTopFitEClassWithSize = Bool -> EClassId -> EClassId -> EGraphST m [EClassId]
forall (m :: * -> *).
Monad m =>
Bool -> EClassId -> EClassId -> EGraphST m [EClassId]
getTopEClassWithSize Bool
True
getTopDLEClassWithSize :: Monad m => Int -> Int -> EGraphST m [EClassId]
getTopDLEClassWithSize :: forall (m :: * -> *).
Monad m =>
EClassId -> EClassId -> EGraphST m [EClassId]
getTopDLEClassWithSize  = Bool -> EClassId -> EClassId -> EGraphST m [EClassId]
forall (m :: * -> *).
Monad m =>
Bool -> EClassId -> EClassId -> EGraphST m [EClassId]
getTopEClassWithSize Bool
False

rebuildAllRanges :: Monad m => EGraphST m ()
rebuildAllRanges :: forall (m :: * -> *). Monad m => EGraphST m ()
rebuildAllRanges = do IntMap (RangeTree Double)
szF <- (EGraph -> IntMap (RangeTree Double))
-> StateT EGraph m (IntMap (RangeTree Double))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EGraphDB -> IntMap (RangeTree Double)
_sizeFitDB(EGraphDB -> IntMap (RangeTree Double))
-> (EGraph -> EGraphDB) -> EGraph -> IntMap (RangeTree Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.EGraph -> EGraphDB
_eDB) StateT EGraph m (IntMap (RangeTree Double))
-> (IntMap (RangeTree Double)
    -> StateT EGraph m (IntMap (RangeTree Double)))
-> StateT EGraph m (IntMap (RangeTree Double))
forall a b.
StateT EGraph m a -> (a -> StateT EGraph m b) -> StateT EGraph m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (RangeTree Double -> StateT EGraph m (RangeTree Double))
-> IntMap (RangeTree Double)
-> StateT EGraph m (IntMap (RangeTree Double))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> IntMap a -> f (IntMap b)
traverse RangeTree Double -> StateT EGraph m (RangeTree Double)
forall (m :: * -> *).
Monad m =>
RangeTree Double -> EGraphST m (RangeTree Double)
rebuildRange
                      IntMap (RangeTree Double)
dlF <- (EGraph -> IntMap (RangeTree Double))
-> StateT EGraph m (IntMap (RangeTree Double))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EGraphDB -> IntMap (RangeTree Double)
_sizeDLDB(EGraphDB -> IntMap (RangeTree Double))
-> (EGraph -> EGraphDB) -> EGraph -> IntMap (RangeTree Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.EGraph -> EGraphDB
_eDB) StateT EGraph m (IntMap (RangeTree Double))
-> (IntMap (RangeTree Double)
    -> StateT EGraph m (IntMap (RangeTree Double)))
-> StateT EGraph m (IntMap (RangeTree Double))
forall a b.
StateT EGraph m a -> (a -> StateT EGraph m b) -> StateT EGraph m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (RangeTree Double -> StateT EGraph m (RangeTree Double))
-> IntMap (RangeTree Double)
-> StateT EGraph m (IntMap (RangeTree Double))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> IntMap a -> f (IntMap b)
traverse RangeTree Double -> StateT EGraph m (RangeTree Double)
forall (m :: * -> *).
Monad m =>
RangeTree Double -> EGraphST m (RangeTree Double)
rebuildRange
                      RangeTree Double
fR  <- (EGraph -> RangeTree Double) -> StateT EGraph m (RangeTree Double)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EGraphDB -> RangeTree Double
_fitRangeDB(EGraphDB -> RangeTree Double)
-> (EGraph -> EGraphDB) -> EGraph -> RangeTree Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
.EGraph -> EGraphDB
_eDB) StateT EGraph m (RangeTree Double)
-> (RangeTree Double -> StateT EGraph m (RangeTree Double))
-> StateT EGraph m (RangeTree Double)
forall a b.
StateT EGraph m a -> (a -> StateT EGraph m b) -> StateT EGraph m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= RangeTree Double -> StateT EGraph m (RangeTree Double)
forall (m :: * -> *).
Monad m =>
RangeTree Double -> EGraphST m (RangeTree Double)
rebuildRange
                      RangeTree Double
dR  <- (EGraph -> RangeTree Double) -> StateT EGraph m (RangeTree Double)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EGraphDB -> RangeTree Double
_dlRangeDB(EGraphDB -> RangeTree Double)
-> (EGraph -> EGraphDB) -> EGraph -> RangeTree Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
.EGraph -> EGraphDB
_eDB) StateT EGraph m (RangeTree Double)
-> (RangeTree Double -> StateT EGraph m (RangeTree Double))
-> StateT EGraph m (RangeTree Double)
forall a b.
StateT EGraph m a -> (a -> StateT EGraph m b) -> StateT EGraph m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= RangeTree Double -> StateT EGraph m (RangeTree Double)
forall (m :: * -> *).
Monad m =>
RangeTree Double -> EGraphST m (RangeTree Double)
rebuildRange

                      (EGraph -> EGraph) -> EGraphST m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> EGraphST m ())
-> (EGraph -> EGraph) -> EGraphST m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (RangeTree Double) (RangeTree Double)
-> (RangeTree Double -> RangeTree Double) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((RangeTree Double -> Identity (RangeTree Double))
    -> EGraphDB -> Identity EGraphDB)
-> ASetter EGraph EGraph (RangeTree Double) (RangeTree Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(RangeTree Double -> Identity (RangeTree Double))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (RangeTree Double)
fitRangeDB) (RangeTree Double -> RangeTree Double -> RangeTree Double
forall a b. a -> b -> a
const RangeTree Double
fR)
                              (EGraph -> EGraph) -> (EGraph -> EGraph) -> EGraph -> EGraph
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ASetter EGraph EGraph (RangeTree Double) (RangeTree Double)
-> (RangeTree Double -> RangeTree Double) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((RangeTree Double -> Identity (RangeTree Double))
    -> EGraphDB -> Identity EGraphDB)
-> ASetter EGraph EGraph (RangeTree Double) (RangeTree Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(RangeTree Double -> Identity (RangeTree Double))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (RangeTree Double)
dlRangeDB) (RangeTree Double -> RangeTree Double -> RangeTree Double
forall a b. a -> b -> a
const RangeTree Double
dR)
                              (EGraph -> EGraph) -> (EGraph -> EGraph) -> EGraph -> EGraph
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ASetter
  EGraph
  EGraph
  (IntMap (RangeTree Double))
  (IntMap (RangeTree Double))
-> (IntMap (RangeTree Double) -> IntMap (RangeTree Double))
-> EGraph
-> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((IntMap (RangeTree Double)
     -> Identity (IntMap (RangeTree Double)))
    -> EGraphDB -> Identity EGraphDB)
-> ASetter
     EGraph
     EGraph
     (IntMap (RangeTree Double))
     (IntMap (RangeTree Double))
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(IntMap (RangeTree Double) -> Identity (IntMap (RangeTree Double)))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (IntMap (RangeTree Double))
sizeFitDB) (IntMap (RangeTree Double)
-> IntMap (RangeTree Double) -> IntMap (RangeTree Double)
forall a b. a -> b -> a
const IntMap (RangeTree Double)
szF)
                              (EGraph -> EGraph) -> (EGraph -> EGraph) -> EGraph -> EGraph
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ASetter
  EGraph
  EGraph
  (IntMap (RangeTree Double))
  (IntMap (RangeTree Double))
-> (IntMap (RangeTree Double) -> IntMap (RangeTree Double))
-> EGraph
-> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((IntMap (RangeTree Double)
     -> Identity (IntMap (RangeTree Double)))
    -> EGraphDB -> Identity EGraphDB)
-> ASetter
     EGraph
     EGraph
     (IntMap (RangeTree Double))
     (IntMap (RangeTree Double))
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(IntMap (RangeTree Double) -> Identity (IntMap (RangeTree Double)))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (IntMap (RangeTree Double))
sizeDLDB) (IntMap (RangeTree Double)
-> IntMap (RangeTree Double) -> IntMap (RangeTree Double)
forall a b. a -> b -> a
const IntMap (RangeTree Double)
dlF)

canonizeRange :: Monad m => RangeTree Double -> EGraphST m (RangeTree Double)
canonizeRange :: forall (m :: * -> *).
Monad m =>
RangeTree Double -> EGraphST m (RangeTree Double)
canonizeRange = ((Double, EClassId) -> StateT EGraph m (Double, EClassId))
-> RangeTree Double -> StateT EGraph m (RangeTree Double)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Seq a -> f (Seq b)
traverse (\(Double
x, EClassId
eid) -> (Double
x,) (EClassId -> (Double, EClassId))
-> StateT EGraph m EClassId -> StateT EGraph m (Double, EClassId)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> EClassId -> StateT EGraph m EClassId
forall (m :: * -> *). Monad m => EClassId -> EGraphST m EClassId
canonical EClassId
eid)

rebuildRange :: Monad m => RangeTree Double -> EGraphST m (RangeTree Double)
rebuildRange :: forall (m :: * -> *).
Monad m =>
RangeTree Double -> EGraphST m (RangeTree Double)
rebuildRange RangeTree Double
rt = HashSet EClassId
-> RangeTree Double -> RangeTree Double -> RangeTree Double
go HashSet EClassId
forall a. HashSet a
Set.empty RangeTree Double
forall a. Seq a
Empty (RangeTree Double -> RangeTree Double)
-> StateT EGraph m (RangeTree Double)
-> StateT EGraph m (RangeTree Double)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> RangeTree Double -> StateT EGraph m (RangeTree Double)
forall (m :: * -> *).
Monad m =>
RangeTree Double -> EGraphST m (RangeTree Double)
canonizeRange RangeTree Double
rt
  where
    go :: Set.HashSet EClassId -> RangeTree Double -> RangeTree Double -> RangeTree Double
    go :: HashSet EClassId
-> RangeTree Double -> RangeTree Double -> RangeTree Double
go HashSet EClassId
seen RangeTree Double
root RangeTree Double
Empty = RangeTree Double
root
    go HashSet EClassId
seen RangeTree Double
root (RangeTree Double
xs :|> (Double
x,EClassId
eid)) = HashSet EClassId
-> RangeTree Double -> RangeTree Double -> RangeTree Double
go (EClassId -> HashSet EClassId -> HashSet EClassId
forall a. Hashable a => a -> HashSet a -> HashSet a
Set.insert EClassId
eid HashSet EClassId
seen)
                                       (if EClassId -> HashSet EClassId -> Bool
forall a. Hashable a => a -> HashSet a -> Bool
Set.member EClassId
eid HashSet EClassId
seen
                                          then RangeTree Double
root
                                          else (Double
x, EClassId
eid) (Double, EClassId) -> RangeTree Double -> RangeTree Double
forall a. a -> Seq a -> Seq a
:<| RangeTree Double
root)
                                        RangeTree Double
xs -- (Prelude.filter ((/= eid) . snd) xs)