{-# LANGUAGE Strict #-}
module Futhark.Optimise.Fusion (fuseSOACs) where
import Control.Monad
import Control.Monad.Reader
import Control.Monad.State
import Data.Graph.Inductive.Graph qualified as G
import Data.Graph.Inductive.Query.DFS qualified as Q
import Data.List qualified as L
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.Analysis.Alias qualified as Alias
import Futhark.Analysis.HORep.SOAC qualified as H
import Futhark.Construct
import Futhark.IR.Prop.Aliases
import Futhark.IR.SOACS hiding (SOAC (..))
import Futhark.IR.SOACS qualified as Futhark
import Futhark.IR.SOACS.Simplify (simplifyLambda)
import Futhark.Optimise.Fusion.GraphRep
import Futhark.Optimise.Fusion.RulesWithAccs qualified as SF
import Futhark.Optimise.Fusion.TryFusion qualified as TF
import Futhark.Pass
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
data FusionEnv = FusionEnv
{ FusionEnv -> VNameSource
vNameSource :: VNameSource,
FusionEnv -> Int
fusionCount :: Int,
FusionEnv -> Bool
fuseScans :: Bool
}
freshFusionEnv :: FusionEnv
freshFusionEnv :: FusionEnv
freshFusionEnv =
FusionEnv
{ vNameSource :: VNameSource
vNameSource = VNameSource
blankNameSource,
fusionCount :: Int
fusionCount = Int
0,
fuseScans :: Bool
fuseScans = Bool
True
}
newtype FusionM a = FusionM (ReaderT (Scope SOACS) (State FusionEnv) a)
deriving
( Applicative FusionM
Applicative FusionM =>
(forall a b. FusionM a -> (a -> FusionM b) -> FusionM b)
-> (forall a b. FusionM a -> FusionM b -> FusionM b)
-> (forall a. a -> FusionM a)
-> Monad FusionM
forall a. a -> FusionM a
forall a b. FusionM a -> FusionM b -> FusionM b
forall a b. FusionM a -> (a -> FusionM b) -> FusionM b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall a b. FusionM a -> (a -> FusionM b) -> FusionM b
>>= :: forall a b. FusionM a -> (a -> FusionM b) -> FusionM b
$c>> :: forall a b. FusionM a -> FusionM b -> FusionM b
>> :: forall a b. FusionM a -> FusionM b -> FusionM b
$creturn :: forall a. a -> FusionM a
return :: forall a. a -> FusionM a
Monad,
Functor FusionM
Functor FusionM =>
(forall a. a -> FusionM a)
-> (forall a b. FusionM (a -> b) -> FusionM a -> FusionM b)
-> (forall a b c.
(a -> b -> c) -> FusionM a -> FusionM b -> FusionM c)
-> (forall a b. FusionM a -> FusionM b -> FusionM b)
-> (forall a b. FusionM a -> FusionM b -> FusionM a)
-> Applicative FusionM
forall a. a -> FusionM a
forall a b. FusionM a -> FusionM b -> FusionM a
forall a b. FusionM a -> FusionM b -> FusionM b
forall a b. FusionM (a -> b) -> FusionM a -> FusionM b
forall a b c. (a -> b -> c) -> FusionM a -> FusionM b -> FusionM c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall a. a -> FusionM a
pure :: forall a. a -> FusionM a
$c<*> :: forall a b. FusionM (a -> b) -> FusionM a -> FusionM b
<*> :: forall a b. FusionM (a -> b) -> FusionM a -> FusionM b
$cliftA2 :: forall a b c. (a -> b -> c) -> FusionM a -> FusionM b -> FusionM c
liftA2 :: forall a b c. (a -> b -> c) -> FusionM a -> FusionM b -> FusionM c
$c*> :: forall a b. FusionM a -> FusionM b -> FusionM b
*> :: forall a b. FusionM a -> FusionM b -> FusionM b
$c<* :: forall a b. FusionM a -> FusionM b -> FusionM a
<* :: forall a b. FusionM a -> FusionM b -> FusionM a
Applicative,
(forall a b. (a -> b) -> FusionM a -> FusionM b)
-> (forall a b. a -> FusionM b -> FusionM a) -> Functor FusionM
forall a b. a -> FusionM b -> FusionM a
forall a b. (a -> b) -> FusionM a -> FusionM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> FusionM a -> FusionM b
fmap :: forall a b. (a -> b) -> FusionM a -> FusionM b
$c<$ :: forall a b. a -> FusionM b -> FusionM a
<$ :: forall a b. a -> FusionM b -> FusionM a
Functor,
MonadState FusionEnv,
HasScope SOACS,
LocalScope SOACS
)
instance MonadFreshNames FusionM where
getNameSource :: FusionM VNameSource
getNameSource = (FusionEnv -> VNameSource) -> FusionM VNameSource
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets FusionEnv -> VNameSource
vNameSource
putNameSource :: VNameSource -> FusionM ()
putNameSource VNameSource
source =
(FusionEnv -> FusionEnv) -> FusionM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\FusionEnv
env -> FusionEnv
env {vNameSource = source})
runFusionM :: (MonadFreshNames m) => Scope SOACS -> FusionEnv -> FusionM a -> m a
runFusionM :: forall (m :: * -> *) a.
MonadFreshNames m =>
Scope SOACS -> FusionEnv -> FusionM a -> m a
runFusionM Scope SOACS
scope FusionEnv
fenv (FusionM ReaderT (Scope SOACS) (State FusionEnv) a
a) = (VNameSource -> (a, VNameSource)) -> m a
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (a, VNameSource)) -> m a)
-> (VNameSource -> (a, VNameSource)) -> m a
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
let x :: State FusionEnv a
x = ReaderT (Scope SOACS) (State FusionEnv) a
-> Scope SOACS -> State FusionEnv a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Scope SOACS) (State FusionEnv) a
a Scope SOACS
scope
(a
y, FusionEnv
z) = State FusionEnv a -> FusionEnv -> (a, FusionEnv)
forall s a. State s a -> s -> (a, s)
runState State FusionEnv a
x (FusionEnv
fenv {vNameSource = src})
in (a
y, FusionEnv -> VNameSource
vNameSource FusionEnv
z)
doFuseScans :: FusionM a -> FusionM a
doFuseScans :: forall a. FusionM a -> FusionM a
doFuseScans FusionM a
m = do
Bool
fs <- (FusionEnv -> Bool) -> FusionM Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets FusionEnv -> Bool
fuseScans
(FusionEnv -> FusionEnv) -> FusionM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\FusionEnv
s -> FusionEnv
s {fuseScans = True})
a
r <- FusionM a
m
(FusionEnv -> FusionEnv) -> FusionM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\FusionEnv
s -> FusionEnv
s {fuseScans = fs})
a -> FusionM a
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
r
dontFuseScans :: FusionM a -> FusionM a
dontFuseScans :: forall a. FusionM a -> FusionM a
dontFuseScans FusionM a
m = do
Bool
fs <- (FusionEnv -> Bool) -> FusionM Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets FusionEnv -> Bool
fuseScans
(FusionEnv -> FusionEnv) -> FusionM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\FusionEnv
s -> FusionEnv
s {fuseScans = False})
a
r <- FusionM a
m
(FusionEnv -> FusionEnv) -> FusionM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\FusionEnv
s -> FusionEnv
s {fuseScans = fs})
a -> FusionM a
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
r
isNotVarInput :: [H.Input] -> [H.Input]
isNotVarInput :: [Input] -> [Input]
isNotVarInput = (Input -> Bool) -> [Input] -> [Input]
forall a. (a -> Bool) -> [a] -> [a]
filter (Maybe VName -> Bool
forall a. Maybe a -> Bool
isNothing (Maybe VName -> Bool) -> (Input -> Maybe VName) -> Input -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Input -> Maybe VName
H.isVarInput)
finalizeNode :: (HasScope SOACS m, MonadFreshNames m) => NodeT -> m (Stms SOACS)
finalizeNode :: forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
NodeT -> m (Stms SOACS)
finalizeNode NodeT
nt = case NodeT
nt of
StmNode Stm SOACS
stm -> Stms SOACS -> m (Stms SOACS)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms SOACS -> m (Stms SOACS)) -> Stms SOACS -> m (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Stms SOACS
forall rep. Stm rep -> Stms rep
oneStm Stm SOACS
stm
SoacNode ArrayTransforms
ots Pat Type
outputs SOAC SOACS
soac StmAux (ExpDec SOACS)
aux -> Builder SOACS () -> m (Stms SOACS)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Builder SOACS () -> m (Stms SOACS))
-> Builder SOACS () -> m (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ do
[VName]
untransformed_outputs <- (VName -> BuilderT SOACS (State VNameSource) VName)
-> [VName] -> BuilderT SOACS (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> BuilderT SOACS (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => VName -> m VName
newName ([VName] -> BuilderT SOACS (State VNameSource) [VName])
-> [VName] -> BuilderT SOACS (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
outputs
StmAux () -> Builder SOACS () -> Builder SOACS ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
StmAux (ExpDec SOACS)
aux (Builder SOACS () -> Builder SOACS ())
-> Builder SOACS () -> Builder SOACS ()
forall a b. (a -> b) -> a -> b
$ [VName]
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> Builder SOACS ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName]
untransformed_outputs (Exp SOACS -> Builder SOACS ())
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> Builder SOACS ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> Builder SOACS ())
-> BuilderT SOACS (State VNameSource) (SOAC SOACS)
-> Builder SOACS ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SOAC (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT
SOACS
(State VNameSource)
(SOAC (Rep (BuilderT SOACS (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
SOAC (Rep m) -> m (SOAC (Rep m))
H.toSOAC SOAC (Rep (BuilderT SOACS (State VNameSource)))
SOAC SOACS
soac
[(VName, VName)]
-> ((VName, VName) -> Builder SOACS ()) -> Builder SOACS ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
outputs) [VName]
untransformed_outputs) (((VName, VName) -> Builder SOACS ()) -> Builder SOACS ())
-> ((VName, VName) -> Builder SOACS ()) -> Builder SOACS ()
forall a b. (a -> b) -> a -> b
$ \(VName
output, VName
v) ->
[VName]
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> Builder SOACS ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
output] (Exp SOACS -> Builder SOACS ())
-> (VName -> Exp SOACS) -> VName -> Builder SOACS ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> (VName -> BasicOp) -> VName -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> (VName -> SubExp) -> VName -> BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> Builder SOACS ())
-> BuilderT SOACS (State VNameSource) VName -> Builder SOACS ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ArrayTransforms
-> VName -> BuilderT SOACS (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
ArrayTransforms -> VName -> m VName
H.applyTransforms ArrayTransforms
ots VName
v
ResNode VName
_ -> Stms SOACS -> m (Stms SOACS)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms SOACS
forall a. Monoid a => a
mempty
TransNode VName
output ArrayTransform
tr VName
ia -> do
(Certs
cs, Exp SOACS
e) <- ArrayTransform -> VName -> m (Certs, Exp SOACS)
forall (m :: * -> *) rep.
(Monad m, HasScope rep m) =>
ArrayTransform -> VName -> m (Certs, Exp rep)
H.transformToExp ArrayTransform
tr VName
ia
Builder SOACS () -> m (Stms SOACS)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Builder SOACS () -> m (Stms SOACS))
-> Builder SOACS () -> m (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ Certs -> Builder SOACS () -> Builder SOACS ()
forall a.
Certs
-> BuilderT SOACS (State VNameSource) a
-> BuilderT SOACS (State VNameSource) a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (Builder SOACS () -> Builder SOACS ())
-> Builder SOACS () -> Builder SOACS ()
forall a b. (a -> b) -> a -> b
$ [VName]
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> Builder SOACS ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
output] Exp (Rep (BuilderT SOACS (State VNameSource)))
Exp SOACS
e
FreeNode VName
_ -> Stms SOACS -> m (Stms SOACS)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms SOACS
forall a. Monoid a => a
mempty
DoNode Stm SOACS
stm [(NodeT, [EdgeT])]
lst -> do
[Stms SOACS]
lst' <- ((NodeT, [EdgeT]) -> m (Stms SOACS))
-> [(NodeT, [EdgeT])] -> m [Stms SOACS]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (NodeT -> m (Stms SOACS)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
NodeT -> m (Stms SOACS)
finalizeNode (NodeT -> m (Stms SOACS))
-> ((NodeT, [EdgeT]) -> NodeT)
-> (NodeT, [EdgeT])
-> m (Stms SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (NodeT, [EdgeT]) -> NodeT
forall a b. (a, b) -> a
fst) [(NodeT, [EdgeT])]
lst
Stms SOACS -> m (Stms SOACS)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms SOACS -> m (Stms SOACS)) -> Stms SOACS -> m (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ [Stms SOACS] -> Stms SOACS
forall a. Monoid a => [a] -> a
mconcat [Stms SOACS]
lst' Stms SOACS -> Stms SOACS -> Stms SOACS
forall a. Semigroup a => a -> a -> a
<> Stm SOACS -> Stms SOACS
forall rep. Stm rep -> Stms rep
oneStm Stm SOACS
stm
MatchNode Stm SOACS
stm [(NodeT, [EdgeT])]
lst -> do
[Stms SOACS]
lst' <- ((NodeT, [EdgeT]) -> m (Stms SOACS))
-> [(NodeT, [EdgeT])] -> m [Stms SOACS]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (NodeT -> m (Stms SOACS)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
NodeT -> m (Stms SOACS)
finalizeNode (NodeT -> m (Stms SOACS))
-> ((NodeT, [EdgeT]) -> NodeT)
-> (NodeT, [EdgeT])
-> m (Stms SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (NodeT, [EdgeT]) -> NodeT
forall a b. (a, b) -> a
fst) [(NodeT, [EdgeT])]
lst
Stms SOACS -> m (Stms SOACS)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms SOACS -> m (Stms SOACS)) -> Stms SOACS -> m (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ [Stms SOACS] -> Stms SOACS
forall a. Monoid a => [a] -> a
mconcat [Stms SOACS]
lst' Stms SOACS -> Stms SOACS -> Stms SOACS
forall a. Semigroup a => a -> a -> a
<> Stm SOACS -> Stms SOACS
forall rep. Stm rep -> Stms rep
oneStm Stm SOACS
stm
linearizeGraph :: (HasScope SOACS m, MonadFreshNames m) => DepGraph -> m (Stms SOACS)
linearizeGraph :: forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
DepGraph -> m (Stms SOACS)
linearizeGraph DepGraph
dg =
([Stms SOACS] -> Stms SOACS) -> m [Stms SOACS] -> m (Stms SOACS)
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Stms SOACS] -> Stms SOACS
forall a. Monoid a => [a] -> a
mconcat (m [Stms SOACS] -> m (Stms SOACS))
-> m [Stms SOACS] -> m (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ (NodeT -> m (Stms SOACS)) -> [NodeT] -> m [Stms SOACS]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM NodeT -> m (Stms SOACS)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
NodeT -> m (Stms SOACS)
finalizeNode ([NodeT] -> m [Stms SOACS]) -> [NodeT] -> m [Stms SOACS]
forall a b. (a -> b) -> a -> b
$ [NodeT] -> [NodeT]
forall a. [a] -> [a]
reverse ([NodeT] -> [NodeT]) -> [NodeT] -> [NodeT]
forall a b. (a -> b) -> a -> b
$ Gr NodeT EdgeT -> [NodeT]
forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> [a]
Q.topsort' (DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg)
fusedSomething :: NodeT -> FusionM (Maybe NodeT)
fusedSomething :: NodeT -> FusionM (Maybe NodeT)
fusedSomething NodeT
x = do
(FusionEnv -> FusionEnv) -> FusionM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((FusionEnv -> FusionEnv) -> FusionM ())
-> (FusionEnv -> FusionEnv) -> FusionM ()
forall a b. (a -> b) -> a -> b
$ \FusionEnv
s -> FusionEnv
s {fusionCount = 1 + fusionCount s}
Maybe NodeT -> FusionM (Maybe NodeT)
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe NodeT -> FusionM (Maybe NodeT))
-> Maybe NodeT -> FusionM (Maybe NodeT)
forall a b. (a -> b) -> a -> b
$ NodeT -> Maybe NodeT
forall a. a -> Maybe a
Just NodeT
x
vTryFuseNodesInGraph :: G.Node -> G.Node -> DepGraphAug FusionM
vTryFuseNodesInGraph :: Int -> Int -> DepGraphAug FusionM
vTryFuseNodesInGraph Int
node_1 Int
node_2 dg :: DepGraph
dg@DepGraph {dgGraph :: DepGraph -> Gr NodeT EdgeT
dgGraph = Gr NodeT EdgeT
g}
| Bool -> Bool
not (Int -> Gr NodeT EdgeT -> Bool
forall (gr :: * -> * -> *) a b. Graph gr => Int -> gr a b -> Bool
G.gelem Int
node_1 Gr NodeT EdgeT
g Bool -> Bool -> Bool
&& Int -> Gr NodeT EdgeT -> Bool
forall (gr :: * -> * -> *) a b. Graph gr => Int -> gr a b -> Bool
G.gelem Int
node_2 Gr NodeT EdgeT
g) = DepGraphAug FusionM
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure DepGraph
dg
| DepGraph -> Int -> Int -> Bool
vFusionFeasability DepGraph
dg Int
node_1 Int
node_2 = do
let (Context NodeT EdgeT
ctx1, Context NodeT EdgeT
ctx2) = (Gr NodeT EdgeT -> Int -> Context NodeT EdgeT
forall (gr :: * -> * -> *) a b.
Graph gr =>
gr a b -> Int -> Context a b
G.context Gr NodeT EdgeT
g Int
node_1, Gr NodeT EdgeT -> Int -> Context NodeT EdgeT
forall (gr :: * -> * -> *) a b.
Graph gr =>
gr a b -> Int -> Context a b
G.context Gr NodeT EdgeT
g Int
node_2)
Maybe (Context NodeT EdgeT)
fres <- [EdgeT]
-> [VName]
-> Context NodeT EdgeT
-> Context NodeT EdgeT
-> FusionM (Maybe (Context NodeT EdgeT))
vFuseContexts [EdgeT]
edgs [VName]
infusable_nodes Context NodeT EdgeT
ctx1 Context NodeT EdgeT
ctx2
case Maybe (Context NodeT EdgeT)
fres of
Just (Adj EdgeT
inputs, Int
_, NodeT
nodeT, Adj EdgeT
outputs) -> do
NodeT
nodeT' <-
if [VName] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
fusedC
then NodeT -> FusionM NodeT
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure NodeT
nodeT
else do
let (Adj EdgeT
_, Int
_, NodeT
_, Adj EdgeT
deps_1) = Context NodeT EdgeT
ctx1
(Adj EdgeT
_, Int
_, NodeT
_, Adj EdgeT
deps_2) = Context NodeT EdgeT
ctx2
old_cons :: [VName]
old_cons = ((EdgeT, Int) -> VName) -> Adj EdgeT -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (EdgeT -> VName
getName (EdgeT -> VName)
-> ((EdgeT, Int) -> EdgeT) -> (EdgeT, Int) -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (EdgeT, Int) -> EdgeT
forall a b. (a, b) -> a
fst) (Adj EdgeT -> [VName]) -> Adj EdgeT -> [VName]
forall a b. (a -> b) -> a -> b
$ ((EdgeT, Int) -> Bool) -> Adj EdgeT -> Adj EdgeT
forall a. (a -> Bool) -> [a] -> [a]
filter (EdgeT -> Bool
isCons (EdgeT -> Bool) -> ((EdgeT, Int) -> EdgeT) -> (EdgeT, Int) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (EdgeT, Int) -> EdgeT
forall a b. (a, b) -> a
fst) (Adj EdgeT
deps_1 Adj EdgeT -> Adj EdgeT -> Adj EdgeT
forall a. Semigroup a => a -> a -> a
<> Adj EdgeT
deps_2)
[VName] -> NodeT -> FusionM NodeT
forall (m :: * -> *).
(LocalScope SOACS m, MonadFreshNames m) =>
[VName] -> NodeT -> m NodeT
makeCopiesOfFusedExcept [VName]
old_cons NodeT
nodeT
Int -> Context NodeT EdgeT -> DepGraphAug FusionM
forall (m :: * -> *).
Monad m =>
Int -> Context NodeT EdgeT -> DepGraphAug m
contractEdge Int
node_2 (Adj EdgeT
inputs, Int
node_1, NodeT
nodeT', Adj EdgeT
outputs) DepGraph
dg
Maybe (Context NodeT EdgeT)
Nothing -> DepGraphAug FusionM
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure DepGraph
dg
| Bool
otherwise = DepGraphAug FusionM
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure DepGraph
dg
where
edgs :: [EdgeT]
edgs = (DepEdge -> EdgeT) -> [DepEdge] -> [EdgeT]
forall a b. (a -> b) -> [a] -> [b]
map DepEdge -> EdgeT
forall b. LEdge b -> b
G.edgeLabel ([DepEdge] -> [EdgeT]) -> [DepEdge] -> [EdgeT]
forall a b. (a -> b) -> a -> b
$ DepGraph -> Int -> Int -> [DepEdge]
edgesBetween DepGraph
dg Int
node_1 Int
node_2
fusedC :: [VName]
fusedC = (EdgeT -> VName) -> [EdgeT] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map EdgeT -> VName
getName ([EdgeT] -> [VName]) -> [EdgeT] -> [VName]
forall a b. (a -> b) -> a -> b
$ (EdgeT -> Bool) -> [EdgeT] -> [EdgeT]
forall a. (a -> Bool) -> [a] -> [a]
filter EdgeT -> Bool
isCons [EdgeT]
edgs
infusable_nodes :: [VName]
infusable_nodes =
(DepEdge -> VName) -> [DepEdge] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map
DepEdge -> VName
depsFromEdge
((Int -> [DepEdge]) -> [Int] -> [DepEdge]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (DepGraph -> Int -> Int -> [DepEdge]
edgesBetween DepGraph
dg Int
node_1) ((Int -> Bool) -> [Int] -> [Int]
forall a. (a -> Bool) -> [a] -> [a]
filter (Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
node_2) ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ Gr NodeT EdgeT -> Int -> [Int]
forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> Int -> [Int]
G.pre Gr NodeT EdgeT
g Int
node_1))
hTryFuseNodesInGraph :: G.Node -> G.Node -> DepGraphAug FusionM
hTryFuseNodesInGraph :: Int -> Int -> DepGraphAug FusionM
hTryFuseNodesInGraph Int
node_1 Int
node_2 dg :: DepGraph
dg@DepGraph {dgGraph :: DepGraph -> Gr NodeT EdgeT
dgGraph = Gr NodeT EdgeT
g}
| Bool -> Bool
not (Int -> Gr NodeT EdgeT -> Bool
forall (gr :: * -> * -> *) a b. Graph gr => Int -> gr a b -> Bool
G.gelem Int
node_1 Gr NodeT EdgeT
g Bool -> Bool -> Bool
&& Int -> Gr NodeT EdgeT -> Bool
forall (gr :: * -> * -> *) a b. Graph gr => Int -> gr a b -> Bool
G.gelem Int
node_2 Gr NodeT EdgeT
g) = DepGraphAug FusionM
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure DepGraph
dg
| DepGraph -> Int -> Int -> Bool
hFusionFeasability DepGraph
dg Int
node_1 Int
node_2 = do
Maybe (Context NodeT EdgeT)
fres <- Context NodeT EdgeT
-> Context NodeT EdgeT -> FusionM (Maybe (Context NodeT EdgeT))
hFuseContexts (Gr NodeT EdgeT -> Int -> Context NodeT EdgeT
forall (gr :: * -> * -> *) a b.
Graph gr =>
gr a b -> Int -> Context a b
G.context Gr NodeT EdgeT
g Int
node_1) (Gr NodeT EdgeT -> Int -> Context NodeT EdgeT
forall (gr :: * -> * -> *) a b.
Graph gr =>
gr a b -> Int -> Context a b
G.context Gr NodeT EdgeT
g Int
node_2)
case Maybe (Context NodeT EdgeT)
fres of
Just Context NodeT EdgeT
ctx -> Int -> Context NodeT EdgeT -> DepGraphAug FusionM
forall (m :: * -> *).
Monad m =>
Int -> Context NodeT EdgeT -> DepGraphAug m
contractEdge Int
node_2 Context NodeT EdgeT
ctx DepGraph
dg
Maybe (Context NodeT EdgeT)
Nothing -> DepGraphAug FusionM
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure DepGraph
dg
| Bool
otherwise = DepGraphAug FusionM
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure DepGraph
dg
hFuseContexts :: DepContext -> DepContext -> FusionM (Maybe DepContext)
hFuseContexts :: Context NodeT EdgeT
-> Context NodeT EdgeT -> FusionM (Maybe (Context NodeT EdgeT))
hFuseContexts Context NodeT EdgeT
c1 Context NodeT EdgeT
c2 = do
let (Adj EdgeT
_, Int
_, NodeT
nodeT1, Adj EdgeT
_) = Context NodeT EdgeT
c1
(Adj EdgeT
_, Int
_, NodeT
nodeT2, Adj EdgeT
_) = Context NodeT EdgeT
c2
Maybe NodeT
fres <- NodeT -> NodeT -> FusionM (Maybe NodeT)
hFuseNodeT NodeT
nodeT1 NodeT
nodeT2
case Maybe NodeT
fres of
Just NodeT
nodeT -> Maybe (Context NodeT EdgeT)
-> FusionM (Maybe (Context NodeT EdgeT))
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (Context NodeT EdgeT)
-> FusionM (Maybe (Context NodeT EdgeT)))
-> Maybe (Context NodeT EdgeT)
-> FusionM (Maybe (Context NodeT EdgeT))
forall a b. (a -> b) -> a -> b
$ Context NodeT EdgeT -> Maybe (Context NodeT EdgeT)
forall a. a -> Maybe a
Just (NodeT
-> Context NodeT EdgeT
-> Context NodeT EdgeT
-> Context NodeT EdgeT
forall b a. Ord b => a -> Context a b -> Context a b -> Context a b
mergedContext NodeT
nodeT Context NodeT EdgeT
c1 Context NodeT EdgeT
c2)
Maybe NodeT
Nothing -> Maybe (Context NodeT EdgeT)
-> FusionM (Maybe (Context NodeT EdgeT))
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Context NodeT EdgeT)
forall a. Maybe a
Nothing
vFuseContexts :: [EdgeT] -> [VName] -> DepContext -> DepContext -> FusionM (Maybe DepContext)
vFuseContexts :: [EdgeT]
-> [VName]
-> Context NodeT EdgeT
-> Context NodeT EdgeT
-> FusionM (Maybe (Context NodeT EdgeT))
vFuseContexts [EdgeT]
edgs [VName]
infusable Context NodeT EdgeT
c1 Context NodeT EdgeT
c2 = do
let (Adj EdgeT
i1, Int
n1, NodeT
nodeT1, Adj EdgeT
o1) = Context NodeT EdgeT
c1
(Adj EdgeT
_i2, Int
n2, NodeT
nodeT2, Adj EdgeT
o2) = Context NodeT EdgeT
c2
Maybe NodeT
fres <-
[EdgeT]
-> [VName]
-> (NodeT, [EdgeT], [EdgeT])
-> (NodeT, [EdgeT])
-> FusionM (Maybe NodeT)
vFuseNodeT
[EdgeT]
edgs
[VName]
infusable
(NodeT
nodeT1, ((EdgeT, Int) -> EdgeT) -> Adj EdgeT -> [EdgeT]
forall a b. (a -> b) -> [a] -> [b]
map (EdgeT, Int) -> EdgeT
forall a b. (a, b) -> a
fst (Adj EdgeT -> [EdgeT]) -> Adj EdgeT -> [EdgeT]
forall a b. (a -> b) -> a -> b
$ ((EdgeT, Int) -> Bool) -> Adj EdgeT -> Adj EdgeT
forall a. (a -> Bool) -> [a] -> [a]
filter (Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
(/=) Int
n2 (Int -> Bool) -> ((EdgeT, Int) -> Int) -> (EdgeT, Int) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (EdgeT, Int) -> Int
forall a b. (a, b) -> b
snd) Adj EdgeT
i1, ((EdgeT, Int) -> EdgeT) -> Adj EdgeT -> [EdgeT]
forall a b. (a -> b) -> [a] -> [b]
map (EdgeT, Int) -> EdgeT
forall a b. (a, b) -> a
fst Adj EdgeT
o1)
(NodeT
nodeT2, ((EdgeT, Int) -> EdgeT) -> Adj EdgeT -> [EdgeT]
forall a b. (a -> b) -> [a] -> [b]
map (EdgeT, Int) -> EdgeT
forall a b. (a, b) -> a
fst (Adj EdgeT -> [EdgeT]) -> Adj EdgeT -> [EdgeT]
forall a b. (a -> b) -> a -> b
$ ((EdgeT, Int) -> Bool) -> Adj EdgeT -> Adj EdgeT
forall a. (a -> Bool) -> [a] -> [a]
filter (Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
(/=) Int
n1 (Int -> Bool) -> ((EdgeT, Int) -> Int) -> (EdgeT, Int) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (EdgeT, Int) -> Int
forall a b. (a, b) -> b
snd) Adj EdgeT
o2)
case Maybe NodeT
fres of
Just NodeT
nodeT -> Maybe (Context NodeT EdgeT)
-> FusionM (Maybe (Context NodeT EdgeT))
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (Context NodeT EdgeT)
-> FusionM (Maybe (Context NodeT EdgeT)))
-> Maybe (Context NodeT EdgeT)
-> FusionM (Maybe (Context NodeT EdgeT))
forall a b. (a -> b) -> a -> b
$ Context NodeT EdgeT -> Maybe (Context NodeT EdgeT)
forall a. a -> Maybe a
Just (NodeT
-> Context NodeT EdgeT
-> Context NodeT EdgeT
-> Context NodeT EdgeT
forall b a. Ord b => a -> Context a b -> Context a b -> Context a b
mergedContext NodeT
nodeT Context NodeT EdgeT
c1 Context NodeT EdgeT
c2)
Maybe NodeT
Nothing -> Maybe (Context NodeT EdgeT)
-> FusionM (Maybe (Context NodeT EdgeT))
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Context NodeT EdgeT)
forall a. Maybe a
Nothing
makeCopiesOfFusedExcept ::
(LocalScope SOACS m, MonadFreshNames m) =>
[VName] ->
NodeT ->
m NodeT
makeCopiesOfFusedExcept :: forall (m :: * -> *).
(LocalScope SOACS m, MonadFreshNames m) =>
[VName] -> NodeT -> m NodeT
makeCopiesOfFusedExcept [VName]
noCopy (SoacNode ArrayTransforms
ots Pat Type
pats SOAC SOACS
soac StmAux (ExpDec SOACS)
aux) = do
let lam :: Lambda SOACS
lam = SOAC SOACS -> Lambda SOACS
forall rep. SOAC rep -> Lambda rep
H.lambda SOAC SOACS
soac
Scope SOACS -> m NodeT -> m NodeT
forall a. Scope SOACS -> m a -> m a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Lambda SOACS -> Scope SOACS
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Lambda SOACS
lam) (m NodeT -> m NodeT) -> m NodeT -> m NodeT
forall a b. (a -> b) -> a -> b
$ do
[VName]
fused_inner <-
(VName -> m Bool) -> [VName] -> m [VName]
forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM ((Type -> Bool) -> m Type -> m Bool
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Bool -> Bool
not (Bool -> Bool) -> (Type -> Bool) -> Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Bool
forall shape u. TypeBase shape u -> Bool
isAcc) (m Type -> m Bool) -> (VName -> m Type) -> VName -> m Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType) ([VName] -> m [VName])
-> (Lambda (Aliases SOACS) -> [VName])
-> Lambda (Aliases SOACS)
-> m [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList (Names -> [VName])
-> (Lambda (Aliases SOACS) -> Names)
-> Lambda (Aliases SOACS)
-> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda (Aliases SOACS) -> Names
forall rep. Aliased rep => Lambda rep -> Names
consumedByLambda (Lambda (Aliases SOACS) -> m [VName])
-> Lambda (Aliases SOACS) -> m [VName]
forall a b. (a -> b) -> a -> b
$
AliasTable -> Lambda SOACS -> Lambda (Aliases SOACS)
forall rep.
AliasableRep rep =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
forall a. Monoid a => a
mempty Lambda SOACS
lam
Lambda SOACS
lam' <- [VName] -> Lambda SOACS -> m (Lambda SOACS)
forall (m :: * -> *).
(LocalScope SOACS m, MonadFreshNames m) =>
[VName] -> Lambda SOACS -> m (Lambda SOACS)
makeCopiesInLambda ([VName]
fused_inner [VName] -> [VName] -> [VName]
forall a. Eq a => [a] -> [a] -> [a]
L.\\ [VName]
noCopy) Lambda SOACS
lam
NodeT -> m NodeT
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (NodeT -> m NodeT) -> NodeT -> m NodeT
forall a b. (a -> b) -> a -> b
$ ArrayTransforms
-> Pat Type -> SOAC SOACS -> StmAux (ExpDec SOACS) -> NodeT
SoacNode ArrayTransforms
ots Pat Type
pats (Lambda SOACS -> SOAC SOACS -> SOAC SOACS
forall rep. Lambda rep -> SOAC rep -> SOAC rep
H.setLambda Lambda SOACS
lam' SOAC SOACS
soac) StmAux (ExpDec SOACS)
aux
makeCopiesOfFusedExcept [VName]
_ NodeT
nodeT = NodeT -> m NodeT
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure NodeT
nodeT
makeCopiesInLambda ::
(LocalScope SOACS m, MonadFreshNames m) =>
[VName] ->
Lambda SOACS ->
m (Lambda SOACS)
makeCopiesInLambda :: forall (m :: * -> *).
(LocalScope SOACS m, MonadFreshNames m) =>
[VName] -> Lambda SOACS -> m (Lambda SOACS)
makeCopiesInLambda [VName]
toCopy Lambda SOACS
lam = do
(Stms SOACS
copies, Map VName VName
nameMap) <- [VName] -> m (Stms SOACS, Map VName VName)
forall (m :: * -> *).
(LocalScope SOACS m, MonadFreshNames m) =>
[VName] -> m (Stms SOACS, Map VName VName)
makeCopyStms [VName]
toCopy
let l_body :: Body SOACS
l_body = Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam
newBody :: Body SOACS
newBody = Stms SOACS -> Body SOACS -> Body SOACS
forall rep. Buildable rep => Stms rep -> Body rep -> Body rep
insertStms Stms SOACS
copies (Map VName VName -> Body SOACS -> Body SOACS
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
nameMap Body SOACS
l_body)
newLambda :: Lambda SOACS
newLambda = Lambda SOACS
lam {lambdaBody = newBody}
Lambda SOACS -> m (Lambda SOACS)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda SOACS
newLambda
makeCopyStms ::
(LocalScope SOACS m, MonadFreshNames m) =>
[VName] ->
m (Stms SOACS, M.Map VName VName)
makeCopyStms :: forall (m :: * -> *).
(LocalScope SOACS m, MonadFreshNames m) =>
[VName] -> m (Stms SOACS, Map VName VName)
makeCopyStms [VName]
vs = do
[VName]
vs' <- (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> m VName
forall (m :: * -> *). MonadFreshNames m => VName -> m VName
makeNewName [VName]
vs
[Stm SOACS]
copies <- [(VName, VName)]
-> ((VName, VName) -> m (Stm SOACS)) -> m [Stm SOACS]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
vs [VName]
vs') (((VName, VName) -> m (Stm SOACS)) -> m [Stm SOACS])
-> ((VName, VName) -> m (Stm SOACS)) -> m [Stm SOACS]
forall a b. (a -> b) -> a -> b
$ \(VName
name, VName
name') ->
[VName] -> Exp SOACS -> m (Stm SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m, HasScope rep m) =>
[VName] -> Exp rep -> m (Stm rep)
forall (m :: * -> *).
(MonadFreshNames m, HasScope SOACS m) =>
[VName] -> Exp SOACS -> m (Stm SOACS)
mkLetNames [VName
name'] (Exp SOACS -> m (Stm SOACS)) -> Exp SOACS -> m (Stm SOACS)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> SubExp -> BasicOp
Replicate ShapeBase SubExp
forall a. Monoid a => a
mempty (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
name
(Stms SOACS, Map VName VName) -> m (Stms SOACS, Map VName VName)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Stm SOACS] -> Stms SOACS
forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm SOACS]
copies, [(VName, VName)] -> Map VName VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, VName)] -> Map VName VName)
-> [(VName, VName)] -> Map VName VName
forall a b. (a -> b) -> a -> b
$ [VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
vs [VName]
vs')
where
makeNewName :: VName -> m VName
makeNewName VName
name = String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> m VName) -> String -> m VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
name String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_copy"
okToFuseProducer :: H.SOAC SOACS -> FusionM Bool
okToFuseProducer :: SOAC SOACS -> FusionM Bool
okToFuseProducer (H.Screma SubExp
_ [Input]
_ ScremaForm SOACS
form) = do
let is_scan :: Bool
is_scan = Maybe ([Scan SOACS], Lambda SOACS) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe ([Scan SOACS], Lambda SOACS) -> Bool)
-> Maybe ([Scan SOACS], Lambda SOACS) -> Bool
forall a b. (a -> b) -> a -> b
$ ScremaForm SOACS -> Maybe ([Scan SOACS], Lambda SOACS)
forall rep. ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
Futhark.isScanomapSOAC ScremaForm SOACS
form
(FusionEnv -> Bool) -> FusionM Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((FusionEnv -> Bool) -> FusionM Bool)
-> (FusionEnv -> Bool) -> FusionM Bool
forall a b. (a -> b) -> a -> b
$ (Bool -> Bool
not Bool
is_scan ||) (Bool -> Bool) -> (FusionEnv -> Bool) -> FusionEnv -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FusionEnv -> Bool
fuseScans
okToFuseProducer SOAC SOACS
_ = Bool -> FusionM Bool
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
vFuseNodeT :: [EdgeT] -> [VName] -> (NodeT, [EdgeT], [EdgeT]) -> (NodeT, [EdgeT]) -> FusionM (Maybe NodeT)
vFuseNodeT :: [EdgeT]
-> [VName]
-> (NodeT, [EdgeT], [EdgeT])
-> (NodeT, [EdgeT])
-> FusionM (Maybe NodeT)
vFuseNodeT [EdgeT]
_ [VName]
infusible (NodeT
s1, [EdgeT]
_, [EdgeT]
e1s) (MatchNode Stm SOACS
stm2 [(NodeT, [EdgeT])]
dfused, [EdgeT]
_)
| NodeT -> Bool
isRealNode NodeT
s1,
[VName] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
infusible =
Maybe NodeT -> FusionM (Maybe NodeT)
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe NodeT -> FusionM (Maybe NodeT))
-> Maybe NodeT -> FusionM (Maybe NodeT)
forall a b. (a -> b) -> a -> b
$ NodeT -> Maybe NodeT
forall a. a -> Maybe a
Just (NodeT -> Maybe NodeT) -> NodeT -> Maybe NodeT
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> [(NodeT, [EdgeT])] -> NodeT
MatchNode Stm SOACS
stm2 ([(NodeT, [EdgeT])] -> NodeT) -> [(NodeT, [EdgeT])] -> NodeT
forall a b. (a -> b) -> a -> b
$ (NodeT
s1, [EdgeT]
e1s) (NodeT, [EdgeT]) -> [(NodeT, [EdgeT])] -> [(NodeT, [EdgeT])]
forall a. a -> [a] -> [a]
: [(NodeT, [EdgeT])]
dfused
vFuseNodeT [EdgeT]
_ [VName]
infusible (TransNode VName
stm1_out ArrayTransform
tr VName
stm1_in, [EdgeT]
_, [EdgeT]
_) (SoacNode ArrayTransforms
ots2 Pat Type
pats2 SOAC SOACS
soac2 StmAux (ExpDec SOACS)
aux2, [EdgeT]
_)
| [VName] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
infusible = do
Type
stm1_in_t <- VName -> FusionM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
stm1_in
let onInput :: Input -> Input
onInput Input
inp
| Input -> VName
H.inputArray Input
inp VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
stm1_out =
ArrayTransforms -> VName -> Type -> Input
H.Input (ArrayTransform
tr ArrayTransform -> ArrayTransforms -> ArrayTransforms
H.<| Input -> ArrayTransforms
H.inputTransforms Input
inp) VName
stm1_in Type
stm1_in_t
| Bool
otherwise =
Input
inp
soac2' :: SOAC SOACS
soac2' = (Input -> Input) -> [Input] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map Input -> Input
onInput (SOAC SOACS -> [Input]
forall rep. SOAC rep -> [Input]
H.inputs SOAC SOACS
soac2) [Input] -> SOAC SOACS -> SOAC SOACS
forall rep. [Input] -> SOAC rep -> SOAC rep
`H.setInputs` SOAC SOACS
soac2
Maybe NodeT -> FusionM (Maybe NodeT)
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe NodeT -> FusionM (Maybe NodeT))
-> Maybe NodeT -> FusionM (Maybe NodeT)
forall a b. (a -> b) -> a -> b
$ NodeT -> Maybe NodeT
forall a. a -> Maybe a
Just (NodeT -> Maybe NodeT) -> NodeT -> Maybe NodeT
forall a b. (a -> b) -> a -> b
$ ArrayTransforms
-> Pat Type -> SOAC SOACS -> StmAux (ExpDec SOACS) -> NodeT
SoacNode ArrayTransforms
ots2 Pat Type
pats2 SOAC SOACS
soac2' StmAux (ExpDec SOACS)
aux2
vFuseNodeT
[EdgeT]
_
[VName]
_
(SoacNode ArrayTransforms
ots1 Pat Type
pats1 SOAC SOACS
soac1 StmAux (ExpDec SOACS)
aux1, [EdgeT]
i1s, [EdgeT]
_e1s)
(SoacNode ArrayTransforms
ots2 Pat Type
pats2 SOAC SOACS
soac2 StmAux (ExpDec SOACS)
aux2, [EdgeT]
_e2s) = do
let ker :: FusedSOAC
ker =
TF.FusedSOAC
{ fsSOAC :: SOAC SOACS
TF.fsSOAC = SOAC SOACS
soac2,
fsOutputTransform :: ArrayTransforms
TF.fsOutputTransform = ArrayTransforms
ots2,
fsOutNames :: [VName]
TF.fsOutNames = Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
pats2
}
preserveEdge :: EdgeT -> Bool
preserveEdge InfDep {} = Bool
True
preserveEdge EdgeT
e = EdgeT -> Bool
isDep EdgeT
e
preserve :: Names
preserve = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (EdgeT -> VName) -> [EdgeT] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map EdgeT -> VName
getName ([EdgeT] -> [VName]) -> [EdgeT] -> [VName]
forall a b. (a -> b) -> a -> b
$ (EdgeT -> Bool) -> [EdgeT] -> [EdgeT]
forall a. (a -> Bool) -> [a] -> [a]
filter EdgeT -> Bool
preserveEdge [EdgeT]
i1s
Bool
ok <- SOAC SOACS -> FusionM Bool
okToFuseProducer SOAC SOACS
soac1
Maybe FusedSOAC
r <-
if Bool
ok Bool -> Bool -> Bool
&& ArrayTransforms
ots1 ArrayTransforms -> ArrayTransforms -> Bool
forall a. Eq a => a -> a -> Bool
== ArrayTransforms
forall a. Monoid a => a
mempty
then Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> FusionM (Maybe FusedSOAC)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> m (Maybe FusedSOAC)
TF.attemptFusion Mode
TF.Vertical Names
preserve (Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
pats1) SOAC SOACS
soac1 FusedSOAC
ker
else Maybe FusedSOAC -> FusionM (Maybe FusedSOAC)
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe FusedSOAC
forall a. Maybe a
Nothing
case Maybe FusedSOAC
r of
Just FusedSOAC
ker' -> do
let pats2' :: [PatElem Type]
pats2' =
(VName -> Type -> PatElem Type)
-> [VName] -> [Type] -> [PatElem Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> Type -> PatElem Type
forall dec. VName -> dec -> PatElem dec
PatElem (FusedSOAC -> [VName]
TF.fsOutNames FusedSOAC
ker') (SOAC SOACS -> [Type]
forall rep. SOAC rep -> [Type]
H.typeOf (FusedSOAC -> SOAC SOACS
TF.fsSOAC FusedSOAC
ker'))
NodeT -> FusionM (Maybe NodeT)
fusedSomething (NodeT -> FusionM (Maybe NodeT)) -> NodeT -> FusionM (Maybe NodeT)
forall a b. (a -> b) -> a -> b
$
ArrayTransforms
-> Pat Type -> SOAC SOACS -> StmAux (ExpDec SOACS) -> NodeT
SoacNode
(FusedSOAC -> ArrayTransforms
TF.fsOutputTransform FusedSOAC
ker')
([PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type]
pats2')
(FusedSOAC -> SOAC SOACS
TF.fsSOAC FusedSOAC
ker')
(StmAux ()
StmAux (ExpDec SOACS)
aux1 StmAux () -> StmAux () -> StmAux ()
forall a. Semigroup a => a -> a -> a
<> StmAux ()
StmAux (ExpDec SOACS)
aux2)
Maybe FusedSOAC
Nothing -> Maybe NodeT -> FusionM (Maybe NodeT)
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe NodeT
forall a. Maybe a
Nothing
vFuseNodeT
[EdgeT]
_
[VName]
infusible
(SoacNode ArrayTransforms
ots1 Pat Type
pat1 (H.Screma SubExp
w [Input]
inps ScremaForm SOACS
form) StmAux (ExpDec SOACS)
aux1, [EdgeT]
_, [EdgeT]
_)
(TransNode VName
stm2_out (H.Index Certs
cs slice :: Slice SubExp
slice@(Slice (ds :: DimIndex SubExp
ds@(DimSlice SubExp
_ SubExp
w' SubExp
_) : [DimIndex SubExp]
ds_rest))) VName
_, [EdgeT]
_)
| [VName] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
infusible,
SubExp
w SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
/= SubExp
w',
ArrayTransforms
ots1 ArrayTransforms -> ArrayTransforms -> Bool
forall a. Eq a => a -> a -> Bool
== ArrayTransforms
forall a. Monoid a => a
mempty,
Just Lambda SOACS
_ <- ScremaForm SOACS -> Maybe (Lambda SOACS)
forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm SOACS
form,
[PatElem Type
pe] <- Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
pat1 = do
let out_t :: Type
out_t = PatElem Type -> Type
forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem Type
pe Type -> ShapeBase SubExp -> Type
forall newshape oldshape u.
ArrayShape newshape =>
TypeBase oldshape u -> newshape -> TypeBase newshape u
`setArrayShape` Slice SubExp -> ShapeBase SubExp
forall d. Slice d -> ShapeBase d
sliceShape Slice SubExp
slice
inps' :: [Input]
inps' = (Input -> Input) -> [Input] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map Input -> Input
sliceInput [Input]
inps
ots1' :: ArrayTransforms
ots1' = ArrayTransforms
ots1 ArrayTransforms -> ArrayTransform -> ArrayTransforms
H.|> Certs -> Slice SubExp -> ArrayTransform
H.Index Certs
cs ([DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice (SubExp -> DimIndex SubExp
sliceDim SubExp
w' DimIndex SubExp -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. a -> [a] -> [a]
: [DimIndex SubExp]
ds_rest))
NodeT -> FusionM (Maybe NodeT)
fusedSomething (NodeT -> FusionM (Maybe NodeT)) -> NodeT -> FusionM (Maybe NodeT)
forall a b. (a -> b) -> a -> b
$
ArrayTransforms
-> Pat Type -> SOAC SOACS -> StmAux (ExpDec SOACS) -> NodeT
SoacNode
ArrayTransforms
ots1'
([PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [VName -> Type -> PatElem Type
forall dec. VName -> dec -> PatElem dec
PatElem VName
stm2_out Type
out_t])
(SubExp -> [Input] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [Input] -> ScremaForm rep -> SOAC rep
H.Screma SubExp
w' [Input]
inps' ScremaForm SOACS
form)
StmAux (ExpDec SOACS)
aux1
where
sliceInput :: Input -> Input
sliceInput Input
inp =
ArrayTransform -> Input -> Input
H.addTransform
(Certs -> Slice SubExp -> ArrayTransform
H.Index Certs
cs (Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice (Input -> Type
H.inputType Input
inp) [DimIndex SubExp
ds]))
Input
inp
vFuseNodeT
[EdgeT]
_edges
[VName]
_infusible
(SoacNode ArrayTransforms
ots1 Pat Type
pat1 soac :: SOAC SOACS
soac@(H.Screma SubExp
_w [Input]
_form ScremaForm SOACS
_s_inps) StmAux (ExpDec SOACS)
aux1, [EdgeT]
is1, [EdgeT]
_os1)
(StmNode (Let Pat (LetDec SOACS)
pat2 StmAux (ExpDec SOACS)
aux2 (WithAcc [WithAccInput SOACS]
w_inps Lambda SOACS
lam0)), [EdgeT]
_os2)
| ArrayTransforms
ots1 ArrayTransforms -> ArrayTransforms -> Bool
forall a. Eq a => a -> a -> Bool
== ArrayTransforms
forall a. Monoid a => a
mempty,
Names
wacc_cons_nms <- [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (WithAccInput SOACS -> [VName]) -> [WithAccInput SOACS] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\(ShapeBase SubExp
_, [VName]
nms, Maybe (Lambda SOACS, [SubExp])
_) -> [VName]
nms) [WithAccInput SOACS]
w_inps,
[VName]
soac_prod_nms <- (PatElem Type -> VName) -> [PatElem Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName ([PatElem Type] -> [VName]) -> [PatElem Type] -> [VName]
forall a b. (a -> b) -> a -> b
$ Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
pat1,
[VName]
soac_indep_nms <- (EdgeT -> VName) -> [EdgeT] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map EdgeT -> VName
getName [EdgeT]
is1,
(VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> Names -> Bool
`notNameIn` Names
wacc_cons_nms) ([VName]
soac_indep_nms [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
soac_prod_nms) =
do
Lambda SOACS
lam <- (Lambda SOACS, Bool) -> Lambda SOACS
forall a b. (a, b) -> a
fst ((Lambda SOACS, Bool) -> Lambda SOACS)
-> FusionM (Lambda SOACS, Bool) -> FusionM (Lambda SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Lambda SOACS -> FusionM (Lambda SOACS, Bool)
doFusionInLambda Lambda SOACS
lam0
Body SOACS
bdy' <-
Builder SOACS Result -> FusionM (Body SOACS)
forall rep (m :: * -> *) somerep.
(Buildable rep, MonadFreshNames m, HasScope somerep m,
SameScope somerep rep) =>
Builder rep Result -> m (Body rep)
runBodyBuilder (Builder SOACS Result -> FusionM (Body SOACS))
-> Builder SOACS Result -> FusionM (Body SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Builder SOACS Result -> Builder SOACS Result
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Lambda SOACS
lam (Builder SOACS Result -> Builder SOACS Result)
-> Builder SOACS Result -> Builder SOACS Result
forall a b. (a -> b) -> a -> b
$ do
Exp SOACS
soac' <- SOAC (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT
SOACS
(State VNameSource)
(Exp (Rep (BuilderT SOACS (State VNameSource))))
forall (m :: * -> *).
(MonadBuilder m, Op (Rep m) ~ SOAC (Rep m)) =>
SOAC (Rep m) -> m (Exp (Rep m))
H.toExp SOAC (Rep (BuilderT SOACS (State VNameSource)))
SOAC SOACS
soac
Stm (Rep (BuilderT SOACS (State VNameSource))) -> Builder SOACS ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep (BuilderT SOACS (State VNameSource)))
-> Builder SOACS ())
-> Stm (Rep (BuilderT SOACS (State VNameSource)))
-> Builder SOACS ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec SOACS)
pat1 StmAux (ExpDec SOACS)
aux1 Exp SOACS
soac'
Result
lam_res <- Body (Rep (BuilderT SOACS (State VNameSource)))
-> Builder SOACS Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (Body (Rep (BuilderT SOACS (State VNameSource)))
-> Builder SOACS Result)
-> Body (Rep (BuilderT SOACS (State VNameSource)))
-> Builder SOACS Result
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam
let pat1_res :: Result
pat1_res = (VName -> SubExpRes) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (Certs -> SubExp -> SubExpRes
SubExpRes ([VName] -> Certs
Certs []) (SubExp -> SubExpRes) -> (VName -> SubExp) -> VName -> SubExpRes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
soac_prod_nms
Result -> Builder SOACS Result
forall a. a -> BuilderT SOACS (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> Builder SOACS Result) -> Result -> Builder SOACS Result
forall a b. (a -> b) -> a -> b
$ Result
lam_res Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
pat1_res
let lam_ret_tp :: [Type]
lam_ret_tp = Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
lam [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ (PatElem Type -> Type) -> [PatElem Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map PatElem Type -> Type
forall dec. Typed dec => PatElem dec -> Type
patElemType (Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
pat1)
pat :: Pat Type
pat = [PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem Type] -> Pat Type) -> [PatElem Type] -> Pat Type
forall a b. (a -> b) -> a -> b
$ Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
Pat (LetDec SOACS)
pat2 [PatElem Type] -> [PatElem Type] -> [PatElem Type]
forall a. [a] -> [a] -> [a]
++ Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
pat1
Lambda SOACS
lam' <- Lambda SOACS -> FusionM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda (Lambda SOACS -> FusionM (Lambda SOACS))
-> Lambda SOACS -> FusionM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS
lam {lambdaBody = bdy', lambdaReturnType = lam_ret_tp}
(Lambda SOACS
lam'', Bool
success) <- Lambda SOACS -> FusionM (Lambda SOACS, Bool)
doFusionInLambda Lambda SOACS
lam'
if Bool -> Bool
not Bool
success
then Maybe NodeT -> FusionM (Maybe NodeT)
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe NodeT
forall a. Maybe a
Nothing
else do
NodeT -> FusionM (Maybe NodeT)
fusedSomething (NodeT -> FusionM (Maybe NodeT)) -> NodeT -> FusionM (Maybe NodeT)
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> NodeT
StmNode (Stm SOACS -> NodeT) -> Stm SOACS -> NodeT
forall a b. (a -> b) -> a -> b
$ Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux2 (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ [WithAccInput SOACS] -> Lambda SOACS -> Exp SOACS
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [WithAccInput SOACS]
w_inps Lambda SOACS
lam''
vFuseNodeT
[EdgeT]
edges
[VName]
_infusible
(StmNode (Let Pat (LetDec SOACS)
pat1 StmAux (ExpDec SOACS)
aux1 (WithAcc [WithAccInput SOACS]
w_inps Lambda SOACS
wlam0)), [EdgeT]
_is1, [EdgeT]
_os1)
(SoacNode ArrayTransforms
ots2 Pat Type
pat2 soac :: SOAC SOACS
soac@(H.Screma SubExp
_w [Input]
_form ScremaForm SOACS
_s_inps) StmAux (ExpDec SOACS)
aux2, [EdgeT]
_os2)
| ArrayTransforms
ots2 ArrayTransforms -> ArrayTransforms -> Bool
forall a. Eq a => a -> a -> Bool
== ArrayTransforms
forall a. Monoid a => a
mempty,
Int
n <- [Param Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
wlam0) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2,
Names
pat1_acc_nms <- [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
take Int
n ([VName] -> [VName]) -> [VName] -> [VName]
forall a b. (a -> b) -> a -> b
$ (PatElem Type -> VName) -> [PatElem Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName ([PatElem Type] -> [VName]) -> [PatElem Type] -> [VName]
forall a b. (a -> b) -> a -> b
$ Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
Pat (LetDec SOACS)
pat1,
(EdgeT -> Bool) -> [EdgeT] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((VName -> Names -> Bool
`notNameIn` Names
pat1_acc_nms) (VName -> Bool) -> (EdgeT -> VName) -> EdgeT -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EdgeT -> VName
getName) [EdgeT]
edges = do
let empty_aux :: StmAux ()
empty_aux = Certs -> Attrs -> () -> StmAux ()
forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
forall a. Monoid a => a
mempty Attrs
forall a. Monoid a => a
mempty ()
forall a. Monoid a => a
mempty
Lambda SOACS
wlam <- (Lambda SOACS, Bool) -> Lambda SOACS
forall a b. (a, b) -> a
fst ((Lambda SOACS, Bool) -> Lambda SOACS)
-> FusionM (Lambda SOACS, Bool) -> FusionM (Lambda SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Lambda SOACS -> FusionM (Lambda SOACS, Bool)
doFusionInLambda Lambda SOACS
wlam0
Body SOACS
bdy' <-
Builder SOACS Result -> FusionM (Body SOACS)
forall rep (m :: * -> *) somerep.
(Buildable rep, MonadFreshNames m, HasScope somerep m,
SameScope somerep rep) =>
Builder rep Result -> m (Body rep)
runBodyBuilder (Builder SOACS Result -> FusionM (Body SOACS))
-> Builder SOACS Result -> FusionM (Body SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Builder SOACS Result -> Builder SOACS Result
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Lambda SOACS
wlam (Builder SOACS Result -> Builder SOACS Result)
-> Builder SOACS Result -> Builder SOACS Result
forall a b. (a -> b) -> a -> b
$ do
Result
wlam_res <- Body (Rep (BuilderT SOACS (State VNameSource)))
-> Builder SOACS Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (Body (Rep (BuilderT SOACS (State VNameSource)))
-> Builder SOACS Result)
-> Body (Rep (BuilderT SOACS (State VNameSource)))
-> Builder SOACS Result
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
wlam
let other_pr1 :: [(PatElem Type, SubExpRes)]
other_pr1 = Int -> [(PatElem Type, SubExpRes)] -> [(PatElem Type, SubExpRes)]
forall a. Int -> [a] -> [a]
drop Int
n ([(PatElem Type, SubExpRes)] -> [(PatElem Type, SubExpRes)])
-> [(PatElem Type, SubExpRes)] -> [(PatElem Type, SubExpRes)]
forall a b. (a -> b) -> a -> b
$ [PatElem Type] -> Result -> [(PatElem Type, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
Pat (LetDec SOACS)
pat1) Result
wlam_res
[(PatElem Type, SubExpRes)]
-> ((PatElem Type, SubExpRes) -> Builder SOACS ())
-> Builder SOACS ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(PatElem Type, SubExpRes)]
other_pr1 (((PatElem Type, SubExpRes) -> Builder SOACS ())
-> Builder SOACS ())
-> ((PatElem Type, SubExpRes) -> Builder SOACS ())
-> Builder SOACS ()
forall a b. (a -> b) -> a -> b
$ \(PatElem Type
pat_elm, SubExpRes
bdy_res) -> do
let (VName
nm, SubExp
se, Type
tp) = (PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName PatElem Type
pat_elm, SubExpRes -> SubExp
resSubExp SubExpRes
bdy_res, PatElem Type -> Type
forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem Type
pat_elm)
aux :: StmAux ()
aux = StmAux ()
empty_aux {stmAuxCerts = resCerts bdy_res}
Stm (Rep (BuilderT SOACS (State VNameSource))) -> Builder SOACS ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep (BuilderT SOACS (State VNameSource)))
-> Builder SOACS ())
-> Stm (Rep (BuilderT SOACS (State VNameSource)))
-> Builder SOACS ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (BuilderT SOACS (State VNameSource))))
-> StmAux (ExpDec (Rep (BuilderT SOACS (State VNameSource))))
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> Stm (Rep (BuilderT SOACS (State VNameSource)))
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [VName -> Type -> PatElem Type
forall dec. VName -> dec -> PatElem dec
PatElem VName
nm Type
tp]) StmAux ()
StmAux (ExpDec (Rep (BuilderT SOACS (State VNameSource))))
aux (Exp (Rep (BuilderT SOACS (State VNameSource)))
-> Stm (Rep (BuilderT SOACS (State VNameSource))))
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> Stm (Rep (BuilderT SOACS (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (BuilderT SOACS (State VNameSource)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT SOACS (State VNameSource))))
-> BasicOp -> Exp (Rep (BuilderT SOACS (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
Exp SOACS
soac' <- SOAC (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT
SOACS
(State VNameSource)
(Exp (Rep (BuilderT SOACS (State VNameSource))))
forall (m :: * -> *).
(MonadBuilder m, Op (Rep m) ~ SOAC (Rep m)) =>
SOAC (Rep m) -> m (Exp (Rep m))
H.toExp SOAC (Rep (BuilderT SOACS (State VNameSource)))
SOAC SOACS
soac
Stm (Rep (BuilderT SOACS (State VNameSource))) -> Builder SOACS ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep (BuilderT SOACS (State VNameSource)))
-> Builder SOACS ())
-> Stm (Rep (BuilderT SOACS (State VNameSource)))
-> Builder SOACS ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec SOACS)
pat2 StmAux (ExpDec SOACS)
aux2 Exp SOACS
soac'
let pat2_res :: Result
pat2_res = (PatElem Type -> SubExpRes) -> [PatElem Type] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (Certs -> SubExp -> SubExpRes
SubExpRes ([VName] -> Certs
Certs []) (SubExp -> SubExpRes)
-> (PatElem Type -> SubExp) -> PatElem Type -> SubExpRes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> SubExp)
-> (PatElem Type -> VName) -> PatElem Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName) ([PatElem Type] -> Result) -> [PatElem Type] -> Result
forall a b. (a -> b) -> a -> b
$ Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
pat2
Result -> Builder SOACS Result
forall a. a -> BuilderT SOACS (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> Builder SOACS Result) -> Result -> Builder SOACS Result
forall a b. (a -> b) -> a -> b
$ Result
wlam_res Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
pat2_res
let lam_ret_tp :: [Type]
lam_ret_tp = Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
wlam [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ (PatElem Type -> Type) -> [PatElem Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map PatElem Type -> Type
forall dec. Typed dec => PatElem dec -> Type
patElemType (Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
pat2)
pat :: Pat Type
pat = [PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem Type] -> Pat Type) -> [PatElem Type] -> Pat Type
forall a b. (a -> b) -> a -> b
$ Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
Pat (LetDec SOACS)
pat1 [PatElem Type] -> [PatElem Type] -> [PatElem Type]
forall a. [a] -> [a] -> [a]
++ Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
pat2
Lambda SOACS
wlam' <- Lambda SOACS -> FusionM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda (Lambda SOACS -> FusionM (Lambda SOACS))
-> Lambda SOACS -> FusionM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS
wlam {lambdaBody = bdy', lambdaReturnType = lam_ret_tp}
(Lambda SOACS
wlam'', Bool
success) <- Lambda SOACS -> FusionM (Lambda SOACS, Bool)
doFusionInLambda Lambda SOACS
wlam'
if Bool -> Bool
not Bool
success
then Maybe NodeT -> FusionM (Maybe NodeT)
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe NodeT
forall a. Maybe a
Nothing
else
NodeT -> FusionM (Maybe NodeT)
fusedSomething (NodeT -> FusionM (Maybe NodeT)) -> NodeT -> FusionM (Maybe NodeT)
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> NodeT
StmNode (Stm SOACS -> NodeT) -> Stm SOACS -> NodeT
forall a b. (a -> b) -> a -> b
$ Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux1 (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ [WithAccInput SOACS] -> Lambda SOACS -> Exp SOACS
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [WithAccInput SOACS]
w_inps Lambda SOACS
wlam''
vFuseNodeT
[EdgeT]
_edges
[VName]
infusible
(StmNode (Let Pat (LetDec SOACS)
pat1 StmAux (ExpDec SOACS)
aux1 (WithAcc [WithAccInput SOACS]
w1_inps Lambda SOACS
lam1)), [EdgeT]
is1, [EdgeT]
_os1)
(StmNode (Let Pat (LetDec SOACS)
pat2 StmAux (ExpDec SOACS)
aux2 (WithAcc [WithAccInput SOACS]
w2_inps Lambda SOACS
lam2)), [EdgeT]
_os2)
| Names
wacc2_cons_nms <- [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (WithAccInput SOACS -> [VName]) -> [WithAccInput SOACS] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\(ShapeBase SubExp
_, [VName]
nms, Maybe (Lambda SOACS, [SubExp])
_) -> [VName]
nms) [WithAccInput SOACS]
w2_inps,
[VName]
wacc1_indep_nms <- (EdgeT -> VName) -> [EdgeT] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map EdgeT -> VName
getName [EdgeT]
is1,
(VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> Names -> Bool
`notNameIn` Names
wacc2_cons_nms) [VName]
wacc1_indep_nms = do
Lambda SOACS
lam1' <- (Lambda SOACS, Bool) -> Lambda SOACS
forall a b. (a, b) -> a
fst ((Lambda SOACS, Bool) -> Lambda SOACS)
-> FusionM (Lambda SOACS, Bool) -> FusionM (Lambda SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Lambda SOACS -> FusionM (Lambda SOACS, Bool)
doFusionInLambda Lambda SOACS
lam1
Lambda SOACS
lam2' <- (Lambda SOACS, Bool) -> Lambda SOACS
forall a b. (a, b) -> a
fst ((Lambda SOACS, Bool) -> Lambda SOACS)
-> FusionM (Lambda SOACS, Bool) -> FusionM (Lambda SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Lambda SOACS -> FusionM (Lambda SOACS, Bool)
doFusionInLambda Lambda SOACS
lam2
let stm1 :: Stm SOACS
stm1 = Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
pat1 StmAux (ExpDec SOACS)
aux1 ([WithAccInput SOACS] -> Lambda SOACS -> Exp SOACS
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [WithAccInput SOACS]
w1_inps Lambda SOACS
lam1')
stm2 :: Stm SOACS
stm2 = Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
pat2 StmAux (ExpDec SOACS)
aux2 ([WithAccInput SOACS] -> Lambda SOACS -> Exp SOACS
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [WithAccInput SOACS]
w2_inps Lambda SOACS
lam2')
Maybe (Stm SOACS)
mstm <- [VName] -> Stm SOACS -> Stm SOACS -> FusionM (Maybe (Stm SOACS))
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
[VName] -> Stm SOACS -> Stm SOACS -> m (Maybe (Stm SOACS))
SF.tryFuseWithAccs [VName]
infusible Stm SOACS
stm1 Stm SOACS
stm2
case Maybe (Stm SOACS)
mstm of
Just (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (WithAcc [WithAccInput SOACS]
w_inps Lambda SOACS
wlam)) -> do
(Lambda SOACS
wlam', Bool
success) <- Lambda SOACS -> FusionM (Lambda SOACS, Bool)
doFusionInLambda Lambda SOACS
wlam
let new_stm :: Stm SOACS
new_stm = Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux ([WithAccInput SOACS] -> Lambda SOACS -> Exp SOACS
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [WithAccInput SOACS]
w_inps Lambda SOACS
wlam')
if Bool
success then NodeT -> FusionM (Maybe NodeT)
fusedSomething (Stm SOACS -> NodeT
StmNode Stm SOACS
new_stm) else Maybe NodeT -> FusionM (Maybe NodeT)
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe NodeT
forall a. Maybe a
Nothing
Maybe (Stm SOACS)
_ -> String -> FusionM (Maybe NodeT)
forall a. HasCallStack => String -> a
error String
"Illegal result of tryFuseWithAccs called from vFuseNodeT."
vFuseNodeT [EdgeT]
_ [VName]
_ (NodeT, [EdgeT], [EdgeT])
_ (NodeT, [EdgeT])
_ = Maybe NodeT -> FusionM (Maybe NodeT)
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe NodeT
forall a. Maybe a
Nothing
resFromLambda :: Lambda rep -> Result
resFromLambda :: forall rep. Lambda rep -> Result
resFromLambda = Body rep -> Result
forall rep. Body rep -> Result
bodyResult (Body rep -> Result)
-> (Lambda rep -> Body rep) -> Lambda rep -> Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody
hasNoDifferingInputs :: [H.Input] -> [H.Input] -> Bool
hasNoDifferingInputs :: [Input] -> [Input] -> Bool
hasNoDifferingInputs [Input]
is1 [Input]
is2 =
let ([Input]
vs1, [Input]
vs2) = ([Input] -> [Input]
isNotVarInput [Input]
is1, [Input] -> [Input]
isNotVarInput ([Input] -> [Input]) -> [Input] -> [Input]
forall a b. (a -> b) -> a -> b
$ [Input]
is2 [Input] -> [Input] -> [Input]
forall a. Eq a => [a] -> [a] -> [a]
L.\\ [Input]
is1)
in [Input] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([Input] -> Bool) -> [Input] -> Bool
forall a b. (a -> b) -> a -> b
$ [Input]
vs1 [Input] -> [Input] -> [Input]
forall a. Eq a => [a] -> [a] -> [a]
`L.intersect` [Input]
vs2
hFuseNodeT :: NodeT -> NodeT -> FusionM (Maybe NodeT)
hFuseNodeT :: NodeT -> NodeT -> FusionM (Maybe NodeT)
hFuseNodeT (SoacNode ArrayTransforms
ots1 Pat Type
pats1 SOAC SOACS
soac1 StmAux (ExpDec SOACS)
aux1) (SoacNode ArrayTransforms
ots2 Pat Type
pats2 SOAC SOACS
soac2 StmAux (ExpDec SOACS)
aux2)
| ArrayTransforms
ots1 ArrayTransforms -> ArrayTransforms -> Bool
forall a. Eq a => a -> a -> Bool
== ArrayTransforms
forall a. Monoid a => a
mempty,
ArrayTransforms
ots2 ArrayTransforms -> ArrayTransforms -> Bool
forall a. Eq a => a -> a -> Bool
== ArrayTransforms
forall a. Monoid a => a
mempty,
[Input] -> [Input] -> Bool
hasNoDifferingInputs (SOAC SOACS -> [Input]
forall rep. SOAC rep -> [Input]
H.inputs SOAC SOACS
soac1) (SOAC SOACS -> [Input]
forall rep. SOAC rep -> [Input]
H.inputs SOAC SOACS
soac2) = do
let ker :: FusedSOAC
ker =
TF.FusedSOAC
{ fsSOAC :: SOAC SOACS
TF.fsSOAC = SOAC SOACS
soac2,
fsOutputTransform :: ArrayTransforms
TF.fsOutputTransform = ArrayTransforms
forall a. Monoid a => a
mempty,
fsOutNames :: [VName]
TF.fsOutNames = Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
pats2
}
preserve :: Names
preserve = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
pats1
Maybe FusedSOAC
r <- Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> FusionM (Maybe FusedSOAC)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> m (Maybe FusedSOAC)
TF.attemptFusion Mode
TF.Horizontal Names
preserve (Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
pats1) SOAC SOACS
soac1 FusedSOAC
ker
case Maybe FusedSOAC
r of
Just FusedSOAC
ker' -> do
let pats2' :: [PatElem Type]
pats2' =
(VName -> Type -> PatElem Type)
-> [VName] -> [Type] -> [PatElem Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> Type -> PatElem Type
forall dec. VName -> dec -> PatElem dec
PatElem (FusedSOAC -> [VName]
TF.fsOutNames FusedSOAC
ker') (SOAC SOACS -> [Type]
forall rep. SOAC rep -> [Type]
H.typeOf (FusedSOAC -> SOAC SOACS
TF.fsSOAC FusedSOAC
ker'))
NodeT -> FusionM (Maybe NodeT)
fusedSomething (NodeT -> FusionM (Maybe NodeT)) -> NodeT -> FusionM (Maybe NodeT)
forall a b. (a -> b) -> a -> b
$ ArrayTransforms
-> Pat Type -> SOAC SOACS -> StmAux (ExpDec SOACS) -> NodeT
SoacNode ArrayTransforms
forall a. Monoid a => a
mempty ([PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type]
pats2') (FusedSOAC -> SOAC SOACS
TF.fsSOAC FusedSOAC
ker') (StmAux ()
StmAux (ExpDec SOACS)
aux1 StmAux () -> StmAux () -> StmAux ()
forall a. Semigroup a => a -> a -> a
<> StmAux ()
StmAux (ExpDec SOACS)
aux2)
Maybe FusedSOAC
Nothing -> Maybe NodeT -> FusionM (Maybe NodeT)
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe NodeT
forall a. Maybe a
Nothing
hFuseNodeT NodeT
_ NodeT
_ = Maybe NodeT -> FusionM (Maybe NodeT)
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe NodeT
forall a. Maybe a
Nothing
removeOutputsExcept :: [VName] -> NodeT -> NodeT
removeOutputsExcept :: [VName] -> NodeT -> NodeT
removeOutputsExcept [VName]
toKeep NodeT
s = case NodeT
s of
SoacNode ArrayTransforms
ots (Pat [PatElem Type]
pats1) soac :: SOAC SOACS
soac@(H.Screma SubExp
_ [Input]
_ (ScremaForm Lambda SOACS
lam_1 [Scan SOACS]
scans_1 [Reduce SOACS]
red_1)) StmAux (ExpDec SOACS)
aux1 ->
ArrayTransforms
-> Pat Type -> SOAC SOACS -> StmAux (ExpDec SOACS) -> NodeT
SoacNode ArrayTransforms
ots ([PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem Type] -> Pat Type) -> [PatElem Type] -> Pat Type
forall a b. (a -> b) -> a -> b
$ [PatElem Type]
pats_unchanged [PatElem Type] -> [PatElem Type] -> [PatElem Type]
forall a. Semigroup a => a -> a -> a
<> [PatElem Type]
pats_new) (Lambda SOACS -> SOAC SOACS -> SOAC SOACS
forall rep. Lambda rep -> SOAC rep -> SOAC rep
H.setLambda Lambda SOACS
lam_new SOAC SOACS
soac) StmAux (ExpDec SOACS)
aux1
where
scan_output_size :: Int
scan_output_size = [Scan SOACS] -> Int
forall rep. [Scan rep] -> Int
Futhark.scanResults [Scan SOACS]
scans_1
red_output_size :: Int
red_output_size = [Reduce SOACS] -> Int
forall rep. [Reduce rep] -> Int
Futhark.redResults [Reduce SOACS]
red_1
([PatElem Type]
pats_unchanged, [PatElem Type]
pats_toChange) = Int -> [PatElem Type] -> ([PatElem Type], [PatElem Type])
forall a. Int -> [a] -> ([a], [a])
splitAt (Int
scan_output_size Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
red_output_size) [PatElem Type]
pats1
([(SubExpRes, Type)]
res_unchanged, [(SubExpRes, Type)]
res_toChange) = Int
-> [(SubExpRes, Type)]
-> ([(SubExpRes, Type)], [(SubExpRes, Type)])
forall a. Int -> [a] -> ([a], [a])
splitAt (Int
scan_output_size Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
red_output_size) (Result -> [Type] -> [(SubExpRes, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda SOACS -> Result
forall rep. Lambda rep -> Result
resFromLambda Lambda SOACS
lam_1) (Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
lam_1))
([PatElem Type]
pats_new, [(SubExpRes, Type)]
other) = [(PatElem Type, (SubExpRes, Type))]
-> ([PatElem Type], [(SubExpRes, Type)])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(PatElem Type, (SubExpRes, Type))]
-> ([PatElem Type], [(SubExpRes, Type)]))
-> [(PatElem Type, (SubExpRes, Type))]
-> ([PatElem Type], [(SubExpRes, Type)])
forall a b. (a -> b) -> a -> b
$ ((PatElem Type, (SubExpRes, Type)) -> Bool)
-> [(PatElem Type, (SubExpRes, Type))]
-> [(PatElem Type, (SubExpRes, Type))]
forall a. (a -> Bool) -> [a] -> [a]
filter (\(PatElem Type
x, (SubExpRes, Type)
_) -> PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName PatElem Type
x VName -> [VName] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
toKeep) ([PatElem Type]
-> [(SubExpRes, Type)] -> [(PatElem Type, (SubExpRes, Type))]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem Type]
pats_toChange [(SubExpRes, Type)]
res_toChange)
(Result
results, [Type]
types) = [(SubExpRes, Type)] -> (Result, [Type])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(SubExpRes, Type)]
res_unchanged [(SubExpRes, Type)] -> [(SubExpRes, Type)] -> [(SubExpRes, Type)]
forall a. [a] -> [a] -> [a]
++ [(SubExpRes, Type)]
other)
lam_new :: Lambda SOACS
lam_new =
Lambda SOACS
lam_1
{ lambdaReturnType = types,
lambdaBody = (lambdaBody lam_1) {bodyResult = results}
}
NodeT
node -> NodeT
node
vNameFromAdj :: G.Node -> (EdgeT, G.Node) -> VName
vNameFromAdj :: Int -> (EdgeT, Int) -> VName
vNameFromAdj Int
n1 (EdgeT
edge, Int
n2) = DepEdge -> VName
depsFromEdge (Int
n2, Int
n1, EdgeT
edge)
removeUnusedOutputsFromContext :: DepContext -> FusionM DepContext
removeUnusedOutputsFromContext :: Context NodeT EdgeT -> FusionM (Context NodeT EdgeT)
removeUnusedOutputsFromContext (Adj EdgeT
incoming, Int
n1, NodeT
nodeT, Adj EdgeT
outgoing) =
Context NodeT EdgeT -> FusionM (Context NodeT EdgeT)
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Adj EdgeT
incoming, Int
n1, NodeT
nodeT', Adj EdgeT
outgoing)
where
toKeep :: [VName]
toKeep = ((EdgeT, Int) -> VName) -> Adj EdgeT -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> (EdgeT, Int) -> VName
vNameFromAdj Int
n1) Adj EdgeT
incoming
nodeT' :: NodeT
nodeT' = [VName] -> NodeT -> NodeT
removeOutputsExcept [VName]
toKeep NodeT
nodeT
removeUnusedOutputs :: DepGraphAug FusionM
removeUnusedOutputs :: DepGraphAug FusionM
removeUnusedOutputs = (Context NodeT EdgeT -> FusionM (Context NodeT EdgeT))
-> DepGraphAug FusionM
forall (m :: * -> *).
Monad m =>
(Context NodeT EdgeT -> m (Context NodeT EdgeT)) -> DepGraphAug m
mapAcross Context NodeT EdgeT -> FusionM (Context NodeT EdgeT)
removeUnusedOutputsFromContext
tryFuseNodeInGraph :: DepNode -> DepGraphAug FusionM
tryFuseNodeInGraph :: DepNode -> DepGraphAug FusionM
tryFuseNodeInGraph DepNode
node_to_fuse dg :: DepGraph
dg@DepGraph {dgGraph :: DepGraph -> Gr NodeT EdgeT
dgGraph = Gr NodeT EdgeT
g}
| Bool -> Bool
not (Int -> Gr NodeT EdgeT -> Bool
forall (gr :: * -> * -> *) a b. Graph gr => Int -> gr a b -> Bool
G.gelem (DepNode -> Int
nodeFromLNode DepNode
node_to_fuse) Gr NodeT EdgeT
g) = DepGraphAug FusionM
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure DepGraph
dg
tryFuseNodeInGraph DepNode
node_to_fuse dg :: DepGraph
dg@DepGraph {dgGraph :: DepGraph -> Gr NodeT EdgeT
dgGraph = Gr NodeT EdgeT
g} = do
Maybe DepGraph
spec_rule_res <- DepNode -> DepGraph -> FusionM (Maybe DepGraph)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
DepNode -> DepGraph -> m (Maybe DepGraph)
SF.ruleMFScat DepNode
node_to_fuse DepGraph
dg
case Maybe DepGraph
spec_rule_res of
Just DepGraph
dg' -> DepGraphAug FusionM
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure DepGraph
dg'
Maybe DepGraph
Nothing -> [DepGraphAug FusionM] -> DepGraphAug FusionM
forall (m :: * -> *). Monad m => [DepGraphAug m] -> DepGraphAug m
applyAugs ((Int -> DepGraphAug FusionM) -> [Int] -> [DepGraphAug FusionM]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> DepGraphAug FusionM
vTryFuseNodesInGraph Int
node_to_fuse_id) [Int]
fuses_with) DepGraph
dg
where
node_to_fuse_id :: Int
node_to_fuse_id = DepNode -> Int
nodeFromLNode DepNode
node_to_fuse
relevant :: (Int, EdgeT) -> Bool
relevant (Int
n, InfDep VName
_) = Int -> DepGraph -> Bool
isWithAccNodeId Int
n DepGraph
dg
relevant (Int
_, EdgeT
e) = EdgeT -> Bool
isDep EdgeT
e
fuses_with :: [Int]
fuses_with = ((Int, EdgeT) -> Int) -> [(Int, EdgeT)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int, EdgeT) -> Int
forall a b. (a, b) -> a
fst ([(Int, EdgeT)] -> [Int]) -> [(Int, EdgeT)] -> [Int]
forall a b. (a -> b) -> a -> b
$ ((Int, EdgeT) -> Bool) -> [(Int, EdgeT)] -> [(Int, EdgeT)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Int, EdgeT) -> Bool
relevant ([(Int, EdgeT)] -> [(Int, EdgeT)])
-> [(Int, EdgeT)] -> [(Int, EdgeT)]
forall a b. (a -> b) -> a -> b
$ Gr NodeT EdgeT -> Int -> [(Int, EdgeT)]
forall (gr :: * -> * -> *) a b.
Graph gr =>
gr a b -> Int -> [(Int, b)]
G.lpre Gr NodeT EdgeT
g Int
node_to_fuse_id
doVerticalFusion :: DepGraphAug FusionM
doVerticalFusion :: DepGraphAug FusionM
doVerticalFusion DepGraph
dg = [DepGraphAug FusionM] -> DepGraphAug FusionM
forall (m :: * -> *). Monad m => [DepGraphAug m] -> DepGraphAug m
applyAugs ((DepNode -> DepGraphAug FusionM)
-> [DepNode] -> [DepGraphAug FusionM]
forall a b. (a -> b) -> [a] -> [b]
map DepNode -> DepGraphAug FusionM
tryFuseNodeInGraph ([DepNode] -> [DepGraphAug FusionM])
-> [DepNode] -> [DepGraphAug FusionM]
forall a b. (a -> b) -> a -> b
$ [DepNode] -> [DepNode]
forall a. [a] -> [a]
reverse ([DepNode] -> [DepNode]) -> [DepNode] -> [DepNode]
forall a b. (a -> b) -> a -> b
$ (DepNode -> Bool) -> [DepNode] -> [DepNode]
forall a. (a -> Bool) -> [a] -> [a]
filter DepNode -> Bool
forall {a}. (a, NodeT) -> Bool
relevant ([DepNode] -> [DepNode]) -> [DepNode] -> [DepNode]
forall a b. (a -> b) -> a -> b
$ Gr NodeT EdgeT -> [DepNode]
forall a b. Gr a b -> [LNode a]
forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> [LNode a]
G.labNodes (DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg)) DepGraph
dg
where
relevant :: (a, NodeT) -> Bool
relevant (a
_, n :: NodeT
n@(StmNode {})) = NodeT -> Bool
isWithAccNodeT NodeT
n
relevant (a
_, ResNode {}) = Bool
False
relevant (a, NodeT)
_ = Bool
True
doHorizontalFusion :: DepGraphAug FusionM
doHorizontalFusion :: DepGraphAug FusionM
doHorizontalFusion DepGraph
dg = [DepGraphAug FusionM] -> DepGraphAug FusionM
forall (m :: * -> *). Monad m => [DepGraphAug m] -> DepGraphAug m
applyAugs [DepGraphAug FusionM]
pairs DepGraph
dg
where
pairs :: [DepGraphAug FusionM]
pairs :: [DepGraphAug FusionM]
pairs = do
(Int
x, SoacNode ArrayTransforms
_ Pat Type
_ SOAC SOACS
soac_x StmAux (ExpDec SOACS)
_) <- Gr NodeT EdgeT -> [DepNode]
forall a b. Gr a b -> [LNode a]
forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> [LNode a]
G.labNodes (Gr NodeT EdgeT -> [DepNode]) -> Gr NodeT EdgeT -> [DepNode]
forall a b. (a -> b) -> a -> b
$ DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg
(Int
y, SoacNode ArrayTransforms
_ Pat Type
_ SOAC SOACS
soac_y StmAux (ExpDec SOACS)
_) <- Gr NodeT EdgeT -> [DepNode]
forall a b. Gr a b -> [LNode a]
forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> [LNode a]
G.labNodes (Gr NodeT EdgeT -> [DepNode]) -> Gr NodeT EdgeT -> [DepNode]
forall a b. (a -> b) -> a -> b
$ DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg
Bool -> [()]
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> [()]) -> Bool -> [()]
forall a b. (a -> b) -> a -> b
$ Int
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
y
Bool -> [()]
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> [()]) -> Bool -> [()]
forall a b. (a -> b) -> a -> b
$
(Input -> Bool) -> [Input] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any
((VName -> [VName] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` (Input -> VName) -> [Input] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Input -> VName
H.inputArray (SOAC SOACS -> [Input]
forall rep. SOAC rep -> [Input]
H.inputs SOAC SOACS
soac_x)) (VName -> Bool) -> (Input -> VName) -> Input -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Input -> VName
H.inputArray)
(SOAC SOACS -> [Input]
forall rep. SOAC rep -> [Input]
H.inputs SOAC SOACS
soac_y)
DepGraphAug FusionM -> [DepGraphAug FusionM]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DepGraphAug FusionM -> [DepGraphAug FusionM])
-> DepGraphAug FusionM -> [DepGraphAug FusionM]
forall a b. (a -> b) -> a -> b
$ \DepGraph
dg' -> do
if Int -> Gr NodeT EdgeT -> Bool
forall (gr :: * -> * -> *) a b. Graph gr => Int -> gr a b -> Bool
G.gelem Int
x (DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg') Bool -> Bool -> Bool
&& Int -> Gr NodeT EdgeT -> Bool
forall (gr :: * -> * -> *) a b. Graph gr => Int -> gr a b -> Bool
G.gelem Int
y (DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg')
then Int -> Int -> DepGraphAug FusionM
hTryFuseNodesInGraph Int
x Int
y DepGraph
dg'
else DepGraphAug FusionM
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure DepGraph
dg'
doInnerFusion :: DepGraphAug FusionM
doInnerFusion :: DepGraphAug FusionM
doInnerFusion = (Context NodeT EdgeT -> FusionM (Context NodeT EdgeT))
-> DepGraphAug FusionM
forall (m :: * -> *).
Monad m =>
(Context NodeT EdgeT -> m (Context NodeT EdgeT)) -> DepGraphAug m
mapAcross Context NodeT EdgeT -> FusionM (Context NodeT EdgeT)
runInnerFusionOnContext
keepTrying :: DepGraphAug FusionM -> DepGraphAug FusionM
keepTrying :: DepGraphAug FusionM -> DepGraphAug FusionM
keepTrying DepGraphAug FusionM
f DepGraph
g = do
Int
prev_fused <- (FusionEnv -> Int) -> FusionM Int
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets FusionEnv -> Int
fusionCount
DepGraph
g' <- DepGraphAug FusionM
f DepGraph
g
Int
aft_fused <- (FusionEnv -> Int) -> FusionM Int
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets FusionEnv -> Int
fusionCount
if Int
prev_fused Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
aft_fused then DepGraphAug FusionM -> DepGraphAug FusionM
keepTrying DepGraphAug FusionM
f DepGraph
g' else DepGraphAug FusionM
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure DepGraph
g'
doAllFusion :: DepGraphAug FusionM
doAllFusion :: DepGraphAug FusionM
doAllFusion =
[DepGraphAug FusionM] -> DepGraphAug FusionM
forall (m :: * -> *). Monad m => [DepGraphAug m] -> DepGraphAug m
applyAugs
[ DepGraphAug FusionM -> DepGraphAug FusionM
keepTrying (DepGraphAug FusionM -> DepGraphAug FusionM)
-> ([DepGraphAug FusionM] -> DepGraphAug FusionM)
-> [DepGraphAug FusionM]
-> DepGraphAug FusionM
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [DepGraphAug FusionM] -> DepGraphAug FusionM
forall (m :: * -> *). Monad m => [DepGraphAug m] -> DepGraphAug m
applyAugs ([DepGraphAug FusionM] -> DepGraphAug FusionM)
-> [DepGraphAug FusionM] -> DepGraphAug FusionM
forall a b. (a -> b) -> a -> b
$
[ DepGraphAug FusionM
doVerticalFusion,
DepGraphAug FusionM
doHorizontalFusion,
DepGraphAug FusionM
doInnerFusion
],
DepGraphAug FusionM
removeUnusedOutputs
]
runInnerFusionOnContext :: DepContext -> FusionM DepContext
runInnerFusionOnContext :: Context NodeT EdgeT -> FusionM (Context NodeT EdgeT)
runInnerFusionOnContext c :: Context NodeT EdgeT
c@(Adj EdgeT
incoming, Int
node, NodeT
nodeT, Adj EdgeT
outgoing) = case NodeT
nodeT of
DoNode (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Loop [(FParam SOACS, SubExp)]
params LoopForm
form Body SOACS
body)) [(NodeT, [EdgeT])]
to_fuse ->
FusionM (Context NodeT EdgeT) -> FusionM (Context NodeT EdgeT)
forall a. FusionM a -> FusionM a
doFuseScans (FusionM (Context NodeT EdgeT) -> FusionM (Context NodeT EdgeT))
-> (FusionM (Context NodeT EdgeT) -> FusionM (Context NodeT EdgeT))
-> FusionM (Context NodeT EdgeT)
-> FusionM (Context NodeT EdgeT)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope SOACS
-> FusionM (Context NodeT EdgeT) -> FusionM (Context NodeT EdgeT)
forall a. Scope SOACS -> FusionM a -> FusionM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([FParam SOACS] -> Scope SOACS
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams (((FParam SOACS, SubExp) -> FParam SOACS)
-> [(FParam SOACS, SubExp)] -> [FParam SOACS]
forall a b. (a -> b) -> [a] -> [b]
map (FParam SOACS, SubExp) -> FParam SOACS
forall a b. (a, b) -> a
fst [(FParam SOACS, SubExp)]
params) Scope SOACS -> Scope SOACS -> Scope SOACS
forall a. Semigroup a => a -> a -> a
<> LoopForm -> Scope SOACS
forall rep. LoopForm -> Scope rep
scopeOfLoopForm LoopForm
form) (FusionM (Context NodeT EdgeT) -> FusionM (Context NodeT EdgeT))
-> FusionM (Context NodeT EdgeT) -> FusionM (Context NodeT EdgeT)
forall a b. (a -> b) -> a -> b
$ do
Body SOACS
b <- Body SOACS -> [(NodeT, [EdgeT])] -> FusionM (Body SOACS)
doFusionWithDelayed Body SOACS
body [(NodeT, [EdgeT])]
to_fuse
Context NodeT EdgeT -> FusionM (Context NodeT EdgeT)
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Adj EdgeT
incoming, Int
node, Stm SOACS -> [(NodeT, [EdgeT])] -> NodeT
DoNode (Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux ([(FParam SOACS, SubExp)] -> LoopForm -> Body SOACS -> Exp SOACS
forall rep.
[(FParam rep, SubExp)] -> LoopForm -> Body rep -> Exp rep
Loop [(FParam SOACS, SubExp)]
params LoopForm
form Body SOACS
b)) [], Adj EdgeT
outgoing)
MatchNode (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Match [SubExp]
cond [Case (Body SOACS)]
cases Body SOACS
defbody MatchDec (BranchType SOACS)
dec)) [(NodeT, [EdgeT])]
to_fuse -> FusionM (Context NodeT EdgeT) -> FusionM (Context NodeT EdgeT)
forall a. FusionM a -> FusionM a
doFuseScans (FusionM (Context NodeT EdgeT) -> FusionM (Context NodeT EdgeT))
-> FusionM (Context NodeT EdgeT) -> FusionM (Context NodeT EdgeT)
forall a b. (a -> b) -> a -> b
$ do
[Case (Body SOACS)]
cases' <- (Case (Body SOACS) -> FusionM (Case (Body SOACS)))
-> [Case (Body SOACS)] -> FusionM [Case (Body SOACS)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((Body SOACS -> FusionM (Body SOACS))
-> Case (Body SOACS) -> FusionM (Case (Body SOACS))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Case a -> f (Case b)
traverse ((Body SOACS -> FusionM (Body SOACS))
-> Case (Body SOACS) -> FusionM (Case (Body SOACS)))
-> (Body SOACS -> FusionM (Body SOACS))
-> Case (Body SOACS)
-> FusionM (Case (Body SOACS))
forall a b. (a -> b) -> a -> b
$ Body SOACS -> FusionM (Body SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody (Body SOACS -> FusionM (Body SOACS))
-> (Body SOACS -> FusionM (Body SOACS))
-> Body SOACS
-> FusionM (Body SOACS)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< (Body SOACS -> [(NodeT, [EdgeT])] -> FusionM (Body SOACS)
`doFusionWithDelayed` [(NodeT, [EdgeT])]
to_fuse)) [Case (Body SOACS)]
cases
Body SOACS
defbody' <- Body SOACS -> [(NodeT, [EdgeT])] -> FusionM (Body SOACS)
doFusionWithDelayed Body SOACS
defbody [(NodeT, [EdgeT])]
to_fuse
Context NodeT EdgeT -> FusionM (Context NodeT EdgeT)
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Adj EdgeT
incoming, Int
node, Stm SOACS -> [(NodeT, [EdgeT])] -> NodeT
MatchNode (Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux ([SubExp]
-> [Case (Body SOACS)]
-> Body SOACS
-> MatchDec (BranchType SOACS)
-> Exp SOACS
forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond [Case (Body SOACS)]
cases' Body SOACS
defbody' MatchDec (BranchType SOACS)
dec)) [], Adj EdgeT
outgoing)
StmNode (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op (Futhark.VJP [SubExp]
args [SubExp]
vec Lambda SOACS
lam))) -> FusionM (Context NodeT EdgeT) -> FusionM (Context NodeT EdgeT)
forall a. FusionM a -> FusionM a
doFuseScans (FusionM (Context NodeT EdgeT) -> FusionM (Context NodeT EdgeT))
-> FusionM (Context NodeT EdgeT) -> FusionM (Context NodeT EdgeT)
forall a b. (a -> b) -> a -> b
$ do
Lambda SOACS
lam' <- (Lambda SOACS, Bool) -> Lambda SOACS
forall a b. (a, b) -> a
fst ((Lambda SOACS, Bool) -> Lambda SOACS)
-> FusionM (Lambda SOACS, Bool) -> FusionM (Lambda SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Lambda SOACS -> FusionM (Lambda SOACS, Bool)
doFusionInLambda Lambda SOACS
lam
Context NodeT EdgeT -> FusionM (Context NodeT EdgeT)
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Adj EdgeT
incoming, Int
node, Stm SOACS -> NodeT
StmNode (Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op ([SubExp] -> [SubExp] -> Lambda SOACS -> SOAC SOACS
forall rep. [SubExp] -> [SubExp] -> Lambda rep -> SOAC rep
Futhark.VJP [SubExp]
args [SubExp]
vec Lambda SOACS
lam'))), Adj EdgeT
outgoing)
StmNode (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op (Futhark.JVP [SubExp]
args [SubExp]
vec Lambda SOACS
lam))) -> FusionM (Context NodeT EdgeT) -> FusionM (Context NodeT EdgeT)
forall a. FusionM a -> FusionM a
doFuseScans (FusionM (Context NodeT EdgeT) -> FusionM (Context NodeT EdgeT))
-> FusionM (Context NodeT EdgeT) -> FusionM (Context NodeT EdgeT)
forall a b. (a -> b) -> a -> b
$ do
Lambda SOACS
lam' <- (Lambda SOACS, Bool) -> Lambda SOACS
forall a b. (a, b) -> a
fst ((Lambda SOACS, Bool) -> Lambda SOACS)
-> FusionM (Lambda SOACS, Bool) -> FusionM (Lambda SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Lambda SOACS -> FusionM (Lambda SOACS, Bool)
doFusionInLambda Lambda SOACS
lam
Context NodeT EdgeT -> FusionM (Context NodeT EdgeT)
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Adj EdgeT
incoming, Int
node, Stm SOACS -> NodeT
StmNode (Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op ([SubExp] -> [SubExp] -> Lambda SOACS -> SOAC SOACS
forall rep. [SubExp] -> [SubExp] -> Lambda rep -> SOAC rep
Futhark.JVP [SubExp]
args [SubExp]
vec Lambda SOACS
lam'))), Adj EdgeT
outgoing)
StmNode (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (WithAcc [WithAccInput SOACS]
inputs Lambda SOACS
lam)) -> FusionM (Context NodeT EdgeT) -> FusionM (Context NodeT EdgeT)
forall a. FusionM a -> FusionM a
doFuseScans (FusionM (Context NodeT EdgeT) -> FusionM (Context NodeT EdgeT))
-> FusionM (Context NodeT EdgeT) -> FusionM (Context NodeT EdgeT)
forall a b. (a -> b) -> a -> b
$ do
Lambda SOACS
lam' <- (Lambda SOACS, Bool) -> Lambda SOACS
forall a b. (a, b) -> a
fst ((Lambda SOACS, Bool) -> Lambda SOACS)
-> FusionM (Lambda SOACS, Bool) -> FusionM (Lambda SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Lambda SOACS -> FusionM (Lambda SOACS, Bool)
doFusionInLambda Lambda SOACS
lam
Context NodeT EdgeT -> FusionM (Context NodeT EdgeT)
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Adj EdgeT
incoming, Int
node, Stm SOACS -> NodeT
StmNode (Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux ([WithAccInput SOACS] -> Lambda SOACS -> Exp SOACS
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [WithAccInput SOACS]
inputs Lambda SOACS
lam')), Adj EdgeT
outgoing)
SoacNode ArrayTransforms
ots Pat Type
pat SOAC SOACS
soac StmAux (ExpDec SOACS)
aux -> do
let lam :: Lambda SOACS
lam = SOAC SOACS -> Lambda SOACS
forall rep. SOAC rep -> Lambda rep
H.lambda SOAC SOACS
soac
Lambda SOACS
lam' <- Lambda SOACS -> FusionM (Lambda SOACS) -> FusionM (Lambda SOACS)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Lambda SOACS
lam (FusionM (Lambda SOACS) -> FusionM (Lambda SOACS))
-> FusionM (Lambda SOACS) -> FusionM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ case SOAC SOACS
soac of
H.Stream {} ->
FusionM (Lambda SOACS) -> FusionM (Lambda SOACS)
forall a. FusionM a -> FusionM a
dontFuseScans (FusionM (Lambda SOACS) -> FusionM (Lambda SOACS))
-> FusionM (Lambda SOACS) -> FusionM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ (Lambda SOACS, Bool) -> Lambda SOACS
forall a b. (a, b) -> a
fst ((Lambda SOACS, Bool) -> Lambda SOACS)
-> FusionM (Lambda SOACS, Bool) -> FusionM (Lambda SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Lambda SOACS -> FusionM (Lambda SOACS, Bool)
doFusionInLambda Lambda SOACS
lam
SOAC SOACS
_ ->
FusionM (Lambda SOACS) -> FusionM (Lambda SOACS)
forall a. FusionM a -> FusionM a
doFuseScans (FusionM (Lambda SOACS) -> FusionM (Lambda SOACS))
-> FusionM (Lambda SOACS) -> FusionM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ (Lambda SOACS, Bool) -> Lambda SOACS
forall a b. (a, b) -> a
fst ((Lambda SOACS, Bool) -> Lambda SOACS)
-> FusionM (Lambda SOACS, Bool) -> FusionM (Lambda SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Lambda SOACS -> FusionM (Lambda SOACS, Bool)
doFusionInLambda Lambda SOACS
lam
let nodeT' :: NodeT
nodeT' = ArrayTransforms
-> Pat Type -> SOAC SOACS -> StmAux (ExpDec SOACS) -> NodeT
SoacNode ArrayTransforms
ots Pat Type
pat (Lambda SOACS -> SOAC SOACS -> SOAC SOACS
forall rep. Lambda rep -> SOAC rep -> SOAC rep
H.setLambda Lambda SOACS
lam' SOAC SOACS
soac) StmAux (ExpDec SOACS)
aux
Context NodeT EdgeT -> FusionM (Context NodeT EdgeT)
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Adj EdgeT
incoming, Int
node, NodeT
nodeT', Adj EdgeT
outgoing)
NodeT
_ -> Context NodeT EdgeT -> FusionM (Context NodeT EdgeT)
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Context NodeT EdgeT
c
where
doFusionWithDelayed :: Body SOACS -> [(NodeT, [EdgeT])] -> FusionM (Body SOACS)
doFusionWithDelayed :: Body SOACS -> [(NodeT, [EdgeT])] -> FusionM (Body SOACS)
doFusionWithDelayed (Body () Stms SOACS
stms Result
res) [(NodeT, [EdgeT])]
extraNodes = Stms SOACS -> FusionM (Body SOACS) -> FusionM (Body SOACS)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms SOACS
stms (FusionM (Body SOACS) -> FusionM (Body SOACS))
-> FusionM (Body SOACS) -> FusionM (Body SOACS)
forall a b. (a -> b) -> a -> b
$ do
[Stms SOACS]
stm_node <- ((NodeT, [EdgeT]) -> FusionM (Stms SOACS))
-> [(NodeT, [EdgeT])] -> FusionM [Stms SOACS]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (NodeT -> FusionM (Stms SOACS)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
NodeT -> m (Stms SOACS)
finalizeNode (NodeT -> FusionM (Stms SOACS))
-> ((NodeT, [EdgeT]) -> NodeT)
-> (NodeT, [EdgeT])
-> FusionM (Stms SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (NodeT, [EdgeT]) -> NodeT
forall a b. (a, b) -> a
fst) [(NodeT, [EdgeT])]
extraNodes
Stms SOACS
stms' <- Body SOACS -> FusionM (Stms SOACS)
fuseGraph (Stms SOACS -> Result -> Body SOACS
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody ([Stms SOACS] -> Stms SOACS
forall a. Monoid a => [a] -> a
mconcat [Stms SOACS]
stm_node Stms SOACS -> Stms SOACS -> Stms SOACS
forall a. Semigroup a => a -> a -> a
<> Stms SOACS
stms) Result
res)
Body SOACS -> FusionM (Body SOACS)
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body SOACS -> FusionM (Body SOACS))
-> Body SOACS -> FusionM (Body SOACS)
forall a b. (a -> b) -> a -> b
$ BodyDec SOACS -> Stms SOACS -> Result -> Body SOACS
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms SOACS
stms' Result
res
doFusionInLambda :: Lambda SOACS -> FusionM (Lambda SOACS, Bool)
doFusionInLambda :: Lambda SOACS -> FusionM (Lambda SOACS, Bool)
doFusionInLambda Lambda SOACS
lam = do
Lambda SOACS
lam' <- Lambda SOACS -> FusionM (Lambda SOACS)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Lambda SOACS -> m (Lambda SOACS)
simplifyLambda Lambda SOACS
lam
Int
prev_count <- (FusionEnv -> Int) -> FusionM Int
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets FusionEnv -> Int
fusionCount
Body SOACS
newbody <- Lambda SOACS -> FusionM (Body SOACS) -> FusionM (Body SOACS)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Lambda SOACS
lam' (FusionM (Body SOACS) -> FusionM (Body SOACS))
-> FusionM (Body SOACS) -> FusionM (Body SOACS)
forall a b. (a -> b) -> a -> b
$ Body SOACS -> FusionM (Body SOACS)
doFusionBody (Body SOACS -> FusionM (Body SOACS))
-> Body SOACS -> FusionM (Body SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam'
Int
aft_count <- (FusionEnv -> Int) -> FusionM Int
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets FusionEnv -> Int
fusionCount
Lambda SOACS
lam'' <-
(if Int
prev_count Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
aft_count then Lambda SOACS -> FusionM (Lambda SOACS)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Lambda SOACS -> m (Lambda SOACS)
simplifyLambda else Lambda SOACS -> FusionM (Lambda SOACS)
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure)
Lambda SOACS
lam' {lambdaBody = newbody}
(Lambda SOACS, Bool) -> FusionM (Lambda SOACS, Bool)
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda SOACS
lam'', Int
prev_count Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
aft_count)
where
doFusionBody :: Body SOACS -> FusionM (Body SOACS)
doFusionBody :: Body SOACS -> FusionM (Body SOACS)
doFusionBody Body SOACS
body = do
Stms SOACS
stms' <- Body SOACS -> FusionM (Stms SOACS)
fuseGraph Body SOACS
body
Body SOACS -> FusionM (Body SOACS)
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body SOACS -> FusionM (Body SOACS))
-> Body SOACS -> FusionM (Body SOACS)
forall a b. (a -> b) -> a -> b
$ Body SOACS
body {bodyStms = stms'}
fuseGraph :: Body SOACS -> FusionM (Stms SOACS)
fuseGraph :: Body SOACS -> FusionM (Stms SOACS)
fuseGraph Body SOACS
body = Stms SOACS -> FusionM (Stms SOACS) -> FusionM (Stms SOACS)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf (Body SOACS -> Stms SOACS
forall rep. Body rep -> Stms rep
bodyStms Body SOACS
body) (FusionM (Stms SOACS) -> FusionM (Stms SOACS))
-> FusionM (Stms SOACS) -> FusionM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ do
DepGraph
graph_not_fused <- Body SOACS -> FusionM DepGraph
forall (m :: * -> *).
(HasScope SOACS m, Monad m) =>
Body SOACS -> m DepGraph
mkDepGraph Body SOACS
body
DepGraph
graph_fused <- DepGraphAug FusionM
doAllFusion DepGraph
graph_not_fused
DepGraph -> FusionM (Stms SOACS)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
DepGraph -> m (Stms SOACS)
linearizeGraph DepGraph
graph_fused
fuseConsts :: [VName] -> Stms SOACS -> PassM (Stms SOACS)
fuseConsts :: [VName] -> Stms SOACS -> PassM (Stms SOACS)
fuseConsts [VName]
outputs Stms SOACS
stms =
Scope SOACS
-> FusionEnv -> FusionM (Stms SOACS) -> PassM (Stms SOACS)
forall (m :: * -> *) a.
MonadFreshNames m =>
Scope SOACS -> FusionEnv -> FusionM a -> m a
runFusionM
(Stms SOACS -> Scope SOACS
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms SOACS
stms)
FusionEnv
freshFusionEnv
(Body SOACS -> FusionM (Stms SOACS)
fuseGraph (Stms SOACS -> Result -> Body SOACS
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody Stms SOACS
stms ([VName] -> Result
varsRes [VName]
outputs)))
fuseFun :: Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS)
fuseFun :: Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS)
fuseFun Stms SOACS
consts FunDef SOACS
fun = do
Stms SOACS
fun_stms' <-
Scope SOACS
-> FusionEnv -> FusionM (Stms SOACS) -> PassM (Stms SOACS)
forall (m :: * -> *) a.
MonadFreshNames m =>
Scope SOACS -> FusionEnv -> FusionM a -> m a
runFusionM
(FunDef SOACS -> Scope SOACS
forall rep a. Scoped rep a => a -> Scope rep
scopeOf FunDef SOACS
fun Scope SOACS -> Scope SOACS -> Scope SOACS
forall a. Semigroup a => a -> a -> a
<> Stms SOACS -> Scope SOACS
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms SOACS
consts)
FusionEnv
freshFusionEnv
(Body SOACS -> FusionM (Stms SOACS)
fuseGraph (FunDef SOACS -> Body SOACS
forall rep. FunDef rep -> Body rep
funDefBody FunDef SOACS
fun))
FunDef SOACS -> PassM (FunDef SOACS)
forall a. a -> PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure FunDef SOACS
fun {funDefBody = (funDefBody fun) {bodyStms = fun_stms'}}
{-# NOINLINE fuseSOACs #-}
fuseSOACs :: Pass SOACS SOACS
fuseSOACs :: Pass SOACS SOACS
fuseSOACs =
Pass
{ passName :: String
passName = String
"Fuse SOACs",
passDescription :: String
passDescription = String
"Perform higher-order optimisation, i.e., fusion.",
passFunction :: Prog SOACS -> PassM (Prog SOACS)
passFunction = \Prog SOACS
p ->
(Stms SOACS -> PassM (Stms SOACS))
-> (Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS))
-> Prog SOACS
-> PassM (Prog SOACS)
forall fromrep torep.
(Stms fromrep -> PassM (Stms torep))
-> (Stms torep -> FunDef fromrep -> PassM (FunDef torep))
-> Prog fromrep
-> PassM (Prog torep)
intraproceduralTransformationWithConsts
([VName] -> Stms SOACS -> PassM (Stms SOACS)
fuseConsts (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ [FunDef SOACS] -> Names
forall a. FreeIn a => a -> Names
freeIn (Prog SOACS -> [FunDef SOACS]
forall rep. Prog rep -> [FunDef rep]
progFuns Prog SOACS
p)))
Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS)
fuseFun
Prog SOACS
p
}