{-# LANGUAGE TypeFamilies #-}
module Futhark.Optimise.HistAccs (histAccsGPU) where
import Control.Monad
import Control.Monad.Reader
import Control.Monad.State
import Data.Map.Strict qualified as M
import Futhark.IR.GPU
import Futhark.MonadFreshNames
import Futhark.Pass
import Futhark.Tools
import Futhark.Transform.Rename
import Prelude hiding (quot)
type Accs rep = M.Map VName (WithAccInput rep)
type OptM = ReaderT (Scope GPU) (State VNameSource)
optimiseBody :: Accs GPU -> Body GPU -> OptM (Body GPU)
optimiseBody :: Map VName (WithAccInput GPU) -> Body GPU -> OptM (Body GPU)
optimiseBody Map VName (WithAccInput GPU)
accs Body GPU
body = Stms GPU -> Result -> Body GPU
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody (Stms GPU -> Result -> Body GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Result -> Body GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Map VName (WithAccInput GPU)
-> Stms GPU -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
optimiseStms Map VName (WithAccInput GPU)
accs (Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms Body GPU
body) ReaderT (Scope GPU) (State VNameSource) (Result -> Body GPU)
-> ReaderT (Scope GPU) (State VNameSource) Result
-> OptM (Body GPU)
forall a b.
ReaderT (Scope GPU) (State VNameSource) (a -> b)
-> ReaderT (Scope GPU) (State VNameSource) a
-> ReaderT (Scope GPU) (State VNameSource) b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> ReaderT (Scope GPU) (State VNameSource) Result
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body GPU -> Result
forall rep. Body rep -> Result
bodyResult Body GPU
body)
optimiseExp :: Accs GPU -> Exp GPU -> OptM (Exp GPU)
optimiseExp :: Map VName (WithAccInput GPU) -> Exp GPU -> OptM (Exp GPU)
optimiseExp Map VName (WithAccInput GPU)
accs = Mapper GPU GPU (ReaderT (Scope GPU) (State VNameSource))
-> Exp GPU -> OptM (Exp GPU)
forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper GPU GPU (ReaderT (Scope GPU) (State VNameSource))
mapper
where
mapper :: Mapper GPU GPU (ReaderT (Scope GPU) (State VNameSource))
mapper =
Mapper GPU GPU (ReaderT (Scope GPU) (State VNameSource))
forall rep (m :: * -> *). Monad m => Mapper rep rep m
identityMapper
{ mapOnBody = \Scope GPU
scope Body GPU
body -> Scope GPU -> OptM (Body GPU) -> OptM (Body GPU)
forall a.
Scope GPU
-> ReaderT (Scope GPU) (State VNameSource) a
-> ReaderT (Scope GPU) (State VNameSource) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPU
scope (OptM (Body GPU) -> OptM (Body GPU))
-> OptM (Body GPU) -> OptM (Body GPU)
forall a b. (a -> b) -> a -> b
$ Map VName (WithAccInput GPU) -> Body GPU -> OptM (Body GPU)
optimiseBody Map VName (WithAccInput GPU)
accs Body GPU
body
}
extractUpdate ::
Accs rep ->
VName ->
Stms rep ->
Maybe ((WithAccInput rep, VName, [SubExp], [SubExp]), Stms rep)
Accs rep
accs VName
v Stms rep
stms = do
(Stm rep
stm, Stms rep
stms') <- Stms rep -> Maybe (Stm rep, Stms rep)
forall rep. Stms rep -> Maybe (Stm rep, Stms rep)
stmsHead Stms rep
stms
case Stm rep
stm of
Let (Pat [PatElem VName
pe_v LetDec rep
_]) StmAux (ExpDec rep)
_ (BasicOp (UpdateAcc Safety
_ VName
acc [SubExp]
is [SubExp]
vs))
| VName
pe_v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v -> do
WithAccInput rep
acc_input <- VName -> Accs rep -> Maybe (WithAccInput rep)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
acc Accs rep
accs
((WithAccInput rep, VName, [SubExp], [SubExp]), Stms rep)
-> Maybe ((WithAccInput rep, VName, [SubExp], [SubExp]), Stms rep)
forall a. a -> Maybe a
Just ((WithAccInput rep
acc_input, VName
acc, [SubExp]
is, [SubExp]
vs), Stms rep
stms')
Stm rep
_ -> do
((WithAccInput rep, VName, [SubExp], [SubExp])
x, Stms rep
stms'') <- Accs rep
-> VName
-> Stms rep
-> Maybe ((WithAccInput rep, VName, [SubExp], [SubExp]), Stms rep)
forall rep.
Accs rep
-> VName
-> Stms rep
-> Maybe ((WithAccInput rep, VName, [SubExp], [SubExp]), Stms rep)
extractUpdate Accs rep
accs VName
v Stms rep
stms'
((WithAccInput rep, VName, [SubExp], [SubExp]), Stms rep)
-> Maybe ((WithAccInput rep, VName, [SubExp], [SubExp]), Stms rep)
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((WithAccInput rep, VName, [SubExp], [SubExp])
x, Stm rep -> Stms rep
forall rep. Stm rep -> Stms rep
oneStm Stm rep
stm Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> Stms rep
stms'')
mkHistBody :: Accs GPU -> KernelBody GPU -> Maybe (KernelBody GPU, WithAccInput GPU, VName)
mkHistBody :: Map VName (WithAccInput GPU)
-> KernelBody GPU
-> Maybe (KernelBody GPU, WithAccInput GPU, VName)
mkHistBody Map VName (WithAccInput GPU)
accs (KernelBody () Stms GPU
stms [Returns ResultManifest
rm Certs
cs (Var VName
v)]) = do
((WithAccInput GPU
acc_input, VName
acc, [SubExp]
is, [SubExp]
vs), Stms GPU
stms') <- Map VName (WithAccInput GPU)
-> VName
-> Stms GPU
-> Maybe ((WithAccInput GPU, VName, [SubExp], [SubExp]), Stms GPU)
forall rep.
Accs rep
-> VName
-> Stms rep
-> Maybe ((WithAccInput rep, VName, [SubExp], [SubExp]), Stms rep)
extractUpdate Map VName (WithAccInput GPU)
accs VName
v Stms GPU
stms
(KernelBody GPU, WithAccInput GPU, VName)
-> Maybe (KernelBody GPU, WithAccInput GPU, VName)
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( BodyDec GPU -> Stms GPU -> [KernelResult] -> KernelBody GPU
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
stms' ([KernelResult] -> KernelBody GPU)
-> [KernelResult] -> KernelBody GPU
forall a b. (a -> b) -> a -> b
$ (SubExp -> KernelResult) -> [SubExp] -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map (ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
rm Certs
cs) [SubExp]
is [KernelResult] -> [KernelResult] -> [KernelResult]
forall a. [a] -> [a] -> [a]
++ (SubExp -> KernelResult) -> [SubExp] -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map (ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
rm Certs
cs) [SubExp]
vs,
WithAccInput GPU
acc_input,
VName
acc
)
mkHistBody Map VName (WithAccInput GPU)
_ KernelBody GPU
_ = Maybe (KernelBody GPU, WithAccInput GPU, VName)
forall a. Maybe a
Nothing
withAccLamToHistLam :: (MonadFreshNames m) => Shape -> Lambda GPU -> m (Lambda GPU)
withAccLamToHistLam :: forall (m :: * -> *).
MonadFreshNames m =>
Shape -> Lambda GPU -> m (Lambda GPU)
withAccLamToHistLam Shape
shape Lambda GPU
lam =
Lambda GPU -> m (Lambda GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda (Lambda GPU -> m (Lambda GPU)) -> Lambda GPU -> m (Lambda GPU)
forall a b. (a -> b) -> a -> b
$ Lambda GPU
lam {lambdaParams = drop (shapeRank shape) (lambdaParams lam)}
addArrsToAcc ::
(MonadBuilder m, Rep m ~ GPU) =>
SegLevel ->
Shape ->
[VName] ->
VName ->
m (Exp GPU)
addArrsToAcc :: forall (m :: * -> *).
(MonadBuilder m, Rep m ~ GPU) =>
SegLevel -> Shape -> [VName] -> VName -> m (Exp GPU)
addArrsToAcc SegLevel
lvl Shape
shape [VName]
arrs VName
acc = do
VName
flat <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"phys_tid"
[VName]
gtids <- Int -> m VName -> m [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
shape) (String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gtid")
let space :: SegSpace
space = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
flat ([(VName, SubExp)] -> SegSpace) -> [(VName, SubExp)] -> SegSpace
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids ([SubExp] -> [(VName, SubExp)]) -> [SubExp] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape
(VName
acc', Stms GPU
stms) <- Scope GPU -> m (VName, Stms GPU) -> m (VName, Stms GPU)
forall a. Scope GPU -> m a -> m a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (SegSpace -> Scope GPU
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) (m (VName, Stms GPU) -> m (VName, Stms GPU))
-> (m VName -> m (VName, Stms GPU))
-> m VName
-> m (VName, Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m VName -> m (VName, Stms (Rep m))
m VName -> m (VName, Stms GPU)
forall a. m a -> m (a, Stms (Rep m))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms (m VName -> m (VName, Stms GPU)) -> m VName -> m (VName, Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
[SubExp]
vs <- [VName] -> (VName -> m SubExp) -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
arrs ((VName -> m SubExp) -> m [SubExp])
-> (VName -> m SubExp) -> m [SubExp]
forall a b. (a -> b) -> a -> b
$ \VName
arr -> do
Type
arr_t <- VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp (VName -> String
baseString VName
arr String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_elem") (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_t ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$
(VName -> DimIndex SubExp) -> [VName] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> (VName -> SubExp) -> VName -> DimIndex SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
gtids
String -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
acc String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_upd") (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$
Safety -> VName -> [SubExp] -> [SubExp] -> BasicOp
UpdateAcc Safety
Safe VName
acc ((VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
gtids) [SubExp]
vs
Type
acc_t <- VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
acc
Exp GPU -> m (Exp GPU)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp GPU -> m (Exp GPU))
-> (KernelBody GPU -> Exp GPU) -> KernelBody GPU -> m (Exp GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op GPU -> Exp GPU
HostOp SOAC GPU -> Exp GPU
forall rep. Op rep -> Exp rep
Op (HostOp SOAC GPU -> Exp GPU)
-> (KernelBody GPU -> HostOp SOAC GPU) -> KernelBody GPU -> Exp GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp SegLevel GPU -> HostOp SOAC GPU
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp (SegOp SegLevel GPU -> HostOp SOAC GPU)
-> (KernelBody GPU -> SegOp SegLevel GPU)
-> KernelBody GPU
-> HostOp SOAC GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegLevel
-> SegSpace -> [Type] -> KernelBody GPU -> SegOp SegLevel GPU
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
lvl SegSpace
space [Type
acc_t] (KernelBody GPU -> m (Exp GPU)) -> KernelBody GPU -> m (Exp GPU)
forall a b. (a -> b) -> a -> b
$
BodyDec GPU -> Stms GPU -> [KernelResult] -> KernelBody GPU
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
stms [ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify Certs
forall a. Monoid a => a
mempty (VName -> SubExp
Var VName
acc')]
flatKernelBody ::
(MonadBuilder m) =>
SegSpace ->
KernelBody (Rep m) ->
m (SegSpace, KernelBody (Rep m))
flatKernelBody :: forall (m :: * -> *).
MonadBuilder m =>
SegSpace -> KernelBody (Rep m) -> m (SegSpace, KernelBody (Rep m))
flatKernelBody SegSpace
space KernelBody (Rep m)
kbody = do
VName
gtid <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gtid"
SubExp
dims_prod <-
String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"dims_prod"
(Exp (Rep m) -> m SubExp) -> m (Exp (Rep m)) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (SegSpace -> [SubExp]
segSpaceDims SegSpace
space)
let space' :: SegSpace
space' = VName -> [(VName, SubExp)] -> SegSpace
SegSpace (SegSpace -> VName
segFlat SegSpace
space) [(VName
gtid, SubExp
dims_prod)]
Stms (Rep m)
kbody_stms <- Scope (Rep m) -> m (Stms (Rep m)) -> m (Stms (Rep m))
forall a. Scope (Rep m) -> m a -> m a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (SegSpace -> Scope (Rep m)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space') (m (Stms (Rep m)) -> m (Stms (Rep m)))
-> (m () -> m (Stms (Rep m))) -> m () -> m (Stms (Rep m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m () -> m (Stms (Rep m))
forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ (m () -> m (Stms (Rep m))) -> m () -> m (Stms (Rep m))
forall a b. (a -> b) -> a -> b
$ do
let new_inds :: [TPrimExp Int64 VName]
new_inds =
[TPrimExp Int64 VName]
-> TPrimExp Int64 VName -> [TPrimExp Int64 VName]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex ((SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (SegSpace -> [SubExp]
segSpaceDims SegSpace
space)) (SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName) -> SubExp -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
gtid)
([VName] -> Exp (Rep m) -> m ())
-> [[VName]] -> [Exp (Rep m)] -> m ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ [VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames (((VName, SubExp) -> [VName]) -> [(VName, SubExp)] -> [[VName]]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> [VName]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> [VName])
-> ((VName, SubExp) -> VName) -> (VName, SubExp) -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst) (SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space))
([Exp (Rep m)] -> m ()) -> m [Exp (Rep m)] -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (TPrimExp Int64 VName -> m (Exp (Rep m)))
-> [TPrimExp Int64 VName] -> m [Exp (Rep m)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM TPrimExp Int64 VName -> m (Exp (Rep m))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp [TPrimExp Int64 VName]
new_inds
Stms (Rep m) -> m ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (Stms (Rep m) -> m ()) -> Stms (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ KernelBody (Rep m) -> Stms (Rep m)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody (Rep m)
kbody
(SegSpace, KernelBody (Rep m)) -> m (SegSpace, KernelBody (Rep m))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SegSpace
space', KernelBody (Rep m)
kbody {kernelBodyStms = kbody_stms})
optimiseStm :: Accs GPU -> Stm GPU -> OptM (Stms GPU)
optimiseStm :: Map VName (WithAccInput GPU)
-> Stm GPU -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
optimiseStm Map VName (WithAccInput GPU)
accs (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (WithAcc [WithAccInput GPU]
inputs Lambda GPU
lam)) = do
Scope GPU
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall a.
Scope GPU
-> ReaderT (Scope GPU) (State VNameSource) a
-> ReaderT (Scope GPU) (State VNameSource) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param (LParamInfo GPU)] -> Scope GPU
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams (Lambda GPU -> [Param (LParamInfo GPU)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
lam)) (ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU))
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
Body GPU
body' <- Map VName (WithAccInput GPU) -> Body GPU -> OptM (Body GPU)
optimiseBody Map VName (WithAccInput GPU)
accs' (Body GPU -> OptM (Body GPU)) -> Body GPU -> OptM (Body GPU)
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
lam
let lam' :: Lambda GPU
lam' = Lambda GPU
lam {lambdaBody = body'}
Stms GPU -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU))
-> Stms GPU -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stm GPU -> Stms GPU
forall rep. Stm rep -> Stms rep
oneStm (Stm GPU -> Stms GPU) -> Stm GPU -> Stms GPU
forall a b. (a -> b) -> a -> b
$ Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (Exp GPU -> Stm GPU) -> Exp GPU -> Stm GPU
forall a b. (a -> b) -> a -> b
$ [WithAccInput GPU] -> Lambda GPU -> Exp GPU
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [WithAccInput GPU]
inputs Lambda GPU
lam'
where
acc_names :: [VName]
acc_names = (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName ([Param Type] -> [VName]) -> [Param Type] -> [VName]
forall a b. (a -> b) -> a -> b
$ Int -> [Param Type] -> [Param Type]
forall a. Int -> [a] -> [a]
drop ([WithAccInput GPU] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput GPU]
inputs) ([Param Type] -> [Param Type]) -> [Param Type] -> [Param Type]
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> [Param (LParamInfo GPU)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
lam
accs' :: Map VName (WithAccInput GPU)
accs' = [(VName, WithAccInput GPU)] -> Map VName (WithAccInput GPU)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([VName] -> [WithAccInput GPU] -> [(VName, WithAccInput GPU)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
acc_names [WithAccInput GPU]
inputs) Map VName (WithAccInput GPU)
-> Map VName (WithAccInput GPU) -> Map VName (WithAccInput GPU)
forall a. Semigroup a => a -> a -> a
<> Map VName (WithAccInput GPU)
accs
optimiseStm Map VName (WithAccInput GPU)
accs (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (Op (SegOp (SegMap SegLevel
lvl SegSpace
space [Type]
_ KernelBody GPU
kbody))))
| Map VName (WithAccInput GPU)
accs Map VName (WithAccInput GPU)
-> Map VName (WithAccInput GPU) -> Bool
forall a. Eq a => a -> a -> Bool
/= Map VName (WithAccInput GPU)
forall a. Monoid a => a
mempty,
Just (KernelBody GPU
kbody', (Shape
acc_shape, [VName]
_, Just (Lambda GPU
acc_lam, [SubExp]
acc_nes)), VName
acc) <-
Map VName (WithAccInput GPU)
-> KernelBody GPU
-> Maybe (KernelBody GPU, WithAccInput GPU, VName)
mkHistBody Map VName (WithAccInput GPU)
accs KernelBody GPU
kbody,
(Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([Type] -> Bool) -> [Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda GPU
acc_lam = Builder GPU ()
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Builder GPU ()
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU))
-> Builder GPU ()
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
[VName]
hist_dests <- [SubExp]
-> (SubExp -> BuilderT GPU (State VNameSource) VName)
-> BuilderT GPU (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SubExp]
acc_nes ((SubExp -> BuilderT GPU (State VNameSource) VName)
-> BuilderT GPU (State VNameSource) [VName])
-> (SubExp -> BuilderT GPU (State VNameSource) VName)
-> BuilderT GPU (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \SubExp
ne ->
String
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"hist_dest" (Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName)
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource))))
-> BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
acc_shape SubExp
ne
Lambda GPU
acc_lam' <- Shape
-> Lambda GPU -> BuilderT GPU (State VNameSource) (Lambda GPU)
forall (m :: * -> *).
MonadFreshNames m =>
Shape -> Lambda GPU -> m (Lambda GPU)
withAccLamToHistLam Shape
acc_shape Lambda GPU
acc_lam
let ts' :: [Type]
ts' =
Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate (Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
acc_shape) (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
[Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ Lambda GPU -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda GPU
acc_lam
histop :: HistOp GPU
histop =
HistOp
{ histShape :: Shape
histShape = Shape
acc_shape,
histRaceFactor :: SubExp
histRaceFactor = IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1,
histDest :: [VName]
histDest = [VName]
hist_dests,
histNeutral :: [SubExp]
histNeutral = [SubExp]
acc_nes,
histOpShape :: Shape
histOpShape = Shape
forall a. Monoid a => a
mempty,
histOp :: Lambda GPU
histOp = Lambda GPU
acc_lam'
}
(SegSpace
space', KernelBody GPU
kbody'') <- SegSpace
-> KernelBody (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT
GPU
(State VNameSource)
(SegSpace, KernelBody (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
SegSpace -> KernelBody (Rep m) -> m (SegSpace, KernelBody (Rep m))
flatKernelBody SegSpace
space KernelBody (Rep (BuilderT GPU (State VNameSource)))
KernelBody GPU
kbody'
[VName]
hist_dest_upd <-
String
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"hist_dest_upd" (Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) [VName])
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ Op (Rep (BuilderT GPU (State VNameSource)))
-> Exp (Rep (BuilderT GPU (State VNameSource)))
forall rep. Op rep -> Exp rep
Op (Op (Rep (BuilderT GPU (State VNameSource)))
-> Exp (Rep (BuilderT GPU (State VNameSource))))
-> Op (Rep (BuilderT GPU (State VNameSource)))
-> Exp (Rep (BuilderT GPU (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPU -> HostOp SOAC GPU
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp (SegOp SegLevel GPU -> HostOp SOAC GPU)
-> SegOp SegLevel GPU -> HostOp SOAC GPU
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [Type]
-> KernelBody GPU
-> [HistOp GPU]
-> SegOp SegLevel GPU
forall lvl rep.
lvl
-> SegSpace
-> [Type]
-> KernelBody rep
-> [HistOp rep]
-> SegOp lvl rep
SegHist SegLevel
lvl SegSpace
space' [Type]
ts' KernelBody GPU
kbody'' [HistOp GPU
histop]
Stm (Rep (BuilderT GPU (State VNameSource))) -> Builder GPU ()
Stm GPU -> Builder GPU ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm GPU -> Builder GPU ())
-> (Exp GPU -> Stm GPU) -> Exp GPU -> Builder GPU ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (Exp GPU -> Builder GPU ())
-> BuilderT GPU (State VNameSource) (Exp GPU) -> Builder GPU ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegLevel
-> Shape
-> [VName]
-> VName
-> BuilderT GPU (State VNameSource) (Exp GPU)
forall (m :: * -> *).
(MonadBuilder m, Rep m ~ GPU) =>
SegLevel -> Shape -> [VName] -> VName -> m (Exp GPU)
addArrsToAcc SegLevel
lvl Shape
acc_shape [VName]
hist_dest_upd VName
acc
optimiseStm Map VName (WithAccInput GPU)
accs (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux Exp GPU
e) =
Stm GPU -> Stms GPU
forall rep. Stm rep -> Stms rep
oneStm (Stm GPU -> Stms GPU)
-> (Exp GPU -> Stm GPU) -> Exp GPU -> Stms GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (Exp GPU -> Stms GPU)
-> OptM (Exp GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Map VName (WithAccInput GPU) -> Exp GPU -> OptM (Exp GPU)
optimiseExp Map VName (WithAccInput GPU)
accs Exp GPU
e
optimiseStms :: Accs GPU -> Stms GPU -> OptM (Stms GPU)
optimiseStms :: Map VName (WithAccInput GPU)
-> Stms GPU -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
optimiseStms Map VName (WithAccInput GPU)
accs Stms GPU
stms =
Scope GPU
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall a.
Scope GPU
-> ReaderT (Scope GPU) (State VNameSource) a
-> ReaderT (Scope GPU) (State VNameSource) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Stms GPU -> Scope GPU
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms GPU
stms) (ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU))
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall a b. (a -> b) -> a -> b
$
[Stms GPU] -> Stms GPU
forall a. Monoid a => [a] -> a
mconcat ([Stms GPU] -> Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) [Stms GPU]
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm GPU -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU))
-> [Stm GPU] -> ReaderT (Scope GPU) (State VNameSource) [Stms GPU]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Map VName (WithAccInput GPU)
-> Stm GPU -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
optimiseStm Map VName (WithAccInput GPU)
accs) (Stms GPU -> [Stm GPU]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms GPU
stms)
histAccsGPU :: Pass GPU GPU
histAccsGPU :: Pass GPU GPU
histAccsGPU =
String -> String -> (Prog GPU -> PassM (Prog GPU)) -> Pass GPU GPU
forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass String
"hist accs" String
"Turn certain accumulations into histograms" ((Prog GPU -> PassM (Prog GPU)) -> Pass GPU GPU)
-> (Prog GPU -> PassM (Prog GPU)) -> Pass GPU GPU
forall a b. (a -> b) -> a -> b
$
(Scope GPU -> Stms GPU -> PassM (Stms GPU))
-> Prog GPU -> PassM (Prog GPU)
forall rep.
(Scope rep -> Stms rep -> PassM (Stms rep))
-> Prog rep -> PassM (Prog rep)
intraproceduralTransformation Scope GPU -> Stms GPU -> PassM (Stms GPU)
forall {m :: * -> *}.
MonadFreshNames m =>
Scope GPU -> Stms GPU -> m (Stms GPU)
onStms
where
onStms :: Scope GPU -> Stms GPU -> m (Stms GPU)
onStms Scope GPU
scope Stms GPU
stms =
(VNameSource -> (Stms GPU, VNameSource)) -> m (Stms GPU)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Stms GPU, VNameSource)) -> m (Stms GPU))
-> (State VNameSource (Stms GPU)
-> VNameSource -> (Stms GPU, VNameSource))
-> State VNameSource (Stms GPU)
-> m (Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. State VNameSource (Stms GPU)
-> VNameSource -> (Stms GPU, VNameSource)
forall s a. State s a -> s -> (a, s)
runState (State VNameSource (Stms GPU) -> m (Stms GPU))
-> State VNameSource (Stms GPU) -> m (Stms GPU)
forall a b. (a -> b) -> a -> b
$
ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
-> Scope GPU -> State VNameSource (Stms GPU)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Map VName (WithAccInput GPU)
-> Stms GPU -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
optimiseStms Map VName (WithAccInput GPU)
forall a. Monoid a => a
mempty Stms GPU
stms) Scope GPU
scope