{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Optimise.TileLoops (tileLoops) where
import Control.Monad
import Control.Monad.Reader
import Control.Monad.State
import Data.Map.Strict qualified as M
import Data.Maybe (mapMaybe)
import Data.Sequence qualified as Seq
import Futhark.Analysis.Alias qualified as Alias
import Futhark.IR.GPU
import Futhark.IR.Prop.Aliases (consumedInStm)
import Futhark.MonadFreshNames
import Futhark.Optimise.BlkRegTiling
import Futhark.Optimise.TileLoops.Shared
import Futhark.Pass
import Futhark.Tools
import Futhark.Transform.Rename
import Prelude hiding (quot)
tileLoops :: Pass GPU GPU
tileLoops :: Pass GPU GPU
tileLoops =
[Char] -> [Char] -> (Prog GPU -> PassM (Prog GPU)) -> Pass GPU GPU
forall fromrep torep.
[Char]
-> [Char]
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass [Char]
"tile loops" [Char]
"Tile stream loops inside kernels" ((Prog GPU -> PassM (Prog GPU)) -> Pass GPU GPU)
-> (Prog GPU -> PassM (Prog GPU)) -> Pass GPU GPU
forall a b. (a -> b) -> a -> b
$
(Scope GPU -> Stms GPU -> PassM (Stms GPU))
-> Prog GPU -> PassM (Prog GPU)
forall rep.
(Scope rep -> Stms rep -> PassM (Stms rep))
-> Prog rep -> PassM (Prog rep)
intraproceduralTransformation Scope GPU -> Stms GPU -> PassM (Stms GPU)
forall {m :: * -> *}.
MonadFreshNames m =>
Scope GPU -> Stms GPU -> m (Stms GPU)
onStms
where
onStms :: Scope GPU -> Stms GPU -> m (Stms GPU)
onStms Scope GPU
scope Stms GPU
stms =
(VNameSource -> (Stms GPU, VNameSource)) -> m (Stms GPU)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Stms GPU, VNameSource)) -> m (Stms GPU))
-> (VNameSource -> (Stms GPU, VNameSource)) -> m (Stms GPU)
forall a b. (a -> b) -> a -> b
$
State VNameSource (Stms GPU)
-> VNameSource -> (Stms GPU, VNameSource)
forall s a. State s a -> s -> (a, s)
runState (State VNameSource (Stms GPU)
-> VNameSource -> (Stms GPU, VNameSource))
-> State VNameSource (Stms GPU)
-> VNameSource
-> (Stms GPU, VNameSource)
forall a b. (a -> b) -> a -> b
$
ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
-> Scope GPU -> State VNameSource (Stms GPU)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Env
-> Stms GPU -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
optimiseStms (Map VName (Lambda GPU, [SubExp])
forall k a. Map k a
M.empty, Map VName LMAD
forall k a. Map k a
M.empty) Stms GPU
stms) Scope GPU
scope
optimiseBody :: Env -> Body GPU -> TileM (Body GPU)
optimiseBody :: Env -> Body GPU -> TileM (Body GPU)
optimiseBody Env
env (Body () Stms GPU
stms Result
res) =
BodyDec GPU -> Stms GPU -> Result -> Body GPU
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () (Stms GPU -> Result -> Body GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Result -> Body GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Env
-> Stms GPU -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
optimiseStms Env
env Stms GPU
stms ReaderT (Scope GPU) (State VNameSource) (Result -> Body GPU)
-> ReaderT (Scope GPU) (State VNameSource) Result
-> TileM (Body GPU)
forall a b.
ReaderT (Scope GPU) (State VNameSource) (a -> b)
-> ReaderT (Scope GPU) (State VNameSource) a
-> ReaderT (Scope GPU) (State VNameSource) b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> ReaderT (Scope GPU) (State VNameSource) Result
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res
optimiseStms :: Env -> Stms GPU -> TileM (Stms GPU)
optimiseStms :: Env
-> Stms GPU -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
optimiseStms Env
env Stms GPU
stms =
Scope GPU
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall a.
Scope GPU
-> ReaderT (Scope GPU) (State VNameSource) a
-> ReaderT (Scope GPU) (State VNameSource) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Stms GPU -> Scope GPU
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms GPU
stms) (ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU))
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
(Env
_, Stms GPU
stms') <- ((Env, Stms GPU)
-> Stm GPU
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU))
-> (Env, Stms GPU)
-> [Stm GPU]
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Env, Stms GPU)
-> Stm GPU
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU)
foldfun (Env
env, Stms GPU
forall a. Monoid a => a
mempty) ([Stm GPU]
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU))
-> [Stm GPU]
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stms GPU -> [Stm GPU]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms GPU
stms
Stms GPU -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms GPU
stms'
where
foldfun :: (Env, Stms GPU) -> Stm GPU -> TileM (Env, Stms GPU)
foldfun :: (Env, Stms GPU)
-> Stm GPU
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU)
foldfun (Env
e, Stms GPU
ss) Stm GPU
s = do
(Env
e', Stms GPU
s') <- Env
-> Stm GPU
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU)
optimiseStm Env
e Stm GPU
s
(Env, Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU)
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Env
e', Stms GPU
ss Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<> Stms GPU
s')
optimiseStm :: Env -> Stm GPU -> TileM (Env, Stms GPU)
optimiseStm :: Env
-> Stm GPU
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU)
optimiseStm Env
env stm :: Stm GPU
stm@(Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (Op (SegOp (SegMap lvl :: SegLevel
lvl@SegThread {} SegSpace
space [Type]
ts KernelBody GPU
kbody)))) = do
Maybe (Stms GPU, Stm GPU)
res3dtiling <- Scope GPU
-> ReaderT
(Scope GPU) (State VNameSource) (Maybe (Stms GPU, Stm GPU))
-> ReaderT
(Scope GPU) (State VNameSource) (Maybe (Stms GPU, Stm GPU))
forall a.
Scope GPU
-> ReaderT (Scope GPU) (State VNameSource) a
-> ReaderT (Scope GPU) (State VNameSource) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (SegSpace -> Scope GPU
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) (ReaderT
(Scope GPU) (State VNameSource) (Maybe (Stms GPU, Stm GPU))
-> ReaderT
(Scope GPU) (State VNameSource) (Maybe (Stms GPU, Stm GPU)))
-> ReaderT
(Scope GPU) (State VNameSource) (Maybe (Stms GPU, Stm GPU))
-> ReaderT
(Scope GPU) (State VNameSource) (Maybe (Stms GPU, Stm GPU))
forall a b. (a -> b) -> a -> b
$ Stm GPU
-> ReaderT
(Scope GPU) (State VNameSource) (Maybe (Stms GPU, Stm GPU))
doRegTiling3D Stm GPU
stm
Stms GPU
stms' <-
case Maybe (Stms GPU, Stm GPU)
res3dtiling of
Just (Stms GPU
extra_stms, Stm GPU
stmt') -> Stms GPU -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
extra_stms Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<> Stm GPU -> Stms GPU
forall rep. Stm rep -> Stms rep
oneStm Stm GPU
stmt')
Maybe (Stms GPU, Stm GPU)
Nothing -> do
Maybe (Stms GPU, Stm GPU)
blkRegTiling_res <- Env
-> Stm GPU
-> ReaderT
(Scope GPU) (State VNameSource) (Maybe (Stms GPU, Stm GPU))
mmBlkRegTiling Env
env Stm GPU
stm
case Maybe (Stms GPU, Stm GPU)
blkRegTiling_res of
Just (Stms GPU
extra_stms, Stm GPU
stmt') -> Stms GPU -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
extra_stms Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<> Stm GPU -> Stms GPU
forall rep. Stm rep -> Stms rep
oneStm Stm GPU
stmt')
Maybe (Stms GPU, Stm GPU)
Nothing -> Scope GPU
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall a.
Scope GPU
-> ReaderT (Scope GPU) (State VNameSource) a
-> ReaderT (Scope GPU) (State VNameSource) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (SegSpace -> Scope GPU
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) (ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU))
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
(Stms GPU
host_stms, (SegLevel
lvl', SegSpace
space', KernelBody GPU
kbody')) <- Names
-> AliasTable
-> SegLevel
-> SegSpace
-> [Type]
-> KernelBody GPU
-> TileM (Stms GPU, (SegLevel, SegSpace, KernelBody GPU))
tileInKernelBody Names
forall a. Monoid a => a
mempty AliasTable
initial_variance SegLevel
lvl SegSpace
space [Type]
ts KernelBody GPU
kbody
Stms GPU -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU))
-> Stms GPU -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stms GPU
host_stms Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<> Stm GPU -> Stms GPU
forall rep. Stm rep -> Stms rep
oneStm (Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (Exp GPU -> Stm GPU) -> Exp GPU -> Stm GPU
forall a b. (a -> b) -> a -> b
$ Op GPU -> Exp GPU
forall rep. Op rep -> Exp rep
Op (Op GPU -> Exp GPU) -> Op GPU -> Exp GPU
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPU -> HostOp SOAC GPU
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp (SegOp SegLevel GPU -> HostOp SOAC GPU)
-> SegOp SegLevel GPU -> HostOp SOAC GPU
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace -> [Type] -> KernelBody GPU -> SegOp SegLevel GPU
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
lvl' SegSpace
space' [Type]
ts KernelBody GPU
kbody')
(Env, Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU)
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Env
env, Stms GPU
stms')
where
initial_variance :: AliasTable
initial_variance = (NameInfo Any -> Names) -> Map VName (NameInfo Any) -> AliasTable
forall a b k. (a -> b) -> Map k a -> Map k b
M.map NameInfo Any -> Names
forall a. Monoid a => a
mempty (Map VName (NameInfo Any) -> AliasTable)
-> Map VName (NameInfo Any) -> AliasTable
forall a b. (a -> b) -> a -> b
$ SegSpace -> Map VName (NameInfo Any)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space
optimiseStm Env
env (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux Exp GPU
e) = do
Env
env' <- Env -> VName -> Exp GPU -> TileM Env
changeEnv Env
env ([VName] -> VName
forall a. HasCallStack => [a] -> a
head ([VName] -> VName) -> [VName] -> VName
forall a b. (a -> b) -> a -> b
$ Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
Pat (LetDec GPU)
pat) Exp GPU
e
Exp GPU
e' <- Mapper GPU GPU (ReaderT (Scope GPU) (State VNameSource))
-> Exp GPU -> ReaderT (Scope GPU) (State VNameSource) (Exp GPU)
forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM (Env -> Mapper GPU GPU (ReaderT (Scope GPU) (State VNameSource))
optimise Env
env') Exp GPU
e
(Env, Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU)
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Env
env', Stm GPU -> Stms GPU
forall rep. Stm rep -> Stms rep
oneStm (Stm GPU -> Stms GPU) -> Stm GPU -> Stms GPU
forall a b. (a -> b) -> a -> b
$ Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux Exp GPU
e')
where
optimise :: Env -> Mapper GPU GPU (ReaderT (Scope GPU) (State VNameSource))
optimise Env
env' = Mapper GPU GPU (ReaderT (Scope GPU) (State VNameSource))
forall rep (m :: * -> *). Monad m => Mapper rep rep m
identityMapper {mapOnBody = \Scope GPU
scope -> Scope GPU -> TileM (Body GPU) -> TileM (Body GPU)
forall a.
Scope GPU
-> ReaderT (Scope GPU) (State VNameSource) a
-> ReaderT (Scope GPU) (State VNameSource) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPU
scope (TileM (Body GPU) -> TileM (Body GPU))
-> (Body GPU -> TileM (Body GPU)) -> Body GPU -> TileM (Body GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> Body GPU -> TileM (Body GPU)
optimiseBody Env
env'}
tileInKernelBody ::
Names ->
VarianceTable ->
SegLevel ->
SegSpace ->
[Type] ->
KernelBody GPU ->
TileM (Stms GPU, (SegLevel, SegSpace, KernelBody GPU))
tileInKernelBody :: Names
-> AliasTable
-> SegLevel
-> SegSpace
-> [Type]
-> KernelBody GPU
-> TileM (Stms GPU, (SegLevel, SegSpace, KernelBody GPU))
tileInKernelBody Names
branch_variant AliasTable
initial_variance SegLevel
lvl SegSpace
initial_kspace [Type]
ts KernelBody GPU
kbody
| Just Result
kbody_res <- (KernelResult -> Maybe SubExpRes) -> [KernelResult] -> Maybe Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM KernelResult -> Maybe SubExpRes
isSimpleResult ([KernelResult] -> Maybe Result) -> [KernelResult] -> Maybe Result
forall a b. (a -> b) -> a -> b
$ KernelBody GPU -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPU
kbody = do
Maybe (Stms GPU, Tiling, TiledBody)
maybe_tiled <-
Names
-> AliasTable
-> SegLevel
-> SegSpace
-> [Type]
-> Body GPU
-> TileM (Maybe (Stms GPU, Tiling, TiledBody))
tileInBody Names
branch_variant AliasTable
initial_variance SegLevel
lvl SegSpace
initial_kspace [Type]
ts (Body GPU -> TileM (Maybe (Stms GPU, Tiling, TiledBody)))
-> Body GPU -> TileM (Maybe (Stms GPU, Tiling, TiledBody))
forall a b. (a -> b) -> a -> b
$
BodyDec GPU -> Stms GPU -> Result -> Body GPU
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () (KernelBody GPU -> Stms GPU
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPU
kbody) Result
kbody_res
case Maybe (Stms GPU, Tiling, TiledBody)
maybe_tiled of
Just (Stms GPU
host_stms, Tiling
tiling, TiledBody
tiledBody) -> do
([KernelResult]
res', Stms GPU
stms') <-
Builder GPU [KernelResult]
-> ReaderT
(Scope GPU) (State VNameSource) ([KernelResult], Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder GPU [KernelResult]
-> ReaderT
(Scope GPU) (State VNameSource) ([KernelResult], Stms GPU))
-> Builder GPU [KernelResult]
-> ReaderT
(Scope GPU) (State VNameSource) ([KernelResult], Stms GPU)
forall a b. (a -> b) -> a -> b
$ (VName -> BuilderT GPU (State VNameSource) KernelResult)
-> [VName] -> Builder GPU [KernelResult]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Tiling -> VName -> BuilderT GPU (State VNameSource) KernelResult
tilingTileReturns Tiling
tiling) ([VName] -> Builder GPU [KernelResult])
-> BuilderT GPU (State VNameSource) [VName]
-> Builder GPU [KernelResult]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TiledBody
tiledBody Names
forall a. Monoid a => a
mempty PrivStms
forall a. Monoid a => a
mempty
(Stms GPU, (SegLevel, SegSpace, KernelBody GPU))
-> TileM (Stms GPU, (SegLevel, SegSpace, KernelBody GPU))
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( Stms GPU
host_stms,
( Tiling -> SegLevel
tilingLevel Tiling
tiling,
Tiling -> SegSpace
tilingSpace Tiling
tiling,
BodyDec GPU -> Stms GPU -> [KernelResult] -> KernelBody GPU
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
stms' [KernelResult]
res'
)
)
Maybe (Stms GPU, Tiling, TiledBody)
Nothing ->
(Stms GPU, (SegLevel, SegSpace, KernelBody GPU))
-> TileM (Stms GPU, (SegLevel, SegSpace, KernelBody GPU))
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
forall a. Monoid a => a
mempty, (SegLevel
lvl, SegSpace
initial_kspace, KernelBody GPU
kbody))
| Bool
otherwise =
(Stms GPU, (SegLevel, SegSpace, KernelBody GPU))
-> TileM (Stms GPU, (SegLevel, SegSpace, KernelBody GPU))
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
forall a. Monoid a => a
mempty, (SegLevel
lvl, SegSpace
initial_kspace, KernelBody GPU
kbody))
where
isSimpleResult :: KernelResult -> Maybe SubExpRes
isSimpleResult (Returns ResultManifest
_ Certs
cs SubExp
se) = SubExpRes -> Maybe SubExpRes
forall a. a -> Maybe a
Just (SubExpRes -> Maybe SubExpRes) -> SubExpRes -> Maybe SubExpRes
forall a b. (a -> b) -> a -> b
$ Certs -> SubExp -> SubExpRes
SubExpRes Certs
cs SubExp
se
isSimpleResult KernelResult
_ = Maybe SubExpRes
forall a. Maybe a
Nothing
tileInBody ::
Names ->
VarianceTable ->
SegLevel ->
SegSpace ->
[Type] ->
Body GPU ->
TileM (Maybe (Stms GPU, Tiling, TiledBody))
tileInBody :: Names
-> AliasTable
-> SegLevel
-> SegSpace
-> [Type]
-> Body GPU
-> TileM (Maybe (Stms GPU, Tiling, TiledBody))
tileInBody Names
branch_variant AliasTable
initial_variance SegLevel
initial_lvl SegSpace
initial_space [Type]
res_ts (Body () Stms GPU
initial_kstms Result
stms_res) =
Stms GPU
-> [Stm GPU] -> TileM (Maybe (Stms GPU, Tiling, TiledBody))
descend Stms GPU
forall a. Monoid a => a
mempty ([Stm GPU] -> TileM (Maybe (Stms GPU, Tiling, TiledBody)))
-> [Stm GPU] -> TileM (Maybe (Stms GPU, Tiling, TiledBody))
forall a b. (a -> b) -> a -> b
$ Stms GPU -> [Stm GPU]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms GPU
initial_kstms
where
variance :: AliasTable
variance = AliasTable -> Stms GPU -> AliasTable
varianceInStms AliasTable
initial_variance Stms GPU
initial_kstms
descend :: Stms GPU
-> [Stm GPU] -> TileM (Maybe (Stms GPU, Tiling, TiledBody))
descend Stms GPU
_ [] =
Maybe (Stms GPU, Tiling, TiledBody)
-> TileM (Maybe (Stms GPU, Tiling, TiledBody))
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Stms GPU, Tiling, TiledBody)
forall a. Maybe a
Nothing
descend Stms GPU
prestms (Stm GPU
stm_to_tile : [Stm GPU]
poststms)
| ([VName]
gtids, [SubExp]
kdims) <- [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
initial_space,
Just (SubExp
w, [VName]
arrs, (Commutativity, Lambda GPU, [SubExp], Lambda GPU)
form) <- Stm GPU
-> Maybe
(SubExp, [VName],
(Commutativity, Lambda GPU, [SubExp], Lambda GPU))
tileable Stm GPU
stm_to_tile,
Just [InputArray]
inputs <-
(VName -> Maybe InputArray) -> [VName] -> Maybe [InputArray]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Names -> AliasTable -> [VName] -> VName -> Maybe InputArray
invariantToOneOfTwoInnerDims Names
branch_variant AliasTable
variance [VName]
gtids) [VName]
arrs,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [(VName, [Int])] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([(VName, [Int])] -> Bool) -> [(VName, [Int])] -> Bool
forall a b. (a -> b) -> a -> b
$ [InputArray] -> [(VName, [Int])]
tiledInputs [InputArray]
inputs,
VName
gtid_y : VName
gtid_x : [VName]
top_gtids_rev <- [VName] -> [VName]
forall a. [a] -> [a]
reverse [VName]
gtids,
SubExp
kdim_y : SubExp
kdim_x : [SubExp]
top_kdims_rev <- [SubExp] -> [SubExp]
forall a. [a] -> [a]
reverse [SubExp]
kdims,
Just (Stms GPU
prestms', Stms GPU
poststms') <-
AliasTable
-> Stms GPU -> Stm GPU -> Stms GPU -> Maybe (Stms GPU, Stms GPU)
preludeToPostlude AliasTable
variance Stms GPU
prestms Stm GPU
stm_to_tile ([Stm GPU] -> Stms GPU
forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm GPU]
poststms),
Names
used <- Stm GPU -> Names
forall a. FreeIn a => a -> Names
freeIn Stm GPU
stm_to_tile Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Stms GPU -> Names
forall a. FreeIn a => a -> Names
freeIn Stms GPU
poststms' Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
stms_res =
(Stms GPU, Tiling, TiledBody)
-> Maybe (Stms GPU, Tiling, TiledBody)
forall a. a -> Maybe a
Just ((Stms GPU, Tiling, TiledBody)
-> Maybe (Stms GPU, Tiling, TiledBody))
-> ((Stms GPU, Tiling, TiledBody) -> (Stms GPU, Tiling, TiledBody))
-> (Stms GPU, Tiling, TiledBody)
-> Maybe (Stms GPU, Tiling, TiledBody)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegSpace
-> AliasTable
-> Stms GPU
-> Names
-> (Stms GPU, Tiling, TiledBody)
-> (Stms GPU, Tiling, TiledBody)
injectPrelude SegSpace
initial_space AliasTable
variance Stms GPU
prestms' Names
used
((Stms GPU, Tiling, TiledBody)
-> Maybe (Stms GPU, Tiling, TiledBody))
-> ReaderT
(Scope GPU) (State VNameSource) (Stms GPU, Tiling, TiledBody)
-> TileM (Maybe (Stms GPU, Tiling, TiledBody))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DoTiling (VName, VName) (SubExp, SubExp)
-> [Type]
-> Pat Type
-> (VName, VName)
-> (SubExp, SubExp)
-> SubExp
-> (Commutativity, Lambda GPU, [SubExp], Lambda GPU)
-> [InputArray]
-> Stms GPU
-> Result
-> ReaderT
(Scope GPU) (State VNameSource) (Stms GPU, Tiling, TiledBody)
forall gtids kdims.
DoTiling gtids kdims
-> [Type]
-> Pat Type
-> gtids
-> kdims
-> SubExp
-> (Commutativity, Lambda GPU, [SubExp], Lambda GPU)
-> [InputArray]
-> Stms GPU
-> Result
-> ReaderT
(Scope GPU) (State VNameSource) (Stms GPU, Tiling, TiledBody)
tileGeneric
([(VName, SubExp)] -> DoTiling (VName, VName) (SubExp, SubExp)
tiling2d ([(VName, SubExp)] -> DoTiling (VName, VName) (SubExp, SubExp))
-> [(VName, SubExp)] -> DoTiling (VName, VName) (SubExp, SubExp)
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
reverse ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
top_gtids_rev [SubExp]
top_kdims_rev)
[Type]
res_ts
(Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm_to_tile)
(VName
gtid_x, VName
gtid_y)
(SubExp
kdim_x, SubExp
kdim_y)
SubExp
w
(Commutativity, Lambda GPU, [SubExp], Lambda GPU)
form
[InputArray]
inputs
Stms GPU
poststms'
Result
stms_res
| (VName
gtid, SubExp
kdim) : [(VName, SubExp)]
top_space_rev <- [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
reverse ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
initial_space,
Just (SubExp
w, [VName]
arrs, (Commutativity, Lambda GPU, [SubExp], Lambda GPU)
form) <- Stm GPU
-> Maybe
(SubExp, [VName],
(Commutativity, Lambda GPU, [SubExp], Lambda GPU))
tileable Stm GPU
stm_to_tile,
[InputArray]
inputs <- (VName -> InputArray) -> [VName] -> [InputArray]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> AliasTable -> VName -> InputArray
is1DTileable VName
gtid AliasTable
variance) [VName]
arrs,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [(VName, [Int])] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([(VName, [Int])] -> Bool) -> [(VName, [Int])] -> Bool
forall a b. (a -> b) -> a -> b
$ [InputArray] -> [(VName, [Int])]
tiledInputs [InputArray]
inputs,
VName
gtid VName -> Names -> Bool
`notNameIn` Names
branch_variant,
Just (Stms GPU
prestms', Stms GPU
poststms') <-
AliasTable
-> Stms GPU -> Stm GPU -> Stms GPU -> Maybe (Stms GPU, Stms GPU)
preludeToPostlude AliasTable
variance Stms GPU
prestms Stm GPU
stm_to_tile ([Stm GPU] -> Stms GPU
forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm GPU]
poststms),
Names
used <- Stm GPU -> Names
forall a. FreeIn a => a -> Names
freeIn Stm GPU
stm_to_tile Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Stms GPU -> Names
forall a. FreeIn a => a -> Names
freeIn Stms GPU
poststms' Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
stms_res =
(Stms GPU, Tiling, TiledBody)
-> Maybe (Stms GPU, Tiling, TiledBody)
forall a. a -> Maybe a
Just ((Stms GPU, Tiling, TiledBody)
-> Maybe (Stms GPU, Tiling, TiledBody))
-> ((Stms GPU, Tiling, TiledBody) -> (Stms GPU, Tiling, TiledBody))
-> (Stms GPU, Tiling, TiledBody)
-> Maybe (Stms GPU, Tiling, TiledBody)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegSpace
-> AliasTable
-> Stms GPU
-> Names
-> (Stms GPU, Tiling, TiledBody)
-> (Stms GPU, Tiling, TiledBody)
injectPrelude SegSpace
initial_space AliasTable
variance Stms GPU
prestms' Names
used
((Stms GPU, Tiling, TiledBody)
-> Maybe (Stms GPU, Tiling, TiledBody))
-> ReaderT
(Scope GPU) (State VNameSource) (Stms GPU, Tiling, TiledBody)
-> TileM (Maybe (Stms GPU, Tiling, TiledBody))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DoTiling VName SubExp
-> [Type]
-> Pat Type
-> VName
-> SubExp
-> SubExp
-> (Commutativity, Lambda GPU, [SubExp], Lambda GPU)
-> [InputArray]
-> Stms GPU
-> Result
-> ReaderT
(Scope GPU) (State VNameSource) (Stms GPU, Tiling, TiledBody)
forall gtids kdims.
DoTiling gtids kdims
-> [Type]
-> Pat Type
-> gtids
-> kdims
-> SubExp
-> (Commutativity, Lambda GPU, [SubExp], Lambda GPU)
-> [InputArray]
-> Stms GPU
-> Result
-> ReaderT
(Scope GPU) (State VNameSource) (Stms GPU, Tiling, TiledBody)
tileGeneric
([(VName, SubExp)] -> DoTiling VName SubExp
tiling1d ([(VName, SubExp)] -> DoTiling VName SubExp)
-> [(VName, SubExp)] -> DoTiling VName SubExp
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
reverse [(VName, SubExp)]
top_space_rev)
[Type]
res_ts
(Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm_to_tile)
VName
gtid
SubExp
kdim
SubExp
w
(Commutativity, Lambda GPU, [SubExp], Lambda GPU)
form
[InputArray]
inputs
Stms GPU
poststms'
Result
stms_res
| Loop [(FParam GPU, SubExp)]
merge (ForLoop VName
i IntType
it SubExp
bound) Body GPU
loopbody <- Stm GPU -> Exp GPU
forall rep. Stm rep -> Exp rep
stmExp Stm GPU
stm_to_tile,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ ((Param (TypeBase Shape Uniqueness), SubExp) -> Bool)
-> [(Param (TypeBase Shape Uniqueness), SubExp)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((VName -> Names -> Bool
`nameIn` [(Param (TypeBase Shape Uniqueness), SubExp)] -> Names
forall a. FreeIn a => a -> Names
freeIn [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam GPU, SubExp)]
merge) (VName -> Bool)
-> ((Param (TypeBase Shape Uniqueness), SubExp) -> VName)
-> (Param (TypeBase Shape Uniqueness), SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName (Param (TypeBase Shape Uniqueness) -> VName)
-> ((Param (TypeBase Shape Uniqueness), SubExp)
-> Param (TypeBase Shape Uniqueness))
-> (Param (TypeBase Shape Uniqueness), SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param (TypeBase Shape Uniqueness), SubExp)
-> Param (TypeBase Shape Uniqueness)
forall a b. (a, b) -> a
fst) [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam GPU, SubExp)]
merge,
Just (Stms GPU
prestms', Stms GPU
poststms') <-
AliasTable
-> Stms GPU -> Stm GPU -> Stms GPU -> Maybe (Stms GPU, Stms GPU)
preludeToPostlude AliasTable
variance Stms GPU
prestms Stm GPU
stm_to_tile ([Stm GPU] -> Stms GPU
forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm GPU]
poststms) = do
let branch_variant' :: Names
branch_variant' =
Names
branch_variant
Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat
( (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map
((VName -> AliasTable -> Names) -> AliasTable -> VName -> Names
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Names -> VName -> AliasTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty) AliasTable
variance)
(Names -> [VName]
namesToList (SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn SubExp
bound))
)
merge_params :: [Param (TypeBase Shape Uniqueness)]
merge_params = ((Param (TypeBase Shape Uniqueness), SubExp)
-> Param (TypeBase Shape Uniqueness))
-> [(Param (TypeBase Shape Uniqueness), SubExp)]
-> [Param (TypeBase Shape Uniqueness)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (TypeBase Shape Uniqueness), SubExp)
-> Param (TypeBase Shape Uniqueness)
forall a b. (a, b) -> a
fst [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam GPU, SubExp)]
merge
Maybe (Stms GPU, Tiling, TiledBody)
maybe_tiled <-
Scope GPU
-> TileM (Maybe (Stms GPU, Tiling, TiledBody))
-> TileM (Maybe (Stms GPU, Tiling, TiledBody))
forall a.
Scope GPU
-> ReaderT (Scope GPU) (State VNameSource) a
-> ReaderT (Scope GPU) (State VNameSource) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (VName -> NameInfo GPU -> Scope GPU -> Scope GPU
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
i (IntType -> NameInfo GPU
forall rep. IntType -> NameInfo rep
IndexName IntType
it) (Scope GPU -> Scope GPU) -> Scope GPU -> Scope GPU
forall a b. (a -> b) -> a -> b
$ [Param (TypeBase Shape Uniqueness)] -> Scope GPU
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param (TypeBase Shape Uniqueness)]
merge_params)
(TileM (Maybe (Stms GPU, Tiling, TiledBody))
-> TileM (Maybe (Stms GPU, Tiling, TiledBody)))
-> TileM (Maybe (Stms GPU, Tiling, TiledBody))
-> TileM (Maybe (Stms GPU, Tiling, TiledBody))
forall a b. (a -> b) -> a -> b
$ Names
-> AliasTable
-> SegLevel
-> SegSpace
-> [Type]
-> Body GPU
-> TileM (Maybe (Stms GPU, Tiling, TiledBody))
tileInBody
Names
branch_variant'
AliasTable
variance
SegLevel
initial_lvl
SegSpace
initial_space
((Param (TypeBase Shape Uniqueness) -> Type)
-> [Param (TypeBase Shape Uniqueness)] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape Uniqueness) -> Type
forall dec. Typed dec => Param dec -> Type
paramType [Param (TypeBase Shape Uniqueness)]
merge_params)
(Body GPU -> TileM (Maybe (Stms GPU, Tiling, TiledBody)))
-> Body GPU -> TileM (Maybe (Stms GPU, Tiling, TiledBody))
forall a b. (a -> b) -> a -> b
$ Stms GPU -> Result -> Body GPU
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody (Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms Body GPU
loopbody) (Body GPU -> Result
forall rep. Body rep -> Result
bodyResult Body GPU
loopbody)
case Maybe (Stms GPU, Tiling, TiledBody)
maybe_tiled of
Maybe (Stms GPU, Tiling, TiledBody)
Nothing -> TileM (Maybe (Stms GPU, Tiling, TiledBody))
next
Just (Stms GPU, Tiling, TiledBody)
tiled ->
(Stms GPU, Tiling, TiledBody)
-> Maybe (Stms GPU, Tiling, TiledBody)
forall a. a -> Maybe a
Just
((Stms GPU, Tiling, TiledBody)
-> Maybe (Stms GPU, Tiling, TiledBody))
-> ReaderT
(Scope GPU) (State VNameSource) (Stms GPU, Tiling, TiledBody)
-> TileM (Maybe (Stms GPU, Tiling, TiledBody))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegSpace
-> AliasTable
-> Stms GPU
-> Names
-> (Stms GPU, Tiling, TiledBody)
-> [Type]
-> Pat Type
-> StmAux (ExpDec GPU)
-> [(FParam GPU, SubExp)]
-> VName
-> IntType
-> SubExp
-> Stms GPU
-> Result
-> ReaderT
(Scope GPU) (State VNameSource) (Stms GPU, Tiling, TiledBody)
tileLoop
SegSpace
initial_space
AliasTable
variance
Stms GPU
prestms'
(Body GPU -> Names
forall a. FreeIn a => a -> Names
freeIn Body GPU
loopbody Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [(Param (TypeBase Shape Uniqueness), SubExp)] -> Names
forall a. FreeIn a => a -> Names
freeIn [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam GPU, SubExp)]
merge)
(Stms GPU, Tiling, TiledBody)
tiled
[Type]
res_ts
(Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm_to_tile)
(Stm GPU -> StmAux (ExpDec GPU)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm GPU
stm_to_tile)
[(FParam GPU, SubExp)]
merge
VName
i
IntType
it
SubExp
bound
Stms GPU
poststms'
Result
stms_res
| Bool
otherwise = TileM (Maybe (Stms GPU, Tiling, TiledBody))
next
where
next :: TileM (Maybe (Stms GPU, Tiling, TiledBody))
next =
Scope GPU
-> TileM (Maybe (Stms GPU, Tiling, TiledBody))
-> TileM (Maybe (Stms GPU, Tiling, TiledBody))
forall a.
Scope GPU
-> ReaderT (Scope GPU) (State VNameSource) a
-> ReaderT (Scope GPU) (State VNameSource) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Stm GPU -> Scope GPU
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stm GPU
stm_to_tile) (TileM (Maybe (Stms GPU, Tiling, TiledBody))
-> TileM (Maybe (Stms GPU, Tiling, TiledBody)))
-> TileM (Maybe (Stms GPU, Tiling, TiledBody))
-> TileM (Maybe (Stms GPU, Tiling, TiledBody))
forall a b. (a -> b) -> a -> b
$
Stms GPU
-> [Stm GPU] -> TileM (Maybe (Stms GPU, Tiling, TiledBody))
descend (Stms GPU
prestms Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<> Stm GPU -> Stms GPU
forall rep. Stm rep -> Stms rep
oneStm Stm GPU
stm_to_tile) [Stm GPU]
poststms
preludeToPostlude ::
VarianceTable ->
Stms GPU ->
Stm GPU ->
Stms GPU ->
Maybe (Stms GPU, Stms GPU)
preludeToPostlude :: AliasTable
-> Stms GPU -> Stm GPU -> Stms GPU -> Maybe (Stms GPU, Stms GPU)
preludeToPostlude AliasTable
variance Stms GPU
prelude Stm GPU
stm_to_tile Stms GPU
postlude = do
let prelude_sizes :: Names
prelude_sizes =
[Type] -> Names
forall a. FreeIn a => a -> Names
freeIn ([Type] -> Names) -> [Type] -> Names
forall a b. (a -> b) -> a -> b
$ (Stm GPU -> [Type]) -> Stms GPU -> [Type]
forall m a. Monoid m => (a -> m) -> Seq a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Pat Type -> [Type]
forall dec. Typed dec => Pat dec -> [Type]
patTypes (Pat Type -> [Type]) -> (Stm GPU -> Pat Type) -> Stm GPU -> [Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm GPU -> Pat Type
Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat) Stms GPU
prelude_used
prelude_bound :: Names
prelude_bound =
[VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (Stm GPU -> [VName]) -> Stms GPU -> [VName]
forall m a. Monoid m => (a -> m) -> Seq a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat Type -> [VName])
-> (Stm GPU -> Pat Type) -> Stm GPU -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm GPU -> Pat Type
Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat) Stms GPU
prelude_used
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Names
prelude_sizes Names -> Names -> Bool
`namesIntersect` Names
prelude_bound
(Stms GPU, Stms GPU) -> Maybe (Stms GPU, Stms GPU)
forall a. a -> Maybe a
Just (Stms GPU
prelude_used, Stms GPU
prelude_not_used Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<> Stms GPU
postlude)
where
used_in_tiled :: Names
used_in_tiled = Stm GPU -> Names
forall a. FreeIn a => a -> Names
freeIn Stm GPU
stm_to_tile
used_in_stm_variant :: Names
used_in_stm_variant =
(Names
used_in_tiled <>) (Names -> Names) -> Names -> Names
forall a b. (a -> b) -> a -> b
$
[Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$
(VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map ((VName -> AliasTable -> Names) -> AliasTable -> VName -> Names
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Names -> VName -> AliasTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty) AliasTable
variance) ([VName] -> [Names]) -> [VName] -> [Names]
forall a b. (a -> b) -> a -> b
$
Names -> [VName]
namesToList Names
used_in_tiled
used :: Stm GPU -> Bool
used Stm GPU
stm =
(VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` Names
used_in_stm_variant) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$
Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat Type -> [VName]) -> Pat Type -> [VName]
forall a b. (a -> b) -> a -> b
$
Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm
(Stms GPU
prelude_used, Stms GPU
prelude_not_used) =
(Stm GPU -> Bool) -> Stms GPU -> (Stms GPU, Stms GPU)
forall a. (a -> Bool) -> Seq a -> (Seq a, Seq a)
Seq.partition Stm GPU -> Bool
used Stms GPU
prelude
partitionPrelude ::
VarianceTable ->
Stms GPU ->
Names ->
Names ->
(Stms GPU, Stms GPU, Stms GPU)
partitionPrelude :: AliasTable
-> Stms GPU -> Names -> Names -> (Stms GPU, Stms GPU, Stms GPU)
partitionPrelude AliasTable
variance Stms GPU
prestms Names
private Names
used_after =
(Stms GPU
invariant_prestms, Stms GPU
variant_prestms, Stms GPU
recomputed_variant_prestms)
where
invariantTo :: Names -> Stm GPU -> Bool
invariantTo Names
names Stm GPU
stm =
case Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames (Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm) of
[] -> Bool
True
VName
v : [VName]
_ -> (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> Names -> Bool
`notNameIn` Names
names) (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Names -> VName -> AliasTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
v AliasTable
variance)
consumed_in_prestms :: Names
consumed_in_prestms =
(Stm (Aliases GPU) -> Names) -> Seq (Stm (Aliases GPU)) -> Names
forall m a. Monoid m => (a -> m) -> Seq a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm (Aliases GPU) -> Names
forall rep. Aliased rep => Stm rep -> Names
consumedInStm (Seq (Stm (Aliases GPU)) -> Names)
-> Seq (Stm (Aliases GPU)) -> Names
forall a b. (a -> b) -> a -> b
$ (Seq (Stm (Aliases GPU)), AliasesAndConsumed)
-> Seq (Stm (Aliases GPU))
forall a b. (a, b) -> a
fst ((Seq (Stm (Aliases GPU)), AliasesAndConsumed)
-> Seq (Stm (Aliases GPU)))
-> (Seq (Stm (Aliases GPU)), AliasesAndConsumed)
-> Seq (Stm (Aliases GPU))
forall a b. (a -> b) -> a -> b
$ AliasTable
-> Stms GPU -> (Seq (Stm (Aliases GPU)), AliasesAndConsumed)
forall rep.
AliasableRep rep =>
AliasTable -> Stms rep -> (Stms (Aliases rep), AliasesAndConsumed)
Alias.analyseStms AliasTable
forall a. Monoid a => a
mempty Stms GPU
prestms
consumed :: VName -> Bool
consumed VName
v = VName
v VName -> Names -> Bool
`nameIn` Names
consumed_in_prestms
consumedStm :: Stm GPU -> Bool
consumedStm Stm GPU
stm = (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any VName -> Bool
consumed (Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames (Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm))
later_consumed :: Names
later_consumed =
[VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (Stm GPU -> [VName]) -> Stms GPU -> [VName]
forall m a. Monoid m => (a -> m) -> Seq a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat Type -> [VName])
-> (Stm GPU -> Pat Type) -> Stm GPU -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm GPU -> Pat Type
Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat) (Stms GPU -> [VName]) -> Stms GPU -> [VName]
forall a b. (a -> b) -> a -> b
$ (Stm GPU -> Bool) -> Stms GPU -> Stms GPU
forall a. (a -> Bool) -> Seq a -> Seq a
Seq.filter Stm GPU -> Bool
consumedStm Stms GPU
prestms
groupInvariant :: Stm GPU -> Bool
groupInvariant Stm GPU
stm =
Names -> Stm GPU -> Bool
invariantTo Names
private Stm GPU
stm
Bool -> Bool -> Bool
&& (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> Names -> Bool
`notNameIn` Names
later_consumed) (Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames (Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm))
Bool -> Bool -> Bool
&& Names -> Stm GPU -> Bool
invariantTo Names
later_consumed Stm GPU
stm
(Stms GPU
invariant_prestms, Stms GPU
variant_prestms) =
(Stm GPU -> Bool) -> Stms GPU -> (Stms GPU, Stms GPU)
forall a. (a -> Bool) -> Seq a -> (Seq a, Seq a)
Seq.partition Stm GPU -> Bool
groupInvariant Stms GPU
prestms
mustBeInlinedExp :: Exp rep -> Bool
mustBeInlinedExp (BasicOp (Index VName
_ Slice SubExp
slice)) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([SubExp] -> Bool) -> [SubExp] -> Bool
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice
mustBeInlinedExp (BasicOp Iota {}) = Bool
True
mustBeInlinedExp (BasicOp Rearrange {}) = Bool
True
mustBeInlinedExp (BasicOp Reshape {}) = Bool
True
mustBeInlinedExp Exp rep
_ = Bool
False
mustBeInlined :: Stm GPU -> Bool
mustBeInlined Stm GPU
stm =
Exp GPU -> Bool
forall {rep}. Exp rep -> Bool
mustBeInlinedExp (Stm GPU -> Exp GPU
forall rep. Stm rep -> Exp rep
stmExp Stm GPU
stm)
Bool -> Bool -> Bool
&& (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` Names
used_after) (Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames (Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm))
must_be_inlined :: Names
must_be_inlined =
[VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$
(Stm GPU -> [VName]) -> Stms GPU -> [VName]
forall m a. Monoid m => (a -> m) -> Seq a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat Type -> [VName])
-> (Stm GPU -> Pat Type) -> Stm GPU -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm GPU -> Pat Type
Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat) (Stms GPU -> [VName]) -> Stms GPU -> [VName]
forall a b. (a -> b) -> a -> b
$
(Stm GPU -> Bool) -> Stms GPU -> Stms GPU
forall a. (a -> Bool) -> Seq a -> Seq a
Seq.filter Stm GPU -> Bool
mustBeInlined Stms GPU
variant_prestms
recompute :: Stm GPU -> Bool
recompute Stm GPU
stm =
(VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` Names
must_be_inlined) (Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames (Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm))
recomputed_variant_prestms :: Stms GPU
recomputed_variant_prestms =
(Stm GPU -> Bool) -> Stms GPU -> Stms GPU
forall a. (a -> Bool) -> Seq a -> Seq a
Seq.filter Stm GPU -> Bool
recompute Stms GPU
variant_prestms
injectPrelude ::
SegSpace ->
VarianceTable ->
Stms GPU ->
Names ->
(Stms GPU, Tiling, TiledBody) ->
(Stms GPU, Tiling, TiledBody)
injectPrelude :: SegSpace
-> AliasTable
-> Stms GPU
-> Names
-> (Stms GPU, Tiling, TiledBody)
-> (Stms GPU, Tiling, TiledBody)
injectPrelude SegSpace
initial_space AliasTable
variance Stms GPU
prestms Names
used (Stms GPU
host_stms, Tiling
tiling, TiledBody
tiledBody) =
(Stms GPU
host_stms, Tiling
tiling, TiledBody
tiledBody')
where
tiledBody' :: TiledBody
tiledBody' Names
private PrivStms
privstms = do
let nontiled :: (VName, SubExp) -> Bool
nontiled = ((VName, SubExp) -> [(VName, SubExp)] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` SegSpace -> [(VName, SubExp)]
unSegSpace (Tiling -> SegSpace
tilingSpace Tiling
tiling))
private' :: Names
private' =
Names
private
Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [VName] -> Names
namesFromList (((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst (((VName, SubExp) -> Bool) -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName, SubExp) -> Bool
nontiled ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
initial_space))
( Stms GPU
invariant_prestms,
Stms GPU
precomputed_variant_prestms,
Stms GPU
recomputed_variant_prestms
) =
AliasTable
-> Stms GPU -> Names -> Names -> (Stms GPU, Stms GPU, Stms GPU)
partitionPrelude AliasTable
variance Stms GPU
prestms Names
private' Names
used
Stms (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms (Rep (BuilderT GPU (State VNameSource)))
Stms GPU
invariant_prestms
let live_set :: [VName]
live_set =
Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$
Stms GPU -> Names -> Names
forall a. FreeIn a => Stms GPU -> a -> Names
liveSet Stms GPU
precomputed_variant_prestms (Names -> Names) -> Names -> Names
forall a b. (a -> b) -> a -> b
$
Names
used Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Stms GPU -> Names
forall a. FreeIn a => a -> Names
freeIn Stms GPU
recomputed_variant_prestms
[VName]
prelude_arrs <-
Stms GPU
-> BuilderT GPU (State VNameSource) [VName]
-> BuilderT GPU (State VNameSource) [VName]
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms GPU
precomputed_variant_prestms (BuilderT GPU (State VNameSource) [VName]
-> BuilderT GPU (State VNameSource) [VName])
-> BuilderT GPU (State VNameSource) [VName]
-> BuilderT GPU (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
Tiling
-> PrivStms
-> Stms GPU
-> [VName]
-> BuilderT GPU (State VNameSource) [VName]
doPrelude Tiling
tiling PrivStms
privstms Stms GPU
precomputed_variant_prestms [VName]
live_set
let prelude_privstms :: PrivStms
prelude_privstms =
Stms GPU -> ReadPrelude -> PrivStms
PrivStms Stms GPU
recomputed_variant_prestms (ReadPrelude -> PrivStms) -> ReadPrelude -> PrivStms
forall a b. (a -> b) -> a -> b
$
[VName] -> [VName] -> ReadPrelude
mkReadPreludeValues [VName]
prelude_arrs [VName]
live_set
TiledBody
tiledBody Names
private' (PrivStms
prelude_privstms PrivStms -> PrivStms -> PrivStms
forall a. Semigroup a => a -> a -> a
<> PrivStms
privstms)
tileLoop ::
SegSpace ->
VarianceTable ->
Stms GPU ->
Names ->
(Stms GPU, Tiling, TiledBody) ->
[Type] ->
Pat Type ->
StmAux (ExpDec GPU) ->
[(FParam GPU, SubExp)] ->
VName ->
IntType ->
SubExp ->
Stms GPU ->
Result ->
TileM (Stms GPU, Tiling, TiledBody)
tileLoop :: SegSpace
-> AliasTable
-> Stms GPU
-> Names
-> (Stms GPU, Tiling, TiledBody)
-> [Type]
-> Pat Type
-> StmAux (ExpDec GPU)
-> [(FParam GPU, SubExp)]
-> VName
-> IntType
-> SubExp
-> Stms GPU
-> Result
-> ReaderT
(Scope GPU) (State VNameSource) (Stms GPU, Tiling, TiledBody)
tileLoop SegSpace
initial_space AliasTable
variance Stms GPU
prestms Names
used_in_body (Stms GPU
host_stms, Tiling
tiling, TiledBody
tiledBody) [Type]
res_ts Pat Type
pat StmAux (ExpDec GPU)
aux [(FParam GPU, SubExp)]
merge VName
i IntType
it SubExp
bound Stms GPU
poststms Result
poststms_res = do
let prestms_used :: Names
prestms_used = Names
used_in_body Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Stms GPU -> Names
forall a. FreeIn a => a -> Names
freeIn Stms GPU
poststms Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
poststms_res
( Stms GPU
invariant_prestms,
Stms GPU
precomputed_variant_prestms,
Stms GPU
recomputed_variant_prestms
) =
AliasTable
-> Stms GPU -> Names -> Names -> (Stms GPU, Stms GPU, Stms GPU)
partitionPrelude AliasTable
variance Stms GPU
prestms Names
tiled_kdims Names
prestms_used
let ([Param (TypeBase Shape Uniqueness)]
mergeparams, [SubExp]
mergeinits) = [(Param (TypeBase Shape Uniqueness), SubExp)]
-> ([Param (TypeBase Shape Uniqueness)], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam GPU, SubExp)]
merge
tileDim :: TypeBase Shape Uniqueness -> TypeBase Shape Uniqueness
tileDim TypeBase Shape Uniqueness
t = TypeBase Shape Uniqueness
-> Shape -> Uniqueness -> TypeBase Shape Uniqueness
forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf TypeBase Shape Uniqueness
t (Tiling -> Shape
tilingTileShape Tiling
tiling) (Uniqueness -> TypeBase Shape Uniqueness)
-> Uniqueness -> TypeBase Shape Uniqueness
forall a b. (a -> b) -> a -> b
$ TypeBase Shape Uniqueness -> Uniqueness
forall shape. TypeBase shape Uniqueness -> Uniqueness
uniqueness TypeBase Shape Uniqueness
t
merge_scope :: Scope GPU
merge_scope = VName -> NameInfo GPU -> Scope GPU -> Scope GPU
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
i (IntType -> NameInfo GPU
forall rep. IntType -> NameInfo rep
IndexName IntType
it) (Scope GPU -> Scope GPU) -> Scope GPU -> Scope GPU
forall a b. (a -> b) -> a -> b
$ [Param (TypeBase Shape Uniqueness)] -> Scope GPU
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param (TypeBase Shape Uniqueness)]
mergeparams
tiledBody' :: TiledBody
tiledBody' Names
private PrivStms
privstms = Scope GPU
-> BuilderT GPU (State VNameSource) [VName]
-> BuilderT GPU (State VNameSource) [VName]
forall a.
Scope GPU
-> BuilderT GPU (State VNameSource) a
-> BuilderT GPU (State VNameSource) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Stms GPU -> Scope GPU
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms GPU
host_stms Scope GPU -> Scope GPU -> Scope GPU
forall a. Semigroup a => a -> a -> a
<> Scope GPU
merge_scope) (BuilderT GPU (State VNameSource) [VName]
-> BuilderT GPU (State VNameSource) [VName])
-> BuilderT GPU (State VNameSource) [VName]
-> BuilderT GPU (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ do
Stms (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms (Rep (BuilderT GPU (State VNameSource)))
Stms GPU
invariant_prestms
let live_set :: [VName]
live_set =
Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$
Stms GPU -> Names -> Names
forall a. FreeIn a => Stms GPU -> a -> Names
liveSet Stms GPU
precomputed_variant_prestms (Names -> Names) -> Names -> Names
forall a b. (a -> b) -> a -> b
$
Stms GPU -> Names
forall a. FreeIn a => a -> Names
freeIn Stms GPU
recomputed_variant_prestms Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
prestms_used
[VName]
prelude_arrs <-
Stms GPU
-> BuilderT GPU (State VNameSource) [VName]
-> BuilderT GPU (State VNameSource) [VName]
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms GPU
precomputed_variant_prestms (BuilderT GPU (State VNameSource) [VName]
-> BuilderT GPU (State VNameSource) [VName])
-> BuilderT GPU (State VNameSource) [VName]
-> BuilderT GPU (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
Tiling
-> PrivStms
-> Stms GPU
-> [VName]
-> BuilderT GPU (State VNameSource) [VName]
doPrelude Tiling
tiling PrivStms
privstms Stms GPU
precomputed_variant_prestms [VName]
live_set
[Param (TypeBase Shape Uniqueness)]
mergeparams' <- [Param (TypeBase Shape Uniqueness)]
-> (Param (TypeBase Shape Uniqueness)
-> BuilderT
GPU (State VNameSource) (Param (TypeBase Shape Uniqueness)))
-> BuilderT
GPU (State VNameSource) [Param (TypeBase Shape Uniqueness)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Param (TypeBase Shape Uniqueness)]
mergeparams ((Param (TypeBase Shape Uniqueness)
-> BuilderT
GPU (State VNameSource) (Param (TypeBase Shape Uniqueness)))
-> BuilderT
GPU (State VNameSource) [Param (TypeBase Shape Uniqueness)])
-> (Param (TypeBase Shape Uniqueness)
-> BuilderT
GPU (State VNameSource) (Param (TypeBase Shape Uniqueness)))
-> BuilderT
GPU (State VNameSource) [Param (TypeBase Shape Uniqueness)]
forall a b. (a -> b) -> a -> b
$ \(Param Attrs
attrs VName
pname TypeBase Shape Uniqueness
pt) ->
Attrs
-> VName
-> TypeBase Shape Uniqueness
-> Param (TypeBase Shape Uniqueness)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs (VName
-> TypeBase Shape Uniqueness -> Param (TypeBase Shape Uniqueness))
-> BuilderT GPU (State VNameSource) VName
-> BuilderT
GPU
(State VNameSource)
(TypeBase Shape Uniqueness -> Param (TypeBase Shape Uniqueness))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName (VName -> [Char]
baseString VName
pname [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_group") BuilderT
GPU
(State VNameSource)
(TypeBase Shape Uniqueness -> Param (TypeBase Shape Uniqueness))
-> BuilderT GPU (State VNameSource) (TypeBase Shape Uniqueness)
-> BuilderT
GPU (State VNameSource) (Param (TypeBase Shape Uniqueness))
forall a b.
BuilderT GPU (State VNameSource) (a -> b)
-> BuilderT GPU (State VNameSource) a
-> BuilderT GPU (State VNameSource) b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TypeBase Shape Uniqueness
-> BuilderT GPU (State VNameSource) (TypeBase Shape Uniqueness)
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TypeBase Shape Uniqueness -> TypeBase Shape Uniqueness
tileDim TypeBase Shape Uniqueness
pt)
let merge_ts :: [Type]
merge_ts = (Param (TypeBase Shape Uniqueness) -> Type)
-> [Param (TypeBase Shape Uniqueness)] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape Uniqueness) -> Type
forall dec. Typed dec => Param dec -> Type
paramType [Param (TypeBase Shape Uniqueness)]
mergeparams
let inloop_privstms :: PrivStms
inloop_privstms =
Stms GPU -> ReadPrelude -> PrivStms
PrivStms Stms GPU
recomputed_variant_prestms (ReadPrelude -> PrivStms) -> ReadPrelude -> PrivStms
forall a b. (a -> b) -> a -> b
$
[VName] -> [VName] -> ReadPrelude
mkReadPreludeValues [VName]
prelude_arrs [VName]
live_set
[SubExp]
mergeinit' <-
([VName] -> [SubExp])
-> BuilderT GPU (State VNameSource) [VName]
-> BuilderT GPU (State VNameSource) [SubExp]
forall a b.
(a -> b)
-> BuilderT GPU (State VNameSource) a
-> BuilderT GPU (State VNameSource) b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var) (BuilderT GPU (State VNameSource) [VName]
-> BuilderT GPU (State VNameSource) [SubExp])
-> BuilderT GPU (State VNameSource) [VName]
-> BuilderT GPU (State VNameSource) [SubExp]
forall a b. (a -> b) -> a -> b
$
Certs
-> BuilderT GPU (State VNameSource) [VName]
-> BuilderT GPU (State VNameSource) [VName]
forall a.
Certs
-> BuilderT GPU (State VNameSource) a
-> BuilderT GPU (State VNameSource) a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (StmAux () -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux ()
StmAux (ExpDec GPU)
aux) (BuilderT GPU (State VNameSource) [VName]
-> BuilderT GPU (State VNameSource) [VName])
-> BuilderT GPU (State VNameSource) [VName]
-> BuilderT GPU (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
Tiling
-> [Char]
-> ResultManifest
-> (PrimExp VName -> [DimIndex SubExp] -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName]
tilingSegMap Tiling
tiling [Char]
"tiled_loopinit" ResultManifest
ResultPrivate ((PrimExp VName -> [DimIndex SubExp] -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName])
-> (PrimExp VName -> [DimIndex SubExp] -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
\PrimExp VName
in_bounds [DimIndex SubExp]
slice ->
([VName] -> Result)
-> BuilderT GPU (State VNameSource) [VName] -> Builder GPU Result
forall a b.
(a -> b)
-> BuilderT GPU (State VNameSource) a
-> BuilderT GPU (State VNameSource) b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes (BuilderT GPU (State VNameSource) [VName] -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName] -> Builder GPU Result
forall a b. (a -> b) -> a -> b
$
[Char]
-> PrimExp VName
-> [Type]
-> Builder GPU Result
-> BuilderT GPU (State VNameSource) [VName]
protectOutOfBounds [Char]
"loopinit" PrimExp VName
in_bounds [Type]
merge_ts (Builder GPU Result -> BuilderT GPU (State VNameSource) [VName])
-> Builder GPU Result -> BuilderT GPU (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ do
[DimIndex SubExp]
-> PrivStms -> BuilderT GPU (State VNameSource) ()
addPrivStms [DimIndex SubExp]
slice PrivStms
inloop_privstms
[DimIndex SubExp]
-> PrivStms -> BuilderT GPU (State VNameSource) ()
addPrivStms [DimIndex SubExp]
slice PrivStms
privstms
Result -> Builder GPU Result
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> Builder GPU Result) -> Result -> Builder GPU Result
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes [SubExp]
mergeinits
let merge' :: [(Param (TypeBase Shape Uniqueness), SubExp)]
merge' = [Param (TypeBase Shape Uniqueness)]
-> [SubExp] -> [(Param (TypeBase Shape Uniqueness), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase Shape Uniqueness)]
mergeparams' [SubExp]
mergeinit'
let indexLoopParams :: ReadPrelude
indexLoopParams [DimIndex SubExp]
slice =
Scope GPU
-> BuilderT GPU (State VNameSource) ()
-> BuilderT GPU (State VNameSource) ()
forall a.
Scope GPU
-> BuilderT GPU (State VNameSource) a
-> BuilderT GPU (State VNameSource) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param (TypeBase Shape Uniqueness)] -> Scope GPU
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param (TypeBase Shape Uniqueness)]
mergeparams') (BuilderT GPU (State VNameSource) ()
-> BuilderT GPU (State VNameSource) ())
-> BuilderT GPU (State VNameSource) ()
-> BuilderT GPU (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
[(Param (TypeBase Shape Uniqueness),
Param (TypeBase Shape Uniqueness))]
-> ((Param (TypeBase Shape Uniqueness),
Param (TypeBase Shape Uniqueness))
-> BuilderT GPU (State VNameSource) ())
-> BuilderT GPU (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (TypeBase Shape Uniqueness)]
-> [Param (TypeBase Shape Uniqueness)]
-> [(Param (TypeBase Shape Uniqueness),
Param (TypeBase Shape Uniqueness))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase Shape Uniqueness)]
mergeparams [Param (TypeBase Shape Uniqueness)]
mergeparams') (((Param (TypeBase Shape Uniqueness),
Param (TypeBase Shape Uniqueness))
-> BuilderT GPU (State VNameSource) ())
-> BuilderT GPU (State VNameSource) ())
-> ((Param (TypeBase Shape Uniqueness),
Param (TypeBase Shape Uniqueness))
-> BuilderT GPU (State VNameSource) ())
-> BuilderT GPU (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ \(Param (TypeBase Shape Uniqueness)
to, Param (TypeBase Shape Uniqueness)
from) ->
[VName]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape Uniqueness)
to] (Exp GPU -> BuilderT GPU (State VNameSource) ())
-> (Slice SubExp -> Exp GPU)
-> Slice SubExp
-> BuilderT GPU (State VNameSource) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU)
-> (Slice SubExp -> BasicOp) -> Slice SubExp -> Exp GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Slice SubExp -> BasicOp
Index (Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape Uniqueness)
from) (Slice SubExp -> BuilderT GPU (State VNameSource) ())
-> Slice SubExp -> BuilderT GPU (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice (Param (TypeBase Shape Uniqueness) -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param (TypeBase Shape Uniqueness)
from) [DimIndex SubExp]
slice
private' :: Names
private' =
Names
private Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [VName] -> Names
namesFromList ((Param (TypeBase Shape Uniqueness) -> VName)
-> [Param (TypeBase Shape Uniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName [Param (TypeBase Shape Uniqueness)]
mergeparams [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ (Param (TypeBase Shape Uniqueness) -> VName)
-> [Param (TypeBase Shape Uniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName [Param (TypeBase Shape Uniqueness)]
mergeparams')
privstms' :: PrivStms
privstms' =
Stms GPU -> ReadPrelude -> PrivStms
PrivStms Stms GPU
forall a. Monoid a => a
mempty ReadPrelude
indexLoopParams PrivStms -> PrivStms -> PrivStms
forall a. Semigroup a => a -> a -> a
<> PrivStms
privstms PrivStms -> PrivStms -> PrivStms
forall a. Semigroup a => a -> a -> a
<> PrivStms
inloop_privstms
Body GPU
loopbody' <-
Scope GPU
-> BuilderT GPU (State VNameSource) (Body GPU)
-> BuilderT GPU (State VNameSource) (Body GPU)
forall a.
Scope GPU
-> BuilderT GPU (State VNameSource) a
-> BuilderT GPU (State VNameSource) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param (TypeBase Shape Uniqueness)] -> Scope GPU
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param (TypeBase Shape Uniqueness)]
mergeparams') (BuilderT GPU (State VNameSource) (Body GPU)
-> BuilderT GPU (State VNameSource) (Body GPU))
-> (Builder GPU Result
-> BuilderT GPU (State VNameSource) (Body GPU))
-> Builder GPU Result
-> BuilderT GPU (State VNameSource) (Body GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder GPU Result -> BuilderT GPU (State VNameSource) (Body GPU)
forall rep (m :: * -> *) somerep.
(Buildable rep, MonadFreshNames m, HasScope somerep m,
SameScope somerep rep) =>
Builder rep Result -> m (Body rep)
runBodyBuilder (Builder GPU Result -> BuilderT GPU (State VNameSource) (Body GPU))
-> Builder GPU Result
-> BuilderT GPU (State VNameSource) (Body GPU)
forall a b. (a -> b) -> a -> b
$
[VName] -> Result
varsRes ([VName] -> Result)
-> BuilderT GPU (State VNameSource) [VName] -> Builder GPU Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TiledBody
tiledBody Names
private' PrivStms
privstms'
[VName]
accs' <-
[Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"tiled_inside_loop" (Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) [VName])
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
[(FParam GPU, SubExp)] -> LoopForm -> Body GPU -> Exp GPU
forall rep.
[(FParam rep, SubExp)] -> LoopForm -> Body rep -> Exp rep
Loop [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam GPU, SubExp)]
merge' (VName -> IntType -> SubExp -> LoopForm
ForLoop VName
i IntType
it SubExp
bound) Body GPU
loopbody'
Tiling
-> PrivStms
-> Pat Type
-> [VName]
-> Stms GPU
-> Result
-> [Type]
-> BuilderT GPU (State VNameSource) [VName]
postludeGeneric Tiling
tiling (PrivStms
privstms PrivStms -> PrivStms -> PrivStms
forall a. Semigroup a => a -> a -> a
<> PrivStms
inloop_privstms) Pat Type
pat [VName]
accs' Stms GPU
poststms Result
poststms_res [Type]
res_ts
(Stms GPU, Tiling, TiledBody)
-> ReaderT
(Scope GPU) (State VNameSource) (Stms GPU, Tiling, TiledBody)
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
host_stms, Tiling
tiling, TiledBody
tiledBody')
where
tiled_kdims :: Names
tiled_kdims =
[VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$
((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst ([(VName, SubExp)] -> [VName]) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$
((VName, SubExp) -> Bool) -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName, SubExp) -> [(VName, SubExp)] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` SegSpace -> [(VName, SubExp)]
unSegSpace (Tiling -> SegSpace
tilingSpace Tiling
tiling)) ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$
SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
initial_space
doPrelude :: Tiling -> PrivStms -> Stms GPU -> [VName] -> Builder GPU [VName]
doPrelude :: Tiling
-> PrivStms
-> Stms GPU
-> [VName]
-> BuilderT GPU (State VNameSource) [VName]
doPrelude Tiling
tiling PrivStms
privstms Stms GPU
prestms [VName]
prestms_live =
Tiling
-> [Char]
-> ResultManifest
-> (PrimExp VName -> [DimIndex SubExp] -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName]
tilingSegMap Tiling
tiling [Char]
"prelude" ResultManifest
ResultPrivate ((PrimExp VName -> [DimIndex SubExp] -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName])
-> (PrimExp VName -> [DimIndex SubExp] -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \PrimExp VName
in_bounds [DimIndex SubExp]
slice -> do
[Type]
ts <- (VName -> BuilderT GPU (State VNameSource) Type)
-> [VName] -> BuilderT GPU (State VNameSource) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> BuilderT GPU (State VNameSource) Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
prestms_live
([VName] -> Result)
-> BuilderT GPU (State VNameSource) [VName] -> Builder GPU Result
forall a b.
(a -> b)
-> BuilderT GPU (State VNameSource) a
-> BuilderT GPU (State VNameSource) b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes (BuilderT GPU (State VNameSource) [VName] -> Builder GPU Result)
-> (Builder GPU Result -> BuilderT GPU (State VNameSource) [VName])
-> Builder GPU Result
-> Builder GPU Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char]
-> PrimExp VName
-> [Type]
-> Builder GPU Result
-> BuilderT GPU (State VNameSource) [VName]
protectOutOfBounds [Char]
"pre" PrimExp VName
in_bounds [Type]
ts (Builder GPU Result -> Builder GPU Result)
-> Builder GPU Result -> Builder GPU Result
forall a b. (a -> b) -> a -> b
$ do
[DimIndex SubExp]
-> PrivStms -> BuilderT GPU (State VNameSource) ()
addPrivStms [DimIndex SubExp]
slice PrivStms
privstms
Stms (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms (Rep (BuilderT GPU (State VNameSource)))
Stms GPU
prestms
Result -> Builder GPU Result
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> Builder GPU Result) -> Result -> Builder GPU Result
forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes [VName]
prestms_live
liveSet :: (FreeIn a) => Stms GPU -> a -> Names
liveSet :: forall a. FreeIn a => Stms GPU -> a -> Names
liveSet Stms GPU
stms a
after =
[VName] -> Names
namesFromList ((Stm GPU -> [VName]) -> Stms GPU -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat Type -> [VName])
-> (Stm GPU -> Pat Type) -> Stm GPU -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm GPU -> Pat Type
Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat) Stms GPU
stms)
Names -> Names -> Names
`namesIntersection` a -> Names
forall a. FreeIn a => a -> Names
freeIn a
after
tileable ::
Stm GPU ->
Maybe
( SubExp,
[VName],
(Commutativity, Lambda GPU, [SubExp], Lambda GPU)
)
tileable :: Stm GPU
-> Maybe
(SubExp, [VName],
(Commutativity, Lambda GPU, [SubExp], Lambda GPU))
tileable Stm GPU
stm
| Op (OtherOp (Screma SubExp
w [VName]
arrs ScremaForm GPU
form)) <- Stm GPU -> Exp GPU
forall rep. Stm rep -> Exp rep
stmExp Stm GPU
stm,
Just ([Reduce GPU]
reds, Lambda GPU
map_lam) <- ScremaForm GPU -> Maybe ([Reduce GPU], Lambda GPU)
forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm GPU
form,
Reduce Commutativity
red_comm Lambda GPU
red_lam [SubExp]
red_nes <- [Reduce GPU] -> Reduce GPU
forall rep. Buildable rep => [Reduce rep] -> Reduce rep
singleReduce [Reduce GPU]
reds,
Lambda GPU -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda GPU
map_lam [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda GPU -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda GPU
red_lam,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [VName] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
arrs,
(Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([Type] -> Bool) -> [Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda GPU
map_lam,
(Param Type -> Bool) -> [Param Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> (Param Type -> Type) -> Param Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType) ([Param Type] -> Bool) -> [Param Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> [LParam GPU]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
map_lam,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Attr
"unroll" Attr -> Attrs -> Bool
`inAttrs` StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs (Stm GPU -> StmAux (ExpDec GPU)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm GPU
stm) =
(SubExp, [VName],
(Commutativity, Lambda GPU, [SubExp], Lambda GPU))
-> Maybe
(SubExp, [VName],
(Commutativity, Lambda GPU, [SubExp], Lambda GPU))
forall a. a -> Maybe a
Just (SubExp
w, [VName]
arrs, (Commutativity
red_comm, Lambda GPU
red_lam, [SubExp]
red_nes, Lambda GPU
map_lam))
| Bool
otherwise =
Maybe
(SubExp, [VName],
(Commutativity, Lambda GPU, [SubExp], Lambda GPU))
forall a. Maybe a
Nothing
data InputArray
= InputTile [Int] VName
| InputDontTile VName
tiledInputs :: [InputArray] -> [(VName, [Int])]
tiledInputs :: [InputArray] -> [(VName, [Int])]
tiledInputs = (InputArray -> Maybe (VName, [Int]))
-> [InputArray] -> [(VName, [Int])]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe InputArray -> Maybe (VName, [Int])
f
where
f :: InputArray -> Maybe (VName, [Int])
f (InputTile [Int]
perm VName
arr) = (VName, [Int]) -> Maybe (VName, [Int])
forall a. a -> Maybe a
Just (VName
arr, [Int]
perm)
f InputDontTile {} = Maybe (VName, [Int])
forall a. Maybe a
Nothing
data InputTile
= InputTiled [Int] VName
| InputUntiled VName
inputsToTiles :: [InputArray] -> [VName] -> [InputTile]
inputsToTiles :: [InputArray] -> [VName] -> [InputTile]
inputsToTiles (InputTile [Int]
perm VName
_ : [InputArray]
inputs) (VName
tile : [VName]
tiles) =
[Int] -> VName -> InputTile
InputTiled [Int]
perm VName
tile InputTile -> [InputTile] -> [InputTile]
forall a. a -> [a] -> [a]
: [InputArray] -> [VName] -> [InputTile]
inputsToTiles [InputArray]
inputs [VName]
tiles
inputsToTiles (InputDontTile VName
arr : [InputArray]
inputs) [VName]
tiles =
VName -> InputTile
InputUntiled VName
arr InputTile -> [InputTile] -> [InputTile]
forall a. a -> [a] -> [a]
: [InputArray] -> [VName] -> [InputTile]
inputsToTiles [InputArray]
inputs [VName]
tiles
inputsToTiles [InputArray]
_ [VName]
_ = []
sliceUntiled ::
(MonadBuilder m) =>
VName ->
SubExp ->
SubExp ->
SubExp ->
m VName
sliceUntiled :: forall (m :: * -> *).
MonadBuilder m =>
VName -> SubExp -> SubExp -> SubExp -> m VName
sliceUntiled VName
arr SubExp
tile_id SubExp
full_tile_size SubExp
this_tile_size = do
Type
arr_t <- VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
SubExp
slice_offset <-
[Char] -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"slice_offset" (Exp (Rep m) -> m SubExp) -> m (Exp (Rep m)) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> m (Exp (Rep m))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_id TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
full_tile_size)
let slice :: DimIndex SubExp
slice = SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
slice_offset SubExp
this_tile_size (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)
[Char] -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"untiled_slice" (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_t [DimIndex SubExp
slice]
data PrivStms = PrivStms (Stms GPU) ReadPrelude
privStms :: Stms GPU -> PrivStms
privStms :: Stms GPU -> PrivStms
privStms Stms GPU
stms = Stms GPU -> ReadPrelude -> PrivStms
PrivStms Stms GPU
stms (ReadPrelude -> PrivStms) -> ReadPrelude -> PrivStms
forall a b. (a -> b) -> a -> b
$ BuilderT GPU (State VNameSource) () -> ReadPrelude
forall a b. a -> b -> a
const (BuilderT GPU (State VNameSource) () -> ReadPrelude)
-> BuilderT GPU (State VNameSource) () -> ReadPrelude
forall a b. (a -> b) -> a -> b
$ () -> BuilderT GPU (State VNameSource) ()
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
addPrivStms :: [DimIndex SubExp] -> PrivStms -> Builder GPU ()
addPrivStms :: [DimIndex SubExp]
-> PrivStms -> BuilderT GPU (State VNameSource) ()
addPrivStms [DimIndex SubExp]
local_slice (PrivStms Stms GPU
stms ReadPrelude
readPrelude) = do
ReadPrelude
readPrelude [DimIndex SubExp]
local_slice
Stms (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms (Rep (BuilderT GPU (State VNameSource)))
Stms GPU
stms
instance Semigroup PrivStms where
PrivStms Stms GPU
stms_x ReadPrelude
readPrelude_x <> :: PrivStms -> PrivStms -> PrivStms
<> PrivStms Stms GPU
stms_y ReadPrelude
readPrelude_y =
Stms GPU -> ReadPrelude -> PrivStms
PrivStms Stms GPU
stms_z ReadPrelude
readPrelude_z
where
stms_z :: Stms GPU
stms_z = Stms GPU
stms_x Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<> Stms GPU
stms_y
readPrelude_z :: ReadPrelude
readPrelude_z [DimIndex SubExp]
slice = ReadPrelude
readPrelude_x [DimIndex SubExp]
slice BuilderT GPU (State VNameSource) ()
-> BuilderT GPU (State VNameSource) ()
-> BuilderT GPU (State VNameSource) ()
forall a b.
BuilderT GPU (State VNameSource) a
-> BuilderT GPU (State VNameSource) b
-> BuilderT GPU (State VNameSource) b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ReadPrelude
readPrelude_y [DimIndex SubExp]
slice
instance Monoid PrivStms where
mempty :: PrivStms
mempty = Stms GPU -> PrivStms
privStms Stms GPU
forall a. Monoid a => a
mempty
type ReadPrelude = [DimIndex SubExp] -> Builder GPU ()
data ProcessTileArgs = ProcessTileArgs
{ ProcessTileArgs -> PrivStms
processPrivStms :: PrivStms,
ProcessTileArgs -> Commutativity
processComm :: Commutativity,
ProcessTileArgs -> Lambda GPU
processRedLam :: Lambda GPU,
ProcessTileArgs -> Lambda GPU
processMapLam :: Lambda GPU,
ProcessTileArgs -> [InputTile]
processTiles :: [InputTile],
ProcessTileArgs -> [VName]
processAcc :: [VName],
ProcessTileArgs -> SubExp
processTileId :: SubExp
}
data ResidualTileArgs = ResidualTileArgs
{ ResidualTileArgs -> PrivStms
residualPrivStms :: PrivStms,
ResidualTileArgs -> Commutativity
residualComm :: Commutativity,
ResidualTileArgs -> Lambda GPU
residualRedLam :: Lambda GPU,
ResidualTileArgs -> Lambda GPU
residualMapLam :: Lambda GPU,
ResidualTileArgs -> [InputArray]
residualInput :: [InputArray],
ResidualTileArgs -> [VName]
residualAcc :: [VName],
ResidualTileArgs -> SubExp
residualInputSize :: SubExp,
ResidualTileArgs -> SubExp
residualNumWholeTiles :: SubExp
}
data Tiling = Tiling
{ Tiling
-> [Char]
-> ResultManifest
-> (PrimExp VName -> [DimIndex SubExp] -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName]
tilingSegMap ::
String ->
ResultManifest ->
(PrimExp VName -> [DimIndex SubExp] -> Builder GPU Result) ->
Builder GPU [VName],
Tiling
-> TileKind
-> PrivStms
-> SubExp
-> [InputArray]
-> Builder GPU [InputTile]
tilingReadTile ::
TileKind ->
PrivStms ->
SubExp ->
[InputArray] ->
Builder GPU [InputTile],
Tiling
-> ProcessTileArgs -> BuilderT GPU (State VNameSource) [VName]
tilingProcessTile ::
ProcessTileArgs ->
Builder GPU [VName],
Tiling
-> ResidualTileArgs -> BuilderT GPU (State VNameSource) [VName]
tilingProcessResidualTile ::
ResidualTileArgs ->
Builder GPU [VName],
Tiling -> VName -> BuilderT GPU (State VNameSource) KernelResult
tilingTileReturns :: VName -> Builder GPU KernelResult,
Tiling -> SegSpace
tilingSpace :: SegSpace,
Tiling -> Shape
tilingTileShape :: Shape,
Tiling -> SegLevel
tilingLevel :: SegLevel,
Tiling -> Builder GPU SubExp
tilingNumWholeTiles :: Builder GPU SubExp
}
type DoTiling gtids kdims =
gtids -> kdims -> SubExp -> Builder GPU Tiling
protectOutOfBounds ::
String ->
PrimExp VName ->
[Type] ->
Builder GPU Result ->
Builder GPU [VName]
protectOutOfBounds :: [Char]
-> PrimExp VName
-> [Type]
-> Builder GPU Result
-> BuilderT GPU (State VNameSource) [VName]
protectOutOfBounds [Char]
desc PrimExp VName
in_bounds [Type]
ts Builder GPU Result
m = do
Body GPU
m_body <- BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
m (Body (Rep m)) -> m (Body (Rep m))
insertStmsM (BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource)))))
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ Stms GPU -> Result -> Body GPU
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody Stms GPU
forall a. Monoid a => a
mempty (Result -> Body GPU)
-> Builder GPU Result
-> BuilderT GPU (State VNameSource) (Body GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Builder GPU Result
m
let m_body_free :: [VName]
m_body_free = Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Body GPU -> Names
forall a. FreeIn a => a -> Names
freeIn Body GPU
m_body
[(Type, VName)]
t_to_v <-
((Type, VName) -> Bool) -> [(Type, VName)] -> [(Type, VName)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Type -> Bool
forall shape u. TypeBase shape u -> Bool
isAcc (Type -> Bool) -> ((Type, VName) -> Type) -> (Type, VName) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type, VName) -> Type
forall a b. (a, b) -> a
fst)
([(Type, VName)] -> [(Type, VName)])
-> BuilderT GPU (State VNameSource) [(Type, VName)]
-> BuilderT GPU (State VNameSource) [(Type, VName)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([Type] -> [VName] -> [(Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip ([Type] -> [VName] -> [(Type, VName)])
-> BuilderT GPU (State VNameSource) [Type]
-> BuilderT GPU (State VNameSource) ([VName] -> [(Type, VName)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> BuilderT GPU (State VNameSource) Type)
-> [VName] -> BuilderT GPU (State VNameSource) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> BuilderT GPU (State VNameSource) Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
m_body_free BuilderT GPU (State VNameSource) ([VName] -> [(Type, VName)])
-> BuilderT GPU (State VNameSource) [VName]
-> BuilderT GPU (State VNameSource) [(Type, VName)]
forall a b.
BuilderT GPU (State VNameSource) (a -> b)
-> BuilderT GPU (State VNameSource) a
-> BuilderT GPU (State VNameSource) b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [VName] -> BuilderT GPU (State VNameSource) [VName]
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName]
m_body_free)
let blank :: Type
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
blank Type
t = BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
-> (VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource)))))
-> Maybe VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Type
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *). MonadBuilder m => Type -> m (Exp (Rep m))
eBlank Type
t) (Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource)))))
-> (VName -> Exp (Rep (BuilderT GPU (State VNameSource))))
-> VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource))))
-> (VName -> BasicOp)
-> VName
-> Exp (Rep (BuilderT GPU (State VNameSource)))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> (VName -> SubExp) -> VName -> BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) (Maybe VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource)))))
-> Maybe VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ Type -> [(Type, VName)] -> Maybe VName
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Type
t [(Type, VName)]
t_to_v
[Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
desc (Exp GPU -> BuilderT GPU (State VNameSource) [VName])
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf (PrimExp VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
PrimExp VName -> m (Exp (Rep m))
toExp PrimExp VName
in_bounds) (Body GPU -> BuilderT GPU (State VNameSource) (Body GPU)
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Body GPU
m_body) ([BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource)))))
-> [BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (Type
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource)))))
-> [Type]
-> [BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))]
forall a b. (a -> b) -> [a] -> [b]
map Type
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
blank [Type]
ts)
postludeGeneric ::
Tiling ->
PrivStms ->
Pat Type ->
[VName] ->
Stms GPU ->
Result ->
[Type] ->
Builder GPU [VName]
postludeGeneric :: Tiling
-> PrivStms
-> Pat Type
-> [VName]
-> Stms GPU
-> Result
-> [Type]
-> BuilderT GPU (State VNameSource) [VName]
postludeGeneric Tiling
tiling PrivStms
privstms Pat Type
pat [VName]
accs' Stms GPU
poststms Result
poststms_res [Type]
res_ts =
Tiling
-> [Char]
-> ResultManifest
-> (PrimExp VName -> [DimIndex SubExp] -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName]
tilingSegMap Tiling
tiling [Char]
"thread_res" ResultManifest
ResultPrivate ((PrimExp VName -> [DimIndex SubExp] -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName])
-> (PrimExp VName -> [DimIndex SubExp] -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \PrimExp VName
in_bounds [DimIndex SubExp]
slice -> do
[(VName, VName)]
-> ((VName, VName) -> BuilderT GPU (State VNameSource) ())
-> BuilderT GPU (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
pat) [VName]
accs') (((VName, VName) -> BuilderT GPU (State VNameSource) ())
-> BuilderT GPU (State VNameSource) ())
-> ((VName, VName) -> BuilderT GPU (State VNameSource) ())
-> BuilderT GPU (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ \(VName
us, VName
everyone) -> do
Type
everyone_t <- VName -> BuilderT GPU (State VNameSource) Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
everyone
[VName]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
us] (Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ())
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource))))
-> BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
everyone (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
everyone_t [DimIndex SubExp]
slice
if Stms GPU
poststms Stms GPU -> Stms GPU -> Bool
forall a. Eq a => a -> a -> Bool
== Stms GPU
forall a. Monoid a => a
mempty
then do
[DimIndex SubExp]
-> PrivStms -> BuilderT GPU (State VNameSource) ()
addPrivStms [DimIndex SubExp]
slice PrivStms
privstms
Result -> Builder GPU Result
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
poststms_res
else ([VName] -> Result)
-> BuilderT GPU (State VNameSource) [VName] -> Builder GPU Result
forall a b.
(a -> b)
-> BuilderT GPU (State VNameSource) a
-> BuilderT GPU (State VNameSource) b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes (BuilderT GPU (State VNameSource) [VName] -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName] -> Builder GPU Result
forall a b. (a -> b) -> a -> b
$
[Char]
-> PrimExp VName
-> [Type]
-> Builder GPU Result
-> BuilderT GPU (State VNameSource) [VName]
protectOutOfBounds [Char]
"postlude" PrimExp VName
in_bounds [Type]
res_ts (Builder GPU Result -> BuilderT GPU (State VNameSource) [VName])
-> Builder GPU Result -> BuilderT GPU (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ do
[DimIndex SubExp]
-> PrivStms -> BuilderT GPU (State VNameSource) ()
addPrivStms [DimIndex SubExp]
slice PrivStms
privstms
Stms (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms (Rep (BuilderT GPU (State VNameSource)))
Stms GPU
poststms
Result -> Builder GPU Result
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
poststms_res
type TiledBody = Names -> PrivStms -> Builder GPU [VName]
tileGeneric ::
DoTiling gtids kdims ->
[Type] ->
Pat Type ->
gtids ->
kdims ->
SubExp ->
(Commutativity, Lambda GPU, [SubExp], Lambda GPU) ->
[InputArray] ->
Stms GPU ->
Result ->
TileM (Stms GPU, Tiling, TiledBody)
tileGeneric :: forall gtids kdims.
DoTiling gtids kdims
-> [Type]
-> Pat Type
-> gtids
-> kdims
-> SubExp
-> (Commutativity, Lambda GPU, [SubExp], Lambda GPU)
-> [InputArray]
-> Stms GPU
-> Result
-> ReaderT
(Scope GPU) (State VNameSource) (Stms GPU, Tiling, TiledBody)
tileGeneric DoTiling gtids kdims
doTiling [Type]
res_ts Pat Type
pat gtids
gtids kdims
kdims SubExp
w (Commutativity, Lambda GPU, [SubExp], Lambda GPU)
form [InputArray]
inputs Stms GPU
poststms Result
poststms_res = do
(Tiling
tiling, Stms GPU
tiling_stms) <- Builder GPU Tiling
-> ReaderT (Scope GPU) (State VNameSource) (Tiling, Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder GPU Tiling
-> ReaderT (Scope GPU) (State VNameSource) (Tiling, Stms GPU))
-> Builder GPU Tiling
-> ReaderT (Scope GPU) (State VNameSource) (Tiling, Stms GPU)
forall a b. (a -> b) -> a -> b
$ DoTiling gtids kdims
doTiling gtids
gtids kdims
kdims SubExp
w
(Stms GPU, Tiling, TiledBody)
-> ReaderT
(Scope GPU) (State VNameSource) (Stms GPU, Tiling, TiledBody)
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
tiling_stms, Tiling
tiling, Tiling -> TiledBody
tiledBody Tiling
tiling)
where
(Commutativity
red_comm, Lambda GPU
red_lam, [SubExp]
red_nes, Lambda GPU
map_lam) = (Commutativity, Lambda GPU, [SubExp], Lambda GPU)
form
tiledBody :: Tiling -> Names -> PrivStms -> Builder GPU [VName]
tiledBody :: Tiling -> TiledBody
tiledBody Tiling
tiling Names
_private PrivStms
privstms = do
let tile_shape :: Shape
tile_shape = Tiling -> Shape
tilingTileShape Tiling
tiling
SubExp
num_whole_tiles <- Tiling -> Builder GPU SubExp
tilingNumWholeTiles Tiling
tiling
[VName]
mergeinits <- Tiling
-> [Char]
-> ResultManifest
-> (PrimExp VName -> [DimIndex SubExp] -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName]
tilingSegMap Tiling
tiling [Char]
"mergeinit" ResultManifest
ResultPrivate ((PrimExp VName -> [DimIndex SubExp] -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName])
-> (PrimExp VName -> [DimIndex SubExp] -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \PrimExp VName
in_bounds [DimIndex SubExp]
slice ->
if [SubExp] -> Names
forall a. FreeIn a => a -> Names
freeIn [SubExp]
red_nes Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
== Names
forall a. Monoid a => a
mempty
then Result -> Builder GPU Result
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> Builder GPU Result) -> Result -> Builder GPU Result
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes [SubExp]
red_nes
else ([VName] -> Result)
-> BuilderT GPU (State VNameSource) [VName] -> Builder GPU Result
forall a b.
(a -> b)
-> BuilderT GPU (State VNameSource) a
-> BuilderT GPU (State VNameSource) b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes (BuilderT GPU (State VNameSource) [VName] -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName] -> Builder GPU Result
forall a b. (a -> b) -> a -> b
$
[Char]
-> PrimExp VName
-> [Type]
-> Builder GPU Result
-> BuilderT GPU (State VNameSource) [VName]
protectOutOfBounds [Char]
"neutral" PrimExp VName
in_bounds (Lambda GPU -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda GPU
red_lam) (Builder GPU Result -> BuilderT GPU (State VNameSource) [VName])
-> Builder GPU Result -> BuilderT GPU (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ do
[DimIndex SubExp]
-> PrivStms -> BuilderT GPU (State VNameSource) ()
addPrivStms [DimIndex SubExp]
slice PrivStms
privstms
Result -> Builder GPU Result
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> Builder GPU Result) -> Result -> Builder GPU Result
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes [SubExp]
red_nes
[(Param (TypeBase Shape Uniqueness), SubExp)]
merge <- [(Param Type, VName)]
-> ((Param Type, VName)
-> BuilderT
GPU
(State VNameSource)
(Param (TypeBase Shape Uniqueness), SubExp))
-> BuilderT
GPU
(State VNameSource)
[(Param (TypeBase Shape Uniqueness), SubExp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda GPU -> [LParam GPU]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
red_lam) [VName]
mergeinits) (((Param Type, VName)
-> BuilderT
GPU
(State VNameSource)
(Param (TypeBase Shape Uniqueness), SubExp))
-> BuilderT
GPU
(State VNameSource)
[(Param (TypeBase Shape Uniqueness), SubExp)])
-> ((Param Type, VName)
-> BuilderT
GPU
(State VNameSource)
(Param (TypeBase Shape Uniqueness), SubExp))
-> BuilderT
GPU
(State VNameSource)
[(Param (TypeBase Shape Uniqueness), SubExp)]
forall a b. (a -> b) -> a -> b
$ \(Param Type
p, VName
mergeinit) ->
(,)
(Param (TypeBase Shape Uniqueness)
-> SubExp -> (Param (TypeBase Shape Uniqueness), SubExp))
-> BuilderT
GPU (State VNameSource) (Param (TypeBase Shape Uniqueness))
-> BuilderT
GPU
(State VNameSource)
(SubExp -> (Param (TypeBase Shape Uniqueness), SubExp))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char]
-> TypeBase Shape Uniqueness
-> BuilderT
GPU (State VNameSource) (Param (TypeBase Shape Uniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam
(VName -> [Char]
baseString (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p) [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_merge")
(Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
p Type -> Shape -> Type
`arrayOfShape` Shape
tile_shape Type -> Uniqueness -> TypeBase Shape Uniqueness
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
`toDecl` Uniqueness
Unique)
BuilderT
GPU
(State VNameSource)
(SubExp -> (Param (TypeBase Shape Uniqueness), SubExp))
-> Builder GPU SubExp
-> BuilderT
GPU (State VNameSource) (Param (TypeBase Shape Uniqueness), SubExp)
forall a b.
BuilderT GPU (State VNameSource) (a -> b)
-> BuilderT GPU (State VNameSource) a
-> BuilderT GPU (State VNameSource) b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> Builder GPU SubExp
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> SubExp
Var VName
mergeinit)
VName
tile_id <- [Char] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"tile_id"
let loopform :: LoopForm
loopform = VName -> IntType -> SubExp -> LoopForm
ForLoop VName
tile_id IntType
Int64 SubExp
num_whole_tiles
Body GPU
loopbody <- Body GPU -> BuilderT GPU (State VNameSource) (Body GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody (Body GPU -> BuilderT GPU (State VNameSource) (Body GPU))
-> (Builder GPU Result
-> BuilderT GPU (State VNameSource) (Body GPU))
-> Builder GPU Result
-> BuilderT GPU (State VNameSource) (Body GPU)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Builder GPU Result -> BuilderT GPU (State VNameSource) (Body GPU)
forall rep (m :: * -> *) somerep.
(Buildable rep, MonadFreshNames m, HasScope somerep m,
SameScope somerep rep) =>
Builder rep Result -> m (Body rep)
runBodyBuilder (Builder GPU Result -> BuilderT GPU (State VNameSource) (Body GPU))
-> Builder GPU Result
-> BuilderT GPU (State VNameSource) (Body GPU)
forall a b. (a -> b) -> a -> b
$
Scope GPU -> Builder GPU Result -> Builder GPU Result
forall a.
Scope GPU
-> BuilderT GPU (State VNameSource) a
-> BuilderT GPU (State VNameSource) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (LoopForm -> Scope GPU
forall rep. LoopForm -> Scope rep
scopeOfLoopForm LoopForm
loopform Scope GPU -> Scope GPU -> Scope GPU
forall a. Semigroup a => a -> a -> a
<> [Param (TypeBase Shape Uniqueness)] -> Scope GPU
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams (((Param (TypeBase Shape Uniqueness), SubExp)
-> Param (TypeBase Shape Uniqueness))
-> [(Param (TypeBase Shape Uniqueness), SubExp)]
-> [Param (TypeBase Shape Uniqueness)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (TypeBase Shape Uniqueness), SubExp)
-> Param (TypeBase Shape Uniqueness)
forall a b. (a, b) -> a
fst [(Param (TypeBase Shape Uniqueness), SubExp)]
merge)) (Builder GPU Result -> Builder GPU Result)
-> Builder GPU Result -> Builder GPU Result
forall a b. (a -> b) -> a -> b
$ do
[InputTile]
tile <- Tiling
-> TileKind
-> PrivStms
-> SubExp
-> [InputArray]
-> Builder GPU [InputTile]
tilingReadTile Tiling
tiling TileKind
TilePartial PrivStms
privstms (VName -> SubExp
Var VName
tile_id) [InputArray]
inputs
let accs :: [VName]
accs =
((Param (TypeBase Shape Uniqueness), SubExp) -> VName)
-> [(Param (TypeBase Shape Uniqueness), SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName (Param (TypeBase Shape Uniqueness) -> VName)
-> ((Param (TypeBase Shape Uniqueness), SubExp)
-> Param (TypeBase Shape Uniqueness))
-> (Param (TypeBase Shape Uniqueness), SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param (TypeBase Shape Uniqueness), SubExp)
-> Param (TypeBase Shape Uniqueness)
forall a b. (a, b) -> a
fst) [(Param (TypeBase Shape Uniqueness), SubExp)]
merge
tile_args :: ProcessTileArgs
tile_args =
PrivStms
-> Commutativity
-> Lambda GPU
-> Lambda GPU
-> [InputTile]
-> [VName]
-> SubExp
-> ProcessTileArgs
ProcessTileArgs PrivStms
privstms Commutativity
red_comm Lambda GPU
red_lam Lambda GPU
map_lam [InputTile]
tile [VName]
accs (VName -> SubExp
Var VName
tile_id)
[VName] -> Result
varsRes ([VName] -> Result)
-> BuilderT GPU (State VNameSource) [VName] -> Builder GPU Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Tiling
-> ProcessTileArgs -> BuilderT GPU (State VNameSource) [VName]
tilingProcessTile Tiling
tiling ProcessTileArgs
tile_args
[VName]
accs <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"accs" (Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) [VName])
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ [(FParam GPU, SubExp)] -> LoopForm -> Body GPU -> Exp GPU
forall rep.
[(FParam rep, SubExp)] -> LoopForm -> Body rep -> Exp rep
Loop [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam GPU, SubExp)]
merge LoopForm
loopform Body GPU
loopbody
Lambda GPU
red_lam' <- Lambda GPU -> BuilderT GPU (State VNameSource) (Lambda GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda GPU
red_lam
Lambda GPU
map_lam' <- Lambda GPU -> BuilderT GPU (State VNameSource) (Lambda GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda GPU
map_lam
let residual_args :: ResidualTileArgs
residual_args =
PrivStms
-> Commutativity
-> Lambda GPU
-> Lambda GPU
-> [InputArray]
-> [VName]
-> SubExp
-> SubExp
-> ResidualTileArgs
ResidualTileArgs PrivStms
privstms Commutativity
red_comm Lambda GPU
red_lam' Lambda GPU
map_lam' [InputArray]
inputs [VName]
accs SubExp
w SubExp
num_whole_tiles
[VName]
accs' <- Tiling
-> ResidualTileArgs -> BuilderT GPU (State VNameSource) [VName]
tilingProcessResidualTile Tiling
tiling ResidualTileArgs
residual_args
Tiling
-> PrivStms
-> Pat Type
-> [VName]
-> Stms GPU
-> Result
-> [Type]
-> BuilderT GPU (State VNameSource) [VName]
postludeGeneric Tiling
tiling PrivStms
privstms Pat Type
pat [VName]
accs' Stms GPU
poststms Result
poststms_res [Type]
res_ts
mkReadPreludeValues :: [VName] -> [VName] -> ReadPrelude
mkReadPreludeValues :: [VName] -> [VName] -> ReadPrelude
mkReadPreludeValues [VName]
prestms_live_arrs [VName]
prestms_live [DimIndex SubExp]
slice =
([()] -> ())
-> BuilderT GPU (State VNameSource) [()]
-> BuilderT GPU (State VNameSource) ()
forall a b.
(a -> b)
-> BuilderT GPU (State VNameSource) a
-> BuilderT GPU (State VNameSource) b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [()] -> ()
forall a. Monoid a => [a] -> a
mconcat (BuilderT GPU (State VNameSource) [()]
-> BuilderT GPU (State VNameSource) ())
-> (((VName, VName) -> BuilderT GPU (State VNameSource) ())
-> BuilderT GPU (State VNameSource) [()])
-> ((VName, VName) -> BuilderT GPU (State VNameSource) ())
-> BuilderT GPU (State VNameSource) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(VName, VName)]
-> ((VName, VName) -> BuilderT GPU (State VNameSource) ())
-> BuilderT GPU (State VNameSource) [()]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
prestms_live_arrs [VName]
prestms_live) (((VName, VName) -> BuilderT GPU (State VNameSource) ())
-> BuilderT GPU (State VNameSource) ())
-> ((VName, VName) -> BuilderT GPU (State VNameSource) ())
-> BuilderT GPU (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ \(VName
arr, VName
v) -> do
Type
arr_t <- VName -> BuilderT GPU (State VNameSource) Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
[VName]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
v] (Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ())
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource))))
-> BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_t [DimIndex SubExp]
slice
tileReturns :: [(VName, SubExp)] -> [(SubExp, SubExp)] -> VName -> Builder GPU KernelResult
tileReturns :: [(VName, SubExp)]
-> [(SubExp, SubExp)]
-> VName
-> BuilderT GPU (State VNameSource) KernelResult
tileReturns [(VName, SubExp)]
dims_on_top [(SubExp, SubExp)]
dims VName
arr = do
let unit_dims :: [SubExp]
unit_dims = Int -> SubExp -> [SubExp]
forall a. Int -> a -> [a]
replicate ([(VName, SubExp)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(VName, SubExp)]
dims_on_top) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)
Type
arr_t <- VName -> BuilderT GPU (State VNameSource) Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
VName
arr' <-
if [(VName, SubExp)] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(VName, SubExp)]
dims_on_top Bool -> Bool -> Bool
|| [SubExp] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
arr_t)
then VName -> BuilderT GPU (State VNameSource) VName
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
arr
else do
let new_shape :: Shape
new_shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$ [SubExp]
unit_dims [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
arr_t
[Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
arr) (Exp GPU -> BuilderT GPU (State VNameSource) VName)
-> (BasicOp -> Exp GPU)
-> BasicOp
-> BuilderT GPU (State VNameSource) VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> BuilderT GPU (State VNameSource) VName)
-> BasicOp -> BuilderT GPU (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$
ReshapeKind -> Shape -> VName -> BasicOp
Reshape ReshapeKind
ReshapeArbitrary Shape
new_shape VName
arr
let tile_dims :: [(SubExp, SubExp)]
tile_dims = [SubExp] -> [SubExp] -> [(SubExp, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
dims_on_top) [SubExp]
unit_dims [(SubExp, SubExp)] -> [(SubExp, SubExp)] -> [(SubExp, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(SubExp, SubExp)]
dims
KernelResult -> BuilderT GPU (State VNameSource) KernelResult
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelResult -> BuilderT GPU (State VNameSource) KernelResult)
-> KernelResult -> BuilderT GPU (State VNameSource) KernelResult
forall a b. (a -> b) -> a -> b
$ Certs -> [(SubExp, SubExp)] -> VName -> KernelResult
TileReturns Certs
forall a. Monoid a => a
mempty [(SubExp, SubExp)]
tile_dims VName
arr'
is1DTileable :: VName -> M.Map VName Names -> VName -> InputArray
is1DTileable :: VName -> AliasTable -> VName -> InputArray
is1DTileable VName
gtid AliasTable
variance VName
arr
| Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Names -> Bool
nameIn VName
gtid (Names -> Bool) -> Names -> Bool
forall a b. (a -> b) -> a -> b
$ Names -> VName -> AliasTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
arr AliasTable
variance =
[Int] -> VName -> InputArray
InputTile [Int
0] VName
arr
| Bool
otherwise =
VName -> InputArray
InputDontTile VName
arr
reconstructGtids1D ::
Count BlockSize SubExp ->
VName ->
VName ->
VName ->
Builder GPU ()
reconstructGtids1D :: Count BlockSize SubExp
-> VName -> VName -> VName -> BuilderT GPU (State VNameSource) ()
reconstructGtids1D Count BlockSize SubExp
tblock_size VName
gtid VName
gid VName
ltid =
[VName]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
gtid]
(Exp GPU -> BuilderT GPU (State VNameSource) ())
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gid TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 (Count BlockSize SubExp -> SubExp
forall {k} (u :: k) e. Count u e -> e
unCount Count BlockSize SubExp
tblock_size) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid)
readTile1D ::
SubExp ->
VName ->
VName ->
KernelGrid ->
TileKind ->
PrivStms ->
SubExp ->
[InputArray] ->
Builder GPU [InputTile]
readTile1D :: SubExp
-> VName
-> VName
-> KernelGrid
-> TileKind
-> PrivStms
-> SubExp
-> [InputArray]
-> Builder GPU [InputTile]
readTile1D SubExp
tile_size VName
gid VName
gtid (KernelGrid Count NumBlocks SubExp
_num_tblocks Count BlockSize SubExp
tblock_size) TileKind
kind PrivStms
privstms SubExp
tile_id [InputArray]
inputs =
([VName] -> [InputTile])
-> BuilderT GPU (State VNameSource) [VName]
-> Builder GPU [InputTile]
forall a b.
(a -> b)
-> BuilderT GPU (State VNameSource) a
-> BuilderT GPU (State VNameSource) b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([InputArray] -> [VName] -> [InputTile]
inputsToTiles [InputArray]
inputs)
(BuilderT GPU (State VNameSource) [VName]
-> Builder GPU [InputTile])
-> ((VName -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName])
-> (VName -> Builder GPU Result)
-> Builder GPU [InputTile]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char]
-> SegLevel
-> ResultManifest
-> SubExp
-> (VName -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName]
segMap1D [Char]
"full_tile" SegLevel
lvl ResultManifest
ResultNoSimplify SubExp
tile_size
((VName -> Builder GPU Result) -> Builder GPU [InputTile])
-> (VName -> Builder GPU Result) -> Builder GPU [InputTile]
forall a b. (a -> b) -> a -> b
$ \VName
ltid -> do
SubExp
j <-
[Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"j"
(Exp GPU -> Builder GPU SubExp)
-> BuilderT GPU (State VNameSource) (Exp GPU) -> Builder GPU SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_id TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid)
Count BlockSize SubExp
-> VName -> VName -> VName -> BuilderT GPU (State VNameSource) ()
reconstructGtids1D Count BlockSize SubExp
tblock_size VName
gtid VName
gid VName
ltid
[DimIndex SubExp]
-> PrivStms -> BuilderT GPU (State VNameSource) ()
addPrivStms [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid] PrivStms
privstms
let arrs :: [VName]
arrs = ((VName, [Int]) -> VName) -> [(VName, [Int])] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, [Int]) -> VName
forall a b. (a, b) -> a
fst ([(VName, [Int])] -> [VName]) -> [(VName, [Int])] -> [VName]
forall a b. (a -> b) -> a -> b
$ [InputArray] -> [(VName, [Int])]
tiledInputs [InputArray]
inputs
[Type]
arr_ts <- (VName -> BuilderT GPU (State VNameSource) Type)
-> [VName] -> BuilderT GPU (State VNameSource) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> BuilderT GPU (State VNameSource) Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
arrs
let tile_ts :: [Type]
tile_ts = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType [Type]
arr_ts
w :: SubExp
w = Int -> [Type] -> SubExp
forall u. Int -> [TypeBase Shape u] -> SubExp
arraysSize Int
0 [Type]
arr_ts
let readTileElem :: VName -> BuilderT GPU (State VNameSource) VName
readTileElem VName
arr =
[Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"tile_elem" (BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource))))
-> BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
j])
([VName] -> Result)
-> BuilderT GPU (State VNameSource) [VName] -> Builder GPU Result
forall a b.
(a -> b)
-> BuilderT GPU (State VNameSource) a
-> BuilderT GPU (State VNameSource) b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes (BuilderT GPU (State VNameSource) [VName] -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName] -> Builder GPU Result
forall a b. (a -> b) -> a -> b
$
case TileKind
kind of
TileKind
TilePartial ->
[Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"pre1d"
(Exp GPU -> BuilderT GPU (State VNameSource) [VName])
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource)))))
-> TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 SubExp
j TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
w)
([SubExp] -> Body GPU
forall rep. Buildable rep => [SubExp] -> Body rep
resultBody ([SubExp] -> Body GPU)
-> BuilderT GPU (State VNameSource) [SubExp]
-> BuilderT GPU (State VNameSource) (Body GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> Builder GPU SubExp)
-> [VName] -> BuilderT GPU (State VNameSource) [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((VName -> SubExp)
-> BuilderT GPU (State VNameSource) VName -> Builder GPU SubExp
forall a b.
(a -> b)
-> BuilderT GPU (State VNameSource) a
-> BuilderT GPU (State VNameSource) b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> SubExp
Var (BuilderT GPU (State VNameSource) VName -> Builder GPU SubExp)
-> (VName -> BuilderT GPU (State VNameSource) VName)
-> VName
-> Builder GPU SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> BuilderT GPU (State VNameSource) VName
readTileElem) [VName]
arrs)
([BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource)))))
-> [BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (Type -> BuilderT GPU (State VNameSource) (Exp GPU))
-> [Type] -> [BuilderT GPU (State VNameSource) (Exp GPU)]
forall a b. (a -> b) -> [a] -> [b]
map Type
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
Type -> BuilderT GPU (State VNameSource) (Exp GPU)
forall (m :: * -> *). MonadBuilder m => Type -> m (Exp (Rep m))
eBlank [Type]
tile_ts)
TileKind
TileFull ->
(VName -> BuilderT GPU (State VNameSource) VName)
-> [VName] -> BuilderT GPU (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> BuilderT GPU (State VNameSource) VName
readTileElem [VName]
arrs
where
lvl :: SegLevel
lvl = SegVirt -> SegLevel
SegThreadInBlock SegVirt
SegNoVirt
processTile1D ::
VName ->
VName ->
SubExp ->
SubExp ->
KernelGrid ->
ProcessTileArgs ->
Builder GPU [VName]
processTile1D :: VName
-> VName
-> SubExp
-> SubExp
-> KernelGrid
-> ProcessTileArgs
-> BuilderT GPU (State VNameSource) [VName]
processTile1D VName
gid VName
gtid SubExp
kdim SubExp
tile_size (KernelGrid Count NumBlocks SubExp
_num_tblocks Count BlockSize SubExp
tblock_size) ProcessTileArgs
tile_args = do
let red_comm :: Commutativity
red_comm = ProcessTileArgs -> Commutativity
processComm ProcessTileArgs
tile_args
privstms :: PrivStms
privstms = ProcessTileArgs -> PrivStms
processPrivStms ProcessTileArgs
tile_args
map_lam :: Lambda GPU
map_lam = ProcessTileArgs -> Lambda GPU
processMapLam ProcessTileArgs
tile_args
red_lam :: Lambda GPU
red_lam = ProcessTileArgs -> Lambda GPU
processRedLam ProcessTileArgs
tile_args
tiles :: [InputTile]
tiles = ProcessTileArgs -> [InputTile]
processTiles ProcessTileArgs
tile_args
tile_id :: SubExp
tile_id = ProcessTileArgs -> SubExp
processTileId ProcessTileArgs
tile_args
accs :: [VName]
accs = ProcessTileArgs -> [VName]
processAcc ProcessTileArgs
tile_args
[Char]
-> SegLevel
-> ResultManifest
-> SubExp
-> (VName -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName]
segMap1D [Char]
"acc" SegLevel
lvl ResultManifest
ResultPrivate (Count BlockSize SubExp -> SubExp
forall {k} (u :: k) e. Count u e -> e
unCount Count BlockSize SubExp
tblock_size) ((VName -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName])
-> (VName -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \VName
ltid -> do
Count BlockSize SubExp
-> VName -> VName -> VName -> BuilderT GPU (State VNameSource) ()
reconstructGtids1D Count BlockSize SubExp
tblock_size VName
gtid VName
gid VName
ltid
[DimIndex SubExp]
-> PrivStms -> BuilderT GPU (State VNameSource) ()
addPrivStms [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid] PrivStms
privstms
[SubExp]
thread_accs <- [VName]
-> (VName -> Builder GPU SubExp)
-> BuilderT GPU (State VNameSource) [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
accs ((VName -> Builder GPU SubExp)
-> BuilderT GPU (State VNameSource) [SubExp])
-> (VName -> Builder GPU SubExp)
-> BuilderT GPU (State VNameSource) [SubExp]
forall a b. (a -> b) -> a -> b
$ \VName
acc ->
[Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"acc" (Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp)
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource))))
-> BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
acc (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid]
let sliceTile :: InputTile -> BuilderT GPU (State VNameSource) VName
sliceTile (InputTiled [Int]
_ VName
arr) =
VName -> BuilderT GPU (State VNameSource) VName
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
arr
sliceTile (InputUntiled VName
arr) =
VName
-> SubExp
-> SubExp
-> SubExp
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
VName -> SubExp -> SubExp -> SubExp -> m VName
sliceUntiled VName
arr SubExp
tile_id SubExp
tile_size SubExp
tile_size
[VName]
tiles' <- (InputTile -> BuilderT GPU (State VNameSource) VName)
-> [InputTile] -> BuilderT GPU (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM InputTile -> BuilderT GPU (State VNameSource) VName
sliceTile [InputTile]
tiles
let form' :: ScremaForm GPU
form' = [Reduce GPU] -> Lambda GPU -> ScremaForm GPU
forall rep. [Reduce rep] -> Lambda rep -> ScremaForm rep
redomapSOAC [Commutativity -> Lambda GPU -> [SubExp] -> Reduce GPU
forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
red_comm Lambda GPU
red_lam [SubExp]
thread_accs] Lambda GPU
map_lam
([VName] -> Result)
-> BuilderT GPU (State VNameSource) [VName] -> Builder GPU Result
forall a b.
(a -> b)
-> BuilderT GPU (State VNameSource) a
-> BuilderT GPU (State VNameSource) b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes (BuilderT GPU (State VNameSource) [VName] -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName] -> Builder GPU Result
forall a b. (a -> b) -> a -> b
$
[Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"acc"
(Exp GPU -> BuilderT GPU (State VNameSource) [VName])
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource)))))
-> TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
kdim)
([BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource)))))
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ Op (Rep (BuilderT GPU (State VNameSource)))
-> Exp (Rep (BuilderT GPU (State VNameSource)))
forall rep. Op rep -> Exp rep
Op (Op (Rep (BuilderT GPU (State VNameSource)))
-> Exp (Rep (BuilderT GPU (State VNameSource))))
-> Op (Rep (BuilderT GPU (State VNameSource)))
-> Exp (Rep (BuilderT GPU (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ SOAC GPU -> HostOp SOAC GPU
forall (op :: * -> *) rep. op rep -> HostOp op rep
OtherOp (SOAC GPU -> HostOp SOAC GPU) -> SOAC GPU -> HostOp SOAC GPU
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm GPU -> SOAC GPU
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
tile_size [VName]
tiles' ScremaForm GPU
form'])
([SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [SubExp]
thread_accs)
where
lvl :: SegLevel
lvl = SegVirt -> SegLevel
SegThreadInBlock SegVirt
SegNoVirt
processResidualTile1D ::
VName ->
VName ->
SubExp ->
SubExp ->
KernelGrid ->
ResidualTileArgs ->
Builder GPU [VName]
processResidualTile1D :: VName
-> VName
-> SubExp
-> SubExp
-> KernelGrid
-> ResidualTileArgs
-> BuilderT GPU (State VNameSource) [VName]
processResidualTile1D VName
gid VName
gtid SubExp
kdim SubExp
tile_size KernelGrid
grid ResidualTileArgs
args = do
SubExp
residual_input <-
[Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"residual_input" (Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp)
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource))))
-> BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall a b. (a -> b) -> a -> b
$
BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SRem IntType
Int64 Safety
Unsafe) SubExp
w SubExp
tile_size
[Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"acc_after_residual"
(Exp GPU -> BuilderT GPU (State VNameSource) [VName])
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource)))))
-> TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 SubExp
residual_input TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0)
([SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM ([SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource)))))
-> [SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
accs)
(SubExp -> BuilderT GPU (State VNameSource) (Body GPU)
nonemptyTile SubExp
residual_input)
where
red_comm :: Commutativity
red_comm = ResidualTileArgs -> Commutativity
residualComm ResidualTileArgs
args
map_lam :: Lambda GPU
map_lam = ResidualTileArgs -> Lambda GPU
residualMapLam ResidualTileArgs
args
red_lam :: Lambda GPU
red_lam = ResidualTileArgs -> Lambda GPU
residualRedLam ResidualTileArgs
args
privstms :: PrivStms
privstms = ResidualTileArgs -> PrivStms
residualPrivStms ResidualTileArgs
args
inputs :: [InputArray]
inputs = ResidualTileArgs -> [InputArray]
residualInput ResidualTileArgs
args
accs :: [VName]
accs = ResidualTileArgs -> [VName]
residualAcc ResidualTileArgs
args
num_whole_tiles :: SubExp
num_whole_tiles = ResidualTileArgs -> SubExp
residualNumWholeTiles ResidualTileArgs
args
w :: SubExp
w = ResidualTileArgs -> SubExp
residualInputSize ResidualTileArgs
args
nonemptyTile :: SubExp -> BuilderT GPU (State VNameSource) (Body GPU)
nonemptyTile SubExp
residual_input = Builder GPU Result -> BuilderT GPU (State VNameSource) (Body GPU)
forall rep (m :: * -> *) somerep.
(Buildable rep, MonadFreshNames m, HasScope somerep m,
SameScope somerep rep) =>
Builder rep Result -> m (Body rep)
runBodyBuilder (Builder GPU Result -> BuilderT GPU (State VNameSource) (Body GPU))
-> Builder GPU Result
-> BuilderT GPU (State VNameSource) (Body GPU)
forall a b. (a -> b) -> a -> b
$ do
[InputTile]
full_tiles <-
SubExp
-> VName
-> VName
-> KernelGrid
-> TileKind
-> PrivStms
-> SubExp
-> [InputArray]
-> Builder GPU [InputTile]
readTile1D
SubExp
tile_size
VName
gid
VName
gtid
KernelGrid
grid
TileKind
TilePartial
PrivStms
privstms
SubExp
num_whole_tiles
[InputArray]
inputs
let sliceTile :: InputTile -> BuilderT GPU (State VNameSource) InputTile
sliceTile (InputUntiled VName
arr) =
InputTile -> BuilderT GPU (State VNameSource) InputTile
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (InputTile -> BuilderT GPU (State VNameSource) InputTile)
-> InputTile -> BuilderT GPU (State VNameSource) InputTile
forall a b. (a -> b) -> a -> b
$ VName -> InputTile
InputUntiled VName
arr
sliceTile (InputTiled [Int]
perm VName
tile) = do
let slice :: DimIndex SubExp
slice =
SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) SubExp
residual_input (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)
[Int] -> VName -> InputTile
InputTiled [Int]
perm
(VName -> InputTile)
-> BuilderT GPU (State VNameSource) VName
-> BuilderT GPU (State VNameSource) InputTile
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"partial_tile" (BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource))))
-> BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
tile (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp
slice])
[InputTile]
tiles <- (InputTile -> BuilderT GPU (State VNameSource) InputTile)
-> [InputTile] -> Builder GPU [InputTile]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM InputTile -> BuilderT GPU (State VNameSource) InputTile
sliceTile [InputTile]
full_tiles
let tile_args :: ProcessTileArgs
tile_args =
PrivStms
-> Commutativity
-> Lambda GPU
-> Lambda GPU
-> [InputTile]
-> [VName]
-> SubExp
-> ProcessTileArgs
ProcessTileArgs PrivStms
privstms Commutativity
red_comm Lambda GPU
red_lam Lambda GPU
map_lam [InputTile]
tiles [VName]
accs SubExp
num_whole_tiles
[VName] -> Result
varsRes ([VName] -> Result)
-> BuilderT GPU (State VNameSource) [VName] -> Builder GPU Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName
-> VName
-> SubExp
-> SubExp
-> KernelGrid
-> ProcessTileArgs
-> BuilderT GPU (State VNameSource) [VName]
processTile1D VName
gid VName
gtid SubExp
kdim SubExp
residual_input KernelGrid
grid ProcessTileArgs
tile_args
tiling1d :: [(VName, SubExp)] -> DoTiling VName SubExp
tiling1d :: [(VName, SubExp)] -> DoTiling VName SubExp
tiling1d [(VName, SubExp)]
dims_on_top VName
gtid SubExp
kdim SubExp
w = do
VName
gid <- [Char] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"gid"
VName
gid_flat <- [Char] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"gid_flat"
Name
tile_size_key <- [Char] -> Name
nameFromString ([Char] -> Name) -> (VName -> [Char]) -> VName -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString (VName -> Name)
-> BuilderT GPU (State VNameSource) VName
-> BuilderT GPU (State VNameSource) Name
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"tile_size"
SubExp
tile_size <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"tile_size" (Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp)
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp
forall a b. (a -> b) -> a -> b
$ Op (Rep (BuilderT GPU (State VNameSource)))
-> Exp (Rep (BuilderT GPU (State VNameSource)))
forall rep. Op rep -> Exp rep
Op (Op (Rep (BuilderT GPU (State VNameSource)))
-> Exp (Rep (BuilderT GPU (State VNameSource))))
-> Op (Rep (BuilderT GPU (State VNameSource)))
-> Exp (Rep (BuilderT GPU (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp SOAC GPU
forall (op :: * -> *) rep. SizeOp -> HostOp op rep
SizeOp (SizeOp -> HostOp SOAC GPU) -> SizeOp -> HostOp SOAC GPU
forall a b. (a -> b) -> a -> b
$ Name -> SizeClass -> SizeOp
GetSize Name
tile_size_key SizeClass
SizeThreadBlock
let tblock_size :: SubExp
tblock_size = SubExp
tile_size
(KernelGrid
grid, SegSpace
space) <- do
SubExp
ldim <-
[Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"ldim" (Exp GPU -> Builder GPU SubExp)
-> (BasicOp -> Exp GPU) -> BasicOp -> Builder GPU SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Builder GPU SubExp) -> BasicOp -> Builder GPU SubExp
forall a b. (a -> b) -> a -> b
$
BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SDivUp IntType
Int64 Safety
Unsafe) SubExp
kdim SubExp
tblock_size
SubExp
num_tblocks <-
[Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"computed_num_tblocks"
(Exp GPU -> Builder GPU SubExp)
-> BuilderT GPU (State VNameSource) (Exp GPU) -> Builder GPU SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> SubExp
-> [SubExp]
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) SubExp
ldim (((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
dims_on_top)
(KernelGrid, SegSpace)
-> BuilderT GPU (State VNameSource) (KernelGrid, SegSpace)
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( Count NumBlocks SubExp -> Count BlockSize SubExp -> KernelGrid
KernelGrid (SubExp -> Count NumBlocks SubExp
forall {k} (u :: k) e. e -> Count u e
Count SubExp
num_tblocks) (SubExp -> Count BlockSize SubExp
forall {k} (u :: k) e. e -> Count u e
Count SubExp
tblock_size),
VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
gid_flat ([(VName, SubExp)] -> SegSpace) -> [(VName, SubExp)] -> SegSpace
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)]
dims_on_top [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(VName
gid, SubExp
ldim)]
)
let tiling_lvl :: SegLevel
tiling_lvl = SegVirt -> SegLevel
SegThreadInBlock SegVirt
SegNoVirt
Tiling -> Builder GPU Tiling
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
Tiling
{ tilingSegMap :: [Char]
-> ResultManifest
-> (PrimExp VName -> [DimIndex SubExp] -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName]
tilingSegMap = \[Char]
desc ResultManifest
manifest PrimExp VName -> [DimIndex SubExp] -> Builder GPU Result
f -> [Char]
-> SegLevel
-> ResultManifest
-> SubExp
-> (VName -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName]
segMap1D [Char]
desc SegLevel
tiling_lvl ResultManifest
manifest SubExp
tile_size ((VName -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName])
-> (VName -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \VName
ltid -> do
[VName]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
gtid]
(Exp GPU -> BuilderT GPU (State VNameSource) ())
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gid TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid)
PrimExp VName -> [DimIndex SubExp] -> Builder GPU Result
f (TPrimExp Bool VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Bool VName -> PrimExp VName)
-> TPrimExp Bool VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
kdim) [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid],
tilingReadTile :: TileKind
-> PrivStms -> SubExp -> [InputArray] -> Builder GPU [InputTile]
tilingReadTile =
SubExp
-> VName
-> VName
-> KernelGrid
-> TileKind
-> PrivStms
-> SubExp
-> [InputArray]
-> Builder GPU [InputTile]
readTile1D SubExp
tile_size VName
gid VName
gtid KernelGrid
grid,
tilingProcessTile :: ProcessTileArgs -> BuilderT GPU (State VNameSource) [VName]
tilingProcessTile =
VName
-> VName
-> SubExp
-> SubExp
-> KernelGrid
-> ProcessTileArgs
-> BuilderT GPU (State VNameSource) [VName]
processTile1D VName
gid VName
gtid SubExp
kdim SubExp
tile_size KernelGrid
grid,
tilingProcessResidualTile :: ResidualTileArgs -> BuilderT GPU (State VNameSource) [VName]
tilingProcessResidualTile =
VName
-> VName
-> SubExp
-> SubExp
-> KernelGrid
-> ResidualTileArgs
-> BuilderT GPU (State VNameSource) [VName]
processResidualTile1D VName
gid VName
gtid SubExp
kdim SubExp
tile_size KernelGrid
grid,
tilingTileReturns :: VName -> BuilderT GPU (State VNameSource) KernelResult
tilingTileReturns = [(VName, SubExp)]
-> [(SubExp, SubExp)]
-> VName
-> BuilderT GPU (State VNameSource) KernelResult
tileReturns [(VName, SubExp)]
dims_on_top [(SubExp
kdim, SubExp
tile_size)],
tilingTileShape :: Shape
tilingTileShape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
tile_size],
tilingNumWholeTiles :: Builder GPU SubExp
tilingNumWholeTiles =
[Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"num_whole_tiles" (Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp)
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource))))
-> BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall a b. (a -> b) -> a -> b
$
BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SQuot IntType
Int64 Safety
Unsafe) SubExp
w SubExp
tile_size,
tilingLevel :: SegLevel
tilingLevel = SegVirt -> Maybe KernelGrid -> SegLevel
SegBlock SegVirt
SegNoVirt (KernelGrid -> Maybe KernelGrid
forall a. a -> Maybe a
Just KernelGrid
grid),
tilingSpace :: SegSpace
tilingSpace = SegSpace
space
}
invariantToOneOfTwoInnerDims ::
Names ->
M.Map VName Names ->
[VName] ->
VName ->
Maybe InputArray
invariantToOneOfTwoInnerDims :: Names -> AliasTable -> [VName] -> VName -> Maybe InputArray
invariantToOneOfTwoInnerDims Names
branch_variant AliasTable
variance [VName]
dims VName
arr = do
VName
j : VName
i : [VName]
_ <- [VName] -> Maybe [VName]
forall a. a -> Maybe a
Just ([VName] -> Maybe [VName]) -> [VName] -> Maybe [VName]
forall a b. (a -> b) -> a -> b
$ [VName] -> [VName]
forall a. [a] -> [a]
reverse [VName]
dims
let variant_to :: Names
variant_to = Names -> VName -> AliasTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
arr AliasTable
variance
branch_invariant :: Bool
branch_invariant = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Names -> Bool
nameIn VName
j Names
branch_variant Bool -> Bool -> Bool
|| VName -> Names -> Bool
nameIn VName
i Names
branch_variant
if Bool
branch_invariant Bool -> Bool -> Bool
&& VName
i VName -> Names -> Bool
`nameIn` Names
variant_to Bool -> Bool -> Bool
&& VName
j VName -> Names -> Bool
`notNameIn` Names
variant_to
then InputArray -> Maybe InputArray
forall a. a -> Maybe a
Just (InputArray -> Maybe InputArray) -> InputArray -> Maybe InputArray
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> InputArray
InputTile [Int
0, Int
1] VName
arr
else
if Bool
branch_invariant Bool -> Bool -> Bool
&& VName
j VName -> Names -> Bool
`nameIn` Names
variant_to Bool -> Bool -> Bool
&& VName
i VName -> Names -> Bool
`notNameIn` Names
variant_to
then InputArray -> Maybe InputArray
forall a. a -> Maybe a
Just (InputArray -> Maybe InputArray) -> InputArray -> Maybe InputArray
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> InputArray
InputTile [Int
1, Int
0] VName
arr
else InputArray -> Maybe InputArray
forall a. a -> Maybe a
Just (InputArray -> Maybe InputArray) -> InputArray -> Maybe InputArray
forall a b. (a -> b) -> a -> b
$ VName -> InputArray
InputDontTile VName
arr
reconstructGtids2D ::
SubExp ->
(VName, VName) ->
(VName, VName) ->
(VName, VName) ->
Builder GPU ()
reconstructGtids2D :: SubExp
-> (VName, VName)
-> (VName, VName)
-> (VName, VName)
-> BuilderT GPU (State VNameSource) ()
reconstructGtids2D SubExp
tile_size (VName
gtid_x, VName
gtid_y) (VName
gid_x, VName
gid_y) (VName
ltid_x, VName
ltid_y) = do
[VName]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
gtid_x]
(Exp GPU -> BuilderT GPU (State VNameSource) ())
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gid_x TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid_x)
[VName]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
gtid_y]
(Exp GPU -> BuilderT GPU (State VNameSource) ())
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gid_y TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid_y)
readTile2D ::
(SubExp, SubExp) ->
(VName, VName) ->
(VName, VName) ->
SubExp ->
TileKind ->
PrivStms ->
SubExp ->
[InputArray] ->
Builder GPU [InputTile]
readTile2D :: (SubExp, SubExp)
-> (VName, VName)
-> (VName, VName)
-> SubExp
-> TileKind
-> PrivStms
-> SubExp
-> [InputArray]
-> Builder GPU [InputTile]
readTile2D (SubExp
kdim_x, SubExp
kdim_y) (VName
gtid_x, VName
gtid_y) (VName
gid_x, VName
gid_y) SubExp
tile_size TileKind
kind PrivStms
privstms SubExp
tile_id [InputArray]
inputs =
([VName] -> [InputTile])
-> BuilderT GPU (State VNameSource) [VName]
-> Builder GPU [InputTile]
forall a b.
(a -> b)
-> BuilderT GPU (State VNameSource) a
-> BuilderT GPU (State VNameSource) b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([InputArray] -> [VName] -> [InputTile]
inputsToTiles [InputArray]
inputs)
(BuilderT GPU (State VNameSource) [VName]
-> Builder GPU [InputTile])
-> (((VName, VName) -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName])
-> ((VName, VName) -> Builder GPU Result)
-> Builder GPU [InputTile]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char]
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp)
-> ((VName, VName) -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName]
segMap2D
[Char]
"full_tile"
(SegVirt -> Maybe KernelGrid -> SegLevel
SegThread (SegSeqDims -> SegVirt
SegNoVirtFull ([Int] -> SegSeqDims
SegSeqDims [])) Maybe KernelGrid
forall a. Maybe a
Nothing)
ResultManifest
ResultNoSimplify
(SubExp
tile_size, SubExp
tile_size)
(((VName, VName) -> Builder GPU Result) -> Builder GPU [InputTile])
-> ((VName, VName) -> Builder GPU Result)
-> Builder GPU [InputTile]
forall a b. (a -> b) -> a -> b
$ \(VName
ltid_x, VName
ltid_y) -> do
SubExp
i <-
[Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"i"
(Exp GPU -> Builder GPU SubExp)
-> BuilderT GPU (State VNameSource) (Exp GPU) -> Builder GPU SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_id TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid_x)
SubExp
j <-
[Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"j"
(Exp GPU -> Builder GPU SubExp)
-> BuilderT GPU (State VNameSource) (Exp GPU) -> Builder GPU SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_id TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid_y)
SubExp
-> (VName, VName)
-> (VName, VName)
-> (VName, VName)
-> BuilderT GPU (State VNameSource) ()
reconstructGtids2D SubExp
tile_size (VName
gtid_x, VName
gtid_y) (VName
gid_x, VName
gid_y) (VName
ltid_x, VName
ltid_y)
[DimIndex SubExp]
-> PrivStms -> BuilderT GPU (State VNameSource) ()
addPrivStms [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_x, SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_y] PrivStms
privstms
let arrs_and_perms :: [(VName, [Int])]
arrs_and_perms = [InputArray] -> [(VName, [Int])]
tiledInputs [InputArray]
inputs
readTileElem :: (VName, [Int]) -> BuilderT GPU (State VNameSource) VName
readTileElem (VName
arr, [Int]
perm) =
[Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp
[Char]
"tile_elem"
( BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource))))
-> (Slice SubExp -> BasicOp)
-> Slice SubExp
-> Exp (Rep (BuilderT GPU (State VNameSource)))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> Exp (Rep (BuilderT GPU (State VNameSource))))
-> Slice SubExp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall a b. (a -> b) -> a -> b
$
[DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ [SubExp] -> SubExp
forall a. HasCallStack => [a] -> a
last ([SubExp] -> SubExp) -> [SubExp] -> SubExp
forall a b. (a -> b) -> a -> b
$ [Int] -> [SubExp] -> [SubExp]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm [SubExp
i, SubExp
j]]
)
readTileElemIfInBounds :: (VName, [Int]) -> BuilderT GPU (State VNameSource) (Exp GPU)
readTileElemIfInBounds (VName
arr, [Int]
perm) = do
Type
arr_t <- VName -> BuilderT GPU (State VNameSource) Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
let tile_t :: Type
tile_t = Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
arr_t
w :: SubExp
w = Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
arr_t
idx :: SubExp
idx = [SubExp] -> SubExp
forall a. HasCallStack => [a] -> a
last ([SubExp] -> SubExp) -> [SubExp] -> SubExp
forall a b. (a -> b) -> a -> b
$ [Int] -> [SubExp] -> [SubExp]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm [SubExp
i, SubExp
j]
othercheck :: TPrimExp Bool VName
othercheck =
[TPrimExp Bool VName] -> TPrimExp Bool VName
forall a. HasCallStack => [a] -> a
last ([TPrimExp Bool VName] -> TPrimExp Bool VName)
-> [TPrimExp Bool VName] -> TPrimExp Bool VName
forall a b. (a -> b) -> a -> b
$
[Int] -> [TPrimExp Bool VName] -> [TPrimExp Bool VName]
forall a. [Int] -> [a] -> [a]
rearrangeShape
[Int]
perm
[ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid_y TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
kdim_y,
VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid_x TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
kdim_x
]
BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource)))))
-> TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 SubExp
idx TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
w TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp Bool VName
othercheck)
([BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource)))))
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource))))
-> BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
idx]])
([BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [Type
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *). MonadBuilder m => Type -> m (Exp (Rep m))
eBlank Type
tile_t])
([VName] -> Result)
-> BuilderT GPU (State VNameSource) [VName] -> Builder GPU Result
forall a b.
(a -> b)
-> BuilderT GPU (State VNameSource) a
-> BuilderT GPU (State VNameSource) b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes (BuilderT GPU (State VNameSource) [VName] -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName] -> Builder GPU Result
forall a b. (a -> b) -> a -> b
$
case TileKind
kind of
TileKind
TilePartial ->
((VName, [Int]) -> BuilderT GPU (State VNameSource) VName)
-> [(VName, [Int])] -> BuilderT GPU (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ([Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"pre2d" (Exp GPU -> BuilderT GPU (State VNameSource) VName)
-> ((VName, [Int]) -> BuilderT GPU (State VNameSource) (Exp GPU))
-> (VName, [Int])
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< (VName, [Int]) -> BuilderT GPU (State VNameSource) (Exp GPU)
readTileElemIfInBounds) [(VName, [Int])]
arrs_and_perms
TileKind
TileFull ->
((VName, [Int]) -> BuilderT GPU (State VNameSource) VName)
-> [(VName, [Int])] -> BuilderT GPU (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (VName, [Int]) -> BuilderT GPU (State VNameSource) VName
readTileElem [(VName, [Int])]
arrs_and_perms
findTileSize :: (HasScope rep m) => [InputTile] -> m SubExp
findTileSize :: forall rep (m :: * -> *). HasScope rep m => [InputTile] -> m SubExp
findTileSize [InputTile]
tiles =
case (InputTile -> Maybe VName) -> [InputTile] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe InputTile -> Maybe VName
isTiled [InputTile]
tiles of
VName
v : [VName]
_ -> Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 (Type -> SubExp) -> m Type -> m SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
[] -> SubExp -> m SubExp
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> m SubExp) -> SubExp -> m SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0
where
isTiled :: InputTile -> Maybe VName
isTiled InputUntiled {} = Maybe VName
forall a. Maybe a
Nothing
isTiled (InputTiled [Int]
_ VName
tile) = VName -> Maybe VName
forall a. a -> Maybe a
Just VName
tile
processTile2D ::
(VName, VName) ->
(VName, VName) ->
(SubExp, SubExp) ->
SubExp ->
ProcessTileArgs ->
Builder GPU [VName]
processTile2D :: (VName, VName)
-> (VName, VName)
-> (SubExp, SubExp)
-> SubExp
-> ProcessTileArgs
-> BuilderT GPU (State VNameSource) [VName]
processTile2D (VName
gid_x, VName
gid_y) (VName
gtid_x, VName
gtid_y) (SubExp
kdim_x, SubExp
kdim_y) SubExp
tile_size ProcessTileArgs
tile_args = do
let privstms :: PrivStms
privstms = ProcessTileArgs -> PrivStms
processPrivStms ProcessTileArgs
tile_args
red_comm :: Commutativity
red_comm = ProcessTileArgs -> Commutativity
processComm ProcessTileArgs
tile_args
red_lam :: Lambda GPU
red_lam = ProcessTileArgs -> Lambda GPU
processRedLam ProcessTileArgs
tile_args
map_lam :: Lambda GPU
map_lam = ProcessTileArgs -> Lambda GPU
processMapLam ProcessTileArgs
tile_args
tiles :: [InputTile]
tiles = ProcessTileArgs -> [InputTile]
processTiles ProcessTileArgs
tile_args
accs :: [VName]
accs = ProcessTileArgs -> [VName]
processAcc ProcessTileArgs
tile_args
tile_id :: SubExp
tile_id = ProcessTileArgs -> SubExp
processTileId ProcessTileArgs
tile_args
SubExp
actual_tile_size <- [InputTile] -> Builder GPU SubExp
forall rep (m :: * -> *). HasScope rep m => [InputTile] -> m SubExp
findTileSize [InputTile]
tiles
[Char]
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp)
-> ((VName, VName) -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName]
segMap2D
[Char]
"acc"
(SegVirt -> SegLevel
SegThreadInBlock (SegSeqDims -> SegVirt
SegNoVirtFull ([Int] -> SegSeqDims
SegSeqDims [])))
ResultManifest
ResultPrivate
(SubExp
tile_size, SubExp
tile_size)
(((VName, VName) -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName])
-> ((VName, VName) -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \(VName
ltid_x, VName
ltid_y) -> do
SubExp
-> (VName, VName)
-> (VName, VName)
-> (VName, VName)
-> BuilderT GPU (State VNameSource) ()
reconstructGtids2D SubExp
tile_size (VName
gtid_x, VName
gtid_y) (VName
gid_x, VName
gid_y) (VName
ltid_x, VName
ltid_y)
[DimIndex SubExp]
-> PrivStms -> BuilderT GPU (State VNameSource) ()
addPrivStms [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_x, SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_y] PrivStms
privstms
[SubExp]
thread_accs <- [VName]
-> (VName -> Builder GPU SubExp)
-> BuilderT GPU (State VNameSource) [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
accs ((VName -> Builder GPU SubExp)
-> BuilderT GPU (State VNameSource) [SubExp])
-> (VName -> Builder GPU SubExp)
-> BuilderT GPU (State VNameSource) [SubExp]
forall a b. (a -> b) -> a -> b
$ \VName
acc ->
[Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"acc" (Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp)
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource))))
-> BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
acc (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_x, SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_y]
let form' :: ScremaForm GPU
form' = [Reduce GPU] -> Lambda GPU -> ScremaForm GPU
forall rep. [Reduce rep] -> Lambda rep -> ScremaForm rep
redomapSOAC [Commutativity -> Lambda GPU -> [SubExp] -> Reduce GPU
forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
red_comm Lambda GPU
red_lam [SubExp]
thread_accs] Lambda GPU
map_lam
sliceTile :: InputTile -> BuilderT GPU (State VNameSource) VName
sliceTile (InputUntiled VName
arr) =
VName
-> SubExp
-> SubExp
-> SubExp
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
VName -> SubExp -> SubExp -> SubExp -> m VName
sliceUntiled VName
arr SubExp
tile_id SubExp
tile_size SubExp
actual_tile_size
sliceTile (InputTiled [Int]
perm VName
tile) = do
Type
tile_t <- VName -> BuilderT GPU (State VNameSource) Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
tile
let idx :: DimIndex SubExp
idx = SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ [VName] -> VName
forall a. HasCallStack => [a] -> a
head ([VName] -> VName) -> [VName] -> VName
forall a b. (a -> b) -> a -> b
$ [Int] -> [VName] -> [VName]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm [VName
ltid_x, VName
ltid_y]
[Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"tile" (Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName)
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource))))
-> BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> BasicOp
Index VName
tile (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
Type -> Int -> [DimIndex SubExp] -> Slice SubExp
sliceAt Type
tile_t ([Int] -> Int
forall a. HasCallStack => [a] -> a
head [Int]
perm) [DimIndex SubExp
idx]
[VName]
tiles' <- (InputTile -> BuilderT GPU (State VNameSource) VName)
-> [InputTile] -> BuilderT GPU (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM InputTile -> BuilderT GPU (State VNameSource) VName
sliceTile [InputTile]
tiles
([VName] -> Result)
-> BuilderT GPU (State VNameSource) [VName] -> Builder GPU Result
forall a b.
(a -> b)
-> BuilderT GPU (State VNameSource) a
-> BuilderT GPU (State VNameSource) b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes (BuilderT GPU (State VNameSource) [VName] -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName] -> Builder GPU Result
forall a b. (a -> b) -> a -> b
$
[Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"acc"
(Exp GPU -> BuilderT GPU (State VNameSource) [VName])
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
( TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource)))))
-> TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid_x TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
kdim_x TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid_y TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
kdim_y
)
([BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource)))))
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ Op (Rep (BuilderT GPU (State VNameSource)))
-> Exp (Rep (BuilderT GPU (State VNameSource)))
forall rep. Op rep -> Exp rep
Op (Op (Rep (BuilderT GPU (State VNameSource)))
-> Exp (Rep (BuilderT GPU (State VNameSource))))
-> Op (Rep (BuilderT GPU (State VNameSource)))
-> Exp (Rep (BuilderT GPU (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ SOAC GPU -> HostOp SOAC GPU
forall (op :: * -> *) rep. op rep -> HostOp op rep
OtherOp (SOAC GPU -> HostOp SOAC GPU) -> SOAC GPU -> HostOp SOAC GPU
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm GPU -> SOAC GPU
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
actual_tile_size [VName]
tiles' ScremaForm GPU
form'])
([SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [SubExp]
thread_accs)
processResidualTile2D ::
(VName, VName) ->
(VName, VName) ->
(SubExp, SubExp) ->
SubExp ->
ResidualTileArgs ->
Builder GPU [VName]
processResidualTile2D :: (VName, VName)
-> (VName, VName)
-> (SubExp, SubExp)
-> SubExp
-> ResidualTileArgs
-> BuilderT GPU (State VNameSource) [VName]
processResidualTile2D (VName, VName)
gids (VName, VName)
gtids (SubExp, SubExp)
kdims SubExp
tile_size ResidualTileArgs
args = do
SubExp
residual_input <-
[Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"residual_input" (Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp)
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource))))
-> BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall a b. (a -> b) -> a -> b
$
BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SRem IntType
Int64 Safety
Unsafe) SubExp
w SubExp
tile_size
[Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"acc_after_residual"
(Exp GPU -> BuilderT GPU (State VNameSource) [VName])
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource)))))
-> TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 SubExp
residual_input TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0)
([SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM ([SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource)))))
-> [SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
accs)
(SubExp -> BuilderT GPU (State VNameSource) (Body GPU)
nonemptyTile SubExp
residual_input)
where
privstms :: PrivStms
privstms = ResidualTileArgs -> PrivStms
residualPrivStms ResidualTileArgs
args
red_comm :: Commutativity
red_comm = ResidualTileArgs -> Commutativity
residualComm ResidualTileArgs
args
red_lam :: Lambda GPU
red_lam = ResidualTileArgs -> Lambda GPU
residualRedLam ResidualTileArgs
args
map_lam :: Lambda GPU
map_lam = ResidualTileArgs -> Lambda GPU
residualMapLam ResidualTileArgs
args
accs :: [VName]
accs = ResidualTileArgs -> [VName]
residualAcc ResidualTileArgs
args
inputs :: [InputArray]
inputs = ResidualTileArgs -> [InputArray]
residualInput ResidualTileArgs
args
num_whole_tiles :: SubExp
num_whole_tiles = ResidualTileArgs -> SubExp
residualNumWholeTiles ResidualTileArgs
args
w :: SubExp
w = ResidualTileArgs -> SubExp
residualInputSize ResidualTileArgs
args
nonemptyTile :: SubExp -> BuilderT GPU (State VNameSource) (Body GPU)
nonemptyTile SubExp
residual_input = Body GPU -> BuilderT GPU (State VNameSource) (Body GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody (Body GPU -> BuilderT GPU (State VNameSource) (Body GPU))
-> (Builder GPU Result
-> BuilderT GPU (State VNameSource) (Body GPU))
-> Builder GPU Result
-> BuilderT GPU (State VNameSource) (Body GPU)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Builder GPU Result -> BuilderT GPU (State VNameSource) (Body GPU)
forall rep (m :: * -> *) somerep.
(Buildable rep, MonadFreshNames m, HasScope somerep m,
SameScope somerep rep) =>
Builder rep Result -> m (Body rep)
runBodyBuilder (Builder GPU Result -> BuilderT GPU (State VNameSource) (Body GPU))
-> Builder GPU Result
-> BuilderT GPU (State VNameSource) (Body GPU)
forall a b. (a -> b) -> a -> b
$ do
[InputTile]
full_tile <-
(SubExp, SubExp)
-> (VName, VName)
-> (VName, VName)
-> SubExp
-> TileKind
-> PrivStms
-> SubExp
-> [InputArray]
-> Builder GPU [InputTile]
readTile2D
(SubExp, SubExp)
kdims
(VName, VName)
gtids
(VName, VName)
gids
SubExp
tile_size
TileKind
TilePartial
PrivStms
privstms
SubExp
num_whole_tiles
[InputArray]
inputs
let slice :: DimIndex SubExp
slice =
SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) SubExp
residual_input (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)
[InputTile]
tiles <- [InputTile]
-> (InputTile -> BuilderT GPU (State VNameSource) InputTile)
-> Builder GPU [InputTile]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [InputTile]
full_tile ((InputTile -> BuilderT GPU (State VNameSource) InputTile)
-> Builder GPU [InputTile])
-> (InputTile -> BuilderT GPU (State VNameSource) InputTile)
-> Builder GPU [InputTile]
forall a b. (a -> b) -> a -> b
$ \case
InputTiled [Int]
perm VName
tile' ->
[Int] -> VName -> InputTile
InputTiled [Int]
perm
(VName -> InputTile)
-> BuilderT GPU (State VNameSource) VName
-> BuilderT GPU (State VNameSource) InputTile
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"partial_tile" (BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource))))
-> BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
tile' ([DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp
slice, DimIndex SubExp
slice]))
InputUntiled VName
arr ->
InputTile -> BuilderT GPU (State VNameSource) InputTile
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (InputTile -> BuilderT GPU (State VNameSource) InputTile)
-> InputTile -> BuilderT GPU (State VNameSource) InputTile
forall a b. (a -> b) -> a -> b
$ VName -> InputTile
InputUntiled VName
arr
let tile_args :: ProcessTileArgs
tile_args =
PrivStms
-> Commutativity
-> Lambda GPU
-> Lambda GPU
-> [InputTile]
-> [VName]
-> SubExp
-> ProcessTileArgs
ProcessTileArgs PrivStms
privstms Commutativity
red_comm Lambda GPU
red_lam Lambda GPU
map_lam [InputTile]
tiles [VName]
accs SubExp
num_whole_tiles
[VName] -> Result
varsRes ([VName] -> Result)
-> BuilderT GPU (State VNameSource) [VName] -> Builder GPU Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName, VName)
-> (VName, VName)
-> (SubExp, SubExp)
-> SubExp
-> ProcessTileArgs
-> BuilderT GPU (State VNameSource) [VName]
processTile2D (VName, VName)
gids (VName, VName)
gtids (SubExp, SubExp)
kdims SubExp
tile_size ProcessTileArgs
tile_args
tiling2d :: [(VName, SubExp)] -> DoTiling (VName, VName) (SubExp, SubExp)
tiling2d :: [(VName, SubExp)] -> DoTiling (VName, VName) (SubExp, SubExp)
tiling2d [(VName, SubExp)]
dims_on_top (VName
gtid_x, VName
gtid_y) (SubExp
kdim_x, SubExp
kdim_y) SubExp
w = do
VName
gid_x <- [Char] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"gid_x"
VName
gid_y <- [Char] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"gid_y"
Name
tile_size_key <- [Char] -> Name
nameFromString ([Char] -> Name) -> (VName -> [Char]) -> VName -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString (VName -> Name)
-> BuilderT GPU (State VNameSource) VName
-> BuilderT GPU (State VNameSource) Name
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"tile_size"
SubExp
tile_size <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"tile_size" (Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp)
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp
forall a b. (a -> b) -> a -> b
$ Op (Rep (BuilderT GPU (State VNameSource)))
-> Exp (Rep (BuilderT GPU (State VNameSource)))
forall rep. Op rep -> Exp rep
Op (Op (Rep (BuilderT GPU (State VNameSource)))
-> Exp (Rep (BuilderT GPU (State VNameSource))))
-> Op (Rep (BuilderT GPU (State VNameSource)))
-> Exp (Rep (BuilderT GPU (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp SOAC GPU
forall (op :: * -> *) rep. SizeOp -> HostOp op rep
SizeOp (SizeOp -> HostOp SOAC GPU) -> SizeOp -> HostOp SOAC GPU
forall a b. (a -> b) -> a -> b
$ Name -> SizeClass -> SizeOp
GetSize Name
tile_size_key SizeClass
SizeTile
SubExp
tblock_size <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"tblock_size" (Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp)
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource))))
-> BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) SubExp
tile_size SubExp
tile_size
SubExp
num_tblocks_x <-
[Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"num_tblocks_x" (Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp)
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource))))
-> BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall a b. (a -> b) -> a -> b
$
BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SDivUp IntType
Int64 Safety
Unsafe) SubExp
kdim_x SubExp
tile_size
SubExp
num_tblocks_y <-
[Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"num_tblocks_y" (Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp)
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource))))
-> BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall a b. (a -> b) -> a -> b
$
BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SDivUp IntType
Int64 Safety
Unsafe) SubExp
kdim_y SubExp
tile_size
SubExp
num_tblocks <-
[Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"num_tblocks_top"
(Exp GPU -> Builder GPU SubExp)
-> BuilderT GPU (State VNameSource) (Exp GPU) -> Builder GPU SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> SubExp
-> [SubExp]
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp
(IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef)
SubExp
num_tblocks_x
(SubExp
num_tblocks_y SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
dims_on_top)
VName
gid_flat <- [Char] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"gid_flat"
let grid :: KernelGrid
grid = Count NumBlocks SubExp -> Count BlockSize SubExp -> KernelGrid
KernelGrid (SubExp -> Count NumBlocks SubExp
forall {k} (u :: k) e. e -> Count u e
Count SubExp
num_tblocks) (SubExp -> Count BlockSize SubExp
forall {k} (u :: k) e. e -> Count u e
Count SubExp
tblock_size)
lvl :: SegLevel
lvl = SegVirt -> Maybe KernelGrid -> SegLevel
SegBlock (SegSeqDims -> SegVirt
SegNoVirtFull ([Int] -> SegSeqDims
SegSeqDims [])) (KernelGrid -> Maybe KernelGrid
forall a. a -> Maybe a
Just KernelGrid
grid)
space :: SegSpace
space =
VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
gid_flat ([(VName, SubExp)] -> SegSpace) -> [(VName, SubExp)] -> SegSpace
forall a b. (a -> b) -> a -> b
$
[(VName, SubExp)]
dims_on_top [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(VName
gid_x, SubExp
num_tblocks_x), (VName
gid_y, SubExp
num_tblocks_y)]
tiling_lvl :: SegLevel
tiling_lvl = SegVirt -> SegLevel
SegThreadInBlock SegVirt
SegNoVirt
Tiling -> Builder GPU Tiling
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
Tiling
{ tilingSegMap :: [Char]
-> ResultManifest
-> (PrimExp VName -> [DimIndex SubExp] -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName]
tilingSegMap = \[Char]
desc ResultManifest
manifest PrimExp VName -> [DimIndex SubExp] -> Builder GPU Result
f ->
[Char]
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp)
-> ((VName, VName) -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName]
segMap2D [Char]
desc SegLevel
tiling_lvl ResultManifest
manifest (SubExp
tile_size, SubExp
tile_size) (((VName, VName) -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName])
-> ((VName, VName) -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \(VName
ltid_x, VName
ltid_y) -> do
SubExp
-> (VName, VName)
-> (VName, VName)
-> (VName, VName)
-> BuilderT GPU (State VNameSource) ()
reconstructGtids2D SubExp
tile_size (VName
gtid_x, VName
gtid_y) (VName
gid_x, VName
gid_y) (VName
ltid_x, VName
ltid_y)
PrimExp VName -> [DimIndex SubExp] -> Builder GPU Result
f
( TPrimExp Bool VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Bool VName -> PrimExp VName)
-> TPrimExp Bool VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$
VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid_x TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
kdim_x TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid_y TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
kdim_y
)
[SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_x, SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_y],
tilingReadTile :: TileKind
-> PrivStms -> SubExp -> [InputArray] -> Builder GPU [InputTile]
tilingReadTile = (SubExp, SubExp)
-> (VName, VName)
-> (VName, VName)
-> SubExp
-> TileKind
-> PrivStms
-> SubExp
-> [InputArray]
-> Builder GPU [InputTile]
readTile2D (SubExp
kdim_x, SubExp
kdim_y) (VName
gtid_x, VName
gtid_y) (VName
gid_x, VName
gid_y) SubExp
tile_size,
tilingProcessTile :: ProcessTileArgs -> BuilderT GPU (State VNameSource) [VName]
tilingProcessTile = (VName, VName)
-> (VName, VName)
-> (SubExp, SubExp)
-> SubExp
-> ProcessTileArgs
-> BuilderT GPU (State VNameSource) [VName]
processTile2D (VName
gid_x, VName
gid_y) (VName
gtid_x, VName
gtid_y) (SubExp
kdim_x, SubExp
kdim_y) SubExp
tile_size,
tilingProcessResidualTile :: ResidualTileArgs -> BuilderT GPU (State VNameSource) [VName]
tilingProcessResidualTile = (VName, VName)
-> (VName, VName)
-> (SubExp, SubExp)
-> SubExp
-> ResidualTileArgs
-> BuilderT GPU (State VNameSource) [VName]
processResidualTile2D (VName
gid_x, VName
gid_y) (VName
gtid_x, VName
gtid_y) (SubExp
kdim_x, SubExp
kdim_y) SubExp
tile_size,
tilingTileReturns :: VName -> BuilderT GPU (State VNameSource) KernelResult
tilingTileReturns = [(VName, SubExp)]
-> [(SubExp, SubExp)]
-> VName
-> BuilderT GPU (State VNameSource) KernelResult
tileReturns [(VName, SubExp)]
dims_on_top [(SubExp
kdim_x, SubExp
tile_size), (SubExp
kdim_y, SubExp
tile_size)],
tilingTileShape :: Shape
tilingTileShape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
tile_size, SubExp
tile_size],
tilingNumWholeTiles :: Builder GPU SubExp
tilingNumWholeTiles =
[Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"num_whole_tiles" (Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp)
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource))))
-> BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall a b. (a -> b) -> a -> b
$
BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SQuot IntType
Int64 Safety
Unsafe) SubExp
w SubExp
tile_size,
tilingLevel :: SegLevel
tilingLevel = SegLevel
lvl,
tilingSpace :: SegSpace
tilingSpace = SegSpace
space
}