{-# LANGUAGE Strict #-}
module Futhark.Optimise.Fusion.RulesWithAccs
( ruleMFScat,
tryFuseWithAccs,
)
where
import Control.Monad
import Data.Graph.Inductive.Graph qualified as G
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.Analysis.HORep.SOAC qualified as H
import Futhark.Construct
import Futhark.IR.SOACS hiding (SOAC (..))
import Futhark.IR.SOACS qualified as F
import Futhark.Optimise.Fusion.GraphRep
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
se0 :: SubExp
se0 :: SubExp
se0 = IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0
se1 :: SubExp
se1 :: SubExp
se1 = IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1
type IotaInp = ((VName, LParam SOACS), (SubExp, SubExp, SubExp, IntType))
type RshpInp = ((VName, LParam SOACS), (Shape, Shape, Type))
ruleMFScat ::
(HasScope SOACS m, MonadFreshNames m) =>
DepNode ->
DepGraph ->
m (Maybe DepGraph)
ruleMFScat :: forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
DepNode -> DepGraph -> m (Maybe DepGraph)
ruleMFScat DepNode
node_to_fuse dg :: DepGraph
dg@DepGraph {dgGraph :: DepGraph -> Gr NodeT EdgeT
dgGraph = Gr NodeT EdgeT
g}
| NodeT
soac_nodeT <- DepNode -> NodeT
forall a b. (a, b) -> b
snd DepNode
node_to_fuse,
Int
scat_node_id <- DepNode -> Int
nodeFromLNode DepNode
node_to_fuse,
SoacNode ArrayTransforms
node_out_trsfs Pat Type
scat_pat SOAC SOACS
scat_soac StmAux (ExpDec SOACS)
scat_aux <- NodeT
soac_nodeT,
ArrayTransforms -> Bool
H.nullTransforms ArrayTransforms
node_out_trsfs,
H.Scatter SubExp
_len [Input]
scat_inp ScatterSpec VName
scat_out Lambda SOACS
scat_lam <- SOAC SOACS
scat_soac,
[ArrayTransforms]
scat_trsfs <- (Input -> ArrayTransforms) -> [Input] -> [ArrayTransforms]
forall a b. (a -> b) -> [a] -> [b]
map Input -> ArrayTransforms
H.inputTransforms (SOAC SOACS -> [Input]
forall rep. SOAC rep -> [Input]
H.inputs SOAC SOACS
scat_soac),
(ArrayTransforms -> Bool) -> [ArrayTransforms] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (ArrayTransforms -> ArrayTransforms -> Bool
forall a. Eq a => a -> a -> Bool
/= ArrayTransforms
forall a. Monoid a => a
mempty) [ArrayTransforms]
scat_trsfs,
Context NodeT EdgeT
scat_ctx <- Gr NodeT EdgeT -> Int -> Context NodeT EdgeT
forall (gr :: * -> * -> *) a b.
Graph gr =>
gr a b -> Int -> Context a b
G.context Gr NodeT EdgeT
g Int
scat_node_id,
(Adj EdgeT
out_deps, Int
_, NodeT
_, Adj EdgeT
inp_deps) <- Context NodeT EdgeT
scat_ctx,
Adj EdgeT
cons_deps <- ((EdgeT, Int) -> Bool) -> Adj EdgeT -> Adj EdgeT
forall a. (a -> Bool) -> [a] -> [a]
filter (EdgeT -> Bool
isCons (EdgeT -> Bool) -> ((EdgeT, Int) -> EdgeT) -> (EdgeT, Int) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (EdgeT, Int) -> EdgeT
forall a b. (a, b) -> a
fst) Adj EdgeT
inp_deps,
Adj EdgeT
drct_deps <- ((EdgeT, Int) -> Bool) -> Adj EdgeT -> Adj EdgeT
forall a. (a -> Bool) -> [a] -> [a]
filter (EdgeT -> Bool
isDep (EdgeT -> Bool) -> ((EdgeT, Int) -> EdgeT) -> (EdgeT, Int) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (EdgeT, Int) -> EdgeT
forall a b. (a, b) -> a
fst) Adj EdgeT
inp_deps,
[Context NodeT EdgeT]
cons_ctxs <- ((EdgeT, Int) -> Context NodeT EdgeT)
-> Adj EdgeT -> [Context NodeT EdgeT]
forall a b. (a -> b) -> [a] -> [b]
map (Gr NodeT EdgeT -> Int -> Context NodeT EdgeT
forall (gr :: * -> * -> *) a b.
Graph gr =>
gr a b -> Int -> Context a b
G.context Gr NodeT EdgeT
g (Int -> Context NodeT EdgeT)
-> ((EdgeT, Int) -> Int) -> (EdgeT, Int) -> Context NodeT EdgeT
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (EdgeT, Int) -> Int
forall a b. (a, b) -> b
snd) Adj EdgeT
cons_deps,
[Context NodeT EdgeT]
drct_ctxs <- ((EdgeT, Int) -> Context NodeT EdgeT)
-> Adj EdgeT -> [Context NodeT EdgeT]
forall a b. (a -> b) -> [a] -> [b]
map (Gr NodeT EdgeT -> Int -> Context NodeT EdgeT
forall (gr :: * -> * -> *) a b.
Graph gr =>
gr a b -> Int -> Context a b
G.context Gr NodeT EdgeT
g (Int -> Context NodeT EdgeT)
-> ((EdgeT, Int) -> Int) -> (EdgeT, Int) -> Context NodeT EdgeT
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (EdgeT, Int) -> Int
forall a b. (a, b) -> b
snd) Adj EdgeT
drct_deps,
[NodeT]
_cons_nTs <- (Context NodeT EdgeT -> NodeT) -> [Context NodeT EdgeT] -> [NodeT]
forall a b. (a -> b) -> [a] -> [b]
map Context NodeT EdgeT -> NodeT
forall {a} {b} {c} {d}. (a, b, c, d) -> c
getNodeTfromCtx [Context NodeT EdgeT]
cons_ctxs,
[(Context NodeT EdgeT, (Input, NodeT))]
drct_tups0 <- (Input -> Maybe (Context NodeT EdgeT, (Input, NodeT)))
-> [Input] -> [(Context NodeT EdgeT, (Input, NodeT))]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe ([(Context NodeT EdgeT, EdgeT)]
-> Input -> Maybe (Context NodeT EdgeT, (Input, NodeT))
pairUp ([Context NodeT EdgeT] -> [EdgeT] -> [(Context NodeT EdgeT, EdgeT)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Context NodeT EdgeT]
drct_ctxs (((EdgeT, Int) -> EdgeT) -> Adj EdgeT -> [EdgeT]
forall a b. (a -> b) -> [a] -> [b]
map (EdgeT, Int) -> EdgeT
forall a b. (a, b) -> a
fst Adj EdgeT
drct_deps))) [Input]
scat_inp,
[(Context NodeT EdgeT, (Input, NodeT))] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Context NodeT EdgeT, (Input, NodeT))]
drct_tups0 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Input] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Input]
scat_inp,
([Context NodeT EdgeT]
t1s, [(Input, NodeT)]
t2s) <- [(Context NodeT EdgeT, (Input, NodeT))]
-> ([Context NodeT EdgeT], [(Input, NodeT)])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Context NodeT EdgeT, (Input, NodeT))]
drct_tups0,
[(Context NodeT EdgeT, ((Input, NodeT), Param Type))]
drct_tups <- [Context NodeT EdgeT]
-> [((Input, NodeT), Param Type)]
-> [(Context NodeT EdgeT, ((Input, NodeT), Param Type))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Context NodeT EdgeT]
t1s ([((Input, NodeT), Param Type)]
-> [(Context NodeT EdgeT, ((Input, NodeT), Param Type))])
-> [((Input, NodeT), Param Type)]
-> [(Context NodeT EdgeT, ((Input, NodeT), Param Type))]
forall a b. (a -> b) -> a -> b
$ [(Input, NodeT)] -> [Param Type] -> [((Input, NodeT), Param Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(Input, NodeT)]
t2s (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
scat_lam),
([Context NodeT EdgeT]
ctxs_iots, [((Input, NodeT), Param Type)]
drct_iots) <- [(Context NodeT EdgeT, ((Input, NodeT), Param Type))]
-> ([Context NodeT EdgeT], [((Input, NodeT), Param Type)])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Context NodeT EdgeT, ((Input, NodeT), Param Type))]
-> ([Context NodeT EdgeT], [((Input, NodeT), Param Type)]))
-> [(Context NodeT EdgeT, ((Input, NodeT), Param Type))]
-> ([Context NodeT EdgeT], [((Input, NodeT), Param Type)])
forall a b. (a -> b) -> a -> b
$ ((Context NodeT EdgeT, ((Input, NodeT), Param Type)) -> Bool)
-> [(Context NodeT EdgeT, ((Input, NodeT), Param Type))]
-> [(Context NodeT EdgeT, ((Input, NodeT), Param Type))]
forall a. (a -> Bool) -> [a] -> [a]
filter (NodeT -> Bool
isIota (NodeT -> Bool)
-> ((Context NodeT EdgeT, ((Input, NodeT), Param Type)) -> NodeT)
-> (Context NodeT EdgeT, ((Input, NodeT), Param Type))
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Input, NodeT) -> NodeT
forall a b. (a, b) -> b
snd ((Input, NodeT) -> NodeT)
-> ((Context NodeT EdgeT, ((Input, NodeT), Param Type))
-> (Input, NodeT))
-> (Context NodeT EdgeT, ((Input, NodeT), Param Type))
-> NodeT
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Input, NodeT), Param Type) -> (Input, NodeT)
forall a b. (a, b) -> a
fst (((Input, NodeT), Param Type) -> (Input, NodeT))
-> ((Context NodeT EdgeT, ((Input, NodeT), Param Type))
-> ((Input, NodeT), Param Type))
-> (Context NodeT EdgeT, ((Input, NodeT), Param Type))
-> (Input, NodeT)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Context NodeT EdgeT, ((Input, NodeT), Param Type))
-> ((Input, NodeT), Param Type)
forall a b. (a, b) -> b
snd) [(Context NodeT EdgeT, ((Input, NodeT), Param Type))]
drct_tups,
([Context NodeT EdgeT]
ctxs_rshp, [((Input, NodeT), Param Type)]
drct_rshp) <- [(Context NodeT EdgeT, ((Input, NodeT), Param Type))]
-> ([Context NodeT EdgeT], [((Input, NodeT), Param Type)])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Context NodeT EdgeT, ((Input, NodeT), Param Type))]
-> ([Context NodeT EdgeT], [((Input, NodeT), Param Type)]))
-> [(Context NodeT EdgeT, ((Input, NodeT), Param Type))]
-> ([Context NodeT EdgeT], [((Input, NodeT), Param Type)])
forall a b. (a -> b) -> a -> b
$ ((Context NodeT EdgeT, ((Input, NodeT), Param Type)) -> Bool)
-> [(Context NodeT EdgeT, ((Input, NodeT), Param Type))]
-> [(Context NodeT EdgeT, ((Input, NodeT), Param Type))]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool)
-> ((Context NodeT EdgeT, ((Input, NodeT), Param Type)) -> Bool)
-> (Context NodeT EdgeT, ((Input, NodeT), Param Type))
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NodeT -> Bool
isIota (NodeT -> Bool)
-> ((Context NodeT EdgeT, ((Input, NodeT), Param Type)) -> NodeT)
-> (Context NodeT EdgeT, ((Input, NodeT), Param Type))
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Input, NodeT) -> NodeT
forall a b. (a, b) -> b
snd ((Input, NodeT) -> NodeT)
-> ((Context NodeT EdgeT, ((Input, NodeT), Param Type))
-> (Input, NodeT))
-> (Context NodeT EdgeT, ((Input, NodeT), Param Type))
-> NodeT
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Input, NodeT), Param Type) -> (Input, NodeT)
forall a b. (a, b) -> a
fst (((Input, NodeT), Param Type) -> (Input, NodeT))
-> ((Context NodeT EdgeT, ((Input, NodeT), Param Type))
-> ((Input, NodeT), Param Type))
-> (Context NodeT EdgeT, ((Input, NodeT), Param Type))
-> (Input, NodeT)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Context NodeT EdgeT, ((Input, NodeT), Param Type))
-> ((Input, NodeT), Param Type)
forall a b. (a, b) -> b
snd) [(Context NodeT EdgeT, ((Input, NodeT), Param Type))]
drct_tups,
[((Input, NodeT), Param Type)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [((Input, NodeT), Param Type)]
drct_iots Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [((Input, NodeT), Param Type)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [((Input, NodeT), Param Type)]
drct_rshp Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Input] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Input]
scat_inp,
[((VName, Param Type), (SubExp, SubExp, SubExp, IntType))]
rep_iotas <- (((Input, NodeT), Param Type)
-> Maybe ((VName, Param Type), (SubExp, SubExp, SubExp, IntType)))
-> [((Input, NodeT), Param Type)]
-> [((VName, Param Type), (SubExp, SubExp, SubExp, IntType))]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe ((Input, NodeT), Param Type)
-> Maybe ((VName, Param Type), (SubExp, SubExp, SubExp, IntType))
((Input, NodeT), LParam SOACS) -> Maybe IotaInp
getRepIota [((Input, NodeT), Param Type)]
drct_iots,
[((VName, Param Type), (SubExp, SubExp, SubExp, IntType))] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [((VName, Param Type), (SubExp, SubExp, SubExp, IntType))]
rep_iotas Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [((Input, NodeT), Param Type)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [((Input, NodeT), Param Type)]
drct_iots,
[(((VName, Param Type),
(ShapeBase SubExp, ShapeBase SubExp, Type)),
Certs)]
rep_rshps_certs <- (((Input, NodeT), Param Type)
-> Maybe
(((VName, Param Type), (ShapeBase SubExp, ShapeBase SubExp, Type)),
Certs))
-> [((Input, NodeT), Param Type)]
-> [(((VName, Param Type),
(ShapeBase SubExp, ShapeBase SubExp, Type)),
Certs)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe ((Input, NodeT), Param Type)
-> Maybe
(((VName, Param Type), (ShapeBase SubExp, ShapeBase SubExp, Type)),
Certs)
((Input, NodeT), LParam SOACS) -> Maybe (RshpInp, Certs)
getRepRshpArr [((Input, NodeT), Param Type)]
drct_rshp,
([((VName, Param Type), (ShapeBase SubExp, ShapeBase SubExp, Type))]
rep_rshps, [Certs]
certs_rshps) <- [(((VName, Param Type),
(ShapeBase SubExp, ShapeBase SubExp, Type)),
Certs)]
-> ([((VName, Param Type),
(ShapeBase SubExp, ShapeBase SubExp, Type))],
[Certs])
forall a b. [(a, b)] -> ([a], [b])
unzip [(((VName, Param Type),
(ShapeBase SubExp, ShapeBase SubExp, Type)),
Certs)]
rep_rshps_certs,
Bool -> Bool
not ([((VName, Param Type), (ShapeBase SubExp, ShapeBase SubExp, Type))]
-> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [((VName, Param Type), (ShapeBase SubExp, ShapeBase SubExp, Type))]
rep_rshps),
[((VName, Param Type), (ShapeBase SubExp, ShapeBase SubExp, Type))]
-> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [((VName, Param Type), (ShapeBase SubExp, ShapeBase SubExp, Type))]
rep_rshps Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [((Input, NodeT), Param Type)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [((Input, NodeT), Param Type)]
drct_rshp,
((VName, Param Type)
_, (ShapeBase SubExp
s1, ShapeBase SubExp
s2, Type
_)) : [((VName, Param Type), (ShapeBase SubExp, ShapeBase SubExp, Type))]
_ <- [((VName, Param Type), (ShapeBase SubExp, ShapeBase SubExp, Type))]
rep_rshps,
(((VName, Param Type), (ShapeBase SubExp, ShapeBase SubExp, Type))
-> Bool)
-> [((VName, Param Type),
(ShapeBase SubExp, ShapeBase SubExp, Type))]
-> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\((VName, Param Type)
_, (ShapeBase SubExp
s1', ShapeBase SubExp
s2', Type
_)) -> ShapeBase SubExp
s1 ShapeBase SubExp -> ShapeBase SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== ShapeBase SubExp
s1' Bool -> Bool -> Bool
&& ShapeBase SubExp
s2 ShapeBase SubExp -> ShapeBase SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== ShapeBase SubExp
s2') [((VName, Param Type), (ShapeBase SubExp, ShapeBase SubExp, Type))]
rep_rshps,
DepGraph
-> Int -> [Context NodeT EdgeT] -> [Context NodeT EdgeT] -> Bool
checkSafeAndProfitable DepGraph
dg Int
scat_node_id [Context NodeT EdgeT]
ctxs_rshp [Context NodeT EdgeT]
cons_ctxs = do
let cons_patels_outs :: [(PatElem Type, (ShapeBase SubExp, Int, VName))]
cons_patels_outs = [PatElem Type]
-> ScatterSpec VName
-> [(PatElem Type, (ShapeBase SubExp, Int, VName))]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
scat_pat) ScatterSpec VName
scat_out
Stm SOACS
wacc_stm <- [IotaInp]
-> [RshpInp]
-> [(PatElem (LetDec SOACS), (ShapeBase SubExp, Int, VName))]
-> StmAux (ExpDec SOACS)
-> Lambda SOACS
-> m (Stm SOACS)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
[IotaInp]
-> [RshpInp]
-> [(PatElem (LetDec SOACS), (ShapeBase SubExp, Int, VName))]
-> StmAux (ExpDec SOACS)
-> Lambda SOACS
-> m (Stm SOACS)
mkWithAccStm [((VName, Param Type), (SubExp, SubExp, SubExp, IntType))]
[IotaInp]
rep_iotas [((VName, Param Type), (ShapeBase SubExp, ShapeBase SubExp, Type))]
[RshpInp]
rep_rshps [(PatElem Type, (ShapeBase SubExp, Int, VName))]
[(PatElem (LetDec SOACS), (ShapeBase SubExp, Int, VName))]
cons_patels_outs StmAux (ExpDec SOACS)
scat_aux Lambda SOACS
scat_lam
let all_cert_rshp :: Certs
all_cert_rshp = [Certs] -> Certs
forall a. Monoid a => [a] -> a
mconcat [Certs]
certs_rshps
aux :: StmAux (ExpDec SOACS)
aux = Stm SOACS -> StmAux (ExpDec SOACS)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm SOACS
wacc_stm
aux' :: StmAux ()
aux' = StmAux ()
aux {stmAuxCerts = all_cert_rshp <> stmAuxCerts aux}
wacc_stm' :: Stm SOACS
wacc_stm' = Stm SOACS
wacc_stm {stmAux = aux'}
fiot :: a -> (a, b, c, a) -> a
fiot a
acc (a
_, b
_, c
_, a
inp_deps_iot) =
a
acc a -> a -> a
forall a. Semigroup a => a -> a -> a
<> a
inp_deps_iot
deps_of_iotas :: Adj EdgeT
deps_of_iotas = (Adj EdgeT -> Context NodeT EdgeT -> Adj EdgeT)
-> Adj EdgeT -> [Context NodeT EdgeT] -> Adj EdgeT
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Adj EdgeT -> Context NodeT EdgeT -> Adj EdgeT
forall {a} {a} {b} {c}. Semigroup a => a -> (a, b, c, a) -> a
fiot Adj EdgeT
forall a. Monoid a => a
mempty [Context NodeT EdgeT]
ctxs_iots
iota_nms :: Names
iota_nms = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (((VName, Param Type), (SubExp, SubExp, SubExp, IntType)) -> VName)
-> [((VName, Param Type), (SubExp, SubExp, SubExp, IntType))]
-> [VName]
forall a b. (a -> b) -> [a] -> [b]
map ((VName, Param Type) -> VName
forall a b. (a, b) -> a
fst ((VName, Param Type) -> VName)
-> (((VName, Param Type), (SubExp, SubExp, SubExp, IntType))
-> (VName, Param Type))
-> ((VName, Param Type), (SubExp, SubExp, SubExp, IntType))
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((VName, Param Type), (SubExp, SubExp, SubExp, IntType))
-> (VName, Param Type)
forall a b. (a, b) -> a
fst) [((VName, Param Type), (SubExp, SubExp, SubExp, IntType))]
rep_iotas
inp_deps_wo_iotas :: Adj EdgeT
inp_deps_wo_iotas = ((EdgeT, Int) -> Bool) -> Adj EdgeT -> Adj EdgeT
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> Names -> Bool
`notNameIn` Names
iota_nms) (VName -> Bool) -> ((EdgeT, Int) -> VName) -> (EdgeT, Int) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EdgeT -> VName
getName (EdgeT -> VName)
-> ((EdgeT, Int) -> EdgeT) -> (EdgeT, Int) -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (EdgeT, Int) -> EdgeT
forall a b. (a, b) -> a
fst) Adj EdgeT
inp_deps
new_withacc_nT :: NodeT
new_withacc_nT = Stm SOACS -> NodeT
StmNode Stm SOACS
wacc_stm'
inp_deps' :: Adj EdgeT
inp_deps' = Adj EdgeT
inp_deps_wo_iotas Adj EdgeT -> Adj EdgeT -> Adj EdgeT
forall a. Semigroup a => a -> a -> a
<> Adj EdgeT
deps_of_iotas
new_withacc_ctx :: Context NodeT EdgeT
new_withacc_ctx = (Adj EdgeT
out_deps, Int
scat_node_id, NodeT
new_withacc_nT, Adj EdgeT
inp_deps')
new_node :: Int
new_node = Context NodeT EdgeT -> Int
forall a b. Context a b -> Int
G.node' Context NodeT EdgeT
new_withacc_ctx
dg' :: DepGraph
dg' = DepGraph
dg {dgGraph = new_withacc_ctx G.& G.delNodes [new_node] g}
Maybe DepGraph -> m (Maybe DepGraph)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe DepGraph -> m (Maybe DepGraph))
-> Maybe DepGraph -> m (Maybe DepGraph)
forall a b. (a -> b) -> a -> b
$ DepGraph -> Maybe DepGraph
forall a. a -> Maybe a
Just DepGraph
dg'
where
getNodeTfromCtx :: (a, b, c, d) -> c
getNodeTfromCtx (a
_, b
_, c
nT, d
_) = c
nT
findCtxOf :: [(a, EdgeT)] -> VName -> Maybe (a, EdgeT)
findCtxOf [(a, EdgeT)]
ctxes VName
nm
| [(a, EdgeT)
ctxe] <- ((a, EdgeT) -> Bool) -> [(a, EdgeT)] -> [(a, EdgeT)]
forall a. (a -> Bool) -> [a] -> [a]
filter (\(a, EdgeT)
x -> VName
nm VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== EdgeT -> VName
getName ((a, EdgeT) -> EdgeT
forall a b. (a, b) -> b
snd (a, EdgeT)
x)) [(a, EdgeT)]
ctxes =
(a, EdgeT) -> Maybe (a, EdgeT)
forall a. a -> Maybe a
Just (a, EdgeT)
ctxe
findCtxOf [(a, EdgeT)]
_ VName
_ = Maybe (a, EdgeT)
forall a. Maybe a
Nothing
pairUp :: [(DepContext, EdgeT)] -> H.Input -> Maybe (DepContext, (H.Input, NodeT))
pairUp :: [(Context NodeT EdgeT, EdgeT)]
-> Input -> Maybe (Context NodeT EdgeT, (Input, NodeT))
pairUp [(Context NodeT EdgeT, EdgeT)]
ctxes inp :: Input
inp@(H.Input ArrayTransforms
_arrtrsfs VName
nm Type
_tp)
| Just (ctx :: Context NodeT EdgeT
ctx@(Adj EdgeT
_, Int
_, NodeT
nT, Adj EdgeT
_), EdgeT
_) <- [(Context NodeT EdgeT, EdgeT)]
-> VName -> Maybe (Context NodeT EdgeT, EdgeT)
forall {a}. [(a, EdgeT)] -> VName -> Maybe (a, EdgeT)
findCtxOf [(Context NodeT EdgeT, EdgeT)]
ctxes VName
nm =
(Context NodeT EdgeT, (Input, NodeT))
-> Maybe (Context NodeT EdgeT, (Input, NodeT))
forall a. a -> Maybe a
Just (Context NodeT EdgeT
ctx, (Input
inp, NodeT
nT))
pairUp [(Context NodeT EdgeT, EdgeT)]
_ Input
_ = Maybe (Context NodeT EdgeT, (Input, NodeT))
forall a. Maybe a
Nothing
isIota :: NodeT -> Bool
isIota :: NodeT -> Bool
isIota (StmNode (Let Pat (LetDec SOACS)
_ StmAux (ExpDec SOACS)
_ (BasicOp (Iota {})))) = Bool
True
isIota NodeT
_ = Bool
False
getRepIota :: ((H.Input, NodeT), LParam SOACS) -> Maybe IotaInp
getRepIota :: ((Input, NodeT), LParam SOACS) -> Maybe IotaInp
getRepIota ((H.Input ArrayTransforms
iottrsf VName
arr_nm Type
_arr_tp, NodeT
nt), LParam SOACS
farg)
| ArrayTransforms
forall a. Monoid a => a
mempty ArrayTransforms -> ArrayTransforms -> Bool
forall a. Eq a => a -> a -> Bool
== ArrayTransforms
iottrsf,
StmNode (Let Pat (LetDec SOACS)
_ StmAux (ExpDec SOACS)
_ (BasicOp (Iota SubExp
n SubExp
x SubExp
s IntType
Int64))) <- NodeT
nt =
((VName, Param Type), (SubExp, SubExp, SubExp, IntType))
-> Maybe ((VName, Param Type), (SubExp, SubExp, SubExp, IntType))
forall a. a -> Maybe a
Just ((VName
arr_nm, Param Type
LParam SOACS
farg), (SubExp
n, SubExp
x, SubExp
s, IntType
Int64))
getRepIota ((Input, NodeT), LParam SOACS)
_ = Maybe ((VName, Param Type), (SubExp, SubExp, SubExp, IntType))
Maybe IotaInp
forall a. Maybe a
Nothing
getRepRshpArr :: ((H.Input, NodeT), LParam SOACS) -> Maybe (RshpInp, Certs)
getRepRshpArr :: ((Input, NodeT), LParam SOACS) -> Maybe (RshpInp, Certs)
getRepRshpArr ((H.Input ArrayTransforms
outtrsf VName
arr_nm Type
arr_tp, NodeT
_nt), LParam SOACS
farg)
| ArrayTransform
rshp_trsfm H.:< ArrayTransforms
other_trsfms <- ArrayTransforms -> ViewF
H.viewf ArrayTransforms
outtrsf,
(H.Reshape Certs
c ReshapeKind
ReshapeArbitrary ShapeBase SubExp
shp_flat) <- ArrayTransform
rshp_trsfm,
ArrayTransforms
other_trsfms ArrayTransforms -> ArrayTransforms -> Bool
forall a. Eq a => a -> a -> Bool
== ArrayTransforms
forall a. Monoid a => a
mempty,
Type
eltp <- Param Type -> Type
forall dec. Param dec -> dec
paramDec Param Type
LParam SOACS
farg,
Just ShapeBase SubExp
shp_flat' <- Type -> ShapeBase SubExp -> Maybe (ShapeBase SubExp)
forall {b} {u}.
Eq b =>
TypeBase (ShapeBase b) u -> ShapeBase b -> Maybe (ShapeBase b)
checkShp Type
eltp ShapeBase SubExp
shp_flat,
Array PrimType
_ptp ShapeBase SubExp
shp_unflat NoUniqueness
_ <- Type
arr_tp,
Just ShapeBase SubExp
shp_unflat' <- Type -> ShapeBase SubExp -> Maybe (ShapeBase SubExp)
forall {b} {u}.
Eq b =>
TypeBase (ShapeBase b) u -> ShapeBase b -> Maybe (ShapeBase b)
checkShp Type
eltp ShapeBase SubExp
shp_unflat,
ShapeBase SubExp -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
shp_flat' Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1,
ShapeBase SubExp -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
shp_flat' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< ShapeBase SubExp -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
shp_unflat' =
(((VName, Param Type), (ShapeBase SubExp, ShapeBase SubExp, Type)),
Certs)
-> Maybe
(((VName, Param Type), (ShapeBase SubExp, ShapeBase SubExp, Type)),
Certs)
forall a. a -> Maybe a
Just (((VName
arr_nm, Param Type
LParam SOACS
farg), (ShapeBase SubExp
shp_flat', ShapeBase SubExp
shp_unflat', Type
eltp)), Certs
c)
getRepRshpArr ((Input, NodeT), LParam SOACS)
_ = Maybe
(((VName, Param Type), (ShapeBase SubExp, ShapeBase SubExp, Type)),
Certs)
Maybe (RshpInp, Certs)
forall a. Maybe a
Nothing
checkShp :: TypeBase (ShapeBase b) u -> ShapeBase b -> Maybe (ShapeBase b)
checkShp (Prim PrimType
_) ShapeBase b
shp_arr = ShapeBase b -> Maybe (ShapeBase b)
forall a. a -> Maybe a
Just ShapeBase b
shp_arr
checkShp (Array PrimType
_ptp ShapeBase b
shp_elm u
_) ShapeBase b
shp_arr =
let dims_elm :: [b]
dims_elm = ShapeBase b -> [b]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase b
shp_elm
dims_arr :: [b]
dims_arr = ShapeBase b -> [b]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase b
shp_arr
(Int
m, Int
n) = ([b] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [b]
dims_elm, [b] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [b]
dims_arr)
shp' :: ShapeBase b
shp' = [b] -> ShapeBase b
forall d. [d] -> ShapeBase d
Shape ([b] -> ShapeBase b) -> [b] -> ShapeBase b
forall a b. (a -> b) -> a -> b
$ Int -> [b] -> [b]
forall a. Int -> [a] -> [a]
take (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
m) [b]
dims_arr
dims_com :: [b]
dims_com = Int -> [b] -> [b]
forall a. Int -> [a] -> [a]
drop (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
m) [b]
dims_arr
in if ((b, b) -> Bool) -> [(b, b)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((b -> b -> Bool) -> (b, b) -> Bool
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry b -> b -> Bool
forall a. Eq a => a -> a -> Bool
(==)) ([b] -> [b] -> [(b, b)]
forall a b. [a] -> [b] -> [(a, b)]
zip [b]
dims_com [b]
dims_elm)
then ShapeBase b -> Maybe (ShapeBase b)
forall a. a -> Maybe a
Just ShapeBase b
shp'
else Maybe (ShapeBase b)
forall a. Maybe a
Nothing
checkShp TypeBase (ShapeBase b) u
_ ShapeBase b
_ = Maybe (ShapeBase b)
forall a. Maybe a
Nothing
ruleMFScat DepNode
_ DepGraph
_ = Maybe DepGraph -> m (Maybe DepGraph)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe DepGraph
forall a. Maybe a
Nothing
checkSafeAndProfitable :: DepGraph -> G.Node -> [DepContext] -> [DepContext] -> Bool
checkSafeAndProfitable :: DepGraph
-> Int -> [Context NodeT EdgeT] -> [Context NodeT EdgeT] -> Bool
checkSafeAndProfitable DepGraph
dg Int
scat_node_id ctxs_rshp :: [Context NodeT EdgeT]
ctxs_rshp@(Context NodeT EdgeT
_ : [Context NodeT EdgeT]
_) [Context NodeT EdgeT]
ctxs_cons =
let all_deps :: Adj EdgeT
all_deps = (Context NodeT EdgeT -> Adj EdgeT)
-> [Context NodeT EdgeT] -> Adj EdgeT
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\(Adj EdgeT
x, Int
_, NodeT
_, Adj EdgeT
_) -> Adj EdgeT
x) ([Context NodeT EdgeT] -> Adj EdgeT)
-> [Context NodeT EdgeT] -> Adj EdgeT
forall a b. (a -> b) -> a -> b
$ [Context NodeT EdgeT]
ctxs_rshp [Context NodeT EdgeT]
-> [Context NodeT EdgeT] -> [Context NodeT EdgeT]
forall a. [a] -> [a] -> [a]
++ [Context NodeT EdgeT]
ctxs_cons
prof1 :: Bool
prof1 = ((EdgeT, Int) -> Bool) -> Adj EdgeT -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\(EdgeT
_, Int
dep_id) -> Int
dep_id Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
scat_node_id) Adj EdgeT
all_deps
(Adj EdgeT
_, Int
map_node_id, NodeT
map_nT, Adj EdgeT
_) = [Context NodeT EdgeT] -> Context NodeT EdgeT
forall a. HasCallStack => [a] -> a
head [Context NodeT EdgeT]
ctxs_rshp
prof2 :: Bool
prof2 = (Context NodeT EdgeT -> Bool) -> [Context NodeT EdgeT] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\(Adj EdgeT
_, Int
nid, NodeT
_, Adj EdgeT
_) -> Int
nid Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
map_node_id) [Context NodeT EdgeT]
ctxs_rshp
prof3 :: Bool
prof3 = NodeT -> Bool
isMap NodeT
map_nT
safe :: Bool
safe = DepGraph -> Int -> Int -> Bool
vFusionFeasability DepGraph
dg Int
map_node_id Int
scat_node_id
in Bool
safe Bool -> Bool -> Bool
&& Bool
prof1 Bool -> Bool -> Bool
&& Bool
prof2 Bool -> Bool -> Bool
&& Bool
prof3
where
isMap :: NodeT -> Bool
isMap NodeT
nT
| SoacNode ArrayTransforms
out_trsfs Pat Type
_pat SOAC SOACS
soac StmAux (ExpDec SOACS)
_ <- NodeT
nT,
H.Screma SubExp
_ [Input]
_ ScremaForm SOACS
form <- SOAC SOACS
soac,
ScremaForm Lambda SOACS
_ [] [] <- ScremaForm SOACS
form =
ArrayTransforms -> Bool
H.nullTransforms ArrayTransforms
out_trsfs
isMap NodeT
_ = Bool
False
checkSafeAndProfitable DepGraph
_ Int
_ [Context NodeT EdgeT]
_ [Context NodeT EdgeT]
_ = Bool
False
mkWithAccStm ::
(HasScope SOACS m, MonadFreshNames m) =>
[IotaInp] ->
[RshpInp] ->
[(PatElem (LetDec SOACS), (Shape, Int, VName))] ->
StmAux (ExpDec SOACS) ->
Lambda SOACS ->
m (Stm SOACS)
mkWithAccStm :: forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
[IotaInp]
-> [RshpInp]
-> [(PatElem (LetDec SOACS), (ShapeBase SubExp, Int, VName))]
-> StmAux (ExpDec SOACS)
-> Lambda SOACS
-> m (Stm SOACS)
mkWithAccStm [IotaInp]
iota_inps [RshpInp]
rshp_inps [(PatElem (LetDec SOACS), (ShapeBase SubExp, Int, VName))]
cons_patels_outs StmAux (ExpDec SOACS)
scatter_aux Lambda SOACS
scatter_lam
| RshpInp
rshp_inp : [RshpInp]
_ <- [RshpInp]
rshp_inps,
((VName, LParam SOACS)
_, (ShapeBase SubExp
_, ShapeBase SubExp
s_unflat, Type
_)) <- RshpInp
rshp_inp,
(SubExp
_ : [SubExp]
_) <- ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
s_unflat = do
([Param Type]
cert_params, [Param Type]
acc_params) <- ([(Param Type, Param Type)] -> ([Param Type], [Param Type]))
-> m [(Param Type, Param Type)] -> m ([Param Type], [Param Type])
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(Param Type, Param Type)] -> ([Param Type], [Param Type])
forall a b. [(a, b)] -> ([a], [b])
unzip (m [(Param Type, Param Type)] -> m ([Param Type], [Param Type]))
-> m [(Param Type, Param Type)] -> m ([Param Type], [Param Type])
forall a b. (a -> b) -> a -> b
$
[(PatElem Type, (ShapeBase SubExp, Int, VName))]
-> ((PatElem Type, (ShapeBase SubExp, Int, VName))
-> m (Param Type, Param Type))
-> m [(Param Type, Param Type)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(PatElem Type, (ShapeBase SubExp, Int, VName))]
[(PatElem (LetDec SOACS), (ShapeBase SubExp, Int, VName))]
cons_patels_outs (((PatElem Type, (ShapeBase SubExp, Int, VName))
-> m (Param Type, Param Type))
-> m [(Param Type, Param Type)])
-> ((PatElem Type, (ShapeBase SubExp, Int, VName))
-> m (Param Type, Param Type))
-> m [(Param Type, Param Type)]
forall a b. (a -> b) -> a -> b
$ \(PatElem Type
patel, (ShapeBase SubExp
shp, Int
_, VName
nm)) -> do
Param Type
cert_param <- String -> Type -> m (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"acc_cert_p" (Type -> m (Param Type)) -> Type -> m (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit
let arr_tp :: Type
arr_tp = PatElem Type -> Type
forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem Type
patel
acc_tp :: Type
acc_tp = Int -> Type -> Type
forall u.
Int
-> TypeBase (ShapeBase SubExp) u -> TypeBase (ShapeBase SubExp) u
stripArray (ShapeBase SubExp -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
shp) Type
arr_tp
Param Type
acc_param <-
String -> Type -> m (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (VName -> String
baseString VName
nm) (Type -> m (Param Type)) -> Type -> m (Param Type)
forall a b. (a -> b) -> a -> b
$
VName -> ShapeBase SubExp -> [Type] -> NoUniqueness -> Type
forall shape u.
VName -> ShapeBase SubExp -> [Type] -> u -> TypeBase shape u
Acc (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
cert_param) ShapeBase SubExp
shp [Type
acc_tp] NoUniqueness
NoUniqueness
(Param Type, Param Type) -> m (Param Type, Param Type)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param Type
cert_param, Param Type
acc_param)
let cons_params_outs :: [(Param Type, (ShapeBase SubExp, Int, VName))]
cons_params_outs = [Param Type]
-> ScatterSpec VName
-> [(Param Type, (ShapeBase SubExp, Int, VName))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
acc_params (ScatterSpec VName
-> [(Param Type, (ShapeBase SubExp, Int, VName))])
-> ScatterSpec VName
-> [(Param Type, (ShapeBase SubExp, Int, VName))]
forall a b. (a -> b) -> a -> b
$ ((PatElem Type, (ShapeBase SubExp, Int, VName))
-> (ShapeBase SubExp, Int, VName))
-> [(PatElem Type, (ShapeBase SubExp, Int, VName))]
-> ScatterSpec VName
forall a b. (a -> b) -> [a] -> [b]
map (PatElem Type, (ShapeBase SubExp, Int, VName))
-> (ShapeBase SubExp, Int, VName)
forall a b. (a, b) -> b
snd [(PatElem Type, (ShapeBase SubExp, Int, VName))]
[(PatElem (LetDec SOACS), (ShapeBase SubExp, Int, VName))]
cons_patels_outs
Body SOACS
acc_bdy <- ShapeBase SubExp
-> [IotaInp]
-> [RshpInp]
-> [(LParam SOACS, (ShapeBase SubExp, Int, VName))]
-> Lambda SOACS
-> m (Body SOACS)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
ShapeBase SubExp
-> [IotaInp]
-> [RshpInp]
-> [(LParam SOACS, (ShapeBase SubExp, Int, VName))]
-> Lambda SOACS
-> m (Body SOACS)
mkWithAccBdy ShapeBase SubExp
s_unflat [IotaInp]
iota_inps [RshpInp]
rshp_inps [(Param Type, (ShapeBase SubExp, Int, VName))]
[(LParam SOACS, (ShapeBase SubExp, Int, VName))]
cons_params_outs Lambda SOACS
scatter_lam
let withacc_lam :: Lambda SOACS
withacc_lam =
Lambda
{ lambdaParams :: [LParam SOACS]
lambdaParams = [Param Type]
cert_params [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type]
acc_params,
lambdaReturnType :: [Type]
lambdaReturnType = (Param Type -> Type) -> [Param Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> Type
forall dec. Param dec -> dec
paramDec [Param Type]
acc_params,
lambdaBody :: Body SOACS
lambdaBody = Body SOACS
acc_bdy
}
withacc_inps :: [(ShapeBase SubExp, [VName], Maybe a)]
withacc_inps = ((PatElem Type, (ShapeBase SubExp, Int, VName))
-> (ShapeBase SubExp, [VName], Maybe a))
-> [(PatElem Type, (ShapeBase SubExp, Int, VName))]
-> [(ShapeBase SubExp, [VName], Maybe a)]
forall a b. (a -> b) -> [a] -> [b]
map (\(PatElem Type
_, (ShapeBase SubExp
shp, Int
_, VName
nm)) -> (ShapeBase SubExp
shp, [VName
nm], Maybe a
forall a. Maybe a
Nothing)) [(PatElem Type, (ShapeBase SubExp, Int, VName))]
[(PatElem (LetDec SOACS), (ShapeBase SubExp, Int, VName))]
cons_patels_outs
withacc_pat :: Pat Type
withacc_pat = [PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem Type] -> Pat Type) -> [PatElem Type] -> Pat Type
forall a b. (a -> b) -> a -> b
$ ((PatElem Type, (ShapeBase SubExp, Int, VName)) -> PatElem Type)
-> [(PatElem Type, (ShapeBase SubExp, Int, VName))]
-> [PatElem Type]
forall a b. (a -> b) -> [a] -> [b]
map (PatElem Type, (ShapeBase SubExp, Int, VName)) -> PatElem Type
forall a b. (a, b) -> a
fst [(PatElem Type, (ShapeBase SubExp, Int, VName))]
[(PatElem (LetDec SOACS), (ShapeBase SubExp, Int, VName))]
cons_patels_outs
stm :: Stm SOACS
stm =
Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec SOACS)
withacc_pat StmAux (ExpDec SOACS)
scatter_aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$
[WithAccInput SOACS] -> Lambda SOACS -> Exp SOACS
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [WithAccInput SOACS]
forall {a}. [(ShapeBase SubExp, [VName], Maybe a)]
withacc_inps Lambda SOACS
withacc_lam
Stm SOACS -> m (Stm SOACS)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stm SOACS
stm
mkWithAccStm [IotaInp]
_ [RshpInp]
_ [(PatElem (LetDec SOACS), (ShapeBase SubExp, Int, VName))]
_ StmAux (ExpDec SOACS)
_ Lambda SOACS
_ =
String -> m (Stm SOACS)
forall a. HasCallStack => String -> a
error String
"Unreachable case reached!"
mkWithAccBdy ::
(HasScope SOACS m, MonadFreshNames m) =>
Shape ->
[IotaInp] ->
[RshpInp] ->
[(LParam SOACS, (Shape, Int, VName))] ->
Lambda SOACS ->
m (Body SOACS)
mkWithAccBdy :: forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
ShapeBase SubExp
-> [IotaInp]
-> [RshpInp]
-> [(LParam SOACS, (ShapeBase SubExp, Int, VName))]
-> Lambda SOACS
-> m (Body SOACS)
mkWithAccBdy ShapeBase SubExp
shp [IotaInp]
iota_inps [RshpInp]
rshp_inps [(LParam SOACS, (ShapeBase SubExp, Int, VName))]
cons_params_outs Lambda SOACS
scat_lam = do
let cons_ps :: [Param Type]
cons_ps = ((Param Type, (ShapeBase SubExp, Int, VName)) -> Param Type)
-> [(Param Type, (ShapeBase SubExp, Int, VName))] -> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map (Param Type, (ShapeBase SubExp, Int, VName)) -> Param Type
forall a b. (a, b) -> a
fst [(Param Type, (ShapeBase SubExp, Int, VName))]
[(LParam SOACS, (ShapeBase SubExp, Int, VName))]
cons_params_outs
scat_res_info :: ScatterSpec VName
scat_res_info = ((Param Type, (ShapeBase SubExp, Int, VName))
-> (ShapeBase SubExp, Int, VName))
-> [(Param Type, (ShapeBase SubExp, Int, VName))]
-> ScatterSpec VName
forall a b. (a -> b) -> [a] -> [b]
map (Param Type, (ShapeBase SubExp, Int, VName))
-> (ShapeBase SubExp, Int, VName)
forall a b. (a, b) -> b
snd [(Param Type, (ShapeBase SubExp, Int, VName))]
[(LParam SOACS, (ShapeBase SubExp, Int, VName))]
cons_params_outs
static_arg :: ([((VName, Param Type), (SubExp, SubExp, SubExp, IntType))],
[((VName, Param Type),
(ShapeBase SubExp, ShapeBase SubExp, Type))],
ScatterSpec VName, Lambda SOACS)
static_arg = ([((VName, Param Type), (SubExp, SubExp, SubExp, IntType))]
[IotaInp]
iota_inps, [((VName, Param Type), (ShapeBase SubExp, ShapeBase SubExp, Type))]
[RshpInp]
rshp_inps, ScatterSpec VName
scat_res_info, Lambda SOACS
scat_lam)
mkParam :: ((VName, b), (a, ShapeBase SubExp, Type)) -> Param Type
mkParam ((VName
nm, b
_), (a
_, ShapeBase SubExp
s, Type
t)) = Attrs -> VName -> Type -> Param Type
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty VName
nm (Type -> ShapeBase SubExp -> Type
arrayOfShape Type
t ShapeBase SubExp
s)
rshp_ps :: [Param Type]
rshp_ps = (((VName, Param Type), (ShapeBase SubExp, ShapeBase SubExp, Type))
-> Param Type)
-> [((VName, Param Type),
(ShapeBase SubExp, ShapeBase SubExp, Type))]
-> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map ((VName, Param Type), (ShapeBase SubExp, ShapeBase SubExp, Type))
-> Param Type
forall {b} {a}.
((VName, b), (a, ShapeBase SubExp, Type)) -> Param Type
mkParam [((VName, Param Type), (ShapeBase SubExp, ShapeBase SubExp, Type))]
[RshpInp]
rshp_inps
([IotaInp], [RshpInp], ScatterSpec VName, Lambda SOACS)
-> [SubExp]
-> [SubExp]
-> [VName]
-> [LParam SOACS]
-> [LParam SOACS]
-> m (Body SOACS)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
([IotaInp], [RshpInp], ScatterSpec VName, Lambda SOACS)
-> [SubExp]
-> [SubExp]
-> [VName]
-> [LParam SOACS]
-> [LParam SOACS]
-> m (Body SOACS)
mkWithAccBdy' ([((VName, Param Type), (SubExp, SubExp, SubExp, IntType))],
[((VName, Param Type),
(ShapeBase SubExp, ShapeBase SubExp, Type))],
ScatterSpec VName, Lambda SOACS)
([IotaInp], [RshpInp], ScatterSpec VName, Lambda SOACS)
static_arg (ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shp) [] [] [Param Type]
[LParam SOACS]
rshp_ps [Param Type]
[LParam SOACS]
cons_ps
mkWithAccBdy' ::
(HasScope SOACS m, MonadFreshNames m) =>
([IotaInp], [RshpInp], [(Shape, Int, VName)], Lambda SOACS) ->
[SubExp] ->
[SubExp] ->
[VName] ->
[LParam SOACS] ->
[LParam SOACS] ->
m (Body SOACS)
mkWithAccBdy' :: forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
([IotaInp], [RshpInp], ScatterSpec VName, Lambda SOACS)
-> [SubExp]
-> [SubExp]
-> [VName]
-> [LParam SOACS]
-> [LParam SOACS]
-> m (Body SOACS)
mkWithAccBdy' ([IotaInp], [RshpInp], ScatterSpec VName, Lambda SOACS)
static_arg [] [SubExp]
dims_rev [VName]
iot_par_nms [LParam SOACS]
rshp_ps [LParam SOACS]
cons_ps = do
let ([IotaInp]
iota_inps, [RshpInp]
rshp_inps, ScatterSpec VName
scat_res_info, Lambda SOACS
scat_lam) = ([IotaInp], [RshpInp], ScatterSpec VName, Lambda SOACS)
static_arg
tp_int :: TypeBase shape u
tp_int = PrimType -> TypeBase shape u
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> TypeBase shape u) -> PrimType -> TypeBase shape u
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64
Scope SOACS
scope <- m (Scope SOACS)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
Builder SOACS Result -> m (Body SOACS)
forall rep (m :: * -> *) somerep.
(Buildable rep, MonadFreshNames m, HasScope somerep m,
SameScope somerep rep) =>
Builder rep Result -> m (Body rep)
runBodyBuilder (Builder SOACS Result -> m (Body SOACS))
-> Builder SOACS Result -> m (Body SOACS)
forall a b. (a -> b) -> a -> b
$ Scope SOACS -> Builder SOACS Result -> Builder SOACS Result
forall a.
Scope SOACS
-> BuilderT SOACS (State VNameSource) a
-> BuilderT SOACS (State VNameSource) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Scope SOACS
scope Scope SOACS -> Scope SOACS -> Scope SOACS
forall a. Semigroup a => a -> a -> a
<> [LParam SOACS] -> Scope SOACS
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams ([LParam SOACS]
rshp_ps [LParam SOACS] -> [LParam SOACS] -> [LParam SOACS]
forall a. [a] -> [a] -> [a]
++ [LParam SOACS]
cons_ps)) (Builder SOACS Result -> Builder SOACS Result)
-> Builder SOACS Result -> Builder SOACS Result
forall a b. (a -> b) -> a -> b
$ do
let strides_rev :: [TPrimExp Int64 VName]
strides_rev = (TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> TPrimExp Int64 VName
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
(*) (SubExp -> TPrimExp Int64 VName
pe64 SubExp
se1) ([TPrimExp Int64 VName] -> [TPrimExp Int64 VName])
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims_rev
strides :: [TPrimExp Int64 VName]
strides = [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. HasCallStack => [a] -> [a]
tail ([TPrimExp Int64 VName] -> [TPrimExp Int64 VName])
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a]
reverse [TPrimExp Int64 VName]
strides_rev
prods :: [TPrimExp Int64 VName]
prods = (TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
(*) ((VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 [VName]
iot_par_nms) [TPrimExp Int64 VName]
strides
i_pe :: TPrimExp Int64 VName
i_pe = [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [TPrimExp Int64 VName]
prods
VName
i_norm <- String
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"iota_norm_arg" (Exp SOACS -> BuilderT SOACS (State VNameSource) VName)
-> BuilderT SOACS (State VNameSource) (Exp SOACS)
-> BuilderT SOACS (State VNameSource) VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
SOACS
(State VNameSource)
(Exp (Rep (BuilderT SOACS (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp TPrimExp Int64 VName
i_pe
[((VName, Param Type), (SubExp, SubExp, SubExp, IntType))]
-> (((VName, Param Type), (SubExp, SubExp, SubExp, IntType))
-> BuilderT SOACS (State VNameSource) ())
-> BuilderT SOACS (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [((VName, Param Type), (SubExp, SubExp, SubExp, IntType))]
iota_inps ((((VName, Param Type), (SubExp, SubExp, SubExp, IntType))
-> BuilderT SOACS (State VNameSource) ())
-> BuilderT SOACS (State VNameSource) ())
-> (((VName, Param Type), (SubExp, SubExp, SubExp, IntType))
-> BuilderT SOACS (State VNameSource) ())
-> BuilderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ \((VName, Param Type), (SubExp, SubExp, SubExp, IntType))
arg -> do
let ((VName
_, Param Type
i_par), (SubExp
_, SubExp
b, SubExp
s, IntType
_)) = ((VName, Param Type), (SubExp, SubExp, SubExp, IntType))
arg
VName
i_new <- String
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"tmp" (Exp SOACS -> BuilderT SOACS (State VNameSource) VName)
-> BuilderT SOACS (State VNameSource) (Exp SOACS)
-> BuilderT SOACS (State VNameSource) VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
SOACS
(State VNameSource)
(Exp (Rep (BuilderT SOACS (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
b TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i_norm TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
s)
Pat (LetDec (Rep (BuilderT SOACS (State VNameSource))))
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind ([PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [VName -> Type -> PatElem Type
forall dec. VName -> dec -> PatElem dec
PatElem (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
i_par) Type
forall {shape} {u}. TypeBase shape u
tp_int]) (Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ())
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (BuilderT SOACS (State VNameSource)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT SOACS (State VNameSource))))
-> BasicOp -> Exp (Rep (BuilderT SOACS (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
i_new
let rshp_lam_args :: [Param Type]
rshp_lam_args = (((VName, Param Type), (ShapeBase SubExp, ShapeBase SubExp, Type))
-> Param Type)
-> [((VName, Param Type),
(ShapeBase SubExp, ShapeBase SubExp, Type))]
-> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map ((VName, Param Type) -> Param Type
forall a b. (a, b) -> b
snd ((VName, Param Type) -> Param Type)
-> (((VName, Param Type),
(ShapeBase SubExp, ShapeBase SubExp, Type))
-> (VName, Param Type))
-> ((VName, Param Type),
(ShapeBase SubExp, ShapeBase SubExp, Type))
-> Param Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((VName, Param Type), (ShapeBase SubExp, ShapeBase SubExp, Type))
-> (VName, Param Type)
forall a b. (a, b) -> a
fst) [((VName, Param Type), (ShapeBase SubExp, ShapeBase SubExp, Type))]
rshp_inps
[(Param Type, Param Type)]
-> ((Param Type, Param Type)
-> BuilderT SOACS (State VNameSource) ())
-> BuilderT SOACS (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param Type] -> [Param Type] -> [(Param Type, Param Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
rshp_lam_args [Param Type]
[LParam SOACS]
rshp_ps) (((Param Type, Param Type)
-> BuilderT SOACS (State VNameSource) ())
-> BuilderT SOACS (State VNameSource) ())
-> ((Param Type, Param Type)
-> BuilderT SOACS (State VNameSource) ())
-> BuilderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ \(Param Type
old_par, Param Type
new_par) -> do
let pat :: Pat Type
pat = [PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [VName -> Type -> PatElem Type
forall dec. VName -> dec -> PatElem dec
PatElem (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
old_par) (Param Type -> Type
forall dec. Param dec -> dec
paramDec Param Type
old_par)]
Pat (LetDec (Rep (BuilderT SOACS (State VNameSource))))
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat Type
Pat (LetDec (Rep (BuilderT SOACS (State VNameSource))))
pat (Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ())
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (BuilderT SOACS (State VNameSource)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT SOACS (State VNameSource))))
-> BasicOp -> Exp (Rep (BuilderT SOACS (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
new_par
(Stm SOACS -> BuilderT SOACS (State VNameSource) ())
-> Seq (Stm SOACS) -> BuilderT SOACS (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ()
Stm SOACS -> BuilderT SOACS (State VNameSource) ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Seq (Stm SOACS) -> BuilderT SOACS (State VNameSource) ())
-> Seq (Stm SOACS) -> BuilderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Seq (Stm SOACS)
forall rep. Body rep -> Stms rep
bodyStms (Body SOACS -> Seq (Stm SOACS)) -> Body SOACS -> Seq (Stm SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
scat_lam
let iv_ses :: [(Result, SubExpRes)]
iv_ses = ScatterSpec VName -> Result -> [(Result, SubExpRes)]
forall array a.
[(ShapeBase SubExp, Int, array)] -> [a] -> [([a], a)]
groupScatterResults' ScatterSpec VName
scat_res_info (Result -> [(Result, SubExpRes)])
-> Result -> [(Result, SubExpRes)]
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult (Body SOACS -> Result) -> Body SOACS -> Result
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
scat_lam
[VName]
res_nms <-
[(Param Type, (Result, SubExpRes))]
-> ((Param Type, (Result, SubExpRes))
-> BuilderT SOACS (State VNameSource) VName)
-> BuilderT SOACS (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Param Type]
-> [(Result, SubExpRes)] -> [(Param Type, (Result, SubExpRes))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
[LParam SOACS]
cons_ps [(Result, SubExpRes)]
iv_ses) (((Param Type, (Result, SubExpRes))
-> BuilderT SOACS (State VNameSource) VName)
-> BuilderT SOACS (State VNameSource) [VName])
-> ((Param Type, (Result, SubExpRes))
-> BuilderT SOACS (State VNameSource) VName)
-> BuilderT SOACS (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \(Param Type
cons_p, (Result
i_ses, SubExpRes
v_se)) -> do
let f :: VName -> SubExpRes -> m VName
f VName
nm_in SubExpRes
i_se =
String -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
nm_in) (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ Safety -> VName -> [SubExp] -> [SubExp] -> BasicOp
UpdateAcc Safety
Safe VName
nm_in [SubExpRes -> SubExp
resSubExp SubExpRes
i_se] [SubExpRes -> SubExp
resSubExp SubExpRes
v_se]
(VName -> SubExpRes -> BuilderT SOACS (State VNameSource) VName)
-> VName -> Result -> BuilderT SOACS (State VNameSource) VName
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM VName -> SubExpRes -> BuilderT SOACS (State VNameSource) VName
forall {m :: * -> *}.
MonadBuilder m =>
VName -> SubExpRes -> m VName
f (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
cons_p) Result
i_ses
let lam_certs :: Certs
lam_certs = (SubExpRes -> Certs) -> Result -> Certs
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap SubExpRes -> Certs
resCerts (Result -> Certs) -> Result -> Certs
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult (Body SOACS -> Result) -> Body SOACS -> Result
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
scat_lam
Result -> Builder SOACS Result
forall a. a -> BuilderT SOACS (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> Builder SOACS Result) -> Result -> Builder SOACS Result
forall a b. (a -> b) -> a -> b
$ (VName -> SubExpRes) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (Certs -> SubExp -> SubExpRes
SubExpRes Certs
lam_certs (SubExp -> SubExpRes) -> (VName -> SubExp) -> VName -> SubExpRes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
res_nms
mkWithAccBdy' ([IotaInp], [RshpInp], ScatterSpec VName, Lambda SOACS)
static_arg (SubExp
dim : [SubExp]
dims) [SubExp]
dims_rev [VName]
iot_par_nms [LParam SOACS]
rshp_ps [LParam SOACS]
cons_ps = do
Scope SOACS
scope <- m (Scope SOACS)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
Builder SOACS Result -> m (Body SOACS)
forall rep (m :: * -> *) somerep.
(Buildable rep, MonadFreshNames m, HasScope somerep m,
SameScope somerep rep) =>
Builder rep Result -> m (Body rep)
runBodyBuilder (Builder SOACS Result -> m (Body SOACS))
-> Builder SOACS Result -> m (Body SOACS)
forall a b. (a -> b) -> a -> b
$ Scope SOACS -> Builder SOACS Result -> Builder SOACS Result
forall a.
Scope SOACS
-> BuilderT SOACS (State VNameSource) a
-> BuilderT SOACS (State VNameSource) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Scope SOACS
scope Scope SOACS -> Scope SOACS -> Scope SOACS
forall a. Semigroup a => a -> a -> a
<> [LParam SOACS] -> Scope SOACS
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams ([LParam SOACS]
rshp_ps [LParam SOACS] -> [LParam SOACS] -> [LParam SOACS]
forall a. [a] -> [a] -> [a]
++ [LParam SOACS]
cons_ps)) (Builder SOACS Result -> Builder SOACS Result)
-> Builder SOACS Result -> Builder SOACS Result
forall a b. (a -> b) -> a -> b
$ do
VName
iota_arr <- String
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"iota_arr" (Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) VName)
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (BuilderT SOACS (State VNameSource)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT SOACS (State VNameSource))))
-> BasicOp -> Exp (Rep (BuilderT SOACS (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
dim SubExp
se0 SubExp
se1 IntType
Int64
Param Type
iota_p <- String -> Type -> BuilderT SOACS (State VNameSource) (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"iota_arg" (Type -> BuilderT SOACS (State VNameSource) (Param Type))
-> Type -> BuilderT SOACS (State VNameSource) (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64
[Param Type]
rshp_ps' <- [(Int, Type)]
-> ((Int, Type) -> BuilderT SOACS (State VNameSource) (Param Type))
-> BuilderT SOACS (State VNameSource) [Param Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Int] -> [Type] -> [(Int, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 .. [Param Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Param Type]
[LParam SOACS]
rshp_ps Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Param Type -> Type) -> [Param Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> Type
forall dec. Param dec -> dec
paramDec [Param Type]
[LParam SOACS]
rshp_ps)) (((Int, Type) -> BuilderT SOACS (State VNameSource) (Param Type))
-> BuilderT SOACS (State VNameSource) [Param Type])
-> ((Int, Type) -> BuilderT SOACS (State VNameSource) (Param Type))
-> BuilderT SOACS (State VNameSource) [Param Type]
forall a b. (a -> b) -> a -> b
$
\(Int
i, Type
arr_tp) ->
String -> Type -> BuilderT SOACS (State VNameSource) (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (String
"rshp_arg_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i) (Type -> BuilderT SOACS (State VNameSource) (Param Type))
-> Type -> BuilderT SOACS (State VNameSource) (Param Type)
forall a b. (a -> b) -> a -> b
$ Int -> Type -> Type
forall u.
Int
-> TypeBase (ShapeBase SubExp) u -> TypeBase (ShapeBase SubExp) u
stripArray Int
1 Type
arr_tp
[Param Type]
cons_ps' <- [(Int, Type)]
-> ((Int, Type) -> BuilderT SOACS (State VNameSource) (Param Type))
-> BuilderT SOACS (State VNameSource) [Param Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Int] -> [Type] -> [(Int, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 .. [Param Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Param Type]
[LParam SOACS]
cons_ps Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Param Type -> Type) -> [Param Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> Type
forall dec. Param dec -> dec
paramDec [Param Type]
[LParam SOACS]
cons_ps)) (((Int, Type) -> BuilderT SOACS (State VNameSource) (Param Type))
-> BuilderT SOACS (State VNameSource) [Param Type])
-> ((Int, Type) -> BuilderT SOACS (State VNameSource) (Param Type))
-> BuilderT SOACS (State VNameSource) [Param Type]
forall a b. (a -> b) -> a -> b
$
\(Int
i, Type
arr_tp) ->
String -> Type -> BuilderT SOACS (State VNameSource) (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (String
"acc_arg_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i) Type
arr_tp
Body SOACS
map_lam_bdy <-
([IotaInp], [RshpInp], ScatterSpec VName, Lambda SOACS)
-> [SubExp]
-> [SubExp]
-> [VName]
-> [LParam SOACS]
-> [LParam SOACS]
-> BuilderT SOACS (State VNameSource) (Body SOACS)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
([IotaInp], [RshpInp], ScatterSpec VName, Lambda SOACS)
-> [SubExp]
-> [SubExp]
-> [VName]
-> [LParam SOACS]
-> [LParam SOACS]
-> m (Body SOACS)
mkWithAccBdy' ([IotaInp], [RshpInp], ScatterSpec VName, Lambda SOACS)
static_arg [SubExp]
dims (SubExp
dim SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: [SubExp]
dims_rev) ([VName]
iot_par_nms [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
iota_p]) [Param Type]
[LParam SOACS]
rshp_ps' [Param Type]
[LParam SOACS]
cons_ps'
let map_lam :: Lambda SOACS
map_lam = [LParam SOACS] -> [Type] -> Body SOACS -> Lambda SOACS
forall rep. [LParam rep] -> [Type] -> Body rep -> Lambda rep
Lambda ([Param Type]
rshp_ps' [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type
iota_p] [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type]
cons_ps') ((Param Type -> Type) -> [Param Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> Type
forall dec. Param dec -> dec
paramDec [Param Type]
cons_ps') Body SOACS
map_lam_bdy
map_inps :: [VName]
map_inps = (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
[LParam SOACS]
rshp_ps [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName
iota_arr] [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
[LParam SOACS]
cons_ps
map_soac :: SOAC SOACS
map_soac = SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
F.Screma SubExp
dim [VName]
map_inps (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [Scan SOACS] -> [Reduce SOACS] -> ScremaForm SOACS
forall rep.
Lambda rep -> [Scan rep] -> [Reduce rep] -> ScremaForm rep
ScremaForm Lambda SOACS
map_lam [] []
[VName]
res_nms <- String
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"acc_res" (Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) [VName])
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ Op SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op Op SOACS
SOAC SOACS
map_soac
Result -> Builder SOACS Result
forall a. a -> BuilderT SOACS (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> Builder SOACS Result) -> Result -> Builder SOACS Result
forall a b. (a -> b) -> a -> b
$ (VName -> SubExpRes) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> SubExpRes
subExpRes (SubExp -> SubExpRes) -> (VName -> SubExp) -> VName -> SubExpRes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
res_nms
type AccTup =
( [PatElem (LetDec SOACS)],
WithAccInput SOACS,
LParam SOACS,
LParam SOACS,
(VName, Certs)
)
accTup1 :: AccTup -> [PatElem (LetDec SOACS)]
accTup1 :: AccTup -> [PatElem (LetDec SOACS)]
accTup1 ([PatElem (LetDec SOACS)]
a, WithAccInput SOACS
_, LParam SOACS
_, LParam SOACS
_, (VName, Certs)
_) = [PatElem (LetDec SOACS)]
a
accTup2 :: AccTup -> WithAccInput SOACS
accTup2 :: AccTup -> WithAccInput SOACS
accTup2 ([PatElem (LetDec SOACS)]
_, WithAccInput SOACS
a, LParam SOACS
_, LParam SOACS
_, (VName, Certs)
_) = WithAccInput SOACS
a
accTup3 :: AccTup -> LParam SOACS
accTup3 :: AccTup -> LParam SOACS
accTup3 ([PatElem (LetDec SOACS)]
_, WithAccInput SOACS
_, LParam SOACS
a, LParam SOACS
_, (VName, Certs)
_) = LParam SOACS
a
accTup4 :: AccTup -> LParam SOACS
accTup4 :: AccTup -> LParam SOACS
accTup4 ([PatElem (LetDec SOACS)]
_, WithAccInput SOACS
_, LParam SOACS
_, LParam SOACS
a, (VName, Certs)
_) = LParam SOACS
a
accTup5 :: AccTup -> (VName, Certs)
accTup5 :: AccTup -> (VName, Certs)
accTup5 ([PatElem (LetDec SOACS)]
_, WithAccInput SOACS
_, LParam SOACS
_, LParam SOACS
_, (VName, Certs)
a) = (VName, Certs)
a
tryFuseWithAccs ::
(HasScope SOACS m, MonadFreshNames m) =>
[VName] ->
Stm SOACS ->
Stm SOACS ->
m (Maybe (Stm SOACS))
tryFuseWithAccs :: forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
[VName] -> Stm SOACS -> Stm SOACS -> m (Maybe (Stm SOACS))
tryFuseWithAccs
[VName]
infusible
(Let Pat (LetDec SOACS)
pat1 StmAux (ExpDec SOACS)
aux1 (WithAcc [WithAccInput SOACS]
w_inps1 Lambda SOACS
lam1))
(Let Pat (LetDec SOACS)
pat2 StmAux (ExpDec SOACS)
aux2 (WithAcc [WithAccInput SOACS]
w_inps2 Lambda SOACS
lam2))
| ([PatElem Type]
pat1_els, [PatElem Type]
pat2_els) <- (Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
Pat (LetDec SOACS)
pat1, Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
Pat (LetDec SOACS)
pat2),
([AccTup]
acc_tup1, [(PatElem (LetDec SOACS), SubExpRes)]
other_pr1) <- [PatElem (LetDec SOACS)]
-> [WithAccInput SOACS]
-> Lambda SOACS
-> ([AccTup], [(PatElem (LetDec SOACS), SubExpRes)])
groupAccs [PatElem Type]
[PatElem (LetDec SOACS)]
pat1_els [WithAccInput SOACS]
w_inps1 Lambda SOACS
lam1,
([AccTup]
acc_tup2, [(PatElem (LetDec SOACS), SubExpRes)]
other_pr2) <- [PatElem (LetDec SOACS)]
-> [WithAccInput SOACS]
-> Lambda SOACS
-> ([AccTup], [(PatElem (LetDec SOACS), SubExpRes)])
groupAccs [PatElem Type]
[PatElem (LetDec SOACS)]
pat2_els [WithAccInput SOACS]
w_inps2 Lambda SOACS
lam2,
([(AccTup, AccTup)]
tup_common, [AccTup]
acc_tup1', [AccTup]
acc_tup2') <-
[AccTup] -> [AccTup] -> ([(AccTup, AccTup)], [AccTup], [AccTup])
groupCommonAccs [AccTup]
acc_tup1 [AccTup]
acc_tup2,
[VName]
pnms_1' <- (PatElem Type -> VName) -> [PatElem Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName ([PatElem Type] -> [VName]) -> [PatElem Type] -> [VName]
forall a b. (a -> b) -> a -> b
$ (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> [PatElem Type])
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
-> [PatElem Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\([PatElem Type]
nms, WithAccInput SOACS
_, Param Type
_, Param Type
_, (VName, Certs)
_) -> [PatElem Type]
nms) [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
[AccTup]
acc_tup1',
[VName]
winp_2' <- (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> [VName])
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
-> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\([PatElem Type]
_, (ShapeBase SubExp
_, [VName]
nms, Maybe (Lambda SOACS, [SubExp])
_), Param Type
_, Param Type
_, (VName, Certs)
_) -> [VName]
nms) [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
[AccTup]
acc_tup2',
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Names -> Names -> Bool
namesIntersect ([VName] -> Names
namesFromList [VName]
pnms_1') ([VName] -> Names
namesFromList [VName]
winp_2'),
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Names -> Names -> Bool
namesIntersect ([VName] -> Names
namesFromList [VName]
pnms_1') (Lambda SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn Lambda SOACS
lam2),
[VName]
bs <- (PatElem Type -> VName) -> [PatElem Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName ([PatElem Type] -> [VName]) -> [PatElem Type] -> [VName]
forall a b. (a -> b) -> a -> b
$ ((([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))
-> [PatElem Type])
-> [(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))]
-> [PatElem Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> [PatElem Type]
AccTup -> [PatElem (LetDec SOACS)]
accTup1 (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> [PatElem Type])
-> ((([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))
-> ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))
-> (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))
-> [PatElem Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))
-> ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
forall a b. (a, b) -> a
fst) [(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))]
[(AccTup, AccTup)]
tup_common,
(VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
infusible) [VName]
bs,
Names
cs <- [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> [VName])
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
-> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((\(ShapeBase SubExp
_, [VName]
xs, Maybe (Lambda SOACS, [SubExp])
_) -> [VName]
xs) (WithAccInput SOACS -> [VName])
-> (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> WithAccInput SOACS)
-> ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> WithAccInput SOACS
AccTup -> WithAccInput SOACS
accTup2) [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
[AccTup]
acc_tup2,
((PatElem Type, SubExpRes) -> Bool)
-> [(PatElem Type, SubExpRes)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((VName -> Names -> Bool
`notNameIn` Names
cs) (VName -> Bool)
-> ((PatElem Type, SubExpRes) -> VName)
-> (PatElem Type, SubExpRes)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName (PatElem Type -> VName)
-> ((PatElem Type, SubExpRes) -> PatElem Type)
-> (PatElem Type, SubExpRes)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElem Type, SubExpRes) -> PatElem Type
forall a b. (a, b) -> a
fst) [(PatElem Type, SubExpRes)]
[(PatElem (LetDec SOACS), SubExpRes)]
other_pr1 = do
let getCertPairs :: (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))
-> (VName, VName)
getCertPairs (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
t1, ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
t2) = (Param Type -> VName
forall dec. Param dec -> VName
paramName (AccTup -> LParam SOACS
accTup3 ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
AccTup
t2), Param Type -> VName
forall dec. Param dec -> VName
paramName (AccTup -> LParam SOACS
accTup3 ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
AccTup
t1))
tab_certs :: Map VName VName
tab_certs = [(VName, VName)] -> Map VName VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, VName)] -> Map VName VName)
-> [(VName, VName)] -> Map VName VName
forall a b. (a -> b) -> a -> b
$ ((([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))
-> (VName, VName))
-> [(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))]
-> [(VName, VName)]
forall a b. (a -> b) -> [a] -> [b]
map (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))
-> (VName, VName)
getCertPairs [(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))]
[(AccTup, AccTup)]
tup_common
lam2_bdy' :: Body SOACS
lam2_bdy' = Map VName VName -> Body SOACS -> Body SOACS
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
tab_certs (Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam2)
rcrt_params :: [Param Type]
rcrt_params = ((([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))
-> Param Type)
-> [(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))]
-> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> Param Type
AccTup -> LParam SOACS
accTup3 (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> Param Type)
-> ((([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))
-> ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))
-> (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))
-> Param Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))
-> ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
forall a b. (a, b) -> a
fst) [(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))]
[(AccTup, AccTup)]
tup_common [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> Param Type)
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
-> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> Param Type
AccTup -> LParam SOACS
accTup3 [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
[AccTup]
acc_tup1' [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> Param Type)
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
-> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> Param Type
AccTup -> LParam SOACS
accTup3 [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
[AccTup]
acc_tup2'
racc_params :: [Param Type]
racc_params = ((([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))
-> Param Type)
-> [(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))]
-> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> Param Type
AccTup -> LParam SOACS
accTup4 (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> Param Type)
-> ((([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))
-> ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))
-> (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))
-> Param Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))
-> ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
forall a b. (a, b) -> a
fst) [(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))]
[(AccTup, AccTup)]
tup_common [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> Param Type)
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
-> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> Param Type
AccTup -> LParam SOACS
accTup4 [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
[AccTup]
acc_tup1' [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> Param Type)
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
-> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> Param Type
AccTup -> LParam SOACS
accTup4 [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
[AccTup]
acc_tup2'
([VName]
comm_res_nms, [Certs]
comm_res_certs2) = [(VName, Certs)] -> ([VName], [Certs])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, Certs)] -> ([VName], [Certs]))
-> [(VName, Certs)] -> ([VName], [Certs])
forall a b. (a -> b) -> a -> b
$ ((([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))
-> (VName, Certs))
-> [(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))]
-> [(VName, Certs)]
forall a b. (a -> b) -> [a] -> [b]
map (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> (VName, Certs)
AccTup -> (VName, Certs)
accTup5 (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> (VName, Certs))
-> ((([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))
-> ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))
-> (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))
-> (VName, Certs)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))
-> ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
forall a b. (a, b) -> b
snd) [(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))]
[(AccTup, AccTup)]
tup_common
([VName]
_, [Certs]
comm_res_certs1) = [(VName, Certs)] -> ([VName], [Certs])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, Certs)] -> ([VName], [Certs]))
-> [(VName, Certs)] -> ([VName], [Certs])
forall a b. (a -> b) -> a -> b
$ ((([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))
-> (VName, Certs))
-> [(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))]
-> [(VName, Certs)]
forall a b. (a -> b) -> [a] -> [b]
map (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> (VName, Certs)
AccTup -> (VName, Certs)
accTup5 (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> (VName, Certs))
-> ((([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))
-> ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))
-> (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))
-> (VName, Certs)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))
-> ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
forall a b. (a, b) -> a
fst) [(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))]
[(AccTup, AccTup)]
tup_common
com_res_certs :: [Certs]
com_res_certs = (Certs -> Certs -> Certs) -> [Certs] -> [Certs] -> [Certs]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Certs
x Certs
y -> [VName] -> Certs
Certs (Certs -> [VName]
unCerts Certs
x [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ Certs -> [VName]
unCerts Certs
y)) [Certs]
comm_res_certs1 [Certs]
comm_res_certs2
bdyres_certs :: [Certs]
bdyres_certs = [Certs]
com_res_certs [Certs] -> [Certs] -> [Certs]
forall a. [a] -> [a] -> [a]
++ (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> Certs)
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
-> [Certs]
forall a b. (a -> b) -> [a] -> [b]
map ((VName, Certs) -> Certs
forall a b. (a, b) -> b
snd ((VName, Certs) -> Certs)
-> (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> (VName, Certs))
-> ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> Certs
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> (VName, Certs)
AccTup -> (VName, Certs)
accTup5) ([([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
[AccTup]
acc_tup1' [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
forall a. [a] -> [a] -> [a]
++ [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
[AccTup]
acc_tup2')
bdyres_accse :: [SubExp]
bdyres_accse = (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
comm_res_nms [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> SubExp)
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
-> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> VName)
-> ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, Certs) -> VName
forall a b. (a, b) -> a
fst ((VName, Certs) -> VName)
-> (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> (VName, Certs))
-> ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> (VName, Certs)
AccTup -> (VName, Certs)
accTup5) ([([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
[AccTup]
acc_tup1' [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
forall a. [a] -> [a] -> [a]
++ [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
[AccTup]
acc_tup2')
bdy_res_accs :: Result
bdy_res_accs = (Certs -> SubExp -> SubExpRes) -> [Certs] -> [SubExp] -> Result
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Certs -> SubExp -> SubExpRes
SubExpRes [Certs]
bdyres_certs [SubExp]
bdyres_accse
bdy_res_others :: Result
bdy_res_others = ((PatElem Type, SubExpRes) -> SubExpRes)
-> [(PatElem Type, SubExpRes)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (PatElem Type, SubExpRes) -> SubExpRes
forall a b. (a, b) -> b
snd ([(PatElem Type, SubExpRes)] -> Result)
-> [(PatElem Type, SubExpRes)] -> Result
forall a b. (a -> b) -> a -> b
$ [(PatElem Type, SubExpRes)]
[(PatElem (LetDec SOACS), SubExpRes)]
other_pr1 [(PatElem Type, SubExpRes)]
-> [(PatElem Type, SubExpRes)] -> [(PatElem Type, SubExpRes)]
forall a. [a] -> [a] -> [a]
++ [(PatElem Type, SubExpRes)]
[(PatElem (LetDec SOACS), SubExpRes)]
other_pr2
Scope SOACS
scope <- m (Scope SOACS)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
Body SOACS
lam_bdy <-
Builder SOACS Result -> m (Body SOACS)
forall rep (m :: * -> *) somerep.
(Buildable rep, MonadFreshNames m, HasScope somerep m,
SameScope somerep rep) =>
Builder rep Result -> m (Body rep)
runBodyBuilder (Builder SOACS Result -> m (Body SOACS))
-> Builder SOACS Result -> m (Body SOACS)
forall a b. (a -> b) -> a -> b
$ do
Scope SOACS -> Builder SOACS Result -> Builder SOACS Result
forall a.
Scope SOACS
-> BuilderT SOACS (State VNameSource) a
-> BuilderT SOACS (State VNameSource) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Scope SOACS
scope Scope SOACS -> Scope SOACS -> Scope SOACS
forall a. Semigroup a => a -> a -> a
<> [Param Type] -> Scope SOACS
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams ([Param Type]
rcrt_params [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type]
racc_params)) (Builder SOACS Result -> Builder SOACS Result)
-> Builder SOACS Result -> Builder SOACS Result
forall a b. (a -> b) -> a -> b
$ do
(Stm SOACS -> BuilderT SOACS (State VNameSource) ())
-> [Stm SOACS] -> BuilderT SOACS (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ()
Stm SOACS -> BuilderT SOACS (State VNameSource) ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm ([Stm SOACS] -> BuilderT SOACS (State VNameSource) ())
-> [Stm SOACS] -> BuilderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ Seq (Stm SOACS) -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList (Seq (Stm SOACS) -> [Stm SOACS]) -> Seq (Stm SOACS) -> [Stm SOACS]
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Seq (Stm SOACS)
forall rep. Body rep -> Stms rep
bodyStms (Body SOACS -> Seq (Stm SOACS)) -> Body SOACS -> Seq (Stm SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam1
[(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))]
-> ((([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))
-> BuilderT SOACS (State VNameSource) ())
-> BuilderT SOACS (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))]
[(AccTup, AccTup)]
tup_common (((([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))
-> BuilderT SOACS (State VNameSource) ())
-> BuilderT SOACS (State VNameSource) ())
-> ((([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))
-> BuilderT SOACS (State VNameSource) ())
-> BuilderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ \(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
tup1, ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
tup2) -> do
let (Param Type
lpar1, Param Type
lpar2) = (AccTup -> LParam SOACS
accTup4 ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
AccTup
tup1, AccTup -> LParam SOACS
accTup4 ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
AccTup
tup2)
((VName
nm1, Certs
_), VName
nm2, Type
tp_acc) = (AccTup -> (VName, Certs)
accTup5 ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
AccTup
tup1, Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
lpar2, Param Type -> Type
forall dec. Param dec -> dec
paramDec Param Type
lpar1)
Pat (LetDec (Rep (BuilderT SOACS (State VNameSource))))
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind ([PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [VName -> Type -> PatElem Type
forall dec. VName -> dec -> PatElem dec
PatElem VName
nm2 Type
tp_acc]) (Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ())
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (BuilderT SOACS (State VNameSource)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT SOACS (State VNameSource))))
-> BasicOp -> Exp (Rep (BuilderT SOACS (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
nm1
[(PatElem Type, SubExpRes)]
-> ((PatElem Type, SubExpRes)
-> BuilderT SOACS (State VNameSource) ())
-> BuilderT SOACS (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(PatElem Type, SubExpRes)]
[(PatElem (LetDec SOACS), SubExpRes)]
other_pr1 (((PatElem Type, SubExpRes)
-> BuilderT SOACS (State VNameSource) ())
-> BuilderT SOACS (State VNameSource) ())
-> ((PatElem Type, SubExpRes)
-> BuilderT SOACS (State VNameSource) ())
-> BuilderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ \(PatElem Type
pat_elm, SubExpRes
bdy_res) -> do
let (VName
nm, SubExp
se, Type
tp) = (PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName PatElem Type
pat_elm, SubExpRes -> SubExp
resSubExp SubExpRes
bdy_res, PatElem Type -> Type
forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem Type
pat_elm)
Certs
-> BuilderT SOACS (State VNameSource) ()
-> BuilderT SOACS (State VNameSource) ()
forall a.
Certs
-> BuilderT SOACS (State VNameSource) a
-> BuilderT SOACS (State VNameSource) a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (SubExpRes -> Certs
resCerts SubExpRes
bdy_res) (BuilderT SOACS (State VNameSource) ()
-> BuilderT SOACS (State VNameSource) ())
-> BuilderT SOACS (State VNameSource) ()
-> BuilderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
Pat (LetDec (Rep (BuilderT SOACS (State VNameSource))))
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind ([PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [VName -> Type -> PatElem Type
forall dec. VName -> dec -> PatElem dec
PatElem VName
nm Type
tp]) (Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ())
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (SubExp -> BasicOp
SubExp SubExp
se)
(Stm SOACS -> BuilderT SOACS (State VNameSource) ())
-> [Stm SOACS] -> BuilderT SOACS (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ()
Stm SOACS -> BuilderT SOACS (State VNameSource) ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm ([Stm SOACS] -> BuilderT SOACS (State VNameSource) ())
-> [Stm SOACS] -> BuilderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ Seq (Stm SOACS) -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList (Seq (Stm SOACS) -> [Stm SOACS]) -> Seq (Stm SOACS) -> [Stm SOACS]
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Seq (Stm SOACS)
forall rep. Body rep -> Stms rep
bodyStms Body SOACS
lam2_bdy'
Result -> Builder SOACS Result
forall a. a -> BuilderT SOACS (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> Builder SOACS Result) -> Result -> Builder SOACS Result
forall a b. (a -> b) -> a -> b
$ Result
bdy_res_accs Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
bdy_res_others
let tp_res_other :: [Type]
tp_res_other = ((PatElem Type, SubExpRes) -> Type)
-> [(PatElem Type, SubExpRes)] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (PatElem Type -> Type
forall dec. Typed dec => PatElem dec -> Type
patElemType (PatElem Type -> Type)
-> ((PatElem Type, SubExpRes) -> PatElem Type)
-> (PatElem Type, SubExpRes)
-> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElem Type, SubExpRes) -> PatElem Type
forall a b. (a, b) -> a
fst) ([(PatElem Type, SubExpRes)]
[(PatElem (LetDec SOACS), SubExpRes)]
other_pr1 [(PatElem Type, SubExpRes)]
-> [(PatElem Type, SubExpRes)] -> [(PatElem Type, SubExpRes)]
forall a. [a] -> [a] -> [a]
++ [(PatElem Type, SubExpRes)]
[(PatElem (LetDec SOACS), SubExpRes)]
other_pr2)
res_lam :: Lambda SOACS
res_lam =
Lambda
{ lambdaParams :: [LParam SOACS]
lambdaParams = [Param Type]
rcrt_params [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type]
racc_params,
lambdaBody :: Body SOACS
lambdaBody = Body SOACS
lam_bdy,
lambdaReturnType :: [Type]
lambdaReturnType = (Param Type -> Type) -> [Param Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> Type
forall dec. Param dec -> dec
paramDec [Param Type]
racc_params [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
tp_res_other
}
Lambda SOACS
res_lam' <- Lambda SOACS -> m (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
res_lam
let res_pat :: [PatElem Type]
res_pat =
((([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))
-> [PatElem Type])
-> [(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))]
-> [PatElem Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> [PatElem Type]
AccTup -> [PatElem (LetDec SOACS)]
accTup1 (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> [PatElem Type])
-> ((([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))
-> ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))
-> (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))
-> [PatElem Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))
-> ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
forall a b. (a, b) -> b
snd) [(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))]
[(AccTup, AccTup)]
tup_common
[PatElem Type] -> [PatElem Type] -> [PatElem Type]
forall a. [a] -> [a] -> [a]
++ (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> [PatElem Type])
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
-> [PatElem Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> [PatElem Type]
AccTup -> [PatElem (LetDec SOACS)]
accTup1 ([([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
[AccTup]
acc_tup1' [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
forall a. [a] -> [a] -> [a]
++ [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
[AccTup]
acc_tup2')
[PatElem Type] -> [PatElem Type] -> [PatElem Type]
forall a. [a] -> [a] -> [a]
++ ((PatElem Type, SubExpRes) -> PatElem Type)
-> [(PatElem Type, SubExpRes)] -> [PatElem Type]
forall a b. (a -> b) -> [a] -> [b]
map (PatElem Type, SubExpRes) -> PatElem Type
forall a b. (a, b) -> a
fst ([(PatElem Type, SubExpRes)]
[(PatElem (LetDec SOACS), SubExpRes)]
other_pr1 [(PatElem Type, SubExpRes)]
-> [(PatElem Type, SubExpRes)] -> [(PatElem Type, SubExpRes)]
forall a. [a] -> [a] -> [a]
++ [(PatElem Type, SubExpRes)]
[(PatElem (LetDec SOACS), SubExpRes)]
other_pr2)
res_w_inps :: [WithAccInput SOACS]
res_w_inps = ((([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))
-> WithAccInput SOACS)
-> [(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))]
-> [WithAccInput SOACS]
forall a b. (a -> b) -> [a] -> [b]
map (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> WithAccInput SOACS
AccTup -> WithAccInput SOACS
accTup2 (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> WithAccInput SOACS)
-> ((([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))
-> ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))
-> (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))
-> WithAccInput SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))
-> ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
forall a b. (a, b) -> a
fst) [(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))]
[(AccTup, AccTup)]
tup_common [WithAccInput SOACS]
-> [WithAccInput SOACS] -> [WithAccInput SOACS]
forall a. [a] -> [a] -> [a]
++ (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> WithAccInput SOACS)
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
-> [WithAccInput SOACS]
forall a b. (a -> b) -> [a] -> [b]
map ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> WithAccInput SOACS
AccTup -> WithAccInput SOACS
accTup2 ([([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
[AccTup]
acc_tup1' [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
forall a. [a] -> [a] -> [a]
++ [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
[AccTup]
acc_tup2')
[WithAccInput SOACS]
res_w_inps' <- (WithAccInput SOACS -> m (WithAccInput SOACS))
-> [WithAccInput SOACS] -> m [WithAccInput SOACS]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM WithAccInput SOACS -> m (WithAccInput SOACS)
forall {m :: * -> *} {rep} {a} {b} {b}.
(Rename (OpC rep rep), Rename (LetDec rep), Rename (ExpDec rep),
Rename (BodyDec rep), Rename (FParamInfo rep),
Rename (LParamInfo rep), Rename (RetType rep),
Rename (BranchType rep), MonadFreshNames m) =>
(a, b, Maybe (Lambda rep, b)) -> m (a, b, Maybe (Lambda rep, b))
renameLamInWAccInp [WithAccInput SOACS]
res_w_inps
let stm_res :: Stm SOACS
stm_res = Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type]
res_pat) (StmAux ()
StmAux (ExpDec SOACS)
aux1 StmAux () -> StmAux () -> StmAux ()
forall a. Semigroup a => a -> a -> a
<> StmAux ()
StmAux (ExpDec SOACS)
aux2) (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ [WithAccInput SOACS] -> Lambda SOACS -> Exp SOACS
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [WithAccInput SOACS]
res_w_inps' Lambda SOACS
res_lam'
Maybe (Stm SOACS) -> m (Maybe (Stm SOACS))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (Stm SOACS) -> m (Maybe (Stm SOACS)))
-> Maybe (Stm SOACS) -> m (Maybe (Stm SOACS))
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Maybe (Stm SOACS)
forall a. a -> Maybe a
Just Stm SOACS
stm_res
where
groupAccs ::
[PatElem (LetDec SOACS)] ->
[WithAccInput SOACS] ->
Lambda SOACS ->
([AccTup], [(PatElem (LetDec SOACS), SubExpRes)])
groupAccs :: [PatElem (LetDec SOACS)]
-> [WithAccInput SOACS]
-> Lambda SOACS
-> ([AccTup], [(PatElem (LetDec SOACS), SubExpRes)])
groupAccs [PatElem (LetDec SOACS)]
pat_els [WithAccInput SOACS]
wacc_inps Lambda SOACS
wlam =
let lam_params :: [LParam SOACS]
lam_params = Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
wlam
n :: Int
n = [Param Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Param Type]
lam_params
([Param Type]
lam_par_crts, [Param Type]
lam_par_accs) = Int -> [Param Type] -> ([Param Type], [Param Type])
forall a. Int -> [a] -> ([a], [a])
splitAt (Int
n Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) [Param Type]
lam_params
lab_res_ses :: Result
lab_res_ses = Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult (Body SOACS -> Result) -> Body SOACS -> Result
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
wlam
in [PatElem (LetDec SOACS)]
-> [WithAccInput SOACS]
-> [LParam SOACS]
-> [LParam SOACS]
-> Result
-> ([AccTup], [(PatElem (LetDec SOACS), SubExpRes)])
groupAccsHlp [PatElem (LetDec SOACS)]
pat_els [WithAccInput SOACS]
wacc_inps [Param Type]
[LParam SOACS]
lam_par_crts [Param Type]
[LParam SOACS]
lam_par_accs Result
lab_res_ses
groupAccsHlp ::
[PatElem (LetDec SOACS)] ->
[WithAccInput SOACS] ->
[LParam SOACS] ->
[LParam SOACS] ->
[SubExpRes] ->
([AccTup], [(PatElem (LetDec SOACS), SubExpRes)])
groupAccsHlp :: [PatElem (LetDec SOACS)]
-> [WithAccInput SOACS]
-> [LParam SOACS]
-> [LParam SOACS]
-> Result
-> ([AccTup], [(PatElem (LetDec SOACS), SubExpRes)])
groupAccsHlp [PatElem (LetDec SOACS)]
pat_els [] [] [] Result
lam_res_ses
| [PatElem Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElem Type]
[PatElem (LetDec SOACS)]
pat_els Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Result -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
lam_res_ses =
([], [PatElem Type] -> Result -> [(PatElem Type, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem Type]
[PatElem (LetDec SOACS)]
pat_els Result
lam_res_ses)
groupAccsHlp
[PatElem (LetDec SOACS)]
pat_els
(winp :: WithAccInput SOACS
winp@(ShapeBase SubExp
_, [VName]
inp, Maybe (Lambda SOACS, [SubExp])
_) : [WithAccInput SOACS]
wacc_inps)
(LParam SOACS
par_crt : [LParam SOACS]
lam_par_crts)
(LParam SOACS
par_acc : [LParam SOACS]
lam_par_accs)
(SubExpRes
res_se : Result
lam_res_ses)
| Int
n <- [VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
inp,
(Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= [PatElem Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElem Type]
[PatElem (LetDec SOACS)]
pat_els) Bool -> Bool -> Bool
&& (Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Result -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
lam_res_ses)),
Var VName
res_nm <- SubExpRes -> SubExp
resSubExp SubExpRes
res_se =
let ([PatElem Type]
pat_els_cur, [PatElem Type]
pat_els') = Int -> [PatElem Type] -> ([PatElem Type], [PatElem Type])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
n [PatElem Type]
[PatElem (LetDec SOACS)]
pat_els
([AccTup]
rec1, [(PatElem (LetDec SOACS), SubExpRes)]
rec2) = [PatElem (LetDec SOACS)]
-> [WithAccInput SOACS]
-> [LParam SOACS]
-> [LParam SOACS]
-> Result
-> ([AccTup], [(PatElem (LetDec SOACS), SubExpRes)])
groupAccsHlp [PatElem Type]
[PatElem (LetDec SOACS)]
pat_els' [WithAccInput SOACS]
wacc_inps [LParam SOACS]
lam_par_crts [LParam SOACS]
lam_par_accs Result
lam_res_ses
in (([PatElem Type]
pat_els_cur, WithAccInput SOACS
winp, Param Type
LParam SOACS
par_crt, Param Type
LParam SOACS
par_acc, (VName
res_nm, SubExpRes -> Certs
resCerts SubExpRes
res_se)) ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
forall a. a -> [a] -> [a]
: [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
rec1, [(PatElem Type, SubExpRes)]
[(PatElem (LetDec SOACS), SubExpRes)]
rec2)
groupAccsHlp [PatElem (LetDec SOACS)]
_ [WithAccInput SOACS]
_ [LParam SOACS]
_ [LParam SOACS]
_ Result
_ =
String
-> ([([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))],
[(PatElem Type, SubExpRes)])
forall a. HasCallStack => String -> a
error String
"Unreachable case reached in groupAccsHlp!"
groupCommonAccs :: [AccTup] -> [AccTup] -> ([(AccTup, AccTup)], [AccTup], [AccTup])
groupCommonAccs :: [AccTup] -> [AccTup] -> ([(AccTup, AccTup)], [AccTup], [AccTup])
groupCommonAccs [] [AccTup]
tup_accs2 =
([], [], [AccTup]
tup_accs2)
groupCommonAccs (AccTup
tup_acc1 : [AccTup]
tup_accs1) [AccTup]
tup_accs2
| [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
commons2 <- (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> Bool)
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
forall a. (a -> Bool) -> [a] -> [a]
filter (AccTup -> AccTup -> Bool
matchingAccTup AccTup
tup_acc1) [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
[AccTup]
tup_accs2,
[([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
-> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
commons2 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1 =
let ([(AccTup, AccTup)]
rec1, [AccTup]
rec2, [AccTup]
rec3) =
[AccTup] -> [AccTup] -> ([(AccTup, AccTup)], [AccTup], [AccTup])
groupCommonAccs [AccTup]
tup_accs1 ([AccTup] -> ([(AccTup, AccTup)], [AccTup], [AccTup]))
-> [AccTup] -> ([(AccTup, AccTup)], [AccTup], [AccTup])
forall a b. (a -> b) -> a -> b
$
if [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
-> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
commons2
then [AccTup]
tup_accs2
else (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> Bool)
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool)
-> (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> Bool)
-> ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AccTup -> AccTup -> Bool
matchingAccTup AccTup
tup_acc1) [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
[AccTup]
tup_accs2
in if [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
-> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
commons2
then ([(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))]
[(AccTup, AccTup)]
rec1, ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
AccTup
tup_acc1 ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
forall a. a -> [a] -> [a]
: [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
rec2, [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
[AccTup]
rec3)
else ((([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
AccTup
tup_acc1, [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
-> ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))
forall a. HasCallStack => [a] -> a
head [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
commons2) (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))
-> [(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))]
-> [(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))]
forall a. a -> [a] -> [a]
: [(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))]
rec1, [AccTup]
tup_accs1, [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))]
[AccTup]
rec3)
groupCommonAccs [AccTup]
_ [AccTup]
_ =
String
-> ([(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)),
([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs)))],
[([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))],
[([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
(VName, Certs))])
forall a. HasCallStack => String -> a
error String
"Unreachable case reached in groupCommonAccs!"
renameLamInWAccInp :: (a, b, Maybe (Lambda rep, b)) -> m (a, b, Maybe (Lambda rep, b))
renameLamInWAccInp (a
shp, b
inps, Just (Lambda rep
lam, b
se)) = do
Lambda rep
lam' <- Lambda rep -> m (Lambda rep)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda rep
lam
(a, b, Maybe (Lambda rep, b)) -> m (a, b, Maybe (Lambda rep, b))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
shp, b
inps, (Lambda rep, b) -> Maybe (Lambda rep, b)
forall a. a -> Maybe a
Just (Lambda rep
lam', b
se))
renameLamInWAccInp (a, b, Maybe (Lambda rep, b))
winp = (a, b, Maybe (Lambda rep, b)) -> m (a, b, Maybe (Lambda rep, b))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a, b, Maybe (Lambda rep, b))
winp
tryFuseWithAccs [VName]
_ Stm SOACS
_ Stm SOACS
_ =
Maybe (Stm SOACS) -> m (Maybe (Stm SOACS))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Stm SOACS)
forall a. Maybe a
Nothing
equivLambda ::
M.Map VName VName ->
Lambda SOACS ->
Lambda SOACS ->
Bool
equivLambda :: Map VName VName -> Lambda SOACS -> Lambda SOACS -> Bool
equivLambda Map VName VName
stab Lambda SOACS
lam1 Lambda SOACS
lam2
| ([Param Type]
ps1, [Param Type]
ps2) <- (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam1, Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam2),
([VName]
nms1, [VName]
nms2) <- ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
ps1, (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
ps2),
(Param Type -> Type) -> [Param Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> Type
forall dec. Param dec -> dec
paramDec [Param Type]
ps1 [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== (Param Type -> Type) -> [Param Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> Type
forall dec. Param dec -> dec
paramDec [Param Type]
ps2,
(Param Type -> Attrs) -> [Param Type] -> [Attrs]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> Attrs
forall dec. Param dec -> Attrs
paramAttrs [Param Type]
ps1 [Attrs] -> [Attrs] -> Bool
forall a. Eq a => a -> a -> Bool
== (Param Type -> Attrs) -> [Param Type] -> [Attrs]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> Attrs
forall dec. Param dec -> Attrs
paramAttrs [Param Type]
ps2,
Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
lam1 [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
lam2,
(Body SOACS
bdy1, Body SOACS
bdy2) <- (Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam1, Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam2),
Body SOACS -> BodyDec SOACS
forall rep. Body rep -> BodyDec rep
bodyDec Body SOACS
bdy1 () -> () -> Bool
forall a. Eq a => a -> a -> Bool
== Body SOACS -> BodyDec SOACS
forall rep. Body rep -> BodyDec rep
bodyDec Body SOACS
bdy2 =
let insert :: Map k a -> (a, k) -> Map k a
insert Map k a
tab (a
x, k
k) = k -> a -> Map k a -> Map k a
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert k
k a
x Map k a
tab
stab' :: Map VName VName
stab' = (Map VName VName -> (VName, VName) -> Map VName VName)
-> Map VName VName -> [(VName, VName)] -> Map VName VName
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Map VName VName -> (VName, VName) -> Map VName VName
forall {k} {a}. Ord k => Map k a -> (a, k) -> Map k a
insert Map VName VName
stab ([(VName, VName)] -> Map VName VName)
-> [(VName, VName)] -> Map VName VName
forall a b. (a -> b) -> a -> b
$ [VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
nms1 [VName]
nms2
fStm :: (Map VName VName, Bool)
-> (Stm SOACS, Stm SOACS) -> (Map VName VName, Bool)
fStm (Map VName VName
vtab, Bool
False) (Stm SOACS, Stm SOACS)
_ = (Map VName VName
vtab, Bool
False)
fStm (Map VName VName
vtab, Bool
True) (Stm SOACS
s1, Stm SOACS
s2) = Map VName VName
-> Stm SOACS -> Stm SOACS -> (Map VName VName, Bool)
equivStm Map VName VName
vtab Stm SOACS
s1 Stm SOACS
s2
(Map VName VName
stab'', Bool
success) =
((Map VName VName, Bool)
-> (Stm SOACS, Stm SOACS) -> (Map VName VName, Bool))
-> (Map VName VName, Bool)
-> [(Stm SOACS, Stm SOACS)]
-> (Map VName VName, Bool)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (Map VName VName, Bool)
-> (Stm SOACS, Stm SOACS) -> (Map VName VName, Bool)
fStm (Map VName VName
stab', Bool
True) ([(Stm SOACS, Stm SOACS)] -> (Map VName VName, Bool))
-> [(Stm SOACS, Stm SOACS)] -> (Map VName VName, Bool)
forall a b. (a -> b) -> a -> b
$
[Stm SOACS] -> [Stm SOACS] -> [(Stm SOACS, Stm SOACS)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Seq (Stm SOACS) -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList (Body SOACS -> Seq (Stm SOACS)
forall rep. Body rep -> Stms rep
bodyStms Body SOACS
bdy1)) ([Stm SOACS] -> [(Stm SOACS, Stm SOACS)])
-> [Stm SOACS] -> [(Stm SOACS, Stm SOACS)]
forall a b. (a -> b) -> a -> b
$
Seq (Stm SOACS) -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList (Body SOACS -> Seq (Stm SOACS)
forall rep. Body rep -> Stms rep
bodyStms Body SOACS
bdy2)
sres2 :: [SubExp]
sres2 = Map VName VName -> [SubExp] -> [SubExp]
substInSEs Map VName VName
stab'' ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (Result -> [SubExp]) -> Result -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult Body SOACS
bdy2
in Bool
success Bool -> Bool -> Bool
&& (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult Body SOACS
bdy1) [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp]
sres2
equivLambda Map VName VName
_ Lambda SOACS
_ Lambda SOACS
_ =
Bool
False
equivStm ::
M.Map VName VName ->
Stm SOACS ->
Stm SOACS ->
(M.Map VName VName, Bool)
equivStm :: Map VName VName
-> Stm SOACS -> Stm SOACS -> (Map VName VName, Bool)
equivStm
Map VName VName
stab
(Let Pat (LetDec SOACS)
pat1 StmAux (ExpDec SOACS)
aux1 (BasicOp (BinOp BinOp
bop1 SubExp
se11 SubExp
se12)))
(Let Pat (LetDec SOACS)
pat2 StmAux (ExpDec SOACS)
aux2 (BasicOp (BinOp BinOp
bop2 SubExp
se21 SubExp
se22)))
| [SubExp
se11, SubExp
se12] [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== Map VName VName -> [SubExp] -> [SubExp]
substInSEs Map VName VName
stab [SubExp
se21, SubExp
se22],
([PatElem Type]
pels1, [PatElem Type]
pels2) <- (Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
Pat (LetDec SOACS)
pat1, Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
Pat (LetDec SOACS)
pat2),
(PatElem Type -> Type) -> [PatElem Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map PatElem Type -> Type
forall dec. PatElem dec -> dec
patElemDec [PatElem Type]
pels1 [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== (PatElem Type -> Type) -> [PatElem Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map PatElem Type -> Type
forall dec. PatElem dec -> dec
patElemDec [PatElem Type]
pels2,
BinOp
bop1 BinOp -> BinOp -> Bool
forall a. Eq a => a -> a -> Bool
== BinOp
bop2 Bool -> Bool -> Bool
&& StmAux ()
StmAux (ExpDec SOACS)
aux1 StmAux () -> StmAux () -> Bool
forall a. Eq a => a -> a -> Bool
== StmAux ()
StmAux (ExpDec SOACS)
aux2 =
let stab_new :: Map VName VName
stab_new =
[(VName, VName)] -> Map VName VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, VName)] -> Map VName VName)
-> [(VName, VName)] -> Map VName VName
forall a b. (a -> b) -> a -> b
$
[VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((PatElem Type -> VName) -> [PatElem Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName [PatElem Type]
pels2) ((PatElem Type -> VName) -> [PatElem Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName [PatElem Type]
pels1)
in (Map VName VName -> Map VName VName -> Map VName VName
forall k a. Ord k => Map k a -> Map k a -> Map k a
M.union Map VName VName
stab_new Map VName VName
stab, Bool
True)
equivStm Map VName VName
vtab Stm SOACS
_ Stm SOACS
_ = (Map VName VName
vtab, Bool
False)
matchingAccTup :: AccTup -> AccTup -> Bool
matchingAccTup :: AccTup -> AccTup -> Bool
matchingAccTup
([PatElem (LetDec SOACS)]
pat_els1, (ShapeBase SubExp
shp1, [VName]
_winp_arrs1, Maybe (Lambda SOACS, [SubExp])
mlam1), LParam SOACS
_, LParam SOACS
_, (VName, Certs)
_)
([PatElem (LetDec SOACS)]
_, (ShapeBase SubExp
shp2, [VName]
winp_arrs2, Maybe (Lambda SOACS, [SubExp])
mlam2), LParam SOACS
_, LParam SOACS
_, (VName, Certs)
_) =
ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shp1 [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shp2
Bool -> Bool -> Bool
&& (PatElem Type -> VName) -> [PatElem Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName [PatElem Type]
[PatElem (LetDec SOACS)]
pat_els1 [VName] -> [VName] -> Bool
forall a. Eq a => a -> a -> Bool
== [VName]
winp_arrs2
Bool -> Bool -> Bool
&& case (Maybe (Lambda SOACS, [SubExp])
mlam1, Maybe (Lambda SOACS, [SubExp])
mlam2) of
(Maybe (Lambda SOACS, [SubExp])
Nothing, Maybe (Lambda SOACS, [SubExp])
Nothing) -> Bool
True
(Just (Lambda SOACS
lam1, [SubExp]
see1), Just (Lambda SOACS
lam2, [SubExp]
see2)) ->
([SubExp]
see1 [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp]
see2) Bool -> Bool -> Bool
&& Map VName VName -> Lambda SOACS -> Lambda SOACS -> Bool
equivLambda Map VName VName
forall k a. Map k a
M.empty Lambda SOACS
lam1 Lambda SOACS
lam2
(Maybe (Lambda SOACS, [SubExp]), Maybe (Lambda SOACS, [SubExp]))
_ -> Bool
False
substInSEs :: M.Map VName VName -> [SubExp] -> [SubExp]
substInSEs :: Map VName VName -> [SubExp] -> [SubExp]
substInSEs Map VName VName
vtab = (SubExp -> SubExp) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> SubExp
substInSE
where
substInSE :: SubExp -> SubExp
substInSE (Var VName
x)
| Just VName
y <- VName -> Map VName VName -> Maybe VName
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
x Map VName VName
vtab = VName -> SubExp
Var VName
y
substInSE SubExp
z = SubExp
z