module Futhark.Optimise.ReduceDeviceSyncs (reduceDeviceSyncs) where
import Control.Monad
import Control.Monad.Reader
import Control.Monad.State hiding (State)
import Data.Bifunctor (second)
import Data.Foldable
import Data.IntMap.Strict qualified as IM
import Data.List (transpose, zip4)
import Data.Map.Strict qualified as M
import Data.Sequence ((><), (|>))
import Data.Text qualified as T
import Futhark.Construct (fullSlice, mkBody, sliceDim)
import Futhark.Error
import Futhark.IR.GPU
import Futhark.MonadFreshNames
import Futhark.Optimise.ReduceDeviceSyncs.MigrationTable
import Futhark.Pass
import Futhark.Transform.Substitute
reduceDeviceSyncs :: Pass GPU GPU
reduceDeviceSyncs :: Pass GPU GPU
reduceDeviceSyncs =
String -> String -> (Prog GPU -> PassM (Prog GPU)) -> Pass GPU GPU
forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass
String
"reduce device synchronizations"
String
"Move host statements to device to reduce blocking memory operations."
((Prog GPU -> PassM (Prog GPU)) -> Pass GPU GPU)
-> (Prog GPU -> PassM (Prog GPU)) -> Pass GPU GPU
forall a b. (a -> b) -> a -> b
$ \Prog GPU
prog -> do
let hof :: HostOnlyFuns
hof = [FunDef GPU] -> HostOnlyFuns
hostOnlyFunDefs ([FunDef GPU] -> HostOnlyFuns) -> [FunDef GPU] -> HostOnlyFuns
forall a b. (a -> b) -> a -> b
$ Prog GPU -> [FunDef GPU]
forall rep. Prog rep -> [FunDef rep]
progFuns Prog GPU
prog
consts_mt :: MigrationTable
consts_mt = HostOnlyFuns -> [FunDef GPU] -> Stms GPU -> MigrationTable
analyseConsts HostOnlyFuns
hof (Prog GPU -> [FunDef GPU]
forall rep. Prog rep -> [FunDef rep]
progFuns Prog GPU
prog) (Prog GPU -> Stms GPU
forall rep. Prog rep -> Stms rep
progConsts Prog GPU
prog)
Stms GPU
consts <- MigrationTable -> Stms GPU -> PassM (Stms GPU)
forall {m :: * -> *}.
MonadFreshNames m =>
MigrationTable -> Stms GPU -> m (Stms GPU)
onConsts MigrationTable
consts_mt (Stms GPU -> PassM (Stms GPU)) -> Stms GPU -> PassM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ Prog GPU -> Stms GPU
forall rep. Prog rep -> Stms rep
progConsts Prog GPU
prog
[FunDef GPU]
funs <- (FunDef GPU -> PassM (FunDef GPU))
-> [FunDef GPU] -> PassM [FunDef GPU]
forall a b. (a -> PassM b) -> [a] -> PassM [b]
parPass (HostOnlyFuns -> MigrationTable -> FunDef GPU -> PassM (FunDef GPU)
forall {m :: * -> *}.
MonadFreshNames m =>
HostOnlyFuns -> MigrationTable -> FunDef GPU -> m (FunDef GPU)
onFun HostOnlyFuns
hof MigrationTable
consts_mt) (Prog GPU -> [FunDef GPU]
forall rep. Prog rep -> [FunDef rep]
progFuns Prog GPU
prog)
Prog GPU -> PassM (Prog GPU)
forall a. a -> PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Prog GPU -> PassM (Prog GPU)) -> Prog GPU -> PassM (Prog GPU)
forall a b. (a -> b) -> a -> b
$ Prog GPU
prog {progConsts = consts, progFuns = funs}
where
onConsts :: MigrationTable -> Stms GPU -> m (Stms GPU)
onConsts MigrationTable
consts_mt Stms GPU
stms =
MigrationTable -> ReduceM (Stms GPU) -> m (Stms GPU)
forall (m :: * -> *) a.
MonadFreshNames m =>
MigrationTable -> ReduceM a -> m a
runReduceM MigrationTable
consts_mt (Stms GPU -> ReduceM (Stms GPU)
optimizeStms Stms GPU
stms)
onFun :: HostOnlyFuns -> MigrationTable -> FunDef GPU -> m (FunDef GPU)
onFun HostOnlyFuns
hof MigrationTable
consts_mt FunDef GPU
fd = do
let mt :: MigrationTable
mt = MigrationTable
consts_mt MigrationTable -> MigrationTable -> MigrationTable
forall a. Semigroup a => a -> a -> a
<> HostOnlyFuns -> FunDef GPU -> MigrationTable
analyseFunDef HostOnlyFuns
hof FunDef GPU
fd
MigrationTable -> ReduceM (FunDef GPU) -> m (FunDef GPU)
forall (m :: * -> *) a.
MonadFreshNames m =>
MigrationTable -> ReduceM a -> m a
runReduceM MigrationTable
mt (FunDef GPU -> ReduceM (FunDef GPU)
optimizeFunDef FunDef GPU
fd)
optimizeFunDef :: FunDef GPU -> ReduceM (FunDef GPU)
optimizeFunDef :: FunDef GPU -> ReduceM (FunDef GPU)
optimizeFunDef FunDef GPU
fd = do
let body :: Body GPU
body = FunDef GPU -> Body GPU
forall rep. FunDef rep -> Body rep
funDefBody FunDef GPU
fd
Stms GPU
stms' <- Stms GPU -> ReduceM (Stms GPU)
optimizeStms (Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms Body GPU
body)
FunDef GPU -> ReduceM (FunDef GPU)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (FunDef GPU -> ReduceM (FunDef GPU))
-> FunDef GPU -> ReduceM (FunDef GPU)
forall a b. (a -> b) -> a -> b
$ FunDef GPU
fd {funDefBody = body {bodyStms = stms'}}
optimizeBody :: Body GPU -> ReduceM (Body GPU)
optimizeBody :: Body GPU -> ReduceM (Body GPU)
optimizeBody (Body BodyDec GPU
_ Stms GPU
stms Result
res) = do
Stms GPU
stms' <- Stms GPU -> ReduceM (Stms GPU)
optimizeStms Stms GPU
stms
Result
res' <- Result -> ReduceM Result
resolveResult Result
res
Body GPU -> ReduceM (Body GPU)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (BodyDec GPU -> Stms GPU -> Result -> Body GPU
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms GPU
stms' Result
res')
optimizeStms :: Stms GPU -> ReduceM (Stms GPU)
optimizeStms :: Stms GPU -> ReduceM (Stms GPU)
optimizeStms = (Stms GPU -> Stm GPU -> ReduceM (Stms GPU))
-> Stms GPU -> Stms GPU -> ReduceM (Stms GPU)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Stms GPU -> Stm GPU -> ReduceM (Stms GPU)
optimizeStm Stms GPU
forall a. Monoid a => a
mempty
optimizeStm :: Stms GPU -> Stm GPU -> ReduceM (Stms GPU)
optimizeStm :: Stms GPU -> Stm GPU -> ReduceM (Stms GPU)
optimizeStm Stms GPU
out Stm GPU
stm = do
Bool
move <- (MigrationTable -> Bool) -> ReduceM Bool
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (Stm GPU -> MigrationTable -> Bool
shouldMoveStm Stm GPU
stm)
if Bool
move
then Stms GPU -> Stm GPU -> ReduceM (Stms GPU)
moveStm Stms GPU
out Stm GPU
stm
else case Stm GPU -> Exp GPU
forall rep. Stm rep -> Exp rep
stmExp Stm GPU
stm of
BasicOp (Update Safety
safety VName
arr Slice SubExp
slice (Var VName
v))
| Just [SubExp]
_ <- Slice SubExp -> Maybe [SubExp]
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
slice -> do
Maybe VName
dev <- SubExp -> ReduceM (Maybe VName)
storedScalar (VName -> SubExp
Var VName
v)
case Maybe VName
dev of
Maybe VName
Nothing -> Stms GPU -> ReduceM (Stms GPU)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm)
Just VName
dst -> do
let dims :: [DimIndex SubExp]
dims = Slice SubExp -> [DimIndex SubExp]
forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slice
let ([DimIndex SubExp]
outer, [DimFix SubExp
i]) = Int -> [DimIndex SubExp] -> ([DimIndex SubExp], [DimIndex SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt ([DimIndex SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex SubExp]
dims Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) [DimIndex SubExp]
dims
let one :: SubExp
one = IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1
let slice' :: Slice SubExp
slice' = [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp]
outer [DimIndex SubExp] -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
i SubExp
one SubExp
one]
let e :: Exp rep
e = BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
safety VName
arr Slice SubExp
slice' (VName -> SubExp
Var VName
dst))
let stm' :: Stm GPU
stm' = Stm GPU
stm {stmExp = e}
Stms GPU -> ReduceM (Stms GPU)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm')
BasicOp (Replicate (Shape [SubExp]
dims) (Var VName
v))
| Pat [PatElem VName
n LetDec GPU
arr_t] <- Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm -> do
VName
v' <- VName -> ReduceM VName
resolveName VName
v
let v_kept_on_device :: Bool
v_kept_on_device = VName
v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
/= VName
v'
Bool
gpubody_ok <- (State -> Bool) -> ReduceM Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets State -> Bool
stateGPUBodyOk
case Bool
v_kept_on_device of
Bool
False -> Stms GPU -> ReduceM (Stms GPU)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm)
Bool
True
| (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) [SubExp]
dims,
Just Type
t' <- Int -> Type -> Maybe Type
forall u.
Int
-> TypeBase (ShapeBase SubExp) u
-> Maybe (TypeBase (ShapeBase SubExp) u)
peelArray Int
1 Type
LetDec GPU
arr_t,
Bool
gpubody_ok -> do
let n' :: VName
n' = Name -> Int -> VName
VName (VName -> Name
baseName VName
n Name -> String -> Name
`withSuffix` String
"_inner") Int
0
let pat' :: Pat Type
pat' = [PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [VName -> Type -> PatElem Type
forall dec. VName -> dec -> PatElem dec
PatElem VName
n' Type
t']
let e' :: Exp rep
e' = BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> SubExp -> BasicOp
Replicate ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> ShapeBase SubExp) -> [SubExp] -> ShapeBase SubExp
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [SubExp]
forall a. HasCallStack => [a] -> [a]
tail [SubExp]
dims) (VName -> SubExp
Var VName
v)
let stm' :: Stm GPU
stm' = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec GPU)
pat' (Stm GPU -> StmAux (ExpDec GPU)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm GPU
stm) Exp GPU
forall {rep}. Exp rep
e'
Stm GPU
gpubody <- RewriteM (Stm GPU) -> ReduceM (Stm GPU)
inGPUBody (Stm GPU -> RewriteM (Stm GPU)
rewriteStm Stm GPU
stm')
Stms GPU -> ReduceM (Stms GPU)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
gpubody {stmPat = stmPat stm})
Bool
True
| [SubExp] -> SubExp
forall a. HasCallStack => [a] -> a
last [SubExp]
dims SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1 ->
let e' :: Exp rep
e' = BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> SubExp -> BasicOp
Replicate ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> ShapeBase SubExp) -> [SubExp] -> ShapeBase SubExp
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [SubExp]
forall a. HasCallStack => [a] -> [a]
init [SubExp]
dims) (VName -> SubExp
Var VName
v')
stm' :: Stm GPU
stm' = Stm GPU
stm {stmExp = e'}
in Stms GPU -> ReduceM (Stms GPU)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm')
Bool
True -> do
VName
n' <- VName -> ReduceM VName
forall (m :: * -> *). MonadFreshNames m => VName -> m VName
newName VName
n
let dims' :: [SubExp]
dims' = [SubExp]
dims [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1]
let arr_t' :: Type
arr_t' = PrimType -> ShapeBase SubExp -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
LetDec GPU
arr_t) ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp]
dims') NoUniqueness
NoUniqueness
let pat' :: Pat Type
pat' = [PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [VName -> Type -> PatElem Type
forall dec. VName -> dec -> PatElem dec
PatElem VName
n' Type
arr_t']
let e' :: Exp rep
e' = BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> SubExp -> BasicOp
Replicate ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp]
dims) (VName -> SubExp
Var VName
v')
let repl :: Stm GPU
repl = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec GPU)
pat' (Stm GPU -> StmAux (ExpDec GPU)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm GPU
stm) Exp GPU
forall {rep}. Exp rep
e'
let aux :: StmAux ()
aux = Certs -> Attrs -> () -> StmAux ()
forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
forall a. Monoid a => a
mempty Attrs
forall a. Monoid a => a
mempty ()
let slice :: [DimIndex SubExp]
slice = (SubExp -> DimIndex SubExp) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
sliceDim (Type -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims Type
LetDec GPU
arr_t)
let slice' :: [DimIndex SubExp]
slice' = [DimIndex SubExp]
slice [DimIndex SubExp] -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0]
let idx :: Exp rep
idx = BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
n' ([DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
slice')
let index :: Stm GPU
index = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm) StmAux ()
StmAux (ExpDec GPU)
aux Exp GPU
forall {rep}. Exp rep
idx
Stms GPU -> ReduceM (Stms GPU)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
repl Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
index)
BasicOp {} ->
Stms GPU -> ReduceM (Stms GPU)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm)
Apply {} ->
Stms GPU -> ReduceM (Stms GPU)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm)
Match [SubExp]
ses [Case (Body GPU)]
cases Body GPU
defbody (MatchDec [BranchType GPU]
btypes MatchSort
sort) -> do
[Stms GPU]
cases_stms <- (Case (Body GPU) -> ReduceM (Stms GPU))
-> [Case (Body GPU)] -> ReduceM [Stms GPU]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Stms GPU -> ReduceM (Stms GPU)
optimizeStms (Stms GPU -> ReduceM (Stms GPU))
-> (Case (Body GPU) -> Stms GPU)
-> Case (Body GPU)
-> ReduceM (Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms (Body GPU -> Stms GPU)
-> (Case (Body GPU) -> Body GPU) -> Case (Body GPU) -> Stms GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Case (Body GPU) -> Body GPU
forall body. Case body -> body
caseBody) [Case (Body GPU)]
cases
let cases_res :: [Result]
cases_res = (Case (Body GPU) -> Result) -> [Case (Body GPU)] -> [Result]
forall a b. (a -> b) -> [a] -> [b]
map (Body GPU -> Result
forall rep. Body rep -> Result
bodyResult (Body GPU -> Result)
-> (Case (Body GPU) -> Body GPU) -> Case (Body GPU) -> Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Case (Body GPU) -> Body GPU
forall body. Case body -> body
caseBody) [Case (Body GPU)]
cases
Stms GPU
defbody_stms <- Stms GPU -> ReduceM (Stms GPU)
optimizeStms (Stms GPU -> ReduceM (Stms GPU)) -> Stms GPU -> ReduceM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms Body GPU
defbody
let defbody_res :: Result
defbody_res = Body GPU -> Result
forall rep. Body rep -> Result
bodyResult Body GPU
defbody
let bmerge :: ([(PatElem Type, Result, ExtType)], [Stms GPU])
-> (PatElem Type, Result, ExtType)
-> ReduceM ([(PatElem Type, Result, ExtType)], [Stms GPU])
bmerge ([(PatElem Type, Result, ExtType)]
acc, [Stms GPU]
all_stms) (PatElem Type
pe, Result
reses, ExtType
bt) = do
let onHost :: SubExp -> ReduceM Bool
onHost (Var VName
v) = (VName
v ==) (VName -> Bool) -> ReduceM VName -> ReduceM Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ReduceM VName
resolveName VName
v
onHost SubExp
_ = Bool -> ReduceM Bool
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
Bool
on_host <- [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool) -> ReduceM [Bool] -> ReduceM Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExpRes -> ReduceM Bool) -> Result -> ReduceM [Bool]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SubExp -> ReduceM Bool
onHost (SubExp -> ReduceM Bool)
-> (SubExpRes -> SubExp) -> SubExpRes -> ReduceM Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) Result
reses
if Bool
on_host
then
([(PatElem Type, Result, ExtType)], [Stms GPU])
-> ReduceM ([(PatElem Type, Result, ExtType)], [Stms GPU])
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((PatElem Type
pe, Result
reses, ExtType
bt) (PatElem Type, Result, ExtType)
-> [(PatElem Type, Result, ExtType)]
-> [(PatElem Type, Result, ExtType)]
forall a. a -> [a] -> [a]
: [(PatElem Type, Result, ExtType)]
acc, [Stms GPU]
all_stms)
else do
([Stms GPU]
all_stms', [VName]
arrs) <-
([(Stms GPU, VName)] -> ([Stms GPU], [VName]))
-> ReduceM [(Stms GPU, VName)] -> ReduceM ([Stms GPU], [VName])
forall a b. (a -> b) -> ReduceM a -> ReduceM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(Stms GPU, VName)] -> ([Stms GPU], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip (ReduceM [(Stms GPU, VName)] -> ReduceM ([Stms GPU], [VName]))
-> ReduceM [(Stms GPU, VName)] -> ReduceM ([Stms GPU], [VName])
forall a b. (a -> b) -> a -> b
$
[(Stms GPU, SubExpRes)]
-> ((Stms GPU, SubExpRes) -> ReduceM (Stms GPU, VName))
-> ReduceM [(Stms GPU, VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Stms GPU] -> Result -> [(Stms GPU, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Stms GPU]
all_stms Result
reses) (((Stms GPU, SubExpRes) -> ReduceM (Stms GPU, VName))
-> ReduceM [(Stms GPU, VName)])
-> ((Stms GPU, SubExpRes) -> ReduceM (Stms GPU, VName))
-> ReduceM [(Stms GPU, VName)]
forall a b. (a -> b) -> a -> b
$ \(Stms GPU
stms, SubExpRes
res) ->
Stms GPU -> SubExp -> Type -> ReduceM (Stms GPU, VName)
storeScalar Stms GPU
stms (SubExpRes -> SubExp
resSubExp SubExpRes
res) (PatElem Type -> Type
forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem Type
pe)
PatElem Type
pe' <- PatElem Type -> ReduceM (PatElem Type)
arrayizePatElem PatElem Type
pe
let bt' :: ExtType
bt' = Type -> ExtType
forall u.
TypeBase (ShapeBase SubExp) u -> TypeBase (ShapeBase ExtSize) u
staticShapes1 (PatElem Type -> Type
forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem Type
pe')
reses' :: Result
reses' = (Certs -> SubExp -> SubExpRes) -> [Certs] -> [SubExp] -> Result
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Certs -> SubExp -> SubExpRes
SubExpRes ((SubExpRes -> Certs) -> Result -> [Certs]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> Certs
resCerts Result
reses) ((VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
arrs)
([(PatElem Type, Result, ExtType)], [Stms GPU])
-> ReduceM ([(PatElem Type, Result, ExtType)], [Stms GPU])
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((PatElem Type
pe', Result
reses', ExtType
bt') (PatElem Type, Result, ExtType)
-> [(PatElem Type, Result, ExtType)]
-> [(PatElem Type, Result, ExtType)]
forall a. a -> [a] -> [a]
: [(PatElem Type, Result, ExtType)]
acc, [Stms GPU]
all_stms')
pes :: [PatElem Type]
pes = Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems (Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm)
([(PatElem Type, Result, ExtType)]
acc, ~(Stms GPU
defbody_stms' : [Stms GPU]
cases_stms')) <-
(([(PatElem Type, Result, ExtType)], [Stms GPU])
-> (PatElem Type, Result, ExtType)
-> ReduceM ([(PatElem Type, Result, ExtType)], [Stms GPU]))
-> ([(PatElem Type, Result, ExtType)], [Stms GPU])
-> [(PatElem Type, Result, ExtType)]
-> ReduceM ([(PatElem Type, Result, ExtType)], [Stms GPU])
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ([(PatElem Type, Result, ExtType)], [Stms GPU])
-> (PatElem Type, Result, ExtType)
-> ReduceM ([(PatElem Type, Result, ExtType)], [Stms GPU])
bmerge ([], Stms GPU
defbody_stms Stms GPU -> [Stms GPU] -> [Stms GPU]
forall a. a -> [a] -> [a]
: [Stms GPU]
cases_stms) ([(PatElem Type, Result, ExtType)]
-> ReduceM ([(PatElem Type, Result, ExtType)], [Stms GPU]))
-> [(PatElem Type, Result, ExtType)]
-> ReduceM ([(PatElem Type, Result, ExtType)], [Stms GPU])
forall a b. (a -> b) -> a -> b
$
[PatElem Type]
-> [Result] -> [ExtType] -> [(PatElem Type, Result, ExtType)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElem Type]
pes ([Result] -> [Result]
forall a. [[a]] -> [[a]]
transpose ([Result] -> [Result]) -> [Result] -> [Result]
forall a b. (a -> b) -> a -> b
$ Result
defbody_res Result -> [Result] -> [Result]
forall a. a -> [a] -> [a]
: [Result]
cases_res) [ExtType]
[BranchType GPU]
btypes
let ([PatElem Type]
pes', [Result]
reses, [ExtType]
btypes') = [(PatElem Type, Result, ExtType)]
-> ([PatElem Type], [Result], [ExtType])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(PatElem Type, Result, ExtType)]
-> [(PatElem Type, Result, ExtType)]
forall a. [a] -> [a]
reverse [(PatElem Type, Result, ExtType)]
acc)
let cases' :: [Case (Body GPU)]
cases' =
([Maybe PrimValue] -> Body GPU -> Case (Body GPU))
-> [[Maybe PrimValue]] -> [Body GPU] -> [Case (Body GPU)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith [Maybe PrimValue] -> Body GPU -> Case (Body GPU)
forall body. [Maybe PrimValue] -> body -> Case body
Case ((Case (Body GPU) -> [Maybe PrimValue])
-> [Case (Body GPU)] -> [[Maybe PrimValue]]
forall a b. (a -> b) -> [a] -> [b]
map Case (Body GPU) -> [Maybe PrimValue]
forall body. Case body -> [Maybe PrimValue]
casePat [Case (Body GPU)]
cases) ([Body GPU] -> [Case (Body GPU)])
-> [Body GPU] -> [Case (Body GPU)]
forall a b. (a -> b) -> a -> b
$
(Stms GPU -> Result -> Body GPU)
-> [Stms GPU] -> [Result] -> [Body GPU]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Stms GPU -> Result -> Body GPU
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody [Stms GPU]
cases_stms' ([Result] -> [Body GPU]) -> [Result] -> [Body GPU]
forall a b. (a -> b) -> a -> b
$
Int -> [Result] -> [Result]
forall a. Int -> [a] -> [a]
drop Int
1 ([Result] -> [Result]) -> [Result] -> [Result]
forall a b. (a -> b) -> a -> b
$
[Result] -> [Result]
forall a. [[a]] -> [[a]]
transpose [Result]
reses
defbody' :: Body GPU
defbody' = Stms GPU -> Result -> Body GPU
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody Stms GPU
defbody_stms' (Result -> Body GPU) -> Result -> Body GPU
forall a b. (a -> b) -> a -> b
$ (Result -> SubExpRes) -> [Result] -> Result
forall a b. (a -> b) -> [a] -> [b]
map Result -> SubExpRes
forall a. HasCallStack => [a] -> a
head [Result]
reses
e' :: Exp GPU
e' = [SubExp]
-> [Case (Body GPU)]
-> Body GPU
-> MatchDec (BranchType GPU)
-> Exp GPU
forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
ses [Case (Body GPU)]
cases' Body GPU
defbody' ([ExtType] -> MatchSort -> MatchDec ExtType
forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [ExtType]
btypes' MatchSort
sort)
stm' :: Stm GPU
stm' = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type]
pes') (Stm GPU -> StmAux (ExpDec GPU)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm GPU
stm) Exp GPU
e'
(Stms GPU -> (PatElem Type, PatElem Type) -> ReduceM (Stms GPU))
-> Stms GPU -> [(PatElem Type, PatElem Type)] -> ReduceM (Stms GPU)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Stms GPU -> (PatElem Type, PatElem Type) -> ReduceM (Stms GPU)
forall {dec}.
Stms GPU -> (PatElem Type, PatElem dec) -> ReduceM (Stms GPU)
addRead (Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm') ([PatElem Type] -> [PatElem Type] -> [(PatElem Type, PatElem Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem Type]
pes [PatElem Type]
pes')
Loop [(FParam GPU, SubExp)]
params LoopForm
lform Body GPU
body -> do
let lmerge :: ([(PatElem Type, (Param DeclType, SubExp))], Stms GPU, Stms GPU)
-> (PatElem Type, (Param DeclType, SubExp), MigrationStatus)
-> ReduceM
([(PatElem Type, (Param DeclType, SubExp))], Stms GPU, Stms GPU)
lmerge ([(PatElem Type, (Param DeclType, SubExp))]
res, Stms GPU
stms, Stms GPU
rebinds) (PatElem Type
pe, (Param DeclType, SubExp)
param, MigrationStatus
StayOnHost) =
([(PatElem Type, (Param DeclType, SubExp))], Stms GPU, Stms GPU)
-> ReduceM
([(PatElem Type, (Param DeclType, SubExp))], Stms GPU, Stms GPU)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((PatElem Type
pe, (Param DeclType, SubExp)
param) (PatElem Type, (Param DeclType, SubExp))
-> [(PatElem Type, (Param DeclType, SubExp))]
-> [(PatElem Type, (Param DeclType, SubExp))]
forall a. a -> [a] -> [a]
: [(PatElem Type, (Param DeclType, SubExp))]
res, Stms GPU
stms, Stms GPU
rebinds)
lmerge ([(PatElem Type, (Param DeclType, SubExp))]
res, Stms GPU
stms, Stms GPU
rebinds) (PatElem Type
pe, (Param Attrs
_ VName
pn DeclType
pt, SubExp
pval), MigrationStatus
_) = do
PatElem Type
pe' <- PatElem Type -> ReduceM (PatElem Type)
arrayizePatElem PatElem Type
pe
(Stms GPU
stms', VName
arr) <- Stms GPU -> SubExp -> Type -> ReduceM (Stms GPU, VName)
storeScalar Stms GPU
stms SubExp
pval (DeclType -> Type
forall shape.
TypeBase shape Uniqueness -> TypeBase shape NoUniqueness
fromDecl DeclType
pt)
VName
pn' <- VName -> ReduceM VName
forall (m :: * -> *). MonadFreshNames m => VName -> m VName
newName VName
pn
let pt' :: DeclType
pt' = Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl (PatElem Type -> Type
forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem Type
pe') Uniqueness
Nonunique
let pval' :: SubExp
pval' = VName -> SubExp
Var VName
arr
let param' :: (Param DeclType, SubExp)
param' = (Attrs -> VName -> DeclType -> Param DeclType
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty VName
pn' DeclType
pt', SubExp
pval')
Stms GPU
rebinds' <- (PatElem Type
pe {patElemName = pn}) PatElem Type -> (VName, Stms GPU) -> ReduceM (Stms GPU)
`migratedTo` (VName
pn', Stms GPU
rebinds)
([(PatElem Type, (Param DeclType, SubExp))], Stms GPU, Stms GPU)
-> ReduceM
([(PatElem Type, (Param DeclType, SubExp))], Stms GPU, Stms GPU)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((PatElem Type
pe', (Param DeclType, SubExp)
param') (PatElem Type, (Param DeclType, SubExp))
-> [(PatElem Type, (Param DeclType, SubExp))]
-> [(PatElem Type, (Param DeclType, SubExp))]
forall a. a -> [a] -> [a]
: [(PatElem Type, (Param DeclType, SubExp))]
res, Stms GPU
stms', Stms GPU
rebinds')
MigrationTable
mt <- ReduceM MigrationTable
forall r (m :: * -> *). MonadReader r m => m r
ask
let pes :: [PatElem Type]
pes = Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems (Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm)
let mss :: [MigrationStatus]
mss = ((Param DeclType, SubExp) -> MigrationStatus)
-> [(Param DeclType, SubExp)] -> [MigrationStatus]
forall a b. (a -> b) -> [a] -> [b]
map (\(Param Attrs
_ VName
n DeclType
_, SubExp
_) -> VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt) [(Param DeclType, SubExp)]
[(FParam GPU, SubExp)]
params
([(PatElem Type, (Param DeclType, SubExp))]
zipped', Stms GPU
out', Stms GPU
rebinds) <-
(([(PatElem Type, (Param DeclType, SubExp))], Stms GPU, Stms GPU)
-> (PatElem Type, (Param DeclType, SubExp), MigrationStatus)
-> ReduceM
([(PatElem Type, (Param DeclType, SubExp))], Stms GPU, Stms GPU))
-> ([(PatElem Type, (Param DeclType, SubExp))], Stms GPU, Stms GPU)
-> [(PatElem Type, (Param DeclType, SubExp), MigrationStatus)]
-> ReduceM
([(PatElem Type, (Param DeclType, SubExp))], Stms GPU, Stms GPU)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ([(PatElem Type, (Param DeclType, SubExp))], Stms GPU, Stms GPU)
-> (PatElem Type, (Param DeclType, SubExp), MigrationStatus)
-> ReduceM
([(PatElem Type, (Param DeclType, SubExp))], Stms GPU, Stms GPU)
lmerge ([], Stms GPU
out, Stms GPU
forall a. Monoid a => a
mempty) ([PatElem Type]
-> [(Param DeclType, SubExp)]
-> [MigrationStatus]
-> [(PatElem Type, (Param DeclType, SubExp), MigrationStatus)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElem Type]
pes [(Param DeclType, SubExp)]
[(FParam GPU, SubExp)]
params [MigrationStatus]
mss)
let ([PatElem Type]
pes', [(Param DeclType, SubExp)]
params') = [(PatElem Type, (Param DeclType, SubExp))]
-> ([PatElem Type], [(Param DeclType, SubExp)])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(PatElem Type, (Param DeclType, SubExp))]
-> [(PatElem Type, (Param DeclType, SubExp))]
forall a. [a] -> [a]
reverse [(PatElem Type, (Param DeclType, SubExp))]
zipped')
let body1 :: Body GPU
body1 = Body GPU
body {bodyStms = rebinds >< bodyStms body}
Body GPU
body2 <- Body GPU -> ReduceM (Body GPU)
optimizeBody Body GPU
body1
let zipped :: [(MigrationStatus, SubExpRes, SubExp, Type)]
zipped =
[MigrationStatus]
-> Result
-> [SubExp]
-> [Type]
-> [(MigrationStatus, SubExpRes, SubExp, Type)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4
[MigrationStatus]
mss
(Body GPU -> Result
forall rep. Body rep -> Result
bodyResult Body GPU
body2)
((SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (Result -> [SubExp]) -> Result -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Body GPU -> Result
forall rep. Body rep -> Result
bodyResult Body GPU
body)
((PatElem Type -> Type) -> [PatElem Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map PatElem Type -> Type
forall dec. Typed dec => PatElem dec -> Type
patElemType [PatElem Type]
pes)
let rstore :: (Stms GPU, Result)
-> (MigrationStatus, SubExpRes, SubExp, Type)
-> ReduceM (Stms GPU, Result)
rstore (Stms GPU
bstms, Result
res) (MigrationStatus
StayOnHost, SubExpRes
r, SubExp
_, Type
_) =
(Stms GPU, Result) -> ReduceM (Stms GPU, Result)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
bstms, SubExpRes
r SubExpRes -> Result -> Result
forall a. a -> [a] -> [a]
: Result
res)
rstore (Stms GPU
bstms, Result
res) (MigrationStatus
_, SubExpRes Certs
certs SubExp
_, SubExp
se, Type
t) = do
(Stms GPU
bstms', VName
dev) <- Stms GPU -> SubExp -> Type -> ReduceM (Stms GPU, VName)
storeScalar Stms GPU
bstms SubExp
se Type
t
(Stms GPU, Result) -> ReduceM (Stms GPU, Result)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
bstms', Certs -> SubExp -> SubExpRes
SubExpRes Certs
certs (VName -> SubExp
Var VName
dev) SubExpRes -> Result -> Result
forall a. a -> [a] -> [a]
: Result
res)
(Stms GPU
bstms, Result
res) <- ((Stms GPU, Result)
-> (MigrationStatus, SubExpRes, SubExp, Type)
-> ReduceM (Stms GPU, Result))
-> (Stms GPU, Result)
-> [(MigrationStatus, SubExpRes, SubExp, Type)]
-> ReduceM (Stms GPU, Result)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Stms GPU, Result)
-> (MigrationStatus, SubExpRes, SubExp, Type)
-> ReduceM (Stms GPU, Result)
rstore (Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms Body GPU
body2, []) [(MigrationStatus, SubExpRes, SubExp, Type)]
zipped
let body3 :: Body GPU
body3 = Body GPU
body2 {bodyStms = bstms, bodyResult = reverse res}
let e' :: Exp GPU
e' = [(FParam GPU, SubExp)] -> LoopForm -> Body GPU -> Exp GPU
forall rep.
[(FParam rep, SubExp)] -> LoopForm -> Body rep -> Exp rep
Loop [(Param DeclType, SubExp)]
[(FParam GPU, SubExp)]
params' LoopForm
lform Body GPU
body3
let stm' :: Stm GPU
stm' = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type]
pes') (Stm GPU -> StmAux (ExpDec GPU)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm GPU
stm) Exp GPU
e'
(Stms GPU -> (PatElem Type, PatElem Type) -> ReduceM (Stms GPU))
-> Stms GPU -> [(PatElem Type, PatElem Type)] -> ReduceM (Stms GPU)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Stms GPU -> (PatElem Type, PatElem Type) -> ReduceM (Stms GPU)
forall {dec}.
Stms GPU -> (PatElem Type, PatElem dec) -> ReduceM (Stms GPU)
addRead (Stms GPU
out' Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm') ([PatElem Type] -> [PatElem Type] -> [(PatElem Type, PatElem Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem Type]
pes [PatElem Type]
pes')
WithAcc [WithAccInput GPU]
inputs Lambda GPU
lmd -> do
let getAcc :: TypeBase shape u -> VName
getAcc (Acc VName
a ShapeBase SubExp
_ [Type]
_ u
_) = VName
a
getAcc TypeBase shape u
_ =
String -> VName
forall a. String -> a
compilerBugS
String
"Type error: WithAcc expression did not return accumulator."
let accs :: [(VName, WithAccInput GPU)]
accs = (Type -> WithAccInput GPU -> (VName, WithAccInput GPU))
-> [Type] -> [WithAccInput GPU] -> [(VName, WithAccInput GPU)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Type
t WithAccInput GPU
i -> (Type -> VName
forall {shape} {u}. TypeBase shape u -> VName
getAcc Type
t, WithAccInput GPU
i)) (Lambda GPU -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda GPU
lmd) [WithAccInput GPU]
inputs
[WithAccInput GPU]
inputs' <- ((VName, WithAccInput GPU) -> ReduceM (WithAccInput GPU))
-> [(VName, WithAccInput GPU)] -> ReduceM [WithAccInput GPU]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((VName -> WithAccInput GPU -> ReduceM (WithAccInput GPU))
-> (VName, WithAccInput GPU) -> ReduceM (WithAccInput GPU)
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VName -> WithAccInput GPU -> ReduceM (WithAccInput GPU)
optimizeWithAccInput) [(VName, WithAccInput GPU)]
accs
let body :: Body GPU
body = Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
lmd
Stms GPU
stms' <- Stms GPU -> ReduceM (Stms GPU)
optimizeStms (Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms Body GPU
body)
let rewrite :: (SubExpRes, Type, PatElem Type)
-> ReduceM (SubExpRes, Type, PatElem Type)
rewrite (SubExpRes Certs
certs SubExp
se, Type
t, PatElem Type
pe) =
do
SubExp
se' <- SubExp -> ReduceM SubExp
resolveSubExp SubExp
se
if SubExp
se SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
se'
then (SubExpRes, Type, PatElem Type)
-> ReduceM (SubExpRes, Type, PatElem Type)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Certs -> SubExp -> SubExpRes
SubExpRes Certs
certs SubExp
se, Type
t, PatElem Type
pe)
else do
PatElem Type
pe' <- PatElem Type -> ReduceM (PatElem Type)
arrayizePatElem PatElem Type
pe
let t' :: Type
t' = PatElem Type -> Type
forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem Type
pe'
(SubExpRes, Type, PatElem Type)
-> ReduceM (SubExpRes, Type, PatElem Type)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Certs -> SubExp -> SubExpRes
SubExpRes Certs
certs SubExp
se', Type
t', PatElem Type
pe')
let len :: Int
len = [WithAccInput GPU] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput GPU]
inputs
let (Result
res0, Result
res1) = Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitAt Int
len (Body GPU -> Result
forall rep. Body rep -> Result
bodyResult Body GPU
body)
let ([Type]
rts0, [Type]
rts1) = Int -> [Type] -> ([Type], [Type])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
len (Lambda GPU -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda GPU
lmd)
let pes :: [PatElem Type]
pes = Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems (Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm)
let ([PatElem Type]
pes0, [PatElem Type]
pes1) = Int -> [PatElem Type] -> ([PatElem Type], [PatElem Type])
forall a. Int -> [a] -> ([a], [a])
splitAt ([PatElem Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElem Type]
pes Int -> Int -> Int
forall a. Num a => a -> a -> a
- Result -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
res1) [PatElem Type]
pes
(Result
res1', [Type]
rts1', [PatElem Type]
pes1') <- [(SubExpRes, Type, PatElem Type)]
-> (Result, [Type], [PatElem Type])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(SubExpRes, Type, PatElem Type)]
-> (Result, [Type], [PatElem Type]))
-> ReduceM [(SubExpRes, Type, PatElem Type)]
-> ReduceM (Result, [Type], [PatElem Type])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((SubExpRes, Type, PatElem Type)
-> ReduceM (SubExpRes, Type, PatElem Type))
-> [(SubExpRes, Type, PatElem Type)]
-> ReduceM [(SubExpRes, Type, PatElem Type)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SubExpRes, Type, PatElem Type)
-> ReduceM (SubExpRes, Type, PatElem Type)
rewrite (Result
-> [Type] -> [PatElem Type] -> [(SubExpRes, Type, PatElem Type)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 Result
res1 [Type]
rts1 [PatElem Type]
pes1)
let res' :: Result
res' = Result
res0 Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
res1'
let rts' :: [Type]
rts' = [Type]
rts0 [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
rts1'
let pes' :: [PatElem Type]
pes' = [PatElem Type]
pes0 [PatElem Type] -> [PatElem Type] -> [PatElem Type]
forall a. [a] -> [a] -> [a]
++ [PatElem Type]
pes1'
let body' :: Body GPU
body' = BodyDec GPU -> Stms GPU -> Result -> Body GPU
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms GPU
stms' Result
res'
let lmd' :: Lambda GPU
lmd' = Lambda GPU
lmd {lambdaBody = body', lambdaReturnType = rts'}
let e' :: Exp GPU
e' = [WithAccInput GPU] -> Lambda GPU -> Exp GPU
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [WithAccInput GPU]
inputs' Lambda GPU
lmd'
let stm' :: Stm GPU
stm' = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type]
pes') (Stm GPU -> StmAux (ExpDec GPU)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm GPU
stm) Exp GPU
e'
(Stms GPU -> (PatElem Type, PatElem Type) -> ReduceM (Stms GPU))
-> Stms GPU -> [(PatElem Type, PatElem Type)] -> ReduceM (Stms GPU)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Stms GPU -> (PatElem Type, PatElem Type) -> ReduceM (Stms GPU)
forall {dec}.
Stms GPU -> (PatElem Type, PatElem dec) -> ReduceM (Stms GPU)
addRead (Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm') ([PatElem Type] -> [PatElem Type] -> [(PatElem Type, PatElem Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem Type]
pes [PatElem Type]
pes')
Op Op GPU
op -> do
HostOp SOAC GPU
op' <- HostOp SOAC GPU -> ReduceM (HostOp SOAC GPU)
forall (op :: * -> *). HostOp op GPU -> ReduceM (HostOp op GPU)
optimizeHostOp Op GPU
HostOp SOAC GPU
op
Stms GPU -> ReduceM (Stms GPU)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm {stmExp = Op op'})
where
addRead :: Stms GPU -> (PatElem Type, PatElem dec) -> ReduceM (Stms GPU)
addRead Stms GPU
stms (pe :: PatElem Type
pe@(PatElem VName
n Type
_), PatElem VName
dev dec
_)
| VName
n VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
dev = Stms GPU -> ReduceM (Stms GPU)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms GPU
stms
| Bool
otherwise = PatElem Type
pe PatElem Type -> (VName, Stms GPU) -> ReduceM (Stms GPU)
`migratedTo` (VName
dev, Stms GPU
stms)
optimizeWithAccInput :: VName -> WithAccInput GPU -> ReduceM (WithAccInput GPU)
optimizeWithAccInput :: VName -> WithAccInput GPU -> ReduceM (WithAccInput GPU)
optimizeWithAccInput VName
_ (ShapeBase SubExp
shape, [VName]
arrs, Maybe (Lambda GPU, [SubExp])
Nothing) = WithAccInput GPU -> ReduceM (WithAccInput GPU)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ShapeBase SubExp
shape, [VName]
arrs, Maybe (Lambda GPU, [SubExp])
forall a. Maybe a
Nothing)
optimizeWithAccInput VName
acc (ShapeBase SubExp
shape, [VName]
arrs, Just (Lambda GPU
op, [SubExp]
nes)) = do
Bool
device_only <- (MigrationTable -> Bool) -> ReduceM Bool
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (VName -> MigrationTable -> Bool
shouldMove VName
acc)
if Bool
device_only
then do
Lambda GPU
op' <- Lambda GPU -> ReduceM (Lambda GPU)
addReadsToLambda Lambda GPU
op
WithAccInput GPU -> ReduceM (WithAccInput GPU)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ShapeBase SubExp
shape, [VName]
arrs, (Lambda GPU, [SubExp]) -> Maybe (Lambda GPU, [SubExp])
forall a. a -> Maybe a
Just (Lambda GPU
op', [SubExp]
nes))
else do
let body :: Body GPU
body = Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
op
Stms GPU
stms' <- ReduceM (Stms GPU) -> ReduceM (Stms GPU)
forall a. ReduceM a -> ReduceM a
noGPUBody (ReduceM (Stms GPU) -> ReduceM (Stms GPU))
-> ReduceM (Stms GPU) -> ReduceM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stms GPU -> ReduceM (Stms GPU)
optimizeStms (Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms Body GPU
body)
let op' :: Lambda GPU
op' = Lambda GPU
op {lambdaBody = body {bodyStms = stms'}}
WithAccInput GPU -> ReduceM (WithAccInput GPU)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ShapeBase SubExp
shape, [VName]
arrs, (Lambda GPU, [SubExp]) -> Maybe (Lambda GPU, [SubExp])
forall a. a -> Maybe a
Just (Lambda GPU
op', [SubExp]
nes))
optimizeHostOp :: HostOp op GPU -> ReduceM (HostOp op GPU)
optimizeHostOp :: forall (op :: * -> *). HostOp op GPU -> ReduceM (HostOp op GPU)
optimizeHostOp (SegOp (SegMap SegLevel
lvl SegSpace
space [Type]
types KernelBody GPU
kbody)) =
SegOp SegLevel GPU -> HostOp op GPU
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp (SegOp SegLevel GPU -> HostOp op GPU)
-> (KernelBody GPU -> SegOp SegLevel GPU)
-> KernelBody GPU
-> HostOp op GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegLevel
-> SegSpace -> [Type] -> KernelBody GPU -> SegOp SegLevel GPU
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
lvl SegSpace
space [Type]
types (KernelBody GPU -> HostOp op GPU)
-> ReduceM (KernelBody GPU) -> ReduceM (HostOp op GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelBody GPU -> ReduceM (KernelBody GPU)
addReadsToKernelBody KernelBody GPU
kbody
optimizeHostOp (SegOp (SegRed SegLevel
lvl SegSpace
space [Type]
types KernelBody GPU
kbody [SegBinOp GPU]
ops)) = do
[SegBinOp GPU]
ops' <- (SegBinOp GPU -> ReduceM (SegBinOp GPU))
-> [SegBinOp GPU] -> ReduceM [SegBinOp GPU]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SegBinOp GPU -> ReduceM (SegBinOp GPU)
addReadsToSegBinOp [SegBinOp GPU]
ops
KernelBody GPU
kbody' <- KernelBody GPU -> ReduceM (KernelBody GPU)
addReadsToKernelBody KernelBody GPU
kbody
HostOp op GPU -> ReduceM (HostOp op GPU)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (HostOp op GPU -> ReduceM (HostOp op GPU))
-> (SegOp SegLevel GPU -> HostOp op GPU)
-> SegOp SegLevel GPU
-> ReduceM (HostOp op GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp SegLevel GPU -> HostOp op GPU
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp (SegOp SegLevel GPU -> ReduceM (HostOp op GPU))
-> SegOp SegLevel GPU -> ReduceM (HostOp op GPU)
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [Type]
-> KernelBody GPU
-> [SegBinOp GPU]
-> SegOp SegLevel GPU
forall lvl rep.
lvl
-> SegSpace
-> [Type]
-> KernelBody rep
-> [SegBinOp rep]
-> SegOp lvl rep
SegRed SegLevel
lvl SegSpace
space [Type]
types KernelBody GPU
kbody' [SegBinOp GPU]
ops'
optimizeHostOp (SegOp (SegScan SegLevel
lvl SegSpace
space [Type]
types KernelBody GPU
kbody [SegBinOp GPU]
ops)) = do
[SegBinOp GPU]
ops' <- (SegBinOp GPU -> ReduceM (SegBinOp GPU))
-> [SegBinOp GPU] -> ReduceM [SegBinOp GPU]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SegBinOp GPU -> ReduceM (SegBinOp GPU)
addReadsToSegBinOp [SegBinOp GPU]
ops
KernelBody GPU
kbody' <- KernelBody GPU -> ReduceM (KernelBody GPU)
addReadsToKernelBody KernelBody GPU
kbody
HostOp op GPU -> ReduceM (HostOp op GPU)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (HostOp op GPU -> ReduceM (HostOp op GPU))
-> (SegOp SegLevel GPU -> HostOp op GPU)
-> SegOp SegLevel GPU
-> ReduceM (HostOp op GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp SegLevel GPU -> HostOp op GPU
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp (SegOp SegLevel GPU -> ReduceM (HostOp op GPU))
-> SegOp SegLevel GPU -> ReduceM (HostOp op GPU)
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [Type]
-> KernelBody GPU
-> [SegBinOp GPU]
-> SegOp SegLevel GPU
forall lvl rep.
lvl
-> SegSpace
-> [Type]
-> KernelBody rep
-> [SegBinOp rep]
-> SegOp lvl rep
SegScan SegLevel
lvl SegSpace
space [Type]
types KernelBody GPU
kbody' [SegBinOp GPU]
ops'
optimizeHostOp (SegOp (SegHist SegLevel
lvl SegSpace
space [Type]
types KernelBody GPU
kbody [HistOp GPU]
ops)) = do
[HistOp GPU]
ops' <- (HistOp GPU -> ReduceM (HistOp GPU))
-> [HistOp GPU] -> ReduceM [HistOp GPU]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM HistOp GPU -> ReduceM (HistOp GPU)
addReadsToHistOp [HistOp GPU]
ops
KernelBody GPU
kbody' <- KernelBody GPU -> ReduceM (KernelBody GPU)
addReadsToKernelBody KernelBody GPU
kbody
HostOp op GPU -> ReduceM (HostOp op GPU)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (HostOp op GPU -> ReduceM (HostOp op GPU))
-> (SegOp SegLevel GPU -> HostOp op GPU)
-> SegOp SegLevel GPU
-> ReduceM (HostOp op GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp SegLevel GPU -> HostOp op GPU
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp (SegOp SegLevel GPU -> ReduceM (HostOp op GPU))
-> SegOp SegLevel GPU -> ReduceM (HostOp op GPU)
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [Type]
-> KernelBody GPU
-> [HistOp GPU]
-> SegOp SegLevel GPU
forall lvl rep.
lvl
-> SegSpace
-> [Type]
-> KernelBody rep
-> [HistOp rep]
-> SegOp lvl rep
SegHist SegLevel
lvl SegSpace
space [Type]
types KernelBody GPU
kbody' [HistOp GPU]
ops'
optimizeHostOp (SizeOp SizeOp
op) =
HostOp op GPU -> ReduceM (HostOp op GPU)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SizeOp -> HostOp op GPU
forall (op :: * -> *) rep. SizeOp -> HostOp op rep
SizeOp SizeOp
op)
optimizeHostOp OtherOp {} =
String -> ReduceM (HostOp op GPU)
forall a. String -> a
compilerBugS String
"optimizeHostOp: unhandled OtherOp"
optimizeHostOp (GPUBody [Type]
types Body GPU
body) =
[Type] -> Body GPU -> HostOp op GPU
forall (op :: * -> *) rep. [Type] -> Body rep -> HostOp op rep
GPUBody [Type]
types (Body GPU -> HostOp op GPU)
-> ReduceM (Body GPU) -> ReduceM (HostOp op GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body GPU -> ReduceM (Body GPU)
addReadsToBody Body GPU
body
withSuffix :: Name -> String -> Name
withSuffix :: Name -> String -> Name
withSuffix Name
name String
sfx = Text -> Name
nameFromText (Text -> Name) -> Text -> Name
forall a b. (a -> b) -> a -> b
$ Text -> Text -> Text
T.append (Name -> Text
nameToText Name
name) (String -> Text
T.pack String
sfx)
newtype ReduceM a = ReduceM (StateT State (Reader MigrationTable) a)
deriving
( (forall a b. (a -> b) -> ReduceM a -> ReduceM b)
-> (forall a b. a -> ReduceM b -> ReduceM a) -> Functor ReduceM
forall a b. a -> ReduceM b -> ReduceM a
forall a b. (a -> b) -> ReduceM a -> ReduceM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> ReduceM a -> ReduceM b
fmap :: forall a b. (a -> b) -> ReduceM a -> ReduceM b
$c<$ :: forall a b. a -> ReduceM b -> ReduceM a
<$ :: forall a b. a -> ReduceM b -> ReduceM a
Functor,
Functor ReduceM
Functor ReduceM =>
(forall a. a -> ReduceM a)
-> (forall a b. ReduceM (a -> b) -> ReduceM a -> ReduceM b)
-> (forall a b c.
(a -> b -> c) -> ReduceM a -> ReduceM b -> ReduceM c)
-> (forall a b. ReduceM a -> ReduceM b -> ReduceM b)
-> (forall a b. ReduceM a -> ReduceM b -> ReduceM a)
-> Applicative ReduceM
forall a. a -> ReduceM a
forall a b. ReduceM a -> ReduceM b -> ReduceM a
forall a b. ReduceM a -> ReduceM b -> ReduceM b
forall a b. ReduceM (a -> b) -> ReduceM a -> ReduceM b
forall a b c. (a -> b -> c) -> ReduceM a -> ReduceM b -> ReduceM c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall a. a -> ReduceM a
pure :: forall a. a -> ReduceM a
$c<*> :: forall a b. ReduceM (a -> b) -> ReduceM a -> ReduceM b
<*> :: forall a b. ReduceM (a -> b) -> ReduceM a -> ReduceM b
$cliftA2 :: forall a b c. (a -> b -> c) -> ReduceM a -> ReduceM b -> ReduceM c
liftA2 :: forall a b c. (a -> b -> c) -> ReduceM a -> ReduceM b -> ReduceM c
$c*> :: forall a b. ReduceM a -> ReduceM b -> ReduceM b
*> :: forall a b. ReduceM a -> ReduceM b -> ReduceM b
$c<* :: forall a b. ReduceM a -> ReduceM b -> ReduceM a
<* :: forall a b. ReduceM a -> ReduceM b -> ReduceM a
Applicative,
Applicative ReduceM
Applicative ReduceM =>
(forall a b. ReduceM a -> (a -> ReduceM b) -> ReduceM b)
-> (forall a b. ReduceM a -> ReduceM b -> ReduceM b)
-> (forall a. a -> ReduceM a)
-> Monad ReduceM
forall a. a -> ReduceM a
forall a b. ReduceM a -> ReduceM b -> ReduceM b
forall a b. ReduceM a -> (a -> ReduceM b) -> ReduceM b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall a b. ReduceM a -> (a -> ReduceM b) -> ReduceM b
>>= :: forall a b. ReduceM a -> (a -> ReduceM b) -> ReduceM b
$c>> :: forall a b. ReduceM a -> ReduceM b -> ReduceM b
>> :: forall a b. ReduceM a -> ReduceM b -> ReduceM b
$creturn :: forall a. a -> ReduceM a
return :: forall a. a -> ReduceM a
Monad,
MonadState State,
MonadReader MigrationTable
)
runReduceM :: (MonadFreshNames m) => MigrationTable -> ReduceM a -> m a
runReduceM :: forall (m :: * -> *) a.
MonadFreshNames m =>
MigrationTable -> ReduceM a -> m a
runReduceM MigrationTable
mt (ReduceM StateT State (Reader MigrationTable) a
m) = (VNameSource -> (a, VNameSource)) -> m a
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (a, VNameSource)) -> m a)
-> (VNameSource -> (a, VNameSource)) -> m a
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
(State -> VNameSource) -> (a, State) -> (a, VNameSource)
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second State -> VNameSource
stateNameSource (Reader MigrationTable (a, State) -> MigrationTable -> (a, State)
forall r a. Reader r a -> r -> a
runReader (StateT State (Reader MigrationTable) a
-> State -> Reader MigrationTable (a, State)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT StateT State (Reader MigrationTable) a
m (VNameSource -> State
initialState VNameSource
src)) MigrationTable
mt)
instance MonadFreshNames ReduceM where
getNameSource :: ReduceM VNameSource
getNameSource = (State -> VNameSource) -> ReduceM VNameSource
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets State -> VNameSource
stateNameSource
putNameSource :: VNameSource -> ReduceM ()
putNameSource VNameSource
src = (State -> State) -> ReduceM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((State -> State) -> ReduceM ()) -> (State -> State) -> ReduceM ()
forall a b. (a -> b) -> a -> b
$ \State
s -> State
s {stateNameSource = src}
data State = State
{
State -> VNameSource
stateNameSource :: VNameSource,
State -> IntMap (Name, Type, VName, Bool)
stateMigrated :: IM.IntMap (Name, Type, VName, Bool),
State -> Bool
stateGPUBodyOk :: Bool
}
initialState :: VNameSource -> State
initialState :: VNameSource -> State
initialState VNameSource
ns =
State
{ stateNameSource :: VNameSource
stateNameSource = VNameSource
ns,
stateMigrated :: IntMap (Name, Type, VName, Bool)
stateMigrated = IntMap (Name, Type, VName, Bool)
forall a. Monoid a => a
mempty,
stateGPUBodyOk :: Bool
stateGPUBodyOk = Bool
True
}
noGPUBody :: ReduceM a -> ReduceM a
noGPUBody :: forall a. ReduceM a -> ReduceM a
noGPUBody ReduceM a
m = do
Bool
prev <- (State -> Bool) -> ReduceM Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets State -> Bool
stateGPUBodyOk
(State -> State) -> ReduceM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((State -> State) -> ReduceM ()) -> (State -> State) -> ReduceM ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateGPUBodyOk = False}
a
res <- ReduceM a
m
(State -> State) -> ReduceM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((State -> State) -> ReduceM ()) -> (State -> State) -> ReduceM ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateGPUBodyOk = prev}
a -> ReduceM a
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
res
arrayizePatElem :: PatElem Type -> ReduceM (PatElem Type)
arrayizePatElem :: PatElem Type -> ReduceM (PatElem Type)
arrayizePatElem (PatElem VName
n Type
t) = do
let name :: Name
name = VName -> Name
baseName VName
n Name -> String -> Name
`withSuffix` String
"_dev"
VName
dev <- VName -> ReduceM VName
forall (m :: * -> *). MonadFreshNames m => VName -> m VName
newName (Name -> Int -> VName
VName Name
name Int
0)
let dev_t :: Type
dev_t = Type
t Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1
PatElem Type -> ReduceM (PatElem Type)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> Type -> PatElem Type
forall dec. VName -> dec -> PatElem dec
PatElem VName
dev Type
dev_t)
movedTo :: Ident -> VName -> ReduceM ()
movedTo :: Ident -> VName -> ReduceM ()
movedTo = Bool -> Ident -> VName -> ReduceM ()
recordMigration Bool
False
aliasedBy :: Ident -> VName -> ReduceM ()
aliasedBy :: Ident -> VName -> ReduceM ()
aliasedBy = Bool -> Ident -> VName -> ReduceM ()
recordMigration Bool
True
recordMigration :: Bool -> Ident -> VName -> ReduceM ()
recordMigration :: Bool -> Ident -> VName -> ReduceM ()
recordMigration Bool
host (Ident VName
x Type
t) VName
arr =
(State -> State) -> ReduceM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((State -> State) -> ReduceM ()) -> (State -> State) -> ReduceM ()
forall a b. (a -> b) -> a -> b
$ \State
st ->
let migrated :: IntMap (Name, Type, VName, Bool)
migrated = State -> IntMap (Name, Type, VName, Bool)
stateMigrated State
st
entry :: (Name, Type, VName, Bool)
entry = (VName -> Name
baseName VName
x, Type
t, VName
arr, Bool
host)
migrated' :: IntMap (Name, Type, VName, Bool)
migrated' = Int
-> (Name, Type, VName, Bool)
-> IntMap (Name, Type, VName, Bool)
-> IntMap (Name, Type, VName, Bool)
forall a. Int -> a -> IntMap a -> IntMap a
IM.insert (VName -> Int
baseTag VName
x) (Name, Type, VName, Bool)
entry IntMap (Name, Type, VName, Bool)
migrated
in State
st {stateMigrated = migrated'}
migratedTo :: PatElem Type -> (VName, Stms GPU) -> ReduceM (Stms GPU)
migratedTo :: PatElem Type -> (VName, Stms GPU) -> ReduceM (Stms GPU)
migratedTo PatElem Type
pe (VName
dev, Stms GPU
stms) = do
Bool
used <- (MigrationTable -> Bool) -> ReduceM Bool
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (VName -> MigrationTable -> Bool
usedOnHost (VName -> MigrationTable -> Bool)
-> VName -> MigrationTable -> Bool
forall a b. (a -> b) -> a -> b
$ PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName PatElem Type
pe)
if Bool
used
then PatElem Type -> Ident
forall dec. Typed dec => PatElem dec -> Ident
patElemIdent PatElem Type
pe Ident -> VName -> ReduceM ()
`aliasedBy` VName
dev ReduceM () -> ReduceM (Stms GPU) -> ReduceM (Stms GPU)
forall a b. ReduceM a -> ReduceM b -> ReduceM b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Stms GPU -> ReduceM (Stms GPU)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> PatElem Type -> Exp GPU -> Stm GPU
bind PatElem Type
pe (VName -> Exp GPU
eIndex VName
dev))
else PatElem Type -> Ident
forall dec. Typed dec => PatElem dec -> Ident
patElemIdent PatElem Type
pe Ident -> VName -> ReduceM ()
`movedTo` VName
dev ReduceM () -> ReduceM (Stms GPU) -> ReduceM (Stms GPU)
forall a b. ReduceM a -> ReduceM b -> ReduceM b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Stms GPU -> ReduceM (Stms GPU)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms GPU
stms
useScalar :: Stms GPU -> VName -> ReduceM (Stms GPU, VName)
useScalar :: Stms GPU -> VName -> ReduceM (Stms GPU, VName)
useScalar Stms GPU
stms VName
n = do
Maybe (Name, Type, VName, Bool)
entry <- (State -> Maybe (Name, Type, VName, Bool))
-> ReduceM (Maybe (Name, Type, VName, Bool))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((State -> Maybe (Name, Type, VName, Bool))
-> ReduceM (Maybe (Name, Type, VName, Bool)))
-> (State -> Maybe (Name, Type, VName, Bool))
-> ReduceM (Maybe (Name, Type, VName, Bool))
forall a b. (a -> b) -> a -> b
$ Int
-> IntMap (Name, Type, VName, Bool)
-> Maybe (Name, Type, VName, Bool)
forall a. Int -> IntMap a -> Maybe a
IM.lookup (VName -> Int
baseTag VName
n) (IntMap (Name, Type, VName, Bool)
-> Maybe (Name, Type, VName, Bool))
-> (State -> IntMap (Name, Type, VName, Bool))
-> State
-> Maybe (Name, Type, VName, Bool)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. State -> IntMap (Name, Type, VName, Bool)
stateMigrated
case Maybe (Name, Type, VName, Bool)
entry of
Maybe (Name, Type, VName, Bool)
Nothing ->
(Stms GPU, VName) -> ReduceM (Stms GPU, VName)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms, VName
n)
Just (Name
_, Type
_, VName
_, Bool
True) ->
(Stms GPU, VName) -> ReduceM (Stms GPU, VName)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms, VName
n)
Just (Name
name, Type
t, VName
arr, Bool
_) ->
do
VName
n' <- VName -> ReduceM VName
forall (m :: * -> *). MonadFreshNames m => VName -> m VName
newName (Name -> Int -> VName
VName Name
name Int
0)
let stm :: Stm GPU
stm = PatElem Type -> Exp GPU -> Stm GPU
bind (VName -> Type -> PatElem Type
forall dec. VName -> dec -> PatElem dec
PatElem VName
n' Type
t) (VName -> Exp GPU
eIndex VName
arr)
(Stms GPU, VName) -> ReduceM (Stms GPU, VName)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm, VName
n')
eIndex :: VName -> Exp GPU
eIndex :: VName -> Exp GPU
eIndex VName
arr = BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr ([DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0])
bind :: PatElem Type -> Exp GPU -> Stm GPU
bind :: PatElem Type -> Exp GPU -> Stm GPU
bind PatElem Type
pe = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type
pe]) (Certs -> Attrs -> () -> StmAux ()
forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
forall a. Monoid a => a
mempty Attrs
forall a. Monoid a => a
mempty ())
storedScalar :: SubExp -> ReduceM (Maybe VName)
storedScalar :: SubExp -> ReduceM (Maybe VName)
storedScalar (Constant PrimValue
_) = Maybe VName -> ReduceM (Maybe VName)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe VName
forall a. Maybe a
Nothing
storedScalar (Var VName
n) = do
Maybe (Name, Type, VName, Bool)
entry <- (State -> Maybe (Name, Type, VName, Bool))
-> ReduceM (Maybe (Name, Type, VName, Bool))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((State -> Maybe (Name, Type, VName, Bool))
-> ReduceM (Maybe (Name, Type, VName, Bool)))
-> (State -> Maybe (Name, Type, VName, Bool))
-> ReduceM (Maybe (Name, Type, VName, Bool))
forall a b. (a -> b) -> a -> b
$ Int
-> IntMap (Name, Type, VName, Bool)
-> Maybe (Name, Type, VName, Bool)
forall a. Int -> IntMap a -> Maybe a
IM.lookup (VName -> Int
baseTag VName
n) (IntMap (Name, Type, VName, Bool)
-> Maybe (Name, Type, VName, Bool))
-> (State -> IntMap (Name, Type, VName, Bool))
-> State
-> Maybe (Name, Type, VName, Bool)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. State -> IntMap (Name, Type, VName, Bool)
stateMigrated
Maybe VName -> ReduceM (Maybe VName)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe VName -> ReduceM (Maybe VName))
-> Maybe VName -> ReduceM (Maybe VName)
forall a b. (a -> b) -> a -> b
$ ((Name, Type, VName, Bool) -> VName)
-> Maybe (Name, Type, VName, Bool) -> Maybe VName
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(Name
_, Type
_, VName
arr, Bool
_) -> VName
arr) Maybe (Name, Type, VName, Bool)
entry
storeScalar :: Stms GPU -> SubExp -> Type -> ReduceM (Stms GPU, VName)
storeScalar :: Stms GPU -> SubExp -> Type -> ReduceM (Stms GPU, VName)
storeScalar Stms GPU
stms SubExp
se Type
t = do
Maybe (Name, Type, VName, Bool)
entry <- case SubExp
se of
Var VName
n -> (State -> Maybe (Name, Type, VName, Bool))
-> ReduceM (Maybe (Name, Type, VName, Bool))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((State -> Maybe (Name, Type, VName, Bool))
-> ReduceM (Maybe (Name, Type, VName, Bool)))
-> (State -> Maybe (Name, Type, VName, Bool))
-> ReduceM (Maybe (Name, Type, VName, Bool))
forall a b. (a -> b) -> a -> b
$ Int
-> IntMap (Name, Type, VName, Bool)
-> Maybe (Name, Type, VName, Bool)
forall a. Int -> IntMap a -> Maybe a
IM.lookup (VName -> Int
baseTag VName
n) (IntMap (Name, Type, VName, Bool)
-> Maybe (Name, Type, VName, Bool))
-> (State -> IntMap (Name, Type, VName, Bool))
-> State
-> Maybe (Name, Type, VName, Bool)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. State -> IntMap (Name, Type, VName, Bool)
stateMigrated
SubExp
_ -> Maybe (Name, Type, VName, Bool)
-> ReduceM (Maybe (Name, Type, VName, Bool))
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Name, Type, VName, Bool)
forall a. Maybe a
Nothing
case Maybe (Name, Type, VName, Bool)
entry of
Just (Name
_, Type
_, VName
arr, Bool
_) -> (Stms GPU, VName) -> ReduceM (Stms GPU, VName)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms, VName
arr)
Maybe (Name, Type, VName, Bool)
Nothing -> do
Bool
gpubody_ok <- (State -> Bool) -> ReduceM Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets State -> Bool
stateGPUBodyOk
case SubExp
se of
Var VName
n | Bool
gpubody_ok -> do
VName
n' <- VName -> ReduceM VName
forall (m :: * -> *). MonadFreshNames m => VName -> m VName
newName VName
n
let stm :: Stm GPU
stm = PatElem Type -> Exp GPU -> Stm GPU
bind (VName -> Type -> PatElem Type
forall dec. VName -> dec -> PatElem dec
PatElem VName
n' Type
t) (BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se)
Stm GPU
gpubody <- RewriteM (Stm GPU) -> ReduceM (Stm GPU)
inGPUBody (Stm GPU -> RewriteM (Stm GPU)
forall a. a -> StateT RState ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stm GPU
stm)
let dev :: VName
dev = PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName (PatElem Type -> VName) -> PatElem Type -> VName
forall a b. (a -> b) -> a -> b
$ [PatElem Type] -> PatElem Type
forall a. HasCallStack => [a] -> a
head ([PatElem Type] -> PatElem Type) -> [PatElem Type] -> PatElem Type
forall a b. (a -> b) -> a -> b
$ Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems (Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
gpubody)
(Stms GPU, VName) -> ReduceM (Stms GPU, VName)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
gpubody, VName
dev)
Var VName
n -> do
PatElem Type
pe <- PatElem Type -> ReduceM (PatElem Type)
arrayizePatElem (VName -> Type -> PatElem Type
forall dec. VName -> dec -> PatElem dec
PatElem VName
n Type
t)
let shape :: ShapeBase SubExp
shape = [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1]
let stm :: Stm GPU
stm = PatElem Type -> Exp GPU -> Stm GPU
bind PatElem Type
pe (BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> SubExp -> BasicOp
Replicate ShapeBase SubExp
shape SubExp
se)
(Stms GPU, VName) -> ReduceM (Stms GPU, VName)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm, PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName PatElem Type
pe)
SubExp
_ -> do
let n :: VName
n = Name -> Int -> VName
VName (String -> Name
nameFromString String
"const") Int
0
PatElem Type
pe <- PatElem Type -> ReduceM (PatElem Type)
arrayizePatElem (VName -> Type -> PatElem Type
forall dec. VName -> dec -> PatElem dec
PatElem VName
n Type
t)
let stm :: Stm GPU
stm = PatElem Type -> Exp GPU -> Stm GPU
bind PatElem Type
pe (BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Type -> BasicOp
ArrayLit [SubExp
se] Type
t)
(Stms GPU, VName) -> ReduceM (Stms GPU, VName)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm, PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName PatElem Type
pe)
resolveName :: VName -> ReduceM VName
resolveName :: VName -> ReduceM VName
resolveName VName
n = do
Maybe (Name, Type, VName, Bool)
entry <- (State -> Maybe (Name, Type, VName, Bool))
-> ReduceM (Maybe (Name, Type, VName, Bool))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((State -> Maybe (Name, Type, VName, Bool))
-> ReduceM (Maybe (Name, Type, VName, Bool)))
-> (State -> Maybe (Name, Type, VName, Bool))
-> ReduceM (Maybe (Name, Type, VName, Bool))
forall a b. (a -> b) -> a -> b
$ Int
-> IntMap (Name, Type, VName, Bool)
-> Maybe (Name, Type, VName, Bool)
forall a. Int -> IntMap a -> Maybe a
IM.lookup (VName -> Int
baseTag VName
n) (IntMap (Name, Type, VName, Bool)
-> Maybe (Name, Type, VName, Bool))
-> (State -> IntMap (Name, Type, VName, Bool))
-> State
-> Maybe (Name, Type, VName, Bool)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. State -> IntMap (Name, Type, VName, Bool)
stateMigrated
case Maybe (Name, Type, VName, Bool)
entry of
Maybe (Name, Type, VName, Bool)
Nothing -> VName -> ReduceM VName
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
n
Just (Name
_, Type
_, VName
_, Bool
True) -> VName -> ReduceM VName
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
n
Just (Name
_, Type
_, VName
arr, Bool
_) -> VName -> ReduceM VName
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
arr
resolveSubExp :: SubExp -> ReduceM SubExp
resolveSubExp :: SubExp -> ReduceM SubExp
resolveSubExp (Var VName
n) = VName -> SubExp
Var (VName -> SubExp) -> ReduceM VName -> ReduceM SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ReduceM VName
resolveName VName
n
resolveSubExp SubExp
cnst = SubExp -> ReduceM SubExp
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
cnst
resolveSubExpRes :: SubExpRes -> ReduceM SubExpRes
resolveSubExpRes :: SubExpRes -> ReduceM SubExpRes
resolveSubExpRes (SubExpRes Certs
certs SubExp
se) =
Certs -> SubExp -> SubExpRes
SubExpRes Certs
certs (SubExp -> SubExpRes) -> ReduceM SubExp -> ReduceM SubExpRes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> ReduceM SubExp
resolveSubExp SubExp
se
resolveResult :: Result -> ReduceM Result
resolveResult :: Result -> ReduceM Result
resolveResult = (SubExpRes -> ReduceM SubExpRes) -> Result -> ReduceM Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExpRes -> ReduceM SubExpRes
resolveSubExpRes
moveStm :: Stms GPU -> Stm GPU -> ReduceM (Stms GPU)
moveStm :: Stms GPU -> Stm GPU -> ReduceM (Stms GPU)
moveStm Stms GPU
out (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (BasicOp (ArrayLit [SubExp
se] Type
t')))
| Pat [PatElem VName
n LetDec GPU
_] <- Pat (LetDec GPU)
pat =
do
let n' :: VName
n' = Name -> Int -> VName
VName (VName -> Name
baseName VName
n Name -> String -> Name
`withSuffix` String
"_inner") Int
0
let pat' :: Pat Type
pat' = [PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [VName -> Type -> PatElem Type
forall dec. VName -> dec -> PatElem dec
PatElem VName
n' Type
t']
let e' :: Exp rep
e' = BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (SubExp -> BasicOp
SubExp SubExp
se)
let stm' :: Stm GPU
stm' = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec GPU)
pat' StmAux (ExpDec GPU)
aux Exp GPU
forall {rep}. Exp rep
e'
Stm GPU
gpubody <- RewriteM (Stm GPU) -> ReduceM (Stm GPU)
inGPUBody (Stm GPU -> RewriteM (Stm GPU)
rewriteStm Stm GPU
stm')
Stms GPU -> ReduceM (Stms GPU)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
gpubody {stmPat = pat})
moveStm Stms GPU
out Stm GPU
stm = do
Stm GPU
gpubody <- RewriteM (Stm GPU) -> ReduceM (Stm GPU)
inGPUBody (Stm GPU -> RewriteM (Stm GPU)
rewriteStm Stm GPU
stm)
let arrs :: [(PatElem Type, PatElem Type)]
arrs = [PatElem Type] -> [PatElem Type] -> [(PatElem Type, PatElem Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems (Pat Type -> [PatElem Type]) -> Pat Type -> [PatElem Type]
forall a b. (a -> b) -> a -> b
$ Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm) (Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems (Pat Type -> [PatElem Type]) -> Pat Type -> [PatElem Type]
forall a b. (a -> b) -> a -> b
$ Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
gpubody)
(Stms GPU -> (PatElem Type, PatElem Type) -> ReduceM (Stms GPU))
-> Stms GPU -> [(PatElem Type, PatElem Type)] -> ReduceM (Stms GPU)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Stms GPU -> (PatElem Type, PatElem Type) -> ReduceM (Stms GPU)
addRead (Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
gpubody) [(PatElem Type, PatElem Type)]
arrs
where
addRead :: Stms GPU -> (PatElem Type, PatElem Type) -> ReduceM (Stms GPU)
addRead Stms GPU
stms (pe :: PatElem Type
pe@(PatElem VName
_ Type
t), PatElem VName
dev Type
dev_t) =
let add' :: Exp GPU -> f (Stms GPU)
add' Exp GPU
e = Stms GPU -> f (Stms GPU)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU -> f (Stms GPU)) -> Stms GPU -> f (Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stms GPU
stms Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> PatElem Type -> Exp GPU -> Stm GPU
bind PatElem Type
pe Exp GPU
e
add :: BasicOp -> ReduceM (Stms GPU)
add = Exp GPU -> ReduceM (Stms GPU)
forall {f :: * -> *}. Applicative f => Exp GPU -> f (Stms GPU)
add' (Exp GPU -> ReduceM (Stms GPU))
-> (BasicOp -> Exp GPU) -> BasicOp -> ReduceM (Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp
in case Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
dev_t of
Int
0 -> BasicOp -> ReduceM (Stms GPU)
add (BasicOp -> ReduceM (Stms GPU)) -> BasicOp -> ReduceM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (VName -> SubExp
Var VName
dev)
Int
1 | Type
t Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit -> Exp GPU -> ReduceM (Stms GPU)
forall {f :: * -> *}. Applicative f => Exp GPU -> f (Stms GPU)
add' (VName -> Exp GPU
eIndex VName
dev)
Int
1 -> PatElem Type
pe PatElem Type -> (VName, Stms GPU) -> ReduceM (Stms GPU)
`migratedTo` (VName
dev, Stms GPU
stms)
Int
_ -> BasicOp -> ReduceM (Stms GPU)
add (BasicOp -> ReduceM (Stms GPU)) -> BasicOp -> ReduceM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
dev (Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
dev_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0])
inGPUBody :: RewriteM (Stm GPU) -> ReduceM (Stm GPU)
inGPUBody :: RewriteM (Stm GPU) -> ReduceM (Stm GPU)
inGPUBody RewriteM (Stm GPU)
m = do
(Stm GPU
stm, RState
st) <- RewriteM (Stm GPU) -> RState -> ReduceM (Stm GPU, RState)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT RewriteM (Stm GPU)
m RState
initialRState
let prologue :: Stms GPU
prologue = RState -> Stms GPU
rewritePrologue RState
st
let pes :: [PatElem Type]
pes = Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems (Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm)
Pat Type
pat <- [PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem Type] -> Pat Type)
-> ReduceM [PatElem Type] -> ReduceM (Pat Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (PatElem Type -> ReduceM (PatElem Type))
-> [PatElem Type] -> ReduceM [PatElem Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM PatElem Type -> ReduceM (PatElem Type)
arrayizePatElem [PatElem Type]
pes
let aux :: StmAux ()
aux = Certs -> Attrs -> () -> StmAux ()
forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
forall a. Monoid a => a
mempty Attrs
forall a. Monoid a => a
mempty ()
let types :: [Type]
types = (PatElem Type -> Type) -> [PatElem Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map PatElem Type -> Type
forall dec. Typed dec => PatElem dec -> Type
patElemType [PatElem Type]
pes
let res :: Result
res = (PatElem Type -> SubExpRes) -> [PatElem Type] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (Certs -> SubExp -> SubExpRes
SubExpRes Certs
forall a. Monoid a => a
mempty (SubExp -> SubExpRes)
-> (PatElem Type -> SubExp) -> PatElem Type -> SubExpRes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> SubExp)
-> (PatElem Type -> VName) -> PatElem Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName) [PatElem Type]
pes
let body :: Body GPU
body = BodyDec GPU -> Stms GPU -> Result -> Body GPU
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () (Stms GPU
prologue Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm) Result
res
let e :: Exp GPU
e = Op GPU -> Exp GPU
forall rep. Op rep -> Exp rep
Op ([Type] -> Body GPU -> HostOp SOAC GPU
forall (op :: * -> *) rep. [Type] -> Body rep -> HostOp op rep
GPUBody [Type]
types Body GPU
body)
Stm GPU -> ReduceM (Stm GPU)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec GPU)
pat StmAux ()
StmAux (ExpDec GPU)
aux Exp GPU
e)
type RewriteM = StateT RState ReduceM
data RState = RState
{
RState -> IntMap VName
rewriteRenames :: IM.IntMap VName,
RState -> Stms GPU
rewritePrologue :: Stms GPU
}
initialRState :: RState
initialRState :: RState
initialRState =
RState
{ rewriteRenames :: IntMap VName
rewriteRenames = IntMap VName
forall a. Monoid a => a
mempty,
rewritePrologue :: Stms GPU
rewritePrologue = Stms GPU
forall a. Monoid a => a
mempty
}
addReadsToSegBinOp :: SegBinOp GPU -> ReduceM (SegBinOp GPU)
addReadsToSegBinOp :: SegBinOp GPU -> ReduceM (SegBinOp GPU)
addReadsToSegBinOp SegBinOp GPU
op = do
Lambda GPU
f' <- Lambda GPU -> ReduceM (Lambda GPU)
addReadsToLambda (SegBinOp GPU -> Lambda GPU
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPU
op)
SegBinOp GPU -> ReduceM (SegBinOp GPU)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SegBinOp GPU
op {segBinOpLambda = f'})
addReadsToHistOp :: HistOp GPU -> ReduceM (HistOp GPU)
addReadsToHistOp :: HistOp GPU -> ReduceM (HistOp GPU)
addReadsToHistOp HistOp GPU
op = do
Lambda GPU
f' <- Lambda GPU -> ReduceM (Lambda GPU)
addReadsToLambda (HistOp GPU -> Lambda GPU
forall rep. HistOp rep -> Lambda rep
histOp HistOp GPU
op)
HistOp GPU -> ReduceM (HistOp GPU)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (HistOp GPU
op {histOp = f'})
addReadsToLambda :: Lambda GPU -> ReduceM (Lambda GPU)
addReadsToLambda :: Lambda GPU -> ReduceM (Lambda GPU)
addReadsToLambda Lambda GPU
f = do
Body GPU
body' <- Body GPU -> ReduceM (Body GPU)
addReadsToBody (Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
f)
Lambda GPU -> ReduceM (Lambda GPU)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda GPU
f {lambdaBody = body'})
addReadsToBody :: Body GPU -> ReduceM (Body GPU)
addReadsToBody :: Body GPU -> ReduceM (Body GPU)
addReadsToBody Body GPU
body = do
(Body GPU
body', Stms GPU
prologue) <- Body GPU -> ReduceM (Body GPU, Stms GPU)
forall a. (FreeIn a, Substitute a) => a -> ReduceM (a, Stms GPU)
addReadsHelper Body GPU
body
Body GPU -> ReduceM (Body GPU)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Body GPU
body' {bodyStms = prologue >< bodyStms body'}
addReadsToKernelBody :: KernelBody GPU -> ReduceM (KernelBody GPU)
addReadsToKernelBody :: KernelBody GPU -> ReduceM (KernelBody GPU)
addReadsToKernelBody KernelBody GPU
kbody = do
(KernelBody GPU
kbody', Stms GPU
prologue) <- KernelBody GPU -> ReduceM (KernelBody GPU, Stms GPU)
forall a. (FreeIn a, Substitute a) => a -> ReduceM (a, Stms GPU)
addReadsHelper KernelBody GPU
kbody
KernelBody GPU -> ReduceM (KernelBody GPU)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure KernelBody GPU
kbody' {kernelBodyStms = prologue >< kernelBodyStms kbody'}
addReadsHelper :: (FreeIn a, Substitute a) => a -> ReduceM (a, Stms GPU)
addReadsHelper :: forall a. (FreeIn a, Substitute a) => a -> ReduceM (a, Stms GPU)
addReadsHelper a
x = do
let from :: [VName]
from = Names -> [VName]
namesToList (a -> Names
forall a. FreeIn a => a -> Names
freeIn a
x)
([VName]
to, RState
st) <- StateT RState ReduceM [VName]
-> RState -> ReduceM ([VName], RState)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT ((VName -> StateT RState ReduceM VName)
-> [VName] -> StateT RState ReduceM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> StateT RState ReduceM VName
rename [VName]
from) RState
initialRState
let rename_map :: Map VName VName
rename_map = [(VName, VName)] -> Map VName VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
from [VName]
to)
(a, Stms GPU) -> ReduceM (a, Stms GPU)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map VName VName -> a -> a
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
rename_map a
x, RState -> Stms GPU
rewritePrologue RState
st)
rewriteName :: VName -> RewriteM VName
rewriteName :: VName -> StateT RState ReduceM VName
rewriteName VName
n = do
VName
n' <- ReduceM VName -> StateT RState ReduceM VName
forall (m :: * -> *) a. Monad m => m a -> StateT RState m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (VName -> ReduceM VName
forall (m :: * -> *). MonadFreshNames m => VName -> m VName
newName VName
n)
(RState -> RState) -> StateT RState ReduceM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((RState -> RState) -> StateT RState ReduceM ())
-> (RState -> RState) -> StateT RState ReduceM ()
forall a b. (a -> b) -> a -> b
$ \RState
st -> RState
st {rewriteRenames = IM.insert (baseTag n) n' (rewriteRenames st)}
VName -> StateT RState ReduceM VName
forall a. a -> StateT RState ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
n'
rewriteBody :: Body GPU -> RewriteM (Body GPU)
rewriteBody :: Body GPU -> RewriteM (Body GPU)
rewriteBody (Body BodyDec GPU
_ Stms GPU
stms Result
res) = do
Stms GPU
stms' <- Stms GPU -> RewriteM (Stms GPU)
rewriteStms Stms GPU
stms
Result
res' <- Result -> RewriteM Result
renameResult Result
res
Body GPU -> RewriteM (Body GPU)
forall a. a -> StateT RState ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (BodyDec GPU -> Stms GPU -> Result -> Body GPU
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms GPU
stms' Result
res')
rewriteStms :: Stms GPU -> RewriteM (Stms GPU)
rewriteStms :: Stms GPU -> RewriteM (Stms GPU)
rewriteStms = (Stms GPU -> Stm GPU -> RewriteM (Stms GPU))
-> Stms GPU -> Stms GPU -> RewriteM (Stms GPU)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Stms GPU -> Stm GPU -> RewriteM (Stms GPU)
rewriteTo Stms GPU
forall a. Monoid a => a
mempty
where
rewriteTo :: Stms GPU -> Stm GPU -> RewriteM (Stms GPU)
rewriteTo Stms GPU
out Stm GPU
stm = do
Stm GPU
stm' <- Stm GPU -> RewriteM (Stm GPU)
rewriteStm Stm GPU
stm
Stms GPU -> RewriteM (Stms GPU)
forall a. a -> StateT RState ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU -> RewriteM (Stms GPU))
-> Stms GPU -> RewriteM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ case Stm GPU -> Exp GPU
forall rep. Stm rep -> Exp rep
stmExp Stm GPU
stm' of
Op (GPUBody [Type]
_ (Body BodyDec GPU
_ Stms GPU
stms Result
res)) ->
let pes :: [PatElem Type]
pes = Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems (Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm')
in (Stms GPU -> (PatElem Type, SubExpRes) -> Stms GPU)
-> Stms GPU -> [(PatElem Type, SubExpRes)] -> Stms GPU
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Stms GPU -> (PatElem Type, SubExpRes) -> Stms GPU
bnd (Stms GPU
out Stms GPU -> Stms GPU -> Stms GPU
forall a. Seq a -> Seq a -> Seq a
>< Stms GPU
stms) ([PatElem Type] -> Result -> [(PatElem Type, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem Type]
pes Result
res)
Exp GPU
_ -> Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm'
bnd :: Stms GPU -> (PatElem Type, SubExpRes) -> Stms GPU
bnd :: Stms GPU -> (PatElem Type, SubExpRes) -> Stms GPU
bnd Stms GPU
out (PatElem Type
pe, SubExpRes Certs
cs SubExp
se)
| Just Type
t' <- Int -> Type -> Maybe Type
forall u.
Int
-> TypeBase (ShapeBase SubExp) u
-> Maybe (TypeBase (ShapeBase SubExp) u)
peelArray Int
1 (PatElem Type -> Type
forall t. Typed t => t -> Type
typeOf PatElem Type
pe) =
Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type
pe]) (Certs -> Attrs -> () -> StmAux ()
forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
cs Attrs
forall a. Monoid a => a
mempty ()) (BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Type -> BasicOp
ArrayLit [SubExp
se] Type
t')
| Bool
otherwise =
Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type
pe]) (Certs -> Attrs -> () -> StmAux ()
forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
cs Attrs
forall a. Monoid a => a
mempty ()) (BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se)
rewriteStm :: Stm GPU -> RewriteM (Stm GPU)
rewriteStm :: Stm GPU -> RewriteM (Stm GPU)
rewriteStm (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux Exp GPU
e) = do
Exp GPU
e' <- Exp GPU -> RewriteM (Exp GPU)
rewriteExp Exp GPU
e
Pat Type
pat' <- Pat Type -> RewriteM (Pat Type)
rewritePat Pat Type
Pat (LetDec GPU)
pat
StmAux ()
aux' <- StmAux () -> RewriteM (StmAux ())
rewriteStmAux StmAux ()
StmAux (ExpDec GPU)
aux
Stm GPU -> RewriteM (Stm GPU)
forall a. a -> StateT RState ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec GPU)
pat' StmAux ()
StmAux (ExpDec GPU)
aux' Exp GPU
e')
rewritePat :: Pat Type -> RewriteM (Pat Type)
rewritePat :: Pat Type -> RewriteM (Pat Type)
rewritePat Pat Type
pat = [PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem Type] -> Pat Type)
-> StateT RState ReduceM [PatElem Type] -> RewriteM (Pat Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (PatElem Type -> StateT RState ReduceM (PatElem Type))
-> [PatElem Type] -> StateT RState ReduceM [PatElem Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM PatElem Type -> StateT RState ReduceM (PatElem Type)
rewritePatElem (Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
pat)
rewritePatElem :: PatElem Type -> RewriteM (PatElem Type)
rewritePatElem :: PatElem Type -> StateT RState ReduceM (PatElem Type)
rewritePatElem (PatElem VName
n Type
t) = do
VName
n' <- VName -> StateT RState ReduceM VName
rewriteName VName
n
Type
t' <- Type -> RewriteM Type
forall u.
TypeBase (ShapeBase SubExp) u
-> RewriteM (TypeBase (ShapeBase SubExp) u)
renameType Type
t
PatElem Type -> StateT RState ReduceM (PatElem Type)
forall a. a -> StateT RState ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> Type -> PatElem Type
forall dec. VName -> dec -> PatElem dec
PatElem VName
n' Type
t')
rewriteStmAux :: StmAux () -> RewriteM (StmAux ())
rewriteStmAux :: StmAux () -> RewriteM (StmAux ())
rewriteStmAux (StmAux Certs
certs Attrs
attrs ()
_) = do
Certs
certs' <- Certs -> RewriteM Certs
renameCerts Certs
certs
StmAux () -> RewriteM (StmAux ())
forall a. a -> StateT RState ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Certs -> Attrs -> () -> StmAux ()
forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
certs' Attrs
attrs ())
rewriteExp :: Exp GPU -> RewriteM (Exp GPU)
rewriteExp :: Exp GPU -> RewriteM (Exp GPU)
rewriteExp =
Mapper GPU GPU (StateT RState ReduceM)
-> Exp GPU -> RewriteM (Exp GPU)
forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM (Mapper GPU GPU (StateT RState ReduceM)
-> Exp GPU -> RewriteM (Exp GPU))
-> Mapper GPU GPU (StateT RState ReduceM)
-> Exp GPU
-> RewriteM (Exp GPU)
forall a b. (a -> b) -> a -> b
$
Mapper
{ mapOnSubExp :: SubExp -> StateT RState ReduceM SubExp
mapOnSubExp = SubExp -> StateT RState ReduceM SubExp
renameSubExp,
mapOnBody :: Scope GPU -> Body GPU -> RewriteM (Body GPU)
mapOnBody = (Body GPU -> RewriteM (Body GPU))
-> Scope GPU -> Body GPU -> RewriteM (Body GPU)
forall a b. a -> b -> a
const Body GPU -> RewriteM (Body GPU)
rewriteBody,
mapOnVName :: VName -> StateT RState ReduceM VName
mapOnVName = VName -> StateT RState ReduceM VName
rename,
mapOnRetType :: RetType GPU -> StateT RState ReduceM (RetType GPU)
mapOnRetType = DeclExtType -> RewriteM DeclExtType
RetType GPU -> StateT RState ReduceM (RetType GPU)
forall u.
TypeBase (ShapeBase ExtSize) u
-> RewriteM (TypeBase (ShapeBase ExtSize) u)
renameExtType,
mapOnBranchType :: BranchType GPU -> StateT RState ReduceM (BranchType GPU)
mapOnBranchType = ExtType -> RewriteM ExtType
BranchType GPU -> StateT RState ReduceM (BranchType GPU)
forall u.
TypeBase (ShapeBase ExtSize) u
-> RewriteM (TypeBase (ShapeBase ExtSize) u)
renameExtType,
mapOnFParam :: FParam GPU -> StateT RState ReduceM (FParam GPU)
mapOnFParam = Param DeclType -> RewriteM (Param DeclType)
FParam GPU -> StateT RState ReduceM (FParam GPU)
forall u.
Param (TypeBase (ShapeBase SubExp) u)
-> RewriteM (Param (TypeBase (ShapeBase SubExp) u))
rewriteParam,
mapOnLParam :: LParam GPU -> StateT RState ReduceM (LParam GPU)
mapOnLParam = Param Type -> RewriteM (Param Type)
LParam GPU -> StateT RState ReduceM (LParam GPU)
forall u.
Param (TypeBase (ShapeBase SubExp) u)
-> RewriteM (Param (TypeBase (ShapeBase SubExp) u))
rewriteParam,
mapOnOp :: Op GPU -> StateT RState ReduceM (Op GPU)
mapOnOp = StateT RState ReduceM (HostOp SOAC GPU)
-> HostOp SOAC GPU -> StateT RState ReduceM (HostOp SOAC GPU)
forall a b. a -> b -> a
const StateT RState ReduceM (HostOp SOAC GPU)
forall {a}. a
opError
}
where
opError :: a
opError = String -> a
forall a. String -> a
compilerBugS String
"Cannot migrate a host-only operation to device."
rewriteParam :: Param (TypeBase Shape u) -> RewriteM (Param (TypeBase Shape u))
rewriteParam :: forall u.
Param (TypeBase (ShapeBase SubExp) u)
-> RewriteM (Param (TypeBase (ShapeBase SubExp) u))
rewriteParam (Param Attrs
attrs VName
n TypeBase (ShapeBase SubExp) u
t) = do
VName
n' <- VName -> StateT RState ReduceM VName
rewriteName VName
n
TypeBase (ShapeBase SubExp) u
t' <- TypeBase (ShapeBase SubExp) u
-> RewriteM (TypeBase (ShapeBase SubExp) u)
forall u.
TypeBase (ShapeBase SubExp) u
-> RewriteM (TypeBase (ShapeBase SubExp) u)
renameType TypeBase (ShapeBase SubExp) u
t
Param (TypeBase (ShapeBase SubExp) u)
-> RewriteM (Param (TypeBase (ShapeBase SubExp) u))
forall a. a -> StateT RState ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Attrs
-> VName
-> TypeBase (ShapeBase SubExp) u
-> Param (TypeBase (ShapeBase SubExp) u)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
n' TypeBase (ShapeBase SubExp) u
t')
rename :: VName -> RewriteM VName
rename :: VName -> StateT RState ReduceM VName
rename VName
n = do
RState
st <- StateT RState ReduceM RState
forall s (m :: * -> *). MonadState s m => m s
get
let renames :: IntMap VName
renames = RState -> IntMap VName
rewriteRenames RState
st
let idx :: Int
idx = VName -> Int
baseTag VName
n
case Int -> IntMap VName -> Maybe VName
forall a. Int -> IntMap a -> Maybe a
IM.lookup Int
idx IntMap VName
renames of
Just VName
n' -> VName -> StateT RState ReduceM VName
forall a. a -> StateT RState ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
n'
Maybe VName
_ ->
do
let stms :: Stms GPU
stms = RState -> Stms GPU
rewritePrologue RState
st
(Stms GPU
stms', VName
n') <- ReduceM (Stms GPU, VName)
-> StateT RState ReduceM (Stms GPU, VName)
forall (m :: * -> *) a. Monad m => m a -> StateT RState m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ReduceM (Stms GPU, VName)
-> StateT RState ReduceM (Stms GPU, VName))
-> ReduceM (Stms GPU, VName)
-> StateT RState ReduceM (Stms GPU, VName)
forall a b. (a -> b) -> a -> b
$ Stms GPU -> VName -> ReduceM (Stms GPU, VName)
useScalar Stms GPU
stms VName
n
(RState -> RState) -> StateT RState ReduceM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((RState -> RState) -> StateT RState ReduceM ())
-> (RState -> RState) -> StateT RState ReduceM ()
forall a b. (a -> b) -> a -> b
$ \RState
st' ->
RState
st'
{ rewriteRenames = IM.insert idx n' renames,
rewritePrologue = stms'
}
VName -> StateT RState ReduceM VName
forall a. a -> StateT RState ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
n'
renameResult :: Result -> RewriteM Result
renameResult :: Result -> RewriteM Result
renameResult = (SubExpRes -> StateT RState ReduceM SubExpRes)
-> Result -> RewriteM Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExpRes -> StateT RState ReduceM SubExpRes
renameSubExpRes
renameSubExpRes :: SubExpRes -> RewriteM SubExpRes
renameSubExpRes :: SubExpRes -> StateT RState ReduceM SubExpRes
renameSubExpRes (SubExpRes Certs
certs SubExp
se) = do
Certs
certs' <- Certs -> RewriteM Certs
renameCerts Certs
certs
SubExp
se' <- SubExp -> StateT RState ReduceM SubExp
renameSubExp SubExp
se
SubExpRes -> StateT RState ReduceM SubExpRes
forall a. a -> StateT RState ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Certs -> SubExp -> SubExpRes
SubExpRes Certs
certs' SubExp
se')
renameCerts :: Certs -> RewriteM Certs
renameCerts :: Certs -> RewriteM Certs
renameCerts Certs
cs = [VName] -> Certs
Certs ([VName] -> Certs)
-> StateT RState ReduceM [VName] -> RewriteM Certs
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> StateT RState ReduceM VName)
-> [VName] -> StateT RState ReduceM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> StateT RState ReduceM VName
rename (Certs -> [VName]
unCerts Certs
cs)
renameSubExp :: SubExp -> RewriteM SubExp
renameSubExp :: SubExp -> StateT RState ReduceM SubExp
renameSubExp (Var VName
n) = VName -> SubExp
Var (VName -> SubExp)
-> StateT RState ReduceM VName -> StateT RState ReduceM SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> StateT RState ReduceM VName
rename VName
n
renameSubExp SubExp
se = SubExp -> StateT RState ReduceM SubExp
forall a. a -> StateT RState ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
se
renameType :: TypeBase Shape u -> RewriteM (TypeBase Shape u)
renameType :: forall u.
TypeBase (ShapeBase SubExp) u
-> RewriteM (TypeBase (ShapeBase SubExp) u)
renameType = (SubExp -> StateT RState ReduceM SubExp)
-> TypeBase (ShapeBase SubExp) u
-> StateT RState ReduceM (TypeBase (ShapeBase SubExp) u)
forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase (ShapeBase SubExp) u
-> m (TypeBase (ShapeBase SubExp) u)
mapOnType SubExp -> StateT RState ReduceM SubExp
renameSubExp
renameExtType :: TypeBase ExtShape u -> RewriteM (TypeBase ExtShape u)
renameExtType :: forall u.
TypeBase (ShapeBase ExtSize) u
-> RewriteM (TypeBase (ShapeBase ExtSize) u)
renameExtType = (SubExp -> StateT RState ReduceM SubExp)
-> TypeBase (ShapeBase ExtSize) u
-> StateT RState ReduceM (TypeBase (ShapeBase ExtSize) u)
forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase (ShapeBase ExtSize) u
-> m (TypeBase (ShapeBase ExtSize) u)
mapOnExtType SubExp -> StateT RState ReduceM SubExp
renameSubExp