{-# LANGUAGE TypeFamilies #-}
module Futhark.CodeGen.ImpGen.GPU.SegHist (compileSegHist) where
import Control.Monad
import Data.List qualified as L
import Data.Map qualified as M
import Data.Maybe
import Futhark.CodeGen.ImpCode.GPU qualified as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.GPU.Base
import Futhark.CodeGen.ImpGen.GPU.SegRed (compileSegRed')
import Futhark.Construct (fullSliceNum)
import Futhark.IR.GPUMem
import Futhark.IR.Mem.LMAD qualified as LMAD
import Futhark.Pass.ExplicitAllocations ()
import Futhark.Transform.Substitute
import Futhark.Util (chunks, mapAccumLM, maxinum, splitFromEnd, takeLast)
import Futhark.Util.IntegralExp (divUp, quot, rem)
import Prelude hiding (quot, rem)
data SubhistosInfo = SubhistosInfo
{ SubhistosInfo -> VName
subhistosArray :: VName,
SubhistosInfo -> CallKernelGen ()
subhistosAlloc :: CallKernelGen ()
}
data SegHistSlug = SegHistSlug
{ SegHistSlug -> HistOp GPUMem
slugOp :: HistOp GPUMem,
SegHistSlug -> TV Int64
slugNumSubhistos :: TV Int64,
SegHistSlug -> [SubhistosInfo]
slugSubhistos :: [SubhistosInfo],
SegHistSlug -> AtomicUpdate GPUMem KernelEnv
slugAtomicUpdate :: AtomicUpdate GPUMem KernelEnv
}
histSpaceUsage ::
HistOp GPUMem ->
Imp.Count Imp.Bytes (Imp.TExp Int64)
histSpaceUsage :: HistOp GPUMem -> Count Bytes (TExp Int64)
histSpaceUsage HistOp GPUMem
op =
[Count Bytes (TExp Int64)] -> Count Bytes (TExp Int64)
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Count Bytes (TExp Int64)] -> Count Bytes (TExp Int64))
-> ([Type] -> [Count Bytes (TExp Int64)])
-> [Type]
-> Count Bytes (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Count Bytes (TExp Int64))
-> [Type] -> [Count Bytes (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Count Bytes (TExp Int64)
typeSize (Type -> Count Bytes (TExp Int64))
-> (Type -> Type) -> Type -> Count Bytes (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Shape -> Type
`arrayOfShape` (HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histShape HistOp GPUMem
op Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histOpShape HistOp GPUMem
op))) ([Type] -> Count Bytes (TExp Int64))
-> [Type] -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$
Lambda GPUMem -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (Lambda GPUMem -> [Type]) -> Lambda GPUMem -> [Type]
forall a b. (a -> b) -> a -> b
$
HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp GPUMem
op
histSize :: HistOp GPUMem -> Imp.TExp Int64
histSize :: HistOp GPUMem -> TExp Int64
histSize = [TExp Int64] -> TExp Int64
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TExp Int64] -> TExp Int64)
-> (HistOp GPUMem -> [TExp Int64]) -> HistOp GPUMem -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 ([SubExp] -> [TExp Int64])
-> (HistOp GPUMem -> [SubExp]) -> HistOp GPUMem -> [TExp Int64]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (Shape -> [SubExp])
-> (HistOp GPUMem -> Shape) -> HistOp GPUMem -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histShape
histRank :: HistOp GPUMem -> Int
histRank :: HistOp GPUMem -> Int
histRank = Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (Shape -> Int) -> (HistOp GPUMem -> Shape) -> HistOp GPUMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histShape
computeHistoUsage ::
SegSpace ->
HistOp GPUMem ->
CallKernelGen
( Imp.Count Imp.Bytes (Imp.TExp Int64),
Imp.Count Imp.Bytes (Imp.TExp Int64),
SegHistSlug
)
computeHistoUsage :: SegSpace
-> HistOp GPUMem
-> CallKernelGen
(Count Bytes (TExp Int64), Count Bytes (TExp Int64), SegHistSlug)
computeHistoUsage SegSpace
space HistOp GPUMem
op = do
let segment_dims :: [(VName, SubExp)]
segment_dims = [(VName, SubExp)] -> [(VName, SubExp)]
forall a. HasCallStack => [a] -> [a]
init ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
num_segments :: Int
num_segments = [(VName, SubExp)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(VName, SubExp)]
segment_dims
TV Int64
num_subhistos <- [Char] -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall {k} (t :: k) rep r op.
MkTV t =>
[Char] -> ImpM rep r op (TV t)
dPrim [Char]
"num_subhistos"
[SubhistosInfo]
subhisto_infos <- [(VName, SubExp)]
-> ((VName, SubExp) -> ImpM GPUMem HostEnv HostOp SubhistosInfo)
-> ImpM GPUMem HostEnv HostOp [SubhistosInfo]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (HistOp GPUMem -> [VName]
forall rep. HistOp rep -> [VName]
histDest HistOp GPUMem
op) (HistOp GPUMem -> [SubExp]
forall rep. HistOp rep -> [SubExp]
histNeutral HistOp GPUMem
op)) (((VName, SubExp) -> ImpM GPUMem HostEnv HostOp SubhistosInfo)
-> ImpM GPUMem HostEnv HostOp [SubhistosInfo])
-> ((VName, SubExp) -> ImpM GPUMem HostEnv HostOp SubhistosInfo)
-> ImpM GPUMem HostEnv HostOp [SubhistosInfo]
forall a b. (a -> b) -> a -> b
$ \(VName
dest, SubExp
ne) -> do
Type
dest_t <- VName -> ImpM GPUMem HostEnv HostOp Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
dest
MemLoc
dest_mem <- ArrayEntry -> MemLoc
entryArrayLoc (ArrayEntry -> MemLoc)
-> ImpM GPUMem HostEnv HostOp ArrayEntry
-> ImpM GPUMem HostEnv HostOp MemLoc
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM GPUMem HostEnv HostOp ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
dest
VName
subhistos_mem <-
[Char] -> Space -> ImpM GPUMem HostEnv HostOp VName
forall rep r op. [Char] -> Space -> ImpM rep r op VName
sDeclareMem (VName -> [Char]
baseString VName
dest [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_subhistos_mem") ([Char] -> Space
Space [Char]
"device")
let subhistos_shape :: Shape
subhistos_shape =
[SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape (((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
segment_dims [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [TV Int64 -> SubExp
forall {k} (t :: k). TV t -> SubExp
tvSize TV Int64
num_subhistos])
Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Int -> Shape -> Shape
forall d. Int -> ShapeBase d -> ShapeBase d
stripDims Int
num_segments (Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
dest_t)
VName
subhistos <-
[Char]
-> PrimType
-> Shape
-> VName
-> LMAD
-> ImpM GPUMem HostEnv HostOp VName
forall rep r op.
[Char] -> PrimType -> Shape -> VName -> LMAD -> ImpM rep r op VName
sArray
(VName -> [Char]
baseString VName
dest [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_subhistos")
(Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
dest_t)
Shape
subhistos_shape
VName
subhistos_mem
(LMAD -> ImpM GPUMem HostEnv HostOp VName)
-> LMAD -> ImpM GPUMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> [TExp Int64] -> LMAD
forall num. IntegralExp num => num -> [num] -> LMAD num
LMAD.iota TExp Int64
0
([TExp Int64] -> LMAD) -> [TExp Int64] -> LMAD
forall a b. (a -> b) -> a -> b
$ (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64
([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
subhistos_shape
SubhistosInfo -> ImpM GPUMem HostEnv HostOp SubhistosInfo
forall a. a -> ImpM GPUMem HostEnv HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubhistosInfo -> ImpM GPUMem HostEnv HostOp SubhistosInfo)
-> SubhistosInfo -> ImpM GPUMem HostEnv HostOp SubhistosInfo
forall a b. (a -> b) -> a -> b
$
VName -> CallKernelGen () -> SubhistosInfo
SubhistosInfo VName
subhistos (CallKernelGen () -> SubhistosInfo)
-> CallKernelGen () -> SubhistosInfo
forall a b. (a -> b) -> a -> b
$ do
let unitHistoCase :: CallKernelGen ()
unitHistoCase =
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName -> VName -> Space -> Code HostOp
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
subhistos_mem (MemLoc -> VName
memLocName MemLoc
dest_mem) (Space -> Code HostOp) -> Space -> Code HostOp
forall a b. (a -> b) -> a -> b
$
[Char] -> Space
Space [Char]
"device"
multiHistoCase :: CallKernelGen ()
multiHistoCase = do
let num_elems :: TExp Int64
num_elems = [TExp Int64] -> TExp Int64
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
subhistos_shape
subhistos_mem_size :: Count Bytes (TExp Int64)
subhistos_mem_size =
TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
Imp.bytes (TExp Int64 -> Count Bytes (TExp Int64))
-> TExp Int64 -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$
Count Bytes (TExp Int64) -> TExp Int64
forall {k} (u :: k) e. Count u e -> e
Imp.unCount (TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
Imp.elements TExp Int64
num_elems Count Elements (TExp Int64) -> PrimType -> Count Bytes (TExp Int64)
`Imp.withElemType` Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
dest_t)
VName -> Count Bytes (TExp Int64) -> Space -> CallKernelGen ()
forall rep r op.
VName -> Count Bytes (TExp Int64) -> Space -> ImpM rep r op ()
sAlloc_ VName
subhistos_mem Count Bytes (TExp Int64)
subhistos_mem_size (Space -> CallKernelGen ()) -> Space -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Space
Space [Char]
"device"
VName -> SubExp -> CallKernelGen ()
sReplicate VName
subhistos SubExp
ne
Type
subhistos_t <- VName -> ImpM GPUMem HostEnv HostOp Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
subhistos
let slice :: Slice (TExp Int64)
slice =
[TExp Int64] -> [DimIndex (TExp Int64)] -> Slice (TExp Int64)
forall d. Num d => [d] -> [DimIndex d] -> Slice d
fullSliceNum ((SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
subhistos_t) ([DimIndex (TExp Int64)] -> Slice (TExp Int64))
-> [DimIndex (TExp Int64)] -> Slice (TExp Int64)
forall a b. (a -> b) -> a -> b
$
((VName, SubExp) -> DimIndex (TExp Int64))
-> [(VName, SubExp)] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> TExp Int64 -> DimIndex (TExp Int64)
forall d. Num d => d -> d -> DimIndex d
unitSlice TExp Int64
0 (TExp Int64 -> DimIndex (TExp Int64))
-> ((VName, SubExp) -> TExp Int64)
-> (VName, SubExp)
-> DimIndex (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TExp Int64
pe64 (SubExp -> TExp Int64)
-> ((VName, SubExp) -> SubExp) -> (VName, SubExp) -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) [(VName, SubExp)]
segment_dims
[DimIndex (TExp Int64)]
-> [DimIndex (TExp Int64)] -> [DimIndex (TExp Int64)]
forall a. [a] -> [a] -> [a]
++ [TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix TExp Int64
0]
VName -> Slice (TExp Int64) -> SubExp -> CallKernelGen ()
forall rep r op.
VName -> Slice (TExp Int64) -> SubExp -> ImpM rep r op ()
sUpdate VName
subhistos Slice (TExp Int64)
slice (SubExp -> CallKernelGen ()) -> SubExp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
dest
TExp Bool
-> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf (TV Int64 -> TExp Int64
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
num_subhistos TExp Int64 -> TExp Int64 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64
1) CallKernelGen ()
unitHistoCase CallKernelGen ()
multiHistoCase
let h :: Count Bytes (TExp Int64)
h = HistOp GPUMem -> Count Bytes (TExp Int64)
histSpaceUsage HistOp GPUMem
op
segmented_h :: Count Bytes (TExp Int64)
segmented_h = Count Bytes (TExp Int64)
h Count Bytes (TExp Int64)
-> Count Bytes (TExp Int64) -> Count Bytes (TExp Int64)
forall a. Num a => a -> a -> a
* [Count Bytes (TExp Int64)] -> Count Bytes (TExp Int64)
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ((SubExp -> Count Bytes (TExp Int64))
-> [SubExp] -> [Count Bytes (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
Imp.bytes (TExp Int64 -> Count Bytes (TExp Int64))
-> (SubExp -> TExp Int64) -> SubExp -> Count Bytes (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TExp Int64
pe64) ([SubExp] -> [Count Bytes (TExp Int64)])
-> [SubExp] -> [Count Bytes (TExp Int64)]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [SubExp]
forall a. HasCallStack => [a] -> [a]
init ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space)
AtomicBinOp
atomics <- HostEnv -> AtomicBinOp
hostAtomics (HostEnv -> AtomicBinOp)
-> ImpM GPUMem HostEnv HostOp HostEnv
-> ImpM GPUMem HostEnv HostOp AtomicBinOp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem HostEnv HostOp HostEnv
forall rep r op. ImpM rep r op r
askEnv
(Count Bytes (TExp Int64), Count Bytes (TExp Int64), SegHistSlug)
-> CallKernelGen
(Count Bytes (TExp Int64), Count Bytes (TExp Int64), SegHistSlug)
forall a. a -> ImpM GPUMem HostEnv HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( Count Bytes (TExp Int64)
h,
Count Bytes (TExp Int64)
segmented_h,
HistOp GPUMem
-> TV Int64
-> [SubhistosInfo]
-> AtomicUpdate GPUMem KernelEnv
-> SegHistSlug
SegHistSlug HistOp GPUMem
op TV Int64
num_subhistos [SubhistosInfo]
subhisto_infos (AtomicUpdate GPUMem KernelEnv -> SegHistSlug)
-> AtomicUpdate GPUMem KernelEnv -> SegHistSlug
forall a b. (a -> b) -> a -> b
$
AtomicBinOp -> Lambda GPUMem -> AtomicUpdate GPUMem KernelEnv
atomicUpdateLocking AtomicBinOp
atomics (Lambda GPUMem -> AtomicUpdate GPUMem KernelEnv)
-> Lambda GPUMem -> AtomicUpdate GPUMem KernelEnv
forall a b. (a -> b) -> a -> b
$
HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp GPUMem
op
)
prepareAtomicUpdateGlobal ::
Maybe Locking ->
Shape ->
[VName] ->
SegHistSlug ->
CallKernelGen
( Maybe Locking,
[Imp.TExp Int64] -> InKernelGen ()
)
prepareAtomicUpdateGlobal :: Maybe Locking
-> Shape
-> [VName]
-> SegHistSlug
-> CallKernelGen (Maybe Locking, [TExp Int64] -> InKernelGen ())
prepareAtomicUpdateGlobal Maybe Locking
l Shape
segments [VName]
dests SegHistSlug
slug =
case (Maybe Locking
l, SegHistSlug -> AtomicUpdate GPUMem KernelEnv
slugAtomicUpdate SegHistSlug
slug) of
(Maybe Locking
_, AtomicPrim DoAtomicUpdate GPUMem KernelEnv
f) -> (Maybe Locking, [TExp Int64] -> InKernelGen ())
-> CallKernelGen (Maybe Locking, [TExp Int64] -> InKernelGen ())
forall a. a -> ImpM GPUMem HostEnv HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Locking
l, DoAtomicUpdate GPUMem KernelEnv
f ([Char] -> Space
Space [Char]
"global") [VName]
dests)
(Maybe Locking
_, AtomicCAS DoAtomicUpdate GPUMem KernelEnv
f) -> (Maybe Locking, [TExp Int64] -> InKernelGen ())
-> CallKernelGen (Maybe Locking, [TExp Int64] -> InKernelGen ())
forall a. a -> ImpM GPUMem HostEnv HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Locking
l, DoAtomicUpdate GPUMem KernelEnv
f ([Char] -> Space
Space [Char]
"global") [VName]
dests)
(Just Locking
l', AtomicLocking Locking -> DoAtomicUpdate GPUMem KernelEnv
f) -> (Maybe Locking, [TExp Int64] -> InKernelGen ())
-> CallKernelGen (Maybe Locking, [TExp Int64] -> InKernelGen ())
forall a. a -> ImpM GPUMem HostEnv HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Locking
l, Locking -> DoAtomicUpdate GPUMem KernelEnv
f Locking
l' ([Char] -> Space
Space [Char]
"global") [VName]
dests)
(Maybe Locking
Nothing, AtomicLocking Locking -> DoAtomicUpdate GPUMem KernelEnv
f) -> do
let num_locks :: Int
num_locks = Int
100151
dims :: [TExp Int64]
dims =
(SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$
Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
segments
[SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histOpShape (SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug))
[SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [TV Int64 -> SubExp
forall {k} (t :: k). TV t -> SubExp
tvSize (SegHistSlug -> TV Int64
slugNumSubhistos SegHistSlug
slug)]
[SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histShape (SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug))
VName
locks <- [Char] -> Int -> ImpM GPUMem HostEnv HostOp VName
genZeroes [Char]
"hist_locks" Int
num_locks
let l' :: Locking
l' = VName
-> TExp Int32
-> TExp Int32
-> TExp Int32
-> ([TExp Int64] -> [TExp Int64])
-> Locking
Locking VName
locks TExp Int32
0 TExp Int32
1 TExp Int32
0 (TExp Int64 -> [TExp Int64]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TExp Int64 -> [TExp Int64])
-> ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> [TExp Int64]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`rem` Int -> TExp Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num_locks) (TExp Int64 -> TExp Int64)
-> ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [TExp Int64] -> [TExp Int64] -> TExp Int64
forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [TExp Int64]
dims)
(Maybe Locking, [TExp Int64] -> InKernelGen ())
-> CallKernelGen (Maybe Locking, [TExp Int64] -> InKernelGen ())
forall a. a -> ImpM GPUMem HostEnv HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Locking -> Maybe Locking
forall a. a -> Maybe a
Just Locking
l', Locking -> DoAtomicUpdate GPUMem KernelEnv
f Locking
l' ([Char] -> Space
Space [Char]
"global") [VName]
dests)
data Passage = MustBeSinglePass | MayBeMultiPass deriving (Passage -> Passage -> Bool
(Passage -> Passage -> Bool)
-> (Passage -> Passage -> Bool) -> Eq Passage
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Passage -> Passage -> Bool
== :: Passage -> Passage -> Bool
$c/= :: Passage -> Passage -> Bool
/= :: Passage -> Passage -> Bool
Eq, Eq Passage
Eq Passage =>
(Passage -> Passage -> Ordering)
-> (Passage -> Passage -> Bool)
-> (Passage -> Passage -> Bool)
-> (Passage -> Passage -> Bool)
-> (Passage -> Passage -> Bool)
-> (Passage -> Passage -> Passage)
-> (Passage -> Passage -> Passage)
-> Ord Passage
Passage -> Passage -> Bool
Passage -> Passage -> Ordering
Passage -> Passage -> Passage
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: Passage -> Passage -> Ordering
compare :: Passage -> Passage -> Ordering
$c< :: Passage -> Passage -> Bool
< :: Passage -> Passage -> Bool
$c<= :: Passage -> Passage -> Bool
<= :: Passage -> Passage -> Bool
$c> :: Passage -> Passage -> Bool
> :: Passage -> Passage -> Bool
$c>= :: Passage -> Passage -> Bool
>= :: Passage -> Passage -> Bool
$cmax :: Passage -> Passage -> Passage
max :: Passage -> Passage -> Passage
$cmin :: Passage -> Passage -> Passage
min :: Passage -> Passage -> Passage
Ord)
bodyPassage :: KernelBody GPUMem -> Passage
bodyPassage :: KernelBody GPUMem -> Passage
bodyPassage KernelBody GPUMem
kbody
| Names
forall a. Monoid a => a
mempty Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
== KernelBody (Aliases GPUMem) -> Names
forall rep. Aliased rep => KernelBody rep -> Names
consumedInKernelBody (AliasTable -> KernelBody GPUMem -> KernelBody (Aliases GPUMem)
forall rep.
AliasableRep rep =>
AliasTable -> KernelBody rep -> KernelBody (Aliases rep)
aliasAnalyseKernelBody AliasTable
forall a. Monoid a => a
mempty KernelBody GPUMem
kbody) =
Passage
MayBeMultiPass
| Bool
otherwise =
Passage
MustBeSinglePass
prepareIntermediateArraysGlobal ::
Passage ->
Shape ->
Imp.TExp Int32 ->
Imp.TExp Int64 ->
[SegHistSlug] ->
CallKernelGen
( Imp.TExp Int32,
[[Imp.TExp Int64] -> InKernelGen ()]
)
prepareIntermediateArraysGlobal :: Passage
-> Shape
-> TExp Int32
-> TExp Int64
-> [SegHistSlug]
-> CallKernelGen (TExp Int32, [[TExp Int64] -> InKernelGen ()])
prepareIntermediateArraysGlobal Passage
passage Shape
segments TExp Int32
hist_T TExp Int64
hist_N [SegHistSlug]
slugs = do
TExp Int64
hist_H <- [Char] -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_H" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ (SegHistSlug -> TExp Int64) -> [SegHistSlug] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (HistOp GPUMem -> TExp Int64
histSize (HistOp GPUMem -> TExp Int64)
-> (SegHistSlug -> HistOp GPUMem) -> SegHistSlug -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp GPUMem
slugOp) [SegHistSlug]
slugs
TExp Double
hist_RF <-
[Char] -> TExp Double -> ImpM GPUMem HostEnv HostOp (TExp Double)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_RF" (TExp Double -> ImpM GPUMem HostEnv HostOp (TExp Double))
-> TExp Double -> ImpM GPUMem HostEnv HostOp (TExp Double)
forall a b. (a -> b) -> a -> b
$
[TExp Double] -> TExp Double
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((SegHistSlug -> TExp Double) -> [SegHistSlug] -> [TExp Double]
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> TExp Double
forall {k} {t :: k} {v}. TPrimExp t v -> TPrimExp Double v
r64 (TExp Int64 -> TExp Double)
-> (SegHistSlug -> TExp Int64) -> SegHistSlug -> TExp Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TExp Int64
pe64 (SubExp -> TExp Int64)
-> (SegHistSlug -> SubExp) -> SegHistSlug -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp GPUMem -> SubExp
forall rep. HistOp rep -> SubExp
histRaceFactor (HistOp GPUMem -> SubExp)
-> (SegHistSlug -> HistOp GPUMem) -> SegHistSlug -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp GPUMem
slugOp) [SegHistSlug]
slugs)
TExp Double -> TExp Double -> TExp Double
forall a. Fractional a => a -> a -> a
/ [SegHistSlug] -> TExp Double
forall i a. Num i => [a] -> i
L.genericLength [SegHistSlug]
slugs
TExp Int32
hist_el_size <- [Char] -> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_el_size" (TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ [TExp Int32] -> TExp Int32
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([TExp Int32] -> TExp Int32) -> [TExp Int32] -> TExp Int32
forall a b. (a -> b) -> a -> b
$ (SegHistSlug -> TExp Int32) -> [SegHistSlug] -> [TExp Int32]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> TExp Int32
slugElAvgSize [SegHistSlug]
slugs
TExp Double
hist_C_max <-
[Char] -> TExp Double -> ImpM GPUMem HostEnv HostOp (TExp Double)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_C_max" (TExp Double -> ImpM GPUMem HostEnv HostOp (TExp Double))
-> TExp Double -> ImpM GPUMem HostEnv HostOp (TExp Double)
forall a b. (a -> b) -> a -> b
$
TExp Double -> TExp Double -> TExp Double
forall v.
TPrimExp Double v -> TPrimExp Double v -> TPrimExp Double v
fMin64 (TExp Int32 -> TExp Double
forall {k} {t :: k} {v}. TPrimExp t v -> TPrimExp Double v
r64 TExp Int32
hist_T) (TExp Double -> TExp Double) -> TExp Double -> TExp Double
forall a b. (a -> b) -> a -> b
$
TExp Int64 -> TExp Double
forall {k} {t :: k} {v}. TPrimExp t v -> TPrimExp Double v
r64 TExp Int64
hist_H TExp Double -> TExp Double -> TExp Double
forall a. Fractional a => a -> a -> a
/ TExp Double
hist_k_ct_min
TExp Int32
hist_M_min <-
[Char] -> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_M_min" (TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
TExp Int32 -> TExp Int32 -> TExp Int32
forall v. TPrimExp Int32 v -> TPrimExp Int32 v -> TPrimExp Int32 v
sMax32 TExp Int32
1 (TExp Int32 -> TExp Int32) -> TExp Int32 -> TExp Int32
forall a b. (a -> b) -> a -> b
$
TExp Int64 -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$
TExp Double -> TExp Int64
forall {k} {t :: k} {v}. TPrimExp t v -> TPrimExp Int64 v
t64 (TExp Double -> TExp Int64) -> TExp Double -> TExp Int64
forall a b. (a -> b) -> a -> b
$
TExp Int32 -> TExp Double
forall {k} {t :: k} {v}. TPrimExp t v -> TPrimExp Double v
r64 TExp Int32
hist_T TExp Double -> TExp Double -> TExp Double
forall a. Fractional a => a -> a -> a
/ TExp Double
hist_C_max
TV Int64
hist_L2 <- [Char] -> SizeClass -> ImpM GPUMem HostEnv HostOp (TV Int64)
getSize [Char]
"hist_L2" SizeClass
Imp.SizeCache
let hist_L2_ln_sz :: TExp Double
hist_L2_ln_sz = TExp Double
16 TExp Double -> TExp Double -> TExp Double
forall a. Num a => a -> a -> a
* TExp Double
4
TExp Double
hist_RACE_exp <-
[Char] -> TExp Double -> ImpM GPUMem HostEnv HostOp (TExp Double)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_RACE_exp" (TExp Double -> ImpM GPUMem HostEnv HostOp (TExp Double))
-> TExp Double -> ImpM GPUMem HostEnv HostOp (TExp Double)
forall a b. (a -> b) -> a -> b
$
TExp Double -> TExp Double -> TExp Double
forall v.
TPrimExp Double v -> TPrimExp Double v -> TPrimExp Double v
fMax64 TExp Double
1 (TExp Double -> TExp Double) -> TExp Double -> TExp Double
forall a b. (a -> b) -> a -> b
$
(TExp Double
hist_k_RF TExp Double -> TExp Double -> TExp Double
forall a. Num a => a -> a -> a
* TExp Double
hist_RF)
TExp Double -> TExp Double -> TExp Double
forall a. Fractional a => a -> a -> a
/ (TExp Double
hist_L2_ln_sz TExp Double -> TExp Double -> TExp Double
forall a. Fractional a => a -> a -> a
/ TExp Int32 -> TExp Double
forall {k} {t :: k} {v}. TPrimExp t v -> TPrimExp Double v
r64 TExp Int32
hist_el_size)
TV Int32
hist_S <- [Char] -> ImpM GPUMem HostEnv HostOp (TV Int32)
forall {k} (t :: k) rep r op.
MkTV t =>
[Char] -> ImpM rep r op (TV t)
dPrim [Char]
"hist_S"
TExp Bool
-> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
(TExp Int64
hist_N TExp Int64 -> TExp Int64 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int64
hist_H)
(TV Int32
hist_S TV Int32 -> TExp Int32 -> CallKernelGen ()
forall {k} (t :: k) rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- (TExp Int32
1 :: Imp.TExp Int32))
(CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ TV Int32
hist_S
TV Int32 -> TExp Int32 -> CallKernelGen ()
forall {k} (t :: k) rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- case Passage
passage of
Passage
MayBeMultiPass ->
TExp Int64 -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$
(TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_M_min TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
hist_H TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_el_size)
TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Double -> TExp Int64
forall {k} {t :: k} {v}. TPrimExp t v -> TPrimExp Int64 v
t64 (TExp Double
hist_F_L2 TExp Double -> TExp Double -> TExp Double
forall a. Num a => a -> a -> a
* TExp Int64 -> TExp Double
forall {k} {t :: k} {v}. TPrimExp t v -> TPrimExp Double v
r64 (TV Int64 -> TExp Int64
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
hist_L2) TExp Double -> TExp Double -> TExp Double
forall a. Num a => a -> a -> a
* TExp Double
hist_RACE_exp)
Passage
MustBeSinglePass ->
TExp Int32
1
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Race expansion factor (RACE^exp)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Double -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Double
hist_RACE_exp
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Number of chunks (S)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TExp Int32 -> Exp) -> TExp Int32 -> Exp
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
hist_S
[[TExp Int64] -> InKernelGen ()]
histograms <-
(Maybe Locking, [[TExp Int64] -> InKernelGen ()])
-> [[TExp Int64] -> InKernelGen ()]
forall a b. (a, b) -> b
snd
((Maybe Locking, [[TExp Int64] -> InKernelGen ()])
-> [[TExp Int64] -> InKernelGen ()])
-> ImpM
GPUMem
HostEnv
HostOp
(Maybe Locking, [[TExp Int64] -> InKernelGen ()])
-> ImpM GPUMem HostEnv HostOp [[TExp Int64] -> InKernelGen ()]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Maybe Locking
-> SegHistSlug
-> CallKernelGen (Maybe Locking, [TExp Int64] -> InKernelGen ()))
-> Maybe Locking
-> [SegHistSlug]
-> ImpM
GPUMem
HostEnv
HostOp
(Maybe Locking, [[TExp Int64] -> InKernelGen ()])
forall (m :: * -> *) (t :: * -> *) acc x y.
(Monad m, Traversable t) =>
(acc -> x -> m (acc, y)) -> acc -> t x -> m (acc, t y)
mapAccumLM
(TExp Int64
-> TExp Int32
-> TExp Int32
-> TExp Double
-> Maybe Locking
-> SegHistSlug
-> CallKernelGen (Maybe Locking, [TExp Int64] -> InKernelGen ())
onOp (TV Int64 -> TExp Int64
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
hist_L2) TExp Int32
hist_M_min (TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
hist_S) TExp Double
hist_RACE_exp)
Maybe Locking
forall a. Maybe a
Nothing
[SegHistSlug]
slugs
(TExp Int32, [[TExp Int64] -> InKernelGen ()])
-> CallKernelGen (TExp Int32, [[TExp Int64] -> InKernelGen ()])
forall a. a -> ImpM GPUMem HostEnv HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
hist_S, [[TExp Int64] -> InKernelGen ()]
histograms)
where
hist_k_ct_min :: TExp Double
hist_k_ct_min = TExp Double
2
hist_k_RF :: TExp Double
hist_k_RF = TExp Double
0.75
hist_F_L2 :: TExp Double
hist_F_L2 = TExp Double
0.4
r64 :: TPrimExp t v -> TPrimExp Double v
r64 = PrimExp v -> TPrimExp Double v
forall v. PrimExp v -> TPrimExp Double v
isF64 (PrimExp v -> TPrimExp Double v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> TPrimExp Double v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (IntType -> FloatType -> ConvOp
SIToFP IntType
Int32 FloatType
Float64) (PrimExp v -> PrimExp v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> PrimExp v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp t v -> PrimExp v
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped
t64 :: TPrimExp t v -> TPrimExp Int64 v
t64 = PrimExp v -> TPrimExp Int64 v
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (PrimExp v -> TPrimExp Int64 v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> TPrimExp Int64 v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (FloatType -> IntType -> ConvOp
FPToSI FloatType
Float64 IntType
Int64) (PrimExp v -> PrimExp v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> PrimExp v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp t v -> PrimExp v
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped
slugElAvgSize :: SegHistSlug -> TExp Int32
slugElAvgSize slug :: SegHistSlug
slug@(SegHistSlug HistOp GPUMem
op TV Int64
_ [SubhistosInfo]
_ AtomicUpdate GPUMem KernelEnv
do_op) =
case AtomicUpdate GPUMem KernelEnv
do_op of
AtomicLocking {} ->
SegHistSlug -> TExp Int32
slugElSize SegHistSlug
slug TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` (TExp Int32
1 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ [Type] -> TExp Int32
forall i a. Num i => [a] -> i
L.genericLength (Lambda GPUMem -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp GPUMem
op)))
AtomicUpdate GPUMem KernelEnv
_ ->
SegHistSlug -> TExp Int32
slugElSize SegHistSlug
slug TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` [Type] -> TExp Int32
forall i a. Num i => [a] -> i
L.genericLength (Lambda GPUMem -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp GPUMem
op))
slugElSize :: SegHistSlug -> TExp Int32
slugElSize (SegHistSlug HistOp GPUMem
op TV Int64
_ [SubhistosInfo]
_ AtomicUpdate GPUMem KernelEnv
do_op) =
TExp Int64 -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32)
-> ([Count Bytes (TExp Int64)] -> TExp Int64)
-> [Count Bytes (TExp Int64)]
-> TExp Int32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Count Bytes (TExp Int64) -> TExp Int64
forall {k} (u :: k) e. Count u e -> e
unCount (Count Bytes (TExp Int64) -> TExp Int64)
-> ([Count Bytes (TExp Int64)] -> Count Bytes (TExp Int64))
-> [Count Bytes (TExp Int64)]
-> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Count Bytes (TExp Int64)] -> Count Bytes (TExp Int64)
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Count Bytes (TExp Int64)] -> TExp Int32)
-> [Count Bytes (TExp Int64)] -> TExp Int32
forall a b. (a -> b) -> a -> b
$
case AtomicUpdate GPUMem KernelEnv
do_op of
AtomicLocking {} ->
(Type -> Count Bytes (TExp Int64))
-> [Type] -> [Count Bytes (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Count Bytes (TExp Int64)
typeSize (Type -> Count Bytes (TExp Int64))
-> (Type -> Type) -> Type -> Count Bytes (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Shape -> Type
`arrayOfShape` HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histOpShape HistOp GPUMem
op)) ([Type] -> [Count Bytes (TExp Int64)])
-> [Type] -> [Count Bytes (TExp Int64)]
forall a b. (a -> b) -> a -> b
$
PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32 Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
: Lambda GPUMem -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp GPUMem
op)
AtomicUpdate GPUMem KernelEnv
_ ->
(Type -> Count Bytes (TExp Int64))
-> [Type] -> [Count Bytes (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Count Bytes (TExp Int64)
typeSize (Type -> Count Bytes (TExp Int64))
-> (Type -> Type) -> Type -> Count Bytes (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Shape -> Type
`arrayOfShape` HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histOpShape HistOp GPUMem
op)) ([Type] -> [Count Bytes (TExp Int64)])
-> [Type] -> [Count Bytes (TExp Int64)]
forall a b. (a -> b) -> a -> b
$
Lambda GPUMem -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp GPUMem
op)
onOp :: TExp Int64
-> TExp Int32
-> TExp Int32
-> TExp Double
-> Maybe Locking
-> SegHistSlug
-> CallKernelGen (Maybe Locking, [TExp Int64] -> InKernelGen ())
onOp TExp Int64
hist_L2 TExp Int32
hist_M_min TExp Int32
hist_S TExp Double
hist_RACE_exp Maybe Locking
l SegHistSlug
slug = do
let SegHistSlug HistOp GPUMem
op TV Int64
num_subhistos [SubhistosInfo]
subhisto_info AtomicUpdate GPUMem KernelEnv
do_op = SegHistSlug
slug
hist_H :: TExp Int64
hist_H = HistOp GPUMem -> TExp Int64
histSize HistOp GPUMem
op
TExp Int64
hist_H_chk <- [Char] -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_H_chk" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64
hist_H TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_S
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Chunk size (H_chk)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int64
hist_H_chk
TExp Double
hist_k_max <-
[Char] -> TExp Double -> ImpM GPUMem HostEnv HostOp (TExp Double)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_k_max" (TExp Double -> ImpM GPUMem HostEnv HostOp (TExp Double))
-> TExp Double -> ImpM GPUMem HostEnv HostOp (TExp Double)
forall a b. (a -> b) -> a -> b
$
TExp Double -> TExp Double -> TExp Double
forall v.
TPrimExp Double v -> TPrimExp Double v -> TPrimExp Double v
fMin64
(TExp Double
hist_F_L2 TExp Double -> TExp Double -> TExp Double
forall a. Num a => a -> a -> a
* (TExp Int64 -> TExp Double
forall {k} {t :: k} {v}. TPrimExp t v -> TPrimExp Double v
r64 TExp Int64
hist_L2 TExp Double -> TExp Double -> TExp Double
forall a. Fractional a => a -> a -> a
/ TExp Int32 -> TExp Double
forall {k} {t :: k} {v}. TPrimExp t v -> TPrimExp Double v
r64 (SegHistSlug -> TExp Int32
slugElSize SegHistSlug
slug)) TExp Double -> TExp Double -> TExp Double
forall a. Num a => a -> a -> a
* TExp Double
hist_RACE_exp)
(TExp Int64 -> TExp Double
forall {k} {t :: k} {v}. TPrimExp t v -> TPrimExp Double v
r64 TExp Int64
hist_N)
TExp Double -> TExp Double -> TExp Double
forall a. Fractional a => a -> a -> a
/ TExp Int32 -> TExp Double
forall {k} {t :: k} {v}. TPrimExp t v -> TPrimExp Double v
r64 TExp Int32
hist_T
TExp Int64
hist_u <- [Char] -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_u" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$
case AtomicUpdate GPUMem KernelEnv
do_op of
AtomicPrim {} -> TExp Int64
2
AtomicUpdate GPUMem KernelEnv
_ -> TExp Int64
1
TExp Double
hist_C <-
[Char] -> TExp Double -> ImpM GPUMem HostEnv HostOp (TExp Double)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_C" (TExp Double -> ImpM GPUMem HostEnv HostOp (TExp Double))
-> TExp Double -> ImpM GPUMem HostEnv HostOp (TExp Double)
forall a b. (a -> b) -> a -> b
$
TExp Double -> TExp Double -> TExp Double
forall v.
TPrimExp Double v -> TPrimExp Double v -> TPrimExp Double v
fMin64 (TExp Int32 -> TExp Double
forall {k} {t :: k} {v}. TPrimExp t v -> TPrimExp Double v
r64 TExp Int32
hist_T) (TExp Double -> TExp Double) -> TExp Double -> TExp Double
forall a b. (a -> b) -> a -> b
$
TExp Int64 -> TExp Double
forall {k} {t :: k} {v}. TPrimExp t v -> TPrimExp Double v
r64 (TExp Int64
hist_u TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
hist_H_chk) TExp Double -> TExp Double -> TExp Double
forall a. Fractional a => a -> a -> a
/ TExp Double
hist_k_max
TExp Int32
hist_M <- [Char] -> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_M" (TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
case SegHistSlug -> AtomicUpdate GPUMem KernelEnv
slugAtomicUpdate SegHistSlug
slug of
AtomicPrim {} -> TExp Int32
1
AtomicUpdate GPUMem KernelEnv
_ -> TExp Int32 -> TExp Int32 -> TExp Int32
forall v. TPrimExp Int32 v -> TPrimExp Int32 v -> TPrimExp Int32 v
sMax32 TExp Int32
hist_M_min (TExp Int32 -> TExp Int32) -> TExp Int32 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TExp Double -> TExp Int64
forall {k} {t :: k} {v}. TPrimExp t v -> TPrimExp Int64 v
t64 (TExp Double -> TExp Int64) -> TExp Double -> TExp Int64
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TExp Double
forall {k} {t :: k} {v}. TPrimExp t v -> TPrimExp Double v
r64 TExp Int32
hist_T TExp Double -> TExp Double -> TExp Double
forall a. Fractional a => a -> a -> a
/ TExp Double
hist_C
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Elements/thread in L2 cache (k_max)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Double -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Double
hist_k_max
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Multiplication degree (M)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int32
hist_M
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Cooperation level (C)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Double -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Double
hist_C
TV Int64
num_subhistos TV Int64 -> TExp Int64 -> CallKernelGen ()
forall {k} (t :: k) rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_M
[VName]
dests <- [(VName, SubhistosInfo)]
-> ((VName, SubhistosInfo) -> ImpM GPUMem HostEnv HostOp VName)
-> ImpM GPUMem HostEnv HostOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> [SubhistosInfo] -> [(VName, SubhistosInfo)]
forall a b. [a] -> [b] -> [(a, b)]
zip (HistOp GPUMem -> [VName]
forall rep. HistOp rep -> [VName]
histDest HistOp GPUMem
op) [SubhistosInfo]
subhisto_info) (((VName, SubhistosInfo) -> ImpM GPUMem HostEnv HostOp VName)
-> ImpM GPUMem HostEnv HostOp [VName])
-> ((VName, SubhistosInfo) -> ImpM GPUMem HostEnv HostOp VName)
-> ImpM GPUMem HostEnv HostOp [VName]
forall a b. (a -> b) -> a -> b
$ \(VName
dest, SubhistosInfo
info) -> do
MemLoc
dest_mem <- ArrayEntry -> MemLoc
entryArrayLoc (ArrayEntry -> MemLoc)
-> ImpM GPUMem HostEnv HostOp ArrayEntry
-> ImpM GPUMem HostEnv HostOp MemLoc
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM GPUMem HostEnv HostOp ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
dest
VName
sub_mem <-
(MemLoc -> VName)
-> ImpM GPUMem HostEnv HostOp MemLoc
-> ImpM GPUMem HostEnv HostOp VName
forall a b.
(a -> b)
-> ImpM GPUMem HostEnv HostOp a -> ImpM GPUMem HostEnv HostOp b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap MemLoc -> VName
memLocName (ImpM GPUMem HostEnv HostOp MemLoc
-> ImpM GPUMem HostEnv HostOp VName)
-> ImpM GPUMem HostEnv HostOp MemLoc
-> ImpM GPUMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
ArrayEntry -> MemLoc
entryArrayLoc
(ArrayEntry -> MemLoc)
-> ImpM GPUMem HostEnv HostOp ArrayEntry
-> ImpM GPUMem HostEnv HostOp MemLoc
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM GPUMem HostEnv HostOp ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray (SubhistosInfo -> VName
subhistosArray SubhistosInfo
info)
let unitHistoCase :: CallKernelGen ()
unitHistoCase =
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName -> VName -> Space -> Code HostOp
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
sub_mem (MemLoc -> VName
memLocName MemLoc
dest_mem) (Space -> Code HostOp) -> Space -> Code HostOp
forall a b. (a -> b) -> a -> b
$
[Char] -> Space
Space [Char]
"device"
multiHistoCase :: CallKernelGen ()
multiHistoCase = SubhistosInfo -> CallKernelGen ()
subhistosAlloc SubhistosInfo
info
TExp Bool
-> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf (TExp Int32
hist_M TExp Int32 -> TExp Int32 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
1) CallKernelGen ()
unitHistoCase CallKernelGen ()
multiHistoCase
VName -> ImpM GPUMem HostEnv HostOp VName
forall a. a -> ImpM GPUMem HostEnv HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> ImpM GPUMem HostEnv HostOp VName)
-> VName -> ImpM GPUMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ SubhistosInfo -> VName
subhistosArray SubhistosInfo
info
(Maybe Locking
l', [TExp Int64] -> InKernelGen ()
do_op') <- Maybe Locking
-> Shape
-> [VName]
-> SegHistSlug
-> CallKernelGen (Maybe Locking, [TExp Int64] -> InKernelGen ())
prepareAtomicUpdateGlobal Maybe Locking
l Shape
segments [VName]
dests SegHistSlug
slug
(Maybe Locking, [TExp Int64] -> InKernelGen ())
-> CallKernelGen (Maybe Locking, [TExp Int64] -> InKernelGen ())
forall a. a -> ImpM GPUMem HostEnv HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Locking
l', [TExp Int64] -> InKernelGen ()
do_op')
histKernelGlobalPass ::
[PatElem LetDecMem] ->
Count NumBlocks SubExp ->
Count BlockSize SubExp ->
SegSpace ->
[SegHistSlug] ->
KernelBody GPUMem ->
[[Imp.TExp Int64] -> InKernelGen ()] ->
Imp.TExp Int32 ->
Imp.TExp Int32 ->
CallKernelGen ()
histKernelGlobalPass :: [PatElem LParamMem]
-> Count NumBlocks SubExp
-> Count BlockSize SubExp
-> SegSpace
-> [SegHistSlug]
-> KernelBody GPUMem
-> [[TExp Int64] -> InKernelGen ()]
-> TExp Int32
-> TExp Int32
-> CallKernelGen ()
histKernelGlobalPass [PatElem LParamMem]
map_pes Count NumBlocks SubExp
num_tblocks Count BlockSize SubExp
tblock_size SegSpace
space [SegHistSlug]
slugs KernelBody GPUMem
kbody [[TExp Int64] -> InKernelGen ()]
histograms TExp Int32
hist_S TExp Int32
chk_i = do
let ([VName]
space_is, [SubExp]
space_sizes) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
space_sizes_64 :: [TExp Int64]
space_sizes_64 = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int64 -> TExp Int64)
-> (SubExp -> TExp Int64) -> SubExp -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TExp Int64
pe64) [SubExp]
space_sizes
total_w_64 :: TExp Int64
total_w_64 = [TExp Int64] -> TExp Int64
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp Int64]
space_sizes_64
[TExp Int64]
hist_H_chks <- [TExp Int64]
-> (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> ImpM GPUMem HostEnv HostOp [TExp Int64]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ((SegHistSlug -> TExp Int64) -> [SegHistSlug] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (HistOp GPUMem -> TExp Int64
histSize (HistOp GPUMem -> TExp Int64)
-> (SegHistSlug -> HistOp GPUMem) -> SegHistSlug -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp GPUMem
slugOp) [SegHistSlug]
slugs) ((TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> ImpM GPUMem HostEnv HostOp [TExp Int64])
-> (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> ImpM GPUMem HostEnv HostOp [TExp Int64]
forall a b. (a -> b) -> a -> b
$ \TExp Int64
w ->
[Char] -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_H_chk" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64
w TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_S
[Char]
-> VName -> KernelAttrs -> InKernelGen () -> CallKernelGen ()
sKernelThread [Char]
"seghist_global" (SegSpace -> VName
segFlat SegSpace
space) (Count NumBlocks SubExp -> Count BlockSize SubExp -> KernelAttrs
defKernelAttrs Count NumBlocks SubExp
num_tblocks Count BlockSize SubExp
tblock_size) (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
[TExp Int32]
subhisto_inds <- [SegHistSlug]
-> (SegHistSlug -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> ImpM GPUMem KernelEnv KernelOp [TExp Int32]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SegHistSlug]
slugs ((SegHistSlug -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> ImpM GPUMem KernelEnv KernelOp [TExp Int32])
-> (SegHistSlug -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> ImpM GPUMem KernelEnv KernelOp [TExp Int32]
forall a b. (a -> b) -> a -> b
$ \SegHistSlug
slug ->
[Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"subhisto_ind" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
TExp Int32 -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (KernelConstants -> TExp Int32
kernelGlobalThreadId KernelConstants
constants)
TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` ( KernelConstants -> TExp Int32
kernelNumThreads KernelConstants
constants
TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int64 -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TV Int64 -> TExp Int64
forall {k} (t :: k). TV t -> TExp t
tvExp (SegHistSlug -> TV Int64
slugNumSubhistos SegHistSlug
slug))
)
let gtid :: TExp Int64
gtid = TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelGlobalThreadId KernelConstants
constants
num_threads :: TExp Int64
num_threads = TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelNumThreads KernelConstants
constants
TExp Int64
-> TExp Int64
-> TExp Int64
-> (TExp Int64 -> InKernelGen ())
-> InKernelGen ()
forall {k} (t :: k).
IntExp t =>
TExp t
-> TExp t -> TExp t -> (TExp t -> InKernelGen ()) -> InKernelGen ()
kernelLoop TExp Int64
gtid TExp Int64
num_threads TExp Int64
total_w_64 ((TExp Int64 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int64 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
offset -> do
[(VName, TExp Int64)] -> TExp Int64 -> InKernelGen ()
forall rep r op.
[(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
dIndexSpace ([VName] -> [TExp Int64] -> [(VName, TExp Int64)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
space_is [TExp Int64]
space_sizes_64) TExp Int64
offset
let input_in_bounds :: TExp Bool
input_in_bounds = TExp Int64
offset TExp Int64 -> TExp Int64 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int64
total_w_64
TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
input_in_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Names -> Stms GPUMem -> InKernelGen () -> InKernelGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let ([KernelResult]
red_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([PatElem LParamMem] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElem LParamMem]
map_pes) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
kbody
Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"save map-out results" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(PatElem LParamMem, KernelResult)]
-> ((PatElem LParamMem, KernelResult) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem LParamMem]
-> [KernelResult] -> [(PatElem LParamMem, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
map_pes [KernelResult]
map_res) (((PatElem LParamMem, KernelResult) -> InKernelGen ())
-> InKernelGen ())
-> ((PatElem LParamMem, KernelResult) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, KernelResult
res) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix
(PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)
(((VName, SubExp) -> TExp Int64)
-> [(VName, SubExp)] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
Imp.le64 (VName -> TExp Int64)
-> ((VName, SubExp) -> VName) -> (VName, SubExp) -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst) ([(VName, SubExp)] -> [TExp Int64])
-> [(VName, SubExp)] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space)
(KernelResult -> SubExp
kernelResultSubExp KernelResult
res)
[]
let red_res_split :: [([SubExp], [SubExp])]
red_res_split =
[HistOp GPUMem] -> [SubExp] -> [([SubExp], [SubExp])]
forall rep. [HistOp rep] -> [SubExp] -> [([SubExp], [SubExp])]
splitHistResults ((SegHistSlug -> HistOp GPUMem) -> [SegHistSlug] -> [HistOp GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> HistOp GPUMem
slugOp [SegHistSlug]
slugs) ([SubExp] -> [([SubExp], [SubExp])])
-> [SubExp] -> [([SubExp], [SubExp])]
forall a b. (a -> b) -> a -> b
$
(KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
red_res
Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"perform atomic updates" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(HistOp GPUMem, [TExp Int64] -> InKernelGen (),
([SubExp], [SubExp]), TExp Int32, TExp Int64)]
-> ((HistOp GPUMem, [TExp Int64] -> InKernelGen (),
([SubExp], [SubExp]), TExp Int32, TExp Int64)
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([HistOp GPUMem]
-> [[TExp Int64] -> InKernelGen ()]
-> [([SubExp], [SubExp])]
-> [TExp Int32]
-> [TExp Int64]
-> [(HistOp GPUMem, [TExp Int64] -> InKernelGen (),
([SubExp], [SubExp]), TExp Int32, TExp Int64)]
forall a b c d e.
[a] -> [b] -> [c] -> [d] -> [e] -> [(a, b, c, d, e)]
L.zip5 ((SegHistSlug -> HistOp GPUMem) -> [SegHistSlug] -> [HistOp GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> HistOp GPUMem
slugOp [SegHistSlug]
slugs) [[TExp Int64] -> InKernelGen ()]
histograms [([SubExp], [SubExp])]
red_res_split [TExp Int32]
subhisto_inds [TExp Int64]
hist_H_chks) (((HistOp GPUMem, [TExp Int64] -> InKernelGen (),
([SubExp], [SubExp]), TExp Int32, TExp Int64)
-> InKernelGen ())
-> InKernelGen ())
-> ((HistOp GPUMem, [TExp Int64] -> InKernelGen (),
([SubExp], [SubExp]), TExp Int32, TExp Int64)
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
\( HistOp Shape
dest_shape SubExp
_ [VName]
_ [SubExp]
_ Shape
shape Lambda GPUMem
lam,
[TExp Int64] -> InKernelGen ()
do_op,
([SubExp]
bucket, [SubExp]
vs'),
TExp Int32
subhisto_ind,
TExp Int64
hist_H_chk
) -> do
let chk_beg :: TExp Int64
chk_beg = TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chk_i TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
hist_H_chk
bucket' :: [TExp Int64]
bucket' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
bucket
dest_shape' :: [TExp Int64]
dest_shape' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
dest_shape
flat_bucket :: TExp Int64
flat_bucket = [TExp Int64] -> [TExp Int64] -> TExp Int64
forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [TExp Int64]
dest_shape' [TExp Int64]
bucket'
bucket_in_bounds :: TExp Bool
bucket_in_bounds =
TExp Int64
chk_beg
TExp Int64 -> TExp Int64 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TExp Int64
flat_bucket
TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Int64
flat_bucket
TExp Int64 -> TExp Int64 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. (TExp Int64
chk_beg TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
hist_H_chk)
TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. Slice (TExp Int64) -> [TExp Int64] -> TExp Bool
inBounds ([DimIndex (TExp Int64)] -> Slice (TExp Int64)
forall d. [DimIndex d] -> Slice d
Slice ((TExp Int64 -> DimIndex (TExp Int64))
-> [TExp Int64] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix [TExp Int64]
bucket')) [TExp Int64]
dest_shape'
vs_params :: [Param LParamMem]
vs_params = Int -> [Param LParamMem] -> [Param LParamMem]
forall a. Int -> [a] -> [a]
takeLast ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs') ([Param LParamMem] -> [Param LParamMem])
-> [Param LParamMem] -> [Param LParamMem]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam
TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
bucket_in_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let bucket_is :: [TExp Int64]
bucket_is =
(VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
Imp.le64 ([VName] -> [VName]
forall a. HasCallStack => [a] -> [a]
init [VName]
space_is)
[TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
subhisto_ind]
[TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
dest_shape' TExp Int64
flat_bucket
[LParam GPUMem] -> InKernelGen ()
forall rep (inner :: * -> *) r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams ([LParam GPUMem] -> InKernelGen ())
-> [LParam GPUMem] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam
Shape -> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall rep r op.
Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest Shape
shape (([TExp Int64] -> InKernelGen ()) -> InKernelGen ())
-> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
is -> do
[(Param LParamMem, SubExp)]
-> ((Param LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [SubExp] -> [(Param LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
vs_params [SubExp]
vs') (((Param LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
res) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
res [TExp Int64]
is
[TExp Int64] -> InKernelGen ()
do_op ([TExp Int64]
bucket_is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
is)
histKernelGlobal ::
[PatElem LetDecMem] ->
Count NumBlocks SubExp ->
Count BlockSize SubExp ->
SegSpace ->
[SegHistSlug] ->
KernelBody GPUMem ->
CallKernelGen ()
histKernelGlobal :: [PatElem LParamMem]
-> Count NumBlocks SubExp
-> Count BlockSize SubExp
-> SegSpace
-> [SegHistSlug]
-> KernelBody GPUMem
-> CallKernelGen ()
histKernelGlobal [PatElem LParamMem]
map_pes Count NumBlocks SubExp
num_tblocks Count BlockSize SubExp
tblock_size SegSpace
space [SegHistSlug]
slugs KernelBody GPUMem
kbody = do
let num_tblocks' :: Count NumBlocks (TExp Int64)
num_tblocks' = (SubExp -> TExp Int64)
-> Count NumBlocks SubExp -> Count NumBlocks (TExp Int64)
forall a b. (a -> b) -> Count NumBlocks a -> Count NumBlocks b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
pe64 Count NumBlocks SubExp
num_tblocks
tblock_size' :: Count BlockSize (TExp Int64)
tblock_size' = (SubExp -> TExp Int64)
-> Count BlockSize SubExp -> Count BlockSize (TExp Int64)
forall a b. (a -> b) -> Count BlockSize a -> Count BlockSize b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
pe64 Count BlockSize SubExp
tblock_size
let ([VName]
_space_is, [SubExp]
space_sizes) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
num_threads :: TExp Int32
num_threads = TExp Int64 -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ Count NumBlocks (TExp Int64) -> TExp Int64
forall {k} (u :: k) e. Count u e -> e
unCount Count NumBlocks (TExp Int64)
num_tblocks' TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* Count BlockSize (TExp Int64) -> TExp Int64
forall {k} (u :: k) e. Count u e -> e
unCount Count BlockSize (TExp Int64)
tblock_size'
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"## Using global memory" Maybe Exp
forall a. Maybe a
Nothing
(TExp Int32
hist_S, [[TExp Int64] -> InKernelGen ()]
histograms) <-
Passage
-> Shape
-> TExp Int32
-> TExp Int64
-> [SegHistSlug]
-> CallKernelGen (TExp Int32, [[TExp Int64] -> InKernelGen ()])
prepareIntermediateArraysGlobal
(KernelBody GPUMem -> Passage
bodyPassage KernelBody GPUMem
kbody)
([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> [SubExp]
forall a. HasCallStack => [a] -> [a]
init [SubExp]
space_sizes))
TExp Int32
num_threads
(SubExp -> TExp Int64
pe64 (SubExp -> TExp Int64) -> SubExp -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [SubExp] -> SubExp
forall a. HasCallStack => [a] -> a
last [SubExp]
space_sizes)
[SegHistSlug]
slugs
[Char]
-> TExp Int32
-> (TExp Int32 -> CallKernelGen ())
-> CallKernelGen ()
forall {k} (t :: k) rep r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"chk_i" TExp Int32
hist_S ((TExp Int32 -> CallKernelGen ()) -> CallKernelGen ())
-> (TExp Int32 -> CallKernelGen ()) -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
chk_i ->
[PatElem LParamMem]
-> Count NumBlocks SubExp
-> Count BlockSize SubExp
-> SegSpace
-> [SegHistSlug]
-> KernelBody GPUMem
-> [[TExp Int64] -> InKernelGen ()]
-> TExp Int32
-> TExp Int32
-> CallKernelGen ()
histKernelGlobalPass
[PatElem LParamMem]
map_pes
Count NumBlocks SubExp
num_tblocks
Count BlockSize SubExp
tblock_size
SegSpace
space
[SegHistSlug]
slugs
KernelBody GPUMem
kbody
[[TExp Int64] -> InKernelGen ()]
histograms
TExp Int32
hist_S
TExp Int32
chk_i
type InitLocalHistograms =
[ ( [VName],
SubExp ->
InKernelGen
( [VName],
[Imp.TExp Int64] -> InKernelGen ()
)
)
]
prepareIntermediateArraysLocal ::
TV Int32 ->
Count NumBlocks (Imp.TExp Int64) ->
[SegHistSlug] ->
CallKernelGen InitLocalHistograms
prepareIntermediateArraysLocal :: TV Int32
-> Count NumBlocks (TExp Int64)
-> [SegHistSlug]
-> CallKernelGen InitLocalHistograms
prepareIntermediateArraysLocal TV Int32
num_subhistos_per_block Count NumBlocks (TExp Int64)
blocks_per_segment =
(SegHistSlug
-> ImpM
GPUMem
HostEnv
HostOp
([VName],
SubExp -> InKernelGen ([VName], [TExp Int64] -> InKernelGen ())))
-> [SegHistSlug] -> CallKernelGen InitLocalHistograms
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SegHistSlug
-> ImpM
GPUMem
HostEnv
HostOp
([VName],
SubExp -> InKernelGen ([VName], [TExp Int64] -> InKernelGen ()))
onOp
where
onOp :: SegHistSlug
-> ImpM
GPUMem
HostEnv
HostOp
([VName],
SubExp -> InKernelGen ([VName], [TExp Int64] -> InKernelGen ()))
onOp (SegHistSlug HistOp GPUMem
op TV Int64
num_subhistos [SubhistosInfo]
subhisto_info AtomicUpdate GPUMem KernelEnv
do_op) = do
TV Int64
num_subhistos TV Int64 -> TExp Int64 -> CallKernelGen ()
forall {k} (t :: k) rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TExp Int64 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (Count NumBlocks (TExp Int64) -> TExp Int64
forall {k} (u :: k) e. Count u e -> e
unCount Count NumBlocks (TExp Int64)
blocks_per_segment)
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
[Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Number of subhistograms in global memory per segment" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$
TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TExp Int64 -> Exp) -> TExp Int64 -> Exp
forall a b. (a -> b) -> a -> b
$
TV Int64 -> TExp Int64
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
num_subhistos
SubExp
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
mk_op <-
case AtomicUpdate GPUMem KernelEnv
do_op of
AtomicPrim DoAtomicUpdate GPUMem KernelEnv
f -> (SubExp
-> ImpM
GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM
GPUMem
HostEnv
HostOp
(SubExp
-> ImpM
GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
forall a. a -> ImpM GPUMem HostEnv HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((SubExp
-> ImpM
GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM
GPUMem
HostEnv
HostOp
(SubExp
-> ImpM
GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)))
-> (SubExp
-> ImpM
GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM
GPUMem
HostEnv
HostOp
(SubExp
-> ImpM
GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
forall a b. (a -> b) -> a -> b
$ ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
-> SubExp
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
forall a b. a -> b -> a
const (ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
-> SubExp
-> ImpM
GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
-> SubExp
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
forall a b. (a -> b) -> a -> b
$ DoAtomicUpdate GPUMem KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
forall a. a -> ImpM GPUMem KernelEnv KernelOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure DoAtomicUpdate GPUMem KernelEnv
f
AtomicCAS DoAtomicUpdate GPUMem KernelEnv
f -> (SubExp
-> ImpM
GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM
GPUMem
HostEnv
HostOp
(SubExp
-> ImpM
GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
forall a. a -> ImpM GPUMem HostEnv HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((SubExp
-> ImpM
GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM
GPUMem
HostEnv
HostOp
(SubExp
-> ImpM
GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)))
-> (SubExp
-> ImpM
GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM
GPUMem
HostEnv
HostOp
(SubExp
-> ImpM
GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
forall a b. (a -> b) -> a -> b
$ ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
-> SubExp
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
forall a b. a -> b -> a
const (ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
-> SubExp
-> ImpM
GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
-> SubExp
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
forall a b. (a -> b) -> a -> b
$ DoAtomicUpdate GPUMem KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
forall a. a -> ImpM GPUMem KernelEnv KernelOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure DoAtomicUpdate GPUMem KernelEnv
f
AtomicLocking Locking -> DoAtomicUpdate GPUMem KernelEnv
f -> (SubExp
-> ImpM
GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM
GPUMem
HostEnv
HostOp
(SubExp
-> ImpM
GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
forall a. a -> ImpM GPUMem HostEnv HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((SubExp
-> ImpM
GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM
GPUMem
HostEnv
HostOp
(SubExp
-> ImpM
GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)))
-> (SubExp
-> ImpM
GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM
GPUMem
HostEnv
HostOp
(SubExp
-> ImpM
GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
forall a b. (a -> b) -> a -> b
$ \SubExp
hist_H_chk -> do
let lock_shape :: Shape
lock_shape =
[SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [TV Int32 -> SubExp
forall {k} (t :: k). TV t -> SubExp
tvSize TV Int32
num_subhistos_per_block, SubExp
hist_H_chk]
let dims :: [TExp Int64]
dims = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
lock_shape
VName
locks <- [Char]
-> PrimType
-> Shape
-> Space
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
[Char] -> PrimType -> Shape -> Space -> ImpM rep r op VName
sAllocArray [Char]
"locks" PrimType
int32 Shape
lock_shape (Space -> ImpM GPUMem KernelEnv KernelOp VName)
-> Space -> ImpM GPUMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ [Char] -> Space
Space [Char]
"shared"
Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"All locks start out unlocked" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[TExp Int64] -> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall {k} (t :: k).
IntExp t =>
[TExp t] -> ([TExp t] -> InKernelGen ()) -> InKernelGen ()
blockCoverSpace [TExp Int64]
dims (([TExp Int64] -> InKernelGen ()) -> InKernelGen ())
-> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
is ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
locks [TExp Int64]
is (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0) []
DoAtomicUpdate GPUMem KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
forall a. a -> ImpM GPUMem KernelEnv KernelOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DoAtomicUpdate GPUMem KernelEnv
-> ImpM
GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> DoAtomicUpdate GPUMem KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
forall a b. (a -> b) -> a -> b
$ Locking -> DoAtomicUpdate GPUMem KernelEnv
f (Locking -> DoAtomicUpdate GPUMem KernelEnv)
-> Locking -> DoAtomicUpdate GPUMem KernelEnv
forall a b. (a -> b) -> a -> b
$ VName
-> TExp Int32
-> TExp Int32
-> TExp Int32
-> ([TExp Int64] -> [TExp Int64])
-> Locking
Locking VName
locks TExp Int32
0 TExp Int32
1 TExp Int32
0 [TExp Int64] -> [TExp Int64]
forall a. a -> a
id
let init_local_subhistos :: SubExp -> InKernelGen ([VName], [TExp Int64] -> InKernelGen ())
init_local_subhistos SubExp
hist_H_chk = do
[VName]
local_subhistos <- [Type]
-> (Type -> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (HistOp GPUMem -> [Type]
forall rep. HistOp rep -> [Type]
histType HistOp GPUMem
op) ((Type -> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [VName])
-> (Type -> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [VName]
forall a b. (a -> b) -> a -> b
$ \Type
t -> do
let subhisto_shape :: Shape
subhisto_shape =
Shape -> Int -> Shape -> Shape
forall d. ShapeBase d -> Int -> ShapeBase d -> ShapeBase d
setOuterDims
(Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t)
(HistOp GPUMem -> Int
histRank HistOp GPUMem
op)
([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
hist_H_chk])
[Char]
-> PrimType
-> Shape
-> Space
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
[Char] -> PrimType -> Shape -> Space -> ImpM rep r op VName
sAllocArray
[Char]
"subhistogram_local"
(Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t)
([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [TV Int32 -> SubExp
forall {k} (t :: k). TV t -> SubExp
tvSize TV Int32
num_subhistos_per_block] Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape
subhisto_shape)
([Char] -> Space
Space [Char]
"shared")
DoAtomicUpdate GPUMem KernelEnv
do_op' <- SubExp
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
mk_op SubExp
hist_H_chk
([VName], [TExp Int64] -> InKernelGen ())
-> InKernelGen ([VName], [TExp Int64] -> InKernelGen ())
forall a. a -> ImpM GPUMem KernelEnv KernelOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([VName]
local_subhistos, DoAtomicUpdate GPUMem KernelEnv
do_op' ([Char] -> Space
Space [Char]
"shared") [VName]
local_subhistos)
[VName]
glob_subhistos <- [SubhistosInfo]
-> (SubhistosInfo -> ImpM GPUMem HostEnv HostOp VName)
-> ImpM GPUMem HostEnv HostOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SubhistosInfo]
subhisto_info ((SubhistosInfo -> ImpM GPUMem HostEnv HostOp VName)
-> ImpM GPUMem HostEnv HostOp [VName])
-> (SubhistosInfo -> ImpM GPUMem HostEnv HostOp VName)
-> ImpM GPUMem HostEnv HostOp [VName]
forall a b. (a -> b) -> a -> b
$ \SubhistosInfo
info -> do
SubhistosInfo -> CallKernelGen ()
subhistosAlloc SubhistosInfo
info
VName -> ImpM GPUMem HostEnv HostOp VName
forall a. a -> ImpM GPUMem HostEnv HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> ImpM GPUMem HostEnv HostOp VName)
-> VName -> ImpM GPUMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ SubhistosInfo -> VName
subhistosArray SubhistosInfo
info
([VName],
SubExp -> InKernelGen ([VName], [TExp Int64] -> InKernelGen ()))
-> ImpM
GPUMem
HostEnv
HostOp
([VName],
SubExp -> InKernelGen ([VName], [TExp Int64] -> InKernelGen ()))
forall a. a -> ImpM GPUMem HostEnv HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([VName]
glob_subhistos, SubExp -> InKernelGen ([VName], [TExp Int64] -> InKernelGen ())
init_local_subhistos)
histKernelLocalPass ::
TV Int32 ->
Count NumBlocks (Imp.TExp Int64) ->
[PatElem LetDecMem] ->
Count NumBlocks SubExp ->
Count BlockSize SubExp ->
SegSpace ->
[SegHistSlug] ->
KernelBody GPUMem ->
InitLocalHistograms ->
Imp.TExp Int32 ->
Imp.TExp Int32 ->
CallKernelGen ()
histKernelLocalPass :: TV Int32
-> Count NumBlocks (TExp Int64)
-> [PatElem LParamMem]
-> Count NumBlocks SubExp
-> Count BlockSize SubExp
-> SegSpace
-> [SegHistSlug]
-> KernelBody GPUMem
-> InitLocalHistograms
-> TExp Int32
-> TExp Int32
-> CallKernelGen ()
histKernelLocalPass
TV Int32
num_subhistos_per_block_var
Count NumBlocks (TExp Int64)
blocks_per_segment
[PatElem LParamMem]
map_pes
Count NumBlocks SubExp
num_tblocks
Count BlockSize SubExp
tblock_size
SegSpace
space
[SegHistSlug]
slugs
KernelBody GPUMem
kbody
InitLocalHistograms
init_histograms
TExp Int32
hist_S
TExp Int32
chk_i = do
let ([VName]
space_is, [SubExp]
space_sizes) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
segment_is :: [VName]
segment_is = [VName] -> [VName]
forall a. HasCallStack => [a] -> [a]
init [VName]
space_is
segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. HasCallStack => [a] -> [a]
init [SubExp]
space_sizes
(VName
i_in_segment, SubExp
segment_size) = [(VName, SubExp)] -> (VName, SubExp)
forall a. HasCallStack => [a] -> a
last ([(VName, SubExp)] -> (VName, SubExp))
-> [(VName, SubExp)] -> (VName, SubExp)
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
num_subhistos_per_block :: TExp Int32
num_subhistos_per_block = TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
num_subhistos_per_block_var
segment_size' :: TExp Int64
segment_size' = SubExp -> TExp Int64
pe64 SubExp
segment_size
TExp Int64
num_segments <- [Char] -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"num_segments" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
segment_dims
[TV Int64]
hist_H_chks <- [HistOp GPUMem]
-> (HistOp GPUMem -> ImpM GPUMem HostEnv HostOp (TV Int64))
-> ImpM GPUMem HostEnv HostOp [TV Int64]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ((SegHistSlug -> HistOp GPUMem) -> [SegHistSlug] -> [HistOp GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> HistOp GPUMem
slugOp [SegHistSlug]
slugs) ((HistOp GPUMem -> ImpM GPUMem HostEnv HostOp (TV Int64))
-> ImpM GPUMem HostEnv HostOp [TV Int64])
-> (HistOp GPUMem -> ImpM GPUMem HostEnv HostOp (TV Int64))
-> ImpM GPUMem HostEnv HostOp [TV Int64]
forall a b. (a -> b) -> a -> b
$ \HistOp GPUMem
op ->
[Char] -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"hist_H_chk" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TV Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall a b. (a -> b) -> a -> b
$ HistOp GPUMem -> TExp Int64
histSize HistOp GPUMem
op TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_S
[([TExp Int64], TExp Int64, TExp Int32)]
histo_sizes <- [(SegHistSlug, TV Int64)]
-> ((SegHistSlug, TV Int64)
-> ImpM
GPUMem HostEnv HostOp ([TExp Int64], TExp Int64, TExp Int32))
-> ImpM
GPUMem HostEnv HostOp [([TExp Int64], TExp Int64, TExp Int32)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([SegHistSlug] -> [TV Int64] -> [(SegHistSlug, TV Int64)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegHistSlug]
slugs [TV Int64]
hist_H_chks) (((SegHistSlug, TV Int64)
-> ImpM
GPUMem HostEnv HostOp ([TExp Int64], TExp Int64, TExp Int32))
-> ImpM
GPUMem HostEnv HostOp [([TExp Int64], TExp Int64, TExp Int32)])
-> ((SegHistSlug, TV Int64)
-> ImpM
GPUMem HostEnv HostOp ([TExp Int64], TExp Int64, TExp Int32))
-> ImpM
GPUMem HostEnv HostOp [([TExp Int64], TExp Int64, TExp Int32)]
forall a b. (a -> b) -> a -> b
$ \(SegHistSlug
slug, TV Int64
hist_H_chk) -> do
let histo_dims :: [TExp Int64]
histo_dims =
TV Int64 -> TExp Int64
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
hist_H_chk TExp Int64 -> [TExp Int64] -> [TExp Int64]
forall a. a -> [a] -> [a]
: (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histOpShape (SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug)))
TExp Int64
histo_size <-
[Char] -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"histo_size" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp Int64]
histo_dims
let block_hists_size :: TExp Int64
block_hists_size = TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
num_subhistos_per_block TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
histo_size
TExp Int32
init_per_thread <-
[Char] -> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"init_per_thread" (TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TExp Int64
block_hists_size TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` SubExp -> TExp Int64
pe64 (Count BlockSize SubExp -> SubExp
forall {k} (u :: k) e. Count u e -> e
unCount Count BlockSize SubExp
tblock_size)
([TExp Int64], TExp Int64, TExp Int32)
-> ImpM
GPUMem HostEnv HostOp ([TExp Int64], TExp Int64, TExp Int32)
forall a. a -> ImpM GPUMem HostEnv HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([TExp Int64]
histo_dims, TExp Int64
histo_size, TExp Int32
init_per_thread)
let attrs :: KernelAttrs
attrs = (Count NumBlocks SubExp -> Count BlockSize SubExp -> KernelAttrs
defKernelAttrs Count NumBlocks SubExp
num_tblocks Count BlockSize SubExp
tblock_size) {kAttrCheckSharedMemory = False}
[Char]
-> VName -> KernelAttrs -> InKernelGen () -> CallKernelGen ()
sKernelThread [Char]
"seghist_local" (SegSpace -> VName
segFlat SegSpace
space) KernelAttrs
attrs (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
SegVirt
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
virtualiseBlocks SegVirt
SegVirt (TExp Int64 -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ Count NumBlocks (TExp Int64) -> TExp Int64
forall {k} (u :: k) e. Count u e -> e
unCount Count NumBlocks (TExp Int64)
blocks_per_segment TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
num_segments) ((TExp Int32 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
tblock_id -> do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
TExp Int32
flat_segment_id <- [Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"flat_segment_id" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ TExp Int32
tblock_id TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int64 -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (Count NumBlocks (TExp Int64) -> TExp Int64
forall {k} (u :: k) e. Count u e -> e
unCount Count NumBlocks (TExp Int64)
blocks_per_segment)
TExp Int32
gid_in_segment <- [Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"gid_in_segment" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ TExp Int32
tblock_id TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`rem` TExp Int64 -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (Count NumBlocks (TExp Int64) -> TExp Int64
forall {k} (u :: k) e. Count u e -> e
unCount Count NumBlocks (TExp Int64)
blocks_per_segment)
TExp Int32
pgtid_in_segment <-
[Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"pgtid_in_segment" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
TExp Int32
gid_in_segment TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int64 -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (KernelConstants -> TExp Int64
kernelBlockSize KernelConstants
constants)
TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
TExp Int32
threads_per_segment <-
[Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"threads_per_segment" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
TExp Int64 -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$
Count NumBlocks (TExp Int64) -> TExp Int64
forall {k} (u :: k) e. Count u e -> e
unCount Count NumBlocks (TExp Int64)
blocks_per_segment TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* KernelConstants -> TExp Int64
kernelBlockSize KernelConstants
constants
(VName -> TExp Int64 -> InKernelGen ())
-> [VName] -> [TExp Int64] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TExp Int64 -> InKernelGen ()
forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
segment_is ([TExp Int64] -> InKernelGen ()) -> [TExp Int64] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex ((SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
segment_dims) (TExp Int64 -> [TExp Int64]) -> TExp Int64 -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$
TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
flat_segment_id
[([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ())]
histograms <- [(([VName],
SubExp -> InKernelGen ([VName], [TExp Int64] -> InKernelGen ())),
TV Int64)]
-> ((([VName],
SubExp -> InKernelGen ([VName], [TExp Int64] -> InKernelGen ())),
TV Int64)
-> ImpM
GPUMem
KernelEnv
KernelOp
([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()))
-> ImpM
GPUMem
KernelEnv
KernelOp
[([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ())]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (InitLocalHistograms
-> [TV Int64]
-> [(([VName],
SubExp -> InKernelGen ([VName], [TExp Int64] -> InKernelGen ())),
TV Int64)]
forall a b. [a] -> [b] -> [(a, b)]
zip InitLocalHistograms
init_histograms [TV Int64]
hist_H_chks) (((([VName],
SubExp -> InKernelGen ([VName], [TExp Int64] -> InKernelGen ())),
TV Int64)
-> ImpM
GPUMem
KernelEnv
KernelOp
([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()))
-> ImpM
GPUMem
KernelEnv
KernelOp
[([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ())])
-> ((([VName],
SubExp -> InKernelGen ([VName], [TExp Int64] -> InKernelGen ())),
TV Int64)
-> ImpM
GPUMem
KernelEnv
KernelOp
([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()))
-> ImpM
GPUMem
KernelEnv
KernelOp
[([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ())]
forall a b. (a -> b) -> a -> b
$
\(([VName]
glob_subhistos, SubExp -> InKernelGen ([VName], [TExp Int64] -> InKernelGen ())
init_local_subhistos), TV Int64
hist_H_chk) -> do
([VName]
local_subhistos, [TExp Int64] -> InKernelGen ()
do_op) <- SubExp -> InKernelGen ([VName], [TExp Int64] -> InKernelGen ())
init_local_subhistos (SubExp -> InKernelGen ([VName], [TExp Int64] -> InKernelGen ()))
-> SubExp -> InKernelGen ([VName], [TExp Int64] -> InKernelGen ())
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ TV Int64 -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV Int64
hist_H_chk
([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ())
-> ImpM
GPUMem
KernelEnv
KernelOp
([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ())
forall a. a -> ImpM GPUMem KernelEnv KernelOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
glob_subhistos [VName]
local_subhistos, TV Int64
hist_H_chk, [TExp Int64] -> InKernelGen ()
do_op)
TExp Int32
thread_local_subhisto_i <-
[Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"thread_local_subhisto_i" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`rem` TExp Int32
num_subhistos_per_block
let onSlugs :: (SegHistSlug
-> [(VName, VName)]
-> TExp Int64
-> [TExp Int64]
-> TExp Int64
-> TExp Int32
-> InKernelGen ())
-> InKernelGen ()
onSlugs SegHistSlug
-> [(VName, VName)]
-> TExp Int64
-> [TExp Int64]
-> TExp Int64
-> TExp Int32
-> InKernelGen ()
f =
[(SegHistSlug,
([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()),
([TExp Int64], TExp Int64, TExp Int32))]
-> ((SegHistSlug,
([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()),
([TExp Int64], TExp Int64, TExp Int32))
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegHistSlug]
-> [([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ())]
-> [([TExp Int64], TExp Int64, TExp Int32)]
-> [(SegHistSlug,
([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()),
([TExp Int64], TExp Int64, TExp Int32))]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [SegHistSlug]
slugs [([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ())]
histograms [([TExp Int64], TExp Int64, TExp Int32)]
histo_sizes) (((SegHistSlug,
([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()),
([TExp Int64], TExp Int64, TExp Int32))
-> InKernelGen ())
-> InKernelGen ())
-> ((SegHistSlug,
([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()),
([TExp Int64], TExp Int64, TExp Int32))
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
\(SegHistSlug
slug, ([(VName, VName)]
dests, TV Int64
hist_H_chk, [TExp Int64] -> InKernelGen ()
_), ([TExp Int64]
histo_dims, TExp Int64
histo_size, TExp Int32
init_per_thread)) ->
SegHistSlug
-> [(VName, VName)]
-> TExp Int64
-> [TExp Int64]
-> TExp Int64
-> TExp Int32
-> InKernelGen ()
f SegHistSlug
slug [(VName, VName)]
dests (TV Int64 -> TExp Int64
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
hist_H_chk) [TExp Int64]
histo_dims TExp Int64
histo_size TExp Int32
init_per_thread
let onAllHistograms :: (VName
-> VName
-> HistOp GPUMem
-> SubExp
-> TExp Int32
-> TExp Int32
-> [TExp Int64]
-> [TExp Int64]
-> InKernelGen ())
-> InKernelGen ()
onAllHistograms VName
-> VName
-> HistOp GPUMem
-> SubExp
-> TExp Int32
-> TExp Int32
-> [TExp Int64]
-> [TExp Int64]
-> InKernelGen ()
f =
(SegHistSlug
-> [(VName, VName)]
-> TExp Int64
-> [TExp Int64]
-> TExp Int64
-> TExp Int32
-> InKernelGen ())
-> InKernelGen ()
onSlugs ((SegHistSlug
-> [(VName, VName)]
-> TExp Int64
-> [TExp Int64]
-> TExp Int64
-> TExp Int32
-> InKernelGen ())
-> InKernelGen ())
-> (SegHistSlug
-> [(VName, VName)]
-> TExp Int64
-> [TExp Int64]
-> TExp Int64
-> TExp Int32
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \SegHistSlug
slug [(VName, VName)]
dests TExp Int64
hist_H_chk [TExp Int64]
histo_dims TExp Int64
histo_size TExp Int32
init_per_thread -> do
let block_hists_size :: TExp Int32
block_hists_size = TExp Int32
num_subhistos_per_block TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int64 -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
histo_size
[((VName, VName), SubExp)]
-> (((VName, VName), SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(VName, VName)] -> [SubExp] -> [((VName, VName), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(VName, VName)]
dests (HistOp GPUMem -> [SubExp]
forall rep. HistOp rep -> [SubExp]
histNeutral (HistOp GPUMem -> [SubExp]) -> HistOp GPUMem -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug)) ((((VName, VName), SubExp) -> InKernelGen ()) -> InKernelGen ())
-> (((VName, VName), SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
\((VName
dest_global, VName
dest_local), SubExp
ne) ->
[Char]
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall {k} (t :: k) rep r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"local_i" TExp Int32
init_per_thread ((TExp Int32 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
i -> do
TExp Int32
j <-
[Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"j" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
TExp Int32
i TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int64 -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (KernelConstants -> TExp Int64
kernelBlockSize KernelConstants
constants)
TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
TExp Int32
j_offset <-
[Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"j_offset" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
TExp Int32
num_subhistos_per_block TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int64 -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
histo_size TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
gid_in_segment TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
j
TExp Int32
local_subhisto_i <- [Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"local_subhisto_i" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ TExp Int32
j TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int64 -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
histo_size
let local_bucket_is :: [TExp Int64]
local_bucket_is = [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
histo_dims (TExp Int64 -> [TExp Int64]) -> TExp Int64 -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ TExp Int32
j TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`rem` TExp Int64 -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
histo_size
nested_hist_size :: [TExp Int64]
nested_hist_size =
(SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (Shape -> [SubExp]) -> Shape -> [SubExp]
forall a b. (a -> b) -> a -> b
$ HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histShape (HistOp GPUMem -> Shape) -> HistOp GPUMem -> Shape
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug
global_bucket_is :: [TExp Int64]
global_bucket_is =
[TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex
[TExp Int64]
nested_hist_size
([TExp Int64] -> TExp Int64
forall a. HasCallStack => [a] -> a
head [TExp Int64]
local_bucket_is TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chk_i TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
hist_H_chk)
[TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64] -> [TExp Int64]
forall a. HasCallStack => [a] -> [a]
tail [TExp Int64]
local_bucket_is
TExp Int32
global_subhisto_i <- [Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"global_subhisto_i" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ TExp Int32
j_offset TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int64 -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
histo_size
TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
j TExp Int32 -> TExp Int32 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
block_hists_size) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName
-> VName
-> HistOp GPUMem
-> SubExp
-> TExp Int32
-> TExp Int32
-> [TExp Int64]
-> [TExp Int64]
-> InKernelGen ()
f
VName
dest_local
VName
dest_global
(SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug)
SubExp
ne
TExp Int32
local_subhisto_i
TExp Int32
global_subhisto_i
[TExp Int64]
local_bucket_is
[TExp Int64]
global_bucket_is
Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"initialize histograms in shared memory" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
(VName
-> VName
-> HistOp GPUMem
-> SubExp
-> TExp Int32
-> TExp Int32
-> [TExp Int64]
-> [TExp Int64]
-> InKernelGen ())
-> InKernelGen ()
onAllHistograms ((VName
-> VName
-> HistOp GPUMem
-> SubExp
-> TExp Int32
-> TExp Int32
-> [TExp Int64]
-> [TExp Int64]
-> InKernelGen ())
-> InKernelGen ())
-> (VName
-> VName
-> HistOp GPUMem
-> SubExp
-> TExp Int32
-> TExp Int32
-> [TExp Int64]
-> [TExp Int64]
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \VName
dest_local VName
dest_global HistOp GPUMem
op SubExp
ne TExp Int32
local_subhisto_i TExp Int32
global_subhisto_i [TExp Int64]
local_bucket_is [TExp Int64]
global_bucket_is ->
Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"First subhistogram is initialised from global memory; others with neutral element." (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
[TExp Int64]
dest_global_shape <- (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 ([SubExp] -> [TExp Int64])
-> (Type -> [SubExp]) -> Type -> [TExp Int64]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [TExp Int64])
-> ImpM GPUMem KernelEnv KernelOp Type
-> ImpM GPUMem KernelEnv KernelOp [TExp Int64]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM GPUMem KernelEnv KernelOp Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
dest_global
let global_is :: [TExp Int64]
global_is = (VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
segment_is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64
0] [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
global_bucket_is
local_is :: [TExp Int64]
local_is = TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
local_subhisto_i TExp Int64 -> [TExp Int64] -> [TExp Int64]
forall a. a -> [a] -> [a]
: [TExp Int64]
local_bucket_is
global_in_bounds :: TExp Bool
global_in_bounds =
Slice (TExp Int64) -> [TExp Int64] -> TExp Bool
inBounds ([DimIndex (TExp Int64)] -> Slice (TExp Int64)
forall d. [DimIndex d] -> Slice d
Slice ((TExp Int64 -> DimIndex (TExp Int64))
-> [TExp Int64] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix [TExp Int64]
global_is)) [TExp Int64]
dest_global_shape
TExp Bool -> InKernelGen () -> InKernelGen () -> InKernelGen ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
(TExp Int32
global_subhisto_i TExp Int32 -> TExp Int32 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0 TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Bool
global_in_bounds)
(VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
dest_local [TExp Int64]
local_is (VName -> SubExp
Var VName
dest_global) [TExp Int64]
global_is)
( Shape -> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall rep r op.
Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histOpShape HistOp GPUMem
op) (([TExp Int64] -> InKernelGen ()) -> InKernelGen ())
-> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
is ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
dest_local ([TExp Int64]
local_is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
is) SubExp
ne []
)
KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
TExp Int64
-> TExp Int64
-> TExp Int64
-> (TExp Int64 -> InKernelGen ())
-> InKernelGen ()
forall {k} (t :: k).
IntExp t =>
TExp t
-> TExp t -> TExp t -> (TExp t -> InKernelGen ()) -> InKernelGen ()
kernelLoop (TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
pgtid_in_segment) (TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
threads_per_segment) TExp Int64
segment_size' ((TExp Int64 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int64 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
ie -> do
VName -> TExp Int64 -> InKernelGen ()
forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
i_in_segment TExp Int64
ie
Names -> Stms GPUMem -> InKernelGen () -> InKernelGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let ([SubExp]
red_res, [SubExp]
map_res) =
Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([PatElem LParamMem] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElem LParamMem]
map_pes) ([SubExp] -> ([SubExp], [SubExp]))
-> [SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$
(KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp ([KernelResult] -> [SubExp]) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> a -> b
$
KernelBody GPUMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
kbody
TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
chk_i TExp Int32 -> TExp Int32 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"save map-out results" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(PatElem LParamMem, SubExp)]
-> ((PatElem LParamMem, SubExp) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem LParamMem] -> [SubExp] -> [(PatElem LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
map_pes [SubExp]
map_res) (((PatElem LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((PatElem LParamMem, SubExp) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, SubExp
se) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix
(PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)
((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
space_is)
SubExp
se
[]
let red_res_split :: [([SubExp], [SubExp])]
red_res_split = [HistOp GPUMem] -> [SubExp] -> [([SubExp], [SubExp])]
forall rep. [HistOp rep] -> [SubExp] -> [([SubExp], [SubExp])]
splitHistResults ((SegHistSlug -> HistOp GPUMem) -> [SegHistSlug] -> [HistOp GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> HistOp GPUMem
slugOp [SegHistSlug]
slugs) [SubExp]
red_res
[(HistOp GPUMem,
([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()),
([SubExp], [SubExp]))]
-> ((HistOp GPUMem,
([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()),
([SubExp], [SubExp]))
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([HistOp GPUMem]
-> [([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ())]
-> [([SubExp], [SubExp])]
-> [(HistOp GPUMem,
([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()),
([SubExp], [SubExp]))]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 ((SegHistSlug -> HistOp GPUMem) -> [SegHistSlug] -> [HistOp GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> HistOp GPUMem
slugOp [SegHistSlug]
slugs) [([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ())]
histograms [([SubExp], [SubExp])]
red_res_split) (((HistOp GPUMem,
([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()),
([SubExp], [SubExp]))
-> InKernelGen ())
-> InKernelGen ())
-> ((HistOp GPUMem,
([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()),
([SubExp], [SubExp]))
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
\( HistOp Shape
dest_shape SubExp
_ [VName]
_ [SubExp]
_ Shape
shape Lambda GPUMem
lam,
([(VName, VName)]
_, TV Int64
hist_H_chk, [TExp Int64] -> InKernelGen ()
do_op),
([SubExp]
bucket, [SubExp]
vs')
) -> do
let chk_beg :: TExp Int64
chk_beg = TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chk_i TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TV Int64 -> TExp Int64
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
hist_H_chk
bucket' :: [TExp Int64]
bucket' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
bucket
dest_shape' :: [TExp Int64]
dest_shape' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
dest_shape
flat_bucket :: TExp Int64
flat_bucket = [TExp Int64] -> [TExp Int64] -> TExp Int64
forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [TExp Int64]
dest_shape' [TExp Int64]
bucket'
bucket_in_bounds :: TExp Bool
bucket_in_bounds =
Slice (TExp Int64) -> [TExp Int64] -> TExp Bool
inBounds ([DimIndex (TExp Int64)] -> Slice (TExp Int64)
forall d. [DimIndex d] -> Slice d
Slice ((TExp Int64 -> DimIndex (TExp Int64))
-> [TExp Int64] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix [TExp Int64]
bucket')) [TExp Int64]
dest_shape'
TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Int64
chk_beg
TExp Int64 -> TExp Int64 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TExp Int64
flat_bucket
TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Int64
flat_bucket
TExp Int64 -> TExp Int64 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. (TExp Int64
chk_beg TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TV Int64 -> TExp Int64
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
hist_H_chk)
bucket_is :: [TExp Int64]
bucket_is =
[TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
thread_local_subhisto_i, TExp Int64
flat_bucket TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
chk_beg]
vs_params :: [Param LParamMem]
vs_params = Int -> [Param LParamMem] -> [Param LParamMem]
forall a. Int -> [a] -> [a]
takeLast ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs') ([Param LParamMem] -> [Param LParamMem])
-> [Param LParamMem] -> [Param LParamMem]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam
Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"perform atomic updates" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
bucket_in_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
[LParam GPUMem] -> InKernelGen ()
forall rep (inner :: * -> *) r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams ([LParam GPUMem] -> InKernelGen ())
-> [LParam GPUMem] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam
Shape -> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall rep r op.
Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest Shape
shape (([TExp Int64] -> InKernelGen ()) -> InKernelGen ())
-> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
is -> do
[(Param LParamMem, SubExp)]
-> ((Param LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [SubExp] -> [(Param LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
vs_params [SubExp]
vs') (((Param LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
v) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
v [TExp Int64]
is
[TExp Int64] -> InKernelGen ()
do_op ([TExp Int64]
bucket_is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
is)
KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceGlobal
Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Compact the multiple shared memory subhistograms to result in global memory" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
(SegHistSlug
-> [(VName, VName)]
-> TExp Int64
-> [TExp Int64]
-> TExp Int64
-> TExp Int32
-> InKernelGen ())
-> InKernelGen ()
onSlugs ((SegHistSlug
-> [(VName, VName)]
-> TExp Int64
-> [TExp Int64]
-> TExp Int64
-> TExp Int32
-> InKernelGen ())
-> InKernelGen ())
-> (SegHistSlug
-> [(VName, VName)]
-> TExp Int64
-> [TExp Int64]
-> TExp Int64
-> TExp Int32
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \SegHistSlug
slug [(VName, VName)]
dests TExp Int64
hist_H_chk [TExp Int64]
histo_dims TExp Int64
_histo_size TExp Int32
bins_per_thread -> do
TV Int64
trunc_H <-
[Char] -> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"trunc_H" (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> (TExp Int64 -> TExp Int64)
-> TExp Int64
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TExp Int64 -> TExp Int64 -> TExp Int64
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64 TExp Int64
hist_H_chk (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
HistOp GPUMem -> TExp Int64
histSize (SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chk_i TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* [TExp Int64] -> TExp Int64
forall a. HasCallStack => [a] -> a
head [TExp Int64]
histo_dims
let trunc_histo_dims :: [TExp Int64]
trunc_histo_dims =
TV Int64 -> TExp Int64
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
trunc_H
TExp Int64 -> [TExp Int64] -> [TExp Int64]
forall a. a -> [a] -> [a]
: (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histOpShape (SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug)))
TExp Int32
trunc_histo_size <- [Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"histo_size" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp Int64]
trunc_histo_dims
[Char]
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall {k} (t :: k) rep r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"local_i" TExp Int32
bins_per_thread ((TExp Int32 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
i -> do
TExp Int32
j <-
[Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"j" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
TExp Int32
i TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int64 -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (KernelConstants -> TExp Int64
kernelBlockSize KernelConstants
constants)
TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
j TExp Int32 -> TExp Int32 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
trunc_histo_size) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let local_bucket_is :: [TExp Int64]
local_bucket_is = [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
histo_dims (TExp Int64 -> [TExp Int64]) -> TExp Int64 -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
j
nested_hist_size :: [TExp Int64]
nested_hist_size =
(SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (Shape -> [SubExp]) -> Shape -> [SubExp]
forall a b. (a -> b) -> a -> b
$ HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histShape (HistOp GPUMem -> Shape) -> HistOp GPUMem -> Shape
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug
global_bucket_is :: [TExp Int64]
global_bucket_is =
[TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex
[TExp Int64]
nested_hist_size
([TExp Int64] -> TExp Int64
forall a. HasCallStack => [a] -> a
head [TExp Int64]
local_bucket_is TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chk_i TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
hist_H_chk)
[TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64] -> [TExp Int64]
forall a. HasCallStack => [a] -> [a]
tail [TExp Int64]
local_bucket_is
[LParam GPUMem] -> InKernelGen ()
forall rep (inner :: * -> *) r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams ([LParam GPUMem] -> InKernelGen ())
-> [LParam GPUMem] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams (Lambda GPUMem -> [LParam GPUMem])
-> Lambda GPUMem -> [LParam GPUMem]
forall a b. (a -> b) -> a -> b
$ HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp (HistOp GPUMem -> Lambda GPUMem) -> HistOp GPUMem -> Lambda GPUMem
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug
let ([VName]
global_dests, [VName]
local_dests) = [(VName, VName)] -> ([VName], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, VName)]
dests
([Param LParamMem]
xparams, [Param LParamMem]
yparams) =
Int -> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
local_dests) ([Param LParamMem] -> ([Param LParamMem], [Param LParamMem]))
-> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a b. (a -> b) -> a -> b
$
Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams (Lambda GPUMem -> [LParam GPUMem])
-> Lambda GPUMem -> [LParam GPUMem]
forall a b. (a -> b) -> a -> b
$
HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp (HistOp GPUMem -> Lambda GPUMem) -> HistOp GPUMem -> Lambda GPUMem
forall a b. (a -> b) -> a -> b
$
SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug
Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Read values from subhistogram 0." (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(Param LParamMem, VName)]
-> ((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [VName] -> [(Param LParamMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
xparams [VName]
local_dests) (((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
xp, VName
subhisto) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix
(Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
xp)
[]
(VName -> SubExp
Var VName
subhisto)
(TExp Int64
0 TExp Int64 -> [TExp Int64] -> [TExp Int64]
forall a. a -> [a] -> [a]
: [TExp Int64]
local_bucket_is)
Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Accumulate based on values in other subhistograms." (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[Char]
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall {k} (t :: k) rep r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"subhisto_id" (TExp Int32
num_subhistos_per_block TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
1) ((TExp Int32 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
subhisto_id -> do
[(Param LParamMem, VName)]
-> ((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [VName] -> [(Param LParamMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
yparams [VName]
local_dests) (((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
yp, VName
subhisto) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix
(Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
yp)
[]
(VName -> SubExp
Var VName
subhisto)
(TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
subhisto_id TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
1 TExp Int64 -> [TExp Int64] -> [TExp Int64]
forall a. a -> [a] -> [a]
: [TExp Int64]
local_bucket_is)
[Param LParamMem] -> Body GPUMem -> InKernelGen ()
forall dec rep r op. [Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [Param LParamMem]
xparams (Body GPUMem -> InKernelGen ()) -> Body GPUMem -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody (Lambda GPUMem -> Body GPUMem) -> Lambda GPUMem -> Body GPUMem
forall a b. (a -> b) -> a -> b
$ HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp (HistOp GPUMem -> Lambda GPUMem) -> HistOp GPUMem -> Lambda GPUMem
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug
Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Put final bucket value in global memory." (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let global_is :: [TExp Int64]
global_is =
(VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
segment_is
[TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
tblock_id TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`rem` Count NumBlocks (TExp Int64) -> TExp Int64
forall {k} (u :: k) e. Count u e -> e
unCount Count NumBlocks (TExp Int64)
blocks_per_segment]
[TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
global_bucket_is
[(Param LParamMem, VName)]
-> ((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [VName] -> [(Param LParamMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
xparams [VName]
global_dests) (((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
xp, VName
global_dest) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
global_dest [TExp Int64]
global_is (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
xp) []
histKernelLocal ::
TV Int32 ->
Count NumBlocks (Imp.TExp Int64) ->
[PatElem LetDecMem] ->
Count NumBlocks SubExp ->
Count BlockSize SubExp ->
SegSpace ->
Imp.TExp Int32 ->
[SegHistSlug] ->
KernelBody GPUMem ->
CallKernelGen ()
histKernelLocal :: TV Int32
-> Count NumBlocks (TExp Int64)
-> [PatElem LParamMem]
-> Count NumBlocks SubExp
-> Count BlockSize SubExp
-> SegSpace
-> TExp Int32
-> [SegHistSlug]
-> KernelBody GPUMem
-> CallKernelGen ()
histKernelLocal TV Int32
num_subhistos_per_block_var Count NumBlocks (TExp Int64)
blocks_per_segment [PatElem LParamMem]
map_pes Count NumBlocks SubExp
num_tblocks Count BlockSize SubExp
tblock_size SegSpace
space TExp Int32
hist_S [SegHistSlug]
slugs KernelBody GPUMem
kbody = do
let num_subhistos_per_block :: TExp Int32
num_subhistos_per_block = TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
num_subhistos_per_block_var
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
[Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Number of local subhistograms per block" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$
TExp Int32 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int32
num_subhistos_per_block
InitLocalHistograms
init_histograms <-
TV Int32
-> Count NumBlocks (TExp Int64)
-> [SegHistSlug]
-> CallKernelGen InitLocalHistograms
prepareIntermediateArraysLocal TV Int32
num_subhistos_per_block_var Count NumBlocks (TExp Int64)
blocks_per_segment [SegHistSlug]
slugs
[Char]
-> TExp Int32
-> (TExp Int32 -> CallKernelGen ())
-> CallKernelGen ()
forall {k} (t :: k) rep r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"chk_i" TExp Int32
hist_S ((TExp Int32 -> CallKernelGen ()) -> CallKernelGen ())
-> (TExp Int32 -> CallKernelGen ()) -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
chk_i ->
TV Int32
-> Count NumBlocks (TExp Int64)
-> [PatElem LParamMem]
-> Count NumBlocks SubExp
-> Count BlockSize SubExp
-> SegSpace
-> [SegHistSlug]
-> KernelBody GPUMem
-> InitLocalHistograms
-> TExp Int32
-> TExp Int32
-> CallKernelGen ()
histKernelLocalPass
TV Int32
num_subhistos_per_block_var
Count NumBlocks (TExp Int64)
blocks_per_segment
[PatElem LParamMem]
map_pes
Count NumBlocks SubExp
num_tblocks
Count BlockSize SubExp
tblock_size
SegSpace
space
[SegHistSlug]
slugs
KernelBody GPUMem
kbody
InitLocalHistograms
init_histograms
TExp Int32
hist_S
TExp Int32
chk_i
slugMaxLocalMemPasses :: SegHistSlug -> Int
slugMaxLocalMemPasses :: SegHistSlug -> Int
slugMaxLocalMemPasses SegHistSlug
slug =
case SegHistSlug -> AtomicUpdate GPUMem KernelEnv
slugAtomicUpdate SegHistSlug
slug of
AtomicPrim DoAtomicUpdate GPUMem KernelEnv
_ -> Int
3
AtomicCAS DoAtomicUpdate GPUMem KernelEnv
_ -> Int
4
AtomicLocking Locking -> DoAtomicUpdate GPUMem KernelEnv
_ -> Int
6
localMemoryCase ::
[PatElem LetDecMem] ->
Imp.TExp Int32 ->
SegSpace ->
Imp.TExp Int64 ->
Imp.TExp Int64 ->
Imp.TExp Int64 ->
Imp.TExp Int32 ->
[SegHistSlug] ->
KernelBody GPUMem ->
CallKernelGen (Imp.TExp Bool, CallKernelGen ())
localMemoryCase :: [PatElem LParamMem]
-> TExp Int32
-> SegSpace
-> TExp Int64
-> TExp Int64
-> TExp Int64
-> TExp Int32
-> [SegHistSlug]
-> KernelBody GPUMem
-> CallKernelGen (TExp Bool, CallKernelGen ())
localMemoryCase [PatElem LParamMem]
map_pes TExp Int32
hist_T SegSpace
space TExp Int64
hist_H TExp Int64
hist_el_size TExp Int64
hist_N TExp Int32
_ [SegHistSlug]
slugs KernelBody GPUMem
kbody = do
let space_sizes :: [SubExp]
space_sizes = SegSpace -> [SubExp]
segSpaceDims SegSpace
space
segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. HasCallStack => [a] -> [a]
init [SubExp]
space_sizes
segmented :: Bool
segmented = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [SubExp]
segment_dims
TV Int64
hist_L <- [Char] -> SizeClass -> ImpM GPUMem HostEnv HostOp (TV Int64)
getSize [Char]
"hist_L" SizeClass
Imp.SizeSharedMemory
TV Int64
max_tblock_size :: TV Int64 <- [Char] -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall {k} (t :: k) rep r op.
MkTV t =>
[Char] -> ImpM rep r op (TV t)
dPrim [Char]
"max_tblock_size"
HostOp -> CallKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (HostOp -> CallKernelGen ()) -> HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> SizeClass -> HostOp
Imp.GetSizeMax (TV Int64 -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV Int64
max_tblock_size) SizeClass
Imp.SizeThreadBlock
let withSizeMax :: Map VName (VarEntry GPUMem) -> Map VName (VarEntry GPUMem)
withSizeMax Map VName (VarEntry GPUMem)
vtable =
case VName -> Map VName (VarEntry GPUMem) -> Maybe (VarEntry GPUMem)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (TV Int64 -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV Int64
max_tblock_size) Map VName (VarEntry GPUMem)
vtable of
Just (ScalarVar Maybe (Exp GPUMem)
_ ScalarEntry
se) ->
VName
-> VarEntry GPUMem
-> Map VName (VarEntry GPUMem)
-> Map VName (VarEntry GPUMem)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert
(TV Int64 -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV Int64
max_tblock_size)
(Maybe (Exp GPUMem) -> ScalarEntry -> VarEntry GPUMem
forall rep. Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
ScalarVar (Exp GPUMem -> Maybe (Exp GPUMem)
forall a. a -> Maybe a
Just (Op GPUMem -> Exp GPUMem
forall rep. Op rep -> Exp rep
Op (HostOp NoOp GPUMem -> MemOp (HostOp NoOp) GPUMem
forall (inner :: * -> *) rep. inner rep -> MemOp inner rep
Inner (SizeOp -> HostOp NoOp GPUMem
forall (op :: * -> *) rep. SizeOp -> HostOp op rep
SizeOp (SizeClass -> SizeOp
GetSizeMax SizeClass
SizeThreadBlock))))) ScalarEntry
se)
Map VName (VarEntry GPUMem)
vtable
Maybe (VarEntry GPUMem)
_ -> Map VName (VarEntry GPUMem)
vtable
let tblock_size :: Count BlockSize SubExp
tblock_size = SubExp -> Count BlockSize SubExp
forall {k} (u :: k) e. e -> Count u e
Imp.Count (SubExp -> Count BlockSize SubExp)
-> SubExp -> Count BlockSize SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ TV Int64 -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV Int64
max_tblock_size
Count NumBlocks SubExp
num_tblocks <-
(TV Int64 -> Count NumBlocks SubExp)
-> ImpM GPUMem HostEnv HostOp (TV Int64)
-> ImpM GPUMem HostEnv HostOp (Count NumBlocks SubExp)
forall a b.
(a -> b)
-> ImpM GPUMem HostEnv HostOp a -> ImpM GPUMem HostEnv HostOp b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (SubExp -> Count NumBlocks SubExp
forall {k} (u :: k) e. e -> Count u e
Imp.Count (SubExp -> Count NumBlocks SubExp)
-> (TV Int64 -> SubExp) -> TV Int64 -> Count NumBlocks SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TV Int64 -> SubExp
forall {k} (t :: k). TV t -> SubExp
tvSize) (ImpM GPUMem HostEnv HostOp (TV Int64)
-> ImpM GPUMem HostEnv HostOp (Count NumBlocks SubExp))
-> ImpM GPUMem HostEnv HostOp (TV Int64)
-> ImpM GPUMem HostEnv HostOp (Count NumBlocks SubExp)
forall a b. (a -> b) -> a -> b
$
[Char] -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"num_tblocks" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TV Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_T TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` SubExp -> TExp Int64
pe64 (Count BlockSize SubExp -> SubExp
forall {k} (u :: k) e. Count u e -> e
unCount Count BlockSize SubExp
tblock_size)
let num_tblocks' :: Count NumBlocks (TExp Int64)
num_tblocks' = SubExp -> TExp Int64
pe64 (SubExp -> TExp Int64)
-> Count NumBlocks SubExp -> Count NumBlocks (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Count NumBlocks SubExp
num_tblocks
tblock_size' :: Count BlockSize (TExp Int64)
tblock_size' = SubExp -> TExp Int64
pe64 (SubExp -> TExp Int64)
-> Count BlockSize SubExp -> Count BlockSize (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Count BlockSize SubExp
tblock_size
let r64 :: TPrimExp t v -> TPrimExp Double v
r64 = PrimExp v -> TPrimExp Double v
forall v. PrimExp v -> TPrimExp Double v
isF64 (PrimExp v -> TPrimExp Double v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> TPrimExp Double v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (IntType -> FloatType -> ConvOp
SIToFP IntType
Int64 FloatType
Float64) (PrimExp v -> PrimExp v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> PrimExp v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp t v -> PrimExp v
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped
t64 :: TPrimExp t v -> TPrimExp Int64 v
t64 = PrimExp v -> TPrimExp Int64 v
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (PrimExp v -> TPrimExp Int64 v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> TPrimExp Int64 v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (FloatType -> IntType -> ConvOp
FPToSI FloatType
Float64 IntType
Int64) (PrimExp v -> PrimExp v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> PrimExp v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp t v -> PrimExp v
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped
TExp Double
hist_m' <-
[Char] -> TExp Double -> ImpM GPUMem HostEnv HostOp (TExp Double)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_m_prime" (TExp Double -> ImpM GPUMem HostEnv HostOp (TExp Double))
-> TExp Double -> ImpM GPUMem HostEnv HostOp (TExp Double)
forall a b. (a -> b) -> a -> b
$
TExp Int64 -> TExp Double
forall {k} {t :: k} {v}. TPrimExp t v -> TPrimExp Double v
r64
( TExp Int64 -> TExp Int64 -> TExp Int64
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64
(TExp Int64 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TV Int64 -> TExp Int64
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
hist_L TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int64
hist_el_size))
(TExp Int64
hist_N TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int64 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (Count NumBlocks (TExp Int64) -> TExp Int64
forall {k} (u :: k) e. Count u e -> e
unCount Count NumBlocks (TExp Int64)
num_tblocks'))
)
TExp Double -> TExp Double -> TExp Double
forall a. Fractional a => a -> a -> a
/ TExp Int64 -> TExp Double
forall {k} {t :: k} {v}. TPrimExp t v -> TPrimExp Double v
r64 TExp Int64
hist_H
let hist_B :: TExp Int64
hist_B = Count BlockSize (TExp Int64) -> TExp Int64
forall {k} (u :: k) e. Count u e -> e
unCount Count BlockSize (TExp Int64)
tblock_size'
TExp Int64
hist_M0 <-
[Char] -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_M0" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$
TExp Int64 -> TExp Int64 -> TExp Int64
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMax64 TExp Int64
1 (TExp Int64 -> TExp Int64) -> TExp Int64 -> TExp Int64
forall a b. (a -> b) -> a -> b
$
TExp Int64 -> TExp Int64 -> TExp Int64
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64 (TExp Double -> TExp Int64
forall {k} {t :: k} {v}. TPrimExp t v -> TPrimExp Int64 v
t64 TExp Double
hist_m') TExp Int64
hist_B
let q_small :: TExp Int64
q_small = TExp Int64
2
TExp Int64
hist_Nout <- [Char] -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_Nout" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
segment_dims
TExp Int64
hist_Nin <- [Char] -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_Nin" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ SubExp -> TExp Int64
pe64 (SubExp -> TExp Int64) -> SubExp -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [SubExp] -> SubExp
forall a. HasCallStack => [a] -> a
last [SubExp]
space_sizes
TExp Int64
work_asymp_M_max <-
if Bool
segmented
then do
TExp Int32
hist_T_hist_min <-
[Char] -> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_T_hist_min" (TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
TExp Int64 -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$
TExp Int64 -> TExp Int64 -> TExp Int64
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64 (TExp Int64 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int64
hist_Nin TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int64
hist_Nout) (TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_T)
TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int64 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int64
hist_Nout
let r :: TExp Int32
r = TExp Int32
hist_T_hist_min TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int64 -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
hist_B
[Char] -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"work_asymp_M_max" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64
hist_Nin TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`quot` (TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
r TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
hist_H)
else
[Char] -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"work_asymp_M_max" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$
(TExp Int64
hist_Nout TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
hist_N)
TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`quot` ( (TExp Int64
q_small TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* Count NumBlocks (TExp Int64) -> TExp Int64
forall {k} (u :: k) e. Count u e -> e
unCount Count NumBlocks (TExp Int64)
num_tblocks' TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
hist_H)
TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`quot` [SegHistSlug] -> TExp Int64
forall i a. Num i => [a] -> i
L.genericLength [SegHistSlug]
slugs
)
TV Int32
hist_M <- [Char] -> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TV Int32)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"hist_M" (TExp Int32 -> ImpM GPUMem HostEnv HostOp (TV Int32))
-> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TV Int32)
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> TExp Int64 -> TExp Int64
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64 TExp Int64
hist_M0 TExp Int64
work_asymp_M_max
let hist_M_nonzero :: TExp Int32
hist_M_nonzero = TExp Int32 -> TExp Int32 -> TExp Int32
forall v. TPrimExp Int32 v -> TPrimExp Int32 v -> TPrimExp Int32 v
sMax32 TExp Int32
1 (TExp Int32 -> TExp Int32) -> TExp Int32 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
hist_M
TExp Int64
hist_C <-
[Char] -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_C" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$
TExp Int64
hist_B TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_M_nonzero
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"local hist_M0" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int64
hist_M0
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"local work asymp M max" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int64
work_asymp_M_max
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"local C" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int64
hist_C
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"local B" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int64
hist_B
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"local M" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TExp Int32 -> Exp) -> TExp Int32 -> Exp
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
hist_M
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
[Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"shared memory needed" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$
TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TExp Int64 -> Exp) -> TExp Int64 -> Exp
forall a b. (a -> b) -> a -> b
$
TExp Int64
hist_H TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
hist_el_size TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
hist_M)
TExp Int64
local_mem_needed <-
[Char] -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"local_mem_needed" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$
TExp Int64
hist_el_size TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
hist_M)
TExp Int32
hist_S <-
[Char] -> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_S" (TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32))
-> (TExp Int64 -> TExp Int32)
-> TExp Int64
-> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TExp Int64 -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int32))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
(TExp Int64
hist_H TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
local_mem_needed TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
1) TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` TV Int64 -> TExp Int64
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
hist_L
let max_S :: TExp Int32
max_S = case KernelBody GPUMem -> Passage
bodyPassage KernelBody GPUMem
kbody of
Passage
MustBeSinglePass -> TExp Int32
1
Passage
MayBeMultiPass -> Int -> TExp Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> TExp Int32) -> Int -> TExp Int32
forall a b. (a -> b) -> a -> b
$ [Int] -> Int
forall a (f :: * -> *). (Num a, Ord a, Foldable f) => f a -> a
maxinum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (SegHistSlug -> Int) -> [SegHistSlug] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> Int
slugMaxLocalMemPasses [SegHistSlug]
slugs
Count NumBlocks (TExp Int64)
blocks_per_segment <-
if Bool
segmented
then
(TExp Int64 -> Count NumBlocks (TExp Int64))
-> ImpM GPUMem HostEnv HostOp (TExp Int64)
-> ImpM GPUMem HostEnv HostOp (Count NumBlocks (TExp Int64))
forall a b.
(a -> b)
-> ImpM GPUMem HostEnv HostOp a -> ImpM GPUMem HostEnv HostOp b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TExp Int64 -> Count NumBlocks (TExp Int64)
forall {k} (u :: k) e. e -> Count u e
Count (ImpM GPUMem HostEnv HostOp (TExp Int64)
-> ImpM GPUMem HostEnv HostOp (Count NumBlocks (TExp Int64)))
-> ImpM GPUMem HostEnv HostOp (TExp Int64)
-> ImpM GPUMem HostEnv HostOp (Count NumBlocks (TExp Int64))
forall a b. (a -> b) -> a -> b
$
[Char] -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"blocks_per_segment" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$
Count NumBlocks (TExp Int64) -> TExp Int64
forall {k} (u :: k) e. Count u e -> e
unCount Count NumBlocks (TExp Int64)
num_tblocks' TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int64
hist_Nout
else Count NumBlocks (TExp Int64)
-> ImpM GPUMem HostEnv HostOp (Count NumBlocks (TExp Int64))
forall a. a -> ImpM GPUMem HostEnv HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Count NumBlocks (TExp Int64)
num_tblocks'
let pick_local :: TExp Bool
pick_local =
TExp Int64
hist_Nin
TExp Int64 -> TExp Int64 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>=. TExp Int64
hist_H
TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. (TExp Int64
local_mem_needed TExp Int64 -> TExp Int64 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TV Int64 -> TExp Int64
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
hist_L)
TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. (TExp Int32
hist_S TExp Int32 -> TExp Int32 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TExp Int32
max_S)
TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Int64
hist_C
TExp Int64 -> TExp Int64 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TExp Int64
hist_B
TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
hist_M
TExp Int32 -> TExp Int32 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. TExp Int32
0
run :: CallKernelGen ()
run = do
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"## Using shared memory" Maybe Exp
forall a. Maybe a
Nothing
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Histogram size (H)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int64
hist_H
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Multiplication degree (M)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TExp Int32 -> Exp) -> TExp Int32 -> Exp
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
hist_M
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Cooperation level (C)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int64
hist_C
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Number of chunks (S)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int32
hist_S
Bool -> CallKernelGen () -> CallKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
segmented (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
[Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Blocks per segment" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$
TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TExp Int64 -> Exp) -> TExp Int64 -> Exp
forall a b. (a -> b) -> a -> b
$
Count NumBlocks (TExp Int64) -> TExp Int64
forall {k} (u :: k) e. Count u e -> e
unCount Count NumBlocks (TExp Int64)
blocks_per_segment
(Map VName (VarEntry GPUMem) -> Map VName (VarEntry GPUMem))
-> CallKernelGen () -> CallKernelGen ()
forall rep r op a.
(VTable rep -> VTable rep) -> ImpM rep r op a -> ImpM rep r op a
localVTable Map VName (VarEntry GPUMem) -> Map VName (VarEntry GPUMem)
withSizeMax (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
TV Int32
-> Count NumBlocks (TExp Int64)
-> [PatElem LParamMem]
-> Count NumBlocks SubExp
-> Count BlockSize SubExp
-> SegSpace
-> TExp Int32
-> [SegHistSlug]
-> KernelBody GPUMem
-> CallKernelGen ()
histKernelLocal
TV Int32
hist_M
Count NumBlocks (TExp Int64)
blocks_per_segment
[PatElem LParamMem]
map_pes
Count NumBlocks SubExp
num_tblocks
Count BlockSize SubExp
tblock_size
SegSpace
space
TExp Int32
hist_S
[SegHistSlug]
slugs
KernelBody GPUMem
kbody
(TExp Bool, CallKernelGen ())
-> CallKernelGen (TExp Bool, CallKernelGen ())
forall a. a -> ImpM GPUMem HostEnv HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TExp Bool
pick_local, CallKernelGen ()
run)
compileSegHist ::
Pat LetDecMem ->
SegLevel ->
SegSpace ->
[HistOp GPUMem] ->
KernelBody GPUMem ->
CallKernelGen ()
compileSegHist :: Pat LParamMem
-> SegLevel
-> SegSpace
-> [HistOp GPUMem]
-> KernelBody GPUMem
-> CallKernelGen ()
compileSegHist (Pat [PatElem LParamMem]
pes) SegLevel
lvl SegSpace
space [HistOp GPUMem]
ops KernelBody GPUMem
kbody = do
KernelAttrs {kAttrNumBlocks :: KernelAttrs -> Count NumBlocks SubExp
kAttrNumBlocks = Count NumBlocks SubExp
num_tblocks, kAttrBlockSize :: KernelAttrs -> Count BlockSize SubExp
kAttrBlockSize = Count BlockSize SubExp
tblock_size} <-
SegLevel -> CallKernelGen KernelAttrs
lvlKernelAttrs SegLevel
lvl
let num_tblocks' :: Count NumBlocks (TExp Int64)
num_tblocks' = (SubExp -> TExp Int64)
-> Count NumBlocks SubExp -> Count NumBlocks (TExp Int64)
forall a b. (a -> b) -> Count NumBlocks a -> Count NumBlocks b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
pe64 Count NumBlocks SubExp
num_tblocks
tblock_size' :: Count BlockSize (TExp Int64)
tblock_size' = (SubExp -> TExp Int64)
-> Count BlockSize SubExp -> Count BlockSize (TExp Int64)
forall a b. (a -> b) -> Count BlockSize a -> Count BlockSize b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
pe64 Count BlockSize SubExp
tblock_size
dims :: [TExp Int64]
dims = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
num_red_res :: Int
num_red_res = [HistOp GPUMem] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp GPUMem]
ops Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((HistOp GPUMem -> Int) -> [HistOp GPUMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (HistOp GPUMem -> [SubExp]) -> HistOp GPUMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp GPUMem -> [SubExp]
forall rep. HistOp rep -> [SubExp]
histNeutral) [HistOp GPUMem]
ops)
([PatElem LParamMem]
all_red_pes, [PatElem LParamMem]
map_pes) = Int
-> [PatElem LParamMem]
-> ([PatElem LParamMem], [PatElem LParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_red_res [PatElem LParamMem]
pes
segment_size :: TExp Int64
segment_size = [TExp Int64] -> TExp Int64
forall a. HasCallStack => [a] -> a
last [TExp Int64]
dims
([Count Bytes (TExp Int64)]
op_hs, [Count Bytes (TExp Int64)]
op_seg_hs, [SegHistSlug]
slugs) <- [(Count Bytes (TExp Int64), Count Bytes (TExp Int64), SegHistSlug)]
-> ([Count Bytes (TExp Int64)], [Count Bytes (TExp Int64)],
[SegHistSlug])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(Count Bytes (TExp Int64), Count Bytes (TExp Int64),
SegHistSlug)]
-> ([Count Bytes (TExp Int64)], [Count Bytes (TExp Int64)],
[SegHistSlug]))
-> ImpM
GPUMem
HostEnv
HostOp
[(Count Bytes (TExp Int64), Count Bytes (TExp Int64), SegHistSlug)]
-> ImpM
GPUMem
HostEnv
HostOp
([Count Bytes (TExp Int64)], [Count Bytes (TExp Int64)],
[SegHistSlug])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HistOp GPUMem
-> CallKernelGen
(Count Bytes (TExp Int64), Count Bytes (TExp Int64), SegHistSlug))
-> [HistOp GPUMem]
-> ImpM
GPUMem
HostEnv
HostOp
[(Count Bytes (TExp Int64), Count Bytes (TExp Int64), SegHistSlug)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SegSpace
-> HistOp GPUMem
-> CallKernelGen
(Count Bytes (TExp Int64), Count Bytes (TExp Int64), SegHistSlug)
computeHistoUsage SegSpace
space) [HistOp GPUMem]
ops
TExp Int64
h <- [Char] -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"h" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ Count Bytes (TExp Int64) -> TExp Int64
forall {k} (u :: k) e. Count u e -> e
Imp.unCount (Count Bytes (TExp Int64) -> TExp Int64)
-> Count Bytes (TExp Int64) -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [Count Bytes (TExp Int64)] -> Count Bytes (TExp Int64)
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Count Bytes (TExp Int64)]
op_hs
TExp Int64
seg_h <- [Char] -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"seg_h" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ Count Bytes (TExp Int64) -> TExp Int64
forall {k} (u :: k) e. Count u e -> e
Imp.unCount (Count Bytes (TExp Int64) -> TExp Int64)
-> Count Bytes (TExp Int64) -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [Count Bytes (TExp Int64)] -> Count Bytes (TExp Int64)
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Count Bytes (TExp Int64)]
op_seg_hs
TExp Bool -> CallKernelGen () -> CallKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless (TExp Int64
seg_h TExp Int64 -> TExp Int64 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64
0) (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let hist_B :: TExp Int64
hist_B = Count BlockSize (TExp Int64) -> TExp Int64
forall {k} (u :: k) e. Count u e -> e
unCount Count BlockSize (TExp Int64)
tblock_size'
TExp Int64
hist_H <- [Char] -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_H" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ (HistOp GPUMem -> TExp Int64) -> [HistOp GPUMem] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map HistOp GPUMem -> TExp Int64
histSize [HistOp GPUMem]
ops
let lockSize :: SegHistSlug -> Maybe a
lockSize SegHistSlug
slug = case SegHistSlug -> AtomicUpdate GPUMem KernelEnv
slugAtomicUpdate SegHistSlug
slug of
AtomicLocking {} -> a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> a -> Maybe a
forall a b. (a -> b) -> a -> b
$ PrimType -> a
forall a. Num a => PrimType -> a
primByteSize PrimType
int32
AtomicUpdate GPUMem KernelEnv
_ -> Maybe a
forall a. Maybe a
Nothing
TExp Int64
hist_el_size <-
[Char] -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_el_size" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$
(TExp Int64 -> TExp Int64 -> TExp Int64)
-> TExp Int64 -> [TExp Int64] -> TExp Int64
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
L.foldl' TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
(+) (TExp Int64
h TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int64
hist_H) ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$
(SegHistSlug -> Maybe (TExp Int64))
-> [SegHistSlug] -> [TExp Int64]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe SegHistSlug -> Maybe (TExp Int64)
forall {a}. Num a => SegHistSlug -> Maybe a
lockSize [SegHistSlug]
slugs
TExp Int64
hist_N <- [Char] -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_N" TExp Int64
segment_size
TExp Int32
hist_RF <-
[Char] -> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"hist_RF" (TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
TExp Int64 -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$
[TExp Int64] -> TExp Int64
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((SegHistSlug -> TExp Int64) -> [SegHistSlug] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TExp Int64
pe64 (SubExp -> TExp Int64)
-> (SegHistSlug -> SubExp) -> SegHistSlug -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp GPUMem -> SubExp
forall rep. HistOp rep -> SubExp
histRaceFactor (HistOp GPUMem -> SubExp)
-> (SegHistSlug -> HistOp GPUMem) -> SegHistSlug -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp GPUMem
slugOp) [SegHistSlug]
slugs)
TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`quot` [SegHistSlug] -> TExp Int64
forall i a. Num i => [a] -> i
L.genericLength [SegHistSlug]
slugs
let hist_T :: TExp Int32
hist_T = TExp Int64 -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ Count NumBlocks (TExp Int64) -> TExp Int64
forall {k} (u :: k) e. Count u e -> e
unCount Count NumBlocks (TExp Int64)
num_tblocks' TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* Count BlockSize (TExp Int64) -> TExp Int64
forall {k} (u :: k) e. Count u e -> e
unCount Count BlockSize (TExp Int64)
tblock_size'
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"\n# SegHist" Maybe Exp
forall a. Maybe a
Nothing
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Number of threads (T)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int32
hist_T
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Desired block size (B)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int64
hist_B
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Histogram size (H)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int64
hist_H
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Input elements per histogram (N)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int64
hist_N
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
[Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Number of segments" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$
TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TExp Int64 -> Exp) -> TExp Int64 -> Exp
forall a b. (a -> b) -> a -> b
$
[TExp Int64] -> TExp Int64
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$
((VName, SubExp) -> TExp Int64)
-> [(VName, SubExp)] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TExp Int64
pe64 (SubExp -> TExp Int64)
-> ((VName, SubExp) -> SubExp) -> (VName, SubExp) -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) [(VName, SubExp)]
segment_dims
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Histogram element size (el_size)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int64
hist_el_size
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Race factor (RF)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int32
hist_RF
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Memory per set of subhistograms per segment" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int64
h
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Memory per set of subhistograms times segments" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int64
seg_h
(TExp Bool
use_shared_memory, CallKernelGen ()
run_in_shared_memory) <-
[PatElem LParamMem]
-> TExp Int32
-> SegSpace
-> TExp Int64
-> TExp Int64
-> TExp Int64
-> TExp Int32
-> [SegHistSlug]
-> KernelBody GPUMem
-> CallKernelGen (TExp Bool, CallKernelGen ())
localMemoryCase [PatElem LParamMem]
map_pes TExp Int32
hist_T SegSpace
space TExp Int64
hist_H TExp Int64
hist_el_size TExp Int64
hist_N TExp Int32
hist_RF [SegHistSlug]
slugs KernelBody GPUMem
kbody
TExp Bool
-> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf TExp Bool
use_shared_memory CallKernelGen ()
run_in_shared_memory (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
[PatElem LParamMem]
-> Count NumBlocks SubExp
-> Count BlockSize SubExp
-> SegSpace
-> [SegHistSlug]
-> KernelBody GPUMem
-> CallKernelGen ()
histKernelGlobal [PatElem LParamMem]
map_pes Count NumBlocks SubExp
num_tblocks Count BlockSize SubExp
tblock_size SegSpace
space [SegHistSlug]
slugs KernelBody GPUMem
kbody
let pes_per_op :: [[PatElem LParamMem]]
pes_per_op = [Int] -> [PatElem LParamMem] -> [[PatElem LParamMem]]
forall a. [Int] -> [a] -> [[a]]
chunks ((HistOp GPUMem -> Int) -> [HistOp GPUMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int)
-> (HistOp GPUMem -> [VName]) -> HistOp GPUMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp GPUMem -> [VName]
forall rep. HistOp rep -> [VName]
histDest) [HistOp GPUMem]
ops) [PatElem LParamMem]
all_red_pes
[(SegHistSlug, [PatElem LParamMem], HistOp GPUMem)]
-> ((SegHistSlug, [PatElem LParamMem], HistOp GPUMem)
-> CallKernelGen ())
-> CallKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegHistSlug]
-> [[PatElem LParamMem]]
-> [HistOp GPUMem]
-> [(SegHistSlug, [PatElem LParamMem], HistOp GPUMem)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [SegHistSlug]
slugs [[PatElem LParamMem]]
pes_per_op [HistOp GPUMem]
ops) (((SegHistSlug, [PatElem LParamMem], HistOp GPUMem)
-> CallKernelGen ())
-> CallKernelGen ())
-> ((SegHistSlug, [PatElem LParamMem], HistOp GPUMem)
-> CallKernelGen ())
-> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegHistSlug
slug, [PatElem LParamMem]
red_pes, HistOp GPUMem
op) -> do
let num_histos :: TV Int64
num_histos = SegHistSlug -> TV Int64
slugNumSubhistos SegHistSlug
slug
subhistos :: [VName]
subhistos = (SubhistosInfo -> VName) -> [SubhistosInfo] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map SubhistosInfo -> VName
subhistosArray ([SubhistosInfo] -> [VName]) -> [SubhistosInfo] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> [SubhistosInfo]
slugSubhistos SegHistSlug
slug
let unitHistoCase :: CallKernelGen ()
unitHistoCase =
[(PatElem LParamMem, VName)]
-> ((PatElem LParamMem, VName) -> CallKernelGen ())
-> CallKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem LParamMem] -> [VName] -> [(PatElem LParamMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
red_pes [VName]
subhistos) (((PatElem LParamMem, VName) -> CallKernelGen ())
-> CallKernelGen ())
-> ((PatElem LParamMem, VName) -> CallKernelGen ())
-> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, VName
subhisto) -> do
VName
pe_mem <-
MemLoc -> VName
memLocName (MemLoc -> VName) -> (ArrayEntry -> MemLoc) -> ArrayEntry -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayEntry -> MemLoc
entryArrayLoc
(ArrayEntry -> VName)
-> ImpM GPUMem HostEnv HostOp ArrayEntry
-> ImpM GPUMem HostEnv HostOp VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM GPUMem HostEnv HostOp ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray (PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)
VName
subhisto_mem <-
MemLoc -> VName
memLocName (MemLoc -> VName) -> (ArrayEntry -> MemLoc) -> ArrayEntry -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayEntry -> MemLoc
entryArrayLoc
(ArrayEntry -> VName)
-> ImpM GPUMem HostEnv HostOp ArrayEntry
-> ImpM GPUMem HostEnv HostOp VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM GPUMem HostEnv HostOp ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
subhisto
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> Space -> Code HostOp
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
pe_mem VName
subhisto_mem (Space -> Code HostOp) -> Space -> Code HostOp
forall a b. (a -> b) -> a -> b
$ [Char] -> Space
Space [Char]
"device"
TExp Bool
-> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf (TV Int64 -> TExp Int64
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
num_histos TExp Int64 -> TExp Int64 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64
1) CallKernelGen ()
unitHistoCase (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
[VName]
bucket_ids <-
Int
-> ImpM GPUMem HostEnv HostOp VName
-> ImpM GPUMem HostEnv HostOp [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histShape HistOp GPUMem
op)) ([Char] -> ImpM GPUMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"bucket_id")
VName
subhistogram_id <- [Char] -> ImpM GPUMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"subhistogram_id"
[VName]
vector_ids <-
Int
-> ImpM GPUMem HostEnv HostOp VName
-> ImpM GPUMem HostEnv HostOp [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histOpShape HistOp GPUMem
op)) ([Char] -> ImpM GPUMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"vector_id")
VName
flat_gtid <- [Char] -> ImpM GPUMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"flat_gtid"
let grid :: KernelGrid
grid = Count NumBlocks SubExp -> Count BlockSize SubExp -> KernelGrid
KernelGrid Count NumBlocks SubExp
num_tblocks Count BlockSize SubExp
tblock_size
segred_space :: SegSpace
segred_space =
VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
flat_gtid ([(VName, SubExp)] -> SegSpace) -> [(VName, SubExp)] -> SegSpace
forall a b. (a -> b) -> a -> b
$
[(VName, SubExp)]
segment_dims
[(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
bucket_ids (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histShape HistOp GPUMem
op))
[(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
vector_ids (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (Shape -> [SubExp]) -> Shape -> [SubExp]
forall a b. (a -> b) -> a -> b
$ HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histOpShape HistOp GPUMem
op)
[(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(VName
subhistogram_id, VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ TV Int64 -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV Int64
num_histos)]
subst :: Map VName VName
subst = VName -> VName -> Map VName VName
forall k a. k -> a -> Map k a
M.singleton (SegSpace -> VName
segFlat SegSpace
space) VName
flat_gtid
let segred_op :: SegBinOp GPUMem
segred_op = Commutativity
-> Lambda GPUMem -> [SubExp] -> Shape -> SegBinOp GPUMem
forall rep.
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp Commutativity
Commutative (Map VName VName -> Lambda GPUMem -> Lambda GPUMem
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst (Lambda GPUMem -> Lambda GPUMem) -> Lambda GPUMem -> Lambda GPUMem
forall a b. (a -> b) -> a -> b
$ HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp GPUMem
op) (HistOp GPUMem -> [SubExp]
forall rep. HistOp rep -> [SubExp]
histNeutral HistOp GPUMem
op) Shape
forall a. Monoid a => a
mempty
Pat LParamMem
-> KernelGrid
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
compileSegRed' ([PatElem LParamMem] -> Pat LParamMem
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem LParamMem]
red_pes) KernelGrid
grid SegSpace
segred_space [SegBinOp GPUMem
segred_op] (DoSegBody -> CallKernelGen ()) -> DoSegBody -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[(SubExp, [TExp Int64])] -> InKernelGen ()
red_cont ->
[(SubExp, [TExp Int64])] -> InKernelGen ()
red_cont ([(SubExp, [TExp Int64])] -> InKernelGen ())
-> ((VName -> (SubExp, [TExp Int64])) -> [(SubExp, [TExp Int64])])
-> (VName -> (SubExp, [TExp Int64]))
-> InKernelGen ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((VName -> (SubExp, [TExp Int64]))
-> [VName] -> [(SubExp, [TExp Int64])])
-> [VName]
-> (VName -> (SubExp, [TExp Int64]))
-> [(SubExp, [TExp Int64])]
forall a b c. (a -> b -> c) -> b -> a -> c
flip (VName -> (SubExp, [TExp Int64]))
-> [VName] -> [(SubExp, [TExp Int64])]
forall a b. (a -> b) -> [a] -> [b]
map [VName]
subhistos ((VName -> (SubExp, [TExp Int64])) -> InKernelGen ())
-> (VName -> (SubExp, [TExp Int64])) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \VName
subhisto ->
( VName -> SubExp
Var VName
subhisto,
(VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
Imp.le64 ([VName] -> [TExp Int64]) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$
((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst [(VName, SubExp)]
segment_dims
[VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName
subhistogram_id]
[VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
bucket_ids
[VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
vector_ids
)
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"" Maybe Exp
forall a. Maybe a
Nothing
where
segment_dims :: [(VName, SubExp)]
segment_dims = [(VName, SubExp)] -> [(VName, SubExp)]
forall a. HasCallStack => [a] -> [a]
init ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space