{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Pass.ExtractKernels.Distribution
( Target,
Targets,
ppTargets,
singleTarget,
outerTarget,
innerTarget,
pushInnerTarget,
popInnerTarget,
targetsScope,
LoopNesting (..),
ppLoopNesting,
scopeOfLoopNesting,
Nesting (..),
Nestings,
ppNestings,
letBindInInnerNesting,
singleNesting,
pushInnerNesting,
KernelNest,
ppKernelNest,
newKernel,
innermostKernelNesting,
pushKernelNesting,
pushInnerKernelNesting,
scopeOfKernelNest,
kernelNestLoops,
kernelNestWidths,
boundInKernelNest,
boundInKernelNests,
flatKernel,
constructKernel,
tryDistribute,
tryDistributeStm,
)
where
import Control.Monad
import Control.Monad.RWS.Strict
import Control.Monad.Trans.Maybe
import Data.Bifunctor (second)
import Data.Foldable
import Data.List (elemIndex, sortOn)
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.IR
import Futhark.IR.SegOp
import Futhark.MonadFreshNames
import Futhark.Pass.ExtractKernels.BlockedKernel
( DistRep,
KernelInput (..),
MkSegLevel,
mapKernel,
readKernelInput,
)
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Util
import Futhark.Util.Log
type Target = (Pat Type, Result)
data Targets = Targets
{ Targets -> Target
_innerTarget :: Target,
Targets -> [Target]
_outerTargets :: [Target]
}
ppTargets :: Targets -> String
ppTargets :: Targets -> [Char]
ppTargets (Targets Target
target [Target]
targets) =
[[Char]] -> [Char]
unlines ([[Char]] -> [Char]) -> [[Char]] -> [Char]
forall a b. (a -> b) -> a -> b
$ (Target -> [Char]) -> [Target] -> [[Char]]
forall a b. (a -> b) -> [a] -> [b]
map Target -> [Char]
forall {a} {a}. (Pretty a, Pretty a) => (a, a) -> [Char]
ppTarget ([Target] -> [[Char]]) -> [Target] -> [[Char]]
forall a b. (a -> b) -> a -> b
$ [Target]
targets [Target] -> [Target] -> [Target]
forall a. [a] -> [a] -> [a]
++ [Target
target]
where
ppTarget :: (a, a) -> [Char]
ppTarget (a
pat, a
res) = a -> [Char]
forall a. Pretty a => a -> [Char]
prettyString a
pat [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
" <- " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ a -> [Char]
forall a. Pretty a => a -> [Char]
prettyString a
res
singleTarget :: Target -> Targets
singleTarget :: Target -> Targets
singleTarget = (Target -> [Target] -> Targets) -> [Target] -> Target -> Targets
forall a b c. (a -> b -> c) -> b -> a -> c
flip Target -> [Target] -> Targets
Targets []
outerTarget :: Targets -> Target
outerTarget :: Targets -> Target
outerTarget (Targets Target
inner_target []) = Target
inner_target
outerTarget (Targets Target
_ (Target
outer_target : [Target]
_)) = Target
outer_target
innerTarget :: Targets -> Target
innerTarget :: Targets -> Target
innerTarget (Targets Target
inner_target [Target]
_) = Target
inner_target
pushOuterTarget :: Target -> Targets -> Targets
pushOuterTarget :: Target -> Targets -> Targets
pushOuterTarget Target
target (Targets Target
inner_target [Target]
targets) =
Target -> [Target] -> Targets
Targets Target
inner_target (Target
target Target -> [Target] -> [Target]
forall a. a -> [a] -> [a]
: [Target]
targets)
pushInnerTarget :: Target -> Targets -> Targets
pushInnerTarget :: Target -> Targets -> Targets
pushInnerTarget (Pat Type
pat, Result
res) (Targets Target
inner_target [Target]
targets) =
Target -> [Target] -> Targets
Targets (Pat Type
pat', Result
res') ([Target]
targets [Target] -> [Target] -> [Target]
forall a. [a] -> [a] -> [a]
++ [Target
inner_target])
where
([PatElem Type]
pes', Result
res') = [(PatElem Type, SubExpRes)] -> ([PatElem Type], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(PatElem Type, SubExpRes)] -> ([PatElem Type], Result))
-> [(PatElem Type, SubExpRes)] -> ([PatElem Type], Result)
forall a b. (a -> b) -> a -> b
$ ((PatElem Type, SubExpRes) -> Bool)
-> [(PatElem Type, SubExpRes)] -> [(PatElem Type, SubExpRes)]
forall a. (a -> Bool) -> [a] -> [a]
filter (PatElem Type -> Bool
used (PatElem Type -> Bool)
-> ((PatElem Type, SubExpRes) -> PatElem Type)
-> (PatElem Type, SubExpRes)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElem Type, SubExpRes) -> PatElem Type
forall a b. (a, b) -> a
fst) ([(PatElem Type, SubExpRes)] -> [(PatElem Type, SubExpRes)])
-> [(PatElem Type, SubExpRes)] -> [(PatElem Type, SubExpRes)]
forall a b. (a -> b) -> a -> b
$ [PatElem Type] -> Result -> [(PatElem Type, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
pat) Result
res
pat' :: Pat Type
pat' = [PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type]
pes'
inner_used :: Names
inner_used = Result -> Names
forall a. FreeIn a => a -> Names
freeIn (Result -> Names) -> Result -> Names
forall a b. (a -> b) -> a -> b
$ Target -> Result
forall a b. (a, b) -> b
snd Target
inner_target
used :: PatElem Type -> Bool
used PatElem Type
pe = PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName PatElem Type
pe VName -> Names -> Bool
`nameIn` Names
inner_used
popInnerTarget :: Targets -> Maybe (Target, Targets)
popInnerTarget :: Targets -> Maybe (Target, Targets)
popInnerTarget (Targets Target
t [Target]
ts) =
case [Target] -> [Target]
forall a. [a] -> [a]
reverse [Target]
ts of
Target
x : [Target]
xs -> (Target, Targets) -> Maybe (Target, Targets)
forall a. a -> Maybe a
Just (Target
t, Target -> [Target] -> Targets
Targets Target
x ([Target] -> Targets) -> [Target] -> Targets
forall a b. (a -> b) -> a -> b
$ [Target] -> [Target]
forall a. [a] -> [a]
reverse [Target]
xs)
[] -> Maybe (Target, Targets)
forall a. Maybe a
Nothing
targetScope :: (DistRep rep) => Target -> Scope rep
targetScope :: forall rep. DistRep rep => Target -> Scope rep
targetScope = Pat Type -> Scope rep
forall rep dec. (LetDec rep ~ dec) => Pat dec -> Scope rep
scopeOfPat (Pat Type -> Scope rep)
-> (Target -> Pat Type) -> Target -> Scope rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Target -> Pat Type
forall a b. (a, b) -> a
fst
targetsScope :: (DistRep rep) => Targets -> Scope rep
targetsScope :: forall rep. DistRep rep => Targets -> Scope rep
targetsScope (Targets Target
t [Target]
ts) = [Scope rep] -> Scope rep
forall a. Monoid a => [a] -> a
mconcat ([Scope rep] -> Scope rep) -> [Scope rep] -> Scope rep
forall a b. (a -> b) -> a -> b
$ (Target -> Scope rep) -> [Target] -> [Scope rep]
forall a b. (a -> b) -> [a] -> [b]
map Target -> Scope rep
forall rep. DistRep rep => Target -> Scope rep
targetScope ([Target] -> [Scope rep]) -> [Target] -> [Scope rep]
forall a b. (a -> b) -> a -> b
$ Target
t Target -> [Target] -> [Target]
forall a. a -> [a] -> [a]
: [Target]
ts
data LoopNesting = MapNesting
{ LoopNesting -> Pat Type
loopNestingPat :: Pat Type,
LoopNesting -> StmAux ()
loopNestingAux :: StmAux (),
LoopNesting -> SubExp
loopNestingWidth :: SubExp,
LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs :: [(Param Type, VName)]
}
deriving (Int -> LoopNesting -> [Char] -> [Char]
[LoopNesting] -> [Char] -> [Char]
LoopNesting -> [Char]
(Int -> LoopNesting -> [Char] -> [Char])
-> (LoopNesting -> [Char])
-> ([LoopNesting] -> [Char] -> [Char])
-> Show LoopNesting
forall a.
(Int -> a -> [Char] -> [Char])
-> (a -> [Char]) -> ([a] -> [Char] -> [Char]) -> Show a
$cshowsPrec :: Int -> LoopNesting -> [Char] -> [Char]
showsPrec :: Int -> LoopNesting -> [Char] -> [Char]
$cshow :: LoopNesting -> [Char]
show :: LoopNesting -> [Char]
$cshowList :: [LoopNesting] -> [Char] -> [Char]
showList :: [LoopNesting] -> [Char] -> [Char]
Show)
scopeOfLoopNesting :: (LParamInfo rep ~ Type) => LoopNesting -> Scope rep
scopeOfLoopNesting :: forall rep. (LParamInfo rep ~ Type) => LoopNesting -> Scope rep
scopeOfLoopNesting = [Param Type] -> Scope rep
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams ([Param Type] -> Scope rep)
-> (LoopNesting -> [Param Type]) -> LoopNesting -> Scope rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Param Type, VName) -> Param Type)
-> [(Param Type, VName)] -> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst ([(Param Type, VName)] -> [Param Type])
-> (LoopNesting -> [(Param Type, VName)])
-> LoopNesting
-> [Param Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs
ppLoopNesting :: LoopNesting -> String
ppLoopNesting :: LoopNesting -> [Char]
ppLoopNesting (MapNesting Pat Type
_ StmAux ()
_ SubExp
_ [(Param Type, VName)]
params_and_arrs) =
[Param Type] -> [Char]
forall a. Pretty a => a -> [Char]
prettyString (((Param Type, VName) -> Param Type)
-> [(Param Type, VName)] -> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst [(Param Type, VName)]
params_and_arrs)
[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
" <- "
[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [VName] -> [Char]
forall a. Pretty a => a -> [Char]
prettyString (((Param Type, VName) -> VName) -> [(Param Type, VName)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param Type, VName) -> VName
forall a b. (a, b) -> b
snd [(Param Type, VName)]
params_and_arrs)
loopNestingParams :: LoopNesting -> [Param Type]
loopNestingParams :: LoopNesting -> [Param Type]
loopNestingParams = ((Param Type, VName) -> Param Type)
-> [(Param Type, VName)] -> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst ([(Param Type, VName)] -> [Param Type])
-> (LoopNesting -> [(Param Type, VName)])
-> LoopNesting
-> [Param Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs
instance FreeIn LoopNesting where
freeIn' :: LoopNesting -> FV
freeIn' (MapNesting Pat Type
pat StmAux ()
aux SubExp
w [(Param Type, VName)]
params_and_arrs) =
Pat Type -> FV
forall a. FreeIn a => a -> FV
freeIn' Pat Type
pat FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> StmAux () -> FV
forall a. FreeIn a => a -> FV
freeIn' StmAux ()
aux FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
w FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [(Param Type, VName)] -> FV
forall a. FreeIn a => a -> FV
freeIn' [(Param Type, VName)]
params_and_arrs
data Nesting = Nesting
{ Nesting -> Names
nestingLetBound :: Names,
Nesting -> LoopNesting
nestingLoop :: LoopNesting
}
deriving (Int -> Nesting -> [Char] -> [Char]
[Nesting] -> [Char] -> [Char]
Nesting -> [Char]
(Int -> Nesting -> [Char] -> [Char])
-> (Nesting -> [Char])
-> ([Nesting] -> [Char] -> [Char])
-> Show Nesting
forall a.
(Int -> a -> [Char] -> [Char])
-> (a -> [Char]) -> ([a] -> [Char] -> [Char]) -> Show a
$cshowsPrec :: Int -> Nesting -> [Char] -> [Char]
showsPrec :: Int -> Nesting -> [Char] -> [Char]
$cshow :: Nesting -> [Char]
show :: Nesting -> [Char]
$cshowList :: [Nesting] -> [Char] -> [Char]
showList :: [Nesting] -> [Char] -> [Char]
Show)
letBindInNesting :: Names -> Nesting -> Nesting
letBindInNesting :: Names -> Nesting -> Nesting
letBindInNesting Names
newnames (Nesting Names
oldnames LoopNesting
loop) =
Names -> LoopNesting -> Nesting
Nesting (Names
oldnames Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
newnames) LoopNesting
loop
type Nestings = (Nesting, [Nesting])
ppNestings :: Nestings -> String
ppNestings :: Nestings -> [Char]
ppNestings (Nesting
nesting, [Nesting]
nestings) =
[[Char]] -> [Char]
unlines ([[Char]] -> [Char]) -> [[Char]] -> [Char]
forall a b. (a -> b) -> a -> b
$ (Nesting -> [Char]) -> [Nesting] -> [[Char]]
forall a b. (a -> b) -> [a] -> [b]
map Nesting -> [Char]
ppNesting ([Nesting] -> [[Char]]) -> [Nesting] -> [[Char]]
forall a b. (a -> b) -> a -> b
$ [Nesting]
nestings [Nesting] -> [Nesting] -> [Nesting]
forall a. [a] -> [a] -> [a]
++ [Nesting
nesting]
where
ppNesting :: Nesting -> [Char]
ppNesting (Nesting Names
_ LoopNesting
loop) = LoopNesting -> [Char]
ppLoopNesting LoopNesting
loop
singleNesting :: Nesting -> Nestings
singleNesting :: Nesting -> Nestings
singleNesting = (,[])
pushInnerNesting :: Nesting -> Nestings -> Nestings
pushInnerNesting :: Nesting -> Nestings -> Nestings
pushInnerNesting Nesting
nesting (Nesting
inner_nesting, [Nesting]
nestings) =
(Nesting
nesting, [Nesting]
nestings [Nesting] -> [Nesting] -> [Nesting]
forall a. [a] -> [a] -> [a]
++ [Nesting
inner_nesting])
boundInNesting :: Nesting -> Names
boundInNesting :: Nesting -> Names
boundInNesting Nesting
nesting =
[VName] -> Names
namesFromList ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName (LoopNesting -> [Param Type]
loopNestingParams LoopNesting
loop))
Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Nesting -> Names
nestingLetBound Nesting
nesting
where
loop :: LoopNesting
loop = Nesting -> LoopNesting
nestingLoop Nesting
nesting
letBindInInnerNesting :: Names -> Nestings -> Nestings
letBindInInnerNesting :: Names -> Nestings -> Nestings
letBindInInnerNesting Names
names (Nesting
nest, [Nesting]
nestings) =
(Names -> Nesting -> Nesting
letBindInNesting Names
names Nesting
nest, [Nesting]
nestings)
type KernelNest = (LoopNesting, [LoopNesting])
ppKernelNest :: KernelNest -> String
ppKernelNest :: KernelNest -> [Char]
ppKernelNest (LoopNesting
nesting, [LoopNesting]
nestings) =
[[Char]] -> [Char]
unlines ([[Char]] -> [Char]) -> [[Char]] -> [Char]
forall a b. (a -> b) -> a -> b
$ (LoopNesting -> [Char]) -> [LoopNesting] -> [[Char]]
forall a b. (a -> b) -> [a] -> [b]
map LoopNesting -> [Char]
ppLoopNesting ([LoopNesting] -> [[Char]]) -> [LoopNesting] -> [[Char]]
forall a b. (a -> b) -> a -> b
$ LoopNesting
nesting LoopNesting -> [LoopNesting] -> [LoopNesting]
forall a. a -> [a] -> [a]
: [LoopNesting]
nestings
innermostKernelNesting :: KernelNest -> LoopNesting
innermostKernelNesting :: KernelNest -> LoopNesting
innermostKernelNesting (LoopNesting
nest, [LoopNesting]
nests) =
LoopNesting -> Maybe LoopNesting -> LoopNesting
forall a. a -> Maybe a -> a
fromMaybe LoopNesting
nest (Maybe LoopNesting -> LoopNesting)
-> Maybe LoopNesting -> LoopNesting
forall a b. (a -> b) -> a -> b
$ [LoopNesting] -> Maybe LoopNesting
forall a. [a] -> Maybe a
maybeHead ([LoopNesting] -> Maybe LoopNesting)
-> [LoopNesting] -> Maybe LoopNesting
forall a b. (a -> b) -> a -> b
$ [LoopNesting] -> [LoopNesting]
forall a. [a] -> [a]
reverse [LoopNesting]
nests
pushKernelNesting :: Target -> LoopNesting -> KernelNest -> KernelNest
pushKernelNesting :: Target -> LoopNesting -> KernelNest -> KernelNest
pushKernelNesting Target
target LoopNesting
newnest (LoopNesting
nest, [LoopNesting]
nests) =
( LoopNesting -> Target -> Pat Type -> LoopNesting
fixNestingPatOrder LoopNesting
newnest Target
target (LoopNesting -> Pat Type
loopNestingPat LoopNesting
nest),
LoopNesting
nest LoopNesting -> [LoopNesting] -> [LoopNesting]
forall a. a -> [a] -> [a]
: [LoopNesting]
nests
)
pushInnerKernelNesting :: Target -> LoopNesting -> KernelNest -> KernelNest
pushInnerKernelNesting :: Target -> LoopNesting -> KernelNest -> KernelNest
pushInnerKernelNesting Target
target LoopNesting
newnest (LoopNesting
nest, [LoopNesting]
nests) =
(LoopNesting
nest, [LoopNesting]
nests [LoopNesting] -> [LoopNesting] -> [LoopNesting]
forall a. [a] -> [a] -> [a]
++ [LoopNesting -> Target -> Pat Type -> LoopNesting
fixNestingPatOrder LoopNesting
newnest Target
target (LoopNesting -> Pat Type
loopNestingPat LoopNesting
innermost)])
where
innermost :: LoopNesting
innermost = case [LoopNesting] -> [LoopNesting]
forall a. [a] -> [a]
reverse [LoopNesting]
nests of
[] -> LoopNesting
nest
LoopNesting
n : [LoopNesting]
_ -> LoopNesting
n
fixNestingPatOrder :: LoopNesting -> Target -> Pat Type -> LoopNesting
fixNestingPatOrder :: LoopNesting -> Target -> Pat Type -> LoopNesting
fixNestingPatOrder LoopNesting
nest (Pat Type
_, Result
res) Pat Type
inner_pat =
LoopNesting
nest {loopNestingPat = basicPat pat'}
where
pat :: Pat Type
pat = LoopNesting -> Pat Type
loopNestingPat LoopNesting
nest
pat' :: [Ident]
pat' = ((Ident, SubExpRes) -> Ident) -> [(Ident, SubExpRes)] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map (Ident, SubExpRes) -> Ident
forall a b. (a, b) -> a
fst [(Ident, SubExpRes)]
fixed_target
fixed_target :: [(Ident, SubExpRes)]
fixed_target = ((Ident, SubExpRes) -> Int)
-> [(Ident, SubExpRes)] -> [(Ident, SubExpRes)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (Ident, SubExpRes) -> Int
posInInnerPat ([(Ident, SubExpRes)] -> [(Ident, SubExpRes)])
-> [(Ident, SubExpRes)] -> [(Ident, SubExpRes)]
forall a b. (a -> b) -> a -> b
$ [Ident] -> Result -> [(Ident, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat Type -> [Ident]
forall dec. Typed dec => Pat dec -> [Ident]
patIdents Pat Type
pat) Result
res
posInInnerPat :: (Ident, SubExpRes) -> Int
posInInnerPat (Ident
_, SubExpRes Certs
_ (Var VName
v)) = Int -> Maybe Int -> Int
forall a. a -> Maybe a -> a
fromMaybe Int
0 (Maybe Int -> Int) -> Maybe Int -> Int
forall a b. (a -> b) -> a -> b
$ VName -> [VName] -> Maybe Int
forall a. Eq a => a -> [a] -> Maybe Int
elemIndex VName
v ([VName] -> Maybe Int) -> [VName] -> Maybe Int
forall a b. (a -> b) -> a -> b
$ Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
inner_pat
posInInnerPat (Ident, SubExpRes)
_ = Int
0
newKernel :: LoopNesting -> KernelNest
newKernel :: LoopNesting -> KernelNest
newKernel LoopNesting
nest = (LoopNesting
nest, [])
kernelNestLoops :: KernelNest -> [LoopNesting]
kernelNestLoops :: KernelNest -> [LoopNesting]
kernelNestLoops (LoopNesting
loop, [LoopNesting]
loops) = LoopNesting
loop LoopNesting -> [LoopNesting] -> [LoopNesting]
forall a. a -> [a] -> [a]
: [LoopNesting]
loops
scopeOfKernelNest :: (LParamInfo rep ~ Type) => KernelNest -> Scope rep
scopeOfKernelNest :: forall rep. (LParamInfo rep ~ Type) => KernelNest -> Scope rep
scopeOfKernelNest = (LoopNesting -> Scope rep) -> [LoopNesting] -> Scope rep
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap LoopNesting -> Scope rep
forall rep. (LParamInfo rep ~ Type) => LoopNesting -> Scope rep
scopeOfLoopNesting ([LoopNesting] -> Scope rep)
-> (KernelNest -> [LoopNesting]) -> KernelNest -> Scope rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelNest -> [LoopNesting]
kernelNestLoops
boundInKernelNest :: KernelNest -> Names
boundInKernelNest :: KernelNest -> Names
boundInKernelNest = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names)
-> (KernelNest -> [Names]) -> KernelNest -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelNest -> [Names]
boundInKernelNests
boundInKernelNests :: KernelNest -> [Names]
boundInKernelNests :: KernelNest -> [Names]
boundInKernelNests =
(LoopNesting -> Names) -> [LoopNesting] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Names
namesFromList ([VName] -> Names)
-> (LoopNesting -> [VName]) -> LoopNesting -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Param Type, VName) -> VName) -> [(Param Type, VName)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param Type -> VName
forall dec. Param dec -> VName
paramName (Param Type -> VName)
-> ((Param Type, VName) -> Param Type)
-> (Param Type, VName)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst) ([(Param Type, VName)] -> [VName])
-> (LoopNesting -> [(Param Type, VName)]) -> LoopNesting -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs)
([LoopNesting] -> [Names])
-> (KernelNest -> [LoopNesting]) -> KernelNest -> [Names]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelNest -> [LoopNesting]
kernelNestLoops
kernelNestWidths :: KernelNest -> [SubExp]
kernelNestWidths :: KernelNest -> [SubExp]
kernelNestWidths = (LoopNesting -> SubExp) -> [LoopNesting] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map LoopNesting -> SubExp
loopNestingWidth ([LoopNesting] -> [SubExp])
-> (KernelNest -> [LoopNesting]) -> KernelNest -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelNest -> [LoopNesting]
kernelNestLoops
constructKernel ::
(DistRep rep, MonadFreshNames m, LocalScope rep m) =>
MkSegLevel rep m ->
KernelNest ->
Body rep ->
m (Stm rep, Stms rep)
constructKernel :: forall rep (m :: * -> *).
(DistRep rep, MonadFreshNames m, LocalScope rep m) =>
MkSegLevel rep m -> KernelNest -> Body rep -> m (Stm rep, Stms rep)
constructKernel MkSegLevel rep m
mk_lvl KernelNest
kernel_nest Body rep
inner_body = BuilderT rep m (Stm rep) -> m (Stm rep, Stms rep)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
BuilderT rep m a -> m (a, Stms rep)
runBuilderT' (BuilderT rep m (Stm rep) -> m (Stm rep, Stms rep))
-> BuilderT rep m (Stm rep) -> m (Stm rep, Stms rep)
forall a b. (a -> b) -> a -> b
$ do
([(VName, SubExp)]
ispace, [KernelInput]
inps) <- KernelNest -> BuilderT rep m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
kernel_nest
let aux :: StmAux ()
aux = LoopNesting -> StmAux ()
loopNestingAux LoopNesting
first_nest
ispace_scope :: Map VName (NameInfo rep)
ispace_scope = [(VName, NameInfo rep)] -> Map VName (NameInfo rep)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, NameInfo rep)] -> Map VName (NameInfo rep))
-> [(VName, NameInfo rep)] -> Map VName (NameInfo rep)
forall a b. (a -> b) -> a -> b
$ ((VName, SubExp) -> (VName, NameInfo rep))
-> [(VName, SubExp)] -> [(VName, NameInfo rep)]
forall a b. (a -> b) -> [a] -> [b]
map ((,IntType -> NameInfo rep
forall rep. IntType -> NameInfo rep
IndexName IntType
Int64) (VName -> (VName, NameInfo rep))
-> ((VName, SubExp) -> VName)
-> (VName, SubExp)
-> (VName, NameInfo rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst) [(VName, SubExp)]
ispace
pat :: Pat Type
pat = LoopNesting -> Pat Type
loopNestingPat LoopNesting
first_nest
rts :: [Type]
rts = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Type -> Type
forall u. Int -> TypeBase Shape u -> TypeBase Shape u
stripArray ([(VName, SubExp)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(VName, SubExp)]
ispace)) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Pat Type -> [Type]
forall dec. Typed dec => Pat dec -> [Type]
patTypes Pat Type
pat
KernelBody rep
inner_body' <- (([KernelResult], Stms rep) -> KernelBody rep)
-> BuilderT rep m ([KernelResult], Stms rep)
-> BuilderT rep m (KernelBody rep)
forall a b. (a -> b) -> BuilderT rep m a -> BuilderT rep m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([KernelResult] -> Stms rep -> KernelBody rep)
-> ([KernelResult], Stms rep) -> KernelBody rep
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((Stms rep -> [KernelResult] -> KernelBody rep)
-> [KernelResult] -> Stms rep -> KernelBody rep
forall a b c. (a -> b -> c) -> b -> a -> c
flip (BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody ()))) (BuilderT rep m ([KernelResult], Stms rep)
-> BuilderT rep m (KernelBody rep))
-> BuilderT rep m ([KernelResult], Stms rep)
-> BuilderT rep m (KernelBody rep)
forall a b. (a -> b) -> a -> b
$
Builder rep [KernelResult]
-> BuilderT rep m ([KernelResult], Stms rep)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder rep [KernelResult]
-> BuilderT rep m ([KernelResult], Stms rep))
-> (Builder rep [KernelResult] -> Builder rep [KernelResult])
-> Builder rep [KernelResult]
-> BuilderT rep m ([KernelResult], Stms rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName (NameInfo rep)
-> Builder rep [KernelResult] -> Builder rep [KernelResult]
forall a.
Map VName (NameInfo rep)
-> BuilderT rep (State VNameSource) a
-> BuilderT rep (State VNameSource) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Map VName (NameInfo rep)
ispace_scope (Builder rep [KernelResult]
-> BuilderT rep m ([KernelResult], Stms rep))
-> Builder rep [KernelResult]
-> BuilderT rep m ([KernelResult], Stms rep)
forall a b. (a -> b) -> a -> b
$ do
(KernelInput -> BuilderT rep (State VNameSource) ())
-> [KernelInput] -> BuilderT rep (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ KernelInput -> BuilderT rep (State VNameSource) ()
forall (m :: * -> *).
(DistRep (Rep m), MonadBuilder m) =>
KernelInput -> m ()
readKernelInput ([KernelInput] -> BuilderT rep (State VNameSource) ())
-> [KernelInput] -> BuilderT rep (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ (KernelInput -> Bool) -> [KernelInput] -> [KernelInput]
forall a. (a -> Bool) -> [a] -> [a]
filter KernelInput -> Bool
inputIsUsed [KernelInput]
inps
Result
res <- Body (Rep (BuilderT rep (State VNameSource)))
-> BuilderT rep (State VNameSource) Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind Body rep
Body (Rep (BuilderT rep (State VNameSource)))
inner_body
Result
-> (SubExpRes -> BuilderT rep (State VNameSource) KernelResult)
-> Builder rep [KernelResult]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM Result
res ((SubExpRes -> BuilderT rep (State VNameSource) KernelResult)
-> Builder rep [KernelResult])
-> (SubExpRes -> BuilderT rep (State VNameSource) KernelResult)
-> Builder rep [KernelResult]
forall a b. (a -> b) -> a -> b
$ \(SubExpRes Certs
cs SubExp
se) -> KernelResult -> BuilderT rep (State VNameSource) KernelResult
forall a. a -> BuilderT rep (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelResult -> BuilderT rep (State VNameSource) KernelResult)
-> KernelResult -> BuilderT rep (State VNameSource) KernelResult
forall a b. (a -> b) -> a -> b
$ ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify Certs
cs SubExp
se
(SegOp (SegOpLevel rep) rep
segop, Stms rep
aux_stms) <- m (SegOp (SegOpLevel rep) rep, Stms rep)
-> BuilderT rep m (SegOp (SegOpLevel rep) rep, Stms rep)
forall (m :: * -> *) a. Monad m => m a -> BuilderT rep m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (SegOp (SegOpLevel rep) rep, Stms rep)
-> BuilderT rep m (SegOp (SegOpLevel rep) rep, Stms rep))
-> m (SegOp (SegOpLevel rep) rep, Stms rep)
-> BuilderT rep m (SegOp (SegOpLevel rep) rep, Stms rep)
forall a b. (a -> b) -> a -> b
$ MkSegLevel rep m
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody rep
-> m (SegOp (SegOpLevel rep) rep, Stms rep)
forall rep (m :: * -> *).
(DistRep rep, HasScope rep m, MonadFreshNames m) =>
MkSegLevel rep m
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody rep
-> m (SegOp (SegOpLevel rep) rep, Stms rep)
mapKernel MkSegLevel rep m
mk_lvl [(VName, SubExp)]
ispace [] [Type]
rts KernelBody rep
inner_body'
Stms (Rep (BuilderT rep m)) -> BuilderT rep m ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms rep
Stms (Rep (BuilderT rep m))
aux_stms
Stm rep -> BuilderT rep m (Stm rep)
forall a. a -> BuilderT rep m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm rep -> BuilderT rep m (Stm rep))
-> Stm rep -> BuilderT rep m (Stm rep)
forall a b. (a -> b) -> a -> b
$ Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec rep)
pat StmAux ()
StmAux (ExpDec rep)
aux (Exp rep -> Stm rep) -> Exp rep -> Stm rep
forall a b. (a -> b) -> a -> b
$ OpC rep rep -> Exp rep
forall rep. Op rep -> Exp rep
Op (OpC rep rep -> Exp rep) -> OpC rep rep -> Exp rep
forall a b. (a -> b) -> a -> b
$ SegOp (SegOpLevel rep) rep -> OpC rep rep
forall rep. HasSegOp rep => SegOp (SegOpLevel rep) rep -> Op rep
segOp SegOp (SegOpLevel rep) rep
segop
where
first_nest :: LoopNesting
first_nest = KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
kernel_nest
inputIsUsed :: KernelInput -> Bool
inputIsUsed KernelInput
input = KernelInput -> VName
kernelInputName KernelInput
input VName -> Names -> Bool
`nameIn` Body rep -> Names
forall a. FreeIn a => a -> Names
freeIn Body rep
inner_body
flatKernel ::
(MonadFreshNames m) =>
KernelNest ->
m ([(VName, SubExp)], [KernelInput])
flatKernel :: forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel (MapNesting Pat Type
_ StmAux ()
_ SubExp
nesting_w [(Param Type, VName)]
params_and_arrs, []) = do
VName
i <- [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"gtid"
let inps :: [KernelInput]
inps =
[ VName -> Type -> VName -> [SubExp] -> KernelInput
KernelInput VName
pname Type
ptype VName
arr [VName -> SubExp
Var VName
i]
| (Param Attrs
_ VName
pname Type
ptype, VName
arr) <- [(Param Type, VName)]
params_and_arrs
]
([(VName, SubExp)], [KernelInput])
-> m ([(VName, SubExp)], [KernelInput])
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([(VName
i, SubExp
nesting_w)], [KernelInput]
inps)
flatKernel (MapNesting Pat Type
_ StmAux ()
_ SubExp
nesting_w [(Param Type, VName)]
params_and_arrs, LoopNesting
nest : [LoopNesting]
nests) = do
VName
i <- [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"gtid"
([(VName, SubExp)]
ispace, [KernelInput]
inps) <- KernelNest -> m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel (LoopNesting
nest, [LoopNesting]
nests)
let inps' :: [KernelInput]
inps' = (KernelInput -> KernelInput) -> [KernelInput] -> [KernelInput]
forall a b. (a -> b) -> [a] -> [b]
map KernelInput -> KernelInput
fixupInput [KernelInput]
inps
isParam :: KernelInput -> Maybe VName
isParam KernelInput
inp =
(Param Type, VName) -> VName
forall a b. (a, b) -> b
snd ((Param Type, VName) -> VName)
-> Maybe (Param Type, VName) -> Maybe VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Param Type, VName) -> Bool)
-> [(Param Type, VName)] -> Maybe (Param Type, VName)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== KernelInput -> VName
kernelInputArray KernelInput
inp) (VName -> Bool)
-> ((Param Type, VName) -> VName) -> (Param Type, VName) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> VName
forall dec. Param dec -> VName
paramName (Param Type -> VName)
-> ((Param Type, VName) -> Param Type)
-> (Param Type, VName)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst) [(Param Type, VName)]
params_and_arrs
fixupInput :: KernelInput -> KernelInput
fixupInput KernelInput
inp
| Just VName
arr <- KernelInput -> Maybe VName
isParam KernelInput
inp =
KernelInput
inp
{ kernelInputArray = arr,
kernelInputIndices = Var i : kernelInputIndices inp
}
| Bool
otherwise =
KernelInput
inp
([(VName, SubExp)], [KernelInput])
-> m ([(VName, SubExp)], [KernelInput])
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((VName
i, SubExp
nesting_w) (VName, SubExp) -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. a -> [a] -> [a]
: [(VName, SubExp)]
ispace, VName -> [KernelInput]
extra_inps VName
i [KernelInput] -> [KernelInput] -> [KernelInput]
forall a. Semigroup a => a -> a -> a
<> [KernelInput]
inps')
where
extra_inps :: VName -> [KernelInput]
extra_inps VName
i =
[ VName -> Type -> VName -> [SubExp] -> KernelInput
KernelInput VName
pname Type
ptype VName
arr [VName -> SubExp
Var VName
i]
| (Param Attrs
_ VName
pname Type
ptype, VName
arr) <- [(Param Type, VName)]
params_and_arrs
]
data DistributionBody = DistributionBody
{ DistributionBody -> Targets
distributionTarget :: Targets,
DistributionBody -> Names
distributionFreeInBody :: Names,
DistributionBody -> Map VName Ident
distributionIdentityMap :: M.Map VName Ident,
DistributionBody -> Target -> Target
distributionExpandTarget :: Target -> Target
}
distributionInnerPat :: DistributionBody -> Pat Type
distributionInnerPat :: DistributionBody -> Pat Type
distributionInnerPat = Target -> Pat Type
forall a b. (a, b) -> a
fst (Target -> Pat Type)
-> (DistributionBody -> Target) -> DistributionBody -> Pat Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Targets -> Target
innerTarget (Targets -> Target)
-> (DistributionBody -> Targets) -> DistributionBody -> Target
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DistributionBody -> Targets
distributionTarget
distributionBodyFromStms ::
(ASTRep rep) =>
Targets ->
Stms rep ->
(DistributionBody, Result)
distributionBodyFromStms :: forall rep.
ASTRep rep =>
Targets -> Stms rep -> (DistributionBody, Result)
distributionBodyFromStms (Targets (Pat Type
inner_pat, Result
inner_res) [Target]
targets) Stms rep
stms =
let bound_by_stms :: Names
bound_by_stms = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Map VName (NameInfo rep) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (NameInfo rep) -> [VName])
-> Map VName (NameInfo rep) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stms rep -> Map VName (NameInfo rep)
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms rep
stms
(Pat Type
inner_pat', Result
inner_res', Map VName Ident
inner_identity_map, Target -> Target
inner_expand_target) =
Names
-> Pat Type
-> Result
-> (Pat Type, Result, Map VName Ident, Target -> Target)
removeIdentityMappingGeneral Names
bound_by_stms Pat Type
inner_pat Result
inner_res
free :: Names
free =
((Stm rep -> Names) -> Stms rep -> Names
forall m a. Monoid m => (a -> m) -> Seq a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm rep -> Names
forall a. FreeIn a => a -> Names
freeIn Stms rep
stms Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [Certs] -> Names
forall a. FreeIn a => a -> Names
freeIn ((SubExpRes -> Certs) -> Result -> [Certs]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> Certs
resCerts Result
inner_res))
Names -> Names -> Names
`namesSubtract` Names
bound_by_stms
in ( DistributionBody
{ distributionTarget :: Targets
distributionTarget = Target -> [Target] -> Targets
Targets (Pat Type
inner_pat', Result
inner_res') [Target]
targets,
distributionFreeInBody :: Names
distributionFreeInBody = Names
free,
distributionIdentityMap :: Map VName Ident
distributionIdentityMap = Map VName Ident
inner_identity_map,
distributionExpandTarget :: Target -> Target
distributionExpandTarget = Target -> Target
inner_expand_target
},
Result
inner_res'
)
distributionBodyFromStm ::
(ASTRep rep) =>
Targets ->
Stm rep ->
(DistributionBody, Result)
distributionBodyFromStm :: forall rep.
ASTRep rep =>
Targets -> Stm rep -> (DistributionBody, Result)
distributionBodyFromStm Targets
targets Stm rep
stm =
Targets -> Stms rep -> (DistributionBody, Result)
forall rep.
ASTRep rep =>
Targets -> Stms rep -> (DistributionBody, Result)
distributionBodyFromStms Targets
targets (Stms rep -> (DistributionBody, Result))
-> Stms rep -> (DistributionBody, Result)
forall a b. (a -> b) -> a -> b
$ Stm rep -> Stms rep
forall rep. Stm rep -> Stms rep
oneStm Stm rep
stm
createKernelNest ::
forall rep m.
(MonadFreshNames m, HasScope rep m) =>
Nestings ->
DistributionBody ->
m (Maybe (Targets, KernelNest))
createKernelNest :: forall rep (m :: * -> *).
(MonadFreshNames m, HasScope rep m) =>
Nestings -> DistributionBody -> m (Maybe (Targets, KernelNest))
createKernelNest (Nesting
inner_nest, [Nesting]
nests) DistributionBody
distrib_body = do
let Targets Target
target [Target]
targets = DistributionBody -> Targets
distributionTarget DistributionBody
distrib_body
Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Nesting] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Nesting]
nests Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Target] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Target]
targets) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
[Char] -> m ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> m ()) -> [Char] -> m ()
forall a b. (a -> b) -> a -> b
$
[Char]
"Nests and targets do not match!\n"
[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"nests: "
[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Nestings -> [Char]
ppNestings (Nesting
inner_nest, [Nesting]
nests)
[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"\ntargets:"
[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Targets -> [Char]
ppTargets (Target -> [Target] -> Targets
Targets Target
target [Target]
targets)
MaybeT m (Targets, KernelNest) -> m (Maybe (Targets, KernelNest))
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT (MaybeT m (Targets, KernelNest) -> m (Maybe (Targets, KernelNest)))
-> MaybeT m (Targets, KernelNest)
-> m (Maybe (Targets, KernelNest))
forall a b. (a -> b) -> a -> b
$ ((KernelNest, Names, Targets) -> (Targets, KernelNest))
-> MaybeT m (KernelNest, Names, Targets)
-> MaybeT m (Targets, KernelNest)
forall a b. (a -> b) -> MaybeT m a -> MaybeT m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (KernelNest, Names, Targets) -> (Targets, KernelNest)
forall {b} {b} {a}. (b, b, a) -> (a, b)
prepare (MaybeT m (KernelNest, Names, Targets)
-> MaybeT m (Targets, KernelNest))
-> MaybeT m (KernelNest, Names, Targets)
-> MaybeT m (Targets, KernelNest)
forall a b. (a -> b) -> a -> b
$ [(Nesting, Target)] -> MaybeT m (KernelNest, Names, Targets)
recurse ([(Nesting, Target)] -> MaybeT m (KernelNest, Names, Targets))
-> [(Nesting, Target)] -> MaybeT m (KernelNest, Names, Targets)
forall a b. (a -> b) -> a -> b
$ [Nesting] -> [Target] -> [(Nesting, Target)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Nesting]
nests [Target]
targets
where
prepare :: (b, b, a) -> (a, b)
prepare (b
x, b
_, a
z) = (a
z, b
x)
bound_in_nest :: Names
bound_in_nest = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (Nesting -> Names) -> [Nesting] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map Nesting -> Names
boundInNesting ([Nesting] -> [Names]) -> [Nesting] -> [Names]
forall a b. (a -> b) -> a -> b
$ Nesting
inner_nest Nesting -> [Nesting] -> [Nesting]
forall a. a -> [a] -> [a]
: [Nesting]
nests
distributableType :: Type -> Bool
distributableType =
(Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
== Names
forall a. Monoid a => a
mempty) (Names -> Bool) -> (Type -> Names) -> Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> Names -> Names
namesIntersection Names
bound_in_nest (Names -> Names) -> (Type -> Names) -> Type -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [SubExp] -> Names
forall a. FreeIn a => a -> Names
freeIn ([SubExp] -> Names) -> (Type -> [SubExp]) -> Type -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims
distributeAtNesting ::
Nesting ->
Pat Type ->
(LoopNesting -> KernelNest, Names) ->
M.Map VName Ident ->
[Ident] ->
(Target -> Targets) ->
MaybeT m (KernelNest, Names, Targets)
distributeAtNesting :: Nesting
-> Pat Type
-> (LoopNesting -> KernelNest, Names)
-> Map VName Ident
-> [Ident]
-> (Target -> Targets)
-> MaybeT m (KernelNest, Names, Targets)
distributeAtNesting
(Nesting Names
nest_let_bound LoopNesting
nest)
Pat Type
pat
(LoopNesting -> KernelNest
add_to_kernel, Names
free_in_kernel)
Map VName Ident
identity_map
[Ident]
inner_returned_arrs
Target -> Targets
addTarget = do
let nest' :: LoopNesting
nest'@(MapNesting Pat Type
_ StmAux ()
aux SubExp
w [(Param Type, VName)]
params_and_arrs) =
Names -> LoopNesting -> LoopNesting
removeUnusedNestingParts Names
free_in_kernel LoopNesting
nest
([Param Type]
params, [VName]
arrs) = [(Param Type, VName)] -> ([Param Type], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param Type, VName)]
params_and_arrs
param_names :: Names
param_names = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
params
free_in_kernel' :: Names
free_in_kernel' =
(LoopNesting -> Names
forall a. FreeIn a => a -> Names
freeIn LoopNesting
nest' Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
free_in_kernel) Names -> Names -> Names
`namesSubtract` Names
param_names
required_from_nest :: Names
required_from_nest =
Names
free_in_kernel' Names -> Names -> Names
`namesIntersection` Names
nest_let_bound
[Ident]
required_from_nest_idents <-
[VName] -> (VName -> MaybeT m Ident) -> MaybeT m [Ident]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (Names -> [VName]
namesToList Names
required_from_nest) ((VName -> MaybeT m Ident) -> MaybeT m [Ident])
-> (VName -> MaybeT m Ident) -> MaybeT m [Ident]
forall a b. (a -> b) -> a -> b
$ \VName
name -> do
Type
t <- m Type -> MaybeT m Type
forall (m :: * -> *) a. Monad m => m a -> MaybeT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Type -> MaybeT m Type) -> m Type -> MaybeT m Type
forall a b. (a -> b) -> a -> b
$ VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
name
Ident -> MaybeT m Ident
forall a. a -> MaybeT m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Ident -> MaybeT m Ident) -> Ident -> MaybeT m Ident
forall a b. (a -> b) -> a -> b
$ VName -> Type -> Ident
Ident VName
name Type
t
([Param Type]
free_params, [Ident]
free_arrs, [Bool]
bind_in_target) <-
([(Param Type, Ident, Bool)] -> ([Param Type], [Ident], [Bool]))
-> MaybeT m [(Param Type, Ident, Bool)]
-> MaybeT m ([Param Type], [Ident], [Bool])
forall a b. (a -> b) -> MaybeT m a -> MaybeT m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(Param Type, Ident, Bool)] -> ([Param Type], [Ident], [Bool])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 (MaybeT m [(Param Type, Ident, Bool)]
-> MaybeT m ([Param Type], [Ident], [Bool]))
-> MaybeT m [(Param Type, Ident, Bool)]
-> MaybeT m ([Param Type], [Ident], [Bool])
forall a b. (a -> b) -> a -> b
$
[Ident]
-> (Ident -> MaybeT m (Param Type, Ident, Bool))
-> MaybeT m [(Param Type, Ident, Bool)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Ident]
inner_returned_arrs [Ident] -> [Ident] -> [Ident]
forall a. [a] -> [a] -> [a]
++ [Ident]
required_from_nest_idents) ((Ident -> MaybeT m (Param Type, Ident, Bool))
-> MaybeT m [(Param Type, Ident, Bool)])
-> (Ident -> MaybeT m (Param Type, Ident, Bool))
-> MaybeT m [(Param Type, Ident, Bool)]
forall a b. (a -> b) -> a -> b
$
\(Ident VName
pname Type
ptype) ->
case VName -> Map VName Ident -> Maybe Ident
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
pname Map VName Ident
identity_map of
Maybe Ident
Nothing -> do
Ident
arr <-
[Char] -> Type -> MaybeT m Ident
forall (m :: * -> *).
MonadFreshNames m =>
[Char] -> Type -> m Ident
newIdent (VName -> [Char]
baseString VName
pname [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_r") (Type -> MaybeT m Ident) -> Type -> MaybeT m Ident
forall a b. (a -> b) -> a -> b
$ Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
arrayOfRow Type
ptype SubExp
w
(Param Type, Ident, Bool) -> MaybeT m (Param Type, Ident, Bool)
forall a. a -> MaybeT m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( Attrs -> VName -> Type -> Param Type
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty VName
pname Type
ptype,
Ident
arr,
Bool
True
)
Just Ident
arr ->
(Param Type, Ident, Bool) -> MaybeT m (Param Type, Ident, Bool)
forall a. a -> MaybeT m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( Attrs -> VName -> Type -> Param Type
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty VName
pname Type
ptype,
Ident
arr,
Bool
False
)
let free_arrs_pat :: Pat Type
free_arrs_pat =
[Ident] -> Pat Type
basicPat ([Ident] -> Pat Type) -> [Ident] -> Pat Type
forall a b. (a -> b) -> a -> b
$ ((Bool, Ident) -> Ident) -> [(Bool, Ident)] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map (Bool, Ident) -> Ident
forall a b. (a, b) -> b
snd ([(Bool, Ident)] -> [Ident]) -> [(Bool, Ident)] -> [Ident]
forall a b. (a -> b) -> a -> b
$ ((Bool, Ident) -> Bool) -> [(Bool, Ident)] -> [(Bool, Ident)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool, Ident) -> Bool
forall a b. (a, b) -> a
fst ([(Bool, Ident)] -> [(Bool, Ident)])
-> [(Bool, Ident)] -> [(Bool, Ident)]
forall a b. (a -> b) -> a -> b
$ [Bool] -> [Ident] -> [(Bool, Ident)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
bind_in_target [Ident]
free_arrs
free_params_pat :: [Param Type]
free_params_pat =
((Bool, Param Type) -> Param Type)
-> [(Bool, Param Type)] -> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map (Bool, Param Type) -> Param Type
forall a b. (a, b) -> b
snd ([(Bool, Param Type)] -> [Param Type])
-> [(Bool, Param Type)] -> [Param Type]
forall a b. (a -> b) -> a -> b
$ ((Bool, Param Type) -> Bool)
-> [(Bool, Param Type)] -> [(Bool, Param Type)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool, Param Type) -> Bool
forall a b. (a, b) -> a
fst ([(Bool, Param Type)] -> [(Bool, Param Type)])
-> [(Bool, Param Type)] -> [(Bool, Param Type)]
forall a b. (a -> b) -> a -> b
$ [Bool] -> [Param Type] -> [(Bool, Param Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
bind_in_target [Param Type]
free_params
([Param Type]
actual_params, [VName]
actual_arrs) =
( [Param Type]
params [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type]
free_params,
[VName]
arrs [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ (Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
free_arrs
)
actual_param_names :: Names
actual_param_names =
[VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
actual_params
nest'' :: LoopNesting
nest'' =
Names -> LoopNesting -> LoopNesting
removeUnusedNestingParts Names
free_in_kernel (LoopNesting -> LoopNesting) -> LoopNesting -> LoopNesting
forall a b. (a -> b) -> a -> b
$
Pat Type
-> StmAux () -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting Pat Type
pat StmAux ()
aux SubExp
w ([(Param Type, VName)] -> LoopNesting)
-> [(Param Type, VName)] -> LoopNesting
forall a b. (a -> b) -> a -> b
$
[Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
actual_params [VName]
actual_arrs
free_in_kernel'' :: Names
free_in_kernel'' =
(LoopNesting -> Names
forall a. FreeIn a => a -> Names
freeIn LoopNesting
nest'' Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
free_in_kernel) Names -> Names -> Names
`namesSubtract` Names
actual_param_names
Bool -> MaybeT m () -> MaybeT m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless
( (Param Type -> Bool) -> [Param Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Type -> Bool
distributableType (Type -> Bool) -> (Param Type -> Type) -> Param Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType) ([Param Type] -> Bool) -> [Param Type] -> Bool
forall a b. (a -> b) -> a -> b
$
LoopNesting -> [Param Type]
loopNestingParams LoopNesting
nest''
)
(MaybeT m () -> MaybeT m ()) -> MaybeT m () -> MaybeT m ()
forall a b. (a -> b) -> a -> b
$ [Char] -> MaybeT m ()
forall a. [Char] -> MaybeT m a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Would induce irregular array"
(KernelNest, Names, Targets)
-> MaybeT m (KernelNest, Names, Targets)
forall a. a -> MaybeT m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( LoopNesting -> KernelNest
add_to_kernel LoopNesting
nest'',
Names
free_in_kernel'',
Target -> Targets
addTarget (Pat Type
free_arrs_pat, [VName] -> Result
varsRes ([VName] -> Result) -> [VName] -> Result
forall a b. (a -> b) -> a -> b
$ (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
free_params_pat)
)
recurse :: [(Nesting, Target)] -> MaybeT m (KernelNest, Names, Targets)
recurse :: [(Nesting, Target)] -> MaybeT m (KernelNest, Names, Targets)
recurse [] =
Nesting
-> Pat Type
-> (LoopNesting -> KernelNest, Names)
-> Map VName Ident
-> [Ident]
-> (Target -> Targets)
-> MaybeT m (KernelNest, Names, Targets)
distributeAtNesting
Nesting
inner_nest
(DistributionBody -> Pat Type
distributionInnerPat DistributionBody
distrib_body)
( LoopNesting -> KernelNest
newKernel,
DistributionBody -> Names
distributionFreeInBody DistributionBody
distrib_body Names -> Names -> Names
`namesIntersection` Names
bound_in_nest
)
(DistributionBody -> Map VName Ident
distributionIdentityMap DistributionBody
distrib_body)
[]
((Target -> Targets) -> MaybeT m (KernelNest, Names, Targets))
-> (Target -> Targets) -> MaybeT m (KernelNest, Names, Targets)
forall a b. (a -> b) -> a -> b
$ Target -> Targets
singleTarget (Target -> Targets) -> (Target -> Target) -> Target -> Targets
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DistributionBody -> Target -> Target
distributionExpandTarget DistributionBody
distrib_body
recurse ((Nesting
nest, (Pat Type
pat, Result
res)) : [(Nesting, Target)]
nests') = do
(kernel :: KernelNest
kernel@(LoopNesting
outer, [LoopNesting]
_), Names
kernel_free, Targets
kernel_targets) <- [(Nesting, Target)] -> MaybeT m (KernelNest, Names, Targets)
recurse [(Nesting, Target)]
nests'
let (Pat Type
pat', Result
res', Map VName Ident
identity_map, Target -> Target
expand_target) =
Names
-> Pat Type
-> Result
-> (Pat Type, Result, Map VName Ident, Target -> Target)
removeIdentityMappingFromNesting
([VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat Type -> [VName]) -> Pat Type -> [VName]
forall a b. (a -> b) -> a -> b
$ LoopNesting -> Pat Type
loopNestingPat LoopNesting
outer)
Pat Type
pat
Result
res
Nesting
-> Pat Type
-> (LoopNesting -> KernelNest, Names)
-> Map VName Ident
-> [Ident]
-> (Target -> Targets)
-> MaybeT m (KernelNest, Names, Targets)
distributeAtNesting
Nesting
nest
Pat Type
pat'
( \LoopNesting
k -> Target -> LoopNesting -> KernelNest -> KernelNest
pushKernelNesting (Pat Type
pat', Result
res') LoopNesting
k KernelNest
kernel,
Names
kernel_free
)
Map VName Ident
identity_map
(Pat Type -> [Ident]
forall dec. Typed dec => Pat dec -> [Ident]
patIdents (Pat Type -> [Ident]) -> Pat Type -> [Ident]
forall a b. (a -> b) -> a -> b
$ Target -> Pat Type
forall a b. (a, b) -> a
fst (Target -> Pat Type) -> Target -> Pat Type
forall a b. (a -> b) -> a -> b
$ Targets -> Target
outerTarget Targets
kernel_targets)
((Target -> Targets -> Targets
`pushOuterTarget` Targets
kernel_targets) (Target -> Targets) -> (Target -> Target) -> Target -> Targets
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Target -> Target
expand_target)
removeUnusedNestingParts :: Names -> LoopNesting -> LoopNesting
removeUnusedNestingParts :: Names -> LoopNesting -> LoopNesting
removeUnusedNestingParts Names
used (MapNesting Pat Type
pat StmAux ()
aux SubExp
w [(Param Type, VName)]
params_and_arrs) =
Pat Type
-> StmAux () -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting Pat Type
pat StmAux ()
aux SubExp
w ([(Param Type, VName)] -> LoopNesting)
-> [(Param Type, VName)] -> LoopNesting
forall a b. (a -> b) -> a -> b
$ [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
used_params [VName]
used_arrs
where
([Param Type]
params, [VName]
arrs) = [(Param Type, VName)] -> ([Param Type], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param Type, VName)]
params_and_arrs
([Param Type]
used_params, [VName]
used_arrs) =
[(Param Type, VName)] -> ([Param Type], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Param Type, VName)] -> ([Param Type], [VName]))
-> [(Param Type, VName)] -> ([Param Type], [VName])
forall a b. (a -> b) -> a -> b
$ ((Param Type, VName) -> Bool)
-> [(Param Type, VName)] -> [(Param Type, VName)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> Names -> Bool
`nameIn` Names
used) (VName -> Bool)
-> ((Param Type, VName) -> VName) -> (Param Type, VName) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> VName
forall dec. Param dec -> VName
paramName (Param Type -> VName)
-> ((Param Type, VName) -> Param Type)
-> (Param Type, VName)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst) ([(Param Type, VName)] -> [(Param Type, VName)])
-> [(Param Type, VName)] -> [(Param Type, VName)]
forall a b. (a -> b) -> a -> b
$ [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
params [VName]
arrs
removeIdentityMappingGeneral ::
Names ->
Pat Type ->
Result ->
( Pat Type,
Result,
M.Map VName Ident,
Target -> Target
)
removeIdentityMappingGeneral :: Names
-> Pat Type
-> Result
-> (Pat Type, Result, Map VName Ident, Target -> Target)
removeIdentityMappingGeneral Names
bound Pat Type
pat Result
res =
let ([(PatElem Type, (Certs, VName))]
identities, [(PatElem Type, SubExpRes)]
not_identities) =
((PatElem Type, SubExpRes)
-> Either (PatElem Type, (Certs, VName)) (PatElem Type, SubExpRes))
-> [(PatElem Type, SubExpRes)]
-> ([(PatElem Type, (Certs, VName))], [(PatElem Type, SubExpRes)])
forall a b c. (a -> Either b c) -> [a] -> ([b], [c])
mapEither (PatElem Type, SubExpRes)
-> Either (PatElem Type, (Certs, VName)) (PatElem Type, SubExpRes)
isIdentity ([(PatElem Type, SubExpRes)]
-> ([(PatElem Type, (Certs, VName))], [(PatElem Type, SubExpRes)]))
-> [(PatElem Type, SubExpRes)]
-> ([(PatElem Type, (Certs, VName))], [(PatElem Type, SubExpRes)])
forall a b. (a -> b) -> a -> b
$ [PatElem Type] -> Result -> [(PatElem Type, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
pat) Result
res
([PatElem Type]
not_identity_patElems, Result
not_identity_res) = [(PatElem Type, SubExpRes)] -> ([PatElem Type], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip [(PatElem Type, SubExpRes)]
not_identities
([PatElem Type]
identity_patElems, [(Certs, VName)]
identity_res) = [(PatElem Type, (Certs, VName))]
-> ([PatElem Type], [(Certs, VName)])
forall a b. [(a, b)] -> ([a], [b])
unzip [(PatElem Type, (Certs, VName))]
identities
expandTarget :: Target -> Target
expandTarget (Pat Type
tpat, Result
tres) =
( [PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem Type] -> Pat Type) -> [PatElem Type] -> Pat Type
forall a b. (a -> b) -> a -> b
$ Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
tpat [PatElem Type] -> [PatElem Type] -> [PatElem Type]
forall a. [a] -> [a] -> [a]
++ [PatElem Type]
identity_patElems,
Result
tres Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ ((Certs, VName) -> SubExpRes) -> [(Certs, VName)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map ((Certs -> SubExp -> SubExpRes) -> (Certs, SubExp) -> SubExpRes
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Certs -> SubExp -> SubExpRes
SubExpRes ((Certs, SubExp) -> SubExpRes)
-> ((Certs, VName) -> (Certs, SubExp))
-> (Certs, VName)
-> SubExpRes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> SubExp) -> (Certs, VName) -> (Certs, SubExp)
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second VName -> SubExp
Var) [(Certs, VName)]
identity_res
)
identity_map :: Map VName Ident
identity_map =
[(VName, Ident)] -> Map VName Ident
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Ident)] -> Map VName Ident)
-> [(VName, Ident)] -> Map VName Ident
forall a b. (a -> b) -> a -> b
$ [VName] -> [Ident] -> [(VName, Ident)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((Certs, VName) -> VName) -> [(Certs, VName)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Certs, VName) -> VName
forall a b. (a, b) -> b
snd [(Certs, VName)]
identity_res) ([Ident] -> [(VName, Ident)]) -> [Ident] -> [(VName, Ident)]
forall a b. (a -> b) -> a -> b
$ (PatElem Type -> Ident) -> [PatElem Type] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map PatElem Type -> Ident
forall dec. Typed dec => PatElem dec -> Ident
patElemIdent [PatElem Type]
identity_patElems
in ( [PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type]
not_identity_patElems,
Result
not_identity_res,
Map VName Ident
identity_map,
Target -> Target
expandTarget
)
where
isIdentity :: (PatElem Type, SubExpRes)
-> Either (PatElem Type, (Certs, VName)) (PatElem Type, SubExpRes)
isIdentity (PatElem Type
patElem, SubExpRes Certs
_ (Var VName
v))
| VName
v VName -> Names -> Bool
`notNameIn` Names
bound = (PatElem Type, (Certs, VName))
-> Either (PatElem Type, (Certs, VName)) (PatElem Type, SubExpRes)
forall a b. a -> Either a b
Left (PatElem Type
patElem, (Certs
forall a. Monoid a => a
mempty, VName
v))
isIdentity (PatElem Type, SubExpRes)
x = (PatElem Type, SubExpRes)
-> Either (PatElem Type, (Certs, VName)) (PatElem Type, SubExpRes)
forall a b. b -> Either a b
Right (PatElem Type, SubExpRes)
x
removeIdentityMappingFromNesting ::
Names ->
Pat Type ->
Result ->
( Pat Type,
Result,
M.Map VName Ident,
Target -> Target
)
removeIdentityMappingFromNesting :: Names
-> Pat Type
-> Result
-> (Pat Type, Result, Map VName Ident, Target -> Target)
removeIdentityMappingFromNesting Names
bound_in_nesting Pat Type
pat Result
res =
let (Pat Type
pat', Result
res', Map VName Ident
identity_map, Target -> Target
expand_target) =
Names
-> Pat Type
-> Result
-> (Pat Type, Result, Map VName Ident, Target -> Target)
removeIdentityMappingGeneral Names
bound_in_nesting Pat Type
pat Result
res
in (Pat Type
pat', Result
res', Map VName Ident
identity_map, Target -> Target
expand_target)
tryDistribute ::
( DistRep rep,
MonadFreshNames m,
LocalScope rep m,
MonadLogger m
) =>
MkSegLevel rep m ->
Nestings ->
Targets ->
Stms rep ->
m (Maybe (Targets, Stms rep))
tryDistribute :: forall rep (m :: * -> *).
(DistRep rep, MonadFreshNames m, LocalScope rep m,
MonadLogger m) =>
MkSegLevel rep m
-> Nestings -> Targets -> Stms rep -> m (Maybe (Targets, Stms rep))
tryDistribute MkSegLevel rep m
_ Nestings
_ Targets
targets Stms rep
stms
| Stms rep -> Bool
forall a. Seq a -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Stms rep
stms =
Maybe (Targets, Stms rep) -> m (Maybe (Targets, Stms rep))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (Targets, Stms rep) -> m (Maybe (Targets, Stms rep)))
-> Maybe (Targets, Stms rep) -> m (Maybe (Targets, Stms rep))
forall a b. (a -> b) -> a -> b
$ (Targets, Stms rep) -> Maybe (Targets, Stms rep)
forall a. a -> Maybe a
Just (Targets
targets, Stms rep
forall a. Monoid a => a
mempty)
tryDistribute MkSegLevel rep m
mk_lvl Nestings
nest Targets
targets Stms rep
stms =
Nestings -> DistributionBody -> m (Maybe (Targets, KernelNest))
forall rep (m :: * -> *).
(MonadFreshNames m, HasScope rep m) =>
Nestings -> DistributionBody -> m (Maybe (Targets, KernelNest))
createKernelNest Nestings
nest DistributionBody
dist_body
m (Maybe (Targets, KernelNest))
-> (Maybe (Targets, KernelNest) -> m (Maybe (Targets, Stms rep)))
-> m (Maybe (Targets, Stms rep))
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just (Targets
targets', KernelNest
distributed) -> do
(Stm rep
kernel_stm, Stms rep
w_stms) <-
Scope rep -> m (Stm rep, Stms rep) -> m (Stm rep, Stms rep)
forall a. Scope rep -> m a -> m a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Targets -> Scope rep
forall rep. DistRep rep => Targets -> Scope rep
targetsScope Targets
targets') (m (Stm rep, Stms rep) -> m (Stm rep, Stms rep))
-> m (Stm rep, Stms rep) -> m (Stm rep, Stms rep)
forall a b. (a -> b) -> a -> b
$
MkSegLevel rep m -> KernelNest -> Body rep -> m (Stm rep, Stms rep)
forall rep (m :: * -> *).
(DistRep rep, MonadFreshNames m, LocalScope rep m) =>
MkSegLevel rep m -> KernelNest -> Body rep -> m (Stm rep, Stms rep)
constructKernel MkSegLevel rep m
mk_lvl KernelNest
distributed (Body rep -> m (Stm rep, Stms rep))
-> Body rep -> m (Stm rep, Stms rep)
forall a b. (a -> b) -> a -> b
$
Stms rep -> Result -> Body rep
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody Stms rep
stms Result
inner_body_res
Stm rep
distributed' <- Stm rep -> m (Stm rep)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Stm rep -> m (Stm rep)
renameStm Stm rep
kernel_stm
[Char] -> m ()
forall a. ToLog a => a -> m ()
forall (m :: * -> *) a. (MonadLogger m, ToLog a) => a -> m ()
logMsg ([Char] -> m ()) -> [Char] -> m ()
forall a b. (a -> b) -> a -> b
$
[Char]
"distributing\n"
[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [[Char]] -> [Char]
unlines ((Stm rep -> [Char]) -> [Stm rep] -> [[Char]]
forall a b. (a -> b) -> [a] -> [b]
map Stm rep -> [Char]
forall a. Pretty a => a -> [Char]
prettyString ([Stm rep] -> [[Char]]) -> [Stm rep] -> [[Char]]
forall a b. (a -> b) -> a -> b
$ Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms rep
stms)
[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Result -> [Char]
forall a. Pretty a => a -> [Char]
prettyString (Target -> Result
forall a b. (a, b) -> b
snd (Target -> Result) -> Target -> Result
forall a b. (a -> b) -> a -> b
$ Targets -> Target
innerTarget Targets
targets)
[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"\nas\n"
[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Stm rep -> [Char]
forall a. Pretty a => a -> [Char]
prettyString Stm rep
distributed'
[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"\ndue to targets\n"
[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Targets -> [Char]
ppTargets Targets
targets
[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"\nand with new targets\n"
[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Targets -> [Char]
ppTargets Targets
targets'
Maybe (Targets, Stms rep) -> m (Maybe (Targets, Stms rep))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (Targets, Stms rep) -> m (Maybe (Targets, Stms rep)))
-> Maybe (Targets, Stms rep) -> m (Maybe (Targets, Stms rep))
forall a b. (a -> b) -> a -> b
$ (Targets, Stms rep) -> Maybe (Targets, Stms rep)
forall a. a -> Maybe a
Just (Targets
targets', Stms rep
w_stms Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> Stm rep -> Stms rep
forall rep. Stm rep -> Stms rep
oneStm Stm rep
distributed')
Maybe (Targets, KernelNest)
Nothing ->
Maybe (Targets, Stms rep) -> m (Maybe (Targets, Stms rep))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Targets, Stms rep)
forall a. Maybe a
Nothing
where
(DistributionBody
dist_body, Result
inner_body_res) = Targets -> Stms rep -> (DistributionBody, Result)
forall rep.
ASTRep rep =>
Targets -> Stms rep -> (DistributionBody, Result)
distributionBodyFromStms Targets
targets Stms rep
stms
tryDistributeStm ::
(MonadFreshNames m, HasScope t m, ASTRep rep) =>
Nestings ->
Targets ->
Stm rep ->
m (Maybe (Result, Targets, KernelNest))
tryDistributeStm :: forall (m :: * -> *) t rep.
(MonadFreshNames m, HasScope t m, ASTRep rep) =>
Nestings
-> Targets -> Stm rep -> m (Maybe (Result, Targets, KernelNest))
tryDistributeStm Nestings
nest Targets
targets Stm rep
stm =
((Targets, KernelNest) -> (Result, Targets, KernelNest))
-> Maybe (Targets, KernelNest)
-> Maybe (Result, Targets, KernelNest)
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Targets, KernelNest) -> (Result, Targets, KernelNest)
addRes (Maybe (Targets, KernelNest)
-> Maybe (Result, Targets, KernelNest))
-> m (Maybe (Targets, KernelNest))
-> m (Maybe (Result, Targets, KernelNest))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Nestings -> DistributionBody -> m (Maybe (Targets, KernelNest))
forall rep (m :: * -> *).
(MonadFreshNames m, HasScope rep m) =>
Nestings -> DistributionBody -> m (Maybe (Targets, KernelNest))
createKernelNest Nestings
nest DistributionBody
dist_body
where
(DistributionBody
dist_body, Result
res) = Targets -> Stm rep -> (DistributionBody, Result)
forall rep.
ASTRep rep =>
Targets -> Stm rep -> (DistributionBody, Result)
distributionBodyFromStm Targets
targets Stm rep
stm
addRes :: (Targets, KernelNest) -> (Result, Targets, KernelNest)
addRes (Targets
targets', KernelNest
kernel_nest) = (Result
res, Targets
targets', KernelNest
kernel_nest)