{-# LANGUAGE TypeFamilies #-}
module Futhark.Optimise.MemoryBlockMerging (optimise) where
import Control.Exception
import Control.Monad.State.Strict
import Data.Function ((&))
import Data.Map (Map, (!))
import Data.Map qualified as M
import Data.Set (Set)
import Data.Set qualified as S
import Futhark.Analysis.Interference qualified as Interference
import Futhark.Builder.Class
import Futhark.Construct
import Futhark.IR.GPUMem
import Futhark.Optimise.MemoryBlockMerging.GreedyColoring qualified as GreedyColoring
import Futhark.Pass (Pass (..), PassM)
import Futhark.Pass qualified as Pass
import Futhark.Util (invertMap)
type Allocs = Map VName (SubExp, Space)
getAllocsStm :: Stm GPUMem -> Allocs
getAllocsStm :: Stm GPUMem -> Allocs
getAllocsStm (Let (Pat [PatElem VName
name LetDec GPUMem
_]) StmAux (ExpDec GPUMem)
_ (Op (Alloc SubExp
se Space
sp))) =
VName -> (SubExp, Space) -> Allocs
forall k a. k -> a -> Map k a
M.singleton VName
name (SubExp
se, Space
sp)
getAllocsStm (Let Pat (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (Op (Alloc SubExp
_ Space
_))) = [Char] -> Allocs
forall a. HasCallStack => [Char] -> a
error [Char]
"impossible"
getAllocsStm (Let Pat (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (Match [SubExp]
_ [Case (Body GPUMem)]
cases Body GPUMem
defbody MatchDec (BranchType GPUMem)
_)) =
(Body GPUMem -> Allocs) -> [Body GPUMem] -> Allocs
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap ((Stm GPUMem -> Allocs) -> Seq (Stm GPUMem) -> Allocs
forall m a. Monoid m => (a -> m) -> Seq a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Allocs
getAllocsStm (Seq (Stm GPUMem) -> Allocs)
-> (Body GPUMem -> Seq (Stm GPUMem)) -> Body GPUMem -> Allocs
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body GPUMem -> Seq (Stm GPUMem)
forall rep. Body rep -> Stms rep
bodyStms) ([Body GPUMem] -> Allocs) -> [Body GPUMem] -> Allocs
forall a b. (a -> b) -> a -> b
$ Body GPUMem
defbody Body GPUMem -> [Body GPUMem] -> [Body GPUMem]
forall a. a -> [a] -> [a]
: (Case (Body GPUMem) -> Body GPUMem)
-> [Case (Body GPUMem)] -> [Body GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map Case (Body GPUMem) -> Body GPUMem
forall body. Case body -> body
caseBody [Case (Body GPUMem)]
cases
getAllocsStm (Let Pat (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (Loop [(FParam GPUMem, SubExp)]
_ LoopForm
_ Body GPUMem
body)) =
(Stm GPUMem -> Allocs) -> Seq (Stm GPUMem) -> Allocs
forall m a. Monoid m => (a -> m) -> Seq a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Allocs
getAllocsStm (Body GPUMem -> Seq (Stm GPUMem)
forall rep. Body rep -> Stms rep
bodyStms Body GPUMem
body)
getAllocsStm Stm GPUMem
_ = Allocs
forall a. Monoid a => a
mempty
getAllocsSegOp :: SegOp lvl GPUMem -> Allocs
getAllocsSegOp :: forall lvl. SegOp lvl GPUMem -> Allocs
getAllocsSegOp (SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody GPUMem
body) =
(Stm GPUMem -> Allocs) -> Seq (Stm GPUMem) -> Allocs
forall m a. Monoid m => (a -> m) -> Seq a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Allocs
getAllocsStm (KernelBody GPUMem -> Seq (Stm GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body)
getAllocsSegOp (SegRed lvl
_ SegSpace
_ [Type]
_ KernelBody GPUMem
body [SegBinOp GPUMem]
_) =
(Stm GPUMem -> Allocs) -> Seq (Stm GPUMem) -> Allocs
forall m a. Monoid m => (a -> m) -> Seq a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Allocs
getAllocsStm (KernelBody GPUMem -> Seq (Stm GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body)
getAllocsSegOp (SegScan lvl
_ SegSpace
_ [Type]
_ KernelBody GPUMem
body [SegBinOp GPUMem]
_) =
(Stm GPUMem -> Allocs) -> Seq (Stm GPUMem) -> Allocs
forall m a. Monoid m => (a -> m) -> Seq a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Allocs
getAllocsStm (KernelBody GPUMem -> Seq (Stm GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body)
getAllocsSegOp (SegHist lvl
_ SegSpace
_ [Type]
_ KernelBody GPUMem
body [HistOp GPUMem]
_) =
(Stm GPUMem -> Allocs) -> Seq (Stm GPUMem) -> Allocs
forall m a. Monoid m => (a -> m) -> Seq a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Allocs
getAllocsStm (KernelBody GPUMem -> Seq (Stm GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body)
setAllocsStm :: Map VName SubExp -> Stm GPUMem -> Stm GPUMem
setAllocsStm :: Map VName SubExp -> Stm GPUMem -> Stm GPUMem
setAllocsStm Map VName SubExp
m stm :: Stm GPUMem
stm@(Let (Pat [PatElem VName
name LetDec GPUMem
_]) StmAux (ExpDec GPUMem)
_ (Op (Alloc SubExp
_ Space
_)))
| Just SubExp
s <- VName -> Map VName SubExp -> Maybe SubExp
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name Map VName SubExp
m =
Stm GPUMem
stm {stmExp = BasicOp $ SubExp s}
setAllocsStm Map VName SubExp
_ stm :: Stm GPUMem
stm@(Let Pat (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (Op (Alloc SubExp
_ Space
_))) = Stm GPUMem
stm
setAllocsStm Map VName SubExp
m stm :: Stm GPUMem
stm@(Let Pat (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (Op (Inner (SegOp SegOp SegLevel GPUMem
segop)))) =
Stm GPUMem
stm {stmExp = Op $ Inner $ SegOp $ setAllocsSegOp m segop}
setAllocsStm Map VName SubExp
m stm :: Stm GPUMem
stm@(Let Pat (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (Match [SubExp]
cond [Case (Body GPUMem)]
cases Body GPUMem
defbody MatchDec (BranchType GPUMem)
dec)) =
Stm GPUMem
stm {stmExp = Match cond (map (fmap onBody) cases) (onBody defbody) dec}
where
onBody :: Body GPUMem -> Body GPUMem
onBody (Body () Seq (Stm GPUMem)
stms Result
res) = BodyDec GPUMem -> Seq (Stm GPUMem) -> Result -> Body GPUMem
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () (Map VName SubExp -> Stm GPUMem -> Stm GPUMem
setAllocsStm Map VName SubExp
m (Stm GPUMem -> Stm GPUMem) -> Seq (Stm GPUMem) -> Seq (Stm GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Seq (Stm GPUMem)
stms) Result
res
setAllocsStm Map VName SubExp
m stm :: Stm GPUMem
stm@(Let Pat (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (Loop [(FParam GPUMem, SubExp)]
merge LoopForm
form Body GPUMem
body)) =
Stm GPUMem
stm
{ stmExp =
Loop merge form (body {bodyStms = setAllocsStm m <$> bodyStms body})
}
setAllocsStm Map VName SubExp
_ Stm GPUMem
stm = Stm GPUMem
stm
setAllocsSegOp ::
Map VName SubExp ->
SegOp lvl GPUMem ->
SegOp lvl GPUMem
setAllocsSegOp :: forall lvl.
Map VName SubExp -> SegOp lvl GPUMem -> SegOp lvl GPUMem
setAllocsSegOp Map VName SubExp
m (SegMap lvl
lvl SegSpace
sp [Type]
tps KernelBody GPUMem
body) =
lvl -> SegSpace -> [Type] -> KernelBody GPUMem -> SegOp lvl GPUMem
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap lvl
lvl SegSpace
sp [Type]
tps (KernelBody GPUMem -> SegOp lvl GPUMem)
-> KernelBody GPUMem -> SegOp lvl GPUMem
forall a b. (a -> b) -> a -> b
$
KernelBody GPUMem
body {kernelBodyStms = setAllocsStm m <$> kernelBodyStms body}
setAllocsSegOp Map VName SubExp
m (SegRed lvl
lvl SegSpace
sp [Type]
tps KernelBody GPUMem
body [SegBinOp GPUMem]
ops) =
lvl
-> SegSpace
-> [Type]
-> KernelBody GPUMem
-> [SegBinOp GPUMem]
-> SegOp lvl GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [Type]
-> KernelBody rep
-> [SegBinOp rep]
-> SegOp lvl rep
SegRed lvl
lvl SegSpace
sp [Type]
tps KernelBody GPUMem
body' [SegBinOp GPUMem]
ops
where
body' :: KernelBody GPUMem
body' = KernelBody GPUMem
body {kernelBodyStms = setAllocsStm m <$> kernelBodyStms body}
setAllocsSegOp Map VName SubExp
m (SegScan lvl
lvl SegSpace
sp [Type]
tps KernelBody GPUMem
body [SegBinOp GPUMem]
ops) =
lvl
-> SegSpace
-> [Type]
-> KernelBody GPUMem
-> [SegBinOp GPUMem]
-> SegOp lvl GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [Type]
-> KernelBody rep
-> [SegBinOp rep]
-> SegOp lvl rep
SegScan lvl
lvl SegSpace
sp [Type]
tps KernelBody GPUMem
body' [SegBinOp GPUMem]
ops
where
body' :: KernelBody GPUMem
body' = KernelBody GPUMem
body {kernelBodyStms = setAllocsStm m <$> kernelBodyStms body}
setAllocsSegOp Map VName SubExp
m (SegHist lvl
lvl SegSpace
sp [Type]
tps KernelBody GPUMem
body [HistOp GPUMem]
ops) =
lvl
-> SegSpace
-> [Type]
-> KernelBody GPUMem
-> [HistOp GPUMem]
-> SegOp lvl GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [Type]
-> KernelBody rep
-> [HistOp rep]
-> SegOp lvl rep
SegHist lvl
lvl SegSpace
sp [Type]
tps KernelBody GPUMem
body' [HistOp GPUMem]
ops
where
body' :: KernelBody GPUMem
body' = KernelBody GPUMem
body {kernelBodyStms = setAllocsStm m <$> kernelBodyStms body}
maxSubExp :: (MonadBuilder m) => Set SubExp -> m SubExp
maxSubExp :: forall (m :: * -> *). MonadBuilder m => Set SubExp -> m SubExp
maxSubExp = [SubExp] -> m SubExp
forall {m :: * -> *}. MonadBuilder m => [SubExp] -> m SubExp
helper ([SubExp] -> m SubExp)
-> (Set SubExp -> [SubExp]) -> Set SubExp -> m SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Set SubExp -> [SubExp]
forall a. Set a -> [a]
S.toList
where
helper :: [SubExp] -> m SubExp
helper (SubExp
s1 : SubExp
s2 : [SubExp]
sexps) = do
SubExp
z <- [Char] -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"maxSubHelper" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
UMax IntType
Int64) SubExp
s1 SubExp
s2
[SubExp] -> m SubExp
helper (SubExp
z SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: [SubExp]
sexps)
helper [SubExp
s] =
SubExp -> m SubExp
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
s
helper [] = [Char] -> m SubExp
forall a. HasCallStack => [Char] -> a
error [Char]
"impossible"
isKernelInvariant :: Scope GPUMem -> (SubExp, space) -> Bool
isKernelInvariant :: forall space. Scope GPUMem -> (SubExp, space) -> Bool
isKernelInvariant Scope GPUMem
scope (Var VName
vname, space
_) = VName
vname VName -> Scope GPUMem -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.member` Scope GPUMem
scope
isKernelInvariant Scope GPUMem
_ (SubExp, space)
_ = Bool
True
isScalarSpace :: (subExp, Space) -> Bool
isScalarSpace :: forall subExp. (subExp, Space) -> Bool
isScalarSpace (subExp
_, ScalarSpace [SubExp]
_ PrimType
_) = Bool
True
isScalarSpace (subExp, Space)
_ = Bool
False
onKernelBodyStms ::
(MonadBuilder m) =>
SegOp lvl GPUMem ->
(Stms GPUMem -> m (Stms GPUMem)) ->
m (SegOp lvl GPUMem)
onKernelBodyStms :: forall (m :: * -> *) lvl.
MonadBuilder m =>
SegOp lvl GPUMem
-> (Seq (Stm GPUMem) -> m (Seq (Stm GPUMem)))
-> m (SegOp lvl GPUMem)
onKernelBodyStms (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody GPUMem
body) Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
f = do
Seq (Stm GPUMem)
stms <- Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
f (Seq (Stm GPUMem) -> m (Seq (Stm GPUMem)))
-> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> Seq (Stm GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body
SegOp lvl GPUMem -> m (SegOp lvl GPUMem)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SegOp lvl GPUMem -> m (SegOp lvl GPUMem))
-> SegOp lvl GPUMem -> m (SegOp lvl GPUMem)
forall a b. (a -> b) -> a -> b
$ lvl -> SegSpace -> [Type] -> KernelBody GPUMem -> SegOp lvl GPUMem
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap lvl
lvl SegSpace
space [Type]
ts (KernelBody GPUMem -> SegOp lvl GPUMem)
-> KernelBody GPUMem -> SegOp lvl GPUMem
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem
body {kernelBodyStms = stms}
onKernelBodyStms (SegRed lvl
lvl SegSpace
space [Type]
ts KernelBody GPUMem
body [SegBinOp GPUMem]
binops) Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
f = do
Seq (Stm GPUMem)
stms <- Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
f (Seq (Stm GPUMem) -> m (Seq (Stm GPUMem)))
-> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> Seq (Stm GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body
SegOp lvl GPUMem -> m (SegOp lvl GPUMem)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SegOp lvl GPUMem -> m (SegOp lvl GPUMem))
-> SegOp lvl GPUMem -> m (SegOp lvl GPUMem)
forall a b. (a -> b) -> a -> b
$ lvl
-> SegSpace
-> [Type]
-> KernelBody GPUMem
-> [SegBinOp GPUMem]
-> SegOp lvl GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [Type]
-> KernelBody rep
-> [SegBinOp rep]
-> SegOp lvl rep
SegRed lvl
lvl SegSpace
space [Type]
ts (KernelBody GPUMem
body {kernelBodyStms = stms}) [SegBinOp GPUMem]
binops
onKernelBodyStms (SegScan lvl
lvl SegSpace
space [Type]
ts KernelBody GPUMem
body [SegBinOp GPUMem]
binops) Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
f = do
Seq (Stm GPUMem)
stms <- Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
f (Seq (Stm GPUMem) -> m (Seq (Stm GPUMem)))
-> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> Seq (Stm GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body
SegOp lvl GPUMem -> m (SegOp lvl GPUMem)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SegOp lvl GPUMem -> m (SegOp lvl GPUMem))
-> SegOp lvl GPUMem -> m (SegOp lvl GPUMem)
forall a b. (a -> b) -> a -> b
$ lvl
-> SegSpace
-> [Type]
-> KernelBody GPUMem
-> [SegBinOp GPUMem]
-> SegOp lvl GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [Type]
-> KernelBody rep
-> [SegBinOp rep]
-> SegOp lvl rep
SegScan lvl
lvl SegSpace
space [Type]
ts (KernelBody GPUMem
body {kernelBodyStms = stms}) [SegBinOp GPUMem]
binops
onKernelBodyStms (SegHist lvl
lvl SegSpace
space [Type]
ts KernelBody GPUMem
body [HistOp GPUMem]
binops) Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
f = do
Seq (Stm GPUMem)
stms <- Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
f (Seq (Stm GPUMem) -> m (Seq (Stm GPUMem)))
-> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> Seq (Stm GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body
SegOp lvl GPUMem -> m (SegOp lvl GPUMem)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SegOp lvl GPUMem -> m (SegOp lvl GPUMem))
-> SegOp lvl GPUMem -> m (SegOp lvl GPUMem)
forall a b. (a -> b) -> a -> b
$ lvl
-> SegSpace
-> [Type]
-> KernelBody GPUMem
-> [HistOp GPUMem]
-> SegOp lvl GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [Type]
-> KernelBody rep
-> [HistOp rep]
-> SegOp lvl rep
SegHist lvl
lvl SegSpace
space [Type]
ts (KernelBody GPUMem
body {kernelBodyStms = stms}) [HistOp GPUMem]
binops
optimiseKernel ::
(MonadBuilder m, Rep m ~ GPUMem) =>
Interference.Graph VName ->
SegOp lvl GPUMem ->
m (SegOp lvl GPUMem)
optimiseKernel :: forall (m :: * -> *) lvl.
(MonadBuilder m, Rep m ~ GPUMem) =>
Graph VName -> SegOp lvl GPUMem -> m (SegOp lvl GPUMem)
optimiseKernel Graph VName
graph SegOp lvl GPUMem
segop0 = do
SegOp lvl GPUMem
segop <- SegOp lvl GPUMem
-> (Seq (Stm GPUMem) -> m (Seq (Stm GPUMem)))
-> m (SegOp lvl GPUMem)
forall (m :: * -> *) lvl.
MonadBuilder m =>
SegOp lvl GPUMem
-> (Seq (Stm GPUMem) -> m (Seq (Stm GPUMem)))
-> m (SegOp lvl GPUMem)
onKernelBodyStms SegOp lvl GPUMem
segop0 ((Seq (Stm GPUMem) -> m (Seq (Stm GPUMem)))
-> m (SegOp lvl GPUMem))
-> (Seq (Stm GPUMem) -> m (Seq (Stm GPUMem)))
-> m (SegOp lvl GPUMem)
forall a b. (a -> b) -> a -> b
$ (SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem))
-> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
forall (m :: * -> *).
LocalScope GPUMem m =>
(SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem))
-> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
onKernels ((SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem))
-> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem)))
-> (SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem))
-> Seq (Stm GPUMem)
-> m (Seq (Stm GPUMem))
forall a b. (a -> b) -> a -> b
$ Graph VName -> SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem)
forall (m :: * -> *) lvl.
(MonadBuilder m, Rep m ~ GPUMem) =>
Graph VName -> SegOp lvl GPUMem -> m (SegOp lvl GPUMem)
optimiseKernel Graph VName
graph
Scope GPUMem
scope_here <- m (Scope GPUMem)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
let allocs :: Allocs
allocs =
((SubExp, Space) -> Bool) -> Allocs -> Allocs
forall a k. (a -> Bool) -> Map k a -> Map k a
M.filter (\(SubExp, Space)
alloc -> Scope GPUMem -> (SubExp, Space) -> Bool
forall space. Scope GPUMem -> (SubExp, space) -> Bool
isKernelInvariant Scope GPUMem
scope_here (SubExp, Space)
alloc Bool -> Bool -> Bool
&& Bool -> Bool
not ((SubExp, Space) -> Bool
forall subExp. (subExp, Space) -> Bool
isScalarSpace (SubExp, Space)
alloc)) (Allocs -> Allocs) -> Allocs -> Allocs
forall a b. (a -> b) -> a -> b
$
SegOp lvl GPUMem -> Allocs
forall lvl. SegOp lvl GPUMem -> Allocs
getAllocsSegOp SegOp lvl GPUMem
segop
(Map Int Space
colorspaces, Coloring VName
coloring) =
Map VName Space -> Graph VName -> (Map Int Space, Coloring VName)
forall a space.
(Ord a, Ord space) =>
Map a space -> Graph a -> (Map Int space, Coloring a)
GreedyColoring.colorGraph
(((SubExp, Space) -> Space) -> Allocs -> Map VName Space
forall a b. (a -> b) -> Map VName a -> Map VName b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (SubExp, Space) -> Space
forall a b. (a, b) -> b
snd Allocs
allocs)
Graph VName
graph
([SubExp]
maxes, Seq (Stm GPUMem)
maxstms) <-
Coloring VName -> Map Int (Set VName)
forall v k. (Ord v, Ord k) => Map k v -> Map v (Set k)
invertMap Coloring VName
coloring
Map Int (Set VName)
-> (Map Int (Set VName) -> [Set VName]) -> [Set VName]
forall a b. a -> (a -> b) -> b
& Map Int (Set VName) -> [Set VName]
forall k a. Map k a -> [a]
M.elems
[Set VName] -> ([Set VName] -> m [SubExp]) -> m [SubExp]
forall a b. a -> (a -> b) -> b
& (Set VName -> m SubExp) -> [Set VName] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Set SubExp -> m SubExp
forall (m :: * -> *). MonadBuilder m => Set SubExp -> m SubExp
maxSubExp (Set SubExp -> m SubExp)
-> (Set VName -> Set SubExp) -> Set VName -> m SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> SubExp) -> Set VName -> Set SubExp
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map ((SubExp, Space) -> SubExp
forall a b. (a, b) -> a
fst ((SubExp, Space) -> SubExp)
-> (VName -> (SubExp, Space)) -> VName -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Allocs
allocs !)))
m [SubExp]
-> (m [SubExp] -> m ([SubExp], Seq (Stm GPUMem)))
-> m ([SubExp], Seq (Stm GPUMem))
forall a b. a -> (a -> b) -> b
& m [SubExp] -> m ([SubExp], Stms (Rep m))
m [SubExp] -> m ([SubExp], Seq (Stm GPUMem))
forall a. m a -> m (a, Stms (Rep m))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms
([SubExp]
colors, Seq (Stm GPUMem)
stms) <-
Bool -> [SubExp] -> [SubExp]
forall a. HasCallStack => Bool -> a -> a
assert ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
maxes Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Map Int Space -> Int
forall k a. Map k a -> Int
M.size Map Int Space
colorspaces) [SubExp]
maxes
[SubExp] -> ([SubExp] -> [(Int, SubExp)]) -> [(Int, SubExp)]
forall a b. a -> (a -> b) -> b
& [Int] -> [SubExp] -> [(Int, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 ..]
[(Int, SubExp)] -> ([(Int, SubExp)] -> m [SubExp]) -> m [SubExp]
forall a b. a -> (a -> b) -> b
& ((Int, SubExp) -> m SubExp) -> [(Int, SubExp)] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\(Int
i, SubExp
x) -> [Char] -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"color" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ Op (Rep m) -> Exp (Rep m)
forall rep. Op rep -> Exp rep
Op (Op (Rep m) -> Exp (Rep m)) -> Op (Rep m) -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp (HostOp NoOp) GPUMem
forall (inner :: * -> *) rep. SubExp -> Space -> MemOp inner rep
Alloc SubExp
x (Space -> MemOp (HostOp NoOp) GPUMem)
-> Space -> MemOp (HostOp NoOp) GPUMem
forall a b. (a -> b) -> a -> b
$ Map Int Space
colorspaces Map Int Space -> Int -> Space
forall k a. Ord k => Map k a -> k -> a
! Int
i)
m [SubExp]
-> (m [SubExp] -> m ([SubExp], Seq (Stm GPUMem)))
-> m ([SubExp], Seq (Stm GPUMem))
forall a b. a -> (a -> b) -> b
& m [SubExp] -> m ([SubExp], Stms (Rep m))
m [SubExp] -> m ([SubExp], Seq (Stm GPUMem))
forall a. m a -> m (a, Stms (Rep m))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms
let segop' :: SegOp lvl GPUMem
segop' = Map VName SubExp -> SegOp lvl GPUMem -> SegOp lvl GPUMem
forall lvl.
Map VName SubExp -> SegOp lvl GPUMem -> SegOp lvl GPUMem
setAllocsSegOp ((Int -> SubExp) -> Coloring VName -> Map VName SubExp
forall a b. (a -> b) -> Map VName a -> Map VName b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([SubExp]
colors !!) Coloring VName
coloring) SegOp lvl GPUMem
segop
SegOp lvl GPUMem -> m (SegOp lvl GPUMem)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SegOp lvl GPUMem -> m (SegOp lvl GPUMem))
-> SegOp lvl GPUMem -> m (SegOp lvl GPUMem)
forall a b. (a -> b) -> a -> b
$ case SegOp lvl GPUMem
segop' of
SegMap lvl
lvl SegSpace
sp [Type]
tps KernelBody GPUMem
body ->
lvl -> SegSpace -> [Type] -> KernelBody GPUMem -> SegOp lvl GPUMem
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap lvl
lvl SegSpace
sp [Type]
tps (KernelBody GPUMem -> SegOp lvl GPUMem)
-> KernelBody GPUMem -> SegOp lvl GPUMem
forall a b. (a -> b) -> a -> b
$
KernelBody GPUMem
body {kernelBodyStms = maxstms <> stms <> kernelBodyStms body}
SegRed lvl
lvl SegSpace
sp [Type]
tps KernelBody GPUMem
body [SegBinOp GPUMem]
ops ->
lvl
-> SegSpace
-> [Type]
-> KernelBody GPUMem
-> [SegBinOp GPUMem]
-> SegOp lvl GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [Type]
-> KernelBody rep
-> [SegBinOp rep]
-> SegOp lvl rep
SegRed lvl
lvl SegSpace
sp [Type]
tps KernelBody GPUMem
body' [SegBinOp GPUMem]
ops
where
body' :: KernelBody GPUMem
body' = KernelBody GPUMem
body {kernelBodyStms = maxstms <> stms <> kernelBodyStms body}
SegScan lvl
lvl SegSpace
sp [Type]
tps KernelBody GPUMem
body [SegBinOp GPUMem]
ops ->
lvl
-> SegSpace
-> [Type]
-> KernelBody GPUMem
-> [SegBinOp GPUMem]
-> SegOp lvl GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [Type]
-> KernelBody rep
-> [SegBinOp rep]
-> SegOp lvl rep
SegScan lvl
lvl SegSpace
sp [Type]
tps KernelBody GPUMem
body' [SegBinOp GPUMem]
ops
where
body' :: KernelBody GPUMem
body' = KernelBody GPUMem
body {kernelBodyStms = maxstms <> stms <> kernelBodyStms body}
SegHist lvl
lvl SegSpace
sp [Type]
tps KernelBody GPUMem
body [HistOp GPUMem]
ops ->
lvl
-> SegSpace
-> [Type]
-> KernelBody GPUMem
-> [HistOp GPUMem]
-> SegOp lvl GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [Type]
-> KernelBody rep
-> [HistOp rep]
-> SegOp lvl rep
SegHist lvl
lvl SegSpace
sp [Type]
tps KernelBody GPUMem
body' [HistOp GPUMem]
ops
where
body' :: KernelBody GPUMem
body' = KernelBody GPUMem
body {kernelBodyStms = maxstms <> stms <> kernelBodyStms body}
onKernels ::
(LocalScope GPUMem m) =>
(SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem)) ->
Stms GPUMem ->
m (Stms GPUMem)
onKernels :: forall (m :: * -> *).
LocalScope GPUMem m =>
(SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem))
-> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
onKernels SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem)
f Seq (Stm GPUMem)
orig_stms = Seq (Stm GPUMem) -> m (Seq (Stm GPUMem)) -> m (Seq (Stm GPUMem))
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Seq (Stm GPUMem)
orig_stms (m (Seq (Stm GPUMem)) -> m (Seq (Stm GPUMem)))
-> m (Seq (Stm GPUMem)) -> m (Seq (Stm GPUMem))
forall a b. (a -> b) -> a -> b
$ (Stm GPUMem -> m (Stm GPUMem))
-> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Seq a -> m (Seq b)
mapM Stm GPUMem -> m (Stm GPUMem)
helper Seq (Stm GPUMem)
orig_stms
where
helper :: Stm GPUMem -> m (Stm GPUMem)
helper stm :: Stm GPUMem
stm@Let {stmExp :: forall rep. Stm rep -> Exp rep
stmExp = Op (Inner (SegOp SegOp SegLevel GPUMem
segop))} = do
SegOp SegLevel GPUMem
exp' <- SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem)
f SegOp SegLevel GPUMem
segop
Stm GPUMem -> m (Stm GPUMem)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm GPUMem -> m (Stm GPUMem)) -> Stm GPUMem -> m (Stm GPUMem)
forall a b. (a -> b) -> a -> b
$ Stm GPUMem
stm {stmExp = Op $ Inner $ SegOp exp'}
helper stm :: Stm GPUMem
stm@Let {stmExp :: forall rep. Stm rep -> Exp rep
stmExp = Match [SubExp]
c [Case (Body GPUMem)]
cases Body GPUMem
defbody MatchDec (BranchType GPUMem)
dec} = do
[Case (Body GPUMem)]
cases' <- (Case (Body GPUMem) -> m (Case (Body GPUMem)))
-> [Case (Body GPUMem)] -> m [Case (Body GPUMem)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((Body GPUMem -> m (Body GPUMem))
-> Case (Body GPUMem) -> m (Case (Body GPUMem))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Case a -> f (Case b)
traverse Body GPUMem -> m (Body GPUMem)
onBody) [Case (Body GPUMem)]
cases
Body GPUMem
defbody' <- Body GPUMem -> m (Body GPUMem)
onBody Body GPUMem
defbody
Stm GPUMem -> m (Stm GPUMem)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm GPUMem -> m (Stm GPUMem)) -> Stm GPUMem -> m (Stm GPUMem)
forall a b. (a -> b) -> a -> b
$ Stm GPUMem
stm {stmExp = Match c cases' defbody' dec}
where
onBody :: Body GPUMem -> m (Body GPUMem)
onBody (Body () Seq (Stm GPUMem)
stms Result
res) =
BodyDec GPUMem -> Seq (Stm GPUMem) -> Result -> Body GPUMem
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () (Seq (Stm GPUMem) -> Result -> Body GPUMem)
-> m (Seq (Stm GPUMem)) -> m (Result -> Body GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem)
f (SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem))
-> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
forall (m :: * -> *).
LocalScope GPUMem m =>
(SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem))
-> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
`onKernels` Seq (Stm GPUMem)
stms m (Result -> Body GPUMem) -> m Result -> m (Body GPUMem)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> m Result
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res
helper stm :: Stm GPUMem
stm@Let {stmExp :: forall rep. Stm rep -> Exp rep
stmExp = Loop [(FParam GPUMem, SubExp)]
merge LoopForm
form Body GPUMem
body} = do
Seq (Stm GPUMem)
body_stms <- SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem)
f (SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem))
-> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
forall (m :: * -> *).
LocalScope GPUMem m =>
(SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem))
-> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
`onKernels` Body GPUMem -> Seq (Stm GPUMem)
forall rep. Body rep -> Stms rep
bodyStms Body GPUMem
body
Stm GPUMem -> m (Stm GPUMem)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm GPUMem -> m (Stm GPUMem)) -> Stm GPUMem -> m (Stm GPUMem)
forall a b. (a -> b) -> a -> b
$ Stm GPUMem
stm {stmExp = Loop merge form (body {bodyStms = body_stms})}
helper Stm GPUMem
stm = Stm GPUMem -> m (Stm GPUMem)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stm GPUMem
stm
optimise :: Pass GPUMem GPUMem
optimise :: Pass GPUMem GPUMem
optimise =
[Char]
-> [Char]
-> (Prog GPUMem -> PassM (Prog GPUMem))
-> Pass GPUMem GPUMem
forall fromrep torep.
[Char]
-> [Char]
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass [Char]
"memory block merging" [Char]
"memory block merging allocations" ((Prog GPUMem -> PassM (Prog GPUMem)) -> Pass GPUMem GPUMem)
-> (Prog GPUMem -> PassM (Prog GPUMem)) -> Pass GPUMem GPUMem
forall a b. (a -> b) -> a -> b
$ \Prog GPUMem
prog ->
let graph :: Graph VName
graph = Prog GPUMem -> Graph VName
Interference.analyseProgGPU Prog GPUMem
prog
in (Scope GPUMem -> Seq (Stm GPUMem) -> PassM (Seq (Stm GPUMem)))
-> Prog GPUMem -> PassM (Prog GPUMem)
forall rep.
(Scope rep -> Stms rep -> PassM (Stms rep))
-> Prog rep -> PassM (Prog rep)
Pass.intraproceduralTransformation (Graph VName
-> Scope GPUMem -> Seq (Stm GPUMem) -> PassM (Seq (Stm GPUMem))
onStms Graph VName
graph) Prog GPUMem
prog
where
onStms ::
Interference.Graph VName ->
Scope GPUMem ->
Stms GPUMem ->
PassM (Stms GPUMem)
onStms :: Graph VName
-> Scope GPUMem -> Seq (Stm GPUMem) -> PassM (Seq (Stm GPUMem))
onStms Graph VName
graph Scope GPUMem
scope Seq (Stm GPUMem)
stms = do
let m :: BuilderT GPUMem (StateT VNameSource Identity) (Seq (Stm GPUMem))
m = Scope GPUMem
-> BuilderT GPUMem (StateT VNameSource Identity) (Seq (Stm GPUMem))
-> BuilderT GPUMem (StateT VNameSource Identity) (Seq (Stm GPUMem))
forall a.
Scope GPUMem
-> BuilderT GPUMem (StateT VNameSource Identity) a
-> BuilderT GPUMem (StateT VNameSource Identity) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPUMem
scope (BuilderT GPUMem (StateT VNameSource Identity) (Seq (Stm GPUMem))
-> BuilderT
GPUMem (StateT VNameSource Identity) (Seq (Stm GPUMem)))
-> BuilderT GPUMem (StateT VNameSource Identity) (Seq (Stm GPUMem))
-> BuilderT GPUMem (StateT VNameSource Identity) (Seq (Stm GPUMem))
forall a b. (a -> b) -> a -> b
$ Graph VName
-> SegOp SegLevel GPUMem
-> BuilderT
GPUMem (StateT VNameSource Identity) (SegOp SegLevel GPUMem)
forall (m :: * -> *) lvl.
(MonadBuilder m, Rep m ~ GPUMem) =>
Graph VName -> SegOp lvl GPUMem -> m (SegOp lvl GPUMem)
optimiseKernel Graph VName
graph (SegOp SegLevel GPUMem
-> BuilderT
GPUMem (StateT VNameSource Identity) (SegOp SegLevel GPUMem))
-> Seq (Stm GPUMem)
-> BuilderT GPUMem (StateT VNameSource Identity) (Seq (Stm GPUMem))
forall (m :: * -> *).
LocalScope GPUMem m =>
(SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem))
-> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
`onKernels` Seq (Stm GPUMem)
stms
((Seq (Stm GPUMem), Seq (Stm GPUMem)) -> Seq (Stm GPUMem))
-> PassM (Seq (Stm GPUMem), Seq (Stm GPUMem))
-> PassM (Seq (Stm GPUMem))
forall a b. (a -> b) -> PassM a -> PassM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Seq (Stm GPUMem), Seq (Stm GPUMem)) -> Seq (Stm GPUMem)
forall a b. (a, b) -> a
fst (PassM (Seq (Stm GPUMem), Seq (Stm GPUMem))
-> PassM (Seq (Stm GPUMem)))
-> PassM (Seq (Stm GPUMem), Seq (Stm GPUMem))
-> PassM (Seq (Stm GPUMem))
forall a b. (a -> b) -> a -> b
$ (VNameSource
-> ((Seq (Stm GPUMem), Seq (Stm GPUMem)), VNameSource))
-> PassM (Seq (Stm GPUMem), Seq (Stm GPUMem))
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource
-> ((Seq (Stm GPUMem), Seq (Stm GPUMem)), VNameSource))
-> PassM (Seq (Stm GPUMem), Seq (Stm GPUMem)))
-> (VNameSource
-> ((Seq (Stm GPUMem), Seq (Stm GPUMem)), VNameSource))
-> PassM (Seq (Stm GPUMem), Seq (Stm GPUMem))
forall a b. (a -> b) -> a -> b
$ State VNameSource (Seq (Stm GPUMem), Seq (Stm GPUMem))
-> VNameSource
-> ((Seq (Stm GPUMem), Seq (Stm GPUMem)), VNameSource)
forall s a. State s a -> s -> (a, s)
runState (BuilderT GPUMem (StateT VNameSource Identity) (Seq (Stm GPUMem))
-> Scope GPUMem
-> State VNameSource (Seq (Stm GPUMem), Seq (Stm GPUMem))
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT BuilderT GPUMem (StateT VNameSource Identity) (Seq (Stm GPUMem))
m Scope GPUMem
forall a. Monoid a => a
mempty)