{-# LANGUAGE TypeFamilies #-}
module Futhark.Optimise.GenRedOpt (optimiseGenRed) where
import Control.Monad
import Control.Monad.Reader
import Control.Monad.State
import Data.List qualified as L
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.Builder
import Futhark.IR.GPU
import Futhark.Optimise.TileLoops.Shared
import Futhark.Pass
import Futhark.Tools
import Futhark.Transform.Rename
type GenRedM = ReaderT (Scope GPU) (State VNameSource)
optimiseGenRed :: Pass GPU GPU
optimiseGenRed :: Pass GPU GPU
optimiseGenRed =
[Char] -> [Char] -> (Prog GPU -> PassM (Prog GPU)) -> Pass GPU GPU
forall fromrep torep.
[Char]
-> [Char]
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass [Char]
"optimise generalized reductions" [Char]
"Specializes generalized reductions into map-reductions or histograms" ((Prog GPU -> PassM (Prog GPU)) -> Pass GPU GPU)
-> (Prog GPU -> PassM (Prog GPU)) -> Pass GPU GPU
forall a b. (a -> b) -> a -> b
$
(Scope GPU -> Stms GPU -> PassM (Stms GPU))
-> Prog GPU -> PassM (Prog GPU)
forall rep.
(Scope rep -> Stms rep -> PassM (Stms rep))
-> Prog rep -> PassM (Prog rep)
intraproceduralTransformation Scope GPU -> Stms GPU -> PassM (Stms GPU)
forall {m :: * -> *}.
MonadFreshNames m =>
Scope GPU -> Stms GPU -> m (Stms GPU)
onStms
where
onStms :: Scope GPU -> Stms GPU -> m (Stms GPU)
onStms Scope GPU
scope Stms GPU
stms =
(VNameSource -> (Stms GPU, VNameSource)) -> m (Stms GPU)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Stms GPU, VNameSource)) -> m (Stms GPU))
-> (VNameSource -> (Stms GPU, VNameSource)) -> m (Stms GPU)
forall a b. (a -> b) -> a -> b
$
State VNameSource (Stms GPU)
-> VNameSource -> (Stms GPU, VNameSource)
forall s a. State s a -> s -> (a, s)
runState (State VNameSource (Stms GPU)
-> VNameSource -> (Stms GPU, VNameSource))
-> State VNameSource (Stms GPU)
-> VNameSource
-> (Stms GPU, VNameSource)
forall a b. (a -> b) -> a -> b
$
ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
-> Scope GPU -> State VNameSource (Stms GPU)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Env
-> Stms GPU -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
optimiseStms (WithEnv
forall k a. Map k a
M.empty, Map VName LMAD
forall k a. Map k a
M.empty) Stms GPU
stms) Scope GPU
scope
optimiseBody :: Env -> Body GPU -> GenRedM (Body GPU)
optimiseBody :: Env -> Body GPU -> GenRedM (Body GPU)
optimiseBody Env
env (Body () Stms GPU
stms Result
res) =
BodyDec GPU -> Stms GPU -> Result -> Body GPU
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () (Stms GPU -> Result -> Body GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Result -> Body GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Env
-> Stms GPU -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
optimiseStms Env
env Stms GPU
stms ReaderT (Scope GPU) (State VNameSource) (Result -> Body GPU)
-> ReaderT (Scope GPU) (State VNameSource) Result
-> GenRedM (Body GPU)
forall a b.
ReaderT (Scope GPU) (State VNameSource) (a -> b)
-> ReaderT (Scope GPU) (State VNameSource) a
-> ReaderT (Scope GPU) (State VNameSource) b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> ReaderT (Scope GPU) (State VNameSource) Result
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res
optimiseStms :: Env -> Stms GPU -> GenRedM (Stms GPU)
optimiseStms :: Env
-> Stms GPU -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
optimiseStms Env
env Stms GPU
stms =
Scope GPU
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall a.
Scope GPU
-> ReaderT (Scope GPU) (State VNameSource) a
-> ReaderT (Scope GPU) (State VNameSource) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Stms GPU -> Scope GPU
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms GPU
stms) (ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU))
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
(Env
_, Stms GPU
stms') <- ((Env, Stms GPU)
-> Stm GPU
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU))
-> (Env, Stms GPU)
-> [Stm GPU]
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Env, Stms GPU)
-> Stm GPU
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU)
foldfun (Env
env, Stms GPU
forall a. Monoid a => a
mempty) ([Stm GPU]
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU))
-> [Stm GPU]
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stms GPU -> [Stm GPU]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms GPU
stms
Stms GPU -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms GPU
stms'
where
foldfun :: (Env, Stms GPU) -> Stm GPU -> GenRedM (Env, Stms GPU)
foldfun :: (Env, Stms GPU)
-> Stm GPU
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU)
foldfun (Env
e, Stms GPU
ss) Stm GPU
s = do
(Env
e', Stms GPU
s') <- Env
-> Stm GPU
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU)
optimiseStm Env
e Stm GPU
s
(Env, Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU)
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Env
e', Stms GPU
ss Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<> Stms GPU
s')
optimiseStm :: Env -> Stm GPU -> GenRedM (Env, Stms GPU)
optimiseStm :: Env
-> Stm GPU
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU)
optimiseStm Env
env stm :: Stm GPU
stm@(Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (Op (SegOp (SegMap SegThread {} SegSpace
_ [Type]
_ KernelBody GPU
_)))) = do
Maybe (Stms GPU)
res_genred_opt <- Env -> Stm GPU -> GenRedM (Maybe (Stms GPU))
genRedOpts Env
env Stm GPU
stm
let stms' :: Stms GPU
stms' =
case Maybe (Stms GPU)
res_genred_opt of
Just Stms GPU
stms -> Stms GPU
stms
Maybe (Stms GPU)
Nothing -> Stm GPU -> Stms GPU
forall rep. Stm rep -> Stms rep
oneStm Stm GPU
stm
(Env, Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU)
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Env
env, Stms GPU
stms')
optimiseStm Env
env (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux Exp GPU
e) = do
Env
env' <- Env -> VName -> Exp GPU -> TileM Env
changeEnv Env
env ([VName] -> VName
forall a. HasCallStack => [a] -> a
head ([VName] -> VName) -> [VName] -> VName
forall a b. (a -> b) -> a -> b
$ Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
Pat (LetDec GPU)
pat) Exp GPU
e
Exp GPU
e' <- Mapper GPU GPU (ReaderT (Scope GPU) (State VNameSource))
-> Exp GPU -> ReaderT (Scope GPU) (State VNameSource) (Exp GPU)
forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM (Env -> Mapper GPU GPU (ReaderT (Scope GPU) (State VNameSource))
optimise Env
env') Exp GPU
e
(Env, Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU)
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Env
env', Stm GPU -> Stms GPU
forall rep. Stm rep -> Stms rep
oneStm (Stm GPU -> Stms GPU) -> Stm GPU -> Stms GPU
forall a b. (a -> b) -> a -> b
$ Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux Exp GPU
e')
where
optimise :: Env -> Mapper GPU GPU (ReaderT (Scope GPU) (State VNameSource))
optimise Env
env' = Mapper GPU GPU (ReaderT (Scope GPU) (State VNameSource))
forall rep (m :: * -> *). Monad m => Mapper rep rep m
identityMapper {mapOnBody = \Scope GPU
scope -> Scope GPU -> GenRedM (Body GPU) -> GenRedM (Body GPU)
forall a.
Scope GPU
-> ReaderT (Scope GPU) (State VNameSource) a
-> ReaderT (Scope GPU) (State VNameSource) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPU
scope (GenRedM (Body GPU) -> GenRedM (Body GPU))
-> (Body GPU -> GenRedM (Body GPU))
-> Body GPU
-> GenRedM (Body GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> Body GPU -> GenRedM (Body GPU)
optimiseBody Env
env'}
genRedOpts :: Env -> Stm GPU -> GenRedM (Maybe (Stms GPU))
genRedOpts :: Env -> Stm GPU -> GenRedM (Maybe (Stms GPU))
genRedOpts Env
env Stm GPU
ker = do
Maybe (Stms GPU, Stm GPU)
res_tile <- Env -> Stm GPU -> GenRedM (Maybe (Stms GPU, Stm GPU))
genRed2Tile2d Env
env Stm GPU
ker
case Maybe (Stms GPU, Stm GPU)
res_tile of
Maybe (Stms GPU, Stm GPU)
Nothing -> do
Maybe (Stms GPU, Stm GPU)
res_sgrd <- Env -> Stm GPU -> GenRedM (Maybe (Stms GPU, Stm GPU))
genRed2SegRed Env
env Stm GPU
ker
Maybe (Stms GPU, Stm GPU) -> GenRedM (Maybe (Stms GPU))
helperGenRed Maybe (Stms GPU, Stm GPU)
res_sgrd
Maybe (Stms GPU, Stm GPU)
_ -> Maybe (Stms GPU, Stm GPU) -> GenRedM (Maybe (Stms GPU))
helperGenRed Maybe (Stms GPU, Stm GPU)
res_tile
where
helperGenRed :: Maybe (Stms GPU, Stm GPU) -> GenRedM (Maybe (Stms GPU))
helperGenRed Maybe (Stms GPU, Stm GPU)
Nothing = Maybe (Stms GPU) -> GenRedM (Maybe (Stms GPU))
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Stms GPU)
forall a. Maybe a
Nothing
helperGenRed (Just (Stms GPU
stms_before, Stm GPU
ker_snd)) = do
Maybe (Stms GPU)
mb_stms_after <- Env -> Stm GPU -> GenRedM (Maybe (Stms GPU))
genRedOpts Env
env Stm GPU
ker_snd
case Maybe (Stms GPU)
mb_stms_after of
Just Stms GPU
stms_after -> Maybe (Stms GPU) -> GenRedM (Maybe (Stms GPU))
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (Stms GPU) -> GenRedM (Maybe (Stms GPU)))
-> Maybe (Stms GPU) -> GenRedM (Maybe (Stms GPU))
forall a b. (a -> b) -> a -> b
$ Stms GPU -> Maybe (Stms GPU)
forall a. a -> Maybe a
Just (Stms GPU -> Maybe (Stms GPU)) -> Stms GPU -> Maybe (Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stms GPU
stms_before Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<> Stms GPU
stms_after
Maybe (Stms GPU)
Nothing -> Maybe (Stms GPU) -> GenRedM (Maybe (Stms GPU))
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (Stms GPU) -> GenRedM (Maybe (Stms GPU)))
-> Maybe (Stms GPU) -> GenRedM (Maybe (Stms GPU))
forall a b. (a -> b) -> a -> b
$ Stms GPU -> Maybe (Stms GPU)
forall a. a -> Maybe a
Just (Stms GPU -> Maybe (Stms GPU)) -> Stms GPU -> Maybe (Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stms GPU
stms_before Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<> Stm GPU -> Stms GPU
forall rep. Stm rep -> Stms rep
oneStm Stm GPU
ker_snd
genRed2Tile2d :: Env -> Stm GPU -> GenRedM (Maybe (Stms GPU, Stm GPU))
genRed2Tile2d :: Env -> Stm GPU -> GenRedM (Maybe (Stms GPU, Stm GPU))
genRed2Tile2d Env
env kerstm :: Stm GPU
kerstm@(Let Pat (LetDec GPU)
pat_ker StmAux (ExpDec GPU)
aux (Op (SegOp (SegMap SegLevel
seg_thd SegSpace
seg_space [Type]
kres_tps KernelBody GPU
old_kbody))))
| SegThread SegVirt
_novirt Maybe KernelGrid
_ <- SegLevel
seg_thd,
KernelBody () Stms GPU
kstms [KernelResult]
kres <- KernelBody GPU
old_kbody,
Just ([VName]
css, [SubExp]
r_ses) <- [KernelResult] -> Maybe ([VName], [SubExp])
allGoodReturns [KernelResult]
kres,
[VName] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
css,
VarianceTable
initial_variance <- (NameInfo Any -> Names)
-> Map VName (NameInfo Any) -> VarianceTable
forall a b k. (a -> b) -> Map k a -> Map k b
M.map NameInfo Any -> Names
forall a. Monoid a => a
mempty (Map VName (NameInfo Any) -> VarianceTable)
-> Map VName (NameInfo Any) -> VarianceTable
forall a b. (a -> b) -> a -> b
$ SegSpace -> Map VName (NameInfo Any)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
seg_space,
VarianceTable
variance <- VarianceTable -> Stms GPU -> VarianceTable
varianceInStms VarianceTable
initial_variance Stms GPU
kstms,
(Stms GPU
code1, Just Stm GPU
accum_stmt, Stms GPU
code2) <- Stms GPU -> (Stms GPU, Maybe (Stm GPU), Stms GPU)
matchCodeAccumCode Stms GPU
kstms,
Let Pat (LetDec GPU)
pat_accum StmAux (ExpDec GPU)
_aux_acc (BasicOp (UpdateAcc Safety
safety VName
acc_nm [SubExp]
acc_inds [SubExp]
acc_vals)) <- Stm GPU
accum_stmt,
[VName
pat_acc_nm] <- Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
Pat (LetDec GPU)
pat_accum,
Just (VName
invar_gid, Int
gid_ind) <- Names
-> SegSpace -> VarianceTable -> [SubExp] -> Maybe (VName, Int)
isInvarToParDim Names
forall a. Monoid a => a
mempty SegSpace
seg_space VarianceTable
variance [SubExp]
acc_inds,
[(VName, SubExp)]
gid_dims_new_0 <- ((VName, SubExp) -> Bool) -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter (\(VName, SubExp)
x -> VName
invar_gid VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
/= (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst (VName, SubExp)
x) (SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
seg_space),
[(VName, SubExp)]
gid_dims_new <- VarianceTable -> [SubExp] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall {b}.
VarianceTable -> [SubExp] -> [(VName, b)] -> [(VName, b)]
reorderParDims VarianceTable
variance [SubExp]
acc_inds [(VName, SubExp)]
gid_dims_new_0,
(Stm GPU -> Bool) -> [Stm GPU] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName
-> [(VName, SubExp)] -> VarianceTable -> VName -> Stm GPU -> Bool
isTileable VName
invar_gid [(VName, SubExp)]
gid_dims_new VarianceTable
variance VName
pat_acc_nm) (Stms GPU -> [Stm GPU]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms GPU
code1),
Cost
cost <- VarianceTable -> VName -> [SubExp] -> Stms GPU -> Cost
costRedundantExecution VarianceTable
variance VName
pat_acc_nm [SubExp]
r_ses Stms GPU
kstms,
Cost -> Cost -> Cost
maxCost Cost
cost (Int -> Cost
Small Int
2) Cost -> Cost -> Bool
forall a. Eq a => a -> a -> Bool
== Int -> Cost
Small Int
2 = do
Type
acc_tp <- VName -> ReaderT (Scope GPU) (State VNameSource) Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
acc_nm
let inv_dim_len :: SubExp
inv_dim_len = SegSpace -> [SubExp]
segSpaceDims SegSpace
seg_space [SubExp] -> Int -> SubExp
forall a. HasCallStack => [a] -> Int -> a
!! Int
gid_ind
((Lambda GPU
redop0, [SubExp]
neutral), [Type]
el_tps) = Type -> ((Lambda GPU, [SubExp]), [Type])
getAccLambda Type
acc_tp
Lambda GPU
redop <- Lambda GPU -> ReaderT (Scope GPU) (State VNameSource) (Lambda GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda GPU
redop0
let red :: Reduce GPU
red =
Reduce
{ redComm :: Commutativity
redComm = Commutativity
Commutative,
redLambda :: Lambda GPU
redLambda = Lambda GPU
redop,
redNeutral :: [SubExp]
redNeutral = [SubExp]
neutral
}
code1' :: Stms GPU
code1' =
[Stm GPU] -> Stms GPU
forall rep. [Stm rep] -> Stms rep
stmsFromList ([Stm GPU] -> Stms GPU) -> [Stm GPU] -> Stms GPU
forall a b. (a -> b) -> a -> b
$
(Stm GPU -> Bool) -> [Stm GPU] -> [Stm GPU]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> VarianceTable -> Stm GPU -> Bool
forall {k} {rep}. Ord k => k -> Map k Names -> Stm rep -> Bool
dependsOnAcc VName
pat_acc_nm VarianceTable
variance) ([Stm GPU] -> [Stm GPU]) -> [Stm GPU] -> [Stm GPU]
forall a b. (a -> b) -> a -> b
$
Stms GPU -> [Stm GPU]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms GPU
code1
(Stms GPU
code1'', Stms GPU
code1_tr_host) <- Names
-> VarianceTable
-> VName
-> Stms GPU
-> GenRedM (Stms GPU, Stms GPU)
transposeFVs (Stm GPU -> Names
forall a. FreeIn a => a -> Names
freeIn Stm GPU
kerstm) VarianceTable
variance VName
invar_gid Stms GPU
code1'
let map_lam_body :: Body GPU
map_lam_body = Stms GPU -> Result -> Body GPU
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody Stms GPU
code1'' (Result -> Body GPU) -> Result -> Body GPU
forall a b. (a -> b) -> a -> b
$ (SubExp -> SubExpRes) -> [SubExp] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (Certs -> SubExp -> SubExpRes
SubExpRes ([VName] -> Certs
Certs [])) [SubExp]
acc_vals
map_lam0 :: Lambda GPU
map_lam0 = [LParam GPU] -> [Type] -> Body GPU -> Lambda GPU
forall rep. [LParam rep] -> [Type] -> Body rep -> Lambda rep
Lambda [Attrs -> VName -> Type -> Param Type
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty VName
invar_gid (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)] [Type]
el_tps Body GPU
map_lam_body
Lambda GPU
map_lam <- Lambda GPU -> ReaderT (Scope GPU) (State VNameSource) (Lambda GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda GPU
map_lam0
(SubExp
k1_res, Stms GPU
ker1_stms) <- BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) SubExp
-> ReaderT (Scope GPU) (State VNameSource) (SubExp, Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
BuilderT rep m a -> m (a, Stms rep)
runBuilderT' (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) SubExp
-> ReaderT (Scope GPU) (State VNameSource) (SubExp, Stms GPU))
-> BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) SubExp
-> ReaderT (Scope GPU) (State VNameSource) (SubExp, Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
VName
iota <- [Char]
-> Exp
(Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
-> BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"iota" (Exp (Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
-> BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) VName)
-> Exp
(Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
-> BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) VName
forall a b. (a -> b) -> a -> b
$ BasicOp
-> Exp
(Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp
-> Exp
(Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)))))
-> BasicOp
-> Exp
(Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
inv_dim_len (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
let op_exp :: Exp GPU
op_exp = OpC GPU GPU -> Exp GPU
forall rep. Op rep -> Exp rep
Op (SOAC GPU -> HostOp SOAC GPU
forall (op :: * -> *) rep. op rep -> HostOp op rep
OtherOp (SubExp -> [VName] -> ScremaForm GPU -> SOAC GPU
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
inv_dim_len [VName
iota] (Lambda GPU -> [Scan GPU] -> [Reduce GPU] -> ScremaForm GPU
forall rep.
Lambda rep -> [Scan rep] -> [Reduce rep] -> ScremaForm rep
ScremaForm Lambda GPU
map_lam [] [Reduce GPU
red])))
[VName]
res_redmap <- [Char]
-> Exp
(Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
-> BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"res_mapred" Exp (Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
Exp GPU
op_exp
[Char]
-> Exp
(Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
-> BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp (VName -> [Char]
baseString VName
pat_acc_nm [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_big_update") (Exp (Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
-> BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) SubExp)
-> Exp
(Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
-> BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (Safety -> VName -> [SubExp] -> [SubExp] -> BasicOp
UpdateAcc Safety
safety VName
acc_nm [SubExp]
acc_inds ([SubExp] -> BasicOp) -> [SubExp] -> BasicOp
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
res_redmap)
VName
gid_flat_1 <- [Char] -> ReaderT (Scope GPU) (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"gid_flat"
let space1 :: SegSpace
space1 = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
gid_flat_1 [(VName, SubExp)]
gid_dims_new
let level1 :: SegLevel
level1 = SegVirt -> Maybe KernelGrid -> SegLevel
SegThread (SegSeqDims -> SegVirt
SegNoVirtFull ([Int] -> SegSeqDims
SegSeqDims [])) Maybe KernelGrid
forall a. Maybe a
Nothing
kbody1 :: KernelBody GPU
kbody1 = BodyDec GPU -> Stms GPU -> [KernelResult] -> KernelBody GPU
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
ker1_stms [ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify ([VName] -> Certs
Certs []) SubExp
k1_res]
Exp GPU
ker_exp <- Exp GPU -> ReaderT (Scope GPU) (State VNameSource) (Exp GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Exp rep -> m (Exp rep)
renameExp (Exp GPU -> ReaderT (Scope GPU) (State VNameSource) (Exp GPU))
-> Exp GPU -> ReaderT (Scope GPU) (State VNameSource) (Exp GPU)
forall a b. (a -> b) -> a -> b
$ OpC GPU GPU -> Exp GPU
forall rep. Op rep -> Exp rep
Op (SegOp SegLevel GPU -> HostOp SOAC GPU
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp (SegLevel
-> SegSpace -> [Type] -> KernelBody GPU -> SegOp SegLevel GPU
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
level1 SegSpace
space1 [Type
acc_tp] KernelBody GPU
kbody1))
let ker1 :: Stm GPU
ker1 = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat_accum StmAux (ExpDec GPU)
aux Exp GPU
ker_exp
let ker2_body :: KernelBody GPU
ker2_body = KernelBody GPU
old_kbody {kernelBodyStms = code1 <> code2}
Exp GPU
ker2_exp <- Exp GPU -> ReaderT (Scope GPU) (State VNameSource) (Exp GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Exp rep -> m (Exp rep)
renameExp (Exp GPU -> ReaderT (Scope GPU) (State VNameSource) (Exp GPU))
-> Exp GPU -> ReaderT (Scope GPU) (State VNameSource) (Exp GPU)
forall a b. (a -> b) -> a -> b
$ OpC GPU GPU -> Exp GPU
forall rep. Op rep -> Exp rep
Op (SegOp SegLevel GPU -> HostOp SOAC GPU
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp (SegLevel
-> SegSpace -> [Type] -> KernelBody GPU -> SegOp SegLevel GPU
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
seg_thd SegSpace
seg_space [Type]
kres_tps KernelBody GPU
ker2_body))
let ker2 :: Stm GPU
ker2 = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat_ker StmAux (ExpDec GPU)
aux Exp GPU
ker2_exp
Maybe (Stms GPU, Stm GPU) -> GenRedM (Maybe (Stms GPU, Stm GPU))
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (Stms GPU, Stm GPU) -> GenRedM (Maybe (Stms GPU, Stm GPU)))
-> Maybe (Stms GPU, Stm GPU) -> GenRedM (Maybe (Stms GPU, Stm GPU))
forall a b. (a -> b) -> a -> b
$
(Stms GPU, Stm GPU) -> Maybe (Stms GPU, Stm GPU)
forall a. a -> Maybe a
Just (Stms GPU
code1_tr_host Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<> Stm GPU -> Stms GPU
forall rep. Stm rep -> Stms rep
oneStm Stm GPU
ker1, Stm GPU
ker2)
where
isIndVarToParDim :: VarianceTable -> SubExp -> (VName, b) -> Bool
isIndVarToParDim VarianceTable
_ (Constant PrimValue
_) (VName, b)
_ = Bool
False
isIndVarToParDim VarianceTable
variance (Var VName
acc_ind) (VName, b)
par_dim =
VName
acc_ind VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== (VName, b) -> VName
forall a b. (a, b) -> a
fst (VName, b)
par_dim
Bool -> Bool -> Bool
|| VName -> Names -> Bool
nameIn ((VName, b) -> VName
forall a b. (a, b) -> a
fst (VName, b)
par_dim) (Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
acc_ind VarianceTable
variance)
foldfunReorder :: VarianceTable
-> ([(VName, b)], [(VName, b)])
-> SubExp
-> ([(VName, b)], [(VName, b)])
foldfunReorder VarianceTable
variance ([(VName, b)]
unused_dims, [(VName, b)]
inner_dims) SubExp
acc_ind =
case ((VName, b) -> Bool) -> [(VName, b)] -> Maybe Int
forall a. (a -> Bool) -> [a] -> Maybe Int
L.findIndex (VarianceTable -> SubExp -> (VName, b) -> Bool
forall {b}. VarianceTable -> SubExp -> (VName, b) -> Bool
isIndVarToParDim VarianceTable
variance SubExp
acc_ind) [(VName, b)]
unused_dims of
Maybe Int
Nothing -> ([(VName, b)]
unused_dims, [(VName, b)]
inner_dims)
Just Int
i ->
( Int -> [(VName, b)] -> [(VName, b)]
forall a. Int -> [a] -> [a]
take Int
i [(VName, b)]
unused_dims [(VName, b)] -> [(VName, b)] -> [(VName, b)]
forall a. [a] -> [a] -> [a]
++ Int -> [(VName, b)] -> [(VName, b)]
forall a. Int -> [a] -> [a]
drop (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [(VName, b)]
unused_dims,
([(VName, b)]
unused_dims [(VName, b)] -> Int -> (VName, b)
forall a. HasCallStack => [a] -> Int -> a
!! Int
i) (VName, b) -> [(VName, b)] -> [(VName, b)]
forall a. a -> [a] -> [a]
: [(VName, b)]
inner_dims
)
reorderParDims :: VarianceTable -> [SubExp] -> [(VName, b)] -> [(VName, b)]
reorderParDims VarianceTable
variance [SubExp]
acc_inds [(VName, b)]
gid_dims_new_0 =
let ([(VName, b)]
invar_dims, [(VName, b)]
inner_dims) =
(([(VName, b)], [(VName, b)])
-> SubExp -> ([(VName, b)], [(VName, b)]))
-> ([(VName, b)], [(VName, b)])
-> [SubExp]
-> ([(VName, b)], [(VName, b)])
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
(VarianceTable
-> ([(VName, b)], [(VName, b)])
-> SubExp
-> ([(VName, b)], [(VName, b)])
forall {b}.
VarianceTable
-> ([(VName, b)], [(VName, b)])
-> SubExp
-> ([(VName, b)], [(VName, b)])
foldfunReorder VarianceTable
variance)
([(VName, b)]
gid_dims_new_0, [])
([SubExp] -> [SubExp]
forall a. [a] -> [a]
reverse [SubExp]
acc_inds)
in [(VName, b)]
invar_dims [(VName, b)] -> [(VName, b)] -> [(VName, b)]
forall a. [a] -> [a] -> [a]
++ [(VName, b)]
inner_dims
getAccLambda :: Type -> ((Lambda GPU, [SubExp]), [Type])
getAccLambda Type
acc_tp =
case Type
acc_tp of
(Acc VName
tp_id ShapeBase SubExp
_shp [Type]
el_tps NoUniqueness
_) ->
case VName -> WithEnv -> Maybe (Lambda GPU, [SubExp])
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
tp_id (Env -> WithEnv
forall a b. (a, b) -> a
fst Env
env) of
Just (Lambda GPU, [SubExp])
lam -> ((Lambda GPU, [SubExp])
lam, [Type]
el_tps)
Maybe (Lambda GPU, [SubExp])
_ -> [Char] -> ((Lambda GPU, [SubExp]), [Type])
forall a. HasCallStack => [Char] -> a
error ([Char] -> ((Lambda GPU, [SubExp]), [Type]))
-> [Char] -> ((Lambda GPU, [SubExp]), [Type])
forall a b. (a -> b) -> a -> b
$ [Char]
"Lookup in environment failed! " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString VName
tp_id [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
" env: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ WithEnv -> [Char]
forall a. Show a => a -> [Char]
show (Env -> WithEnv
forall a b. (a, b) -> a
fst Env
env)
Type
_ -> [Char] -> ((Lambda GPU, [SubExp]), [Type])
forall a. HasCallStack => [Char] -> a
error [Char]
"Illegal accumulator type!"
isSeInvar2 :: VarianceTable -> VName -> SubExp -> Bool
isSeInvar2 VarianceTable
variance VName
gid (Var VName
x) =
let x_deps :: Names
x_deps = Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
x VarianceTable
variance
in VName
gid VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
/= VName
x Bool -> Bool -> Bool
&& VName
gid VName -> Names -> Bool
`notNameIn` Names
x_deps
isSeInvar2 VarianceTable
_ VName
_ SubExp
_ = Bool
True
isDimIdxInvar2 :: VarianceTable -> VName -> DimIndex SubExp -> Bool
isDimIdxInvar2 VarianceTable
variance VName
gid (DimFix SubExp
d) =
VarianceTable -> VName -> SubExp -> Bool
isSeInvar2 VarianceTable
variance VName
gid SubExp
d
isDimIdxInvar2 VarianceTable
variance VName
gid (DimSlice SubExp
d1 SubExp
d2 SubExp
d3) =
(SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VarianceTable -> VName -> SubExp -> Bool
isSeInvar2 VarianceTable
variance VName
gid) [SubExp
d1, SubExp
d2, SubExp
d3]
isSliceInvar2 :: VarianceTable -> Slice SubExp -> t VName -> Bool
isSliceInvar2 VarianceTable
variance Slice SubExp
slc =
(VName -> Bool) -> t VName -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (\VName
gid -> (DimIndex SubExp -> Bool) -> [DimIndex SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VarianceTable -> VName -> DimIndex SubExp -> Bool
isDimIdxInvar2 VarianceTable
variance VName
gid) (Slice SubExp -> [DimIndex SubExp]
forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slc))
isTileable :: VName -> [(VName, SubExp)] -> VarianceTable -> VName -> Stm GPU -> Bool
isTileable :: VName
-> [(VName, SubExp)] -> VarianceTable -> VName -> Stm GPU -> Bool
isTileable VName
seq_gid [(VName, SubExp)]
gid_dims VarianceTable
variance VName
acc_nm (Let (Pat [PatElem (LetDec GPU)
pel]) StmAux (ExpDec GPU)
_ (BasicOp (Index VName
_ Slice SubExp
slc)))
| Names
acc_deps <- Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
acc_nm VarianceTable
variance,
PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName PatElem Type
PatElem (LetDec GPU)
pel VName -> Names -> Bool
`nameIn` Names
acc_deps =
let invar_par :: Bool
invar_par = VarianceTable -> Slice SubExp -> [VName] -> Bool
forall {t :: * -> *}.
Foldable t =>
VarianceTable -> Slice SubExp -> t VName -> Bool
isSliceInvar2 VarianceTable
variance Slice SubExp
slc (((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst [(VName, SubExp)]
gid_dims)
invar_seq :: Bool
invar_seq = VarianceTable -> Slice SubExp -> [VName] -> Bool
forall {t :: * -> *}.
Foldable t =>
VarianceTable -> Slice SubExp -> t VName -> Bool
isSliceInvar2 VarianceTable
variance Slice SubExp
slc [VName
seq_gid]
in Bool
invar_par Bool -> Bool -> Bool
|| Bool
invar_seq
isTileable VName
_ [(VName, SubExp)]
_ VarianceTable
_ VName
_ Stm GPU
_ = Bool
True
dependsOnAcc :: k -> Map k Names -> Stm rep -> Bool
dependsOnAcc k
pat_acc_nm Map k Names
variance (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ Exp rep
_) =
let acc_deps :: Names
acc_deps = Names -> k -> Map k Names -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty k
pat_acc_nm Map k Names
variance
in (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` Names
acc_deps) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat
genRed2Tile2d Env
_ Stm GPU
_ =
Maybe (Stms GPU, Stm GPU) -> GenRedM (Maybe (Stms GPU, Stm GPU))
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Stms GPU, Stm GPU)
forall a. Maybe a
Nothing
genRed2SegRed :: Env -> Stm GPU -> GenRedM (Maybe (Stms GPU, Stm GPU))
genRed2SegRed :: Env -> Stm GPU -> GenRedM (Maybe (Stms GPU, Stm GPU))
genRed2SegRed Env
_ Stm GPU
_ =
Maybe (Stms GPU, Stm GPU) -> GenRedM (Maybe (Stms GPU, Stm GPU))
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Stms GPU, Stm GPU)
forall a. Maybe a
Nothing
transposeFVs ::
Names ->
VarianceTable ->
VName ->
Stms GPU ->
GenRedM (Stms GPU, Stms GPU)
transposeFVs :: Names
-> VarianceTable
-> VName
-> Stms GPU
-> GenRedM (Stms GPU, Stms GPU)
transposeFVs Names
fvs VarianceTable
variance VName
gid Stms GPU
stms = do
(Map VName ([Int], VName, Stms GPU)
tab, Stms GPU
stms') <- ((Map VName ([Int], VName, Stms GPU), Stms GPU)
-> Stm GPU
-> ReaderT
(Scope GPU)
(State VNameSource)
(Map VName ([Int], VName, Stms GPU), Stms GPU))
-> (Map VName ([Int], VName, Stms GPU), Stms GPU)
-> [Stm GPU]
-> ReaderT
(Scope GPU)
(State VNameSource)
(Map VName ([Int], VName, Stms GPU), Stms GPU)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Map VName ([Int], VName, Stms GPU), Stms GPU)
-> Stm GPU
-> ReaderT
(Scope GPU)
(State VNameSource)
(Map VName ([Int], VName, Stms GPU), Stms GPU)
foldfun (Map VName ([Int], VName, Stms GPU)
forall k a. Map k a
M.empty, Stms GPU
forall a. Monoid a => a
mempty) ([Stm GPU]
-> ReaderT
(Scope GPU)
(State VNameSource)
(Map VName ([Int], VName, Stms GPU), Stms GPU))
-> [Stm GPU]
-> ReaderT
(Scope GPU)
(State VNameSource)
(Map VName ([Int], VName, Stms GPU), Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stms GPU -> [Stm GPU]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms GPU
stms
let stms_host :: Stms GPU
stms_host = (([Int], VName, Stms GPU) -> Stms GPU -> Stms GPU)
-> Stms GPU -> Map VName ([Int], VName, Stms GPU) -> Stms GPU
forall a b k. (a -> b -> b) -> b -> Map k a -> b
M.foldr (\([Int]
_, VName
_, Stms GPU
s) Stms GPU
ss -> Stms GPU
ss Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<> Stms GPU
s) Stms GPU
forall a. Monoid a => a
mempty Map VName ([Int], VName, Stms GPU)
tab
(Stms GPU, Stms GPU) -> GenRedM (Stms GPU, Stms GPU)
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms', Stms GPU
stms_host)
where
foldfun :: (Map VName ([Int], VName, Stms GPU), Stms GPU)
-> Stm GPU
-> ReaderT
(Scope GPU)
(State VNameSource)
(Map VName ([Int], VName, Stms GPU), Stms GPU)
foldfun (Map VName ([Int], VName, Stms GPU)
tab, Stms GPU
all_stms) Stm GPU
stm = do
(Map VName ([Int], VName, Stms GPU)
tab', Stm GPU
stm') <- (Map VName ([Int], VName, Stms GPU), Stm GPU)
-> ReaderT
(Scope GPU)
(State VNameSource)
(Map VName ([Int], VName, Stms GPU), Stm GPU)
transposeFV (Map VName ([Int], VName, Stms GPU)
tab, Stm GPU
stm)
(Map VName ([Int], VName, Stms GPU), Stms GPU)
-> ReaderT
(Scope GPU)
(State VNameSource)
(Map VName ([Int], VName, Stms GPU), Stms GPU)
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map VName ([Int], VName, Stms GPU)
tab', Stms GPU
all_stms Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<> Stm GPU -> Stms GPU
forall rep. Stm rep -> Stms rep
oneStm Stm GPU
stm')
transposeFV :: (Map VName ([Int], VName, Stms GPU), Stm GPU)
-> ReaderT
(Scope GPU)
(State VNameSource)
(Map VName ([Int], VName, Stms GPU), Stm GPU)
transposeFV (Map VName ([Int], VName, Stms GPU)
tab, Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (BasicOp (Index VName
arr Slice SubExp
slc)))
| [DimIndex SubExp]
dims <- Slice SubExp -> [DimIndex SubExp]
forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slc,
(DimIndex SubExp -> Bool) -> [DimIndex SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all DimIndex SubExp -> Bool
forall {d}. DimIndex d -> Bool
isFixDim [DimIndex SubExp]
dims,
VName
arr VName -> Names -> Bool
`nameIn` Names
fvs,
[Int]
iis <- (DimIndex SubExp -> Bool) -> [DimIndex SubExp] -> [Int]
forall a. (a -> Bool) -> [a] -> [Int]
L.findIndices DimIndex SubExp -> Bool
depOnGid [DimIndex SubExp]
dims,
[Int
ii] <- [Int]
iis,
Maybe ([Int], VName, Stms GPU)
Nothing <- VName
-> Map VName ([Int], VName, Stms GPU)
-> Maybe ([Int], VName, Stms GPU)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
arr Map VName ([Int], VName, Stms GPU)
tab,
Int
ii Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= [DimIndex SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex SubExp]
dims Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1,
[Int]
perm <- [Int
0 .. Int
ii Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
ii Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 .. [DimIndex SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex SubExp]
dims Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
ii] = do
(VName
arr_tr, Stms GPU
stms_tr) <- BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) VName
-> ReaderT (Scope GPU) (State VNameSource) (VName, Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
BuilderT rep m a -> m (a, Stms rep)
runBuilderT' (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) VName
-> ReaderT (Scope GPU) (State VNameSource) (VName, Stms GPU))
-> BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) VName
-> ReaderT (Scope GPU) (State VNameSource) (VName, Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
VName
arr' <- [Char]
-> Exp
(Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
-> BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
arr [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_trsp") (Exp (Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
-> BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) VName)
-> Exp
(Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
-> BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) VName
forall a b. (a -> b) -> a -> b
$ BasicOp
-> Exp
(Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp
-> Exp
(Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)))))
-> BasicOp
-> Exp
(Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
perm VName
arr
[Char]
-> Exp
(Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
-> BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
arr' [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_opaque") (Exp (Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
-> BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) VName)
-> Exp
(Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
-> BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) VName
forall a b. (a -> b) -> a -> b
$ BasicOp
-> Exp
(Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp
-> Exp
(Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)))))
-> BasicOp
-> Exp
(Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ OpaqueOp -> SubExp -> BasicOp
Opaque OpaqueOp
OpaqueNil (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr'
let tab' :: Map VName ([Int], VName, Stms GPU)
tab' = VName
-> ([Int], VName, Stms GPU)
-> Map VName ([Int], VName, Stms GPU)
-> Map VName ([Int], VName, Stms GPU)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
arr ([Int]
perm, VName
arr_tr, Stms GPU
stms_tr) Map VName ([Int], VName, Stms GPU)
tab
slc' :: Slice SubExp
slc' = [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ (Int -> DimIndex SubExp) -> [Int] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map ([DimIndex SubExp]
dims !!) [Int]
perm
stm' :: Stm GPU
stm' = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (Exp GPU -> Stm GPU) -> Exp GPU -> Stm GPU
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr_tr Slice SubExp
slc'
(Map VName ([Int], VName, Stms GPU), Stm GPU)
-> ReaderT
(Scope GPU)
(State VNameSource)
(Map VName ([Int], VName, Stms GPU), Stm GPU)
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map VName ([Int], VName, Stms GPU)
tab', Stm GPU
stm')
where
isFixDim :: DimIndex d -> Bool
isFixDim DimFix {} = Bool
True
isFixDim DimIndex d
_ = Bool
False
depOnGid :: DimIndex SubExp -> Bool
depOnGid (DimFix (Var VName
nm)) =
VName
gid VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
nm Bool -> Bool -> Bool
|| VName -> Names -> Bool
nameIn VName
gid (Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
nm VarianceTable
variance)
depOnGid DimIndex SubExp
_ = Bool
False
transposeFV (Map VName ([Int], VName, Stms GPU), Stm GPU)
r = (Map VName ([Int], VName, Stms GPU), Stm GPU)
-> ReaderT
(Scope GPU)
(State VNameSource)
(Map VName ([Int], VName, Stms GPU), Stm GPU)
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map VName ([Int], VName, Stms GPU), Stm GPU)
r
matchCodeAccumCode ::
Stms GPU ->
(Stms GPU, Maybe (Stm GPU), Stms GPU)
matchCodeAccumCode :: Stms GPU -> (Stms GPU, Maybe (Stm GPU), Stms GPU)
matchCodeAccumCode Stms GPU
kstms =
let ([Stm GPU]
code1, Maybe (Stm GPU)
screma, [Stm GPU]
code2) =
(([Stm GPU], Maybe (Stm GPU), [Stm GPU])
-> Stm GPU -> ([Stm GPU], Maybe (Stm GPU), [Stm GPU]))
-> ([Stm GPU], Maybe (Stm GPU), [Stm GPU])
-> [Stm GPU]
-> ([Stm GPU], Maybe (Stm GPU), [Stm GPU])
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
( \([Stm GPU], Maybe (Stm GPU), [Stm GPU])
acc Stm GPU
stmt ->
case (([Stm GPU], Maybe (Stm GPU), [Stm GPU])
acc, Stm GPU
stmt) of
(([Stm GPU]
cd1, Maybe (Stm GPU)
Nothing, [Stm GPU]
cd2), Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp UpdateAcc {})) ->
([Stm GPU]
cd1, Stm GPU -> Maybe (Stm GPU)
forall a. a -> Maybe a
Just Stm GPU
stmt, [Stm GPU]
cd2)
(([Stm GPU]
cd1, Maybe (Stm GPU)
Nothing, [Stm GPU]
cd2), Stm GPU
_) ->
([Stm GPU]
cd1 [Stm GPU] -> [Stm GPU] -> [Stm GPU]
forall a. [a] -> [a] -> [a]
++ [Stm GPU
stmt], Maybe (Stm GPU)
forall a. Maybe a
Nothing, [Stm GPU]
cd2)
(([Stm GPU]
cd1, Just Stm GPU
strm, [Stm GPU]
cd2), Stm GPU
_) ->
([Stm GPU]
cd1, Stm GPU -> Maybe (Stm GPU)
forall a. a -> Maybe a
Just Stm GPU
strm, [Stm GPU]
cd2 [Stm GPU] -> [Stm GPU] -> [Stm GPU]
forall a. [a] -> [a] -> [a]
++ [Stm GPU
stmt])
)
([], Maybe (Stm GPU)
forall a. Maybe a
Nothing, [])
(Stms GPU -> [Stm GPU]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms GPU
kstms)
in ([Stm GPU] -> Stms GPU
forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm GPU]
code1, Maybe (Stm GPU)
screma, [Stm GPU] -> Stms GPU
forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm GPU]
code2)
isInvarToParDim ::
Names ->
SegSpace ->
VarianceTable ->
[SubExp] ->
Maybe (VName, Int)
isInvarToParDim :: Names
-> SegSpace -> VarianceTable -> [SubExp] -> Maybe (VName, Int)
isInvarToParDim Names
branch_variant SegSpace
kspace VarianceTable
variance [SubExp]
acc_inds =
let ker_gids :: [VName]
ker_gids = ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst ([(VName, SubExp)] -> [VName]) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
kspace
branch_invariant :: Bool
branch_invariant = (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> Names -> Bool
`notNameIn` Names
branch_variant) [VName]
ker_gids
allvar2 :: Names
allvar2 = [SubExp] -> [VName] -> Names
allvariant2 [SubExp]
acc_inds [VName]
ker_gids
last_invar_dim :: Maybe (VName, Int)
last_invar_dim =
(Maybe (VName, Int) -> (VName, Int) -> Maybe (VName, Int))
-> Maybe (VName, Int) -> [(VName, Int)] -> Maybe (VName, Int)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (Names -> Maybe (VName, Int) -> (VName, Int) -> Maybe (VName, Int)
forall {b}.
Names -> Maybe (VName, b) -> (VName, b) -> Maybe (VName, b)
lastNotIn Names
allvar2) Maybe (VName, Int)
forall a. Maybe a
Nothing ([(VName, Int)] -> Maybe (VName, Int))
-> [(VName, Int)] -> Maybe (VName, Int)
forall a b. (a -> b) -> a -> b
$
[VName] -> [Int] -> [(VName, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ker_gids [Int
0 .. [VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
ker_gids Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
in if Bool
branch_invariant
then Maybe (VName, Int)
last_invar_dim
else Maybe (VName, Int)
forall a. Maybe a
Nothing
where
variant2 :: SubExp -> [VName] -> [VName]
variant2 (Var VName
ind) [VName]
kids =
let variant_to :: Names
variant_to =
Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
ind VarianceTable
variance
Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> (if VName
ind VName -> [VName] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
kids then VName -> Names
oneName VName
ind else Names
forall a. Monoid a => a
mempty)
in (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> Names -> Bool
`nameIn` Names
variant_to) [VName]
kids
variant2 SubExp
_ [VName]
_ = []
allvariant2 :: [SubExp] -> [VName] -> Names
allvariant2 [SubExp]
ind_ses [VName]
kids =
[VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (SubExp -> [VName]) -> [SubExp] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (SubExp -> [VName] -> [VName]
`variant2` [VName]
kids) [SubExp]
ind_ses
lastNotIn :: Names -> Maybe (VName, b) -> (VName, b) -> Maybe (VName, b)
lastNotIn Names
allvar2 Maybe (VName, b)
acc (VName
kid, b
k) =
if VName
kid VName -> Names -> Bool
`nameIn` Names
allvar2 then Maybe (VName, b)
acc else (VName, b) -> Maybe (VName, b)
forall a. a -> Maybe a
Just (VName
kid, b
k)
allGoodReturns :: [KernelResult] -> Maybe ([VName], [SubExp])
allGoodReturns :: [KernelResult] -> Maybe ([VName], [SubExp])
allGoodReturns [KernelResult]
kres
| (KernelResult -> Bool) -> [KernelResult] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all KernelResult -> Bool
goodReturn [KernelResult]
kres = do
([VName], [SubExp]) -> Maybe ([VName], [SubExp])
forall a. a -> Maybe a
Just (([VName], [SubExp]) -> Maybe ([VName], [SubExp]))
-> ([VName], [SubExp]) -> Maybe ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ (([VName], [SubExp]) -> KernelResult -> ([VName], [SubExp]))
-> ([VName], [SubExp]) -> [KernelResult] -> ([VName], [SubExp])
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ([VName], [SubExp]) -> KernelResult -> ([VName], [SubExp])
addCertAndRes ([], []) [KernelResult]
kres
where
goodReturn :: KernelResult -> Bool
goodReturn (Returns ResultManifest
ResultMaySimplify Certs
_ SubExp
_) = Bool
True
goodReturn KernelResult
_ = Bool
False
addCertAndRes :: ([VName], [SubExp]) -> KernelResult -> ([VName], [SubExp])
addCertAndRes ([VName]
cs, [SubExp]
rs) (Returns ResultManifest
ResultMaySimplify Certs
c SubExp
r_se) =
([VName]
cs [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ Certs -> [VName]
unCerts Certs
c, [SubExp]
rs [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp
r_se])
addCertAndRes ([VName], [SubExp])
_ KernelResult
_ =
[Char] -> ([VName], [SubExp])
forall a. HasCallStack => [Char] -> a
error [Char]
"Impossible case reached in GenRedOpt.hs, function allGoodReturns!"
allGoodReturns [KernelResult]
_ = Maybe ([VName], [SubExp])
forall a. Maybe a
Nothing
costRedundantExecution ::
VarianceTable ->
VName ->
[SubExp] ->
Stms GPU ->
Cost
costRedundantExecution :: VarianceTable -> VName -> [SubExp] -> Stms GPU -> Cost
costRedundantExecution VarianceTable
variance VName
pat_acc_nm [SubExp]
r_ses Stms GPU
kstms =
let acc_deps :: Names
acc_deps = Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
pat_acc_nm VarianceTable
variance
vartab_cut_acc :: VarianceTable
vartab_cut_acc = Names -> VarianceTable -> Stms GPU -> VarianceTable
varianceInStmsWithout (VName -> Names
oneName VName
pat_acc_nm) VarianceTable
forall a. Monoid a => a
mempty Stms GPU
kstms
res_deps :: Names
res_deps = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (VarianceTable -> VName -> Names
forall {k} {a}. (Ord k, Monoid a) => Map k a -> k -> a
findDeps VarianceTable
vartab_cut_acc) ([VName] -> [Names]) -> [VName] -> [Names]
forall a b. (a -> b) -> a -> b
$ (SubExp -> Maybe VName) -> [SubExp] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe SubExp -> Maybe VName
se2nm [SubExp]
r_ses
common_deps :: Names
common_deps = Names -> Names -> Names
namesIntersection Names
res_deps Names
acc_deps
in (Cost -> Stm GPU -> Cost) -> Cost -> Stms GPU -> Cost
forall b a. (b -> a -> b) -> b -> Seq a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (Names -> Cost -> Stm GPU -> Cost
addCostOfStmt Names
common_deps) (Int -> Cost
Small Int
0) Stms GPU
kstms
where
se2nm :: SubExp -> Maybe VName
se2nm (Var VName
nm) = VName -> Maybe VName
forall a. a -> Maybe a
Just VName
nm
se2nm SubExp
_ = Maybe VName
forall a. Maybe a
Nothing
findDeps :: Map k a -> k -> a
findDeps Map k a
vartab k
nm = a -> k -> Map k a -> a
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault a
forall a. Monoid a => a
mempty k
nm Map k a
vartab
addCostOfStmt :: Names -> Cost -> Stm GPU -> Cost
addCostOfStmt Names
common_deps Cost
cur_cost Stm GPU
stm =
let pat_nms :: [VName]
pat_nms = Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat Type -> [VName]) -> Pat Type -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm
in if Names -> Names -> Bool
namesIntersect ([VName] -> Names
namesFromList [VName]
pat_nms) Names
common_deps
then Cost -> Cost -> Cost
addCosts Cost
cur_cost (Cost -> Cost) -> Cost -> Cost
forall a b. (a -> b) -> a -> b
$ Stm GPU -> Cost
costRedundantStmt Stm GPU
stm
else Cost
cur_cost
varianceInStmsWithout :: Names -> VarianceTable -> Stms GPU -> VarianceTable
varianceInStmsWithout :: Names -> VarianceTable -> Stms GPU -> VarianceTable
varianceInStmsWithout Names
nms = (VarianceTable -> Stm GPU -> VarianceTable)
-> VarianceTable -> Stms GPU -> VarianceTable
forall b a. (b -> a -> b) -> b -> Seq a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
L.foldl' (Names -> VarianceTable -> Stm GPU -> VarianceTable
forall {rep}.
(FreeDec (ExpDec rep), FreeDec (BodyDec rep), FreeIn (OpC rep rep),
FreeIn (FParamInfo rep), FreeIn (LParamInfo rep),
FreeIn (LetDec rep), FreeIn (RetType rep),
FreeIn (BranchType rep)) =>
Names -> VarianceTable -> Stm rep -> VarianceTable
varianceInStmWithout Names
nms)
varianceInStmWithout :: Names -> VarianceTable -> Stm rep -> VarianceTable
varianceInStmWithout Names
cuts VarianceTable
vartab Stm rep
stm =
let pat_nms :: [VName]
pat_nms = Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat (LetDec rep) -> [VName]) -> Pat (LetDec rep) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm rep -> Pat (LetDec rep)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm rep
stm
in if Names -> Names -> Bool
namesIntersect ([VName] -> Names
namesFromList [VName]
pat_nms) Names
cuts
then VarianceTable
vartab
else (VarianceTable -> VName -> VarianceTable)
-> VarianceTable -> [VName] -> VarianceTable
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
L.foldl' VarianceTable -> VName -> VarianceTable
add VarianceTable
vartab [VName]
pat_nms
where
add :: VarianceTable -> VName -> VarianceTable
add VarianceTable
variance' VName
v = VName -> Names -> VarianceTable -> VarianceTable
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v Names
binding_variance VarianceTable
variance'
look :: VarianceTable -> VName -> Names
look VarianceTable
variance' VName
v = VName -> Names
oneName VName
v Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
v VarianceTable
variance'
binding_variance :: Names
binding_variance = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (VarianceTable -> VName -> Names
look VarianceTable
vartab) ([VName] -> [Names]) -> [VName] -> [Names]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Stm rep -> Names
forall a. FreeIn a => a -> Names
freeIn Stm rep
stm)
data Cost = Small Int | Big | Break
deriving (Cost -> Cost -> Bool
(Cost -> Cost -> Bool) -> (Cost -> Cost -> Bool) -> Eq Cost
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Cost -> Cost -> Bool
== :: Cost -> Cost -> Bool
$c/= :: Cost -> Cost -> Bool
/= :: Cost -> Cost -> Bool
Eq)
addCosts :: Cost -> Cost -> Cost
addCosts :: Cost -> Cost -> Cost
addCosts Cost
Break Cost
_ = Cost
Break
addCosts Cost
_ Cost
Break = Cost
Break
addCosts Cost
Big Cost
_ = Cost
Big
addCosts Cost
_ Cost
Big = Cost
Big
addCosts (Small Int
c1) (Small Int
c2) = Int -> Cost
Small (Int
c1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
c2)
maxCost :: Cost -> Cost -> Cost
maxCost :: Cost -> Cost -> Cost
maxCost (Small Int
c1) (Small Int
c2) = Int -> Cost
Small (Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
c1 Int
c2)
maxCost Cost
c1 Cost
c2 = Cost -> Cost -> Cost
addCosts Cost
c1 Cost
c2
costBody :: Body GPU -> Cost
costBody :: Body GPU -> Cost
costBody Body GPU
bdy =
(Cost -> Cost -> Cost) -> Cost -> [Cost] -> Cost
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Cost -> Cost -> Cost
addCosts (Int -> Cost
Small Int
0) ([Cost] -> Cost) -> [Cost] -> Cost
forall a b. (a -> b) -> a -> b
$
(Stm GPU -> Cost) -> [Stm GPU] -> [Cost]
forall a b. (a -> b) -> [a] -> [b]
map Stm GPU -> Cost
costRedundantStmt ([Stm GPU] -> [Cost]) -> [Stm GPU] -> [Cost]
forall a b. (a -> b) -> a -> b
$
Stms GPU -> [Stm GPU]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms GPU -> [Stm GPU]) -> Stms GPU -> [Stm GPU]
forall a b. (a -> b) -> a -> b
$
Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms Body GPU
bdy
costRedundantStmt :: Stm GPU -> Cost
costRedundantStmt :: Stm GPU -> Cost
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (Op OpC GPU GPU
_)) = Cost
Big
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ Loop {}) = Cost
Big
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ Apply {}) = Cost
Big
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ WithAcc {}) = Cost
Big
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (Match [SubExp]
_ [Case (Body GPU)]
cases Body GPU
defbody MatchDec (BranchType GPU)
_)) =
(Cost -> Cost -> Cost) -> Cost -> [Cost] -> Cost
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
L.foldl' Cost -> Cost -> Cost
maxCost (Body GPU -> Cost
costBody Body GPU
defbody) ([Cost] -> Cost) -> [Cost] -> Cost
forall a b. (a -> b) -> a -> b
$ (Case (Body GPU) -> Cost) -> [Case (Body GPU)] -> [Cost]
forall a b. (a -> b) -> [a] -> [b]
map (Body GPU -> Cost
costBody (Body GPU -> Cost)
-> (Case (Body GPU) -> Body GPU) -> Case (Body GPU) -> Cost
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Case (Body GPU) -> Body GPU
forall body. Case body -> body
caseBody) [Case (Body GPU)]
cases
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp (ArrayLit [SubExp]
_ Array {}))) = Cost
Big
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp (ArrayLit [SubExp]
_ Type
_))) = Int -> Cost
Small Int
1
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp (Index VName
_ Slice SubExp
slc))) =
if (DimIndex SubExp -> Bool) -> [DimIndex SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all DimIndex SubExp -> Bool
forall {d}. DimIndex d -> Bool
isFixDim (Slice SubExp -> [DimIndex SubExp]
forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slc) then Int -> Cost
Small Int
1 else Int -> Cost
Small Int
0
where
isFixDim :: DimIndex d -> Bool
isFixDim DimFix {} = Bool
True
isFixDim DimIndex d
_ = Bool
False
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp FlatIndex {})) = Int -> Cost
Small Int
0
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp Update {})) = Cost
Break
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp FlatUpdate {})) = Cost
Break
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp Concat {})) = Cost
Big
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp Manifest {})) = Cost
Big
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp Replicate {})) = Cost
Big
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp UpdateAcc {})) = Cost
Break
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp BasicOp
_)) = Int -> Cost
Small Int
0