{-# LANGUAGE TypeFamilies #-}
module Futhark.CodeGen.ImpGen.GPU.SegRed
( compileSegRed,
compileSegRed',
DoSegBody,
)
where
import Control.Monad
import Data.List (genericLength, zip4)
import Data.Map qualified as M
import Data.Maybe
import Futhark.CodeGen.ImpCode.GPU qualified as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.GPU.Base
import Futhark.Error
import Futhark.IR.GPUMem
import Futhark.IR.Mem.LMAD qualified as LMAD
import Futhark.Transform.Rename
import Futhark.Util (chunks, mapAccumLM)
import Futhark.Util.IntegralExp (divUp, nextMul, quot, rem)
import Prelude hiding (quot, rem)
forM2_ :: (Monad m) => [a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ :: forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [a]
xs [b]
ys a -> b -> m c
f = [(a, b)] -> ((a, b) -> m c) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([a] -> [b] -> [(a, b)]
forall a b. [a] -> [b] -> [(a, b)]
zip [a]
xs [b]
ys) ((a -> b -> m c) -> (a, b) -> m c
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry a -> b -> m c
f)
maxNumOps :: Int
maxNumOps :: Int
maxNumOps = Int
20
type DoSegBody = ([(SubExp, [Imp.TExp Int64])] -> InKernelGen ()) -> InKernelGen ()
data SegRedIntermediateArrays
= GeneralSegRedInterms
{ SegRedIntermediateArrays -> [VName]
blockRedArrs :: [VName]
}
| NoncommPrimSegRedInterms
{ SegRedIntermediateArrays -> [VName]
collCopyArrs :: [VName],
blockRedArrs :: [VName],
SegRedIntermediateArrays -> [VName]
privateChunks :: [VName]
}
compileSegRed ::
Pat LetDecMem ->
SegLevel ->
SegSpace ->
[SegBinOp GPUMem] ->
KernelBody GPUMem ->
CallKernelGen ()
compileSegRed :: Pat LParamMem
-> SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> CallKernelGen ()
compileSegRed Pat LParamMem
pat SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
segbinops KernelBody GPUMem
map_kbody = do
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegRed" Maybe Exp
forall a. Maybe a
Nothing
KernelAttrs {kAttrNumBlocks :: KernelAttrs -> Count NumBlocks SubExp
kAttrNumBlocks = Count NumBlocks SubExp
num_tblocks, kAttrBlockSize :: KernelAttrs -> Count BlockSize SubExp
kAttrBlockSize = Count BlockSize SubExp
tblock_size} <-
SegLevel -> CallKernelGen KernelAttrs
lvlKernelAttrs SegLevel
lvl
let grid :: KernelGrid
grid = Count NumBlocks SubExp -> Count BlockSize SubExp -> KernelGrid
KernelGrid Count NumBlocks SubExp
num_tblocks Count BlockSize SubExp
tblock_size
Pat LParamMem
-> KernelGrid
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
compileSegRed' Pat LParamMem
pat KernelGrid
grid SegSpace
space [SegBinOp GPUMem]
segbinops (DoSegBody -> CallKernelGen ()) -> DoSegBody -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[(SubExp, [TPrimExp Int64 VName])] -> InKernelGen ()
red_cont ->
Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"apply map function" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Names -> Stms GPUMem -> InKernelGen () -> InKernelGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
map_kbody) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let ([KernelResult]
red_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp GPUMem] -> Int
forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp GPUMem]
segbinops) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
map_kbody
let mapout_arrs :: [PatElem LParamMem]
mapout_arrs = Int -> [PatElem LParamMem] -> [PatElem LParamMem]
forall a. Int -> [a] -> [a]
drop ([SegBinOp GPUMem] -> Int
forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp GPUMem]
segbinops) ([PatElem LParamMem] -> [PatElem LParamMem])
-> [PatElem LParamMem] -> [PatElem LParamMem]
forall a b. (a -> b) -> a -> b
$ Pat LParamMem -> [PatElem LParamMem]
forall dec. Pat dec -> [PatElem dec]
patElems Pat LParamMem
pat
Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([PatElem LParamMem] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [PatElem LParamMem]
mapout_arrs) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"write map-out result(s)" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
(PatElem LParamMem -> KernelResult -> InKernelGen ())
-> [PatElem LParamMem] -> [KernelResult] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace -> PatElem LParamMem -> KernelResult -> InKernelGen ()
compileThreadResult SegSpace
space) [PatElem LParamMem]
mapout_arrs [KernelResult]
map_res
[(SubExp, [TPrimExp Int64 VName])] -> InKernelGen ()
red_cont ([(SubExp, [TPrimExp Int64 VName])] -> InKernelGen ())
-> [(SubExp, [TPrimExp Int64 VName])] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ (KernelResult -> (SubExp, [TPrimExp Int64 VName]))
-> [KernelResult] -> [(SubExp, [TPrimExp Int64 VName])]
forall a b. (a -> b) -> [a] -> [b]
map ((,[]) (SubExp -> (SubExp, [TPrimExp Int64 VName]))
-> (KernelResult -> SubExp)
-> KernelResult
-> (SubExp, [TPrimExp Int64 VName])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelResult -> SubExp
kernelResultSubExp) [KernelResult]
red_res
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"" Maybe Exp
forall a. Maybe a
Nothing
paramOf :: SegBinOp GPUMem -> [Param LParamMem]
paramOf :: SegBinOp GPUMem -> [Param LParamMem]
paramOf (SegBinOp Commutativity
_ Lambda GPUMem
op [SubExp]
ne Shape
_) = Int -> [Param LParamMem] -> [Param LParamMem]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
ne) ([Param LParamMem] -> [Param LParamMem])
-> [Param LParamMem] -> [Param LParamMem]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
op
isPrimSegBinOp :: SegBinOp GPUMem -> Bool
isPrimSegBinOp :: SegBinOp GPUMem -> Bool
isPrimSegBinOp SegBinOp GPUMem
segbinop =
(TypeBase Shape NoUniqueness -> Bool)
-> [TypeBase Shape NoUniqueness] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Lambda GPUMem -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType (Lambda GPUMem -> [TypeBase Shape NoUniqueness])
-> Lambda GPUMem -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> a -> b
$ SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
segbinop)
Bool -> Bool -> Bool
&& Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (SegBinOp GPUMem -> Shape
forall rep. SegBinOp rep -> Shape
segBinOpShape SegBinOp GPUMem
segbinop) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
compileSegRed' ::
Pat LetDecMem ->
KernelGrid ->
SegSpace ->
[SegBinOp GPUMem] ->
DoSegBody ->
CallKernelGen ()
compileSegRed' :: Pat LParamMem
-> KernelGrid
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
compileSegRed' Pat LParamMem
pat KernelGrid
grid SegSpace
space [SegBinOp GPUMem]
segbinops DoSegBody
map_body_cont
| [SegBinOp GPUMem] -> Int
forall i a. Num i => [a] -> i
genericLength [SegBinOp GPUMem]
segbinops Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
maxNumOps =
String -> CallKernelGen ()
forall a. String -> a
compilerLimitationS (String -> CallKernelGen ()) -> String -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
( String
"compileSegRed': at most "
String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
maxNumOps
String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" reduction operators are supported,\nbut found kernel with "
String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show ([SegBinOp GPUMem] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SegBinOp GPUMem]
segbinops)
String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
".\n"
)
String -> String -> String
forall a. Semigroup a => a -> a -> a
<> (String
"Pattern: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Pat LParamMem -> String
forall a. Pretty a => a -> String
prettyString Pat LParamMem
pat)
| Bool
otherwise = do
TV Int64
chunk_v <- String
-> TPrimExp Int64 VName -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"chunk_size" (TPrimExp Int64 VName -> ImpM GPUMem HostEnv HostOp (TV Int64))
-> (Exp -> TPrimExp Int64 VName)
-> Exp
-> ImpM GPUMem HostEnv HostOp (TV Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> TPrimExp Int64 VName
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (Exp -> ImpM GPUMem HostEnv HostOp (TV Int64))
-> ImpM GPUMem HostEnv HostOp Exp
-> ImpM GPUMem HostEnv HostOp (TV Int64)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< KernelConstExp -> ImpM GPUMem HostEnv HostOp Exp
kernelConstToExp KernelConstExp
chunk_const
case SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space of
[(VName
_, Constant (IntValue (Int64Value Int64
1))), (VName, SubExp)
_] ->
(TV Int64, KernelConstExp)
-> (Pat LParamMem
-> Count NumBlocks SubExp
-> Count BlockSize SubExp
-> (TV Int64, KernelConstExp)
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ())
-> CallKernelGen ()
compileReduction (TV Int64
chunk_v, KernelConstExp
chunk_const) Pat LParamMem
-> Count NumBlocks SubExp
-> Count BlockSize SubExp
-> (TV Int64, KernelConstExp)
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
nonsegmentedReduction
[(VName, SubExp)]
_ -> do
let segment_size :: TPrimExp Int64 VName
segment_size = SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName) -> SubExp -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ [SubExp] -> SubExp
forall a. HasCallStack => [a] -> a
last ([SubExp] -> SubExp) -> [SubExp] -> SubExp
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
use_small_segments :: TPrimExp Bool VName
use_small_segments = TPrimExp Int64 VName
segment_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
2 TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 (Count BlockSize SubExp -> SubExp
forall {k} (u :: k) e. Count u e -> e
unCount Count BlockSize SubExp
tblock_size) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TV Int64 -> TPrimExp Int64 VName
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
chunk_v
TPrimExp Bool VName
-> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall rep r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
TPrimExp Bool VName
use_small_segments
((TV Int64, KernelConstExp)
-> (Pat LParamMem
-> Count NumBlocks SubExp
-> Count BlockSize SubExp
-> (TV Int64, KernelConstExp)
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ())
-> CallKernelGen ()
compileReduction (TV Int64
chunk_v, KernelConstExp
chunk_const) Pat LParamMem
-> Count NumBlocks SubExp
-> Count BlockSize SubExp
-> (TV Int64, KernelConstExp)
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
smallSegmentsReduction)
((TV Int64, KernelConstExp)
-> (Pat LParamMem
-> Count NumBlocks SubExp
-> Count BlockSize SubExp
-> (TV Int64, KernelConstExp)
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ())
-> CallKernelGen ()
compileReduction (TV Int64
chunk_v, KernelConstExp
chunk_const) Pat LParamMem
-> Count NumBlocks SubExp
-> Count BlockSize SubExp
-> (TV Int64, KernelConstExp)
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
largeSegmentsReduction)
where
compileReduction :: (TV Int64, KernelConstExp)
-> (Pat LParamMem
-> Count NumBlocks SubExp
-> Count BlockSize SubExp
-> (TV Int64, KernelConstExp)
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ())
-> CallKernelGen ()
compileReduction (TV Int64, KernelConstExp)
chunk Pat LParamMem
-> Count NumBlocks SubExp
-> Count BlockSize SubExp
-> (TV Int64, KernelConstExp)
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
f =
Pat LParamMem
-> Count NumBlocks SubExp
-> Count BlockSize SubExp
-> (TV Int64, KernelConstExp)
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
f Pat LParamMem
pat Count NumBlocks SubExp
num_tblocks Count BlockSize SubExp
tblock_size (TV Int64, KernelConstExp)
chunk SegSpace
space [SegBinOp GPUMem]
segbinops DoSegBody
map_body_cont
param_types :: [TypeBase Shape NoUniqueness]
param_types = (Param LParamMem -> TypeBase Shape NoUniqueness)
-> [Param LParamMem] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map Param LParamMem -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType ([Param LParamMem] -> [TypeBase Shape NoUniqueness])
-> [Param LParamMem] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> a -> b
$ (SegBinOp GPUMem -> [Param LParamMem])
-> [SegBinOp GPUMem] -> [Param LParamMem]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap SegBinOp GPUMem -> [Param LParamMem]
paramOf [SegBinOp GPUMem]
segbinops
num_tblocks :: Count NumBlocks SubExp
num_tblocks = KernelGrid -> Count NumBlocks SubExp
gridNumBlocks KernelGrid
grid
tblock_size :: Count BlockSize SubExp
tblock_size = KernelGrid -> Count BlockSize SubExp
gridBlockSize KernelGrid
grid
chunk_const :: KernelConstExp
chunk_const =
if Commutativity
Noncommutative Commutativity -> [Commutativity] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` (SegBinOp GPUMem -> Commutativity)
-> [SegBinOp GPUMem] -> [Commutativity]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp GPUMem -> Commutativity
forall rep. SegBinOp rep -> Commutativity
segBinOpComm [SegBinOp GPUMem]
segbinops
Bool -> Bool -> Bool
&& (SegBinOp GPUMem -> Bool) -> [SegBinOp GPUMem] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all SegBinOp GPUMem -> Bool
isPrimSegBinOp [SegBinOp GPUMem]
segbinops
then [TypeBase Shape NoUniqueness] -> KernelConstExp
getChunkSize [TypeBase Shape NoUniqueness]
param_types
else PrimValue -> KernelConstExp
forall v. PrimValue -> PrimExp v
Imp.ValueExp (PrimValue -> KernelConstExp) -> PrimValue -> KernelConstExp
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Int64 -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
Int64 (Int64
1 :: Int64)
makeIntermArrays ::
Imp.TExp Int64 ->
SubExp ->
SubExp ->
[SegBinOp GPUMem] ->
InKernelGen [SegRedIntermediateArrays]
makeIntermArrays :: TPrimExp Int64 VName
-> SubExp
-> SubExp
-> [SegBinOp GPUMem]
-> InKernelGen [SegRedIntermediateArrays]
makeIntermArrays TPrimExp Int64 VName
tblock_id SubExp
tblock_size SubExp
chunk [SegBinOp GPUMem]
segbinops
| Commutativity
Noncommutative <- [Commutativity] -> Commutativity
forall a. Monoid a => [a] -> a
mconcat ((SegBinOp GPUMem -> Commutativity)
-> [SegBinOp GPUMem] -> [Commutativity]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp GPUMem -> Commutativity
forall rep. SegBinOp rep -> Commutativity
segBinOpComm [SegBinOp GPUMem]
segbinops),
(SegBinOp GPUMem -> Bool) -> [SegBinOp GPUMem] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all SegBinOp GPUMem -> Bool
isPrimSegBinOp [SegBinOp GPUMem]
segbinops =
InKernelGen [SegRedIntermediateArrays]
noncommPrimSegRedInterms
| Bool
otherwise =
Bool
-> TPrimExp Int64 VName
-> SubExp
-> [SegBinOp GPUMem]
-> InKernelGen [SegRedIntermediateArrays]
generalSegRedInterms Bool
False TPrimExp Int64 VName
tblock_id SubExp
tblock_size [SegBinOp GPUMem]
segbinops
where
params :: [[Param LParamMem]]
params = (SegBinOp GPUMem -> [Param LParamMem])
-> [SegBinOp GPUMem] -> [[Param LParamMem]]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp GPUMem -> [Param LParamMem]
paramOf [SegBinOp GPUMem]
segbinops
noncommPrimSegRedInterms :: InKernelGen [SegRedIntermediateArrays]
noncommPrimSegRedInterms = do
SubExp
block_worksize <- TV Int64 -> SubExp
forall {k} (t :: k). TV t -> SubExp
tvSize (TV Int64 -> SubExp)
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
-> ImpM GPUMem KernelEnv KernelOp SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"block_worksize" TPrimExp Int64 VName
block_worksize_E
let sum_ :: TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
sum_ TPrimExp Int64 VName
x TPrimExp Int64 VName
y = TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
nextMul TPrimExp Int64 VName
x TPrimExp Int64 VName
y TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
tblock_size_E TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
y
block_reds_lmem_requirement :: TPrimExp Int64 VName
block_reds_lmem_requirement = (TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> TPrimExp Int64 VName
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
sum_ TPrimExp Int64 VName
0 ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ [[TPrimExp Int64 VName]] -> [TPrimExp Int64 VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[TPrimExp Int64 VName]]
elem_sizes
collcopy_lmem_requirement :: TPrimExp Int64 VName
collcopy_lmem_requirement = TPrimExp Int64 VName
block_worksize_E TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
max_elem_size
lmem_total_size :: Count Bytes (TPrimExp Int64 VName)
lmem_total_size =
TPrimExp Int64 VName -> Count Bytes (TPrimExp Int64 VName)
forall a. a -> Count Bytes a
Imp.bytes (TPrimExp Int64 VName -> Count Bytes (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> Count Bytes (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName
collcopy_lmem_requirement TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
`sMax64` TPrimExp Int64 VName
block_reds_lmem_requirement
(TPrimExp Int64 VName
_, [[TPrimExp Int64 VName]]
offsets) <-
TPrimExp Int64 VName
-> [[TPrimExp Int64 VName]]
-> (TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> ImpM
GPUMem
KernelEnv
KernelOp
(TPrimExp Int64 VName, TPrimExp Int64 VName))
-> ImpM
GPUMem
KernelEnv
KernelOp
(TPrimExp Int64 VName, [[TPrimExp Int64 VName]])
forall {m :: * -> *} {t :: * -> *} {t :: * -> *} {acc} {x} {y}.
(Monad m, Traversable t, Traversable t) =>
acc -> t (t x) -> (acc -> x -> m (acc, y)) -> m (acc, t (t y))
forAccumLM2D TPrimExp Int64 VName
0 [[TPrimExp Int64 VName]]
elem_sizes ((TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> ImpM
GPUMem
KernelEnv
KernelOp
(TPrimExp Int64 VName, TPrimExp Int64 VName))
-> ImpM
GPUMem
KernelEnv
KernelOp
(TPrimExp Int64 VName, [[TPrimExp Int64 VName]]))
-> (TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> ImpM
GPUMem
KernelEnv
KernelOp
(TPrimExp Int64 VName, TPrimExp Int64 VName))
-> ImpM
GPUMem
KernelEnv
KernelOp
(TPrimExp Int64 VName, [[TPrimExp Int64 VName]])
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
byte_offs TPrimExp Int64 VName
elem_size ->
(,TPrimExp Int64 VName
byte_offs TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int64 VName
elem_size)
(TPrimExp Int64 VName
-> (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
-> ImpM
GPUMem
KernelEnv
KernelOp
(TPrimExp Int64 VName, TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"offset" (TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
sum_ TPrimExp Int64 VName
byte_offs TPrimExp Int64 VName
elem_size)
VName
lmem <- String
-> Count Bytes (TPrimExp Int64 VName)
-> Space
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String
-> Count Bytes (TPrimExp Int64 VName)
-> Space
-> ImpM rep r op VName
sAlloc String
"local_mem" Count Bytes (TPrimExp Int64 VName)
lmem_total_size (String -> Space
Space String
"shared")
let arrInLMem :: PrimType
-> String
-> SubExp
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp VName
arrInLMem PrimType
ptype String
name SubExp
len_se TPrimExp Int64 VName
offset =
String
-> PrimType
-> Shape
-> VName
-> LMAD
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String -> PrimType -> Shape -> VName -> LMAD -> ImpM rep r op VName
sArray
(String
name String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ PrimType -> String
forall a. Pretty a => a -> String
prettyString PrimType
ptype)
PrimType
ptype
([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
len_se])
VName
lmem
(LMAD -> ImpM GPUMem KernelEnv KernelOp VName)
-> LMAD -> ImpM GPUMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> [TPrimExp Int64 VName] -> LMAD
forall num. IntegralExp num => num -> [num] -> LMAD num
LMAD.iota TPrimExp Int64 VName
offset [SubExp -> TPrimExp Int64 VName
pe64 SubExp
len_se]
[[(Param LParamMem, TPrimExp Int64 VName)]]
-> ([(Param LParamMem, TPrimExp Int64 VName)]
-> ImpM GPUMem KernelEnv KernelOp SegRedIntermediateArrays)
-> InKernelGen [SegRedIntermediateArrays]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (([Param LParamMem]
-> [TPrimExp Int64 VName]
-> [(Param LParamMem, TPrimExp Int64 VName)])
-> [[Param LParamMem]]
-> [[TPrimExp Int64 VName]]
-> [[(Param LParamMem, TPrimExp Int64 VName)]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith [Param LParamMem]
-> [TPrimExp Int64 VName]
-> [(Param LParamMem, TPrimExp Int64 VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [[Param LParamMem]]
params [[TPrimExp Int64 VName]]
offsets) (([(Param LParamMem, TPrimExp Int64 VName)]
-> ImpM GPUMem KernelEnv KernelOp SegRedIntermediateArrays)
-> InKernelGen [SegRedIntermediateArrays])
-> ([(Param LParamMem, TPrimExp Int64 VName)]
-> ImpM GPUMem KernelEnv KernelOp SegRedIntermediateArrays)
-> InKernelGen [SegRedIntermediateArrays]
forall a b. (a -> b) -> a -> b
$ \[(Param LParamMem, TPrimExp Int64 VName)]
ps_and_offsets -> do
([VName]
coll_copy_arrs, [VName]
block_red_arrs, [VName]
priv_chunks) <-
([(VName, VName, VName)] -> ([VName], [VName], [VName]))
-> ImpM GPUMem KernelEnv KernelOp [(VName, VName, VName)]
-> ImpM GPUMem KernelEnv KernelOp ([VName], [VName], [VName])
forall a b.
(a -> b)
-> ImpM GPUMem KernelEnv KernelOp a
-> ImpM GPUMem KernelEnv KernelOp b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(VName, VName, VName)] -> ([VName], [VName], [VName])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 (ImpM GPUMem KernelEnv KernelOp [(VName, VName, VName)]
-> ImpM GPUMem KernelEnv KernelOp ([VName], [VName], [VName]))
-> ImpM GPUMem KernelEnv KernelOp [(VName, VName, VName)]
-> ImpM GPUMem KernelEnv KernelOp ([VName], [VName], [VName])
forall a b. (a -> b) -> a -> b
$ [(Param LParamMem, TPrimExp Int64 VName)]
-> ((Param LParamMem, TPrimExp Int64 VName)
-> ImpM GPUMem KernelEnv KernelOp (VName, VName, VName))
-> ImpM GPUMem KernelEnv KernelOp [(VName, VName, VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Param LParamMem, TPrimExp Int64 VName)]
ps_and_offsets (((Param LParamMem, TPrimExp Int64 VName)
-> ImpM GPUMem KernelEnv KernelOp (VName, VName, VName))
-> ImpM GPUMem KernelEnv KernelOp [(VName, VName, VName)])
-> ((Param LParamMem, TPrimExp Int64 VName)
-> ImpM GPUMem KernelEnv KernelOp (VName, VName, VName))
-> ImpM GPUMem KernelEnv KernelOp [(VName, VName, VName)]
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, TPrimExp Int64 VName
offset) -> do
let ptype :: PrimType
ptype = TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType (TypeBase Shape NoUniqueness -> PrimType)
-> TypeBase Shape NoUniqueness -> PrimType
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LParamMem
p
(,,)
(VName -> VName -> VName -> (VName, VName, VName))
-> ImpM GPUMem KernelEnv KernelOp VName
-> ImpM
GPUMem KernelEnv KernelOp (VName -> VName -> (VName, VName, VName))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> PrimType
-> String
-> SubExp
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp VName
arrInLMem PrimType
ptype String
"coll_copy_arr" SubExp
block_worksize TPrimExp Int64 VName
0
ImpM
GPUMem KernelEnv KernelOp (VName -> VName -> (VName, VName, VName))
-> ImpM GPUMem KernelEnv KernelOp VName
-> ImpM GPUMem KernelEnv KernelOp (VName -> (VName, VName, VName))
forall a b.
ImpM GPUMem KernelEnv KernelOp (a -> b)
-> ImpM GPUMem KernelEnv KernelOp a
-> ImpM GPUMem KernelEnv KernelOp b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> PrimType
-> String
-> SubExp
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp VName
arrInLMem PrimType
ptype String
"block_red_arr" SubExp
tblock_size TPrimExp Int64 VName
offset
ImpM GPUMem KernelEnv KernelOp (VName -> (VName, VName, VName))
-> ImpM GPUMem KernelEnv KernelOp VName
-> ImpM GPUMem KernelEnv KernelOp (VName, VName, VName)
forall a b.
ImpM GPUMem KernelEnv KernelOp (a -> b)
-> ImpM GPUMem KernelEnv KernelOp a
-> ImpM GPUMem KernelEnv KernelOp b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> String
-> PrimType
-> Shape
-> Space
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String -> PrimType -> Shape -> Space -> ImpM rep r op VName
sAllocArray
(String
"chunk_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ PrimType -> String
forall a. Pretty a => a -> String
prettyString PrimType
ptype)
PrimType
ptype
([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
chunk])
([SubExp] -> PrimType -> Space
ScalarSpace [SubExp
chunk] PrimType
ptype)
SegRedIntermediateArrays
-> ImpM GPUMem KernelEnv KernelOp SegRedIntermediateArrays
forall a. a -> ImpM GPUMem KernelEnv KernelOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SegRedIntermediateArrays
-> ImpM GPUMem KernelEnv KernelOp SegRedIntermediateArrays)
-> SegRedIntermediateArrays
-> ImpM GPUMem KernelEnv KernelOp SegRedIntermediateArrays
forall a b. (a -> b) -> a -> b
$ [VName] -> [VName] -> [VName] -> SegRedIntermediateArrays
NoncommPrimSegRedInterms [VName]
coll_copy_arrs [VName]
block_red_arrs [VName]
priv_chunks
tblock_size_E :: TPrimExp Int64 VName
tblock_size_E = SubExp -> TPrimExp Int64 VName
pe64 SubExp
tblock_size
block_worksize_E :: TPrimExp Int64 VName
block_worksize_E = TPrimExp Int64 VName
tblock_size_E TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
chunk
paramSize :: Param LParamMem -> TPrimExp Int64 VName
paramSize = PrimType -> TPrimExp Int64 VName
forall a. Num a => PrimType -> a
primByteSize (PrimType -> TPrimExp Int64 VName)
-> (Param LParamMem -> PrimType)
-> Param LParamMem
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType (TypeBase Shape NoUniqueness -> PrimType)
-> (Param LParamMem -> TypeBase Shape NoUniqueness)
-> Param LParamMem
-> PrimType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param LParamMem -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType
elem_sizes :: [[TPrimExp Int64 VName]]
elem_sizes = ([Param LParamMem] -> [TPrimExp Int64 VName])
-> [[Param LParamMem]] -> [[TPrimExp Int64 VName]]
forall a b. (a -> b) -> [a] -> [b]
map ((Param LParamMem -> TPrimExp Int64 VName)
-> [Param LParamMem] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map Param LParamMem -> TPrimExp Int64 VName
paramSize) [[Param LParamMem]]
params
max_elem_size :: TPrimExp Int64 VName
max_elem_size = [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. Ord a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ [[TPrimExp Int64 VName]] -> [TPrimExp Int64 VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[TPrimExp Int64 VName]]
elem_sizes
forAccumLM2D :: acc -> t (t x) -> (acc -> x -> m (acc, y)) -> m (acc, t (t y))
forAccumLM2D acc
acc t (t x)
ls acc -> x -> m (acc, y)
f = (acc -> t x -> m (acc, t y)) -> acc -> t (t x) -> m (acc, t (t y))
forall (m :: * -> *) (t :: * -> *) acc x y.
(Monad m, Traversable t) =>
(acc -> x -> m (acc, y)) -> acc -> t x -> m (acc, t y)
mapAccumLM ((acc -> x -> m (acc, y)) -> acc -> t x -> m (acc, t y)
forall (m :: * -> *) (t :: * -> *) acc x y.
(Monad m, Traversable t) =>
(acc -> x -> m (acc, y)) -> acc -> t x -> m (acc, t y)
mapAccumLM acc -> x -> m (acc, y)
f) acc
acc t (t x)
ls
generalSegRedInterms ::
Bool ->
Imp.TExp Int64 ->
SubExp ->
[SegBinOp GPUMem] ->
InKernelGen [SegRedIntermediateArrays]
generalSegRedInterms :: Bool
-> TPrimExp Int64 VName
-> SubExp
-> [SegBinOp GPUMem]
-> InKernelGen [SegRedIntermediateArrays]
generalSegRedInterms Bool
segmented TPrimExp Int64 VName
tblock_id SubExp
tblock_size [SegBinOp GPUMem]
segbinops =
([[VName]] -> [SegRedIntermediateArrays])
-> ImpM GPUMem KernelEnv KernelOp [[VName]]
-> InKernelGen [SegRedIntermediateArrays]
forall a b.
(a -> b)
-> ImpM GPUMem KernelEnv KernelOp a
-> ImpM GPUMem KernelEnv KernelOp b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([VName] -> SegRedIntermediateArrays)
-> [[VName]] -> [SegRedIntermediateArrays]
forall a b. (a -> b) -> [a] -> [b]
map [VName] -> SegRedIntermediateArrays
GeneralSegRedInterms) (ImpM GPUMem KernelEnv KernelOp [[VName]]
-> InKernelGen [SegRedIntermediateArrays])
-> ((Param LParamMem -> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [[VName]])
-> (Param LParamMem -> ImpM GPUMem KernelEnv KernelOp VName)
-> InKernelGen [SegRedIntermediateArrays]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[Param LParamMem]]
-> ([Param LParamMem] -> ImpM GPUMem KernelEnv KernelOp [VName])
-> ImpM GPUMem KernelEnv KernelOp [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ((SegBinOp GPUMem -> [Param LParamMem])
-> [SegBinOp GPUMem] -> [[Param LParamMem]]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp GPUMem -> [Param LParamMem]
paramOf [SegBinOp GPUMem]
segbinops) (([Param LParamMem] -> ImpM GPUMem KernelEnv KernelOp [VName])
-> ImpM GPUMem KernelEnv KernelOp [[VName]])
-> ((Param LParamMem -> ImpM GPUMem KernelEnv KernelOp VName)
-> [Param LParamMem] -> ImpM GPUMem KernelEnv KernelOp [VName])
-> (Param LParamMem -> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [[VName]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param LParamMem -> ImpM GPUMem KernelEnv KernelOp VName)
-> [Param LParamMem] -> ImpM GPUMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((Param LParamMem -> ImpM GPUMem KernelEnv KernelOp VName)
-> InKernelGen [SegRedIntermediateArrays])
-> (Param LParamMem -> ImpM GPUMem KernelEnv KernelOp VName)
-> InKernelGen [SegRedIntermediateArrays]
forall a b. (a -> b) -> a -> b
$ \Param LParamMem
p ->
case Param LParamMem -> LParamMem
forall dec. Param dec -> dec
paramDec Param LParamMem
p of
MemArray PrimType
pt Shape
shape NoUniqueness
_ (ArrayIn VName
mem LMAD
ixfun) -> do
let shape' :: Shape
shape' = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
tblock_size] Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape
shape
let shape_E :: [TPrimExp Int64 VName]
shape_E = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape'
String
-> PrimType
-> Shape
-> VName
-> LMAD
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String -> PrimType -> Shape -> VName -> LMAD -> ImpM rep r op VName
sArray (String
"red_arr_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ PrimType -> String
forall a. Pretty a => a -> String
prettyString PrimType
pt) PrimType
pt Shape
shape' VName
mem (LMAD -> ImpM GPUMem KernelEnv KernelOp VName)
-> LMAD -> ImpM GPUMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$
if Bool
segmented
then LMAD
ixfun
else TPrimExp Int64 VName -> [TPrimExp Int64 VName] -> LMAD
forall num. IntegralExp num => num -> [num] -> LMAD num
LMAD.iota (TPrimExp Int64 VName
tblock_id TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
shape_E) [TPrimExp Int64 VName]
shape_E
LParamMem
_ -> do
let pt :: PrimType
pt = TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType (TypeBase Shape NoUniqueness -> PrimType)
-> TypeBase Shape NoUniqueness -> PrimType
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LParamMem
p
shape :: Shape
shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
tblock_size]
String
-> PrimType
-> Shape
-> Space
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String -> PrimType -> Shape -> Space -> ImpM rep r op VName
sAllocArray (String
"red_arr_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ PrimType -> String
forall a. Pretty a => a -> String
prettyString PrimType
pt) PrimType
pt Shape
shape (Space -> ImpM GPUMem KernelEnv KernelOp VName)
-> Space -> ImpM GPUMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"shared"
groupResultArrays ::
SubExp ->
SubExp ->
[SegBinOp GPUMem] ->
CallKernelGen [[VName]]
groupResultArrays :: SubExp -> SubExp -> [SegBinOp GPUMem] -> CallKernelGen [[VName]]
groupResultArrays SubExp
num_virtblocks SubExp
tblock_size [SegBinOp GPUMem]
segbinops =
[SegBinOp GPUMem]
-> (SegBinOp GPUMem -> ImpM GPUMem HostEnv HostOp [VName])
-> CallKernelGen [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SegBinOp GPUMem]
segbinops ((SegBinOp GPUMem -> ImpM GPUMem HostEnv HostOp [VName])
-> CallKernelGen [[VName]])
-> (SegBinOp GPUMem -> ImpM GPUMem HostEnv HostOp [VName])
-> CallKernelGen [[VName]]
forall a b. (a -> b) -> a -> b
$ \(SegBinOp Commutativity
_ Lambda GPUMem
lam [SubExp]
_ Shape
shape) ->
[TypeBase Shape NoUniqueness]
-> (TypeBase Shape NoUniqueness
-> ImpM GPUMem HostEnv HostOp VName)
-> ImpM GPUMem HostEnv HostOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (Lambda GPUMem -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda GPUMem
lam) ((TypeBase Shape NoUniqueness -> ImpM GPUMem HostEnv HostOp VName)
-> ImpM GPUMem HostEnv HostOp [VName])
-> (TypeBase Shape NoUniqueness
-> ImpM GPUMem HostEnv HostOp VName)
-> ImpM GPUMem HostEnv HostOp [VName]
forall a b. (a -> b) -> a -> b
$ \TypeBase Shape NoUniqueness
t -> do
let pt :: PrimType
pt = TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
t
extra_dim :: SubExp
extra_dim
| TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType TypeBase Shape NoUniqueness
t, Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
shape Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1
| Bool
otherwise = SubExp
tblock_size
full_shape :: Shape
full_shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
extra_dim, SubExp
num_virtblocks] Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape
shape Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> TypeBase Shape NoUniqueness -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase Shape NoUniqueness
t
perm :: [Int]
perm = [Int
1 .. Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
full_shape Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
0]
String
-> PrimType
-> Shape
-> Space
-> [Int]
-> ImpM GPUMem HostEnv HostOp VName
forall rep r op.
String
-> PrimType -> Shape -> Space -> [Int] -> ImpM rep r op VName
sAllocArrayPerm String
"segred_tmp" PrimType
pt Shape
full_shape (String -> Space
Space String
"device") [Int]
perm
type DoCompileSegRed =
Pat LetDecMem ->
Count NumBlocks SubExp ->
Count BlockSize SubExp ->
(TV Int64, Imp.KernelConstExp) ->
SegSpace ->
[SegBinOp GPUMem] ->
DoSegBody ->
CallKernelGen ()
nonsegmentedReduction :: DoCompileSegRed
nonsegmentedReduction :: Pat LParamMem
-> Count NumBlocks SubExp
-> Count BlockSize SubExp
-> (TV Int64, KernelConstExp)
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
nonsegmentedReduction (Pat [PatElem LParamMem]
segred_pes) Count NumBlocks SubExp
num_tblocks Count BlockSize SubExp
tblock_size (TV Int64
chunk_v, KernelConstExp
chunk_const) SegSpace
space [SegBinOp GPUMem]
segbinops DoSegBody
map_body_cont = do
let ([VName]
gtids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
chunk :: TPrimExp Int64 VName
chunk = TV Int64 -> TPrimExp Int64 VName
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
chunk_v
num_tblocks_se :: SubExp
num_tblocks_se = Count NumBlocks SubExp -> SubExp
forall {k} (u :: k) e. Count u e -> e
unCount Count NumBlocks SubExp
num_tblocks
tblock_size_se :: SubExp
tblock_size_se = Count BlockSize SubExp -> SubExp
forall {k} (u :: k) e. Count u e -> e
unCount Count BlockSize SubExp
tblock_size
tblock_size' :: TPrimExp Int64 VName
tblock_size' = SubExp -> TPrimExp Int64 VName
pe64 SubExp
tblock_size_se
global_tid :: TPrimExp Int64 VName
global_tid = VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 (VName -> TPrimExp Int64 VName) -> VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space
n :: TPrimExp Int64 VName
n = SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName) -> SubExp -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ [SubExp] -> SubExp
forall a. HasCallStack => [a] -> a
last [SubExp]
dims
VName
counters <- String -> Int -> ImpM GPUMem HostEnv HostOp VName
genZeroes String
"counters" Int
maxNumOps
[[VName]]
reds_block_res_arrs <- SubExp -> SubExp -> [SegBinOp GPUMem] -> CallKernelGen [[VName]]
groupResultArrays SubExp
num_tblocks_se SubExp
tblock_size_se [SegBinOp GPUMem]
segbinops
SubExp
num_threads <-
(TV Int64 -> SubExp)
-> ImpM GPUMem HostEnv HostOp (TV Int64)
-> ImpM GPUMem HostEnv HostOp SubExp
forall a b.
(a -> b)
-> ImpM GPUMem HostEnv HostOp a -> ImpM GPUMem HostEnv HostOp b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TV Int64 -> SubExp
forall {k} (t :: k). TV t -> SubExp
tvSize (ImpM GPUMem HostEnv HostOp (TV Int64)
-> ImpM GPUMem HostEnv HostOp SubExp)
-> ImpM GPUMem HostEnv HostOp (TV Int64)
-> ImpM GPUMem HostEnv HostOp SubExp
forall a b. (a -> b) -> a -> b
$ String
-> TPrimExp Int64 VName -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"num_threads" (TPrimExp Int64 VName -> ImpM GPUMem HostEnv HostOp (TV Int64))
-> TPrimExp Int64 VName -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 SubExp
num_tblocks_se TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
tblock_size'
let attrs :: KernelAttrs
attrs =
(Count NumBlocks SubExp -> Count BlockSize SubExp -> KernelAttrs
defKernelAttrs Count NumBlocks SubExp
num_tblocks Count BlockSize SubExp
tblock_size)
{ kAttrConstExps = M.singleton (tvVar chunk_v) chunk_const
}
String
-> VName -> KernelAttrs -> InKernelGen () -> CallKernelGen ()
sKernelThread String
"segred_nonseg" (SegSpace -> VName
segFlat SegSpace
space) KernelAttrs
attrs (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
let ltid :: TPrimExp Int32 VName
ltid = KernelConstants -> TPrimExp Int32 VName
kernelLocalThreadId KernelConstants
constants
let tblock_id :: TPrimExp Int32 VName
tblock_id = KernelConstants -> TPrimExp Int32 VName
kernelBlockId KernelConstants
constants
[SegRedIntermediateArrays]
interms <- TPrimExp Int64 VName
-> SubExp
-> SubExp
-> [SegBinOp GPUMem]
-> InKernelGen [SegRedIntermediateArrays]
makeIntermArrays (TPrimExp Int32 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
tblock_id) SubExp
tblock_size_se (TV Int64 -> SubExp
forall {k} (t :: k). TV t -> SubExp
tvSize TV Int64
chunk_v) [SegBinOp GPUMem]
segbinops
VName
sync_arr <- String
-> PrimType
-> Shape
-> Space
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String -> PrimType -> Shape -> Space -> ImpM rep r op VName
sAllocArray String
"sync_arr" PrimType
Bool ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
1]) (Space -> ImpM GPUMem KernelEnv KernelOp VName)
-> Space -> ImpM GPUMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"shared"
[VName] -> (VName -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [VName]
gtids ((VName -> InKernelGen ()) -> InKernelGen ())
-> (VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \VName
v -> VName -> TPrimExp Int64 VName -> InKernelGen ()
forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
v (TPrimExp Int64 VName
0 :: Imp.TExp Int64)
TPrimExp Int64 VName
q <- String
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"q" (TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
n TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`divUp` (TPrimExp Int32 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int32 VName
kernelNumThreads KernelConstants
constants) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
chunk)
[SegBinOpSlug]
slugs <-
((SegBinOp GPUMem, SegRedIntermediateArrays, [VName])
-> ImpM GPUMem KernelEnv KernelOp SegBinOpSlug)
-> [(SegBinOp GPUMem, SegRedIntermediateArrays, [VName])]
-> ImpM GPUMem KernelEnv KernelOp [SegBinOpSlug]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (TPrimExp Int32 VName
-> TPrimExp Int32 VName
-> (SegBinOp GPUMem, SegRedIntermediateArrays, [VName])
-> ImpM GPUMem KernelEnv KernelOp SegBinOpSlug
segBinOpSlug TPrimExp Int32 VName
ltid TPrimExp Int32 VName
tblock_id) ([(SegBinOp GPUMem, SegRedIntermediateArrays, [VName])]
-> ImpM GPUMem KernelEnv KernelOp [SegBinOpSlug])
-> [(SegBinOp GPUMem, SegRedIntermediateArrays, [VName])]
-> ImpM GPUMem KernelEnv KernelOp [SegBinOpSlug]
forall a b. (a -> b) -> a -> b
$
[SegBinOp GPUMem]
-> [SegRedIntermediateArrays]
-> [[VName]]
-> [(SegBinOp GPUMem, SegRedIntermediateArrays, [VName])]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [SegBinOp GPUMem]
segbinops [SegRedIntermediateArrays]
interms [[VName]]
reds_block_res_arrs
[Lambda GPUMem]
new_lambdas <-
[VName]
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> [SegBinOpSlug]
-> DoSegBody
-> InKernelGen [Lambda GPUMem]
reductionStageOne
[VName]
gtids
TPrimExp Int64 VName
n
TPrimExp Int64 VName
global_tid
TPrimExp Int64 VName
q
TPrimExp Int64 VName
chunk
(SubExp -> TPrimExp Int64 VName
pe64 SubExp
num_threads)
[SegBinOpSlug]
slugs
DoSegBody
map_body_cont
let segred_pess :: [[PatElem LParamMem]]
segred_pess =
[Int] -> [PatElem LParamMem] -> [[PatElem LParamMem]]
forall a. [Int] -> [a] -> [[a]]
chunks
((SegBinOp GPUMem -> Int) -> [SegBinOp GPUMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOp GPUMem -> [SubExp]) -> SegBinOp GPUMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral) [SegBinOp GPUMem]
segbinops)
[PatElem LParamMem]
segred_pes
[([PatElem LParamMem], SegBinOpSlug, Lambda GPUMem, Integer)]
-> (([PatElem LParamMem], SegBinOpSlug, Lambda GPUMem, Integer)
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([[PatElem LParamMem]]
-> [SegBinOpSlug]
-> [Lambda GPUMem]
-> [Integer]
-> [([PatElem LParamMem], SegBinOpSlug, Lambda GPUMem, Integer)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [[PatElem LParamMem]]
segred_pess [SegBinOpSlug]
slugs [Lambda GPUMem]
new_lambdas [Integer
0 ..]) ((([PatElem LParamMem], SegBinOpSlug, Lambda GPUMem, Integer)
-> InKernelGen ())
-> InKernelGen ())
-> (([PatElem LParamMem], SegBinOpSlug, Lambda GPUMem, Integer)
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
\([PatElem LParamMem]
pes, SegBinOpSlug
slug, Lambda GPUMem
new_lambda, Integer
i) ->
[PatElem LParamMem]
-> TPrimExp Int32 VName
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> SegBinOpSlug
-> Lambda GPUMem
-> VName
-> VName
-> TPrimExp Int64 VName
-> InKernelGen ()
reductionStageTwo
[PatElem LParamMem]
pes
TPrimExp Int32 VName
tblock_id
[TPrimExp Int64 VName
0]
TPrimExp Int64 VName
0
(TPrimExp Int64 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int64 VName
kernelNumBlocks KernelConstants
constants)
SegBinOpSlug
slug
Lambda GPUMem
new_lambda
VName
counters
VName
sync_arr
(Integer -> TPrimExp Int64 VName
forall a. Num a => Integer -> a
fromInteger Integer
i)
smallSegmentsReduction :: DoCompileSegRed
smallSegmentsReduction :: Pat LParamMem
-> Count NumBlocks SubExp
-> Count BlockSize SubExp
-> (TV Int64, KernelConstExp)
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
smallSegmentsReduction (Pat [PatElem LParamMem]
segred_pes) Count NumBlocks SubExp
num_tblocks Count BlockSize SubExp
tblock_size (TV Int64, KernelConstExp)
_ SegSpace
space [SegBinOp GPUMem]
segbinops DoSegBody
map_body_cont = do
let ([VName]
gtids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
dims' :: [TPrimExp Int64 VName]
dims' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims
segment_size :: TPrimExp Int64 VName
segment_size = [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. HasCallStack => [a] -> a
last [TPrimExp Int64 VName]
dims'
TPrimExp Int64 VName
segment_size_nonzero <-
String
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"segment_size_nonzero" (TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMax64 TPrimExp Int64 VName
1 TPrimExp Int64 VName
segment_size
let tblock_size_se :: SubExp
tblock_size_se = Count BlockSize SubExp -> SubExp
forall {k} (u :: k) e. Count u e -> e
unCount Count BlockSize SubExp
tblock_size
num_tblocks_se :: SubExp
num_tblocks_se = Count NumBlocks SubExp -> SubExp
forall {k} (u :: k) e. Count u e -> e
unCount Count NumBlocks SubExp
num_tblocks
num_tblocks' :: TPrimExp Int64 VName
num_tblocks' = SubExp -> TPrimExp Int64 VName
pe64 SubExp
num_tblocks_se
tblock_size' :: TPrimExp Int64 VName
tblock_size' = SubExp -> TPrimExp Int64 VName
pe64 SubExp
tblock_size_se
SubExp
num_threads <- (TV Int64 -> SubExp)
-> ImpM GPUMem HostEnv HostOp (TV Int64)
-> ImpM GPUMem HostEnv HostOp SubExp
forall a b.
(a -> b)
-> ImpM GPUMem HostEnv HostOp a -> ImpM GPUMem HostEnv HostOp b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TV Int64 -> SubExp
forall {k} (t :: k). TV t -> SubExp
tvSize (ImpM GPUMem HostEnv HostOp (TV Int64)
-> ImpM GPUMem HostEnv HostOp SubExp)
-> ImpM GPUMem HostEnv HostOp (TV Int64)
-> ImpM GPUMem HostEnv HostOp SubExp
forall a b. (a -> b) -> a -> b
$ String
-> TPrimExp Int64 VName -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"num_threads" (TPrimExp Int64 VName -> ImpM GPUMem HostEnv HostOp (TV Int64))
-> TPrimExp Int64 VName -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
num_tblocks' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
tblock_size'
let num_segments :: TPrimExp Int64 VName
num_segments = [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. HasCallStack => [a] -> [a]
init [TPrimExp Int64 VName]
dims'
segments_per_block :: TPrimExp Int64 VName
segments_per_block = TPrimExp Int64 VName
tblock_size' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int64 VName
segment_size_nonzero
required_blocks :: TPrimExp Int32 VName
required_blocks = TPrimExp Int64 VName -> TPrimExp Int32 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TPrimExp Int32 VName)
-> TPrimExp Int64 VName -> TPrimExp Int32 VName
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
num_segments TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`divUp` TPrimExp Int64 VName
segments_per_block
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"# SegRed-small" Maybe Exp
forall a. Maybe a
Nothing
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"num_segments" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
num_segments
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"segment_size" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
segment_size
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"segments_per_block" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
segments_per_block
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"required_blocks" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int32 VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int32 VName
required_blocks
String
-> VName -> KernelAttrs -> InKernelGen () -> CallKernelGen ()
sKernelThread String
"segred_small" (SegSpace -> VName
segFlat SegSpace
space) (Count NumBlocks SubExp -> Count BlockSize SubExp -> KernelAttrs
defKernelAttrs Count NumBlocks SubExp
num_tblocks Count BlockSize SubExp
tblock_size) (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
let tblock_id :: TPrimExp Int64 VName
tblock_id = KernelConstants -> TPrimExp Int64 VName
kernelBlockSize KernelConstants
constants
ltid :: TPrimExp Int64 VName
ltid = TPrimExp Int32 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TPrimExp Int32 VName -> TPrimExp Int64 VName)
-> TPrimExp Int32 VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int32 VName
kernelLocalThreadId KernelConstants
constants
[SegRedIntermediateArrays]
interms <- Bool
-> TPrimExp Int64 VName
-> SubExp
-> [SegBinOp GPUMem]
-> InKernelGen [SegRedIntermediateArrays]
generalSegRedInterms Bool
True TPrimExp Int64 VName
tblock_id SubExp
tblock_size_se [SegBinOp GPUMem]
segbinops
let reds_arrs :: [[VName]]
reds_arrs = (SegRedIntermediateArrays -> [VName])
-> [SegRedIntermediateArrays] -> [[VName]]
forall a b. (a -> b) -> [a] -> [b]
map SegRedIntermediateArrays -> [VName]
blockRedArrs [SegRedIntermediateArrays]
interms
SegVirt
-> TPrimExp Int32 VName
-> (TPrimExp Int32 VName -> InKernelGen ())
-> InKernelGen ()
virtualiseBlocks SegVirt
SegVirt TPrimExp Int32 VName
required_blocks ((TPrimExp Int32 VName -> InKernelGen ()) -> InKernelGen ())
-> (TPrimExp Int32 VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int32 VName
virttblock_id -> do
let segment_index :: TPrimExp Int64 VName
segment_index =
(TPrimExp Int64 VName
ltid TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int64 VName
segment_size_nonzero)
TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ (TPrimExp Int32 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
virttblock_id TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 VName
segments_per_block)
index_within_segment :: TPrimExp Int64 VName
index_within_segment = TPrimExp Int64 VName
ltid TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`rem` TPrimExp Int64 VName
segment_size
[(VName, TPrimExp Int64 VName)]
-> TPrimExp Int64 VName -> InKernelGen ()
forall rep r op.
[(VName, TPrimExp Int64 VName)]
-> TPrimExp Int64 VName -> ImpM rep r op ()
dIndexSpace ([VName]
-> [TPrimExp Int64 VName] -> [(VName, TPrimExp Int64 VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip ([VName] -> [VName]
forall a. HasCallStack => [a] -> [a]
init [VName]
gtids) ([TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. HasCallStack => [a] -> [a]
init [TPrimExp Int64 VName]
dims')) TPrimExp Int64 VName
segment_index
VName -> TPrimExp Int64 VName -> InKernelGen ()
forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ ([VName] -> VName
forall a. HasCallStack => [a] -> a
last [VName]
gtids) TPrimExp Int64 VName
index_within_segment
let in_bounds :: InKernelGen ()
in_bounds =
DoSegBody
map_body_cont DoSegBody -> DoSegBody
forall a b. (a -> b) -> a -> b
$ \[(SubExp, [TPrimExp Int64 VName])]
red_res ->
Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"save results to be reduced" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let red_dests :: [(VName, [TPrimExp Int64 VName])]
red_dests = (VName -> (VName, [TPrimExp Int64 VName]))
-> [VName] -> [(VName, [TPrimExp Int64 VName])]
forall a b. (a -> b) -> [a] -> [b]
map (,[TPrimExp Int64 VName
ltid]) ([[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
reds_arrs)
[(VName, [TPrimExp Int64 VName])]
-> [(SubExp, [TPrimExp Int64 VName])]
-> ((VName, [TPrimExp Int64 VName])
-> (SubExp, [TPrimExp Int64 VName]) -> InKernelGen ())
-> InKernelGen ()
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [(VName, [TPrimExp Int64 VName])]
red_dests [(SubExp, [TPrimExp Int64 VName])]
red_res (((VName, [TPrimExp Int64 VName])
-> (SubExp, [TPrimExp Int64 VName]) -> InKernelGen ())
-> InKernelGen ())
-> ((VName, [TPrimExp Int64 VName])
-> (SubExp, [TPrimExp Int64 VName]) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
d, [TPrimExp Int64 VName]
d_is) (SubExp
res, [TPrimExp Int64 VName]
res_is) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
d [TPrimExp Int64 VName]
d_is SubExp
res [TPrimExp Int64 VName]
res_is
out_of_bounds :: InKernelGen ()
out_of_bounds =
[SegBinOp GPUMem]
-> [[VName]]
-> (SegBinOp GPUMem -> [VName] -> InKernelGen ())
-> InKernelGen ()
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [SegBinOp GPUMem]
segbinops [[VName]]
reds_arrs ((SegBinOp GPUMem -> [VName] -> InKernelGen ()) -> InKernelGen ())
-> (SegBinOp GPUMem -> [VName] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOp Commutativity
_ Lambda GPUMem
_ [SubExp]
nes Shape
_) [VName]
red_arrs ->
[VName]
-> [SubExp]
-> (VName -> SubExp -> InKernelGen ())
-> InKernelGen ()
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [VName]
red_arrs [SubExp]
nes ((VName -> SubExp -> InKernelGen ()) -> InKernelGen ())
-> (VName -> SubExp -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \VName
arr SubExp
ne ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
arr [TPrimExp Int64 VName
ltid] SubExp
ne []
Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"apply map function if in bounds" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
TPrimExp Bool VName
-> InKernelGen () -> InKernelGen () -> InKernelGen ()
forall rep r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
( TPrimExp Int64 VName
segment_size
TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. TPrimExp Int64 VName
0
TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. [(VName, SubExp)] -> TPrimExp Bool VName
isActive ([(VName, SubExp)] -> [(VName, SubExp)]
forall a. HasCallStack => [a] -> [a]
init ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids [SubExp]
dims)
TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp Int64 VName
ltid
TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 VName
segment_size
TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
segments_per_block
)
InKernelGen ()
in_bounds
InKernelGen ()
out_of_bounds
KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
let crossesSegment :: TPrimExp Int32 VName -> TPrimExp Int32 VName -> TPrimExp Bool VName
crossesSegment TPrimExp Int32 VName
from TPrimExp Int32 VName
to =
(TPrimExp Int32 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
to TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int32 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
from) TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. (TPrimExp Int32 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
to TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`rem` TPrimExp Int64 VName
segment_size)
TPrimExp Bool VName -> InKernelGen () -> InKernelGen ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int64 VName
segment_size TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. TPrimExp Int64 VName
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"perform segmented scan to imitate reduction" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[SegBinOp GPUMem]
-> [[VName]]
-> (SegBinOp GPUMem -> [VName] -> InKernelGen ())
-> InKernelGen ()
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [SegBinOp GPUMem]
segbinops [[VName]]
reds_arrs ((SegBinOp GPUMem -> [VName] -> InKernelGen ()) -> InKernelGen ())
-> (SegBinOp GPUMem -> [VName] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOp Commutativity
_ Lambda GPUMem
red_op [SubExp]
_ Shape
_) [VName]
red_arrs ->
Maybe
(TPrimExp Int32 VName
-> TPrimExp Int32 VName -> TPrimExp Bool VName)
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> Lambda GPUMem
-> [VName]
-> InKernelGen ()
blockScan
((TPrimExp Int32 VName
-> TPrimExp Int32 VName -> TPrimExp Bool VName)
-> Maybe
(TPrimExp Int32 VName
-> TPrimExp Int32 VName -> TPrimExp Bool VName)
forall a. a -> Maybe a
Just TPrimExp Int32 VName -> TPrimExp Int32 VName -> TPrimExp Bool VName
crossesSegment)
(TPrimExp Int64 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 SubExp
num_threads)
(TPrimExp Int64 VName
segment_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
segments_per_block)
Lambda GPUMem
red_op
[VName]
red_arrs
KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"save final values of segments"
(InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ TPrimExp Bool VName -> InKernelGen () -> InKernelGen ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen
( TPrimExp Int32 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
virttblock_id
TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
segments_per_block
TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 VName
ltid
TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 VName
num_segments
TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp Int64 VName
ltid
TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 VName
segments_per_block
)
(InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [PatElem LParamMem]
-> [VName]
-> (PatElem LParamMem -> VName -> InKernelGen ())
-> InKernelGen ()
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [PatElem LParamMem]
segred_pes ([[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
reds_arrs)
((PatElem LParamMem -> VName -> InKernelGen ()) -> InKernelGen ())
-> (PatElem LParamMem -> VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \PatElem LParamMem
pe VName
arr -> do
let flat_segment_index :: TPrimExp Int64 VName
flat_segment_index =
TPrimExp Int32 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
virttblock_id TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
segments_per_block TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 VName
ltid
gtids' :: [TPrimExp Int64 VName]
gtids' =
[TPrimExp Int64 VName]
-> TPrimExp Int64 VName -> [TPrimExp Int64 VName]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex ([TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. HasCallStack => [a] -> [a]
init [TPrimExp Int64 VName]
dims') TPrimExp Int64 VName
flat_segment_index
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix
(PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)
[TPrimExp Int64 VName]
gtids'
(VName -> SubExp
Var VName
arr)
[(TPrimExp Int64 VName
ltid TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
segment_size_nonzero TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1]
KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
largeSegmentsReduction :: DoCompileSegRed
largeSegmentsReduction :: Pat LParamMem
-> Count NumBlocks SubExp
-> Count BlockSize SubExp
-> (TV Int64, KernelConstExp)
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
largeSegmentsReduction (Pat [PatElem LParamMem]
segred_pes) Count NumBlocks SubExp
num_tblocks Count BlockSize SubExp
tblock_size (TV Int64
chunk_v, KernelConstExp
chunk_const) SegSpace
space [SegBinOp GPUMem]
segbinops DoSegBody
map_body_cont = do
let ([VName]
gtids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
dims' :: [TPrimExp Int64 VName]
dims' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims
num_segments :: TPrimExp Int64 VName
num_segments = [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. HasCallStack => [a] -> [a]
init [TPrimExp Int64 VName]
dims'
segment_size :: TPrimExp Int64 VName
segment_size = [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. HasCallStack => [a] -> a
last [TPrimExp Int64 VName]
dims'
num_tblocks' :: TPrimExp Int64 VName
num_tblocks' = SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName) -> SubExp -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ Count NumBlocks SubExp -> SubExp
forall {k} (u :: k) e. Count u e -> e
unCount Count NumBlocks SubExp
num_tblocks
tblock_size_se :: SubExp
tblock_size_se = Count BlockSize SubExp -> SubExp
forall {k} (u :: k) e. Count u e -> e
unCount Count BlockSize SubExp
tblock_size
tblock_size' :: TPrimExp Int64 VName
tblock_size' = SubExp -> TPrimExp Int64 VName
pe64 SubExp
tblock_size_se
chunk :: TPrimExp Int64 VName
chunk = TV Int64 -> TPrimExp Int64 VName
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
chunk_v
TPrimExp Int64 VName
blocks_per_segment <-
String
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"blocks_per_segment" (TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName
num_tblocks' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`divUp` TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMax64 TPrimExp Int64 VName
1 TPrimExp Int64 VName
num_segments
TPrimExp Int64 VName
q <-
String
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"q" (TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName
segment_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`divUp` (TPrimExp Int64 VName
tblock_size' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
blocks_per_segment TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
chunk)
TV Int64
num_virtblocks <-
String
-> TPrimExp Int64 VName -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"num_virtblocks" (TPrimExp Int64 VName -> ImpM GPUMem HostEnv HostOp (TV Int64))
-> TPrimExp Int64 VName -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName
blocks_per_segment TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
num_segments
TPrimExp Int64 VName
threads_per_segment <-
String
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"threads_per_segment" (TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName
blocks_per_segment TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
tblock_size'
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"# SegRed-large" Maybe Exp
forall a. Maybe a
Nothing
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"num_segments" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
num_segments
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"segment_size" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
segment_size
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"num_virtblocks" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> Exp) -> TPrimExp Int64 VName -> Exp
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TPrimExp Int64 VName
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
num_virtblocks
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"num_tblocks" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
num_tblocks'
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"tblock_size" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
tblock_size'
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"q" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
q
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"blocks_per_segment" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
blocks_per_segment
[[VName]]
reds_block_res_arrs <- SubExp -> SubExp -> [SegBinOp GPUMem] -> CallKernelGen [[VName]]
groupResultArrays (TV Int64 -> SubExp
forall {k} (t :: k). TV t -> SubExp
tvSize TV Int64
num_virtblocks) SubExp
tblock_size_se [SegBinOp GPUMem]
segbinops
let num_counters :: Int
num_counters = Int
maxNumOps Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
1024
VName
counters <- String -> Int -> ImpM GPUMem HostEnv HostOp VName
genZeroes String
"counters" (Int -> ImpM GPUMem HostEnv HostOp VName)
-> Int -> ImpM GPUMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num_counters
let attrs :: KernelAttrs
attrs =
(Count NumBlocks SubExp -> Count BlockSize SubExp -> KernelAttrs
defKernelAttrs Count NumBlocks SubExp
num_tblocks Count BlockSize SubExp
tblock_size)
{ kAttrConstExps = M.singleton (tvVar chunk_v) chunk_const
}
String
-> VName -> KernelAttrs -> InKernelGen () -> CallKernelGen ()
sKernelThread String
"segred_large" (SegSpace -> VName
segFlat SegSpace
space) KernelAttrs
attrs (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
let tblock_id :: TPrimExp Int64 VName
tblock_id = TPrimExp Int32 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TPrimExp Int32 VName -> TPrimExp Int64 VName)
-> TPrimExp Int32 VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int32 VName
kernelBlockId KernelConstants
constants
ltid :: TPrimExp Int32 VName
ltid = KernelConstants -> TPrimExp Int32 VName
kernelLocalThreadId KernelConstants
constants
[SegRedIntermediateArrays]
interms <- TPrimExp Int64 VName
-> SubExp
-> SubExp
-> [SegBinOp GPUMem]
-> InKernelGen [SegRedIntermediateArrays]
makeIntermArrays TPrimExp Int64 VName
tblock_id SubExp
tblock_size_se (TV Int64 -> SubExp
forall {k} (t :: k). TV t -> SubExp
tvSize TV Int64
chunk_v) [SegBinOp GPUMem]
segbinops
VName
sync_arr <- String
-> PrimType
-> Shape
-> Space
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String -> PrimType -> Shape -> Space -> ImpM rep r op VName
sAllocArray String
"sync_arr" PrimType
Bool ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
1]) (Space -> ImpM GPUMem KernelEnv KernelOp VName)
-> Space -> ImpM GPUMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"shared"
SegVirt
-> TPrimExp Int32 VName
-> (TPrimExp Int32 VName -> InKernelGen ())
-> InKernelGen ()
virtualiseBlocks SegVirt
SegVirt (TPrimExp Int64 VName -> TPrimExp Int32 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TV Int64 -> TPrimExp Int64 VName
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
num_virtblocks)) ((TPrimExp Int32 VName -> InKernelGen ()) -> InKernelGen ())
-> (TPrimExp Int32 VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int32 VName
virttblock_id -> do
let segment_gtids :: [VName]
segment_gtids = [VName] -> [VName]
forall a. HasCallStack => [a] -> [a]
init [VName]
gtids
TPrimExp Int64 VName
flat_segment_id <-
String
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"flat_segment_id" (TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
TPrimExp Int32 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
virttblock_id TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int64 VName
blocks_per_segment
TPrimExp Int64 VName
global_tid <-
String
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"global_tid" (TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
(TPrimExp Int32 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
virttblock_id TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 VName
tblock_size' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int32 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
ltid)
TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`rem` TPrimExp Int64 VName
threads_per_segment
let first_block_for_segment :: TPrimExp Int64 VName
first_block_for_segment = TPrimExp Int64 VName
flat_segment_id TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
blocks_per_segment
[(VName, TPrimExp Int64 VName)]
-> TPrimExp Int64 VName -> InKernelGen ()
forall rep r op.
[(VName, TPrimExp Int64 VName)]
-> TPrimExp Int64 VName -> ImpM rep r op ()
dIndexSpace ([VName]
-> [TPrimExp Int64 VName] -> [(VName, TPrimExp Int64 VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
segment_gtids ([TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. HasCallStack => [a] -> [a]
init [TPrimExp Int64 VName]
dims')) TPrimExp Int64 VName
flat_segment_id
VName -> PrimType -> InKernelGen ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ ([VName] -> VName
forall a. HasCallStack => [a] -> a
last [VName]
gtids) PrimType
int64
let n :: TPrimExp Int64 VName
n = SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName) -> SubExp -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ [SubExp] -> SubExp
forall a. HasCallStack => [a] -> a
last [SubExp]
dims
[SegBinOpSlug]
slugs <-
((SegBinOp GPUMem, SegRedIntermediateArrays, [VName])
-> ImpM GPUMem KernelEnv KernelOp SegBinOpSlug)
-> [(SegBinOp GPUMem, SegRedIntermediateArrays, [VName])]
-> ImpM GPUMem KernelEnv KernelOp [SegBinOpSlug]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (TPrimExp Int32 VName
-> TPrimExp Int32 VName
-> (SegBinOp GPUMem, SegRedIntermediateArrays, [VName])
-> ImpM GPUMem KernelEnv KernelOp SegBinOpSlug
segBinOpSlug TPrimExp Int32 VName
ltid TPrimExp Int32 VName
virttblock_id) ([(SegBinOp GPUMem, SegRedIntermediateArrays, [VName])]
-> ImpM GPUMem KernelEnv KernelOp [SegBinOpSlug])
-> [(SegBinOp GPUMem, SegRedIntermediateArrays, [VName])]
-> ImpM GPUMem KernelEnv KernelOp [SegBinOpSlug]
forall a b. (a -> b) -> a -> b
$
[SegBinOp GPUMem]
-> [SegRedIntermediateArrays]
-> [[VName]]
-> [(SegBinOp GPUMem, SegRedIntermediateArrays, [VName])]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [SegBinOp GPUMem]
segbinops [SegRedIntermediateArrays]
interms [[VName]]
reds_block_res_arrs
[Lambda GPUMem]
new_lambdas <-
[VName]
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> [SegBinOpSlug]
-> DoSegBody
-> InKernelGen [Lambda GPUMem]
reductionStageOne
[VName]
gtids
TPrimExp Int64 VName
n
TPrimExp Int64 VName
global_tid
TPrimExp Int64 VName
q
TPrimExp Int64 VName
chunk
TPrimExp Int64 VName
threads_per_segment
[SegBinOpSlug]
slugs
DoSegBody
map_body_cont
let segred_pess :: [[PatElem LParamMem]]
segred_pess =
[Int] -> [PatElem LParamMem] -> [[PatElem LParamMem]]
forall a. [Int] -> [a] -> [[a]]
chunks
((SegBinOp GPUMem -> Int) -> [SegBinOp GPUMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOp GPUMem -> [SubExp]) -> SegBinOp GPUMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral) [SegBinOp GPUMem]
segbinops)
[PatElem LParamMem]
segred_pes
multiple_blocks_per_segment :: InKernelGen ()
multiple_blocks_per_segment =
[([PatElem LParamMem], SegBinOpSlug, Lambda GPUMem, Int)]
-> (([PatElem LParamMem], SegBinOpSlug, Lambda GPUMem, Int)
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([[PatElem LParamMem]]
-> [SegBinOpSlug]
-> [Lambda GPUMem]
-> [Int]
-> [([PatElem LParamMem], SegBinOpSlug, Lambda GPUMem, Int)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [[PatElem LParamMem]]
segred_pess [SegBinOpSlug]
slugs [Lambda GPUMem]
new_lambdas [Int
0 ..]) ((([PatElem LParamMem], SegBinOpSlug, Lambda GPUMem, Int)
-> InKernelGen ())
-> InKernelGen ())
-> (([PatElem LParamMem], SegBinOpSlug, Lambda GPUMem, Int)
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
\([PatElem LParamMem]
pes, SegBinOpSlug
slug, Lambda GPUMem
new_lambda, Int
i) -> do
let counter_idx :: TPrimExp Int64 VName
counter_idx =
Int -> TPrimExp Int64 VName
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
num_counters)
TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
flat_segment_id
TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`rem` Int -> TPrimExp Int64 VName
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num_counters
[PatElem LParamMem]
-> TPrimExp Int32 VName
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> SegBinOpSlug
-> Lambda GPUMem
-> VName
-> VName
-> TPrimExp Int64 VName
-> InKernelGen ()
reductionStageTwo
[PatElem LParamMem]
pes
TPrimExp Int32 VName
virttblock_id
((VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
segment_gtids)
TPrimExp Int64 VName
first_block_for_segment
TPrimExp Int64 VName
blocks_per_segment
SegBinOpSlug
slug
Lambda GPUMem
new_lambda
VName
counters
VName
sync_arr
TPrimExp Int64 VName
counter_idx
one_block_per_segment :: InKernelGen ()
one_block_per_segment =
Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"first thread in block saves final result to memory" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[SegBinOpSlug]
-> [[PatElem LParamMem]]
-> (SegBinOpSlug -> [PatElem LParamMem] -> InKernelGen ())
-> InKernelGen ()
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [SegBinOpSlug]
slugs [[PatElem LParamMem]]
segred_pess ((SegBinOpSlug -> [PatElem LParamMem] -> InKernelGen ())
-> InKernelGen ())
-> (SegBinOpSlug -> [PatElem LParamMem] -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \SegBinOpSlug
slug [PatElem LParamMem]
pes ->
TPrimExp Bool VName -> InKernelGen () -> InKernelGen ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int32 VName
ltid TPrimExp Int32 VName -> TPrimExp Int32 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int32 VName
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[PatElem LParamMem]
-> [(VName, [TPrimExp Int64 VName])]
-> (PatElem LParamMem
-> (VName, [TPrimExp Int64 VName]) -> InKernelGen ())
-> InKernelGen ()
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [PatElem LParamMem]
pes (SegBinOpSlug -> [(VName, [TPrimExp Int64 VName])]
slugAccs SegBinOpSlug
slug) ((PatElem LParamMem
-> (VName, [TPrimExp Int64 VName]) -> InKernelGen ())
-> InKernelGen ())
-> (PatElem LParamMem
-> (VName, [TPrimExp Int64 VName]) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \PatElem LParamMem
v (VName
acc, [TPrimExp Int64 VName]
acc_is) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
v) ((VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
segment_gtids) (VName -> SubExp
Var VName
acc) [TPrimExp Int64 VName]
acc_is
TPrimExp Bool VName
-> InKernelGen () -> InKernelGen () -> InKernelGen ()
forall rep r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf (TPrimExp Int64 VName
blocks_per_segment TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
1) InKernelGen ()
one_block_per_segment InKernelGen ()
multiple_blocks_per_segment
data SegBinOpSlug = SegBinOpSlug
{ SegBinOpSlug -> SegBinOp GPUMem
slugOp :: SegBinOp GPUMem,
SegBinOpSlug -> SegRedIntermediateArrays
slugInterms :: SegRedIntermediateArrays,
SegBinOpSlug -> [(VName, [TPrimExp Int64 VName])]
slugAccs :: [(VName, [Imp.TExp Int64])],
SegBinOpSlug -> [VName]
blockResArrs :: [VName]
}
segBinOpSlug ::
Imp.TExp Int32 ->
Imp.TExp Int32 ->
(SegBinOp GPUMem, SegRedIntermediateArrays, [VName]) ->
InKernelGen SegBinOpSlug
segBinOpSlug :: TPrimExp Int32 VName
-> TPrimExp Int32 VName
-> (SegBinOp GPUMem, SegRedIntermediateArrays, [VName])
-> ImpM GPUMem KernelEnv KernelOp SegBinOpSlug
segBinOpSlug TPrimExp Int32 VName
ltid TPrimExp Int32 VName
tblock_id (SegBinOp GPUMem
op, SegRedIntermediateArrays
interms, [VName]
block_res_arrs) = do
[(VName, [TPrimExp Int64 VName])]
accs <- (Param LParamMem
-> VName
-> ImpM GPUMem KernelEnv KernelOp (VName, [TPrimExp Int64 VName]))
-> [Param LParamMem]
-> [VName]
-> ImpM GPUMem KernelEnv KernelOp [(VName, [TPrimExp Int64 VName])]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Param LParamMem
-> VName
-> ImpM GPUMem KernelEnv KernelOp (VName, [TPrimExp Int64 VName])
mkAcc (Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams (SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
op)) [VName]
block_res_arrs
SegBinOpSlug -> ImpM GPUMem KernelEnv KernelOp SegBinOpSlug
forall a. a -> ImpM GPUMem KernelEnv KernelOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SegBinOpSlug -> ImpM GPUMem KernelEnv KernelOp SegBinOpSlug)
-> SegBinOpSlug -> ImpM GPUMem KernelEnv KernelOp SegBinOpSlug
forall a b. (a -> b) -> a -> b
$ SegBinOp GPUMem
-> SegRedIntermediateArrays
-> [(VName, [TPrimExp Int64 VName])]
-> [VName]
-> SegBinOpSlug
SegBinOpSlug SegBinOp GPUMem
op SegRedIntermediateArrays
interms [(VName, [TPrimExp Int64 VName])]
accs [VName]
block_res_arrs
where
mkAcc :: Param LParamMem
-> VName
-> ImpM GPUMem KernelEnv KernelOp (VName, [TPrimExp Int64 VName])
mkAcc Param LParamMem
p VName
block_res_arr
| Prim PrimType
t <- Param LParamMem -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LParamMem
p,
Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (SegBinOp GPUMem -> Shape
forall rep. SegBinOp rep -> Shape
segBinOpShape SegBinOp GPUMem
op) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = do
VName
block_res_acc <- String -> PrimType -> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op. String -> PrimType -> ImpM rep r op VName
dPrimS (VName -> String
baseString (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_block_res_acc") PrimType
t
(VName, [TPrimExp Int64 VName])
-> ImpM GPUMem KernelEnv KernelOp (VName, [TPrimExp Int64 VName])
forall a. a -> ImpM GPUMem KernelEnv KernelOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
block_res_acc, [])
| Bool
otherwise =
(VName, [TPrimExp Int64 VName])
-> ImpM GPUMem KernelEnv KernelOp (VName, [TPrimExp Int64 VName])
forall a. a -> ImpM GPUMem KernelEnv KernelOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
block_res_arr, [TPrimExp Int32 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
ltid, TPrimExp Int32 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
tblock_id])
slugLambda :: SegBinOpSlug -> Lambda GPUMem
slugLambda :: SegBinOpSlug -> Lambda GPUMem
slugLambda = SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda (SegBinOp GPUMem -> Lambda GPUMem)
-> (SegBinOpSlug -> SegBinOp GPUMem)
-> SegBinOpSlug
-> Lambda GPUMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp GPUMem
slugOp
slugBody :: SegBinOpSlug -> Body GPUMem
slugBody :: SegBinOpSlug -> Body GPUMem
slugBody = Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody (Lambda GPUMem -> Body GPUMem)
-> (SegBinOpSlug -> Lambda GPUMem) -> SegBinOpSlug -> Body GPUMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> Lambda GPUMem
slugLambda
slugParams :: SegBinOpSlug -> [LParam GPUMem]
slugParams :: SegBinOpSlug -> [LParam GPUMem]
slugParams = Lambda GPUMem -> [LParam GPUMem]
Lambda GPUMem -> [Param LParamMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams (Lambda GPUMem -> [Param LParamMem])
-> (SegBinOpSlug -> Lambda GPUMem)
-> SegBinOpSlug
-> [Param LParamMem]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> Lambda GPUMem
slugLambda
slugNeutral :: SegBinOpSlug -> [SubExp]
slugNeutral :: SegBinOpSlug -> [SubExp]
slugNeutral = SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral (SegBinOp GPUMem -> [SubExp])
-> (SegBinOpSlug -> SegBinOp GPUMem) -> SegBinOpSlug -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp GPUMem
slugOp
slugShape :: SegBinOpSlug -> Shape
slugShape :: SegBinOpSlug -> Shape
slugShape = SegBinOp GPUMem -> Shape
forall rep. SegBinOp rep -> Shape
segBinOpShape (SegBinOp GPUMem -> Shape)
-> (SegBinOpSlug -> SegBinOp GPUMem) -> SegBinOpSlug -> Shape
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp GPUMem
slugOp
slugsComm :: [SegBinOpSlug] -> Commutativity
slugsComm :: [SegBinOpSlug] -> Commutativity
slugsComm = [Commutativity] -> Commutativity
forall a. Monoid a => [a] -> a
mconcat ([Commutativity] -> Commutativity)
-> ([SegBinOpSlug] -> [Commutativity])
-> [SegBinOpSlug]
-> Commutativity
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SegBinOpSlug -> Commutativity)
-> [SegBinOpSlug] -> [Commutativity]
forall a b. (a -> b) -> [a] -> [b]
map (SegBinOp GPUMem -> Commutativity
forall rep. SegBinOp rep -> Commutativity
segBinOpComm (SegBinOp GPUMem -> Commutativity)
-> (SegBinOpSlug -> SegBinOp GPUMem)
-> SegBinOpSlug
-> Commutativity
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp GPUMem
slugOp)
slugSplitParams :: SegBinOpSlug -> ([LParam GPUMem], [LParam GPUMem])
slugSplitParams :: SegBinOpSlug -> ([LParam GPUMem], [LParam GPUMem])
slugSplitParams SegBinOpSlug
slug = Int -> [LParam GPUMem] -> ([LParam GPUMem], [LParam GPUMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug)) ([LParam GPUMem] -> ([LParam GPUMem], [LParam GPUMem]))
-> [LParam GPUMem] -> ([LParam GPUMem], [LParam GPUMem])
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> [LParam GPUMem]
slugParams SegBinOpSlug
slug
slugBlockRedArrs :: SegBinOpSlug -> [VName]
slugBlockRedArrs :: SegBinOpSlug -> [VName]
slugBlockRedArrs = SegRedIntermediateArrays -> [VName]
blockRedArrs (SegRedIntermediateArrays -> [VName])
-> (SegBinOpSlug -> SegRedIntermediateArrays)
-> SegBinOpSlug
-> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegRedIntermediateArrays
slugInterms
slugPrivChunks :: SegBinOpSlug -> [VName]
slugPrivChunks :: SegBinOpSlug -> [VName]
slugPrivChunks = SegRedIntermediateArrays -> [VName]
privateChunks (SegRedIntermediateArrays -> [VName])
-> (SegBinOpSlug -> SegRedIntermediateArrays)
-> SegBinOpSlug
-> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegRedIntermediateArrays
slugInterms
slugCollCopyArrs :: SegBinOpSlug -> [VName]
slugCollCopyArrs :: SegBinOpSlug -> [VName]
slugCollCopyArrs = SegRedIntermediateArrays -> [VName]
collCopyArrs (SegRedIntermediateArrays -> [VName])
-> (SegBinOpSlug -> SegRedIntermediateArrays)
-> SegBinOpSlug
-> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegRedIntermediateArrays
slugInterms
reductionStageOne ::
[VName] ->
Imp.TExp Int64 ->
Imp.TExp Int64 ->
Imp.TExp Int64 ->
Imp.TExp Int64 ->
Imp.TExp Int64 ->
[SegBinOpSlug] ->
DoSegBody ->
InKernelGen [Lambda GPUMem]
reductionStageOne :: [VName]
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> [SegBinOpSlug]
-> DoSegBody
-> InKernelGen [Lambda GPUMem]
reductionStageOne [VName]
gtids TPrimExp Int64 VName
n TPrimExp Int64 VName
global_tid TPrimExp Int64 VName
q TPrimExp Int64 VName
chunk TPrimExp Int64 VName
threads_per_segment [SegBinOpSlug]
slugs DoSegBody
body_cont = do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
let glb_ind_var :: TV Int64
glb_ind_var = VName -> TV Int64
forall {k} (t :: k). MkTV t => VName -> TV t
mkTV ([VName] -> VName
forall a. HasCallStack => [a] -> a
last [VName]
gtids)
ltid :: TPrimExp Int64 VName
ltid = TPrimExp Int32 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TPrimExp Int32 VName -> TPrimExp Int64 VName)
-> TPrimExp Int32 VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int32 VName
kernelLocalThreadId KernelConstants
constants
Maybe (Exp GPUMem) -> Scope GPUMem -> InKernelGen ()
forall rep (inner :: * -> *) r op.
Mem rep inner =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope Maybe (Exp GPUMem)
forall a. Maybe a
Nothing (Scope GPUMem -> InKernelGen ()) -> Scope GPUMem -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Param LParamMem] -> Scope GPUMem
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams ([Param LParamMem] -> Scope GPUMem)
-> [Param LParamMem] -> Scope GPUMem
forall a b. (a -> b) -> a -> b
$ (SegBinOpSlug -> [Param LParamMem])
-> [SegBinOpSlug] -> [Param LParamMem]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap SegBinOpSlug -> [LParam GPUMem]
SegBinOpSlug -> [Param LParamMem]
slugParams [SegBinOpSlug]
slugs
Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"ne-initialise the outer (per-block) accumulator(s)" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
[SegBinOpSlug]
-> (SegBinOpSlug -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [SegBinOpSlug]
slugs ((SegBinOpSlug -> InKernelGen ()) -> InKernelGen ())
-> (SegBinOpSlug -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \SegBinOpSlug
slug ->
[(VName, [TPrimExp Int64 VName])]
-> [SubExp]
-> ((VName, [TPrimExp Int64 VName]) -> SubExp -> InKernelGen ())
-> InKernelGen ()
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ (SegBinOpSlug -> [(VName, [TPrimExp Int64 VName])]
slugAccs SegBinOpSlug
slug) (SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug) (((VName, [TPrimExp Int64 VName]) -> SubExp -> InKernelGen ())
-> InKernelGen ())
-> ((VName, [TPrimExp Int64 VName]) -> SubExp -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
acc, [TPrimExp Int64 VName]
acc_is) SubExp
ne ->
Shape
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall rep r op.
Shape
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) (([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ())
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
vec_is ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
acc ([TPrimExp Int64 VName]
acc_is [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
vec_is) SubExp
ne []
[Lambda GPUMem]
new_lambdas <- (SegBinOpSlug -> ImpM GPUMem KernelEnv KernelOp (Lambda GPUMem))
-> [SegBinOpSlug] -> InKernelGen [Lambda GPUMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Lambda GPUMem -> ImpM GPUMem KernelEnv KernelOp (Lambda GPUMem)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda (Lambda GPUMem -> ImpM GPUMem KernelEnv KernelOp (Lambda GPUMem))
-> (SegBinOpSlug -> Lambda GPUMem)
-> SegBinOpSlug
-> ImpM GPUMem KernelEnv KernelOp (Lambda GPUMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> Lambda GPUMem
slugLambda) [SegBinOpSlug]
slugs
let tblock_size :: TPrimExp Int32 VName
tblock_size = TPrimExp Int64 VName -> TPrimExp Int32 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TPrimExp Int32 VName)
-> TPrimExp Int64 VName -> TPrimExp Int32 VName
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int64 VName
kernelBlockSize KernelConstants
constants
let doBlockReduce :: InKernelGen ()
doBlockReduce =
[SegBinOpSlug]
-> [Lambda GPUMem]
-> (SegBinOpSlug -> Lambda GPUMem -> InKernelGen ())
-> InKernelGen ()
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [SegBinOpSlug]
slugs [Lambda GPUMem]
new_lambdas ((SegBinOpSlug -> Lambda GPUMem -> InKernelGen ())
-> InKernelGen ())
-> (SegBinOpSlug -> Lambda GPUMem -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \SegBinOpSlug
slug Lambda GPUMem
new_lambda -> do
let accs :: [(VName, [TPrimExp Int64 VName])]
accs = SegBinOpSlug -> [(VName, [TPrimExp Int64 VName])]
slugAccs SegBinOpSlug
slug
let params :: [LParam GPUMem]
params = SegBinOpSlug -> [LParam GPUMem]
slugParams SegBinOpSlug
slug
Shape
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall rep r op.
Shape
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) (([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ())
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
vec_is -> do
let block_red_arrs :: [VName]
block_red_arrs = SegBinOpSlug -> [VName]
slugBlockRedArrs SegBinOpSlug
slug
Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"store accs. prims go in lmem; non-prims in params (in global mem)" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(VName, (VName, [TPrimExp Int64 VName]), Param LParamMem)]
-> ((VName, (VName, [TPrimExp Int64 VName]), Param LParamMem)
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName]
-> [(VName, [TPrimExp Int64 VName])]
-> [Param LParamMem]
-> [(VName, (VName, [TPrimExp Int64 VName]), Param LParamMem)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [VName]
block_red_arrs [(VName, [TPrimExp Int64 VName])]
accs [LParam GPUMem]
[Param LParamMem]
params) (((VName, (VName, [TPrimExp Int64 VName]), Param LParamMem)
-> InKernelGen ())
-> InKernelGen ())
-> ((VName, (VName, [TPrimExp Int64 VName]), Param LParamMem)
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
\(VName
arr, (VName
acc, [TPrimExp Int64 VName]
acc_is), Param LParamMem
p) ->
if Param LParamMem -> Bool
forall p. Typed p => Param p -> Bool
isPrimParam Param LParamMem
p
then VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
arr [TPrimExp Int64 VName
ltid] (VName -> SubExp
Var VName
acc) ([TPrimExp Int64 VName]
acc_is [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
vec_is)
else VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
acc) ([TPrimExp Int64 VName]
acc_is [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
vec_is)
KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
TPrimExp Int32 VName -> Lambda GPUMem -> [VName] -> InKernelGen ()
blockReduce TPrimExp Int32 VName
tblock_size Lambda GPUMem
new_lambda [VName]
block_red_arrs
KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"thread 0 updates per-block acc(s); rest reset to ne" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
TPrimExp Bool VName
-> InKernelGen () -> InKernelGen () -> InKernelGen ()
forall rep r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
(TPrimExp Int64 VName
ltid TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0)
( [(VName, [TPrimExp Int64 VName])]
-> [Param LParamMem]
-> ((VName, [TPrimExp Int64 VName])
-> Param LParamMem -> InKernelGen ())
-> InKernelGen ()
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [(VName, [TPrimExp Int64 VName])]
accs (Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
new_lambda) (((VName, [TPrimExp Int64 VName])
-> Param LParamMem -> InKernelGen ())
-> InKernelGen ())
-> ((VName, [TPrimExp Int64 VName])
-> Param LParamMem -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
\(VName
acc, [TPrimExp Int64 VName]
acc_is) Param LParamMem
p ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
acc ([TPrimExp Int64 VName]
acc_is [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
vec_is) (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) []
)
( [(VName, [TPrimExp Int64 VName])]
-> [SubExp]
-> ((VName, [TPrimExp Int64 VName]) -> SubExp -> InKernelGen ())
-> InKernelGen ()
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [(VName, [TPrimExp Int64 VName])]
accs (SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug) (((VName, [TPrimExp Int64 VName]) -> SubExp -> InKernelGen ())
-> InKernelGen ())
-> ((VName, [TPrimExp Int64 VName]) -> SubExp -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
\(VName
acc, [TPrimExp Int64 VName]
acc_is) SubExp
ne ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
acc ([TPrimExp Int64 VName]
acc_is [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
vec_is) SubExp
ne []
)
case ([SegBinOpSlug] -> Commutativity
slugsComm [SegBinOpSlug]
slugs, (SegBinOpSlug -> Bool) -> [SegBinOpSlug] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (SegBinOp GPUMem -> Bool
isPrimSegBinOp (SegBinOp GPUMem -> Bool)
-> (SegBinOpSlug -> SegBinOp GPUMem) -> SegBinOpSlug -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp GPUMem
slugOp) [SegBinOpSlug]
slugs) of
(Commutativity
Noncommutative, Bool
True) ->
[SegBinOpSlug]
-> DoSegBody
-> TV Int64
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> InKernelGen ()
-> InKernelGen ()
noncommPrimParamsStageOneBody
[SegBinOpSlug]
slugs
DoSegBody
body_cont
TV Int64
glb_ind_var
TPrimExp Int64 VName
global_tid
TPrimExp Int64 VName
q
TPrimExp Int64 VName
n
TPrimExp Int64 VName
chunk
InKernelGen ()
doBlockReduce
(Commutativity, Bool)
_ ->
[SegBinOpSlug]
-> DoSegBody
-> TV Int64
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> InKernelGen ()
-> InKernelGen ()
generalStageOneBody
[SegBinOpSlug]
slugs
DoSegBody
body_cont
TV Int64
glb_ind_var
TPrimExp Int64 VName
global_tid
TPrimExp Int64 VName
q
TPrimExp Int64 VName
n
TPrimExp Int64 VName
threads_per_segment
InKernelGen ()
doBlockReduce
[Lambda GPUMem] -> InKernelGen [Lambda GPUMem]
forall a. a -> ImpM GPUMem KernelEnv KernelOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Lambda GPUMem]
new_lambdas
generalStageOneBody ::
[SegBinOpSlug] ->
DoSegBody ->
TV Int64 ->
Imp.TExp Int64 ->
Imp.TExp Int64 ->
Imp.TExp Int64 ->
Imp.TExp Int64 ->
InKernelGen () ->
InKernelGen ()
generalStageOneBody :: [SegBinOpSlug]
-> DoSegBody
-> TV Int64
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> InKernelGen ()
-> InKernelGen ()
generalStageOneBody [SegBinOpSlug]
slugs DoSegBody
body_cont TV Int64
glb_ind_var TPrimExp Int64 VName
global_tid TPrimExp Int64 VName
q TPrimExp Int64 VName
n TPrimExp Int64 VName
threads_per_segment InKernelGen ()
doBlockReduce = do
let is_comm :: Bool
is_comm = [SegBinOpSlug] -> Commutativity
slugsComm [SegBinOpSlug]
slugs Commutativity -> Commutativity -> Bool
forall a. Eq a => a -> a -> Bool
== Commutativity
Commutative
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
let tblock_size :: TPrimExp Int64 VName
tblock_size = KernelConstants -> TPrimExp Int64 VName
kernelBlockSize KernelConstants
constants
ltid :: TPrimExp Int64 VName
ltid = TPrimExp Int32 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TPrimExp Int32 VName -> TPrimExp Int64 VName)
-> TPrimExp Int32 VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int32 VName
kernelLocalThreadId KernelConstants
constants
TPrimExp Int64 VName
tblock_id_in_segment <- String
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"tblock_id_in_segment" (TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
global_tid TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int64 VName
tblock_size
TPrimExp Int64 VName
block_base_offset <- String
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"block_base_offset" (TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
tblock_id_in_segment TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
q TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
tblock_size
String
-> TPrimExp Int64 VName
-> (TPrimExp Int64 VName -> InKernelGen ())
-> InKernelGen ()
forall {k} (t :: k) rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TPrimExp Int64 VName
q ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ())
-> (TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
i -> do
TPrimExp Int64 VName
block_offset <- String
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"block_offset" (TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
block_base_offset TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
tblock_size
TV Int64
glb_ind_var
TV Int64 -> TPrimExp Int64 VName -> InKernelGen ()
forall {k} (t :: k) rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- if Bool
is_comm
then TPrimExp Int64 VName
global_tid TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
threads_per_segment TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
i
else TPrimExp Int64 VName
block_offset TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
ltid
TPrimExp Bool VName -> InKernelGen () -> InKernelGen ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TV Int64 -> TPrimExp Int64 VName
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
glb_ind_var TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 VName
n) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"apply map function(s)" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
DoSegBody
body_cont DoSegBody -> DoSegBody
forall a b. (a -> b) -> a -> b
$ \[(SubExp, [TPrimExp Int64 VName])]
all_red_res -> do
let maps_res :: [[(SubExp, [TPrimExp Int64 VName])]]
maps_res = [Int]
-> [(SubExp, [TPrimExp Int64 VName])]
-> [[(SubExp, [TPrimExp Int64 VName])]]
forall a. [Int] -> [a] -> [[a]]
chunks ((SegBinOpSlug -> Int) -> [SegBinOpSlug] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOpSlug -> [SubExp]) -> SegBinOpSlug -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> [SubExp]
slugNeutral) [SegBinOpSlug]
slugs) [(SubExp, [TPrimExp Int64 VName])]
all_red_res
[SegBinOpSlug]
-> [[(SubExp, [TPrimExp Int64 VName])]]
-> (SegBinOpSlug
-> [(SubExp, [TPrimExp Int64 VName])] -> InKernelGen ())
-> InKernelGen ()
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [SegBinOpSlug]
slugs [[(SubExp, [TPrimExp Int64 VName])]]
maps_res ((SegBinOpSlug
-> [(SubExp, [TPrimExp Int64 VName])] -> InKernelGen ())
-> InKernelGen ())
-> (SegBinOpSlug
-> [(SubExp, [TPrimExp Int64 VName])] -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \SegBinOpSlug
slug [(SubExp, [TPrimExp Int64 VName])]
map_res ->
Shape
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall rep r op.
Shape
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) (([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ())
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
vec_is -> do
let ([LParam GPUMem]
acc_params, [LParam GPUMem]
next_params) = SegBinOpSlug -> ([LParam GPUMem], [LParam GPUMem])
slugSplitParams SegBinOpSlug
slug
Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"load accumulator(s)" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[Param LParamMem]
-> [(VName, [TPrimExp Int64 VName])]
-> (Param LParamMem
-> (VName, [TPrimExp Int64 VName]) -> InKernelGen ())
-> InKernelGen ()
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [LParam GPUMem]
[Param LParamMem]
acc_params (SegBinOpSlug -> [(VName, [TPrimExp Int64 VName])]
slugAccs SegBinOpSlug
slug) ((Param LParamMem
-> (VName, [TPrimExp Int64 VName]) -> InKernelGen ())
-> InKernelGen ())
-> (Param LParamMem
-> (VName, [TPrimExp Int64 VName]) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \Param LParamMem
p (VName
acc, [TPrimExp Int64 VName]
acc_is) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
acc) ([TPrimExp Int64 VName]
acc_is [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
vec_is)
Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"load next value(s)" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[Param LParamMem]
-> [(SubExp, [TPrimExp Int64 VName])]
-> (Param LParamMem
-> (SubExp, [TPrimExp Int64 VName]) -> InKernelGen ())
-> InKernelGen ()
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [LParam GPUMem]
[Param LParamMem]
next_params [(SubExp, [TPrimExp Int64 VName])]
map_res ((Param LParamMem
-> (SubExp, [TPrimExp Int64 VName]) -> InKernelGen ())
-> InKernelGen ())
-> (Param LParamMem
-> (SubExp, [TPrimExp Int64 VName]) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \Param LParamMem
p (SubExp
res, [TPrimExp Int64 VName]
res_is) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
res ([TPrimExp Int64 VName]
res_is [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
vec_is)
Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"apply reduction operator(s)"
(InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Names -> Stms GPUMem -> InKernelGen () -> InKernelGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (Body GPUMem -> Stms GPUMem
forall rep. Body rep -> Stms rep
bodyStms (Body GPUMem -> Stms GPUMem) -> Body GPUMem -> Stms GPUMem
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body GPUMem
slugBody SegBinOpSlug
slug)
(InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"store in accumulator(s)"
(InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [(VName, [TPrimExp Int64 VName])]
-> [SubExp]
-> ((VName, [TPrimExp Int64 VName]) -> SubExp -> InKernelGen ())
-> InKernelGen ()
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_
(SegBinOpSlug -> [(VName, [TPrimExp Int64 VName])]
slugAccs SegBinOpSlug
slug)
((SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ([SubExpRes] -> [SubExp]) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Body GPUMem -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult (Body GPUMem -> [SubExpRes]) -> Body GPUMem -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body GPUMem
slugBody SegBinOpSlug
slug)
(((VName, [TPrimExp Int64 VName]) -> SubExp -> InKernelGen ())
-> InKernelGen ())
-> ((VName, [TPrimExp Int64 VName]) -> SubExp -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
acc, [TPrimExp Int64 VName]
acc_is) SubExp
se ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
acc ([TPrimExp Int64 VName]
acc_is [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
vec_is) SubExp
se []
Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
is_comm InKernelGen ()
doBlockReduce
KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
is_comm InKernelGen ()
doBlockReduce
noncommPrimParamsStageOneBody ::
[SegBinOpSlug] ->
DoSegBody ->
TV Int64 ->
Imp.TExp Int64 ->
Imp.TExp Int64 ->
Imp.TExp Int64 ->
Imp.TExp Int64 ->
InKernelGen () ->
InKernelGen ()
noncommPrimParamsStageOneBody :: [SegBinOpSlug]
-> DoSegBody
-> TV Int64
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> InKernelGen ()
-> InKernelGen ()
noncommPrimParamsStageOneBody [SegBinOpSlug]
slugs DoSegBody
body_cont TV Int64
glb_ind_var TPrimExp Int64 VName
global_tid TPrimExp Int64 VName
q TPrimExp Int64 VName
n TPrimExp Int64 VName
chunk InKernelGen ()
doLMemBlockReduce = do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
let tblock_size :: TPrimExp Int64 VName
tblock_size = KernelConstants -> TPrimExp Int64 VName
kernelBlockSize KernelConstants
constants
ltid :: TPrimExp Int64 VName
ltid = TPrimExp Int32 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TPrimExp Int32 VName -> TPrimExp Int64 VName)
-> TPrimExp Int32 VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int32 VName
kernelLocalThreadId KernelConstants
constants
TPrimExp Int64 VName
tblock_id_in_segment <- String
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"block_offset_in_segment" (TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
global_tid TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int64 VName
tblock_size
TPrimExp Int64 VName
block_stride <- String
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"block_stride" (TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
tblock_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
chunk
TPrimExp Int64 VName
block_base_offset <- String
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"block_base_offset" (TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
tblock_id_in_segment TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
q TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
block_stride
let chunkLoop :: (TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
chunkLoop = String
-> TPrimExp Int64 VName
-> (TPrimExp Int64 VName -> InKernelGen ())
-> InKernelGen ()
forall {k} (t :: k) rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"k" TPrimExp Int64 VName
chunk
String
-> TPrimExp Int64 VName
-> (TPrimExp Int64 VName -> InKernelGen ())
-> InKernelGen ()
forall {k} (t :: k) rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TPrimExp Int64 VName
q ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ())
-> (TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
i -> do
TPrimExp Int64 VName
block_offset <- String
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"block_offset" (TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
block_base_offset TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
block_stride
(TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
chunkLoop ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ())
-> (TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
k -> do
TPrimExp Int64 VName
loc_ind <- String
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"loc_ind" (TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
k TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
tblock_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
ltid
TV Int64
glb_ind_var TV Int64 -> TPrimExp Int64 VName -> InKernelGen ()
forall {k} (t :: k) rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TPrimExp Int64 VName
block_offset TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
loc_ind
TPrimExp Bool VName
-> InKernelGen () -> InKernelGen () -> InKernelGen ()
forall rep r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
(TV Int64 -> TPrimExp Int64 VName
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
glb_ind_var TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 VName
n)
( DoSegBody
body_cont DoSegBody -> DoSegBody
forall a b. (a -> b) -> a -> b
$ \[(SubExp, [TPrimExp Int64 VName])]
all_red_res -> do
let slugs_res :: [[(SubExp, [TPrimExp Int64 VName])]]
slugs_res = [Int]
-> [(SubExp, [TPrimExp Int64 VName])]
-> [[(SubExp, [TPrimExp Int64 VName])]]
forall a. [Int] -> [a] -> [[a]]
chunks ((SegBinOpSlug -> Int) -> [SegBinOpSlug] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOpSlug -> [SubExp]) -> SegBinOpSlug -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> [SubExp]
slugNeutral) [SegBinOpSlug]
slugs) [(SubExp, [TPrimExp Int64 VName])]
all_red_res
[SegBinOpSlug]
-> [[(SubExp, [TPrimExp Int64 VName])]]
-> (SegBinOpSlug
-> [(SubExp, [TPrimExp Int64 VName])] -> InKernelGen ())
-> InKernelGen ()
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [SegBinOpSlug]
slugs [[(SubExp, [TPrimExp Int64 VName])]]
slugs_res ((SegBinOpSlug
-> [(SubExp, [TPrimExp Int64 VName])] -> InKernelGen ())
-> InKernelGen ())
-> (SegBinOpSlug
-> [(SubExp, [TPrimExp Int64 VName])] -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \SegBinOpSlug
slug [(SubExp, [TPrimExp Int64 VName])]
slug_res -> do
let priv_chunks :: [VName]
priv_chunks = SegBinOpSlug -> [VName]
slugPrivChunks SegBinOpSlug
slug
Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"write map result(s) to private chunk(s)" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[VName]
-> [(SubExp, [TPrimExp Int64 VName])]
-> (VName -> (SubExp, [TPrimExp Int64 VName]) -> InKernelGen ())
-> InKernelGen ()
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [VName]
priv_chunks [(SubExp, [TPrimExp Int64 VName])]
slug_res ((VName -> (SubExp, [TPrimExp Int64 VName]) -> InKernelGen ())
-> InKernelGen ())
-> (VName -> (SubExp, [TPrimExp Int64 VName]) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \VName
priv_chunk (SubExp
res, [TPrimExp Int64 VName]
res_is) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
priv_chunk [TPrimExp Int64 VName
k] SubExp
res [TPrimExp Int64 VName]
res_is
)
( [SegBinOpSlug]
-> (SegBinOpSlug -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [SegBinOpSlug]
slugs ((SegBinOpSlug -> InKernelGen ()) -> InKernelGen ())
-> (SegBinOpSlug -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \SegBinOpSlug
slug ->
[VName]
-> [SubExp]
-> (VName -> SubExp -> InKernelGen ())
-> InKernelGen ()
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ (SegBinOpSlug -> [VName]
slugPrivChunks SegBinOpSlug
slug) (SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug) ((VName -> SubExp -> InKernelGen ()) -> InKernelGen ())
-> (VName -> SubExp -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
\VName
priv_chunk SubExp
ne ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
priv_chunk [TPrimExp Int64 VName
k] SubExp
ne []
)
KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"effectualize collective copies in shared memory" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
[SegBinOpSlug]
-> (SegBinOpSlug -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [SegBinOpSlug]
slugs ((SegBinOpSlug -> InKernelGen ()) -> InKernelGen ())
-> (SegBinOpSlug -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \SegBinOpSlug
slug -> do
let coll_copy_arrs :: [VName]
coll_copy_arrs = SegBinOpSlug -> [VName]
slugCollCopyArrs SegBinOpSlug
slug
let priv_chunks :: [VName]
priv_chunks = SegBinOpSlug -> [VName]
slugPrivChunks SegBinOpSlug
slug
[VName]
-> [VName] -> (VName -> VName -> InKernelGen ()) -> InKernelGen ()
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [VName]
coll_copy_arrs [VName]
priv_chunks ((VName -> VName -> InKernelGen ()) -> InKernelGen ())
-> (VName -> VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \VName
lmem_arr VName
priv_chunk -> do
(TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
chunkLoop ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ())
-> (TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
k -> do
TPrimExp Int64 VName
lmem_idx <- String
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"lmem_idx" (TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
ltid TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
k TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
tblock_size
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
lmem_arr [TPrimExp Int64 VName
lmem_idx] (VName -> SubExp
Var VName
priv_chunk) [TPrimExp Int64 VName
k]
KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
(TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
chunkLoop ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ())
-> (TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
k -> do
TPrimExp Int64 VName
lmem_idx <- String
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"lmem_idx" (TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
ltid TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
chunk TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
k
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
priv_chunk [TPrimExp Int64 VName
k] (VName -> SubExp
Var VName
lmem_arr) [TPrimExp Int64 VName
lmem_idx]
KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"per-thread sequential reduction of private chunk(s)" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
(TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
chunkLoop ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ())
-> (TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
k ->
[SegBinOpSlug]
-> (SegBinOpSlug -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [SegBinOpSlug]
slugs ((SegBinOpSlug -> InKernelGen ()) -> InKernelGen ())
-> (SegBinOpSlug -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \SegBinOpSlug
slug -> do
let accs :: [VName]
accs = ((VName, [TPrimExp Int64 VName]) -> VName)
-> [(VName, [TPrimExp Int64 VName])] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, [TPrimExp Int64 VName]) -> VName
forall a b. (a, b) -> a
fst ([(VName, [TPrimExp Int64 VName])] -> [VName])
-> [(VName, [TPrimExp Int64 VName])] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> [(VName, [TPrimExp Int64 VName])]
slugAccs SegBinOpSlug
slug
let ([LParam GPUMem]
acc_ps, [LParam GPUMem]
next_ps) = SegBinOpSlug -> ([LParam GPUMem], [LParam GPUMem])
slugSplitParams SegBinOpSlug
slug
let ps_accs_chunks :: [(Param LParamMem, Param LParamMem, VName, VName)]
ps_accs_chunks = [Param LParamMem]
-> [Param LParamMem]
-> [VName]
-> [VName]
-> [(Param LParamMem, Param LParamMem, VName, VName)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [LParam GPUMem]
[Param LParamMem]
acc_ps [LParam GPUMem]
[Param LParamMem]
next_ps [VName]
accs (SegBinOpSlug -> [VName]
slugPrivChunks SegBinOpSlug
slug)
Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"load params for all reductions" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
[(Param LParamMem, Param LParamMem, VName, VName)]
-> ((Param LParamMem, Param LParamMem, VName, VName)
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Param LParamMem, Param LParamMem, VName, VName)]
ps_accs_chunks (((Param LParamMem, Param LParamMem, VName, VName)
-> InKernelGen ())
-> InKernelGen ())
-> ((Param LParamMem, Param LParamMem, VName, VName)
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
acc_p, Param LParamMem
next_p, VName
acc, VName
priv_chunk) -> do
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
acc_p) [] (VName -> SubExp
Var VName
acc) []
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
next_p) [] (VName -> SubExp
Var VName
priv_chunk) [TPrimExp Int64 VName
k]
Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"apply reduction operator(s)" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let binop_ress :: [SubExp]
binop_ress = (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ([SubExpRes] -> [SubExp]) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Body GPUMem -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult (Body GPUMem -> [SubExpRes]) -> Body GPUMem -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body GPUMem
slugBody SegBinOpSlug
slug
Names -> Stms GPUMem -> InKernelGen () -> InKernelGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (Body GPUMem -> Stms GPUMem
forall rep. Body rep -> Stms rep
bodyStms (Body GPUMem -> Stms GPUMem) -> Body GPUMem -> Stms GPUMem
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body GPUMem
slugBody SegBinOpSlug
slug) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[VName]
-> [SubExp]
-> (VName -> SubExp -> InKernelGen ())
-> InKernelGen ()
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [VName]
accs [SubExp]
binop_ress ((VName -> SubExp -> InKernelGen ()) -> InKernelGen ())
-> (VName -> SubExp -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \VName
acc SubExp
binop_res ->
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM VName
acc [] SubExp
binop_res []
InKernelGen ()
doLMemBlockReduce
KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
reductionStageTwo ::
[PatElem LetDecMem] ->
Imp.TExp Int32 ->
[Imp.TExp Int64] ->
Imp.TExp Int64 ->
Imp.TExp Int64 ->
SegBinOpSlug ->
Lambda GPUMem ->
VName ->
VName ->
Imp.TExp Int64 ->
InKernelGen ()
reductionStageTwo :: [PatElem LParamMem]
-> TPrimExp Int32 VName
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> SegBinOpSlug
-> Lambda GPUMem
-> VName
-> VName
-> TPrimExp Int64 VName
-> InKernelGen ()
reductionStageTwo [PatElem LParamMem]
segred_pes TPrimExp Int32 VName
tblock_id [TPrimExp Int64 VName]
segment_gtids TPrimExp Int64 VName
first_block_for_segment TPrimExp Int64 VName
blocks_per_segment SegBinOpSlug
slug Lambda GPUMem
new_lambda VName
counters VName
sync_arr TPrimExp Int64 VName
counter_idx = do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
let ltid32 :: TPrimExp Int32 VName
ltid32 = KernelConstants -> TPrimExp Int32 VName
kernelLocalThreadId KernelConstants
constants
ltid :: TPrimExp Int64 VName
ltid = TPrimExp Int32 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
ltid32
tblock_size :: TPrimExp Int64 VName
tblock_size = KernelConstants -> TPrimExp Int64 VName
kernelBlockSize KernelConstants
constants
let ([LParam GPUMem]
acc_params, [LParam GPUMem]
next_params) = SegBinOpSlug -> ([LParam GPUMem], [LParam GPUMem])
slugSplitParams SegBinOpSlug
slug
nes :: [SubExp]
nes = SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug
red_arrs :: [VName]
red_arrs = SegBinOpSlug -> [VName]
slugBlockRedArrs SegBinOpSlug
slug
block_res_arrs :: [VName]
block_res_arrs = SegBinOpSlug -> [VName]
blockResArrs SegBinOpSlug
slug
TV Int32
old_counter <- String -> ImpM GPUMem KernelEnv KernelOp (TV Int32)
forall {k} (t :: k) rep r op.
MkTV t =>
String -> ImpM rep r op (TV t)
dPrim String
"old_counter"
(VName
counter_mem, Space
_, Count Elements (TPrimExp Int64 VName)
counter_offset) <- VName
-> [TPrimExp Int64 VName]
-> ImpM
GPUMem
KernelEnv
KernelOp
(VName, Space, Count Elements (TPrimExp Int64 VName))
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> ImpM
rep r op (VName, Space, Count Elements (TPrimExp Int64 VName))
fullyIndexArray VName
counters [TPrimExp Int64 VName
counter_idx]
Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"first thread in block saves block result to global memory" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
TPrimExp Bool VName -> InKernelGen () -> InKernelGen ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int32 VName
ltid32 TPrimExp Int32 VName -> TPrimExp Int32 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int32 VName
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
[(VName, (VName, [TPrimExp Int64 VName]))]
-> ((VName, (VName, [TPrimExp Int64 VName])) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Int
-> [(VName, (VName, [TPrimExp Int64 VName]))]
-> [(VName, (VName, [TPrimExp Int64 VName]))]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([(VName, (VName, [TPrimExp Int64 VName]))]
-> [(VName, (VName, [TPrimExp Int64 VName]))])
-> [(VName, (VName, [TPrimExp Int64 VName]))]
-> [(VName, (VName, [TPrimExp Int64 VName]))]
forall a b. (a -> b) -> a -> b
$ [VName]
-> [(VName, [TPrimExp Int64 VName])]
-> [(VName, (VName, [TPrimExp Int64 VName]))]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
block_res_arrs (SegBinOpSlug -> [(VName, [TPrimExp Int64 VName])]
slugAccs SegBinOpSlug
slug)) (((VName, (VName, [TPrimExp Int64 VName])) -> InKernelGen ())
-> InKernelGen ())
-> ((VName, (VName, [TPrimExp Int64 VName])) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
v, (VName
acc, [TPrimExp Int64 VName]
acc_is)) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
writeAtomic VName
v [TPrimExp Int64 VName
0, TPrimExp Int32 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
tblock_id] (VName -> SubExp
Var VName
acc) [TPrimExp Int64 VName]
acc_is
KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp
(KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Space -> AtomicOp -> KernelOp
Imp.Atomic Space
DefaultSpace
(AtomicOp -> KernelOp) -> AtomicOp -> KernelOp
forall a b. (a -> b) -> a -> b
$ IntType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
Imp.AtomicAdd
IntType
Int32
(TV Int32 -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV Int32
old_counter)
VName
counter_mem
Count Elements (TPrimExp Int64 VName)
counter_offset
(Exp -> AtomicOp) -> Exp -> AtomicOp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int32 VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int32 VName
1 :: Imp.TExp Int32)
VName -> [TPrimExp Int64 VName] -> Exp -> InKernelGen ()
forall rep r op.
VName -> [TPrimExp Int64 VName] -> Exp -> ImpM rep r op ()
sWrite VName
sync_arr [TPrimExp Int64 VName
0] (Exp -> InKernelGen ()) -> Exp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ TPrimExp Bool VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Bool VName -> Exp) -> TPrimExp Bool VName -> Exp
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TPrimExp Int32 VName
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
old_counter TPrimExp Int32 VName -> TPrimExp Int32 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName -> TPrimExp Int32 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName
blocks_per_segment TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)
KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceGlobal
TV Bool
is_last_block <- String -> ImpM GPUMem KernelEnv KernelOp (TV Bool)
forall {k} (t :: k) rep r op.
MkTV t =>
String -> ImpM rep r op (TV t)
dPrim String
"is_last_block"
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (TV Bool -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV Bool
is_last_block) [] (VName -> SubExp
Var VName
sync_arr) [TPrimExp Int64 VName
0]
TPrimExp Bool VName -> InKernelGen () -> InKernelGen ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TV Bool -> TPrimExp Bool VName
forall {k} (t :: k). TV t -> TExp t
tvExp TV Bool
is_last_block) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
TPrimExp Bool VName -> InKernelGen () -> InKernelGen ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int32 VName
ltid32 TPrimExp Int32 VName -> TPrimExp Int32 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int32 VName
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Space -> AtomicOp -> KernelOp
Imp.Atomic Space
DefaultSpace (AtomicOp -> KernelOp) -> AtomicOp -> KernelOp
forall a b. (a -> b) -> a -> b
$
IntType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
Imp.AtomicAdd IntType
Int32 (TV Int32 -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV Int32
old_counter) VName
counter_mem Count Elements (TPrimExp Int64 VName)
counter_offset (Exp -> AtomicOp) -> Exp -> AtomicOp
forall a b. (a -> b) -> a -> b
$
TPrimExp Int32 VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int32 VName -> Exp) -> TPrimExp Int32 VName -> Exp
forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName -> TPrimExp Int32 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a
negate TPrimExp Int64 VName
blocks_per_segment)
Shape
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall rep r op.
Shape
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) (([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ())
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
vec_is -> do
Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Shape -> Bool
forall a. ShapeBase a -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (Shape -> Bool) -> Shape -> Bool
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal)
Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"read in the per-block-results" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
TPrimExp Int64 VName
read_per_thread <-
String
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"read_per_thread" (TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName
blocks_per_segment TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`divUp` TPrimExp Int64 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 VName
tblock_size
[Param LParamMem]
-> [SubExp]
-> (Param LParamMem -> SubExp -> InKernelGen ())
-> InKernelGen ()
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [LParam GPUMem]
[Param LParamMem]
acc_params [SubExp]
nes ((Param LParamMem -> SubExp -> InKernelGen ()) -> InKernelGen ())
-> (Param LParamMem -> SubExp -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \Param LParamMem
p SubExp
ne ->
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
ne []
String
-> TPrimExp Int64 VName
-> (TPrimExp Int64 VName -> InKernelGen ())
-> InKernelGen ()
forall {k} (t :: k) rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TPrimExp Int64 VName
read_per_thread ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ())
-> (TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
i -> do
TPrimExp Int64 VName
block_res_id <-
String
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"block_res_id" (TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName
ltid TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
read_per_thread TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
i
TPrimExp Int64 VName
index_of_block_res <-
String
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"index_of_block_res" (TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName
first_block_for_segment TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
block_res_id
TPrimExp Bool VName -> InKernelGen () -> InKernelGen ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int64 VName
block_res_id TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 VName
blocks_per_segment) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
[Param LParamMem]
-> [VName]
-> (Param LParamMem -> VName -> InKernelGen ())
-> InKernelGen ()
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [LParam GPUMem]
[Param LParamMem]
next_params [VName]
block_res_arrs ((Param LParamMem -> VName -> InKernelGen ()) -> InKernelGen ())
-> (Param LParamMem -> VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
\Param LParamMem
p VName
block_res_arr ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix
(Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p)
[]
(VName -> SubExp
Var VName
block_res_arr)
([TPrimExp Int64 VName
0, TPrimExp Int64 VName
index_of_block_res] [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
vec_is)
Names -> Stms GPUMem -> InKernelGen () -> InKernelGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (Body GPUMem -> Stms GPUMem
forall rep. Body rep -> Stms rep
bodyStms (Body GPUMem -> Stms GPUMem) -> Body GPUMem -> Stms GPUMem
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body GPUMem
slugBody SegBinOpSlug
slug) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[Param LParamMem]
-> [SubExp]
-> (Param LParamMem -> SubExp -> InKernelGen ())
-> InKernelGen ()
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [LParam GPUMem]
[Param LParamMem]
acc_params ((SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ([SubExpRes] -> [SubExp]) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Body GPUMem -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult (Body GPUMem -> [SubExpRes]) -> Body GPUMem -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body GPUMem
slugBody SegBinOpSlug
slug) ((Param LParamMem -> SubExp -> InKernelGen ()) -> InKernelGen ())
-> (Param LParamMem -> SubExp -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \Param LParamMem
p SubExp
se ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
se []
[Param LParamMem]
-> [VName]
-> (Param LParamMem -> VName -> InKernelGen ())
-> InKernelGen ()
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [LParam GPUMem]
[Param LParamMem]
acc_params [VName]
red_arrs ((Param LParamMem -> VName -> InKernelGen ()) -> InKernelGen ())
-> (Param LParamMem -> VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \Param LParamMem
p VName
arr ->
Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Param LParamMem -> Bool
forall p. Typed p => Param p -> Bool
isPrimParam Param LParamMem
p) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
arr [TPrimExp Int64 VName
ltid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) []
KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"reduce the per-block results" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
TPrimExp Int32 VName -> Lambda GPUMem -> [VName] -> InKernelGen ()
blockReduce (TPrimExp Int64 VName -> TPrimExp Int32 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 VName
tblock_size) Lambda GPUMem
new_lambda [VName]
red_arrs
Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"and back to memory with the final result" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
TPrimExp Bool VName -> InKernelGen () -> InKernelGen ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int32 VName
ltid32 TPrimExp Int32 VName -> TPrimExp Int32 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int32 VName
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[PatElem LParamMem]
-> [Param LParamMem]
-> (PatElem LParamMem -> Param LParamMem -> InKernelGen ())
-> InKernelGen ()
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [PatElem LParamMem]
segred_pes (Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
new_lambda) ((PatElem LParamMem -> Param LParamMem -> InKernelGen ())
-> InKernelGen ())
-> (PatElem LParamMem -> Param LParamMem -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \PatElem LParamMem
pe Param LParamMem
p ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix
(PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)
([TPrimExp Int64 VName]
segment_gtids [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
vec_is)
(VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p)
[]