{-# LANGUAGE TypeFamilies #-}
module Futhark.Pass.ExtractKernels.Interchange
( SeqLoop (..),
interchangeLoops,
Branch (..),
interchangeBranch,
WithAccStm (..),
interchangeWithAcc,
)
where
import Control.Monad
import Data.List (find)
import Data.Maybe
import Futhark.IR.SOACS
import Futhark.MonadFreshNames
import Futhark.Pass.ExtractKernels.Distribution
( KernelNest,
LoopNesting (..),
kernelNestLoops,
scopeOfKernelNest,
)
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Util (splitFromEnd)
data SeqLoop
= SeqLoop [Int] (Pat Type) [(FParam SOACS, SubExp)] LoopForm (Body SOACS)
loopPerm :: SeqLoop -> [Int]
loopPerm :: SeqLoop -> [Int]
loopPerm (SeqLoop [Int]
perm Pat Type
_ [(FParam SOACS, SubExp)]
_ LoopForm
_ Body SOACS
_) = [Int]
perm
seqLoopStm :: SeqLoop -> Stm SOACS
seqLoopStm :: SeqLoop -> Stm SOACS
seqLoopStm (SeqLoop [Int]
_ Pat Type
pat [(FParam SOACS, SubExp)]
merge LoopForm
form Body SOACS
body) =
Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec SOACS)
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ [(FParam SOACS, SubExp)] -> LoopForm -> Body SOACS -> Exp SOACS
forall rep.
[(FParam rep, SubExp)] -> LoopForm -> Body rep -> Exp rep
Loop [(FParam SOACS, SubExp)]
merge LoopForm
form Body SOACS
body
interchangeLoop ::
(MonadBuilder m, Rep m ~ SOACS) =>
(VName -> Maybe VName) ->
SeqLoop ->
LoopNesting ->
m SeqLoop
interchangeLoop :: forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
(VName -> Maybe VName) -> SeqLoop -> LoopNesting -> m SeqLoop
interchangeLoop
VName -> Maybe VName
isMapParameter
(SeqLoop [Int]
perm Pat Type
loop_pat [(FParam SOACS, SubExp)]
merge LoopForm
form Body SOACS
body)
(MapNesting Pat Type
pat StmAux ()
aux SubExp
w [(Param Type, VName)]
params_and_arrs) = do
[(Param DeclType, SubExp)]
merge_expanded <-
Scope SOACS
-> m [(Param DeclType, SubExp)] -> m [(Param DeclType, SubExp)]
forall a. Scope SOACS -> m a -> m a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param Type] -> Scope SOACS
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams ([Param Type] -> Scope SOACS) -> [Param Type] -> Scope SOACS
forall a b. (a -> b) -> a -> b
$ ((Param Type, VName) -> Param Type)
-> [(Param Type, VName)] -> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst [(Param Type, VName)]
params_and_arrs) (m [(Param DeclType, SubExp)] -> m [(Param DeclType, SubExp)])
-> m [(Param DeclType, SubExp)] -> m [(Param DeclType, SubExp)]
forall a b. (a -> b) -> a -> b
$
((Param DeclType, SubExp) -> m (Param DeclType, SubExp))
-> [(Param DeclType, SubExp)] -> m [(Param DeclType, SubExp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Param DeclType, SubExp) -> m (Param DeclType, SubExp)
expand [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
merge
let loop_pat_expanded :: Pat Type
loop_pat_expanded =
[PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem Type] -> Pat Type) -> [PatElem Type] -> Pat Type
forall a b. (a -> b) -> a -> b
$ (PatElem Type -> PatElem Type) -> [PatElem Type] -> [PatElem Type]
forall a b. (a -> b) -> [a] -> [b]
map PatElem Type -> PatElem Type
expandPatElem ([PatElem Type] -> [PatElem Type])
-> [PatElem Type] -> [PatElem Type]
forall a b. (a -> b) -> a -> b
$ Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
loop_pat
new_params :: [Param Type]
new_params =
[Attrs -> VName -> Type -> Param Type
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
pname (Type -> Param Type) -> Type -> Param Type
forall a b. (a -> b) -> a -> b
$ DeclType -> Type
forall shape.
TypeBase shape Uniqueness -> TypeBase shape NoUniqueness
fromDecl DeclType
ptype | (Param Attrs
attrs VName
pname DeclType
ptype, SubExp
_) <- [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
merge]
new_arrs :: [VName]
new_arrs = ((Param DeclType, SubExp) -> VName)
-> [(Param DeclType, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType -> VName
forall dec. Param dec -> VName
paramName (Param DeclType -> VName)
-> ((Param DeclType, SubExp) -> Param DeclType)
-> (Param DeclType, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst) [(Param DeclType, SubExp)]
merge_expanded
rettype :: [Type]
rettype = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall u.
TypeBase (ShapeBase SubExp) u -> TypeBase (ShapeBase SubExp) u
rowType ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Pat Type -> [Type]
forall dec. Typed dec => Pat dec -> [Type]
patTypes Pat Type
loop_pat_expanded
(([Param Type]
params', [VName]
arrs'), Stms SOACS
pre_copy_stms) <-
Builder SOACS ([Param Type], [VName])
-> m (([Param Type], [VName]), Stms SOACS)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder SOACS ([Param Type], [VName])
-> m (([Param Type], [VName]), Stms SOACS))
-> Builder SOACS ([Param Type], [VName])
-> m (([Param Type], [VName]), Stms SOACS)
forall a b. (a -> b) -> a -> b
$
Scope SOACS
-> Builder SOACS ([Param Type], [VName])
-> Builder SOACS ([Param Type], [VName])
forall a.
Scope SOACS
-> BuilderT SOACS (State VNameSource) a
-> BuilderT SOACS (State VNameSource) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param Type] -> Scope SOACS
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [Param Type]
new_params) (Builder SOACS ([Param Type], [VName])
-> Builder SOACS ([Param Type], [VName]))
-> Builder SOACS ([Param Type], [VName])
-> Builder SOACS ([Param Type], [VName])
forall a b. (a -> b) -> a -> b
$
[(Param Type, VName)] -> ([Param Type], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Param Type, VName)] -> ([Param Type], [VName]))
-> ([Maybe (Param Type, VName)] -> [(Param Type, VName)])
-> [Maybe (Param Type, VName)]
-> ([Param Type], [VName])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe (Param Type, VName)] -> [(Param Type, VName)]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe (Param Type, VName)] -> ([Param Type], [VName]))
-> BuilderT SOACS (State VNameSource) [Maybe (Param Type, VName)]
-> Builder SOACS ([Param Type], [VName])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Param Type, VName)
-> BuilderT SOACS (State VNameSource) (Maybe (Param Type, VName)))
-> [(Param Type, VName)]
-> BuilderT SOACS (State VNameSource) [Maybe (Param Type, VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Param Type, VName)
-> BuilderT SOACS (State VNameSource) (Maybe (Param Type, VName))
copyOrRemoveParam [(Param Type, VName)]
params_and_arrs
let lam :: Lambda SOACS
lam = [LParam SOACS] -> [Type] -> Body SOACS -> Lambda SOACS
forall rep. [LParam rep] -> [Type] -> Body rep -> Lambda rep
Lambda ([Param Type]
params' [Param Type] -> [Param Type] -> [Param Type]
forall a. Semigroup a => a -> a -> a
<> [Param Type]
new_params) [Type]
rettype Body SOACS
body
map_stm :: Stm SOACS
map_stm =
Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec SOACS)
loop_pat_expanded StmAux ()
StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$
Op SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$
SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w ([VName]
arrs' [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
new_arrs) (Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam)
res :: Result
res = [VName] -> Result
varsRes ([VName] -> Result) -> [VName] -> Result
forall a b. (a -> b) -> a -> b
$ Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
loop_pat_expanded
pat' :: Pat Type
pat' = [PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem Type] -> Pat Type) -> [PatElem Type] -> Pat Type
forall a b. (a -> b) -> a -> b
$ [Int] -> [PatElem Type] -> [PatElem Type]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([PatElem Type] -> [PatElem Type])
-> [PatElem Type] -> [PatElem Type]
forall a b. (a -> b) -> a -> b
$ Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
pat
SeqLoop -> m SeqLoop
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SeqLoop -> m SeqLoop) -> SeqLoop -> m SeqLoop
forall a b. (a -> b) -> a -> b
$
[Int]
-> Pat Type
-> [(FParam SOACS, SubExp)]
-> LoopForm
-> Body SOACS
-> SeqLoop
SeqLoop [Int]
perm Pat Type
pat' [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
merge_expanded LoopForm
form (Body SOACS -> SeqLoop) -> Body SOACS -> SeqLoop
forall a b. (a -> b) -> a -> b
$
Stms SOACS -> Result -> Body SOACS
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody (Stms SOACS
pre_copy_stms Stms SOACS -> Stms SOACS -> Stms SOACS
forall a. Semigroup a => a -> a -> a
<> Stm SOACS -> Stms SOACS
forall rep. Stm rep -> Stms rep
oneStm Stm SOACS
map_stm) Result
res
where
free_in_body :: Names
free_in_body = Body SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn Body SOACS
body
copyOrRemoveParam :: (Param Type, VName)
-> BuilderT SOACS (State VNameSource) (Maybe (Param Type, VName))
copyOrRemoveParam (Param Type
param, VName
arr)
| Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
param VName -> Names -> Bool
`notNameIn` Names
free_in_body =
Maybe (Param Type, VName)
-> BuilderT SOACS (State VNameSource) (Maybe (Param Type, VName))
forall a. a -> BuilderT SOACS (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Param Type, VName)
forall a. Maybe a
Nothing
| Bool
otherwise =
Maybe (Param Type, VName)
-> BuilderT SOACS (State VNameSource) (Maybe (Param Type, VName))
forall a. a -> BuilderT SOACS (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (Param Type, VName)
-> BuilderT SOACS (State VNameSource) (Maybe (Param Type, VName)))
-> Maybe (Param Type, VName)
-> BuilderT SOACS (State VNameSource) (Maybe (Param Type, VName))
forall a b. (a -> b) -> a -> b
$ (Param Type, VName) -> Maybe (Param Type, VName)
forall a. a -> Maybe a
Just (Param Type
param, VName
arr)
expandedInit :: String -> SubExp -> m SubExp
expandedInit String
_ (Var VName
v)
| Just VName
arr <- VName -> Maybe VName
isMapParameter VName
v =
SubExp -> m SubExp
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> m SubExp) -> SubExp -> m SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr
expandedInit String
param_name SubExp
se =
String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp (String
param_name String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_expanded_init") (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$
ShapeBase SubExp -> SubExp -> BasicOp
Replicate ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
se
expand :: (Param DeclType, SubExp) -> m (Param DeclType, SubExp)
expand (Param DeclType
merge_param, SubExp
merge_init) = do
Param DeclType
expanded_param <-
String -> DeclType -> m (Param DeclType)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (String
param_name String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_expanded") (DeclType -> m (Param DeclType)) -> DeclType -> m (Param DeclType)
forall a b. (a -> b) -> a -> b
$
DeclType -> ShapeBase SubExp -> Uniqueness -> DeclType
forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf (Param DeclType -> DeclType
forall dec. DeclTyped dec => Param dec -> DeclType
paramDeclType Param DeclType
merge_param) ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) Uniqueness
Unique
SubExp
expanded_init <- String -> SubExp -> m SubExp
expandedInit String
param_name SubExp
merge_init
(Param DeclType, SubExp) -> m (Param DeclType, SubExp)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param DeclType
expanded_param, SubExp
expanded_init)
where
param_name :: String
param_name = VName -> String
baseString (VName -> String) -> VName -> String
forall a b. (a -> b) -> a -> b
$ Param DeclType -> VName
forall dec. Param dec -> VName
paramName Param DeclType
merge_param
expandPatElem :: PatElem Type -> PatElem Type
expandPatElem (PatElem VName
name Type
t) =
VName -> Type -> PatElem Type
forall dec. VName -> dec -> PatElem dec
PatElem VName
name (Type -> PatElem Type) -> Type -> PatElem Type
forall a b. (a -> b) -> a -> b
$ Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
arrayOfRow Type
t SubExp
w
maybeCopyInitial ::
(MonadBuilder m) =>
(VName -> Bool) ->
SeqLoop ->
m SeqLoop
maybeCopyInitial :: forall (m :: * -> *).
MonadBuilder m =>
(VName -> Bool) -> SeqLoop -> m SeqLoop
maybeCopyInitial VName -> Bool
isMapInput (SeqLoop [Int]
perm Pat Type
loop_pat [(FParam SOACS, SubExp)]
merge LoopForm
form Body SOACS
body) =
[Int]
-> Pat Type
-> [(FParam SOACS, SubExp)]
-> LoopForm
-> Body SOACS
-> SeqLoop
SeqLoop [Int]
perm Pat Type
loop_pat ([(Param DeclType, SubExp)] -> LoopForm -> Body SOACS -> SeqLoop)
-> m [(Param DeclType, SubExp)]
-> m (LoopForm -> Body SOACS -> SeqLoop)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Param DeclType, SubExp) -> m (Param DeclType, SubExp))
-> [(Param DeclType, SubExp)] -> m [(Param DeclType, SubExp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Param DeclType, SubExp) -> m (Param DeclType, SubExp)
f [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
merge m (LoopForm -> Body SOACS -> SeqLoop)
-> m LoopForm -> m (Body SOACS -> SeqLoop)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> LoopForm -> m LoopForm
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure LoopForm
form m (Body SOACS -> SeqLoop) -> m (Body SOACS) -> m SeqLoop
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Body SOACS -> m (Body SOACS)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Body SOACS
body
where
f :: (Param DeclType, SubExp) -> m (Param DeclType, SubExp)
f (Param DeclType
p, Var VName
arg)
| VName -> Bool
isMapInput VName
arg,
Array {} <- Param DeclType -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param DeclType
p =
(Param DeclType
p,)
(SubExp -> (Param DeclType, SubExp))
-> m SubExp -> m (Param DeclType, SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp
(VName -> String
baseString (Param DeclType -> VName
forall dec. Param dec -> VName
paramName Param DeclType
p) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_inter_copy")
(BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> SubExp -> BasicOp
Replicate ShapeBase SubExp
forall a. Monoid a => a
mempty (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arg)
f (Param DeclType
p, SubExp
arg) =
(Param DeclType, SubExp) -> m (Param DeclType, SubExp)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param DeclType
p, SubExp
arg)
manifestMaps :: [LoopNesting] -> [VName] -> Stms SOACS -> ([VName], Stms SOACS)
manifestMaps :: [LoopNesting] -> [VName] -> Stms SOACS -> ([VName], Stms SOACS)
manifestMaps [] [VName]
res Stms SOACS
stms = ([VName]
res, Stms SOACS
stms)
manifestMaps (LoopNesting
n : [LoopNesting]
ns) [VName]
res Stms SOACS
stms =
let ([VName]
res', Stms SOACS
stms') = [LoopNesting] -> [VName] -> Stms SOACS -> ([VName], Stms SOACS)
manifestMaps [LoopNesting]
ns [VName]
res Stms SOACS
stms
([Param Type]
params, [VName]
arrs) = [(Param Type, VName)] -> ([Param Type], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Param Type, VName)] -> ([Param Type], [VName]))
-> [(Param Type, VName)] -> ([Param Type], [VName])
forall a b. (a -> b) -> a -> b
$ LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs LoopNesting
n
lam :: Lambda SOACS
lam =
[LParam SOACS] -> [Type] -> Body SOACS -> Lambda SOACS
forall rep. [LParam rep] -> [Type] -> Body rep -> Lambda rep
Lambda
[Param Type]
[LParam SOACS]
params
((Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall u.
TypeBase (ShapeBase SubExp) u -> TypeBase (ShapeBase SubExp) u
rowType ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Pat Type -> [Type]
forall dec. Typed dec => Pat dec -> [Type]
patTypes (LoopNesting -> Pat Type
loopNestingPat LoopNesting
n))
(Stms SOACS -> Result -> Body SOACS
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody Stms SOACS
stms' (Result -> Body SOACS) -> Result -> Body SOACS
forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes [VName]
res')
in ( Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat Type -> [VName]) -> Pat Type -> [VName]
forall a b. (a -> b) -> a -> b
$ LoopNesting -> Pat Type
loopNestingPat LoopNesting
n,
Stm SOACS -> Stms SOACS
forall rep. Stm rep -> Stms rep
oneStm (Stm SOACS -> Stms SOACS) -> Stm SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$
Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (LoopNesting -> Pat Type
loopNestingPat LoopNesting
n) (LoopNesting -> StmAux ()
loopNestingAux LoopNesting
n) (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$
Op SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$
SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma (LoopNesting -> SubExp
loopNestingWidth LoopNesting
n) [VName]
arrs (Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam)
)
interchangeLoops ::
(MonadFreshNames m, HasScope SOACS m) =>
KernelNest ->
SeqLoop ->
m (Stms SOACS)
interchangeLoops :: forall (m :: * -> *).
(MonadFreshNames m, HasScope SOACS m) =>
KernelNest -> SeqLoop -> m (Stms SOACS)
interchangeLoops KernelNest
full_nest = [LoopNesting] -> SeqLoop -> m (Stms SOACS)
recurse (KernelNest -> [LoopNesting]
kernelNestLoops KernelNest
full_nest)
where
recurse :: [LoopNesting] -> SeqLoop -> m (Stms SOACS)
recurse [LoopNesting]
nest SeqLoop
loop
| ([LoopNesting]
ns, [LoopNesting
n]) <- Int -> [LoopNesting] -> ([LoopNesting], [LoopNesting])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd Int
1 [LoopNesting]
nest = do
let isMapParameter :: VName -> Maybe VName
isMapParameter VName
v =
(Param Type, VName) -> VName
forall a b. (a, b) -> b
snd ((Param Type, VName) -> VName)
-> Maybe (Param Type, VName) -> Maybe VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Param Type, VName) -> Bool)
-> [(Param Type, VName)] -> Maybe (Param Type, VName)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v) (VName -> Bool)
-> ((Param Type, VName) -> VName) -> (Param Type, VName) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> VName
forall dec. Param dec -> VName
paramName (Param Type -> VName)
-> ((Param Type, VName) -> Param Type)
-> (Param Type, VName)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst) (LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs LoopNesting
n)
isMapInput :: VName -> Bool
isMapInput VName
v =
VName
v VName -> [VName] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` ((Param Type, VName) -> VName) -> [(Param Type, VName)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param Type, VName) -> VName
forall a b. (a, b) -> b
snd (LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs LoopNesting
n)
(SeqLoop
loop', Stms SOACS
stms) <-
Builder SOACS SeqLoop -> m (SeqLoop, Stms SOACS)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder SOACS SeqLoop -> m (SeqLoop, Stms SOACS))
-> (Builder SOACS SeqLoop -> Builder SOACS SeqLoop)
-> Builder SOACS SeqLoop
-> m (SeqLoop, Stms SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope SOACS -> Builder SOACS SeqLoop -> Builder SOACS SeqLoop
forall a.
Scope SOACS
-> BuilderT SOACS (State VNameSource) a
-> BuilderT SOACS (State VNameSource) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (KernelNest -> Scope SOACS
forall rep. (LParamInfo rep ~ Type) => KernelNest -> Scope rep
scopeOfKernelNest KernelNest
full_nest) (Builder SOACS SeqLoop -> m (SeqLoop, Stms SOACS))
-> Builder SOACS SeqLoop -> m (SeqLoop, Stms SOACS)
forall a b. (a -> b) -> a -> b
$
(VName -> Bool) -> SeqLoop -> Builder SOACS SeqLoop
forall (m :: * -> *).
MonadBuilder m =>
(VName -> Bool) -> SeqLoop -> m SeqLoop
maybeCopyInitial VName -> Bool
isMapInput
(SeqLoop -> Builder SOACS SeqLoop)
-> Builder SOACS SeqLoop -> Builder SOACS SeqLoop
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (VName -> Maybe VName)
-> SeqLoop -> LoopNesting -> Builder SOACS SeqLoop
forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
(VName -> Maybe VName) -> SeqLoop -> LoopNesting -> m SeqLoop
interchangeLoop VName -> Maybe VName
isMapParameter SeqLoop
loop LoopNesting
n
if Stms SOACS -> Bool
forall a. Seq a -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Stms SOACS
stms
then [LoopNesting] -> SeqLoop -> m (Stms SOACS)
recurse [LoopNesting]
ns SeqLoop
loop'
else
let loop_stm :: Stm SOACS
loop_stm = SeqLoop -> Stm SOACS
seqLoopStm SeqLoop
loop'
names :: [VName]
names = [Int] -> [VName] -> [VName]
forall a. [Int] -> [a] -> [a]
rearrangeShape (SeqLoop -> [Int]
loopPerm SeqLoop
loop') (Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames (Stm SOACS -> Pat (LetDec SOACS)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm SOACS
loop_stm))
in Stms SOACS -> m (Stms SOACS)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms SOACS -> m (Stms SOACS)) -> Stms SOACS -> m (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ ([VName], Stms SOACS) -> Stms SOACS
forall a b. (a, b) -> b
snd (([VName], Stms SOACS) -> Stms SOACS)
-> ([VName], Stms SOACS) -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ [LoopNesting] -> [VName] -> Stms SOACS -> ([VName], Stms SOACS)
manifestMaps [LoopNesting]
ns [VName]
names (Stms SOACS -> ([VName], Stms SOACS))
-> Stms SOACS -> ([VName], Stms SOACS)
forall a b. (a -> b) -> a -> b
$ Stms SOACS
stms Stms SOACS -> Stms SOACS -> Stms SOACS
forall a. Semigroup a => a -> a -> a
<> Stm SOACS -> Stms SOACS
forall rep. Stm rep -> Stms rep
oneStm Stm SOACS
loop_stm
| Bool
otherwise = Stms SOACS -> m (Stms SOACS)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms SOACS -> m (Stms SOACS)) -> Stms SOACS -> m (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Stms SOACS
forall rep. Stm rep -> Stms rep
oneStm (Stm SOACS -> Stms SOACS) -> Stm SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ SeqLoop -> Stm SOACS
seqLoopStm SeqLoop
loop
data Branch
= Branch [Int] (Pat Type) [SubExp] [Case (Body SOACS)] (Body SOACS) (MatchDec (BranchType SOACS))
branchStm :: Branch -> Stm SOACS
branchStm :: Branch -> Stm SOACS
branchStm (Branch [Int]
_ Pat Type
pat [SubExp]
cond [Case (Body SOACS)]
cases Body SOACS
defbody MatchDec (BranchType SOACS)
ret) =
Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec SOACS)
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ [SubExp]
-> [Case (Body SOACS)]
-> Body SOACS
-> MatchDec (BranchType SOACS)
-> Exp SOACS
forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond [Case (Body SOACS)]
cases Body SOACS
defbody MatchDec (BranchType SOACS)
ret
interchangeBranch1 ::
(MonadFreshNames m, HasScope SOACS m) =>
Branch ->
LoopNesting ->
m Branch
interchangeBranch1 :: forall (m :: * -> *).
(MonadFreshNames m, HasScope SOACS m) =>
Branch -> LoopNesting -> m Branch
interchangeBranch1
(Branch [Int]
perm Pat Type
branch_pat [SubExp]
cond [Case (Body SOACS)]
cases Body SOACS
defbody (MatchDec [BranchType SOACS]
ret MatchSort
if_sort))
(MapNesting Pat Type
pat StmAux ()
aux SubExp
w [(Param Type, VName)]
params_and_arrs) = do
let ret' :: [ExtType]
ret' = (ExtType -> ExtType) -> [ExtType] -> [ExtType]
forall a b. (a -> b) -> [a] -> [b]
map (ExtType -> Ext SubExp -> ExtType
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp -> Ext SubExp
forall a. a -> Ext a
Free SubExp
w) [ExtType]
[BranchType SOACS]
ret
pat' :: Pat Type
pat' = [PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem Type] -> Pat Type) -> [PatElem Type] -> Pat Type
forall a b. (a -> b) -> a -> b
$ [Int] -> [PatElem Type] -> [PatElem Type]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([PatElem Type] -> [PatElem Type])
-> [PatElem Type] -> [PatElem Type]
forall a b. (a -> b) -> a -> b
$ Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
pat
([Param Type]
params, [VName]
arrs) = [(Param Type, VName)] -> ([Param Type], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param Type, VName)]
params_and_arrs
lam_ret :: [Type]
lam_ret = [Int] -> [Type] -> [Type]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall u.
TypeBase (ShapeBase SubExp) u -> TypeBase (ShapeBase SubExp) u
rowType ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Pat Type -> [Type]
forall dec. Typed dec => Pat dec -> [Type]
patTypes Pat Type
pat
branch_pat' :: Pat Type
branch_pat' =
[PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem Type] -> Pat Type) -> [PatElem Type] -> Pat Type
forall a b. (a -> b) -> a -> b
$ (PatElem Type -> PatElem Type) -> [PatElem Type] -> [PatElem Type]
forall a b. (a -> b) -> [a] -> [b]
map ((Type -> Type) -> PatElem Type -> PatElem Type
forall a b. (a -> b) -> PatElem a -> PatElem b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
w)) ([PatElem Type] -> [PatElem Type])
-> [PatElem Type] -> [PatElem Type]
forall a b. (a -> b) -> a -> b
$ Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
branch_pat
mkBranch :: Body SOACS -> m (Body SOACS)
mkBranch Body SOACS
branch = (Body SOACS -> m (Body SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody =<<) (m (Body SOACS) -> m (Body SOACS))
-> m (Body SOACS) -> m (Body SOACS)
forall a b. (a -> b) -> a -> b
$ Builder SOACS Result -> m (Body SOACS)
forall rep (m :: * -> *) somerep.
(Buildable rep, MonadFreshNames m, HasScope somerep m,
SameScope somerep rep) =>
Builder rep Result -> m (Body rep)
runBodyBuilder (Builder SOACS Result -> m (Body SOACS))
-> Builder SOACS Result -> m (Body SOACS)
forall a b. (a -> b) -> a -> b
$ do
let lam :: Lambda SOACS
lam = [LParam SOACS] -> [Type] -> Body SOACS -> Lambda SOACS
forall rep. [LParam rep] -> [Type] -> Body rep -> Lambda rep
Lambda [Param Type]
[LParam SOACS]
params [Type]
lam_ret Body SOACS
branch
Stm (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ())
-> Stm (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (BuilderT SOACS (State VNameSource))))
-> StmAux (ExpDec (Rep (BuilderT SOACS (State VNameSource))))
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> Stm (Rep (BuilderT SOACS (State VNameSource)))
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec (Rep (BuilderT SOACS (State VNameSource))))
branch_pat' StmAux ()
StmAux (ExpDec (Rep (BuilderT SOACS (State VNameSource))))
aux (Exp (Rep (BuilderT SOACS (State VNameSource)))
-> Stm (Rep (BuilderT SOACS (State VNameSource))))
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> Stm (Rep (BuilderT SOACS (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ Op (Rep (BuilderT SOACS (State VNameSource)))
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
forall rep. Op rep -> Exp rep
Op (Op (Rep (BuilderT SOACS (State VNameSource)))
-> Exp (Rep (BuilderT SOACS (State VNameSource))))
-> Op (Rep (BuilderT SOACS (State VNameSource)))
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam
Result -> Builder SOACS Result
forall a. a -> BuilderT SOACS (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> Builder SOACS Result) -> Result -> Builder SOACS Result
forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes ([VName] -> Result) -> [VName] -> Result
forall a b. (a -> b) -> a -> b
$ Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
branch_pat'
[Case (Body SOACS)]
cases' <- (Case (Body SOACS) -> m (Case (Body SOACS)))
-> [Case (Body SOACS)] -> m [Case (Body SOACS)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((Body SOACS -> m (Body SOACS))
-> Case (Body SOACS) -> m (Case (Body SOACS))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Case a -> f (Case b)
traverse Body SOACS -> m (Body SOACS)
mkBranch) [Case (Body SOACS)]
cases
Body SOACS
defbody' <- Body SOACS -> m (Body SOACS)
mkBranch Body SOACS
defbody
Branch -> m Branch
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Branch -> m Branch)
-> (MatchDec ExtType -> Branch) -> MatchDec ExtType -> m Branch
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int]
-> Pat Type
-> [SubExp]
-> [Case (Body SOACS)]
-> Body SOACS
-> MatchDec (BranchType SOACS)
-> Branch
Branch [Int
0 .. Pat Type -> Int
forall dec. Pat dec -> Int
patSize Pat Type
pat Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] Pat Type
pat' [SubExp]
cond [Case (Body SOACS)]
cases' Body SOACS
defbody' (MatchDec ExtType -> m Branch) -> MatchDec ExtType -> m Branch
forall a b. (a -> b) -> a -> b
$
[ExtType] -> MatchSort -> MatchDec ExtType
forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [ExtType]
ret' MatchSort
if_sort
interchangeBranch ::
(MonadFreshNames m, HasScope SOACS m) =>
KernelNest ->
Branch ->
m (Stm SOACS)
interchangeBranch :: forall (m :: * -> *).
(MonadFreshNames m, HasScope SOACS m) =>
KernelNest -> Branch -> m (Stm SOACS)
interchangeBranch KernelNest
nest Branch
loop =
Branch -> Stm SOACS
branchStm (Branch -> Stm SOACS) -> m Branch -> m (Stm SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Branch -> LoopNesting -> m Branch)
-> Branch -> [LoopNesting] -> m Branch
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Branch -> LoopNesting -> m Branch
forall (m :: * -> *).
(MonadFreshNames m, HasScope SOACS m) =>
Branch -> LoopNesting -> m Branch
interchangeBranch1 Branch
loop ([LoopNesting] -> [LoopNesting]
forall a. [a] -> [a]
reverse ([LoopNesting] -> [LoopNesting]) -> [LoopNesting] -> [LoopNesting]
forall a b. (a -> b) -> a -> b
$ KernelNest -> [LoopNesting]
kernelNestLoops KernelNest
nest)
data WithAccStm
= WithAccStm [Int] (Pat Type) [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))] (Lambda SOACS)
withAccStm :: WithAccStm -> Stm SOACS
withAccStm :: WithAccStm -> Stm SOACS
withAccStm (WithAccStm [Int]
_ Pat Type
pat [(ShapeBase SubExp, [VName], Maybe (Lambda SOACS, [SubExp]))]
inputs Lambda SOACS
lam) =
Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec SOACS)
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ [(ShapeBase SubExp, [VName], Maybe (Lambda SOACS, [SubExp]))]
-> Lambda SOACS -> Exp SOACS
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [(ShapeBase SubExp, [VName], Maybe (Lambda SOACS, [SubExp]))]
inputs Lambda SOACS
lam
interchangeWithAcc1 ::
(MonadFreshNames m, LocalScope SOACS m) =>
WithAccStm ->
LoopNesting ->
m WithAccStm
interchangeWithAcc1 :: forall (m :: * -> *).
(MonadFreshNames m, LocalScope SOACS m) =>
WithAccStm -> LoopNesting -> m WithAccStm
interchangeWithAcc1
(WithAccStm [Int]
perm Pat Type
_withacc_pat [(ShapeBase SubExp, [VName], Maybe (Lambda SOACS, [SubExp]))]
inputs Lambda SOACS
acc_lam)
(MapNesting Pat Type
map_pat StmAux ()
map_aux SubExp
w [(Param Type, VName)]
params_and_arrs) = do
[(ShapeBase SubExp, [VName], Maybe (Lambda SOACS, [SubExp]))]
inputs' <- ((ShapeBase SubExp, [VName], Maybe (Lambda SOACS, [SubExp]))
-> m (ShapeBase SubExp, [VName], Maybe (Lambda SOACS, [SubExp])))
-> [(ShapeBase SubExp, [VName], Maybe (Lambda SOACS, [SubExp]))]
-> m [(ShapeBase SubExp, [VName], Maybe (Lambda SOACS, [SubExp]))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (ShapeBase SubExp, [VName], Maybe (Lambda SOACS, [SubExp]))
-> m (ShapeBase SubExp, [VName], Maybe (Lambda SOACS, [SubExp]))
onInput [(ShapeBase SubExp, [VName], Maybe (Lambda SOACS, [SubExp]))]
inputs
[Param Type]
lam_params' <- [Param Type] -> m [Param Type]
forall {m :: * -> *} {a}.
MonadFreshNames m =>
[Param a] -> m [Param a]
newAccLamParams ([Param Type] -> m [Param Type]) -> [Param Type] -> m [Param Type]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
acc_lam
Param Type
iota_p <- String -> Type -> m (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"iota_p" (Type -> m (Param Type)) -> Type -> m (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
Lambda SOACS
acc_lam' <- SubExp -> Lambda SOACS -> m (Lambda SOACS)
trLam (VName -> SubExp
Var (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
iota_p)) (Lambda SOACS -> m (Lambda SOACS))
-> (Builder SOACS Result -> m (Lambda SOACS))
-> Builder SOACS Result
-> m (Lambda SOACS)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< [LParam SOACS] -> Builder SOACS Result -> m (Lambda SOACS)
forall rep (m :: * -> *) somerep.
(Buildable rep, MonadFreshNames m, HasScope somerep m,
SameScope somerep rep) =>
[LParam rep] -> Builder rep Result -> m (Lambda rep)
runLambdaBuilder [Param Type]
[LParam SOACS]
lam_params' (Builder SOACS Result -> m (Lambda SOACS))
-> Builder SOACS Result -> m (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ do
let acc_params :: [Param Type]
acc_params = Int -> [Param Type] -> [Param Type]
forall a. Int -> [a] -> [a]
drop ([(ShapeBase SubExp, [VName], Maybe (Lambda SOACS, [SubExp]))]
-> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(ShapeBase SubExp, [VName], Maybe (Lambda SOACS, [SubExp]))]
inputs) [Param Type]
lam_params'
orig_acc_params :: [Param Type]
orig_acc_params = Int -> [Param Type] -> [Param Type]
forall a. Int -> [a] -> [a]
drop ([(ShapeBase SubExp, [VName], Maybe (Lambda SOACS, [SubExp]))]
-> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(ShapeBase SubExp, [VName], Maybe (Lambda SOACS, [SubExp]))]
inputs) ([Param Type] -> [Param Type]) -> [Param Type] -> [Param Type]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
acc_lam
VName
iota_w <-
String
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"acc_inter_iota" (Exp SOACS -> BuilderT SOACS (State VNameSource) VName)
-> (BasicOp -> Exp SOACS)
-> BasicOp
-> BuilderT SOACS (State VNameSource) VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> BuilderT SOACS (State VNameSource) VName)
-> BasicOp -> BuilderT SOACS (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$
SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
let ([Param Type]
params, [VName]
arrs) = [(Param Type, VName)] -> ([Param Type], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param Type, VName)]
params_and_arrs
maplam_ret :: [Type]
maplam_ret = Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
acc_lam
maplam :: Lambda SOACS
maplam = [LParam SOACS] -> [Type] -> Body SOACS -> Lambda SOACS
forall rep. [LParam rep] -> [Type] -> Body rep -> Lambda rep
Lambda (Param Type
iota_p Param Type -> [Param Type] -> [Param Type]
forall a. a -> [a] -> [a]
: [Param Type]
orig_acc_params [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type]
params) [Type]
maplam_ret (Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
acc_lam)
StmAux () -> Builder SOACS Result -> Builder SOACS Result
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
map_aux (Builder SOACS Result -> Builder SOACS Result)
-> (Exp SOACS -> Builder SOACS Result)
-> Exp SOACS
-> Builder SOACS Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([SubExp] -> Result)
-> BuilderT SOACS (State VNameSource) [SubExp]
-> Builder SOACS Result
forall a b.
(a -> b)
-> BuilderT SOACS (State VNameSource) a
-> BuilderT SOACS (State VNameSource) b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [SubExp] -> Result
subExpsRes (BuilderT SOACS (State VNameSource) [SubExp]
-> Builder SOACS Result)
-> (Exp SOACS -> BuilderT SOACS (State VNameSource) [SubExp])
-> Exp SOACS
-> Builder SOACS Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) [SubExp]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [SubExp]
letTupExp' String
"withacc_inter" (Exp SOACS -> Builder SOACS Result)
-> Exp SOACS -> Builder SOACS Result
forall a b. (a -> b) -> a -> b
$
Op SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$
SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w (VName
iota_w VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
acc_params [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
arrs) (Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
maplam)
let pat :: Pat Type
pat = [PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem Type] -> Pat Type) -> [PatElem Type] -> Pat Type
forall a b. (a -> b) -> a -> b
$ [Int] -> [PatElem Type] -> [PatElem Type]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([PatElem Type] -> [PatElem Type])
-> [PatElem Type] -> [PatElem Type]
forall a b. (a -> b) -> a -> b
$ Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
map_pat
WithAccStm -> m WithAccStm
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (WithAccStm -> m WithAccStm) -> WithAccStm -> m WithAccStm
forall a b. (a -> b) -> a -> b
$ [Int]
-> Pat Type
-> [(ShapeBase SubExp, [VName], Maybe (Lambda SOACS, [SubExp]))]
-> Lambda SOACS
-> WithAccStm
WithAccStm [Int]
perm Pat Type
pat [(ShapeBase SubExp, [VName], Maybe (Lambda SOACS, [SubExp]))]
inputs' Lambda SOACS
acc_lam'
where
newAccLamParams :: [Param a] -> m [Param a]
newAccLamParams [Param a]
ps = do
let ([Param a]
cert_ps, [Param a]
acc_ps) = Int -> [Param a] -> ([Param a], [Param a])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Param a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Param a]
ps Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) [Param a]
ps
[Param a]
acc_ps' <- [Param a] -> (Param a -> m (Param a)) -> m [Param a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Param a]
acc_ps ((Param a -> m (Param a)) -> m [Param a])
-> (Param a -> m (Param a)) -> m [Param a]
forall a b. (a -> b) -> a -> b
$ \(Param Attrs
attrs VName
v a
t) ->
Attrs -> VName -> a -> Param a
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs (VName -> a -> Param a) -> m VName -> m (a -> Param a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (VName -> String
baseString VName
v) m (a -> Param a) -> m a -> m (Param a)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> a -> m a
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
t
[Param a] -> m [Param a]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Param a] -> m [Param a]) -> [Param a] -> m [Param a]
forall a b. (a -> b) -> a -> b
$ [Param a]
cert_ps [Param a] -> [Param a] -> [Param a]
forall a. Semigroup a => a -> a -> a
<> [Param a]
acc_ps'
num_accs :: Int
num_accs = [(ShapeBase SubExp, [VName], Maybe (Lambda SOACS, [SubExp]))]
-> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(ShapeBase SubExp, [VName], Maybe (Lambda SOACS, [SubExp]))]
inputs
acc_certs :: [VName]
acc_certs = (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName ([Param Type] -> [VName]) -> [Param Type] -> [VName]
forall a b. (a -> b) -> a -> b
$ Int -> [Param Type] -> [Param Type]
forall a. Int -> [a] -> [a]
take Int
num_accs ([Param Type] -> [Param Type]) -> [Param Type] -> [Param Type]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
acc_lam
onArr :: VName -> m VName
onArr VName
v =
VName -> m VName
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> m VName)
-> (Maybe (Param Type, VName) -> VName)
-> Maybe (Param Type, VName)
-> m VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName
-> ((Param Type, VName) -> VName)
-> Maybe (Param Type, VName)
-> VName
forall b a. b -> (a -> b) -> Maybe a -> b
maybe VName
v (Param Type, VName) -> VName
forall a b. (a, b) -> b
snd (Maybe (Param Type, VName) -> m VName)
-> Maybe (Param Type, VName) -> m VName
forall a b. (a -> b) -> a -> b
$
((Param Type, VName) -> Bool)
-> [(Param Type, VName)] -> Maybe (Param Type, VName)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v) (VName -> Bool)
-> ((Param Type, VName) -> VName) -> (Param Type, VName) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> VName
forall dec. Param dec -> VName
paramName (Param Type -> VName)
-> ((Param Type, VName) -> Param Type)
-> (Param Type, VName)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst) [(Param Type, VName)]
params_and_arrs
onInput :: (ShapeBase SubExp, [VName], Maybe (Lambda SOACS, [SubExp]))
-> m (ShapeBase SubExp, [VName], Maybe (Lambda SOACS, [SubExp]))
onInput (ShapeBase SubExp
shape, [VName]
arrs, Maybe (Lambda SOACS, [SubExp])
op) =
([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
w] ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> ShapeBase SubExp
shape,,) ([VName]
-> Maybe (Lambda SOACS, [SubExp])
-> (ShapeBase SubExp, [VName], Maybe (Lambda SOACS, [SubExp])))
-> m [VName]
-> m (Maybe (Lambda SOACS, [SubExp])
-> (ShapeBase SubExp, [VName], Maybe (Lambda SOACS, [SubExp])))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> m VName
onArr [VName]
arrs m (Maybe (Lambda SOACS, [SubExp])
-> (ShapeBase SubExp, [VName], Maybe (Lambda SOACS, [SubExp])))
-> m (Maybe (Lambda SOACS, [SubExp]))
-> m (ShapeBase SubExp, [VName], Maybe (Lambda SOACS, [SubExp]))
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ((Lambda SOACS, [SubExp]) -> m (Lambda SOACS, [SubExp]))
-> Maybe (Lambda SOACS, [SubExp])
-> m (Maybe (Lambda SOACS, [SubExp]))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Maybe a -> f (Maybe b)
traverse (Lambda SOACS, [SubExp]) -> m (Lambda SOACS, [SubExp])
forall {rep} {shape} {u} {m :: * -> *} {b}.
(LParamInfo rep ~ TypeBase shape u, MonadFreshNames m) =>
(Lambda rep, b) -> m (Lambda rep, b)
onOp Maybe (Lambda SOACS, [SubExp])
op
onOp :: (Lambda rep, b) -> m (Lambda rep, b)
onOp (Lambda rep
op_lam, b
nes) = do
Param (TypeBase shape u)
idx_p <- String -> TypeBase shape u -> m (Param (TypeBase shape u))
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"idx" (TypeBase shape u -> m (Param (TypeBase shape u)))
-> TypeBase shape u -> m (Param (TypeBase shape u))
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase shape u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
(Lambda rep, b) -> m (Lambda rep, b)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda rep
op_lam {lambdaParams = idx_p : lambdaParams op_lam}, b
nes)
trType :: TypeBase shape u -> TypeBase shape u
trType :: forall shape u. TypeBase shape u -> TypeBase shape u
trType (Acc VName
acc ShapeBase SubExp
ispace [Type]
ts u
u)
| VName
acc VName -> [VName] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
acc_certs =
VName -> ShapeBase SubExp -> [Type] -> u -> TypeBase shape u
forall shape u.
VName -> ShapeBase SubExp -> [Type] -> u -> TypeBase shape u
Acc VName
acc ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
w] ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> ShapeBase SubExp
ispace) [Type]
ts u
u
trType TypeBase shape u
t = TypeBase shape u
t
trParam :: Param (TypeBase shape u) -> Param (TypeBase shape u)
trParam :: forall shape u.
Param (TypeBase shape u) -> Param (TypeBase shape u)
trParam = (TypeBase shape u -> TypeBase shape u)
-> Param (TypeBase shape u) -> Param (TypeBase shape u)
forall a b. (a -> b) -> Param a -> Param b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TypeBase shape u -> TypeBase shape u
forall shape u. TypeBase shape u -> TypeBase shape u
trType
trLam :: SubExp -> Lambda SOACS -> m (Lambda SOACS)
trLam SubExp
i (Lambda [LParam SOACS]
params [Type]
ret Body SOACS
body) =
Scope SOACS -> m (Lambda SOACS) -> m (Lambda SOACS)
forall a. Scope SOACS -> m a -> m a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([LParam SOACS] -> Scope SOACS
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [LParam SOACS]
params) (m (Lambda SOACS) -> m (Lambda SOACS))
-> m (Lambda SOACS) -> m (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$
[LParam SOACS] -> [Type] -> Body SOACS -> Lambda SOACS
forall rep. [LParam rep] -> [Type] -> Body rep -> Lambda rep
Lambda ((Param Type -> Param Type) -> [Param Type] -> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> Param Type
forall shape u.
Param (TypeBase shape u) -> Param (TypeBase shape u)
trParam [Param Type]
[LParam SOACS]
params) ((Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall shape u. TypeBase shape u -> TypeBase shape u
trType [Type]
ret) (Body SOACS -> Lambda SOACS) -> m (Body SOACS) -> m (Lambda SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> Body SOACS -> m (Body SOACS)
trBody SubExp
i Body SOACS
body
trBody :: SubExp -> Body SOACS -> m (Body SOACS)
trBody SubExp
i (Body BodyDec SOACS
dec Stms SOACS
stms Result
res) =
Stms SOACS -> m (Body SOACS) -> m (Body SOACS)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms SOACS
stms (m (Body SOACS) -> m (Body SOACS))
-> m (Body SOACS) -> m (Body SOACS)
forall a b. (a -> b) -> a -> b
$ BodyDec SOACS -> Stms SOACS -> Result -> Body SOACS
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec SOACS
dec (Stms SOACS -> Result -> Body SOACS)
-> m (Stms SOACS) -> m (Result -> Body SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm SOACS -> m (Stm SOACS)) -> Stms SOACS -> m (Stms SOACS)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Seq a -> f (Seq b)
traverse (SubExp -> Stm SOACS -> m (Stm SOACS)
trStm SubExp
i) Stms SOACS
stms m (Result -> Body SOACS) -> m Result -> m (Body SOACS)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> m Result
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res
trStm :: SubExp -> Stm SOACS -> m (Stm SOACS)
trStm SubExp
i (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux Exp SOACS
e) =
Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ((Type -> Type) -> Pat Type -> Pat Type
forall a b. (a -> b) -> Pat a -> Pat b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Type -> Type
forall shape u. TypeBase shape u -> TypeBase shape u
trType Pat Type
Pat (LetDec SOACS)
pat) StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> m (Exp SOACS) -> m (Stm SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> Exp SOACS -> m (Exp SOACS)
trExp SubExp
i Exp SOACS
e
trSOAC :: SubExp -> SOAC SOACS -> m (SOAC SOACS)
trSOAC SubExp
i = SOACMapper SOACS SOACS m -> SOAC SOACS -> m (SOAC SOACS)
forall (m :: * -> *) frep trep.
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper SOACS SOACS m
mapper
where
mapper :: SOACMapper SOACS SOACS m
mapper =
SOACMapper Any Any m
forall rep (m :: * -> *). Monad m => SOACMapper rep rep m
identitySOACMapper {mapOnSOACLambda = trLam i}
trExp :: SubExp -> Exp SOACS -> m (Exp SOACS)
trExp SubExp
i (WithAcc [(ShapeBase SubExp, [VName], Maybe (Lambda SOACS, [SubExp]))]
acc_inputs Lambda SOACS
lam) =
[(ShapeBase SubExp, [VName], Maybe (Lambda SOACS, [SubExp]))]
-> Lambda SOACS -> Exp SOACS
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [(ShapeBase SubExp, [VName], Maybe (Lambda SOACS, [SubExp]))]
acc_inputs (Lambda SOACS -> Exp SOACS) -> m (Lambda SOACS) -> m (Exp SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> Lambda SOACS -> m (Lambda SOACS)
trLam SubExp
i Lambda SOACS
lam
trExp SubExp
i (BasicOp (UpdateAcc Safety
safety VName
acc [SubExp]
is [SubExp]
ses)) = do
Type
acc_t <- VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
acc
Exp SOACS -> m (Exp SOACS)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp SOACS -> m (Exp SOACS)) -> Exp SOACS -> m (Exp SOACS)
forall a b. (a -> b) -> a -> b
$ case Type
acc_t of
Acc VName
cert ShapeBase SubExp
_ [Type]
_ NoUniqueness
_
| VName
cert VName -> [VName] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
acc_certs ->
BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ Safety -> VName -> [SubExp] -> [SubExp] -> BasicOp
UpdateAcc Safety
safety VName
acc (SubExp
i SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: [SubExp]
is) [SubExp]
ses
Type
_ ->
BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ Safety -> VName -> [SubExp] -> [SubExp] -> BasicOp
UpdateAcc Safety
safety VName
acc [SubExp]
is [SubExp]
ses
trExp SubExp
i Exp SOACS
e = Mapper SOACS SOACS m -> Exp SOACS -> m (Exp SOACS)
forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper SOACS SOACS m
mapper Exp SOACS
e
where
mapper :: Mapper SOACS SOACS m
mapper =
Mapper Any Any m
forall rep (m :: * -> *). Monad m => Mapper rep rep m
identityMapper
{ mapOnBody = \Scope SOACS
scope -> Scope SOACS -> m (Body SOACS) -> m (Body SOACS)
forall a. Scope SOACS -> m a -> m a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope SOACS
scope (m (Body SOACS) -> m (Body SOACS))
-> (Body SOACS -> m (Body SOACS)) -> Body SOACS -> m (Body SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> Body SOACS -> m (Body SOACS)
trBody SubExp
i,
mapOnRetType = pure . trType,
mapOnBranchType = pure . trType,
mapOnFParam = pure . trParam,
mapOnLParam = pure . trParam,
mapOnOp = trSOAC i
}
interchangeWithAcc ::
(MonadFreshNames m, LocalScope SOACS m) =>
KernelNest ->
WithAccStm ->
m (Stm SOACS)
interchangeWithAcc :: forall (m :: * -> *).
(MonadFreshNames m, LocalScope SOACS m) =>
KernelNest -> WithAccStm -> m (Stm SOACS)
interchangeWithAcc KernelNest
nest WithAccStm
withacc =
WithAccStm -> Stm SOACS
withAccStm (WithAccStm -> Stm SOACS) -> m WithAccStm -> m (Stm SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (WithAccStm -> LoopNesting -> m WithAccStm)
-> WithAccStm -> [LoopNesting] -> m WithAccStm
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM WithAccStm -> LoopNesting -> m WithAccStm
forall (m :: * -> *).
(MonadFreshNames m, LocalScope SOACS m) =>
WithAccStm -> LoopNesting -> m WithAccStm
interchangeWithAcc1 WithAccStm
withacc ([LoopNesting] -> [LoopNesting]
forall a. [a] -> [a]
reverse ([LoopNesting] -> [LoopNesting]) -> [LoopNesting] -> [LoopNesting]
forall a b. (a -> b) -> a -> b
$ KernelNest -> [LoopNesting]
kernelNestLoops KernelNest
nest)