module Futhark.Optimise.Fusion.GraphRep
(
EdgeT (..),
NodeT (..),
DepContext,
DepGraphAug,
DepGraph (..),
DepNode,
getName,
nodeFromLNode,
mergedContext,
mapAcross,
edgesBetween,
reachable,
applyAugs,
depsFromEdge,
contractEdge,
isRealNode,
isCons,
isDep,
isInf,
mkDepGraph,
mkDepGraphForFun,
pprg,
isWithAccNodeT,
isWithAccNodeId,
vFusionFeasability,
hFusionFeasability,
)
where
import Control.Monad.Reader
import Data.Bifunctor (bimap)
import Data.Foldable (foldlM)
import Data.Graph.Inductive.Dot qualified as G
import Data.Graph.Inductive.Graph qualified as G
import Data.Graph.Inductive.Query.DFS qualified as Q
import Data.Graph.Inductive.Tree qualified as G
import Data.List qualified as L
import Data.Map.Strict qualified as M
import Data.Maybe (mapMaybe)
import Data.Set qualified as S
import Futhark.Analysis.Alias qualified as Alias
import Futhark.Analysis.HORep.SOAC qualified as H
import Futhark.IR.Prop.Aliases
import Futhark.IR.SOACS hiding (SOAC (..))
import Futhark.IR.SOACS qualified as Futhark
import Futhark.Util (nubOrd)
data EdgeT
= Alias VName
| InfDep VName
| Dep VName
| Cons VName
| Fake VName
| Res VName
deriving (EdgeT -> EdgeT -> Bool
(EdgeT -> EdgeT -> Bool) -> (EdgeT -> EdgeT -> Bool) -> Eq EdgeT
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: EdgeT -> EdgeT -> Bool
== :: EdgeT -> EdgeT -> Bool
$c/= :: EdgeT -> EdgeT -> Bool
/= :: EdgeT -> EdgeT -> Bool
Eq, Eq EdgeT
Eq EdgeT =>
(EdgeT -> EdgeT -> Ordering)
-> (EdgeT -> EdgeT -> Bool)
-> (EdgeT -> EdgeT -> Bool)
-> (EdgeT -> EdgeT -> Bool)
-> (EdgeT -> EdgeT -> Bool)
-> (EdgeT -> EdgeT -> EdgeT)
-> (EdgeT -> EdgeT -> EdgeT)
-> Ord EdgeT
EdgeT -> EdgeT -> Bool
EdgeT -> EdgeT -> Ordering
EdgeT -> EdgeT -> EdgeT
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: EdgeT -> EdgeT -> Ordering
compare :: EdgeT -> EdgeT -> Ordering
$c< :: EdgeT -> EdgeT -> Bool
< :: EdgeT -> EdgeT -> Bool
$c<= :: EdgeT -> EdgeT -> Bool
<= :: EdgeT -> EdgeT -> Bool
$c> :: EdgeT -> EdgeT -> Bool
> :: EdgeT -> EdgeT -> Bool
$c>= :: EdgeT -> EdgeT -> Bool
>= :: EdgeT -> EdgeT -> Bool
$cmax :: EdgeT -> EdgeT -> EdgeT
max :: EdgeT -> EdgeT -> EdgeT
$cmin :: EdgeT -> EdgeT -> EdgeT
min :: EdgeT -> EdgeT -> EdgeT
Ord)
data NodeT
= StmNode (Stm SOACS)
| SoacNode H.ArrayTransforms (Pat Type) (H.SOAC SOACS) (StmAux (ExpDec SOACS))
|
TransNode VName H.ArrayTransform VName
|
ResNode VName
|
FreeNode VName
| MatchNode (Stm SOACS) [(NodeT, [EdgeT])]
| DoNode (Stm SOACS) [(NodeT, [EdgeT])]
deriving (NodeT -> NodeT -> Bool
(NodeT -> NodeT -> Bool) -> (NodeT -> NodeT -> Bool) -> Eq NodeT
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: NodeT -> NodeT -> Bool
== :: NodeT -> NodeT -> Bool
$c/= :: NodeT -> NodeT -> Bool
/= :: NodeT -> NodeT -> Bool
Eq)
instance Show EdgeT where
show :: EdgeT -> String
show (Dep VName
vName) = String
"Dep " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> VName -> String
forall a. Pretty a => a -> String
prettyString VName
vName
show (InfDep VName
vName) = String
"iDep " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> VName -> String
forall a. Pretty a => a -> String
prettyString VName
vName
show (Cons VName
_) = String
"Cons"
show (Fake VName
_) = String
"Fake"
show (Res VName
_) = String
"Res"
show (Alias VName
_) = String
"Alias"
instance Show NodeT where
show :: NodeT -> String
show (StmNode (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
_ Exp SOACS
_)) = String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
L.intercalate String
", " ([String] -> String) -> [String] -> String
forall a b. (a -> b) -> a -> b
$ (VName -> String) -> [VName] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map VName -> String
forall a. Pretty a => a -> String
prettyString ([VName] -> [String]) -> [VName] -> [String]
forall a b. (a -> b) -> a -> b
$ Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
Pat (LetDec SOACS)
pat
show (SoacNode ArrayTransforms
_ Pat Type
pat SOAC SOACS
_ StmAux (ExpDec SOACS)
_) = Pat Type -> String
forall a. Pretty a => a -> String
prettyString Pat Type
pat
show (TransNode VName
_ ArrayTransform
tr VName
_) = ArrayTransform -> String
forall a. Pretty a => a -> String
prettyString ArrayTransform
tr
show (ResNode VName
name) = ShowS
forall a. Pretty a => a -> String
prettyString ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$ String
"Res: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
prettyString VName
name
show (FreeNode VName
name) = ShowS
forall a. Pretty a => a -> String
prettyString ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$ String
"Input: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
prettyString VName
name
show (MatchNode Stm SOACS
stm [(NodeT, [EdgeT])]
_) = String
"Match: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
L.intercalate String
", " ((VName -> String) -> [VName] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map VName -> String
forall a. Pretty a => a -> String
prettyString ([VName] -> [String]) -> [VName] -> [String]
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> [VName]
stmNames Stm SOACS
stm)
show (DoNode Stm SOACS
stm [(NodeT, [EdgeT])]
_) = String
"Do: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
L.intercalate String
", " ((VName -> String) -> [VName] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map VName -> String
forall a. Pretty a => a -> String
prettyString ([VName] -> [String]) -> [VName] -> [String]
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> [VName]
stmNames Stm SOACS
stm)
getName :: EdgeT -> VName
getName :: EdgeT -> VName
getName EdgeT
edgeT = case EdgeT
edgeT of
Alias VName
vn -> VName
vn
InfDep VName
vn -> VName
vn
Dep VName
vn -> VName
vn
Cons VName
vn -> VName
vn
Fake VName
vn -> VName
vn
Res VName
vn -> VName
vn
isRealNode :: NodeT -> Bool
isRealNode :: NodeT -> Bool
isRealNode ResNode {} = Bool
False
isRealNode FreeNode {} = Bool
False
isRealNode NodeT
_ = Bool
True
pprg :: DepGraph -> String
pprg :: DepGraph -> String
pprg = Dot () -> String
forall a. Dot a -> String
G.showDot (Dot () -> String) -> (DepGraph -> Dot ()) -> DepGraph -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Gr String String -> Dot ()
forall (gr :: * -> * -> *). Graph gr => gr String String -> Dot ()
G.fglToDotString (Gr String String -> Dot ())
-> (DepGraph -> Gr String String) -> DepGraph -> Dot ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (NodeT -> String)
-> (EdgeT -> String) -> Gr NodeT EdgeT -> Gr String String
forall (gr :: * -> * -> *) a c b d.
DynGraph gr =>
(a -> c) -> (b -> d) -> gr a b -> gr c d
G.nemap NodeT -> String
forall a. Show a => a -> String
show EdgeT -> String
forall a. Show a => a -> String
show (Gr NodeT EdgeT -> Gr String String)
-> (DepGraph -> Gr NodeT EdgeT) -> DepGraph -> Gr String String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DepGraph -> Gr NodeT EdgeT
dgGraph
type DepNode = G.LNode NodeT
type DepEdge = G.LEdge EdgeT
type DepContext = G.Context NodeT EdgeT
data DepGraph = DepGraph
{ DepGraph -> Gr NodeT EdgeT
dgGraph :: G.Gr NodeT EdgeT,
DepGraph -> ProducerMapping
dgProducerMapping :: ProducerMapping,
DepGraph -> AliasTable
dgAliasTable :: AliasTable
}
type DepGraphAug m = DepGraph -> m DepGraph
type EdgeGenerator = NodeT -> [(VName, EdgeT)]
type ProducerMapping = M.Map VName G.Node
makeMapping :: (Monad m) => DepGraphAug m
makeMapping :: forall (m :: * -> *). Monad m => DepGraphAug m
makeMapping dg :: DepGraph
dg@(DepGraph {dgGraph :: DepGraph -> Gr NodeT EdgeT
dgGraph = Gr NodeT EdgeT
g}) =
DepGraph -> m DepGraph
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure DepGraph
dg {dgProducerMapping = M.fromList $ concatMap gen_dep_list (G.labNodes g)}
where
gen_dep_list :: DepNode -> [(VName, G.Node)]
gen_dep_list :: DepNode -> [(VName, Node)]
gen_dep_list (Node
i, NodeT
node) = [(VName
name, Node
i) | VName
name <- NodeT -> [VName]
getOutputs NodeT
node]
applyAugs :: (Monad m) => [DepGraphAug m] -> DepGraphAug m
applyAugs :: forall (m :: * -> *). Monad m => [DepGraphAug m] -> DepGraphAug m
applyAugs [DepGraphAug m]
augs DepGraph
g = (DepGraph -> DepGraphAug m -> m DepGraph)
-> DepGraph -> [DepGraphAug m] -> m DepGraph
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldlM ((DepGraphAug m -> DepGraphAug m)
-> DepGraph -> DepGraphAug m -> m DepGraph
forall a b c. (a -> b -> c) -> b -> a -> c
flip DepGraphAug m -> DepGraphAug m
forall a b. (a -> b) -> a -> b
($)) DepGraph
g [DepGraphAug m]
augs
genEdges :: (Monad m) => [DepNode] -> EdgeGenerator -> DepGraphAug m
genEdges :: forall (m :: * -> *).
Monad m =>
[DepNode] -> EdgeGenerator -> DepGraphAug m
genEdges [DepNode]
l_stms EdgeGenerator
edge_fun DepGraph
dg =
[DepEdge] -> DepGraphAug m
forall (m :: * -> *). Monad m => [DepEdge] -> DepGraphAug m
depGraphInsertEdges ((DepNode -> [DepEdge]) -> [DepNode] -> [DepEdge]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (ProducerMapping -> DepNode -> [DepEdge]
genEdge (DepGraph -> ProducerMapping
dgProducerMapping DepGraph
dg)) [DepNode]
l_stms) DepGraph
dg
where
genEdge :: M.Map VName G.Node -> DepNode -> [G.LEdge EdgeT]
genEdge :: ProducerMapping -> DepNode -> [DepEdge]
genEdge ProducerMapping
name_map (Node
from, NodeT
node) = do
(VName
dep, EdgeT
edgeT) <- EdgeGenerator
edge_fun NodeT
node
Just Node
to <- [VName -> ProducerMapping -> Maybe Node
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
dep ProducerMapping
name_map]
DepEdge -> [DepEdge]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DepEdge -> [DepEdge]) -> DepEdge -> [DepEdge]
forall a b. (a -> b) -> a -> b
$ Edge -> EdgeT -> DepEdge
forall b. Edge -> b -> LEdge b
G.toLEdge (Node
from, Node
to) EdgeT
edgeT
depGraphInsertEdges :: (Monad m) => [DepEdge] -> DepGraphAug m
depGraphInsertEdges :: forall (m :: * -> *). Monad m => [DepEdge] -> DepGraphAug m
depGraphInsertEdges [DepEdge]
edgs DepGraph
dg = DepGraph -> m DepGraph
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DepGraph -> m DepGraph) -> DepGraph -> m DepGraph
forall a b. (a -> b) -> a -> b
$ DepGraph
dg {dgGraph = G.insEdges edgs $ dgGraph dg}
mapAcross :: (Monad m) => (DepContext -> m DepContext) -> DepGraphAug m
mapAcross :: forall (m :: * -> *).
Monad m =>
(DepContext -> m DepContext) -> DepGraphAug m
mapAcross DepContext -> m DepContext
f DepGraph
dg = do
Gr NodeT EdgeT
g' <- (Gr NodeT EdgeT -> Node -> m (Gr NodeT EdgeT))
-> Gr NodeT EdgeT -> [Node] -> m (Gr NodeT EdgeT)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldlM ((Node -> Gr NodeT EdgeT -> m (Gr NodeT EdgeT))
-> Gr NodeT EdgeT -> Node -> m (Gr NodeT EdgeT)
forall a b c. (a -> b -> c) -> b -> a -> c
flip Node -> Gr NodeT EdgeT -> m (Gr NodeT EdgeT)
forall {gr :: * -> * -> *}.
DynGraph gr =>
Node -> gr NodeT EdgeT -> m (gr NodeT EdgeT)
helper) (DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg) (Gr NodeT EdgeT -> [Node]
forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> [Node]
G.nodes (DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg))
DepGraph -> m DepGraph
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DepGraph -> m DepGraph) -> DepGraph -> m DepGraph
forall a b. (a -> b) -> a -> b
$ DepGraph
dg {dgGraph = g'}
where
helper :: Node -> gr NodeT EdgeT -> m (gr NodeT EdgeT)
helper Node
n gr NodeT EdgeT
g' = case Node -> gr NodeT EdgeT -> Decomp gr NodeT EdgeT
forall a b. Node -> gr a b -> Decomp gr a b
forall (gr :: * -> * -> *) a b.
Graph gr =>
Node -> gr a b -> Decomp gr a b
G.match Node
n gr NodeT EdgeT
g' of
(Just DepContext
c, gr NodeT EdgeT
g_new) -> do
DepContext
c' <- DepContext -> m DepContext
f DepContext
c
gr NodeT EdgeT -> m (gr NodeT EdgeT)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (gr NodeT EdgeT -> m (gr NodeT EdgeT))
-> gr NodeT EdgeT -> m (gr NodeT EdgeT)
forall a b. (a -> b) -> a -> b
$ DepContext
c' DepContext -> gr NodeT EdgeT -> gr NodeT EdgeT
forall a b. Context a b -> gr a b -> gr a b
forall (gr :: * -> * -> *) a b.
DynGraph gr =>
Context a b -> gr a b -> gr a b
G.& gr NodeT EdgeT
g_new
(Maybe DepContext
Nothing, gr NodeT EdgeT
_) -> gr NodeT EdgeT -> m (gr NodeT EdgeT)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure gr NodeT EdgeT
g'
stmFromNode :: NodeT -> Stms SOACS
stmFromNode :: NodeT -> Stms SOACS
stmFromNode (StmNode Stm SOACS
x) = Stm SOACS -> Stms SOACS
forall rep. Stm rep -> Stms rep
oneStm Stm SOACS
x
stmFromNode NodeT
_ = Stms SOACS
forall a. Monoid a => a
mempty
nodeFromLNode :: DepNode -> G.Node
nodeFromLNode :: DepNode -> Node
nodeFromLNode = DepNode -> Node
forall a b. (a, b) -> a
fst
depsFromEdge :: DepEdge -> VName
depsFromEdge :: DepEdge -> VName
depsFromEdge = EdgeT -> VName
getName (EdgeT -> VName) -> (DepEdge -> EdgeT) -> DepEdge -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DepEdge -> EdgeT
forall b. LEdge b -> b
G.edgeLabel
edgesBetween :: DepGraph -> G.Node -> G.Node -> [DepEdge]
edgesBetween :: DepGraph -> Node -> Node -> [DepEdge]
edgesBetween DepGraph
dg Node
n1 Node
n2 = Gr NodeT EdgeT -> [DepEdge]
forall a b. Gr a b -> [LEdge b]
forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> [LEdge b]
G.labEdges (Gr NodeT EdgeT -> [DepEdge]) -> Gr NodeT EdgeT -> [DepEdge]
forall a b. (a -> b) -> a -> b
$ [Node] -> Gr NodeT EdgeT -> Gr NodeT EdgeT
forall (gr :: * -> * -> *) a b.
DynGraph gr =>
[Node] -> gr a b -> gr a b
G.subgraph [Node
n1, Node
n2] (Gr NodeT EdgeT -> Gr NodeT EdgeT)
-> Gr NodeT EdgeT -> Gr NodeT EdgeT
forall a b. (a -> b) -> a -> b
$ DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg
reachable :: DepGraph -> G.Node -> G.Node -> Bool
reachable :: DepGraph -> Node -> Node -> Bool
reachable DepGraph
dg Node
source Node
target = Node
target Node -> [Node] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` Node -> Gr NodeT EdgeT -> [Node]
forall (gr :: * -> * -> *) a b.
Graph gr =>
Node -> gr a b -> [Node]
Q.reachable Node
source (DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg)
augWithFun :: (Monad m) => EdgeGenerator -> DepGraphAug m
augWithFun :: forall (m :: * -> *). Monad m => EdgeGenerator -> DepGraphAug m
augWithFun EdgeGenerator
f DepGraph
dg = [DepNode] -> EdgeGenerator -> DepGraphAug m
forall (m :: * -> *).
Monad m =>
[DepNode] -> EdgeGenerator -> DepGraphAug m
genEdges (Gr NodeT EdgeT -> [DepNode]
forall a b. Gr a b -> [LNode a]
forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> [LNode a]
G.labNodes (DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg)) EdgeGenerator
f DepGraph
dg
addDeps :: (Monad m) => DepGraphAug m
addDeps :: forall (m :: * -> *). Monad m => DepGraphAug m
addDeps = EdgeGenerator -> DepGraphAug m
forall (m :: * -> *). Monad m => EdgeGenerator -> DepGraphAug m
augWithFun EdgeGenerator
toDep
where
toDep :: EdgeGenerator
toDep NodeT
stmt =
let ([VName]
fusible, [VName]
infusible) =
([(VName, Classification)] -> [VName])
-> ([(VName, Classification)] -> [VName])
-> ([(VName, Classification)], [(VName, Classification)])
-> ([VName], [VName])
forall a b c d. (a -> b) -> (c -> d) -> (a, c) -> (b, d)
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap (((VName, Classification) -> VName)
-> [(VName, Classification)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, Classification) -> VName
forall a b. (a, b) -> a
fst) (((VName, Classification) -> VName)
-> [(VName, Classification)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, Classification) -> VName
forall a b. (a, b) -> a
fst)
(([(VName, Classification)], [(VName, Classification)])
-> ([VName], [VName]))
-> (Classifications
-> ([(VName, Classification)], [(VName, Classification)]))
-> Classifications
-> ([VName], [VName])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((VName, Classification) -> Bool)
-> [(VName, Classification)]
-> ([(VName, Classification)], [(VName, Classification)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
L.partition ((Classification -> Classification -> Bool
forall a. Eq a => a -> a -> Bool
== Classification
SOACInput) (Classification -> Bool)
-> ((VName, Classification) -> Classification)
-> (VName, Classification)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, Classification) -> Classification
forall a b. (a, b) -> b
snd)
([(VName, Classification)]
-> ([(VName, Classification)], [(VName, Classification)]))
-> (Classifications -> [(VName, Classification)])
-> Classifications
-> ([(VName, Classification)], [(VName, Classification)])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Classifications -> [(VName, Classification)]
forall a. Set a -> [a]
S.toList
(Classifications -> ([VName], [VName]))
-> Classifications -> ([VName], [VName])
forall a b. (a -> b) -> a -> b
$ (Stm SOACS -> Classifications) -> Stms SOACS -> Classifications
forall m a. Monoid m => (a -> m) -> Seq a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm SOACS -> Classifications
stmInputs (NodeT -> Stms SOACS
stmFromNode NodeT
stmt)
mkDep :: VName -> (VName, EdgeT)
mkDep VName
vname = (VName
vname, VName -> EdgeT
Dep VName
vname)
mkInfDep :: VName -> (VName, EdgeT)
mkInfDep VName
vname = (VName
vname, VName -> EdgeT
InfDep VName
vname)
in (VName -> (VName, EdgeT)) -> [VName] -> [(VName, EdgeT)]
forall a b. (a -> b) -> [a] -> [b]
map VName -> (VName, EdgeT)
mkDep [VName]
fusible [(VName, EdgeT)] -> [(VName, EdgeT)] -> [(VName, EdgeT)]
forall a. Semigroup a => a -> a -> a
<> (VName -> (VName, EdgeT)) -> [VName] -> [(VName, EdgeT)]
forall a b. (a -> b) -> [a] -> [b]
map VName -> (VName, EdgeT)
mkInfDep [VName]
infusible
addConsAndAliases :: (Monad m) => DepGraphAug m
addConsAndAliases :: forall (m :: * -> *). Monad m => DepGraphAug m
addConsAndAliases = EdgeGenerator -> DepGraphAug m
forall (m :: * -> *). Monad m => EdgeGenerator -> DepGraphAug m
augWithFun EdgeGenerator
edges
where
edges :: EdgeGenerator
edges (StmNode Stm SOACS
s) = Stm (Aliases SOACS) -> [(VName, EdgeT)]
forall {rep}. Aliased rep => Stm rep -> [(VName, EdgeT)]
consEdges Stm (Aliases SOACS)
s' [(VName, EdgeT)] -> [(VName, EdgeT)] -> [(VName, EdgeT)]
forall a. Semigroup a => a -> a -> a
<> Stm (Aliases SOACS) -> [(VName, EdgeT)]
aliasEdges Stm (Aliases SOACS)
s'
where
s' :: Stm (Aliases SOACS)
s' = AliasTable -> Stm SOACS -> Stm (Aliases SOACS)
forall rep.
AliasableRep rep =>
AliasTable -> Stm rep -> Stm (Aliases rep)
Alias.analyseStm AliasTable
forall a. Monoid a => a
mempty Stm SOACS
s
edges NodeT
_ = [(VName, EdgeT)]
forall a. Monoid a => a
mempty
consEdges :: Stm rep -> [(VName, EdgeT)]
consEdges Stm rep
s = [VName] -> [EdgeT] -> [(VName, EdgeT)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
names ((VName -> EdgeT) -> [VName] -> [EdgeT]
forall a b. (a -> b) -> [a] -> [b]
map VName -> EdgeT
Cons [VName]
names)
where
names :: [VName]
names = Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm rep -> Names
forall rep. Aliased rep => Stm rep -> Names
consumedInStm Stm rep
s
aliasEdges :: Stm (Aliases SOACS) -> [(VName, EdgeT)]
aliasEdges =
(VName -> (VName, EdgeT)) -> [VName] -> [(VName, EdgeT)]
forall a b. (a -> b) -> [a] -> [b]
map (\VName
vname -> (VName
vname, VName -> EdgeT
Alias VName
vname))
([VName] -> [(VName, EdgeT)])
-> (Stm (Aliases SOACS) -> [VName])
-> Stm (Aliases SOACS)
-> [(VName, EdgeT)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList
(Names -> [VName])
-> (Stm (Aliases SOACS) -> Names) -> Stm (Aliases SOACS) -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat
([Names] -> Names)
-> (Stm (Aliases SOACS) -> [Names]) -> Stm (Aliases SOACS) -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec (Aliases SOACS)) -> [Names]
forall dec. AliasesOf dec => Pat dec -> [Names]
patAliases
(Pat (LetDec (Aliases SOACS)) -> [Names])
-> (Stm (Aliases SOACS) -> Pat (LetDec (Aliases SOACS)))
-> Stm (Aliases SOACS)
-> [Names]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm (Aliases SOACS) -> Pat (LetDec (Aliases SOACS))
forall rep. Stm rep -> Pat (LetDec rep)
stmPat
addExtraCons :: (Monad m) => DepGraphAug m
DepGraph
dg =
[DepEdge] -> DepGraphAug m
forall (m :: * -> *). Monad m => [DepEdge] -> DepGraphAug m
depGraphInsertEdges ((DepEdge -> [DepEdge]) -> [DepEdge] -> [DepEdge]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap DepEdge -> [DepEdge]
makeEdge (Gr NodeT EdgeT -> [DepEdge]
forall a b. Gr a b -> [LEdge b]
forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> [LEdge b]
G.labEdges Gr NodeT EdgeT
g)) DepGraph
dg
where
g :: Gr NodeT EdgeT
g = DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg
alias_table :: AliasTable
alias_table = DepGraph -> AliasTable
dgAliasTable DepGraph
dg
mapping :: ProducerMapping
mapping = DepGraph -> ProducerMapping
dgProducerMapping DepGraph
dg
makeEdge :: DepEdge -> [DepEdge]
makeEdge (Node
from, Node
to, Cons VName
cname) = do
let aliases :: [VName]
aliases = Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Names -> VName -> AliasTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
cname AliasTable
alias_table
to' :: [Node]
to' = (VName -> Maybe Node) -> [VName] -> [Node]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (VName -> ProducerMapping -> Maybe Node
forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` ProducerMapping
mapping) [VName]
aliases
p :: (Node, EdgeT) -> Bool
p (Node
tonode, EdgeT
toedge) =
Node
tonode Node -> Node -> Bool
forall a. Eq a => a -> a -> Bool
/= Node
from Bool -> Bool -> Bool
&& EdgeT -> VName
getName EdgeT
toedge VName -> [VName] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` (VName
cname VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
aliases)
(Node
to2, EdgeT
_) <- ((Node, EdgeT) -> Bool) -> [(Node, EdgeT)] -> [(Node, EdgeT)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Node, EdgeT) -> Bool
p ([(Node, EdgeT)] -> [(Node, EdgeT)])
-> [(Node, EdgeT)] -> [(Node, EdgeT)]
forall a b. (a -> b) -> a -> b
$ (Node -> [(Node, EdgeT)]) -> [Node] -> [(Node, EdgeT)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Gr NodeT EdgeT -> Node -> [(Node, EdgeT)]
forall (gr :: * -> * -> *) a b.
Graph gr =>
gr a b -> Node -> [(Node, b)]
G.lpre Gr NodeT EdgeT
g) [Node]
to' [(Node, EdgeT)] -> [(Node, EdgeT)] -> [(Node, EdgeT)]
forall a. Semigroup a => a -> a -> a
<> Gr NodeT EdgeT -> Node -> [(Node, EdgeT)]
forall (gr :: * -> * -> *) a b.
Graph gr =>
gr a b -> Node -> [(Node, b)]
G.lpre Gr NodeT EdgeT
g Node
to
DepEdge -> [DepEdge]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DepEdge -> [DepEdge]) -> DepEdge -> [DepEdge]
forall a b. (a -> b) -> a -> b
$ Edge -> EdgeT -> DepEdge
forall b. Edge -> b -> LEdge b
G.toLEdge (Node
from, Node
to2) (VName -> EdgeT
Fake VName
cname)
makeEdge DepEdge
_ = []
mapAcrossNodeTs :: (Monad m) => (NodeT -> m NodeT) -> DepGraphAug m
mapAcrossNodeTs :: forall (m :: * -> *).
Monad m =>
(NodeT -> m NodeT) -> DepGraphAug m
mapAcrossNodeTs NodeT -> m NodeT
f = (DepContext -> m DepContext) -> DepGraphAug m
forall (m :: * -> *).
Monad m =>
(DepContext -> m DepContext) -> DepGraphAug m
mapAcross DepContext -> m DepContext
forall {a} {b} {d}. (a, b, NodeT, d) -> m (a, b, NodeT, d)
f'
where
f' :: (a, b, NodeT, d) -> m (a, b, NodeT, d)
f' (a
ins, b
n, NodeT
nodeT, d
outs) = do
NodeT
nodeT' <- NodeT -> m NodeT
f NodeT
nodeT
(a, b, NodeT, d) -> m (a, b, NodeT, d)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
ins, b
n, NodeT
nodeT', d
outs)
nodeToSoacNode :: (HasScope SOACS m, Monad m) => NodeT -> m NodeT
nodeToSoacNode :: forall (m :: * -> *).
(HasScope SOACS m, Monad m) =>
NodeT -> m NodeT
nodeToSoacNode n :: NodeT
n@(StmNode s :: Stm SOACS
s@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux Exp SOACS
op)) = case Exp SOACS
op of
Op {} -> do
Either NotSOAC (SOAC SOACS)
maybeSoac <- Exp SOACS -> m (Either NotSOAC (SOAC SOACS))
forall rep (m :: * -> *).
(Op rep ~ SOAC rep, HasScope rep m) =>
Exp rep -> m (Either NotSOAC (SOAC rep))
H.fromExp Exp SOACS
op
case Either NotSOAC (SOAC SOACS)
maybeSoac of
Right SOAC SOACS
hsoac -> NodeT -> m NodeT
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (NodeT -> m NodeT) -> NodeT -> m NodeT
forall a b. (a -> b) -> a -> b
$ ArrayTransforms
-> Pat Type -> SOAC SOACS -> StmAux (ExpDec SOACS) -> NodeT
SoacNode ArrayTransforms
forall a. Monoid a => a
mempty Pat Type
Pat (LetDec SOACS)
pat SOAC SOACS
hsoac StmAux (ExpDec SOACS)
aux
Left NotSOAC
H.NotSOAC -> NodeT -> m NodeT
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure NodeT
n
Loop {} ->
NodeT -> m NodeT
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (NodeT -> m NodeT) -> NodeT -> m NodeT
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> [(NodeT, [EdgeT])] -> NodeT
DoNode Stm SOACS
s []
Match {} ->
NodeT -> m NodeT
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (NodeT -> m NodeT) -> NodeT -> m NodeT
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> [(NodeT, [EdgeT])] -> NodeT
MatchNode Stm SOACS
s []
Exp SOACS
e
| [VName
output] <- Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
Pat (LetDec SOACS)
pat,
Just (VName
ia, ArrayTransform
tr) <- Certs -> Exp SOACS -> Maybe (VName, ArrayTransform)
forall rep. Certs -> Exp rep -> Maybe (VName, ArrayTransform)
H.transformFromExp (StmAux () -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) Exp SOACS
e ->
NodeT -> m NodeT
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (NodeT -> m NodeT) -> NodeT -> m NodeT
forall a b. (a -> b) -> a -> b
$ VName -> ArrayTransform -> VName -> NodeT
TransNode VName
output ArrayTransform
tr VName
ia
Exp SOACS
_ -> NodeT -> m NodeT
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure NodeT
n
nodeToSoacNode NodeT
n = NodeT -> m NodeT
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure NodeT
n
emptyGraph :: Body SOACS -> DepGraph
emptyGraph :: Body SOACS -> DepGraph
emptyGraph Body SOACS
body =
DepGraph
{ dgGraph :: Gr NodeT EdgeT
dgGraph = [DepNode] -> [DepEdge] -> Gr NodeT EdgeT
forall a b. [LNode a] -> [LEdge b] -> Gr a b
forall (gr :: * -> * -> *) a b.
Graph gr =>
[LNode a] -> [LEdge b] -> gr a b
G.mkGraph ([NodeT] -> [DepNode]
forall {b}. [b] -> [(Node, b)]
labelNodes ([NodeT]
stmnodes [NodeT] -> [NodeT] -> [NodeT]
forall a. Semigroup a => a -> a -> a
<> [NodeT]
resnodes [NodeT] -> [NodeT] -> [NodeT]
forall a. Semigroup a => a -> a -> a
<> [NodeT]
inputnodes)) [],
dgProducerMapping :: ProducerMapping
dgProducerMapping = ProducerMapping
forall a. Monoid a => a
mempty,
dgAliasTable :: AliasTable
dgAliasTable = AliasTable
aliases
}
where
labelNodes :: [b] -> [(Node, b)]
labelNodes = [Node] -> [b] -> [(Node, b)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Node
0 ..]
stmnodes :: [NodeT]
stmnodes = (Stm SOACS -> NodeT) -> [Stm SOACS] -> [NodeT]
forall a b. (a -> b) -> [a] -> [b]
map Stm SOACS -> NodeT
StmNode ([Stm SOACS] -> [NodeT]) -> [Stm SOACS] -> [NodeT]
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms SOACS -> [Stm SOACS]) -> Stms SOACS -> [Stm SOACS]
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Stms SOACS
forall rep. Body rep -> Stms rep
bodyStms Body SOACS
body
resnodes :: [NodeT]
resnodes = (VName -> NodeT) -> [VName] -> [NodeT]
forall a b. (a -> b) -> [a] -> [b]
map VName -> NodeT
ResNode ([VName] -> [NodeT]) -> [VName] -> [NodeT]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Result -> Names
forall a. FreeIn a => a -> Names
freeIn (Result -> Names) -> Result -> Names
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult Body SOACS
body
inputnodes :: [NodeT]
inputnodes = (VName -> NodeT) -> [VName] -> [NodeT]
forall a b. (a -> b) -> [a] -> [b]
map VName -> NodeT
FreeNode ([VName] -> [NodeT]) -> [VName] -> [NodeT]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
consumed
(Stms (Aliases SOACS)
_, (AliasTable
aliases, Names
consumed)) = AliasTable
-> Stms SOACS -> (Stms (Aliases SOACS), AliasesAndConsumed)
forall rep.
AliasableRep rep =>
AliasTable -> Stms rep -> (Stms (Aliases rep), AliasesAndConsumed)
Alias.analyseStms AliasTable
forall a. Monoid a => a
mempty (Stms SOACS -> (Stms (Aliases SOACS), AliasesAndConsumed))
-> Stms SOACS -> (Stms (Aliases SOACS), AliasesAndConsumed)
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Stms SOACS
forall rep. Body rep -> Stms rep
bodyStms Body SOACS
body
getStmRes :: EdgeGenerator
getStmRes :: EdgeGenerator
getStmRes (ResNode VName
name) = [(VName
name, VName -> EdgeT
Res VName
name)]
getStmRes NodeT
_ = []
addResEdges :: (Monad m) => DepGraphAug m
addResEdges :: forall (m :: * -> *). Monad m => DepGraphAug m
addResEdges = EdgeGenerator -> DepGraphAug m
forall (m :: * -> *). Monad m => EdgeGenerator -> DepGraphAug m
augWithFun EdgeGenerator
getStmRes
mkDepGraph :: (HasScope SOACS m, Monad m) => Body SOACS -> m DepGraph
mkDepGraph :: forall (m :: * -> *).
(HasScope SOACS m, Monad m) =>
Body SOACS -> m DepGraph
mkDepGraph Body SOACS
body = [DepGraphAug m] -> DepGraphAug m
forall (m :: * -> *). Monad m => [DepGraphAug m] -> DepGraphAug m
applyAugs [DepGraphAug m]
augs DepGraphAug m -> DepGraphAug m
forall a b. (a -> b) -> a -> b
$ Body SOACS -> DepGraph
emptyGraph Body SOACS
body
where
augs :: [DepGraphAug m]
augs =
[ DepGraphAug m
forall (m :: * -> *). Monad m => DepGraphAug m
makeMapping,
DepGraphAug m
forall (m :: * -> *). Monad m => DepGraphAug m
addDeps,
DepGraphAug m
forall (m :: * -> *). Monad m => DepGraphAug m
addConsAndAliases,
DepGraphAug m
forall (m :: * -> *). Monad m => DepGraphAug m
addExtraCons,
DepGraphAug m
forall (m :: * -> *). Monad m => DepGraphAug m
addResEdges,
(NodeT -> m NodeT) -> DepGraphAug m
forall (m :: * -> *).
Monad m =>
(NodeT -> m NodeT) -> DepGraphAug m
mapAcrossNodeTs NodeT -> m NodeT
forall (m :: * -> *).
(HasScope SOACS m, Monad m) =>
NodeT -> m NodeT
nodeToSoacNode
]
mkDepGraphForFun :: FunDef SOACS -> DepGraph
mkDepGraphForFun :: FunDef SOACS -> DepGraph
mkDepGraphForFun FunDef SOACS
f = Reader (Scope SOACS) DepGraph -> Scope SOACS -> DepGraph
forall r a. Reader r a -> r -> a
runReader (Body SOACS -> Reader (Scope SOACS) DepGraph
forall (m :: * -> *).
(HasScope SOACS m, Monad m) =>
Body SOACS -> m DepGraph
mkDepGraph (FunDef SOACS -> Body SOACS
forall rep. FunDef rep -> Body rep
funDefBody FunDef SOACS
f)) Scope SOACS
scope
where
scope :: Scope SOACS
scope = [Param (FParamInfo SOACS)] -> Scope SOACS
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams (FunDef SOACS -> [Param (FParamInfo SOACS)]
forall rep. FunDef rep -> [FParam rep]
funDefParams FunDef SOACS
f) Scope SOACS -> Scope SOACS -> Scope SOACS
forall a. Semigroup a => a -> a -> a
<> Stms SOACS -> Scope SOACS
forall rep a. Scoped rep a => a -> Scope rep
scopeOf (Body SOACS -> Stms SOACS
forall rep. Body rep -> Stms rep
bodyStms (FunDef SOACS -> Body SOACS
forall rep. FunDef rep -> Body rep
funDefBody FunDef SOACS
f))
mergedContext :: (Ord b) => a -> G.Context a b -> G.Context a b -> G.Context a b
mergedContext :: forall b a. Ord b => a -> Context a b -> Context a b -> Context a b
mergedContext a
mergedlabel (Adj b
inp1, Node
n1, a
_, Adj b
out1) (Adj b
inp2, Node
n2, a
_, Adj b
out2) =
let new_inp :: Adj b
new_inp = ((b, Node) -> Bool) -> Adj b -> Adj b
forall a. (a -> Bool) -> [a] -> [a]
filter (\(b, Node)
n -> (b, Node) -> Node
forall a b. (a, b) -> b
snd (b, Node)
n Node -> Node -> Bool
forall a. Eq a => a -> a -> Bool
/= Node
n1 Bool -> Bool -> Bool
&& (b, Node) -> Node
forall a b. (a, b) -> b
snd (b, Node)
n Node -> Node -> Bool
forall a. Eq a => a -> a -> Bool
/= Node
n2) (Adj b -> Adj b
forall a. Ord a => [a] -> [a]
nubOrd (Adj b
inp1 Adj b -> Adj b -> Adj b
forall a. Semigroup a => a -> a -> a
<> Adj b
inp2))
new_out :: Adj b
new_out = ((b, Node) -> Bool) -> Adj b -> Adj b
forall a. (a -> Bool) -> [a] -> [a]
filter (\(b, Node)
n -> (b, Node) -> Node
forall a b. (a, b) -> b
snd (b, Node)
n Node -> Node -> Bool
forall a. Eq a => a -> a -> Bool
/= Node
n1 Bool -> Bool -> Bool
&& (b, Node) -> Node
forall a b. (a, b) -> b
snd (b, Node)
n Node -> Node -> Bool
forall a. Eq a => a -> a -> Bool
/= Node
n2) (Adj b -> Adj b
forall a. Ord a => [a] -> [a]
nubOrd (Adj b
out1 Adj b -> Adj b -> Adj b
forall a. Semigroup a => a -> a -> a
<> Adj b
out2))
in (Adj b
new_inp, Node
n1, a
mergedlabel, Adj b
new_out)
contractEdge :: (Monad m) => G.Node -> DepContext -> DepGraphAug m
contractEdge :: forall (m :: * -> *).
Monad m =>
Node -> DepContext -> DepGraphAug m
contractEdge Node
n2 DepContext
ctx DepGraph
dg = do
let n1 :: Node
n1 = DepContext -> Node
forall a b. Context a b -> Node
G.node' DepContext
ctx
DepGraph -> m DepGraph
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DepGraph -> m DepGraph) -> DepGraph -> m DepGraph
forall a b. (a -> b) -> a -> b
$ DepGraph
dg {dgGraph = ctx G.& G.delNodes [n1, n2] (dgGraph dg)}
data Classification
=
SOACInput
|
Other
deriving (Classification -> Classification -> Bool
(Classification -> Classification -> Bool)
-> (Classification -> Classification -> Bool) -> Eq Classification
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Classification -> Classification -> Bool
== :: Classification -> Classification -> Bool
$c/= :: Classification -> Classification -> Bool
/= :: Classification -> Classification -> Bool
Eq, Eq Classification
Eq Classification =>
(Classification -> Classification -> Ordering)
-> (Classification -> Classification -> Bool)
-> (Classification -> Classification -> Bool)
-> (Classification -> Classification -> Bool)
-> (Classification -> Classification -> Bool)
-> (Classification -> Classification -> Classification)
-> (Classification -> Classification -> Classification)
-> Ord Classification
Classification -> Classification -> Bool
Classification -> Classification -> Ordering
Classification -> Classification -> Classification
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: Classification -> Classification -> Ordering
compare :: Classification -> Classification -> Ordering
$c< :: Classification -> Classification -> Bool
< :: Classification -> Classification -> Bool
$c<= :: Classification -> Classification -> Bool
<= :: Classification -> Classification -> Bool
$c> :: Classification -> Classification -> Bool
> :: Classification -> Classification -> Bool
$c>= :: Classification -> Classification -> Bool
>= :: Classification -> Classification -> Bool
$cmax :: Classification -> Classification -> Classification
max :: Classification -> Classification -> Classification
$cmin :: Classification -> Classification -> Classification
min :: Classification -> Classification -> Classification
Ord, Node -> Classification -> ShowS
[Classification] -> ShowS
Classification -> String
(Node -> Classification -> ShowS)
-> (Classification -> String)
-> ([Classification] -> ShowS)
-> Show Classification
forall a.
(Node -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Node -> Classification -> ShowS
showsPrec :: Node -> Classification -> ShowS
$cshow :: Classification -> String
show :: Classification -> String
$cshowList :: [Classification] -> ShowS
showList :: [Classification] -> ShowS
Show)
type Classifications = S.Set (VName, Classification)
freeClassifications :: (FreeIn a) => a -> Classifications
freeClassifications :: forall a. FreeIn a => a -> Classifications
freeClassifications =
[(VName, Classification)] -> Classifications
forall a. Ord a => [a] -> Set a
S.fromList ([(VName, Classification)] -> Classifications)
-> (a -> [(VName, Classification)]) -> a -> Classifications
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([VName] -> [Classification] -> [(VName, Classification)]
forall a b. [a] -> [b] -> [(a, b)]
`zip` Classification -> [Classification]
forall a. a -> [a]
repeat Classification
Other) ([VName] -> [(VName, Classification)])
-> (a -> [VName]) -> a -> [(VName, Classification)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList (Names -> [VName]) -> (a -> Names) -> a -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Names
forall a. FreeIn a => a -> Names
freeIn
stmInputs :: Stm SOACS -> Classifications
stmInputs :: Stm SOACS -> Classifications
stmInputs (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux Exp SOACS
e) =
(Pat Type, StmAux ()) -> Classifications
forall a. FreeIn a => a -> Classifications
freeClassifications (Pat Type
Pat (LetDec SOACS)
pat, StmAux ()
StmAux (ExpDec SOACS)
aux) Classifications -> Classifications -> Classifications
forall a. Semigroup a => a -> a -> a
<> Exp SOACS -> Classifications
expInputs Exp SOACS
e
bodyInputs :: Body SOACS -> Classifications
bodyInputs :: Body SOACS -> Classifications
bodyInputs (Body BodyDec SOACS
_ Stms SOACS
stms Result
res) = (Stm SOACS -> Classifications) -> Stms SOACS -> Classifications
forall m a. Monoid m => (a -> m) -> Seq a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm SOACS -> Classifications
stmInputs Stms SOACS
stms Classifications -> Classifications -> Classifications
forall a. Semigroup a => a -> a -> a
<> Result -> Classifications
forall a. FreeIn a => a -> Classifications
freeClassifications Result
res
expInputs :: Exp SOACS -> Classifications
expInputs :: Exp SOACS -> Classifications
expInputs (Match [SubExp]
cond [Case (Body SOACS)]
cases Body SOACS
defbody MatchDec (BranchType SOACS)
attr) =
(Case (Body SOACS) -> Classifications)
-> [Case (Body SOACS)] -> Classifications
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Body SOACS -> Classifications
bodyInputs (Body SOACS -> Classifications)
-> (Case (Body SOACS) -> Body SOACS)
-> Case (Body SOACS)
-> Classifications
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Case (Body SOACS) -> Body SOACS
forall body. Case body -> body
caseBody) [Case (Body SOACS)]
cases
Classifications -> Classifications -> Classifications
forall a. Semigroup a => a -> a -> a
<> Body SOACS -> Classifications
bodyInputs Body SOACS
defbody
Classifications -> Classifications -> Classifications
forall a. Semigroup a => a -> a -> a
<> ([SubExp], MatchDec ExtType) -> Classifications
forall a. FreeIn a => a -> Classifications
freeClassifications ([SubExp]
cond, MatchDec ExtType
MatchDec (BranchType SOACS)
attr)
expInputs (Loop [(Param (FParamInfo SOACS), SubExp)]
params LoopForm
form Body SOACS
b1) =
([(Param DeclType, SubExp)], LoopForm) -> Classifications
forall a. FreeIn a => a -> Classifications
freeClassifications ([(Param DeclType, SubExp)]
[(Param (FParamInfo SOACS), SubExp)]
params, LoopForm
form) Classifications -> Classifications -> Classifications
forall a. Semigroup a => a -> a -> a
<> Body SOACS -> Classifications
bodyInputs Body SOACS
b1
expInputs (Op Op SOACS
soac) = case Op SOACS
soac of
Futhark.Screma SubExp
w [VName]
is ScremaForm SOACS
form -> [VName] -> Classifications
inputs [VName]
is Classifications -> Classifications -> Classifications
forall a. Semigroup a => a -> a -> a
<> (SubExp, ScremaForm SOACS) -> Classifications
forall a. FreeIn a => a -> Classifications
freeClassifications (SubExp
w, ScremaForm SOACS
form)
Futhark.Hist SubExp
w [VName]
is [HistOp SOACS]
ops Lambda SOACS
lam -> [VName] -> Classifications
inputs [VName]
is Classifications -> Classifications -> Classifications
forall a. Semigroup a => a -> a -> a
<> (SubExp, [HistOp SOACS], Lambda SOACS) -> Classifications
forall a. FreeIn a => a -> Classifications
freeClassifications (SubExp
w, [HistOp SOACS]
ops, Lambda SOACS
lam)
Futhark.Scatter SubExp
w [VName]
is ScatterSpec VName
lam Lambda SOACS
iws -> [VName] -> Classifications
inputs [VName]
is Classifications -> Classifications -> Classifications
forall a. Semigroup a => a -> a -> a
<> (SubExp, ScatterSpec VName, Lambda SOACS) -> Classifications
forall a. FreeIn a => a -> Classifications
freeClassifications (SubExp
w, ScatterSpec VName
lam, Lambda SOACS
iws)
Futhark.Stream SubExp
w [VName]
is [SubExp]
nes Lambda SOACS
lam ->
[VName] -> Classifications
inputs [VName]
is Classifications -> Classifications -> Classifications
forall a. Semigroup a => a -> a -> a
<> (SubExp, [SubExp], Lambda SOACS) -> Classifications
forall a. FreeIn a => a -> Classifications
freeClassifications (SubExp
w, [SubExp]
nes, Lambda SOACS
lam)
Futhark.JVP {} -> SOAC SOACS -> Classifications
forall a. FreeIn a => a -> Classifications
freeClassifications Op SOACS
SOAC SOACS
soac
Futhark.VJP {} -> SOAC SOACS -> Classifications
forall a. FreeIn a => a -> Classifications
freeClassifications Op SOACS
SOAC SOACS
soac
where
inputs :: [VName] -> Classifications
inputs = [(VName, Classification)] -> Classifications
forall a. Ord a => [a] -> Set a
S.fromList ([(VName, Classification)] -> Classifications)
-> ([VName] -> [(VName, Classification)])
-> [VName]
-> Classifications
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([VName] -> [Classification] -> [(VName, Classification)]
forall a b. [a] -> [b] -> [(a, b)]
`zip` Classification -> [Classification]
forall a. a -> [a]
repeat Classification
SOACInput)
expInputs Exp SOACS
e
| Just (VName
arr, ArrayTransform
_) <- Certs -> Exp SOACS -> Maybe (VName, ArrayTransform)
forall rep. Certs -> Exp rep -> Maybe (VName, ArrayTransform)
H.transformFromExp Certs
forall a. Monoid a => a
mempty Exp SOACS
e =
(VName, Classification) -> Classifications
forall a. a -> Set a
S.singleton (VName
arr, Classification
SOACInput)
Classifications -> Classifications -> Classifications
forall a. Semigroup a => a -> a -> a
<> Names -> Classifications
forall a. FreeIn a => a -> Classifications
freeClassifications (Exp SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn Exp SOACS
e Names -> Names -> Names
`namesSubtract` VName -> Names
oneName VName
arr)
| Bool
otherwise = Exp SOACS -> Classifications
forall a. FreeIn a => a -> Classifications
freeClassifications Exp SOACS
e
stmNames :: Stm SOACS -> [VName]
stmNames :: Stm SOACS -> [VName]
stmNames = Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat Type -> [VName])
-> (Stm SOACS -> Pat Type) -> Stm SOACS -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm SOACS -> Pat Type
Stm SOACS -> Pat (LetDec SOACS)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat
getOutputs :: NodeT -> [VName]
getOutputs :: NodeT -> [VName]
getOutputs NodeT
node = case NodeT
node of
(StmNode Stm SOACS
stm) -> Stm SOACS -> [VName]
stmNames Stm SOACS
stm
(TransNode VName
v ArrayTransform
_ VName
_) -> [VName
v]
(ResNode VName
_) -> []
(FreeNode VName
name) -> [VName
name]
(MatchNode Stm SOACS
stm [(NodeT, [EdgeT])]
_) -> Stm SOACS -> [VName]
stmNames Stm SOACS
stm
(DoNode Stm SOACS
stm [(NodeT, [EdgeT])]
_) -> Stm SOACS -> [VName]
stmNames Stm SOACS
stm
(SoacNode ArrayTransforms
_ Pat Type
pat SOAC SOACS
_ StmAux (ExpDec SOACS)
_) -> Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
pat
isDep :: EdgeT -> Bool
isDep :: EdgeT -> Bool
isDep (Dep VName
_) = Bool
True
isDep (Res VName
_) = Bool
True
isDep EdgeT
_ = Bool
False
isInf :: (G.Node, G.Node, EdgeT) -> Bool
isInf :: DepEdge -> Bool
isInf (Node
_, Node
_, EdgeT
e) = case EdgeT
e of
InfDep VName
_ -> Bool
True
Fake VName
_ -> Bool
True
EdgeT
_ -> Bool
False
isCons :: EdgeT -> Bool
isCons :: EdgeT -> Bool
isCons (Cons VName
_) = Bool
True
isCons EdgeT
_ = Bool
False
isWithAccNodeT :: NodeT -> Bool
isWithAccNodeT :: NodeT -> Bool
isWithAccNodeT (StmNode (Let Pat (LetDec SOACS)
_ StmAux (ExpDec SOACS)
_ (WithAcc [WithAccInput SOACS]
_ Lambda SOACS
_))) = Bool
True
isWithAccNodeT NodeT
_ = Bool
False
isWithAccNodeId :: G.Node -> DepGraph -> Bool
isWithAccNodeId :: Node -> DepGraph -> Bool
isWithAccNodeId Node
node_id (DepGraph {dgGraph :: DepGraph -> Gr NodeT EdgeT
dgGraph = Gr NodeT EdgeT
g}) =
let (Adj EdgeT
_, Node
_, NodeT
nT, Adj EdgeT
_) = Gr NodeT EdgeT -> Node -> DepContext
forall (gr :: * -> * -> *) a b.
Graph gr =>
gr a b -> Node -> Context a b
G.context Gr NodeT EdgeT
g Node
node_id
in NodeT -> Bool
isWithAccNodeT NodeT
nT
unreachableEitherDir :: DepGraph -> G.Node -> G.Node -> Bool
unreachableEitherDir :: DepGraph -> Node -> Node -> Bool
unreachableEitherDir DepGraph
g Node
a Node
b =
Bool -> Bool
not (DepGraph -> Node -> Node -> Bool
reachable DepGraph
g Node
a Node
b Bool -> Bool -> Bool
|| DepGraph -> Node -> Node -> Bool
reachable DepGraph
g Node
b Node
a)
vFusionFeasability :: DepGraph -> G.Node -> G.Node -> Bool
vFusionFeasability :: DepGraph -> Node -> Node -> Bool
vFusionFeasability dg :: DepGraph
dg@DepGraph {dgGraph :: DepGraph -> Gr NodeT EdgeT
dgGraph = Gr NodeT EdgeT
g} Node
n1 Node
n2 =
(Node -> DepGraph -> Bool
isWithAccNodeId Node
n2 DepGraph
dg Bool -> Bool -> Bool
|| Bool -> Bool
not ((DepEdge -> Bool) -> [DepEdge] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any DepEdge -> Bool
isInf (DepGraph -> Node -> Node -> [DepEdge]
edgesBetween DepGraph
dg Node
n1 Node
n2)))
Bool -> Bool -> Bool
&& Bool -> Bool
not ((Node -> Bool) -> [Node] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (DepGraph -> Node -> Node -> Bool
reachable DepGraph
dg Node
n2) ((Node -> Bool) -> [Node] -> [Node]
forall a. (a -> Bool) -> [a] -> [a]
filter (Node -> Node -> Bool
forall a. Eq a => a -> a -> Bool
/= Node
n2) (Gr NodeT EdgeT -> Node -> [Node]
forall (gr :: * -> * -> *) a b.
Graph gr =>
gr a b -> Node -> [Node]
G.pre Gr NodeT EdgeT
g Node
n1)))
hFusionFeasability :: DepGraph -> G.Node -> G.Node -> Bool
hFusionFeasability :: DepGraph -> Node -> Node -> Bool
hFusionFeasability = DepGraph -> Node -> Node -> Bool
unreachableEitherDir