module Futhark.Optimise.ReduceDeviceSyncs.MigrationTable
(
analyseFunDef,
analyseConsts,
hostOnlyFunDefs,
MigrationTable,
MigrationStatus (..),
shouldMoveStm,
shouldMove,
usedOnHost,
statusOf,
)
where
import Control.Monad
import Control.Monad.Trans.Class
import Control.Monad.Trans.Reader qualified as R
import Control.Monad.Trans.State.Strict ()
import Control.Monad.Trans.State.Strict hiding (State)
import Data.Bifunctor (first, second)
import Data.Foldable
import Data.IntMap.Strict qualified as IM
import Data.IntSet qualified as IS
import Data.List qualified as L
import Data.Map.Strict qualified as M
import Data.Maybe (fromMaybe, isJust, isNothing)
import Data.Sequence qualified as SQ
import Data.Set (Set, (\\))
import Data.Set qualified as S
import Futhark.Error
import Futhark.IR.GPU
import Futhark.Optimise.ReduceDeviceSyncs.MigrationTable.Graph
( EdgeType (..),
Edges (..),
Id,
IdSet,
Result (..),
Routing (..),
Vertex (..),
)
import Futhark.Optimise.ReduceDeviceSyncs.MigrationTable.Graph qualified as MG
data MigrationStatus
=
MoveToDevice
|
UsedOnHost
|
StayOnHost
deriving (MigrationStatus -> MigrationStatus -> Bool
(MigrationStatus -> MigrationStatus -> Bool)
-> (MigrationStatus -> MigrationStatus -> Bool)
-> Eq MigrationStatus
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: MigrationStatus -> MigrationStatus -> Bool
== :: MigrationStatus -> MigrationStatus -> Bool
$c/= :: MigrationStatus -> MigrationStatus -> Bool
/= :: MigrationStatus -> MigrationStatus -> Bool
Eq, Eq MigrationStatus
Eq MigrationStatus =>
(MigrationStatus -> MigrationStatus -> Ordering)
-> (MigrationStatus -> MigrationStatus -> Bool)
-> (MigrationStatus -> MigrationStatus -> Bool)
-> (MigrationStatus -> MigrationStatus -> Bool)
-> (MigrationStatus -> MigrationStatus -> Bool)
-> (MigrationStatus -> MigrationStatus -> MigrationStatus)
-> (MigrationStatus -> MigrationStatus -> MigrationStatus)
-> Ord MigrationStatus
MigrationStatus -> MigrationStatus -> Bool
MigrationStatus -> MigrationStatus -> Ordering
MigrationStatus -> MigrationStatus -> MigrationStatus
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: MigrationStatus -> MigrationStatus -> Ordering
compare :: MigrationStatus -> MigrationStatus -> Ordering
$c< :: MigrationStatus -> MigrationStatus -> Bool
< :: MigrationStatus -> MigrationStatus -> Bool
$c<= :: MigrationStatus -> MigrationStatus -> Bool
<= :: MigrationStatus -> MigrationStatus -> Bool
$c> :: MigrationStatus -> MigrationStatus -> Bool
> :: MigrationStatus -> MigrationStatus -> Bool
$c>= :: MigrationStatus -> MigrationStatus -> Bool
>= :: MigrationStatus -> MigrationStatus -> Bool
$cmax :: MigrationStatus -> MigrationStatus -> MigrationStatus
max :: MigrationStatus -> MigrationStatus -> MigrationStatus
$cmin :: MigrationStatus -> MigrationStatus -> MigrationStatus
min :: MigrationStatus -> MigrationStatus -> MigrationStatus
Ord, Id -> MigrationStatus -> ShowS
[MigrationStatus] -> ShowS
MigrationStatus -> String
(Id -> MigrationStatus -> ShowS)
-> (MigrationStatus -> String)
-> ([MigrationStatus] -> ShowS)
-> Show MigrationStatus
forall a.
(Id -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Id -> MigrationStatus -> ShowS
showsPrec :: Id -> MigrationStatus -> ShowS
$cshow :: MigrationStatus -> String
show :: MigrationStatus -> String
$cshowList :: [MigrationStatus] -> ShowS
showList :: [MigrationStatus] -> ShowS
Show)
newtype MigrationTable = MigrationTable (IM.IntMap MigrationStatus)
instance Semigroup MigrationTable where
MigrationTable IntMap MigrationStatus
a <> :: MigrationTable -> MigrationTable -> MigrationTable
<> MigrationTable IntMap MigrationStatus
b = IntMap MigrationStatus -> MigrationTable
MigrationTable (IntMap MigrationStatus
a IntMap MigrationStatus
-> IntMap MigrationStatus -> IntMap MigrationStatus
forall a. IntMap a -> IntMap a -> IntMap a
`IM.union` IntMap MigrationStatus
b)
statusOf :: VName -> MigrationTable -> MigrationStatus
statusOf :: VName -> MigrationTable -> MigrationStatus
statusOf VName
n (MigrationTable IntMap MigrationStatus
mt) =
MigrationStatus -> Maybe MigrationStatus -> MigrationStatus
forall a. a -> Maybe a -> a
fromMaybe MigrationStatus
StayOnHost (Maybe MigrationStatus -> MigrationStatus)
-> Maybe MigrationStatus -> MigrationStatus
forall a b. (a -> b) -> a -> b
$ Id -> IntMap MigrationStatus -> Maybe MigrationStatus
forall a. Id -> IntMap a -> Maybe a
IM.lookup (VName -> Id
baseTag VName
n) IntMap MigrationStatus
mt
shouldMoveStm :: Stm GPU -> MigrationTable -> Bool
shouldMoveStm :: Stm GPU -> MigrationTable -> Bool
shouldMoveStm (Let (Pat ((PatElem VName
n LetDec GPU
_) : [PatElem (LetDec GPU)]
_)) StmAux (ExpDec GPU)
_ (BasicOp (Index VName
_ Slice SubExp
slice))) MigrationTable
mt =
VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt MigrationStatus -> MigrationStatus -> Bool
forall a. Eq a => a -> a -> Bool
== MigrationStatus
MoveToDevice Bool -> Bool -> Bool
|| (SubExp -> Bool) -> Slice SubExp -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any SubExp -> Bool
movedOperand Slice SubExp
slice
where
movedOperand :: SubExp -> Bool
movedOperand (Var VName
op) = VName -> MigrationTable -> MigrationStatus
statusOf VName
op MigrationTable
mt MigrationStatus -> MigrationStatus -> Bool
forall a. Eq a => a -> a -> Bool
== MigrationStatus
MoveToDevice
movedOperand SubExp
_ = Bool
False
shouldMoveStm (Let (Pat ((PatElem VName
n LetDec GPU
_) : [PatElem (LetDec GPU)]
_)) StmAux (ExpDec GPU)
_ (BasicOp BasicOp
_)) MigrationTable
mt =
VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt MigrationStatus -> MigrationStatus -> Bool
forall a. Eq a => a -> a -> Bool
/= MigrationStatus
StayOnHost
shouldMoveStm (Let (Pat ((PatElem VName
n LetDec GPU
_) : [PatElem (LetDec GPU)]
_)) StmAux (ExpDec GPU)
_ Apply {}) MigrationTable
mt =
VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt MigrationStatus -> MigrationStatus -> Bool
forall a. Eq a => a -> a -> Bool
/= MigrationStatus
StayOnHost
shouldMoveStm (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (Match [SubExp]
cond [Case (Body GPU)]
_ Body GPU
_ MatchDec (BranchType GPU)
_)) MigrationTable
mt =
(VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((MigrationStatus -> MigrationStatus -> Bool
forall a. Eq a => a -> a -> Bool
== MigrationStatus
MoveToDevice) (MigrationStatus -> Bool)
-> (VName -> MigrationStatus) -> VName -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> MigrationTable -> MigrationStatus
`statusOf` MigrationTable
mt)) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [VName]
subExpVars [SubExp]
cond
shouldMoveStm (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (Loop [(FParam GPU, SubExp)]
_ (ForLoop VName
_ IntType
_ (Var VName
n)) Body GPU
_)) MigrationTable
mt =
VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt MigrationStatus -> MigrationStatus -> Bool
forall a. Eq a => a -> a -> Bool
== MigrationStatus
MoveToDevice
shouldMoveStm (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (Loop [(FParam GPU, SubExp)]
_ (WhileLoop VName
n) Body GPU
_)) MigrationTable
mt =
VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt MigrationStatus -> MigrationStatus -> Bool
forall a. Eq a => a -> a -> Bool
== MigrationStatus
MoveToDevice
shouldMoveStm Stm GPU
_ MigrationTable
_ = Bool
False
shouldMove :: VName -> MigrationTable -> Bool
shouldMove :: VName -> MigrationTable -> Bool
shouldMove VName
n MigrationTable
mt = VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt MigrationStatus -> MigrationStatus -> Bool
forall a. Eq a => a -> a -> Bool
/= MigrationStatus
StayOnHost
usedOnHost :: VName -> MigrationTable -> Bool
usedOnHost :: VName -> MigrationTable -> Bool
usedOnHost VName
n MigrationTable
mt = VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt MigrationStatus -> MigrationStatus -> Bool
forall a. Eq a => a -> a -> Bool
/= MigrationStatus
MoveToDevice
type HostOnlyFuns = Set Name
hostOnlyFunDefs :: [FunDef GPU] -> HostOnlyFuns
hostOnlyFunDefs :: [FunDef GPU] -> HostOnlyFuns
hostOnlyFunDefs [FunDef GPU]
funs =
let names :: [Name]
names = (FunDef GPU -> Name) -> [FunDef GPU] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map FunDef GPU -> Name
forall rep. FunDef rep -> Name
funDefName [FunDef GPU]
funs
call_map :: Map Name (Maybe HostOnlyFuns)
call_map = [(Name, Maybe HostOnlyFuns)] -> Map Name (Maybe HostOnlyFuns)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Name, Maybe HostOnlyFuns)] -> Map Name (Maybe HostOnlyFuns))
-> [(Name, Maybe HostOnlyFuns)] -> Map Name (Maybe HostOnlyFuns)
forall a b. (a -> b) -> a -> b
$ [Name] -> [Maybe HostOnlyFuns] -> [(Name, Maybe HostOnlyFuns)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Name]
names ((FunDef GPU -> Maybe HostOnlyFuns)
-> [FunDef GPU] -> [Maybe HostOnlyFuns]
forall a b. (a -> b) -> [a] -> [b]
map FunDef GPU -> Maybe HostOnlyFuns
checkFunDef [FunDef GPU]
funs)
in [Name] -> HostOnlyFuns
forall a. Ord a => [a] -> Set a
S.fromList [Name]
names HostOnlyFuns -> HostOnlyFuns -> HostOnlyFuns
forall a. Ord a => Set a -> Set a -> Set a
\\ Map Name (Maybe HostOnlyFuns) -> HostOnlyFuns
forall {a}. Map Name a -> HostOnlyFuns
keysToSet (Map Name (Maybe HostOnlyFuns) -> Map Name (Maybe HostOnlyFuns)
removeHostOnly Map Name (Maybe HostOnlyFuns)
call_map)
where
keysToSet :: Map Name a -> HostOnlyFuns
keysToSet = [Name] -> HostOnlyFuns
forall a. Eq a => [a] -> Set a
S.fromAscList ([Name] -> HostOnlyFuns)
-> (Map Name a -> [Name]) -> Map Name a -> HostOnlyFuns
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map Name a -> [Name]
forall k a. Map k a -> [k]
M.keys
removeHostOnly :: Map Name (Maybe HostOnlyFuns) -> Map Name (Maybe HostOnlyFuns)
removeHostOnly Map Name (Maybe HostOnlyFuns)
cm =
let (Map Name (Maybe HostOnlyFuns)
host_only, Map Name (Maybe HostOnlyFuns)
cm') = (Maybe HostOnlyFuns -> Bool)
-> Map Name (Maybe HostOnlyFuns)
-> (Map Name (Maybe HostOnlyFuns), Map Name (Maybe HostOnlyFuns))
forall a k. (a -> Bool) -> Map k a -> (Map k a, Map k a)
M.partition Maybe HostOnlyFuns -> Bool
forall {a}. Maybe a -> Bool
isHostOnly Map Name (Maybe HostOnlyFuns)
cm
in if Map Name (Maybe HostOnlyFuns) -> Bool
forall k a. Map k a -> Bool
M.null Map Name (Maybe HostOnlyFuns)
host_only
then Map Name (Maybe HostOnlyFuns)
cm'
else Map Name (Maybe HostOnlyFuns) -> Map Name (Maybe HostOnlyFuns)
removeHostOnly (Map Name (Maybe HostOnlyFuns) -> Map Name (Maybe HostOnlyFuns))
-> Map Name (Maybe HostOnlyFuns) -> Map Name (Maybe HostOnlyFuns)
forall a b. (a -> b) -> a -> b
$ (Maybe HostOnlyFuns -> Maybe HostOnlyFuns)
-> Map Name (Maybe HostOnlyFuns) -> Map Name (Maybe HostOnlyFuns)
forall a b k. (a -> b) -> Map k a -> Map k b
M.map (HostOnlyFuns -> Maybe HostOnlyFuns -> Maybe HostOnlyFuns
forall {a}. Ord a => Set a -> Maybe (Set a) -> Maybe (Set a)
checkCalls (HostOnlyFuns -> Maybe HostOnlyFuns -> Maybe HostOnlyFuns)
-> HostOnlyFuns -> Maybe HostOnlyFuns -> Maybe HostOnlyFuns
forall a b. (a -> b) -> a -> b
$ Map Name (Maybe HostOnlyFuns) -> HostOnlyFuns
forall {a}. Map Name a -> HostOnlyFuns
keysToSet Map Name (Maybe HostOnlyFuns)
host_only) Map Name (Maybe HostOnlyFuns)
cm'
isHostOnly :: Maybe a -> Bool
isHostOnly = Maybe a -> Bool
forall {a}. Maybe a -> Bool
isNothing
checkCalls :: Set a -> Maybe (Set a) -> Maybe (Set a)
checkCalls Set a
hostOnlyFuns (Just Set a
calls)
| Set a
hostOnlyFuns Set a -> Set a -> Bool
forall a. Ord a => Set a -> Set a -> Bool
`S.disjoint` Set a
calls =
Set a -> Maybe (Set a)
forall a. a -> Maybe a
Just Set a
calls
checkCalls Set a
_ Maybe (Set a)
_ =
Maybe (Set a)
forall a. Maybe a
Nothing
checkFunDef :: FunDef GPU -> Maybe (Set Name)
checkFunDef :: FunDef GPU -> Maybe HostOnlyFuns
checkFunDef FunDef GPU
fun = do
[Param DeclType] -> Maybe ()
checkFParams ([Param DeclType] -> Maybe ()) -> [Param DeclType] -> Maybe ()
forall a b. (a -> b) -> a -> b
$ FunDef GPU -> [FParam GPU]
forall rep. FunDef rep -> [FParam rep]
funDefParams FunDef GPU
fun
[TypeBase ExtShape Uniqueness] -> Maybe ()
forall {u}. [TypeBase ExtShape u] -> Maybe ()
checkRetTypes ([TypeBase ExtShape Uniqueness] -> Maybe ())
-> [TypeBase ExtShape Uniqueness] -> Maybe ()
forall a b. (a -> b) -> a -> b
$ ((TypeBase ExtShape Uniqueness, RetAls)
-> TypeBase ExtShape Uniqueness)
-> [(TypeBase ExtShape Uniqueness, RetAls)]
-> [TypeBase ExtShape Uniqueness]
forall a b. (a -> b) -> [a] -> [b]
map (TypeBase ExtShape Uniqueness, RetAls)
-> TypeBase ExtShape Uniqueness
forall a b. (a, b) -> a
fst ([(TypeBase ExtShape Uniqueness, RetAls)]
-> [TypeBase ExtShape Uniqueness])
-> [(TypeBase ExtShape Uniqueness, RetAls)]
-> [TypeBase ExtShape Uniqueness]
forall a b. (a -> b) -> a -> b
$ FunDef GPU -> [(RetType GPU, RetAls)]
forall rep. FunDef rep -> [(RetType rep, RetAls)]
funDefRetType FunDef GPU
fun
Body GPU -> Maybe HostOnlyFuns
checkBody (Body GPU -> Maybe HostOnlyFuns) -> Body GPU -> Maybe HostOnlyFuns
forall a b. (a -> b) -> a -> b
$ FunDef GPU -> Body GPU
forall rep. FunDef rep -> Body rep
funDefBody FunDef GPU
fun
where
hostOnly :: Maybe a
hostOnly = Maybe a
forall a. Maybe a
Nothing
ok :: Maybe ()
ok = () -> Maybe ()
forall a. a -> Maybe a
Just ()
check :: (a -> Bool) -> t a -> Maybe ()
check a -> Bool
isArr t a
as = if (a -> Bool) -> t a -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any a -> Bool
isArr t a
as then Maybe ()
forall a. Maybe a
hostOnly else Maybe ()
ok
checkFParams :: [Param DeclType] -> Maybe ()
checkFParams = (Param DeclType -> Bool) -> [Param DeclType] -> Maybe ()
forall {t :: * -> *} {a}.
Foldable t =>
(a -> Bool) -> t a -> Maybe ()
check Param DeclType -> Bool
forall t. Typed t => t -> Bool
isArray
checkLParams :: [(FParam GPU, b)] -> Maybe ()
checkLParams = ((FParam GPU, b) -> Bool) -> [(FParam GPU, b)] -> Maybe ()
forall {t :: * -> *} {a}.
Foldable t =>
(a -> Bool) -> t a -> Maybe ()
check (FParam GPU -> Bool
forall t. Typed t => t -> Bool
isArray (FParam GPU -> Bool)
-> ((FParam GPU, b) -> FParam GPU) -> (FParam GPU, b) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam GPU, b) -> FParam GPU
forall a b. (a, b) -> a
fst)
checkRetTypes :: [TypeBase ExtShape u] -> Maybe ()
checkRetTypes = (TypeBase ExtShape u -> Bool) -> [TypeBase ExtShape u] -> Maybe ()
forall {t :: * -> *} {a}.
Foldable t =>
(a -> Bool) -> t a -> Maybe ()
check TypeBase ExtShape u -> Bool
forall shape u. ArrayShape shape => TypeBase shape u -> Bool
isArrayType
checkPats :: [PatElem (LetDec GPU)] -> Maybe ()
checkPats = (PatElem (LetDec GPU) -> Bool)
-> [PatElem (LetDec GPU)] -> Maybe ()
forall {t :: * -> *} {a}.
Foldable t =>
(a -> Bool) -> t a -> Maybe ()
check PatElem (LetDec GPU) -> Bool
forall t. Typed t => t -> Bool
isArray
checkBody :: Body GPU -> Maybe HostOnlyFuns
checkBody = Stms GPU -> Maybe HostOnlyFuns
checkStms (Stms GPU -> Maybe HostOnlyFuns)
-> (Body GPU -> Stms GPU) -> Body GPU -> Maybe HostOnlyFuns
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms
checkStms :: Stms GPU -> Maybe HostOnlyFuns
checkStms Stms GPU
stms = Seq HostOnlyFuns -> HostOnlyFuns
forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
S.unions (Seq HostOnlyFuns -> HostOnlyFuns)
-> Maybe (Seq HostOnlyFuns) -> Maybe HostOnlyFuns
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm GPU -> Maybe HostOnlyFuns)
-> Stms GPU -> Maybe (Seq HostOnlyFuns)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Seq a -> m (Seq b)
mapM Stm GPU -> Maybe HostOnlyFuns
checkStm Stms GPU
stms
checkStm :: Stm GPU -> Maybe HostOnlyFuns
checkStm (Let (Pat [PatElem (LetDec GPU)]
pats) StmAux (ExpDec GPU)
_ Exp GPU
e) = [PatElem (LetDec GPU)] -> Maybe ()
checkPats [PatElem (LetDec GPU)]
pats Maybe () -> Maybe HostOnlyFuns -> Maybe HostOnlyFuns
forall a b. Maybe a -> Maybe b -> Maybe b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Exp GPU -> Maybe HostOnlyFuns
checkExp Exp GPU
e
checkExp :: Exp GPU -> Maybe HostOnlyFuns
checkExp (BasicOp (Index VName
_ Slice SubExp
_)) = Maybe HostOnlyFuns
forall a. Maybe a
hostOnly
checkExp (WithAcc [WithAccInput GPU]
_ Lambda GPU
_) = Maybe HostOnlyFuns
forall a. Maybe a
hostOnly
checkExp (Op Op GPU
_) = Maybe HostOnlyFuns
forall a. Maybe a
hostOnly
checkExp (Apply Name
fn [(SubExp, Diet)]
_ [(RetType GPU, RetAls)]
_ (Safety, SrcLoc, [SrcLoc])
_) = HostOnlyFuns -> Maybe HostOnlyFuns
forall a. a -> Maybe a
Just (Name -> HostOnlyFuns
forall a. a -> Set a
S.singleton Name
fn)
checkExp (Match [SubExp]
_ [Case (Body GPU)]
cases Body GPU
defbody MatchDec (BranchType GPU)
_) =
[HostOnlyFuns] -> HostOnlyFuns
forall a. Monoid a => [a] -> a
mconcat ([HostOnlyFuns] -> HostOnlyFuns)
-> Maybe [HostOnlyFuns] -> Maybe HostOnlyFuns
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Body GPU -> Maybe HostOnlyFuns)
-> [Body GPU] -> Maybe [HostOnlyFuns]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Body GPU -> Maybe HostOnlyFuns
checkBody (Body GPU
defbody Body GPU -> [Body GPU] -> [Body GPU]
forall a. a -> [a] -> [a]
: (Case (Body GPU) -> Body GPU) -> [Case (Body GPU)] -> [Body GPU]
forall a b. (a -> b) -> [a] -> [b]
map Case (Body GPU) -> Body GPU
forall body. Case body -> body
caseBody [Case (Body GPU)]
cases)
checkExp (Loop [(FParam GPU, SubExp)]
params LoopForm
_ Body GPU
body) = do
[(FParam GPU, SubExp)] -> Maybe ()
forall {b}. [(FParam GPU, b)] -> Maybe ()
checkLParams [(FParam GPU, SubExp)]
params
Body GPU -> Maybe HostOnlyFuns
checkBody Body GPU
body
checkExp BasicOp {} = HostOnlyFuns -> Maybe HostOnlyFuns
forall a. a -> Maybe a
Just HostOnlyFuns
forall a. Set a
S.empty
type HostUsage = [Id]
nameToId :: VName -> Id
nameToId :: VName -> Id
nameToId = VName -> Id
baseTag
analyseConsts :: HostOnlyFuns -> [FunDef GPU] -> Stms GPU -> MigrationTable
analyseConsts :: HostOnlyFuns -> [FunDef GPU] -> Stms GPU -> MigrationTable
analyseConsts HostOnlyFuns
hof [FunDef GPU]
funs Stms GPU
consts =
let usage :: [Id]
usage = ([Id] -> VName -> NameInfo GPU -> [Id])
-> [Id] -> Map VName (NameInfo GPU) -> [Id]
forall a k b. (a -> k -> b -> a) -> a -> Map k b -> a
M.foldlWithKey (Names -> [Id] -> VName -> NameInfo GPU -> [Id]
forall {t}. Typed t => Names -> [Id] -> VName -> t -> [Id]
f (Names -> [Id] -> VName -> NameInfo GPU -> [Id])
-> Names -> [Id] -> VName -> NameInfo GPU -> [Id]
forall a b. (a -> b) -> a -> b
$ [FunDef GPU] -> Names
forall a. FreeIn a => a -> Names
freeIn [FunDef GPU]
funs) [] (Stms GPU -> Map VName (NameInfo GPU)
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms GPU
consts)
in HostOnlyFuns -> [Id] -> Stms GPU -> MigrationTable
analyseStms HostOnlyFuns
hof [Id]
usage Stms GPU
consts
where
f :: Names -> [Id] -> VName -> t -> [Id]
f Names
free [Id]
usage VName
n t
t
| t -> Bool
forall t. Typed t => t -> Bool
isScalar t
t,
VName
n VName -> Names -> Bool
`nameIn` Names
free =
VName -> Id
nameToId VName
n Id -> [Id] -> [Id]
forall a. a -> [a] -> [a]
: [Id]
usage
| Bool
otherwise =
[Id]
usage
analyseFunDef :: HostOnlyFuns -> FunDef GPU -> MigrationTable
analyseFunDef :: HostOnlyFuns -> FunDef GPU -> MigrationTable
analyseFunDef HostOnlyFuns
hof FunDef GPU
fd =
let body :: Body GPU
body = FunDef GPU -> Body GPU
forall rep. FunDef rep -> Body rep
funDefBody FunDef GPU
fd
usage :: [Id]
usage = ([Id] -> (SubExpRes, TypeBase ExtShape Uniqueness) -> [Id])
-> [Id] -> [(SubExpRes, TypeBase ExtShape Uniqueness)] -> [Id]
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' [Id] -> (SubExpRes, TypeBase ExtShape Uniqueness) -> [Id]
forall {shape} {u}. [Id] -> (SubExpRes, TypeBase shape u) -> [Id]
f [] ([(SubExpRes, TypeBase ExtShape Uniqueness)] -> [Id])
-> [(SubExpRes, TypeBase ExtShape Uniqueness)] -> [Id]
forall a b. (a -> b) -> a -> b
$ [SubExpRes]
-> [TypeBase ExtShape Uniqueness]
-> [(SubExpRes, TypeBase ExtShape Uniqueness)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Body GPU -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body GPU
body) (((TypeBase ExtShape Uniqueness, RetAls)
-> TypeBase ExtShape Uniqueness)
-> [(TypeBase ExtShape Uniqueness, RetAls)]
-> [TypeBase ExtShape Uniqueness]
forall a b. (a -> b) -> [a] -> [b]
map (TypeBase ExtShape Uniqueness, RetAls)
-> TypeBase ExtShape Uniqueness
forall a b. (a, b) -> a
fst ([(TypeBase ExtShape Uniqueness, RetAls)]
-> [TypeBase ExtShape Uniqueness])
-> [(TypeBase ExtShape Uniqueness, RetAls)]
-> [TypeBase ExtShape Uniqueness]
forall a b. (a -> b) -> a -> b
$ FunDef GPU -> [(RetType GPU, RetAls)]
forall rep. FunDef rep -> [(RetType rep, RetAls)]
funDefRetType FunDef GPU
fd)
stms :: Stms GPU
stms = Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms Body GPU
body
in HostOnlyFuns -> [Id] -> Stms GPU -> MigrationTable
analyseStms HostOnlyFuns
hof [Id]
usage Stms GPU
stms
where
f :: [Id] -> (SubExpRes, TypeBase shape u) -> [Id]
f [Id]
usage (SubExpRes Certs
_ (Var VName
n), TypeBase shape u
t) | TypeBase shape u -> Bool
forall shape u. TypeBase shape u -> Bool
isScalarType TypeBase shape u
t = VName -> Id
nameToId VName
n Id -> [Id] -> [Id]
forall a. a -> [a] -> [a]
: [Id]
usage
f [Id]
usage (SubExpRes, TypeBase shape u)
_ = [Id]
usage
analyseStms :: HostOnlyFuns -> HostUsage -> Stms GPU -> MigrationTable
analyseStms :: HostOnlyFuns -> [Id] -> Stms GPU -> MigrationTable
analyseStms HostOnlyFuns
hof [Id]
usage Stms GPU
stms =
let (Graph
g, Sources
srcs, [Id]
_) = HostOnlyFuns -> [Id] -> Stms GPU -> (Graph, Sources, [Id])
buildGraph HostOnlyFuns
hof [Id]
usage Stms GPU
stms
([Id]
routed, [Id]
unrouted) = Sources
srcs
([Id]
_, Graph
g') = [Id] -> Graph -> ([Id], Graph)
forall m. [Id] -> Graph m -> ([Id], Graph m)
MG.routeMany [Id]
unrouted Graph
g
f :: ((Operands, Operands, Operands), Visited ())
-> Id -> ((Operands, Operands, Operands), Visited ())
f ((Operands, Operands, Operands), Visited ())
st' = Graph
-> ((Operands, Operands, Operands)
-> EdgeType -> Vertex Meta -> (Operands, Operands, Operands))
-> ((Operands, Operands, Operands), Visited ())
-> EdgeType
-> Id
-> ((Operands, Operands, Operands), Visited ())
forall m a.
Graph m
-> (a -> EdgeType -> Vertex m -> a)
-> (a, Visited ())
-> EdgeType
-> Id
-> (a, Visited ())
MG.fold Graph
g' (Operands, Operands, Operands)
-> EdgeType -> Vertex Meta -> (Operands, Operands, Operands)
forall {m}.
(Operands, Operands, Operands)
-> EdgeType -> Vertex m -> (Operands, Operands, Operands)
visit ((Operands, Operands, Operands), Visited ())
st' EdgeType
Normal
st :: ((Operands, Operands, Operands), Visited ())
st = (((Operands, Operands, Operands), Visited ())
-> Id -> ((Operands, Operands, Operands), Visited ()))
-> ((Operands, Operands, Operands), Visited ())
-> [Id]
-> ((Operands, Operands, Operands), Visited ())
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((Operands, Operands, Operands), Visited ())
-> Id -> ((Operands, Operands, Operands), Visited ())
f ((Operands, Operands, Operands)
initial, Visited ()
forall a. Visited a
MG.none) [Id]
unrouted
(Operands
vr, Operands
vn, Operands
tn) = ((Operands, Operands, Operands), Visited ())
-> (Operands, Operands, Operands)
forall a b. (a, b) -> a
fst (((Operands, Operands, Operands), Visited ())
-> (Operands, Operands, Operands))
-> ((Operands, Operands, Operands), Visited ())
-> (Operands, Operands, Operands)
forall a b. (a -> b) -> a -> b
$ (((Operands, Operands, Operands), Visited ())
-> Id -> ((Operands, Operands, Operands), Visited ()))
-> ((Operands, Operands, Operands), Visited ())
-> [Id]
-> ((Operands, Operands, Operands), Visited ())
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((Operands, Operands, Operands), Visited ())
-> Id -> ((Operands, Operands, Operands), Visited ())
f ((Operands, Operands, Operands), Visited ())
st [Id]
routed
in
IntMap MigrationStatus -> MigrationTable
MigrationTable (IntMap MigrationStatus -> MigrationTable)
-> IntMap MigrationStatus -> MigrationTable
forall a b. (a -> b) -> a -> b
$
[IntMap MigrationStatus] -> IntMap MigrationStatus
forall (f :: * -> *) a. Foldable f => f (IntMap a) -> IntMap a
IM.unions
[ (Id -> MigrationStatus) -> Operands -> IntMap MigrationStatus
forall a. (Id -> a) -> Operands -> IntMap a
IM.fromSet (MigrationStatus -> Id -> MigrationStatus
forall a b. a -> b -> a
const MigrationStatus
MoveToDevice) Operands
vr,
(Id -> MigrationStatus) -> Operands -> IntMap MigrationStatus
forall a. (Id -> a) -> Operands -> IntMap a
IM.fromSet (MigrationStatus -> Id -> MigrationStatus
forall a b. a -> b -> a
const MigrationStatus
MoveToDevice) Operands
vn,
(Id -> MigrationStatus) -> Operands -> IntMap MigrationStatus
forall a. (Id -> a) -> Operands -> IntMap a
IM.fromSet (MigrationStatus -> Id -> MigrationStatus
forall a b. a -> b -> a
const MigrationStatus
UsedOnHost) Operands
tn
]
where
initial :: (Operands, Operands, Operands)
initial = (Operands
IS.empty, Operands
IS.empty, Operands
IS.empty)
visit :: (Operands, Operands, Operands)
-> EdgeType -> Vertex m -> (Operands, Operands, Operands)
visit (Operands
vr, Operands
vn, Operands
tn) EdgeType
Reversed Vertex m
v =
let vr' :: Operands
vr' = Id -> Operands -> Operands
IS.insert (Vertex m -> Id
forall m. Vertex m -> Id
vertexId Vertex m
v) Operands
vr
in (Operands
vr', Operands
vn, Operands
tn)
visit (Operands
vr, Operands
vn, Operands
tn) EdgeType
Normal v :: Vertex m
v@Vertex {vertexRouting :: forall m. Vertex m -> Routing
vertexRouting = Routing
NoRoute} =
let vn' :: Operands
vn' = Id -> Operands -> Operands
IS.insert (Vertex m -> Id
forall m. Vertex m -> Id
vertexId Vertex m
v) Operands
vn
in (Operands
vr, Operands
vn', Operands
tn)
visit (Operands
vr, Operands
vn, Operands
tn) EdgeType
Normal Vertex m
v =
let tn' :: Operands
tn' = Id -> Operands -> Operands
IS.insert (Vertex m -> Id
forall m. Vertex m -> Id
vertexId Vertex m
v) Operands
tn
in (Operands
vr, Operands
vn, Operands
tn')
isScalar :: (Typed t) => t -> Bool
isScalar :: forall t. Typed t => t -> Bool
isScalar = Type -> Bool
forall shape u. TypeBase shape u -> Bool
isScalarType (Type -> Bool) -> (t -> Type) -> t -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. t -> Type
forall t. Typed t => t -> Type
typeOf
isScalarType :: TypeBase shape u -> Bool
isScalarType :: forall shape u. TypeBase shape u -> Bool
isScalarType (Prim PrimType
Unit) = Bool
False
isScalarType (Prim PrimType
_) = Bool
True
isScalarType TypeBase shape u
_ = Bool
False
isArray :: (Typed t) => t -> Bool
isArray :: forall t. Typed t => t -> Bool
isArray = Type -> Bool
forall shape u. ArrayShape shape => TypeBase shape u -> Bool
isArrayType (Type -> Bool) -> (t -> Type) -> t -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. t -> Type
forall t. Typed t => t -> Type
typeOf
isArrayType :: (ArrayShape shape) => TypeBase shape u -> Bool
isArrayType :: forall shape u. ArrayShape shape => TypeBase shape u -> Bool
isArrayType = (Id
0 <) (Id -> Bool)
-> (TypeBase shape u -> Id) -> TypeBase shape u -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeBase shape u -> Id
forall shape u. ArrayShape shape => TypeBase shape u -> Id
arrayRank
buildGraph :: HostOnlyFuns -> HostUsage -> Stms GPU -> (Graph, Sources, Sinks)
buildGraph :: HostOnlyFuns -> [Id] -> Stms GPU -> (Graph, Sources, [Id])
buildGraph HostOnlyFuns
hof [Id]
usage Stms GPU
stms =
let (Graph
g, Sources
srcs, [Id]
sinks) = HostOnlyFuns -> Grapher () -> (Graph, Sources, [Id])
forall a. HostOnlyFuns -> Grapher a -> (Graph, Sources, [Id])
execGrapher HostOnlyFuns
hof (Stms GPU -> Grapher ()
graphStms Stms GPU
stms)
g' :: Graph
g' = (Graph -> Id -> Graph) -> Graph -> [Id] -> Graph
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((Id -> Graph -> Graph) -> Graph -> Id -> Graph
forall a b c. (a -> b -> c) -> b -> a -> c
flip Id -> Graph -> Graph
forall m. Id -> Graph m -> Graph m
MG.connectToSink) Graph
g [Id]
usage
in (Graph
g', Sources
srcs, [Id]
sinks)
graphBody :: Body GPU -> Grapher ()
graphBody :: Body GPU -> Grapher ()
graphBody Body GPU
body = do
let res_ops :: Operands
res_ops = Names -> Operands
namesIntSet (Names -> Operands) -> Names -> Operands
forall a b. (a -> b) -> a -> b
$ [SubExpRes] -> Names
forall a. FreeIn a => a -> Names
freeIn (Body GPU -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body GPU
body)
BodyStats
body_stats <-
Grapher () -> Grapher BodyStats
forall a. Grapher a -> Grapher BodyStats
captureBodyStats (Grapher () -> Grapher BodyStats)
-> Grapher () -> Grapher BodyStats
forall a b. (a -> b) -> a -> b
$
Grapher () -> Grapher ()
forall a. Grapher a -> Grapher a
incBodyDepthFor (Stms GPU -> Grapher ()
graphStms (Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms Body GPU
body) Grapher () -> Grapher () -> Grapher ()
forall a b.
StateT State (Reader Env) a
-> StateT State (Reader Env) b -> StateT State (Reader Env) b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Operands -> Grapher ()
tellOperands Operands
res_ops)
Id
body_depth <- (Id
1 +) (Id -> Id)
-> StateT State (Reader Env) Id -> StateT State (Reader Env) Id
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StateT State (Reader Env) Id
getBodyDepth
let host_only :: Bool
host_only = Id -> Operands -> Bool
IS.member Id
body_depth (BodyStats -> Operands
bodyHostOnlyParents BodyStats
body_stats)
(State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st ->
let stats :: BodyStats
stats = State -> BodyStats
stateStats State
st
hops' :: Operands
hops' = Id -> Operands -> Operands
IS.delete Id
body_depth (BodyStats -> Operands
bodyHostOnlyParents BodyStats
stats)
stats' :: BodyStats
stats' = if Bool
host_only then BodyStats
stats {bodyHostOnly = True} else BodyStats
stats
in State
st {stateStats = stats' {bodyHostOnlyParents = hops'}}
graphStms :: Stms GPU -> Grapher ()
graphStms :: Stms GPU -> Grapher ()
graphStms = (Stm GPU -> Grapher ()) -> Stms GPU -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm GPU -> Grapher ()
graphStm
graphStm :: Stm GPU -> Grapher ()
graphStm :: Stm GPU -> Grapher ()
graphStm Stm GPU
stm = do
let bs :: [Binding]
bs = Stm GPU -> [Binding]
boundBy Stm GPU
stm
let e :: Exp GPU
e = Stm GPU -> Exp GPU
forall rep. Stm rep -> Exp rep
stmExp Stm GPU
stm
case Exp GPU
e of
BasicOp (SubExp SubExp
se) -> do
[Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
[Binding] -> Binding
forall {a}. [a] -> a
one [Binding]
bs Binding -> SubExp -> Grapher ()
`reusesSubExp` SubExp
se
BasicOp (Opaque OpaqueOp
_ SubExp
se) -> do
[Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
[Binding] -> Binding
forall {a}. [a] -> a
one [Binding]
bs Binding -> SubExp -> Grapher ()
`reusesSubExp` SubExp
se
BasicOp (ArrayLit [SubExp]
arr Type
t)
| Type -> Bool
forall t. Typed t => t -> Bool
isScalar Type
t,
(SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Maybe VName -> Bool
forall {a}. Maybe a -> Bool
isJust (Maybe VName -> Bool) -> (SubExp -> Maybe VName) -> SubExp -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> Maybe VName
subExpVar) [SubExp]
arr ->
Binding -> Grapher ()
graphAutoMove ([Binding] -> Binding
forall {a}. [a] -> a
one [Binding]
bs)
BasicOp UnOp {} -> [Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
BasicOp BinOp {} -> [Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
BasicOp CmpOp {} -> [Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
BasicOp ConvOp {} -> [Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
BasicOp Assert {} ->
[Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
BasicOp (Index VName
_ Slice SubExp
slice)
| Slice SubExp -> Bool
forall {d}. Slice d -> Bool
isFixed Slice SubExp
slice ->
Binding -> Grapher ()
graphRead ([Binding] -> Binding
forall {a}. [a] -> a
one [Binding]
bs)
BasicOp {}
| [(Id
_, Type
t)] <- [Binding]
bs,
[SubExp]
dims <- Type -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims Type
t,
[SubExp]
dims [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
/= [],
(SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) [SubExp]
dims ->
[Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
BasicOp (Index VName
arr Slice SubExp
s) -> do
[SubExp] -> Exp GPU -> Grapher ()
forall {t :: * -> *}.
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn (Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
s) Exp GPU
e
[Binding] -> Binding
forall {a}. [a] -> a
one [Binding]
bs Binding -> VName -> Grapher ()
`reuses` VName
arr
BasicOp (Update Safety
_ VName
arr Slice SubExp
slice SubExp
_)
| Slice SubExp -> Bool
forall {d}. Slice d -> Bool
isFixed Slice SubExp
slice -> do
[SubExp] -> Exp GPU -> Grapher ()
forall {t :: * -> *}.
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn [] Exp GPU
e
[Binding] -> Binding
forall {a}. [a] -> a
one [Binding]
bs Binding -> VName -> Grapher ()
`reuses` VName
arr
BasicOp (FlatIndex VName
arr FlatSlice SubExp
s) -> do
[SubExp] -> Exp GPU -> Grapher ()
forall {t :: * -> *}.
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn (FlatSlice SubExp -> [SubExp]
forall d. FlatSlice d -> [d]
flatSliceDims FlatSlice SubExp
s) Exp GPU
e
[Binding] -> Binding
forall {a}. [a] -> a
one [Binding]
bs Binding -> VName -> Grapher ()
`reuses` VName
arr
BasicOp (FlatUpdate VName
arr FlatSlice SubExp
_ VName
_) -> do
[SubExp] -> Exp GPU -> Grapher ()
forall {t :: * -> *}.
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn [] Exp GPU
e
[Binding] -> Binding
forall {a}. [a] -> a
one [Binding]
bs Binding -> VName -> Grapher ()
`reuses` VName
arr
BasicOp (Scratch PrimType
_ [SubExp]
s) ->
[SubExp] -> Exp GPU -> Grapher ()
forall {t :: * -> *}.
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn [SubExp]
s Exp GPU
e
BasicOp (Reshape ReshapeKind
_ ShapeBase SubExp
s VName
arr) -> do
[SubExp] -> Exp GPU -> Grapher ()
forall {t :: * -> *}.
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn (ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
s) Exp GPU
e
[Binding] -> Binding
forall {a}. [a] -> a
one [Binding]
bs Binding -> VName -> Grapher ()
`reuses` VName
arr
BasicOp (Rearrange [Id]
_ VName
arr) -> do
[SubExp] -> Exp GPU -> Grapher ()
forall {t :: * -> *}.
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn [] Exp GPU
e
[Binding] -> Binding
forall {a}. [a] -> a
one [Binding]
bs Binding -> VName -> Grapher ()
`reuses` VName
arr
BasicOp ArrayLit {} ->
Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
BasicOp ArrayVal {} ->
Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
BasicOp Update {} ->
Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
BasicOp Concat {} ->
Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
BasicOp Manifest {} ->
Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
BasicOp Iota {} -> Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
BasicOp Replicate {} -> Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
BasicOp UpdateAcc {} ->
Binding -> Exp GPU -> Grapher ()
graphUpdateAcc ([Binding] -> Binding
forall {a}. [a] -> a
one [Binding]
bs) Exp GPU
e
Apply Name
fn [(SubExp, Diet)]
_ [(RetType GPU, RetAls)]
_ (Safety, SrcLoc, [SrcLoc])
_ ->
Name -> [Binding] -> Exp GPU -> Grapher ()
graphApply Name
fn [Binding]
bs Exp GPU
e
Match [SubExp]
ses [Case (Body GPU)]
cases Body GPU
defbody MatchDec (BranchType GPU)
_ ->
[Binding]
-> [SubExp] -> [Case (Body GPU)] -> Body GPU -> Grapher ()
graphMatch [Binding]
bs [SubExp]
ses [Case (Body GPU)]
cases Body GPU
defbody
Loop [(FParam GPU, SubExp)]
params LoopForm
lform Body GPU
body ->
[Binding]
-> [(FParam GPU, SubExp)] -> LoopForm -> Body GPU -> Grapher ()
graphLoop [Binding]
bs [(FParam GPU, SubExp)]
params LoopForm
lform Body GPU
body
WithAcc [WithAccInput GPU]
inputs Lambda GPU
f ->
[Binding] -> [WithAccInput GPU] -> Lambda GPU -> Grapher ()
graphWithAcc [Binding]
bs [WithAccInput GPU]
inputs Lambda GPU
f
Op GPUBody {} ->
Grapher ()
tellGPUBody
Op Op GPU
_ ->
Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
where
one :: [a] -> a
one [a
x] = a
x
one [a]
_ = String -> a
forall a. String -> a
compilerBugS String
"Type error: unexpected number of pattern elements."
isFixed :: Slice d -> Bool
isFixed = Maybe [d] -> Bool
forall {a}. Maybe a -> Bool
isJust (Maybe [d] -> Bool) -> (Slice d -> Maybe [d]) -> Slice d -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Slice d -> Maybe [d]
forall d. Slice d -> Maybe [d]
sliceIndices
graphInefficientReturn :: t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn t SubExp
new_dims Exp GPU
e = do
(SubExp -> Grapher ()) -> t SubExp -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SubExp -> Grapher ()
hostSize t SubExp
new_dims
Exp GPU -> Grapher Operands
graphedScalarOperands Exp GPU
e Grapher Operands -> (Operands -> Grapher ()) -> Grapher ()
forall a b.
StateT State (Reader Env) a
-> (a -> StateT State (Reader Env) b)
-> StateT State (Reader Env) b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Edges -> Operands -> Grapher ()
addEdges Edges
ToSink
hostSize :: SubExp -> Grapher ()
hostSize (Var VName
n) = VName -> Grapher ()
hostSizeVar VName
n
hostSize SubExp
_ = () -> Grapher ()
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
hostSizeVar :: VName -> Grapher ()
hostSizeVar = Id -> Grapher ()
requiredOnHost (Id -> Grapher ()) -> (VName -> Id) -> VName -> Grapher ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Id
nameToId
boundBy :: Stm GPU -> [Binding]
boundBy :: Stm GPU -> [Binding]
boundBy = (PatElem Type -> Binding) -> [PatElem Type] -> [Binding]
forall a b. (a -> b) -> [a] -> [b]
map (\(PatElem VName
n Type
t) -> (VName -> Id
nameToId VName
n, Type
t)) ([PatElem Type] -> [Binding])
-> (Stm GPU -> [PatElem Type]) -> Stm GPU -> [Binding]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems (Pat Type -> [PatElem Type])
-> (Stm GPU -> Pat Type) -> Stm GPU -> [PatElem Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm GPU -> Pat Type
Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat
graphSimple :: [Binding] -> Exp GPU -> Grapher ()
graphSimple :: [Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e = do
Operands
ops <- Exp GPU -> Grapher Operands
graphedScalarOperands Exp GPU
e
let edges :: Edges
edges = [Id] -> Edges
MG.declareEdges ((Binding -> Id) -> [Binding] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map Binding -> Id
forall a b. (a, b) -> a
fst [Binding]
bs)
Bool -> Grapher () -> Grapher ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Operands -> Bool
IS.null Operands
ops) ((Binding -> Grapher ()) -> [Binding] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Binding -> Grapher ()
addVertex [Binding]
bs Grapher () -> Grapher () -> Grapher ()
forall a b.
StateT State (Reader Env) a
-> StateT State (Reader Env) b -> StateT State (Reader Env) b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Edges -> Operands -> Grapher ()
addEdges Edges
edges Operands
ops)
graphRead :: Binding -> Grapher ()
graphRead :: Binding -> Grapher ()
graphRead Binding
b = do
Binding -> Grapher ()
addSource Binding
b
Grapher ()
tellRead
graphAutoMove :: Binding -> Grapher ()
graphAutoMove :: Binding -> Grapher ()
graphAutoMove =
Binding -> Grapher ()
addSource
graphHostOnly :: Exp GPU -> Grapher ()
graphHostOnly :: Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e = do
Operands
ops <- Exp GPU -> Grapher Operands
graphedScalarOperands Exp GPU
e
Edges -> Operands -> Grapher ()
addEdges Edges
ToSink Operands
ops
Grapher ()
tellHostOnly
graphUpdateAcc :: Binding -> Exp GPU -> Grapher ()
graphUpdateAcc :: Binding -> Exp GPU -> Grapher ()
graphUpdateAcc Binding
b Exp GPU
e | (Id
_, Acc VName
a ShapeBase SubExp
_ [Type]
_ NoUniqueness
_) <- Binding
b =
(State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st ->
let accs :: IntMap [Delayed]
accs = State -> IntMap [Delayed]
stateUpdateAccs State
st
accs' :: IntMap [Delayed]
accs' = (Maybe [Delayed] -> Maybe [Delayed])
-> Id -> IntMap [Delayed] -> IntMap [Delayed]
forall a. (Maybe a -> Maybe a) -> Id -> IntMap a -> IntMap a
IM.alter Maybe [Delayed] -> Maybe [Delayed]
add (VName -> Id
nameToId VName
a) IntMap [Delayed]
accs
in State
st {stateUpdateAccs = accs'}
where
add :: Maybe [Delayed] -> Maybe [Delayed]
add Maybe [Delayed]
Nothing = [Delayed] -> Maybe [Delayed]
forall a. a -> Maybe a
Just [(Binding
b, Exp GPU
e)]
add (Just [Delayed]
xs) = [Delayed] -> Maybe [Delayed]
forall a. a -> Maybe a
Just ([Delayed] -> Maybe [Delayed]) -> [Delayed] -> Maybe [Delayed]
forall a b. (a -> b) -> a -> b
$ (Binding
b, Exp GPU
e) Delayed -> [Delayed] -> [Delayed]
forall a. a -> [a] -> [a]
: [Delayed]
xs
graphUpdateAcc Binding
_ Exp GPU
_ =
String -> Grapher ()
forall a. String -> a
compilerBugS
String
"Type error: UpdateAcc did not produce accumulator typed value."
graphApply :: Name -> [Binding] -> Exp GPU -> Grapher ()
graphApply :: Name -> [Binding] -> Exp GPU -> Grapher ()
graphApply Name
fn [Binding]
bs Exp GPU
e = do
Bool
hof <- Name -> Grapher Bool
isHostOnlyFun Name
fn
if Bool
hof
then Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
else [Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
graphMatch :: [Binding] -> [SubExp] -> [Case (Body GPU)] -> Body GPU -> Grapher ()
graphMatch :: [Binding]
-> [SubExp] -> [Case (Body GPU)] -> Body GPU -> Grapher ()
graphMatch [Binding]
bs [SubExp]
ses [Case (Body GPU)]
cases Body GPU
defbody = do
Bool
body_host_only <-
Grapher Bool -> Grapher Bool
forall a. Grapher a -> Grapher a
incForkDepthFor (Grapher Bool -> Grapher Bool) -> Grapher Bool -> Grapher Bool
forall a b. (a -> b) -> a -> b
$
(BodyStats -> Bool) -> [BodyStats] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any BodyStats -> Bool
bodyHostOnly
([BodyStats] -> Bool)
-> StateT State (Reader Env) [BodyStats] -> Grapher Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Body GPU -> Grapher BodyStats)
-> [Body GPU] -> StateT State (Reader Env) [BodyStats]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Grapher () -> Grapher BodyStats
forall a. Grapher a -> Grapher BodyStats
captureBodyStats (Grapher () -> Grapher BodyStats)
-> (Body GPU -> Grapher ()) -> Body GPU -> Grapher BodyStats
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body GPU -> Grapher ()
graphBody) (Body GPU
defbody Body GPU -> [Body GPU] -> [Body GPU]
forall a. a -> [a] -> [a]
: (Case (Body GPU) -> Body GPU) -> [Case (Body GPU)] -> [Body GPU]
forall a b. (a -> b) -> [a] -> [b]
map Case (Body GPU) -> Body GPU
forall body. Case body -> body
caseBody [Case (Body GPU)]
cases)
let branch_results :: [[SubExp]]
branch_results = Body GPU -> [SubExp]
forall {rep}. Body rep -> [SubExp]
results Body GPU
defbody [SubExp] -> [[SubExp]] -> [[SubExp]]
forall a. a -> [a] -> [a]
: (Case (Body GPU) -> [SubExp]) -> [Case (Body GPU)] -> [[SubExp]]
forall a b. (a -> b) -> [a] -> [b]
map (Body GPU -> [SubExp]
forall {rep}. Body rep -> [SubExp]
results (Body GPU -> [SubExp])
-> (Case (Body GPU) -> Body GPU) -> Case (Body GPU) -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Case (Body GPU) -> Body GPU
forall body. Case body -> body
caseBody) [Case (Body GPU)]
cases
Bool
may_copy_results <- [Binding] -> [[SubExp]] -> Grapher Bool
reusesBranches [Binding]
bs [[SubExp]]
branch_results
let may_migrate :: Bool
may_migrate = Bool -> Bool
not Bool
body_host_only Bool -> Bool -> Bool
&& Bool
may_copy_results
Operands
cond_id <-
if Bool
may_migrate
then [VName] -> Grapher Operands
forall (t :: * -> *). Foldable t => t VName -> Grapher Operands
onlyGraphedScalars ([VName] -> Grapher Operands) -> [VName] -> Grapher Operands
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [VName]
subExpVars [SubExp]
ses
else do
(VName -> Grapher ()) -> [VName] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Id -> Grapher ()
connectToSink (Id -> Grapher ()) -> (VName -> Id) -> VName -> Grapher ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Id
nameToId) ([SubExp] -> [VName]
subExpVars [SubExp]
ses)
Operands -> Grapher Operands
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Operands
IS.empty
Operands -> Grapher ()
tellOperands Operands
cond_id
[Operands]
ret <- ([SubExp] -> Grapher Operands)
-> [[SubExp]] -> StateT State (Reader Env) [Operands]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Operands -> [SubExp] -> Grapher Operands
comb Operands
cond_id) ([[SubExp]] -> StateT State (Reader Env) [Operands])
-> [[SubExp]] -> StateT State (Reader Env) [Operands]
forall a b. (a -> b) -> a -> b
$ [[SubExp]] -> [[SubExp]]
forall a. [[a]] -> [[a]]
L.transpose [[SubExp]]
branch_results
((Binding, Operands) -> Grapher ())
-> [(Binding, Operands)] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((Binding -> Operands -> Grapher ())
-> (Binding, Operands) -> Grapher ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Binding -> Operands -> Grapher ()
createNode) ([Binding] -> [Operands] -> [(Binding, Operands)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Binding]
bs [Operands]
ret)
where
results :: Body rep -> [SubExp]
results = (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ([SubExpRes] -> [SubExp])
-> (Body rep -> [SubExpRes]) -> Body rep -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body rep -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult
comb :: Operands -> [SubExp] -> Grapher Operands
comb Operands
ci [SubExp]
a = (Operands
ci <>) (Operands -> Operands) -> Grapher Operands -> Grapher Operands
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Set VName -> Grapher Operands
forall (t :: * -> *). Foldable t => t VName -> Grapher Operands
onlyGraphedScalars ([VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList ([VName] -> Set VName) -> [VName] -> Set VName
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [VName]
subExpVars [SubExp]
a)
type ReachableBindings = IdSet
type ReachableBindingsCache = MG.Visited (MG.Result ReachableBindings)
type NonExhausted = [Id]
type LoopValue = (Binding, Id, SubExp, SubExp)
graphLoop ::
[Binding] ->
[(FParam GPU, SubExp)] ->
LoopForm ->
Body GPU ->
Grapher ()
graphLoop :: [Binding]
-> [(FParam GPU, SubExp)] -> LoopForm -> Body GPU -> Grapher ()
graphLoop [] [(FParam GPU, SubExp)]
_ LoopForm
_ Body GPU
_ =
String -> Grapher ()
forall a. String -> a
compilerBugS String
"Loop statement bound no variable; should have been eliminated."
graphLoop (Binding
b : [Binding]
bs) [(FParam GPU, SubExp)]
params LoopForm
lform Body GPU
body = do
Graph
g <- Grapher Graph
getGraph
BodyStats
stats <- Grapher () -> Grapher BodyStats
forall a. Grapher a -> Grapher BodyStats
captureBodyStats (Id
subgraphId Id -> Grapher () -> Grapher ()
forall a. Id -> Grapher a -> Grapher a
`graphIdFor` Grapher ()
graphTheLoop)
let args :: [SubExp]
args = ((Param DeclType, SubExp) -> SubExp)
-> [(Param DeclType, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(Param DeclType, SubExp)]
[(FParam GPU, SubExp)]
params
let results :: [SubExp]
results = (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (Body GPU -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body GPU
body)
Bool
may_copy_results <- [Binding] -> [[SubExp]] -> Grapher Bool
reusesBranches (Binding
b Binding -> [Binding] -> [Binding]
forall a. a -> [a] -> [a]
: [Binding]
bs) [[SubExp]
args, [SubExp]
results]
let may_migrate :: Bool
may_migrate = Bool -> Bool
not (BodyStats -> Bool
bodyHostOnly BodyStats
stats) Bool -> Bool -> Bool
&& Bool
may_copy_results
Bool -> Grapher () -> Grapher ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
may_migrate (Grapher () -> Grapher ()) -> Grapher () -> Grapher ()
forall a b. (a -> b) -> a -> b
$ case LoopForm
lform of
ForLoop VName
_ IntType
_ (Var VName
n) -> Id -> Grapher ()
connectToSink (VName -> Id
nameToId VName
n)
WhileLoop VName
n
| Just (Binding
_, Id
p, SubExp
_, SubExp
res) <- VName -> Maybe (Binding, Id, SubExp, SubExp)
loopValueFor VName
n -> do
Id -> Grapher ()
connectToSink Id
p
case SubExp
res of
Var VName
v -> Id -> Grapher ()
connectToSink (VName -> Id
nameToId VName
v)
SubExp
_ -> () -> Grapher ()
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
LoopForm
_ -> () -> Grapher ()
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
((Binding, Id, SubExp, SubExp) -> Grapher ())
-> [(Binding, Id, SubExp, SubExp)] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Binding, Id, SubExp, SubExp) -> Grapher ()
mergeLoopParam [(Binding, Id, SubExp, SubExp)]
loopValues
[Id]
srcs <- Id -> Grapher [Id]
routeSubgraph Id
subgraphId
[(Binding, Id, SubExp, SubExp)]
-> ((Binding, Id, SubExp, SubExp) -> Grapher ()) -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Binding, Id, SubExp, SubExp)]
loopValues (((Binding, Id, SubExp, SubExp) -> Grapher ()) -> Grapher ())
-> ((Binding, Id, SubExp, SubExp) -> Grapher ()) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \(Binding
bnd, Id
p, SubExp
_, SubExp
_) -> Binding -> Operands -> Grapher ()
createNode Binding
bnd (Id -> Operands
IS.singleton Id
p)
Graph
g' <- Grapher Graph
getGraph
let (Operands
dbs, ReachableBindingsCache
rbc) = ((Operands, ReachableBindingsCache)
-> Id -> (Operands, ReachableBindingsCache))
-> (Operands, ReachableBindingsCache)
-> [Id]
-> (Operands, ReachableBindingsCache)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (Graph
-> (Operands, ReachableBindingsCache)
-> Id
-> (Operands, ReachableBindingsCache)
deviceBindings Graph
g') (Operands
IS.empty, ReachableBindingsCache
forall a. Visited a
MG.none) [Id]
srcs
(Sources -> Sources) -> Grapher ()
modifySources ((Sources -> Sources) -> Grapher ())
-> (Sources -> Sources) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ ([Id] -> [Id]) -> Sources -> Sources
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (Operands -> [Id]
IS.toList Operands
dbs <>)
let ops :: Operands
ops = (Id -> Bool) -> Operands -> Operands
IS.filter (Id -> Graph -> Bool
forall m. Id -> Graph m -> Bool
`MG.member` Graph
g) (BodyStats -> Operands
bodyOperands BodyStats
stats)
(ReachableBindingsCache
-> Id -> StateT State (Reader Env) ReachableBindingsCache)
-> ReachableBindingsCache -> [Id] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m ()
foldM_ ReachableBindingsCache
-> Id -> StateT State (Reader Env) ReachableBindingsCache
connectOperand ReachableBindingsCache
rbc (Operands -> [Id]
IS.elems Operands
ops)
Bool -> Grapher () -> Grapher ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
may_migrate (Grapher () -> Grapher ()) -> Grapher () -> Grapher ()
forall a b. (a -> b) -> a -> b
$ case LoopForm
lform of
ForLoop VName
_ IntType
_ SubExp
n ->
SubExp -> Grapher Operands
onlyGraphedScalarSubExp SubExp
n Grapher Operands -> (Operands -> Grapher ()) -> Grapher ()
forall a b.
StateT State (Reader Env) a
-> (a -> StateT State (Reader Env) b)
-> StateT State (Reader Env) b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Edges -> Operands -> Grapher ()
addEdges (Operands -> Maybe Operands -> Edges
ToNodes Operands
bindings Maybe Operands
forall a. Maybe a
Nothing)
WhileLoop VName
n
| Just (Binding
_, Id
_, SubExp
arg, SubExp
_) <- VName -> Maybe (Binding, Id, SubExp, SubExp)
loopValueFor VName
n ->
SubExp -> Grapher Operands
onlyGraphedScalarSubExp SubExp
arg Grapher Operands -> (Operands -> Grapher ()) -> Grapher ()
forall a b.
StateT State (Reader Env) a
-> (a -> StateT State (Reader Env) b)
-> StateT State (Reader Env) b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Edges -> Operands -> Grapher ()
addEdges (Operands -> Maybe Operands -> Edges
ToNodes Operands
bindings Maybe Operands
forall a. Maybe a
Nothing)
LoopForm
_ -> () -> Grapher ()
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
where
subgraphId :: Id
subgraphId :: Id
subgraphId = Binding -> Id
forall a b. (a, b) -> a
fst Binding
b
loopValues :: [LoopValue]
loopValues :: [(Binding, Id, SubExp, SubExp)]
loopValues =
let tmp :: [(Binding, (Param DeclType, SubExp), SubExpRes)]
tmp = [Binding]
-> [(Param DeclType, SubExp)]
-> [SubExpRes]
-> [(Binding, (Param DeclType, SubExp), SubExpRes)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (Binding
b Binding -> [Binding] -> [Binding]
forall a. a -> [a] -> [a]
: [Binding]
bs) [(Param DeclType, SubExp)]
[(FParam GPU, SubExp)]
params (Body GPU -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body GPU
body)
tmp' :: [(Binding, Id, SubExp, SubExp)]
tmp' = (((Binding, (Param DeclType, SubExp), SubExpRes)
-> (Binding, Id, SubExp, SubExp))
-> [(Binding, (Param DeclType, SubExp), SubExpRes)]
-> [(Binding, Id, SubExp, SubExp)])
-> [(Binding, (Param DeclType, SubExp), SubExpRes)]
-> ((Binding, (Param DeclType, SubExp), SubExpRes)
-> (Binding, Id, SubExp, SubExp))
-> [(Binding, Id, SubExp, SubExp)]
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Binding, (Param DeclType, SubExp), SubExpRes)
-> (Binding, Id, SubExp, SubExp))
-> [(Binding, (Param DeclType, SubExp), SubExpRes)]
-> [(Binding, Id, SubExp, SubExp)]
forall a b. (a -> b) -> [a] -> [b]
map [(Binding, (Param DeclType, SubExp), SubExpRes)]
tmp (((Binding, (Param DeclType, SubExp), SubExpRes)
-> (Binding, Id, SubExp, SubExp))
-> [(Binding, Id, SubExp, SubExp)])
-> ((Binding, (Param DeclType, SubExp), SubExpRes)
-> (Binding, Id, SubExp, SubExp))
-> [(Binding, Id, SubExp, SubExp)]
forall a b. (a -> b) -> a -> b
$
\(Binding
bnd, (Param DeclType
p, SubExp
arg), SubExpRes
res) ->
let i :: Id
i = VName -> Id
nameToId (Param DeclType -> VName
forall dec. Param dec -> VName
paramName Param DeclType
p)
in (Binding
bnd, Id
i, SubExp
arg, SubExpRes -> SubExp
resSubExp SubExpRes
res)
in ((Binding, Id, SubExp, SubExp) -> Bool)
-> [(Binding, Id, SubExp, SubExp)]
-> [(Binding, Id, SubExp, SubExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter (\((Id
_, Type
t), Id
_, SubExp
_, SubExp
_) -> Type -> Bool
forall t. Typed t => t -> Bool
isScalar Type
t) [(Binding, Id, SubExp, SubExp)]
tmp'
bindings :: IdSet
bindings :: Operands
bindings = [Id] -> Operands
IS.fromList ([Id] -> Operands) -> [Id] -> Operands
forall a b. (a -> b) -> a -> b
$ ((Binding, Id, SubExp, SubExp) -> Id)
-> [(Binding, Id, SubExp, SubExp)] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map (\((Id
i, Type
_), Id
_, SubExp
_, SubExp
_) -> Id
i) [(Binding, Id, SubExp, SubExp)]
loopValues
loopValueFor :: VName -> Maybe (Binding, Id, SubExp, SubExp)
loopValueFor VName
n =
((Binding, Id, SubExp, SubExp) -> Bool)
-> [(Binding, Id, SubExp, SubExp)]
-> Maybe (Binding, Id, SubExp, SubExp)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (\(Binding
_, Id
p, SubExp
_, SubExp
_) -> Id
p Id -> Id -> Bool
forall a. Eq a => a -> a -> Bool
== VName -> Id
nameToId VName
n) [(Binding, Id, SubExp, SubExp)]
loopValues
graphTheLoop :: Grapher ()
graphTheLoop :: Grapher ()
graphTheLoop = do
((Binding, Id, SubExp, SubExp) -> Grapher ())
-> [(Binding, Id, SubExp, SubExp)] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Binding, Id, SubExp, SubExp) -> Grapher ()
forall {a} {d}. ((a, Type), Id, SubExp, d) -> Grapher ()
graphParam [(Binding, Id, SubExp, SubExp)]
loopValues
case LoopForm
lform of
ForLoop VName
_ IntType
_ SubExp
n ->
SubExp -> Grapher Operands
onlyGraphedScalarSubExp SubExp
n Grapher Operands -> (Operands -> Grapher ()) -> Grapher ()
forall a b.
StateT State (Reader Env) a
-> (a -> StateT State (Reader Env) b)
-> StateT State (Reader Env) b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Operands -> Grapher ()
tellOperands
WhileLoop VName
_ -> () -> Grapher ()
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
Body GPU -> Grapher ()
graphBody Body GPU
body
where
graphParam :: ((a, Type), Id, SubExp, d) -> Grapher ()
graphParam ((a
_, Type
t), Id
p, SubExp
arg, d
_) =
do
Binding -> Grapher ()
addVertex (Id
p, Type
t)
Operands
ops <- SubExp -> Grapher Operands
onlyGraphedScalarSubExp SubExp
arg
Edges -> Operands -> Grapher ()
addEdges (Id -> Edges
MG.oneEdge Id
p) Operands
ops
mergeLoopParam :: LoopValue -> Grapher ()
mergeLoopParam :: (Binding, Id, SubExp, SubExp) -> Grapher ()
mergeLoopParam (Binding
_, Id
p, SubExp
_, SubExp
res)
| Var VName
n <- SubExp
res,
Id
ret <- VName -> Id
nameToId VName
n,
Id
ret Id -> Id -> Bool
forall a. Eq a => a -> a -> Bool
/= Id
p =
Edges -> Operands -> Grapher ()
addEdges (Id -> Edges
MG.oneEdge Id
p) (Id -> Operands
IS.singleton Id
ret)
| Bool
otherwise =
() -> Grapher ()
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
deviceBindings ::
Graph ->
(ReachableBindings, ReachableBindingsCache) ->
Id ->
(ReachableBindings, ReachableBindingsCache)
deviceBindings :: Graph
-> (Operands, ReachableBindingsCache)
-> Id
-> (Operands, ReachableBindingsCache)
deviceBindings Graph
g (Operands
rb, ReachableBindingsCache
rbc) Id
i =
let (Result Operands
r, ReachableBindingsCache
rbc') = Graph
-> (Operands -> EdgeType -> Vertex Meta -> Operands)
-> ReachableBindingsCache
-> EdgeType
-> Id
-> (Result Operands, ReachableBindingsCache)
forall a m.
Monoid a =>
Graph m
-> (a -> EdgeType -> Vertex m -> a)
-> Visited (Result a)
-> EdgeType
-> Id
-> (Result a, Visited (Result a))
MG.reduce Graph
g Operands -> EdgeType -> Vertex Meta -> Operands
bindingReach ReachableBindingsCache
rbc EdgeType
Normal Id
i
in case Result Operands
r of
Produced Operands
rb' -> (Operands
rb Operands -> Operands -> Operands
forall a. Semigroup a => a -> a -> a
<> Operands
rb', ReachableBindingsCache
rbc')
Result Operands
_ ->
String -> (Operands, ReachableBindingsCache)
forall a. String -> a
compilerBugS
String
"Migration graph sink could be reached from source after it\
\ had been attempted routed."
bindingReach ::
ReachableBindings ->
EdgeType ->
Vertex Meta ->
ReachableBindings
bindingReach :: Operands -> EdgeType -> Vertex Meta -> Operands
bindingReach Operands
rb EdgeType
_ Vertex Meta
v
| Id
i <- Vertex Meta -> Id
forall m. Vertex m -> Id
vertexId Vertex Meta
v,
Id -> Operands -> Bool
IS.member Id
i Operands
bindings =
Id -> Operands -> Operands
IS.insert Id
i Operands
rb
| Bool
otherwise =
Operands
rb
connectOperand ::
ReachableBindingsCache ->
Id ->
Grapher ReachableBindingsCache
connectOperand :: ReachableBindingsCache
-> Id -> StateT State (Reader Env) ReachableBindingsCache
connectOperand ReachableBindingsCache
cache Id
op = do
Graph
g <- Grapher Graph
getGraph
case Id -> Graph -> Maybe (Vertex Meta)
forall m. Id -> Graph m -> Maybe (Vertex m)
MG.lookup Id
op Graph
g of
Maybe (Vertex Meta)
Nothing -> ReachableBindingsCache
-> StateT State (Reader Env) ReachableBindingsCache
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ReachableBindingsCache
cache
Just Vertex Meta
v ->
case Vertex Meta -> Edges
forall m. Vertex m -> Edges
vertexEdges Vertex Meta
v of
Edges
ToSink -> ReachableBindingsCache
-> StateT State (Reader Env) ReachableBindingsCache
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ReachableBindingsCache
cache
ToNodes Operands
es Maybe Operands
Nothing -> Graph
-> ReachableBindingsCache
-> Id
-> Operands
-> StateT State (Reader Env) ReachableBindingsCache
connectOp Graph
g ReachableBindingsCache
cache Id
op Operands
es
ToNodes Operands
_ (Just Operands
nx) -> Graph
-> ReachableBindingsCache
-> Id
-> Operands
-> StateT State (Reader Env) ReachableBindingsCache
connectOp Graph
g ReachableBindingsCache
cache Id
op Operands
nx
where
connectOp ::
Graph ->
ReachableBindingsCache ->
Id ->
IdSet ->
Grapher ReachableBindingsCache
connectOp :: Graph
-> ReachableBindingsCache
-> Id
-> Operands
-> StateT State (Reader Env) ReachableBindingsCache
connectOp Graph
g ReachableBindingsCache
rbc Id
i Operands
es = do
let (Result Operands
res, [Id]
nx, ReachableBindingsCache
rbc') = Graph
-> (Operands, [Id], ReachableBindingsCache)
-> [Id]
-> (Result Operands, [Id], ReachableBindingsCache)
findBindings Graph
g (Operands
IS.empty, [], ReachableBindingsCache
rbc) (Operands -> [Id]
IS.elems Operands
es)
case Result Operands
res of
Result Operands
FoundSink -> Id -> Grapher ()
connectToSink Id
i
Produced Operands
rb -> (Graph -> Graph) -> Grapher ()
modifyGraph ((Graph -> Graph) -> Grapher ()) -> (Graph -> Graph) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ (Vertex Meta -> Vertex Meta) -> Id -> Graph -> Graph
forall m. (Vertex m -> Vertex m) -> Id -> Graph m -> Graph m
MG.adjust ([Id] -> Operands -> Vertex Meta -> Vertex Meta
updateEdges [Id]
nx Operands
rb) Id
i
ReachableBindingsCache
-> StateT State (Reader Env) ReachableBindingsCache
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ReachableBindingsCache
rbc'
updateEdges ::
NonExhausted ->
ReachableBindings ->
Vertex Meta ->
Vertex Meta
updateEdges :: [Id] -> Operands -> Vertex Meta -> Vertex Meta
updateEdges [Id]
nx Operands
rb Vertex Meta
v
| ToNodes Operands
es Maybe Operands
_ <- Vertex Meta -> Edges
forall m. Vertex m -> Edges
vertexEdges Vertex Meta
v =
let nx' :: Operands
nx' = [Id] -> Operands
IS.fromList [Id]
nx
es' :: Edges
es' = Operands -> Maybe Operands -> Edges
ToNodes (Operands
rb Operands -> Operands -> Operands
forall a. Semigroup a => a -> a -> a
<> Operands
es) (Maybe Operands -> Edges) -> Maybe Operands -> Edges
forall a b. (a -> b) -> a -> b
$ Operands -> Maybe Operands
forall a. a -> Maybe a
Just (Operands
rb Operands -> Operands -> Operands
forall a. Semigroup a => a -> a -> a
<> Operands
nx')
in Vertex Meta
v {vertexEdges = es'}
| Bool
otherwise = Vertex Meta
v
findBindings ::
Graph ->
(ReachableBindings, NonExhausted, ReachableBindingsCache) ->
[Id] ->
(MG.Result ReachableBindings, NonExhausted, ReachableBindingsCache)
findBindings :: Graph
-> (Operands, [Id], ReachableBindingsCache)
-> [Id]
-> (Result Operands, [Id], ReachableBindingsCache)
findBindings Graph
_ (Operands
rb, [Id]
nx, ReachableBindingsCache
rbc) [] =
(Operands -> Result Operands
forall a. a -> Result a
Produced Operands
rb, [Id]
nx, ReachableBindingsCache
rbc)
findBindings Graph
g (Operands
rb, [Id]
nx, ReachableBindingsCache
rbc) (Id
i : [Id]
is)
| Just Vertex Meta
v <- Id -> Graph -> Maybe (Vertex Meta)
forall m. Id -> Graph m -> Maybe (Vertex m)
MG.lookup Id
i Graph
g,
Just Id
gid <- Meta -> Maybe Id
metaGraphId (Vertex Meta -> Meta
forall m. Vertex m -> m
vertexMeta Vertex Meta
v),
Id
gid Id -> Id -> Bool
forall a. Eq a => a -> a -> Bool
== Id
subgraphId
=
let (Result Operands
res, ReachableBindingsCache
rbc') = Graph
-> (Operands -> EdgeType -> Vertex Meta -> Operands)
-> ReachableBindingsCache
-> EdgeType
-> Id
-> (Result Operands, ReachableBindingsCache)
forall a m.
Monoid a =>
Graph m
-> (a -> EdgeType -> Vertex m -> a)
-> Visited (Result a)
-> EdgeType
-> Id
-> (Result a, Visited (Result a))
MG.reduce Graph
g Operands -> EdgeType -> Vertex Meta -> Operands
bindingReach ReachableBindingsCache
rbc EdgeType
Normal Id
i
in case Result Operands
res of
Result Operands
FoundSink -> (Result Operands
forall a. Result a
FoundSink, [], ReachableBindingsCache
rbc')
Produced Operands
rb' -> Graph
-> (Operands, [Id], ReachableBindingsCache)
-> [Id]
-> (Result Operands, [Id], ReachableBindingsCache)
findBindings Graph
g (Operands
rb Operands -> Operands -> Operands
forall a. Semigroup a => a -> a -> a
<> Operands
rb', [Id]
nx, ReachableBindingsCache
rbc') [Id]
is
| Bool
otherwise =
Graph
-> (Operands, [Id], ReachableBindingsCache)
-> [Id]
-> (Result Operands, [Id], ReachableBindingsCache)
findBindings Graph
g (Operands
rb, Id
i Id -> [Id] -> [Id]
forall a. a -> [a] -> [a]
: [Id]
nx, ReachableBindingsCache
rbc) [Id]
is
graphWithAcc ::
[Binding] ->
[WithAccInput GPU] ->
Lambda GPU ->
Grapher ()
graphWithAcc :: [Binding] -> [WithAccInput GPU] -> Lambda GPU -> Grapher ()
graphWithAcc [Binding]
bs [WithAccInput GPU]
inputs Lambda GPU
f = do
Body GPU -> Grapher ()
graphBody (Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
f)
((Type, WithAccInput GPU) -> Grapher ())
-> [(Type, WithAccInput GPU)] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Type, WithAccInput GPU) -> Grapher ()
forall {shape} {u} {a} {b}.
(TypeBase shape u, (a, b, Maybe (Lambda GPU, [SubExp])))
-> Grapher ()
graph ([(Type, WithAccInput GPU)] -> Grapher ())
-> [(Type, WithAccInput GPU)] -> Grapher ()
forall a b. (a -> b) -> a -> b
$ [Type] -> [WithAccInput GPU] -> [(Type, WithAccInput GPU)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda GPU -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda GPU
f) [WithAccInput GPU]
inputs
let arrs :: [SubExp]
arrs = (WithAccInput GPU -> [SubExp]) -> [WithAccInput GPU] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\(ShapeBase SubExp
_, [VName]
as, Maybe (Lambda GPU, [SubExp])
_) -> (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
as) [WithAccInput GPU]
inputs
let res :: [SubExpRes]
res = Id -> [SubExpRes] -> [SubExpRes]
forall a. Id -> [a] -> [a]
drop ([WithAccInput GPU] -> Id
forall a. [a] -> Id
forall (t :: * -> *) a. Foldable t => t a -> Id
length [WithAccInput GPU]
inputs) (Body GPU -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult (Body GPU -> [SubExpRes]) -> Body GPU -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
f)
Bool
_ <- [Binding] -> [SubExp] -> Grapher Bool
reusesReturn [Binding]
bs ([SubExp]
arrs [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp [SubExpRes]
res)
[Operands]
ret <- (SubExpRes -> Grapher Operands)
-> [SubExpRes] -> StateT State (Reader Env) [Operands]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SubExp -> Grapher Operands
onlyGraphedScalarSubExp (SubExp -> Grapher Operands)
-> (SubExpRes -> SubExp) -> SubExpRes -> Grapher Operands
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) [SubExpRes]
res
((Binding, Operands) -> Grapher ())
-> [(Binding, Operands)] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((Binding -> Operands -> Grapher ())
-> (Binding, Operands) -> Grapher ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Binding -> Operands -> Grapher ()
createNode) ([(Binding, Operands)] -> Grapher ())
-> [(Binding, Operands)] -> Grapher ()
forall a b. (a -> b) -> a -> b
$ [Binding] -> [Operands] -> [(Binding, Operands)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Id -> [Binding] -> [Binding]
forall a. Id -> [a] -> [a]
drop ([SubExp] -> Id
forall a. [a] -> Id
forall (t :: * -> *) a. Foldable t => t a -> Id
length [SubExp]
arrs) [Binding]
bs) [Operands]
ret
where
graph :: (TypeBase shape u, (a, b, Maybe (Lambda GPU, [SubExp])))
-> Grapher ()
graph (Acc VName
a ShapeBase SubExp
_ [Type]
types u
_, (a
_, b
_, Maybe (Lambda GPU, [SubExp])
comb)) = do
let i :: Id
i = VName -> Id
nameToId VName
a
[Delayed]
delayed <- [Delayed] -> Maybe [Delayed] -> [Delayed]
forall a. a -> Maybe a -> a
fromMaybe [] (Maybe [Delayed] -> [Delayed])
-> StateT State (Reader Env) (Maybe [Delayed])
-> StateT State (Reader Env) [Delayed]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (State -> Maybe [Delayed])
-> StateT State (Reader Env) (Maybe [Delayed])
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets (Id -> IntMap [Delayed] -> Maybe [Delayed]
forall a. Id -> IntMap a -> Maybe a
IM.lookup Id
i (IntMap [Delayed] -> Maybe [Delayed])
-> (State -> IntMap [Delayed]) -> State -> Maybe [Delayed]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. State -> IntMap [Delayed]
stateUpdateAccs)
(State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateUpdateAccs = IM.delete i (stateUpdateAccs st)}
Id -> [Type] -> Maybe (Lambda GPU) -> [Delayed] -> Grapher ()
graphAcc Id
i [Type]
types ((Lambda GPU, [SubExp]) -> Lambda GPU
forall a b. (a, b) -> a
fst ((Lambda GPU, [SubExp]) -> Lambda GPU)
-> Maybe (Lambda GPU, [SubExp]) -> Maybe (Lambda GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe (Lambda GPU, [SubExp])
comb) [Delayed]
delayed
(SubExp -> Grapher ()) -> [SubExp] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SubExp -> Grapher ()
connectSubExpToSink ([SubExp] -> Grapher ()) -> [SubExp] -> Grapher ()
forall a b. (a -> b) -> a -> b
$ [SubExp]
-> ((Lambda GPU, [SubExp]) -> [SubExp])
-> Maybe (Lambda GPU, [SubExp])
-> [SubExp]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe [] (Lambda GPU, [SubExp]) -> [SubExp]
forall a b. (a, b) -> b
snd Maybe (Lambda GPU, [SubExp])
comb
graph (TypeBase shape u, (a, b, Maybe (Lambda GPU, [SubExp])))
_ =
String -> Grapher ()
forall a. String -> a
compilerBugS String
"Type error: WithAcc expression did not return accumulator."
graphAcc :: Id -> [Type] -> Maybe (Lambda GPU) -> [Delayed] -> Grapher ()
graphAcc :: Id -> [Type] -> Maybe (Lambda GPU) -> [Delayed] -> Grapher ()
graphAcc Id
i [Type]
_ Maybe (Lambda GPU)
_ [] = Binding -> Grapher ()
addSource (Id
i, PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit)
graphAcc Id
i [Type]
types Maybe (Lambda GPU)
op [Delayed]
delayed = do
Env
env <- Grapher Env
ask
State
st <- StateT State (Reader Env) State
forall (m :: * -> *) s. Monad m => StateT s m s
get
let lambda :: Lambda GPU
lambda = Lambda GPU -> Maybe (Lambda GPU) -> Lambda GPU
forall a. a -> Maybe a -> a
fromMaybe ([LParam GPU] -> [Type] -> Body GPU -> Lambda GPU
forall rep. [LParam rep] -> [Type] -> Body rep -> Lambda rep
Lambda [] [] (BodyDec GPU -> Stms GPU -> [SubExpRes] -> Body GPU
forall rep. BodyDec rep -> Stms rep -> [SubExpRes] -> Body rep
Body () Stms GPU
forall a. Seq a
SQ.empty [])) Maybe (Lambda GPU)
op
let m :: Grapher ()
m = Body GPU -> Grapher ()
graphBody (Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
lambda)
let stats :: BodyStats
stats = Reader Env BodyStats -> Env -> BodyStats
forall r a. Reader r a -> r -> a
R.runReader (Grapher BodyStats -> State -> Reader Env BodyStats
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (Grapher () -> Grapher BodyStats
forall a. Grapher a -> Grapher BodyStats
captureBodyStats Grapher ()
m) State
st) Env
env
let host_only :: Bool
host_only = BodyStats -> Bool
bodyHostOnly BodyStats
stats Bool -> Bool -> Bool
|| BodyStats -> Bool
bodyHasGPUBody BodyStats
stats
let does_read :: Bool
does_read = BodyStats -> Bool
bodyReads BodyStats
stats Bool -> Bool -> Bool
|| (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Type -> Bool
forall t. Typed t => t -> Bool
isScalar [Type]
types
Operands
ops <- Exp GPU -> Grapher Operands
graphedScalarOperands ([WithAccInput GPU] -> Lambda GPU -> Exp GPU
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [] Lambda GPU
lambda)
case (Bool
host_only, Bool
does_read) of
(Bool
True, Bool
_) -> do
(Delayed -> Grapher ()) -> [Delayed] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Exp GPU -> Grapher ()
graphHostOnly (Exp GPU -> Grapher ())
-> (Delayed -> Exp GPU) -> Delayed -> Grapher ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Delayed -> Exp GPU
forall a b. (a, b) -> b
snd) [Delayed]
delayed
Edges -> Operands -> Grapher ()
addEdges Edges
ToSink Operands
ops
(Bool
_, Bool
True) -> do
(Delayed -> Grapher ()) -> [Delayed] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Binding -> Grapher ()
graphAutoMove (Binding -> Grapher ())
-> (Delayed -> Binding) -> Delayed -> Grapher ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Delayed -> Binding
forall a b. (a, b) -> a
fst) [Delayed]
delayed
Binding -> Grapher ()
addSource (Id
i, PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit)
(Bool, Bool)
_ -> do
Binding -> Operands -> Grapher ()
createNode (Id
i, PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit) Operands
ops
[Delayed] -> (Delayed -> Grapher ()) -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Delayed]
delayed ((Delayed -> Grapher ()) -> Grapher ())
-> (Delayed -> Grapher ()) -> Grapher ()
forall a b. (a -> b) -> a -> b
$
\(Binding
b, Exp GPU
e) -> Exp GPU -> Grapher Operands
graphedScalarOperands Exp GPU
e Grapher Operands -> (Operands -> Grapher ()) -> Grapher ()
forall a b.
StateT State (Reader Env) a
-> (a -> StateT State (Reader Env) b)
-> StateT State (Reader Env) b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Binding -> Operands -> Grapher ()
createNode Binding
b (Operands -> Grapher ())
-> (Operands -> Operands) -> Operands -> Grapher ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Id -> Operands -> Operands
IS.insert Id
i
graphedScalarOperands :: Exp GPU -> Grapher Operands
graphedScalarOperands :: Exp GPU -> Grapher Operands
graphedScalarOperands Exp GPU
e =
let is :: Operands
is = (Operands, Set VName) -> Operands
forall a b. (a, b) -> a
fst ((Operands, Set VName) -> Operands)
-> (Operands, Set VName) -> Operands
forall a b. (a -> b) -> a -> b
$ State (Operands, Set VName) ()
-> (Operands, Set VName) -> (Operands, Set VName)
forall s a. State s a -> s -> s
execState (Exp GPU -> State (Operands, Set VName) ()
collect Exp GPU
e) (Operands, Set VName)
forall {a}. (Operands, Set a)
initial
in Operands -> Operands -> Operands
IS.intersection Operands
is (Operands -> Operands) -> Grapher Operands -> Grapher Operands
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Grapher Operands
getGraphedScalars
where
initial :: (Operands, Set a)
initial = (Operands
IS.empty, Set a
forall a. Set a
S.empty)
captureName :: VName -> StateT (p Operands c) m ()
captureName VName
n = (p Operands c -> p Operands c) -> StateT (p Operands c) m ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((p Operands c -> p Operands c) -> StateT (p Operands c) m ())
-> (p Operands c -> p Operands c) -> StateT (p Operands c) m ()
forall a b. (a -> b) -> a -> b
$ (Operands -> Operands) -> p Operands c -> p Operands c
forall a b c. (a -> b) -> p a c -> p b c
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first ((Operands -> Operands) -> p Operands c -> p Operands c)
-> (Operands -> Operands) -> p Operands c -> p Operands c
forall a b. (a -> b) -> a -> b
$ Id -> Operands -> Operands
IS.insert (VName -> Id
nameToId VName
n)
captureAcc :: a -> StateT (p a (Set a)) m ()
captureAcc a
a = (p a (Set a) -> p a (Set a)) -> StateT (p a (Set a)) m ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((p a (Set a) -> p a (Set a)) -> StateT (p a (Set a)) m ())
-> (p a (Set a) -> p a (Set a)) -> StateT (p a (Set a)) m ()
forall a b. (a -> b) -> a -> b
$ (Set a -> Set a) -> p a (Set a) -> p a (Set a)
forall b c a. (b -> c) -> p a b -> p a c
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second ((Set a -> Set a) -> p a (Set a) -> p a (Set a))
-> (Set a -> Set a) -> p a (Set a) -> p a (Set a)
forall a b. (a -> b) -> a -> b
$ a -> Set a -> Set a
forall a. Ord a => a -> Set a -> Set a
S.insert a
a
collectFree :: a -> StateT (p Operands c) m ()
collectFree a
x = (VName -> StateT (p Operands c) m ())
-> [VName] -> StateT (p Operands c) m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ VName -> StateT (p Operands c) m ()
forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
VName -> StateT (p Operands c) m ()
captureName (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ a -> Names
forall a. FreeIn a => a -> Names
freeIn a
x)
collect :: Exp GPU -> State (Operands, Set VName) ()
collect b :: Exp GPU
b@BasicOp {} =
Exp GPU -> State (Operands, Set VName) ()
forall {m :: * -> *} {p :: * -> * -> *} {rep} {c}.
(Monad m, Bifunctor p) =>
Exp rep -> StateT (p Operands c) m ()
collectBasic Exp GPU
b
collect (Apply Name
_ [(SubExp, Diet)]
params [(RetType GPU, RetAls)]
_ (Safety, SrcLoc, [SrcLoc])
_) =
((SubExp, Diet) -> State (Operands, Set VName) ())
-> [(SubExp, Diet)] -> State (Operands, Set VName) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (SubExp -> State (Operands, Set VName) ()
forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p Operands c) m ()
collectSubExp (SubExp -> State (Operands, Set VName) ())
-> ((SubExp, Diet) -> SubExp)
-> (SubExp, Diet)
-> State (Operands, Set VName) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp, Diet) -> SubExp
forall a b. (a, b) -> a
fst) [(SubExp, Diet)]
params
collect (Match [SubExp]
ses [Case (Body GPU)]
cases Body GPU
defbody MatchDec (BranchType GPU)
_) = do
(SubExp -> State (Operands, Set VName) ())
-> [SubExp] -> State (Operands, Set VName) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SubExp -> State (Operands, Set VName) ()
forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p Operands c) m ()
collectSubExp [SubExp]
ses
(Case (Body GPU) -> State (Operands, Set VName) ())
-> [Case (Body GPU)] -> State (Operands, Set VName) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Body GPU -> State (Operands, Set VName) ()
collectBody (Body GPU -> State (Operands, Set VName) ())
-> (Case (Body GPU) -> Body GPU)
-> Case (Body GPU)
-> State (Operands, Set VName) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Case (Body GPU) -> Body GPU
forall body. Case body -> body
caseBody) [Case (Body GPU)]
cases
Body GPU -> State (Operands, Set VName) ()
collectBody Body GPU
defbody
collect (Loop [(FParam GPU, SubExp)]
params LoopForm
lform Body GPU
body) = do
((FParam GPU, SubExp) -> State (Operands, Set VName) ())
-> [(FParam GPU, SubExp)] -> State (Operands, Set VName) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (SubExp -> State (Operands, Set VName) ()
forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p Operands c) m ()
collectSubExp (SubExp -> State (Operands, Set VName) ())
-> ((FParam GPU, SubExp) -> SubExp)
-> (FParam GPU, SubExp)
-> State (Operands, Set VName) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam GPU, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) [(FParam GPU, SubExp)]
params
LoopForm -> State (Operands, Set VName) ()
forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
LoopForm -> StateT (p Operands c) m ()
collectLForm LoopForm
lform
Body GPU -> State (Operands, Set VName) ()
collectBody Body GPU
body
collect (WithAcc [WithAccInput GPU]
accs Lambda GPU
f) =
[WithAccInput GPU] -> Lambda GPU -> State (Operands, Set VName) ()
collectWithAcc [WithAccInput GPU]
accs Lambda GPU
f
collect (Op Op GPU
op) =
HostOp SOAC GPU -> State (Operands, Set VName) ()
forall {op :: * -> *} {rep} {c}.
FreeIn (op rep) =>
HostOp op rep -> StateT (Operands, c) Identity ()
collectHostOp Op GPU
HostOp SOAC GPU
op
collectBasic :: Exp rep -> StateT (p Operands c) m ()
collectBasic (BasicOp (Update Safety
_ VName
_ Slice SubExp
slice SubExp
_)) =
Slice SubExp -> StateT (p Operands c) m ()
forall {m :: * -> *} {p :: * -> * -> *} {a} {c}.
(Monad m, Bifunctor p, FreeIn a) =>
a -> StateT (p Operands c) m ()
collectFree Slice SubExp
slice
collectBasic (BasicOp (Replicate ShapeBase SubExp
shape SubExp
_)) =
ShapeBase SubExp -> StateT (p Operands c) m ()
forall {m :: * -> *} {p :: * -> * -> *} {a} {c}.
(Monad m, Bifunctor p, FreeIn a) =>
a -> StateT (p Operands c) m ()
collectFree ShapeBase SubExp
shape
collectBasic Exp rep
e' =
Walker rep (StateT (p Operands c) m)
-> Exp rep -> StateT (p Operands c) m ()
forall (m :: * -> *) rep.
Monad m =>
Walker rep m -> Exp rep -> m ()
walkExpM (Walker rep (StateT (p Operands c) m)
forall rep (m :: * -> *). Monad m => Walker rep m
identityWalker {walkOnSubExp = collectSubExp}) Exp rep
e'
collectSubExp :: SubExp -> StateT (p Operands c) m ()
collectSubExp (Var VName
n) = VName -> StateT (p Operands c) m ()
forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
VName -> StateT (p Operands c) m ()
captureName VName
n
collectSubExp SubExp
_ = () -> StateT (p Operands c) m ()
forall a. a -> StateT (p Operands c) m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
collectBody :: Body GPU -> State (Operands, Set VName) ()
collectBody Body GPU
body = do
Stms GPU -> State (Operands, Set VName) ()
collectStms (Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms Body GPU
body)
[SubExpRes] -> State (Operands, Set VName) ()
forall {m :: * -> *} {p :: * -> * -> *} {a} {c}.
(Monad m, Bifunctor p, FreeIn a) =>
a -> StateT (p Operands c) m ()
collectFree (Body GPU -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body GPU
body)
collectStms :: Stms GPU -> State (Operands, Set VName) ()
collectStms = (Stm GPU -> State (Operands, Set VName) ())
-> Stms GPU -> State (Operands, Set VName) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm GPU -> State (Operands, Set VName) ()
collectStm
collectStm :: Stm GPU -> State (Operands, Set VName) ()
collectStm (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
_ Exp GPU
ua)
| BasicOp UpdateAcc {} <- Exp GPU
ua,
Pat [PatElem (LetDec GPU)
pe] <- Pat (LetDec GPU)
pat,
Acc VName
a ShapeBase SubExp
_ [Type]
_ NoUniqueness
_ <- PatElem (LetDec GPU) -> Type
forall t. Typed t => t -> Type
typeOf PatElem (LetDec GPU)
pe =
VName -> State (Operands, Set VName) ()
forall {m :: * -> *} {p :: * -> * -> *} {a} {a}.
(Monad m, Bifunctor p, Ord a) =>
a -> StateT (p a (Set a)) m ()
captureAcc VName
a State (Operands, Set VName) ()
-> State (Operands, Set VName) () -> State (Operands, Set VName) ()
forall a b.
StateT (Operands, Set VName) Identity a
-> StateT (Operands, Set VName) Identity b
-> StateT (Operands, Set VName) Identity b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Exp GPU -> State (Operands, Set VName) ()
forall {m :: * -> *} {p :: * -> * -> *} {rep} {c}.
(Monad m, Bifunctor p) =>
Exp rep -> StateT (p Operands c) m ()
collectBasic Exp GPU
ua
collectStm Stm GPU
stm = Exp GPU -> State (Operands, Set VName) ()
collect (Stm GPU -> Exp GPU
forall rep. Stm rep -> Exp rep
stmExp Stm GPU
stm)
collectLForm :: LoopForm -> StateT (p Operands c) m ()
collectLForm (ForLoop VName
_ IntType
_ SubExp
b) = SubExp -> StateT (p Operands c) m ()
forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p Operands c) m ()
collectSubExp SubExp
b
collectLForm (WhileLoop VName
_) = () -> StateT (p Operands c) m ()
forall a. a -> StateT (p Operands c) m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
collectWithAcc :: [WithAccInput GPU] -> Lambda GPU -> State (Operands, Set VName) ()
collectWithAcc [WithAccInput GPU]
inputs Lambda GPU
f = do
Body GPU -> State (Operands, Set VName) ()
collectBody (Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
f)
Set VName
used_accs <- ((Operands, Set VName) -> Set VName)
-> StateT (Operands, Set VName) Identity (Set VName)
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets (Operands, Set VName) -> Set VName
forall a b. (a, b) -> b
snd
let accs :: [Type]
accs = Id -> [Type] -> [Type]
forall a. Id -> [a] -> [a]
take ([WithAccInput GPU] -> Id
forall a. [a] -> Id
forall (t :: * -> *) a. Foldable t => t a -> Id
length [WithAccInput GPU]
inputs) (Lambda GPU -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda GPU
f)
let used :: [Bool]
used = (Type -> Bool) -> [Type] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map (\(Acc VName
a ShapeBase SubExp
_ [Type]
_ NoUniqueness
_) -> VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
S.member VName
a Set VName
used_accs) [Type]
accs
((Bool, WithAccInput GPU) -> State (Operands, Set VName) ())
-> [(Bool, WithAccInput GPU)] -> State (Operands, Set VName) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Bool, WithAccInput GPU) -> State (Operands, Set VName) ()
collectAcc ([Bool] -> [WithAccInput GPU] -> [(Bool, WithAccInput GPU)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
used [WithAccInput GPU]
inputs)
collectAcc :: (Bool, WithAccInput GPU) -> State (Operands, Set VName) ()
collectAcc (Bool
_, (ShapeBase SubExp
_, [VName]
_, Maybe (Lambda GPU, [SubExp])
Nothing)) = () -> State (Operands, Set VName) ()
forall a. a -> StateT (Operands, Set VName) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
collectAcc (Bool
used, (ShapeBase SubExp
_, [VName]
_, Just (Lambda GPU
op, [SubExp]
nes))) = do
(SubExp -> State (Operands, Set VName) ())
-> [SubExp] -> State (Operands, Set VName) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SubExp -> State (Operands, Set VName) ()
forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p Operands c) m ()
collectSubExp [SubExp]
nes
Bool
-> State (Operands, Set VName) () -> State (Operands, Set VName) ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
used (State (Operands, Set VName) () -> State (Operands, Set VName) ())
-> State (Operands, Set VName) () -> State (Operands, Set VName) ()
forall a b. (a -> b) -> a -> b
$ Body GPU -> State (Operands, Set VName) ()
collectBody (Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
op)
collectHostOp :: HostOp op rep -> StateT (Operands, c) Identity ()
collectHostOp (SegOp (SegMap SegLevel
lvl SegSpace
sp [Type]
_ KernelBody rep
_)) = do
SegLevel -> StateT (Operands, c) Identity ()
forall {c}. SegLevel -> StateT (Operands, c) Identity ()
collectSegLevel SegLevel
lvl
SegSpace -> StateT (Operands, c) Identity ()
forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
SegSpace -> StateT (p Operands c) m ()
collectSegSpace SegSpace
sp
collectHostOp (SegOp (SegRed SegLevel
lvl SegSpace
sp [Type]
_ KernelBody rep
_ [SegBinOp rep]
ops)) = do
SegLevel -> StateT (Operands, c) Identity ()
forall {c}. SegLevel -> StateT (Operands, c) Identity ()
collectSegLevel SegLevel
lvl
SegSpace -> StateT (Operands, c) Identity ()
forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
SegSpace -> StateT (p Operands c) m ()
collectSegSpace SegSpace
sp
(SegBinOp rep -> StateT (Operands, c) Identity ())
-> [SegBinOp rep] -> StateT (Operands, c) Identity ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SegBinOp rep -> StateT (Operands, c) Identity ()
forall {m :: * -> *} {p :: * -> * -> *} {rep} {c}.
(Monad m, Bifunctor p) =>
SegBinOp rep -> StateT (p Operands c) m ()
collectSegBinOp [SegBinOp rep]
ops
collectHostOp (SegOp (SegScan SegLevel
lvl SegSpace
sp [Type]
_ KernelBody rep
_ [SegBinOp rep]
ops)) = do
SegLevel -> StateT (Operands, c) Identity ()
forall {c}. SegLevel -> StateT (Operands, c) Identity ()
collectSegLevel SegLevel
lvl
SegSpace -> StateT (Operands, c) Identity ()
forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
SegSpace -> StateT (p Operands c) m ()
collectSegSpace SegSpace
sp
(SegBinOp rep -> StateT (Operands, c) Identity ())
-> [SegBinOp rep] -> StateT (Operands, c) Identity ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SegBinOp rep -> StateT (Operands, c) Identity ()
forall {m :: * -> *} {p :: * -> * -> *} {rep} {c}.
(Monad m, Bifunctor p) =>
SegBinOp rep -> StateT (p Operands c) m ()
collectSegBinOp [SegBinOp rep]
ops
collectHostOp (SegOp (SegHist SegLevel
lvl SegSpace
sp [Type]
_ KernelBody rep
_ [HistOp rep]
ops)) = do
SegLevel -> StateT (Operands, c) Identity ()
forall {c}. SegLevel -> StateT (Operands, c) Identity ()
collectSegLevel SegLevel
lvl
SegSpace -> StateT (Operands, c) Identity ()
forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
SegSpace -> StateT (p Operands c) m ()
collectSegSpace SegSpace
sp
(HistOp rep -> StateT (Operands, c) Identity ())
-> [HistOp rep] -> StateT (Operands, c) Identity ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ HistOp rep -> StateT (Operands, c) Identity ()
forall {m :: * -> *} {p :: * -> * -> *} {rep} {c}.
(Monad m, Bifunctor p) =>
HistOp rep -> StateT (p Operands c) m ()
collectHistOp [HistOp rep]
ops
collectHostOp (SizeOp SizeOp
op) = SizeOp -> StateT (Operands, c) Identity ()
forall {m :: * -> *} {p :: * -> * -> *} {a} {c}.
(Monad m, Bifunctor p, FreeIn a) =>
a -> StateT (p Operands c) m ()
collectFree SizeOp
op
collectHostOp (OtherOp op rep
op) = op rep -> StateT (Operands, c) Identity ()
forall {m :: * -> *} {p :: * -> * -> *} {a} {c}.
(Monad m, Bifunctor p, FreeIn a) =>
a -> StateT (p Operands c) m ()
collectFree op rep
op
collectHostOp GPUBody {} = () -> StateT (Operands, c) Identity ()
forall a. a -> StateT (Operands, c) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
collectSegLevel :: SegLevel -> StateT (Operands, c) Identity ()
collectSegLevel = (VName -> StateT (Operands, c) Identity ())
-> [VName] -> StateT (Operands, c) Identity ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ VName -> StateT (Operands, c) Identity ()
forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
VName -> StateT (p Operands c) m ()
captureName ([VName] -> StateT (Operands, c) Identity ())
-> (SegLevel -> [VName])
-> SegLevel
-> StateT (Operands, c) Identity ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList (Names -> [VName]) -> (SegLevel -> Names) -> SegLevel -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegLevel -> Names
forall a. FreeIn a => a -> Names
freeIn
collectSegSpace :: SegSpace -> StateT (p Operands c) m ()
collectSegSpace SegSpace
space =
(SubExp -> StateT (p Operands c) m ())
-> [SubExp] -> StateT (p Operands c) m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SubExp -> StateT (p Operands c) m ()
forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p Operands c) m ()
collectSubExp (SegSpace -> [SubExp]
segSpaceDims SegSpace
space)
collectSegBinOp :: SegBinOp rep -> StateT (p Operands c) m ()
collectSegBinOp (SegBinOp Commutativity
_ Lambda rep
_ [SubExp]
nes ShapeBase SubExp
_) =
(SubExp -> StateT (p Operands c) m ())
-> [SubExp] -> StateT (p Operands c) m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SubExp -> StateT (p Operands c) m ()
forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p Operands c) m ()
collectSubExp [SubExp]
nes
collectHistOp :: HistOp rep -> StateT (p Operands c) m ()
collectHistOp (HistOp ShapeBase SubExp
_ SubExp
rf [VName]
_ [SubExp]
nes ShapeBase SubExp
_ Lambda rep
_) = do
SubExp -> StateT (p Operands c) m ()
forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p Operands c) m ()
collectSubExp SubExp
rf
(SubExp -> StateT (p Operands c) m ())
-> [SubExp] -> StateT (p Operands c) m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SubExp -> StateT (p Operands c) m ()
forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p Operands c) m ()
collectSubExp [SubExp]
nes
createNode :: Binding -> Operands -> Grapher ()
createNode :: Binding -> Operands -> Grapher ()
createNode Binding
b Operands
ops =
Bool -> Grapher () -> Grapher ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Operands -> Bool
IS.null Operands
ops) (Binding -> Grapher ()
addVertex Binding
b Grapher () -> Grapher () -> Grapher ()
forall a b.
StateT State (Reader Env) a
-> StateT State (Reader Env) b -> StateT State (Reader Env) b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Edges -> Operands -> Grapher ()
addEdges (Id -> Edges
MG.oneEdge (Id -> Edges) -> Id -> Edges
forall a b. (a -> b) -> a -> b
$ Binding -> Id
forall a b. (a, b) -> a
fst Binding
b) Operands
ops)
addVertex :: Binding -> Grapher ()
addVertex :: Binding -> Grapher ()
addVertex (Id
i, Type
t) = do
Meta
meta <- Grapher Meta
getMeta
let v :: Vertex Meta
v = Id -> Meta -> Vertex Meta
forall m. Id -> m -> Vertex m
MG.vertex Id
i Meta
meta
Bool -> Grapher () -> Grapher ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Type -> Bool
forall t. Typed t => t -> Bool
isScalar Type
t) (Grapher () -> Grapher ()) -> Grapher () -> Grapher ()
forall a b. (a -> b) -> a -> b
$ (Operands -> Operands) -> Grapher ()
modifyGraphedScalars (Id -> Operands -> Operands
IS.insert Id
i)
Bool -> Grapher () -> Grapher ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Type -> Bool
forall t. Typed t => t -> Bool
isArray Type
t) (Grapher () -> Grapher ()) -> Grapher () -> Grapher ()
forall a b. (a -> b) -> a -> b
$ Id -> Id -> Grapher ()
recordCopyableMemory Id
i (Meta -> Id
metaBodyDepth Meta
meta)
(Graph -> Graph) -> Grapher ()
modifyGraph (Vertex Meta -> Graph -> Graph
forall m. Vertex m -> Graph m -> Graph m
MG.insert Vertex Meta
v)
addSource :: Binding -> Grapher ()
addSource :: Binding -> Grapher ()
addSource Binding
b = do
Binding -> Grapher ()
addVertex Binding
b
(Sources -> Sources) -> Grapher ()
modifySources ((Sources -> Sources) -> Grapher ())
-> (Sources -> Sources) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ ([Id] -> [Id]) -> Sources -> Sources
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (Binding -> Id
forall a b. (a, b) -> a
fst Binding
b :)
addEdges :: Edges -> IdSet -> Grapher ()
addEdges :: Edges -> Operands -> Grapher ()
addEdges Edges
ToSink Operands
is = do
(Graph -> Graph) -> Grapher ()
modifyGraph ((Graph -> Graph) -> Grapher ()) -> (Graph -> Graph) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \Graph
g -> (Graph -> Id -> Graph) -> Graph -> Operands -> Graph
forall a. (a -> Id -> a) -> a -> Operands -> a
IS.foldl' ((Id -> Graph -> Graph) -> Graph -> Id -> Graph
forall a b c. (a -> b -> c) -> b -> a -> c
flip Id -> Graph -> Graph
forall m. Id -> Graph m -> Graph m
MG.connectToSink) Graph
g Operands
is
(Operands -> Operands) -> Grapher ()
modifyGraphedScalars (Operands -> Operands -> Operands
`IS.difference` Operands
is)
addEdges Edges
es Operands
is = do
(Graph -> Graph) -> Grapher ()
modifyGraph ((Graph -> Graph) -> Grapher ()) -> (Graph -> Graph) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \Graph
g -> (Graph -> Id -> Graph) -> Graph -> Operands -> Graph
forall a. (a -> Id -> a) -> a -> Operands -> a
IS.foldl' ((Id -> Graph -> Graph) -> Graph -> Id -> Graph
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Id -> Graph -> Graph) -> Graph -> Id -> Graph)
-> (Id -> Graph -> Graph) -> Graph -> Id -> Graph
forall a b. (a -> b) -> a -> b
$ Edges -> Id -> Graph -> Graph
forall m. Edges -> Id -> Graph m -> Graph m
MG.addEdges Edges
es) Graph
g Operands
is
Operands -> Grapher ()
tellOperands Operands
is
requiredOnHost :: Id -> Grapher ()
requiredOnHost :: Id -> Grapher ()
requiredOnHost Id
i = do
Maybe (Vertex Meta)
mv <- Id -> Graph -> Maybe (Vertex Meta)
forall m. Id -> Graph m -> Maybe (Vertex m)
MG.lookup Id
i (Graph -> Maybe (Vertex Meta))
-> Grapher Graph -> StateT State (Reader Env) (Maybe (Vertex Meta))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Grapher Graph
getGraph
case Maybe (Vertex Meta)
mv of
Maybe (Vertex Meta)
Nothing -> () -> Grapher ()
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
Just Vertex Meta
v -> do
Id -> Grapher ()
connectToSink Id
i
Id -> Grapher ()
tellHostOnlyParent (Meta -> Id
metaBodyDepth (Meta -> Id) -> Meta -> Id
forall a b. (a -> b) -> a -> b
$ Vertex Meta -> Meta
forall m. Vertex m -> m
vertexMeta Vertex Meta
v)
connectToSink :: Id -> Grapher ()
connectToSink :: Id -> Grapher ()
connectToSink Id
i = do
(Graph -> Graph) -> Grapher ()
modifyGraph (Id -> Graph -> Graph
forall m. Id -> Graph m -> Graph m
MG.connectToSink Id
i)
(Operands -> Operands) -> Grapher ()
modifyGraphedScalars (Id -> Operands -> Operands
IS.delete Id
i)
connectSubExpToSink :: SubExp -> Grapher ()
connectSubExpToSink :: SubExp -> Grapher ()
connectSubExpToSink (Var VName
n) = Id -> Grapher ()
connectToSink (VName -> Id
nameToId VName
n)
connectSubExpToSink SubExp
_ = () -> Grapher ()
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
routeSubgraph :: Id -> Grapher [Id]
routeSubgraph :: Id -> Grapher [Id]
routeSubgraph Id
si = do
State
st <- StateT State (Reader Env) State
forall (m :: * -> *) s. Monad m => StateT s m s
get
let g :: Graph
g = State -> Graph
stateGraph State
st
let ([Id]
routed, [Id]
unrouted) = State -> Sources
stateSources State
st
let ([Id]
gsrcs, [Id]
unrouted') = (Id -> Bool) -> [Id] -> Sources
forall a. (a -> Bool) -> [a] -> ([a], [a])
span (Id -> Graph -> Id -> Bool
inSubGraph Id
si Graph
g) [Id]
unrouted
let ([Id]
sinks, Graph
g') = [Id] -> Graph -> ([Id], Graph)
forall m. [Id] -> Graph m -> ([Id], Graph m)
MG.routeMany [Id]
gsrcs Graph
g
State -> Grapher ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put (State -> Grapher ()) -> State -> Grapher ()
forall a b. (a -> b) -> a -> b
$
State
st
{ stateGraph = g',
stateSources = (gsrcs ++ routed, unrouted'),
stateSinks = sinks ++ stateSinks st
}
[Id] -> Grapher [Id]
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Id]
gsrcs
inSubGraph :: Id -> Graph -> Id -> Bool
inSubGraph :: Id -> Graph -> Id -> Bool
inSubGraph Id
si Graph
g Id
i
| Just Vertex Meta
v <- Id -> Graph -> Maybe (Vertex Meta)
forall m. Id -> Graph m -> Maybe (Vertex m)
MG.lookup Id
i Graph
g,
Just Id
mgi <- Meta -> Maybe Id
metaGraphId (Vertex Meta -> Meta
forall m. Vertex m -> m
vertexMeta Vertex Meta
v) =
Id
si Id -> Id -> Bool
forall a. Eq a => a -> a -> Bool
== Id
mgi
inSubGraph Id
_ Graph
_ Id
_ = Bool
False
reuses :: Binding -> VName -> Grapher ()
reuses :: Binding -> VName -> Grapher ()
reuses (Id
i, Type
t) VName
n
| Type -> Bool
forall t. Typed t => t -> Bool
isArray Type
t =
do
Maybe Id
body_depth <- VName -> Grapher (Maybe Id)
outermostCopyableArray VName
n
Maybe Id -> (Id -> Grapher ()) -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ Maybe Id
body_depth (Id -> Id -> Grapher ()
recordCopyableMemory Id
i)
| Bool
otherwise =
() -> Grapher ()
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
reusesSubExp :: Binding -> SubExp -> Grapher ()
reusesSubExp :: Binding -> SubExp -> Grapher ()
reusesSubExp Binding
b (Var VName
n) = Binding
b Binding -> VName -> Grapher ()
`reuses` VName
n
reusesSubExp Binding
_ SubExp
_ = () -> Grapher ()
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
reusesReturn :: [Binding] -> [SubExp] -> Grapher Bool
reusesReturn :: [Binding] -> [SubExp] -> Grapher Bool
reusesReturn [Binding]
bs [SubExp]
res = do
Id
body_depth <- Meta -> Id
metaBodyDepth (Meta -> Id) -> Grapher Meta -> StateT State (Reader Env) Id
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Grapher Meta
getMeta
(Bool -> (Binding, SubExp) -> Grapher Bool)
-> Bool -> [(Binding, SubExp)] -> Grapher Bool
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Id -> Bool -> (Binding, SubExp) -> Grapher Bool
reuse Id
body_depth) Bool
True ([Binding] -> [SubExp] -> [(Binding, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Binding]
bs [SubExp]
res)
where
reuse :: Int -> Bool -> (Binding, SubExp) -> Grapher Bool
reuse :: Id -> Bool -> (Binding, SubExp) -> Grapher Bool
reuse Id
body_depth Bool
onlyCopyable (Binding
b, SubExp
se)
| (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (Type -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims (Type -> [SubExp]) -> Type -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Binding -> Type
forall a b. (a, b) -> b
snd Binding
b) =
Bool -> Grapher Bool
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
onlyCopyable
| (Id
i, Type
t) <- Binding
b,
Type -> Bool
forall t. Typed t => t -> Bool
isArray Type
t,
Var VName
n <- SubExp
se =
do
Maybe Id
res_body_depth <- VName -> Grapher (Maybe Id)
outermostCopyableArray VName
n
case Maybe Id
res_body_depth of
Just Id
inner -> do
Id -> Id -> Grapher ()
recordCopyableMemory Id
i (Id -> Id -> Id
forall a. Ord a => a -> a -> a
min Id
body_depth Id
inner)
let returns_free_var :: Bool
returns_free_var = Id
inner Id -> Id -> Bool
forall a. Ord a => a -> a -> Bool
<= Id
body_depth
Bool -> Grapher Bool
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool
onlyCopyable Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
returns_free_var)
Maybe Id
_ ->
Bool -> Grapher Bool
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
| Bool
otherwise =
Bool -> Grapher Bool
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
onlyCopyable
reusesBranches :: [Binding] -> [[SubExp]] -> Grapher Bool
reusesBranches :: [Binding] -> [[SubExp]] -> Grapher Bool
reusesBranches [Binding]
bs [[SubExp]]
seses = do
Id
body_depth <- Meta -> Id
metaBodyDepth (Meta -> Id) -> Grapher Meta -> StateT State (Reader Env) Id
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Grapher Meta
getMeta
(Bool -> (Binding, [SubExp]) -> Grapher Bool)
-> Bool -> [(Binding, [SubExp])] -> Grapher Bool
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Id -> Bool -> (Binding, [SubExp]) -> Grapher Bool
reuse Id
body_depth) Bool
True ([(Binding, [SubExp])] -> Grapher Bool)
-> [(Binding, [SubExp])] -> Grapher Bool
forall a b. (a -> b) -> a -> b
$ [Binding] -> [[SubExp]] -> [(Binding, [SubExp])]
forall a b. [a] -> [b] -> [(a, b)]
zip [Binding]
bs ([[SubExp]] -> [(Binding, [SubExp])])
-> [[SubExp]] -> [(Binding, [SubExp])]
forall a b. (a -> b) -> a -> b
$ [[SubExp]] -> [[SubExp]]
forall a. [[a]] -> [[a]]
L.transpose [[SubExp]]
seses
where
reuse :: Int -> Bool -> (Binding, [SubExp]) -> Grapher Bool
reuse :: Id -> Bool -> (Binding, [SubExp]) -> Grapher Bool
reuse Id
body_depth Bool
onlyCopyable (Binding
b, [SubExp]
ses)
| (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (Type -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims (Type -> [SubExp]) -> Type -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Binding -> Type
forall a b. (a, b) -> b
snd Binding
b) =
Bool -> Grapher Bool
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
onlyCopyable
| (Id
i, Type
t) <- Binding
b,
Type -> Bool
forall t. Typed t => t -> Bool
isArray Type
t,
Just [VName]
ns <- (SubExp -> Maybe VName) -> [SubExp] -> Maybe [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> Maybe VName
subExpVar [SubExp]
ses = do
[Maybe Id]
body_depths <- (VName -> Grapher (Maybe Id))
-> [VName] -> StateT State (Reader Env) [Maybe Id]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> Grapher (Maybe Id)
outermostCopyableArray [VName]
ns
case [Maybe Id] -> Maybe [Id]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence [Maybe Id]
body_depths of
Just [Id]
bds -> do
let inner :: Id
inner = [Id] -> Id
forall a. Ord a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
minimum [Id]
bds
Id -> Id -> Grapher ()
recordCopyableMemory Id
i (Id -> Id -> Id
forall a. Ord a => a -> a -> a
min Id
body_depth Id
inner)
let returns_free_var :: Bool
returns_free_var = Id
inner Id -> Id -> Bool
forall a. Ord a => a -> a -> Bool
<= Id
body_depth
Bool -> Grapher Bool
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool
onlyCopyable Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
returns_free_var)
Maybe [Id]
_ ->
Bool -> Grapher Bool
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
| Bool
otherwise =
Bool -> Grapher Bool
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
onlyCopyable
type Grapher = StateT State (R.Reader Env)
data Env = Env
{
Env -> HostOnlyFuns
envHostOnlyFuns :: HostOnlyFuns,
Env -> Meta
envMeta :: Meta
}
type BodyDepth = Int
data Meta = Meta
{
Meta -> Id
metaForkDepth :: Int,
Meta -> Id
metaBodyDepth :: BodyDepth,
Meta -> Maybe Id
metaGraphId :: Maybe Id
}
type Operands = IdSet
data BodyStats = BodyStats
{
BodyStats -> Bool
bodyHostOnly :: Bool,
BodyStats -> Bool
bodyHasGPUBody :: Bool,
BodyStats -> Bool
bodyReads :: Bool,
BodyStats -> Operands
bodyOperands :: Operands,
BodyStats -> Operands
bodyHostOnlyParents :: IS.IntSet
}
instance Semigroup BodyStats where
(BodyStats Bool
ho1 Bool
gb1 Bool
r1 Operands
o1 Operands
hop1) <> :: BodyStats -> BodyStats -> BodyStats
<> (BodyStats Bool
ho2 Bool
gb2 Bool
r2 Operands
o2 Operands
hop2) =
BodyStats
{ bodyHostOnly :: Bool
bodyHostOnly = Bool
ho1 Bool -> Bool -> Bool
|| Bool
ho2,
bodyHasGPUBody :: Bool
bodyHasGPUBody = Bool
gb1 Bool -> Bool -> Bool
|| Bool
gb2,
bodyReads :: Bool
bodyReads = Bool
r1 Bool -> Bool -> Bool
|| Bool
r2,
bodyOperands :: Operands
bodyOperands = Operands -> Operands -> Operands
IS.union Operands
o1 Operands
o2,
bodyHostOnlyParents :: Operands
bodyHostOnlyParents = Operands -> Operands -> Operands
IS.union Operands
hop1 Operands
hop2
}
instance Monoid BodyStats where
mempty :: BodyStats
mempty =
BodyStats
{ bodyHostOnly :: Bool
bodyHostOnly = Bool
False,
bodyHasGPUBody :: Bool
bodyHasGPUBody = Bool
False,
bodyReads :: Bool
bodyReads = Bool
False,
bodyOperands :: Operands
bodyOperands = Operands
IS.empty,
bodyHostOnlyParents :: Operands
bodyHostOnlyParents = Operands
IS.empty
}
type Graph = MG.Graph Meta
type Sources = ([Id], [Id])
type Sinks = [Id]
type Delayed = (Binding, Exp GPU)
type Binding = (Id, Type)
type CopyableMemoryMap = IM.IntMap BodyDepth
data State = State
{
State -> Graph
stateGraph :: Graph,
State -> Operands
stateGraphedScalars :: IdSet,
State -> Sources
stateSources :: Sources,
State -> [Id]
stateSinks :: Sinks,
State -> IntMap [Delayed]
stateUpdateAccs :: IM.IntMap [Delayed],
State -> CopyableMemoryMap
stateCopyableMemory :: CopyableMemoryMap,
State -> BodyStats
stateStats :: BodyStats
}
execGrapher :: HostOnlyFuns -> Grapher a -> (Graph, Sources, Sinks)
execGrapher :: forall a. HostOnlyFuns -> Grapher a -> (Graph, Sources, [Id])
execGrapher HostOnlyFuns
hof Grapher a
m =
let s :: State
s = Reader Env State -> Env -> State
forall r a. Reader r a -> r -> a
R.runReader (Grapher a -> State -> Reader Env State
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m s
execStateT Grapher a
m State
st) Env
env
in (State -> Graph
stateGraph State
s, State -> Sources
stateSources State
s, State -> [Id]
stateSinks State
s)
where
env :: Env
env =
Env
{ envHostOnlyFuns :: HostOnlyFuns
envHostOnlyFuns = HostOnlyFuns
hof,
envMeta :: Meta
envMeta =
Meta
{ metaForkDepth :: Id
metaForkDepth = Id
0,
metaBodyDepth :: Id
metaBodyDepth = Id
0,
metaGraphId :: Maybe Id
metaGraphId = Maybe Id
forall a. Maybe a
Nothing
}
}
st :: State
st =
State
{ stateGraph :: Graph
stateGraph = Graph
forall m. Graph m
MG.empty,
stateGraphedScalars :: Operands
stateGraphedScalars = Operands
IS.empty,
stateSources :: Sources
stateSources = ([], []),
stateSinks :: [Id]
stateSinks = [],
stateUpdateAccs :: IntMap [Delayed]
stateUpdateAccs = IntMap [Delayed]
forall a. IntMap a
IM.empty,
stateCopyableMemory :: CopyableMemoryMap
stateCopyableMemory = CopyableMemoryMap
forall a. IntMap a
IM.empty,
stateStats :: BodyStats
stateStats = BodyStats
forall a. Monoid a => a
mempty
}
local :: (Env -> Env) -> Grapher a -> Grapher a
local :: forall a. (Env -> Env) -> Grapher a -> Grapher a
local Env -> Env
f = (Reader Env (a, State) -> Reader Env (a, State))
-> StateT State (Reader Env) a -> StateT State (Reader Env) a
forall (m :: * -> *) a s (n :: * -> *) b.
(m (a, s) -> n (b, s)) -> StateT s m a -> StateT s n b
mapStateT ((Env -> Env) -> Reader Env (a, State) -> Reader Env (a, State)
forall r (m :: * -> *) a.
(r -> r) -> ReaderT r m a -> ReaderT r m a
R.local Env -> Env
f)
ask :: Grapher Env
ask :: Grapher Env
ask = Reader Env Env -> Grapher Env
forall (m :: * -> *) a. Monad m => m a -> StateT State m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift Reader Env Env
forall (m :: * -> *) r. Monad m => ReaderT r m r
R.ask
asks :: (Env -> a) -> Grapher a
asks :: forall a. (Env -> a) -> Grapher a
asks = Reader Env a -> StateT State (Reader Env) a
forall (m :: * -> *) a. Monad m => m a -> StateT State m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Reader Env a -> StateT State (Reader Env) a)
-> ((Env -> a) -> Reader Env a)
-> (Env -> a)
-> StateT State (Reader Env) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Env -> a) -> Reader Env a
forall (m :: * -> *) r a. Monad m => (r -> a) -> ReaderT r m a
R.asks
tellHostOnly :: Grapher ()
tellHostOnly :: Grapher ()
tellHostOnly =
(State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateStats = (stateStats st) {bodyHostOnly = True}}
tellGPUBody :: Grapher ()
tellGPUBody :: Grapher ()
tellGPUBody =
(State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateStats = (stateStats st) {bodyHasGPUBody = True}}
tellRead :: Grapher ()
tellRead :: Grapher ()
tellRead =
(State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateStats = (stateStats st) {bodyReads = True}}
tellOperands :: IdSet -> Grapher ()
tellOperands :: Operands -> Grapher ()
tellOperands Operands
is =
(State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st ->
let stats :: BodyStats
stats = State -> BodyStats
stateStats State
st
operands :: Operands
operands = BodyStats -> Operands
bodyOperands BodyStats
stats
in State
st {stateStats = stats {bodyOperands = operands <> is}}
tellHostOnlyParent :: BodyDepth -> Grapher ()
tellHostOnlyParent :: Id -> Grapher ()
tellHostOnlyParent Id
body_depth =
(State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st ->
let stats :: BodyStats
stats = State -> BodyStats
stateStats State
st
parents :: Operands
parents = BodyStats -> Operands
bodyHostOnlyParents BodyStats
stats
parents' :: Operands
parents' = Id -> Operands -> Operands
IS.insert Id
body_depth Operands
parents
in State
st {stateStats = stats {bodyHostOnlyParents = parents'}}
getGraph :: Grapher Graph
getGraph :: Grapher Graph
getGraph = (State -> Graph) -> Grapher Graph
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> Graph
stateGraph
getGraphedScalars :: Grapher IdSet
getGraphedScalars :: Grapher Operands
getGraphedScalars = (State -> Operands) -> Grapher Operands
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> Operands
stateGraphedScalars
getCopyableMemory :: Grapher CopyableMemoryMap
getCopyableMemory :: Grapher CopyableMemoryMap
getCopyableMemory = (State -> CopyableMemoryMap) -> Grapher CopyableMemoryMap
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> CopyableMemoryMap
stateCopyableMemory
outermostCopyableArray :: VName -> Grapher (Maybe BodyDepth)
outermostCopyableArray :: VName -> Grapher (Maybe Id)
outermostCopyableArray VName
n = Id -> CopyableMemoryMap -> Maybe Id
forall a. Id -> IntMap a -> Maybe a
IM.lookup (VName -> Id
nameToId VName
n) (CopyableMemoryMap -> Maybe Id)
-> Grapher CopyableMemoryMap -> Grapher (Maybe Id)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Grapher CopyableMemoryMap
getCopyableMemory
onlyGraphedScalars :: (Foldable t) => t VName -> Grapher IdSet
onlyGraphedScalars :: forall (t :: * -> *). Foldable t => t VName -> Grapher Operands
onlyGraphedScalars t VName
vs = do
let is :: Operands
is = (Operands -> VName -> Operands) -> Operands -> t VName -> Operands
forall b a. (b -> a -> b) -> b -> t a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\Operands
s VName
n -> Id -> Operands -> Operands
IS.insert (VName -> Id
nameToId VName
n) Operands
s) Operands
IS.empty t VName
vs
Operands -> Operands -> Operands
IS.intersection Operands
is (Operands -> Operands) -> Grapher Operands -> Grapher Operands
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Grapher Operands
getGraphedScalars
onlyGraphedScalar :: VName -> Grapher IdSet
onlyGraphedScalar :: VName -> Grapher Operands
onlyGraphedScalar VName
n = do
let i :: Id
i = VName -> Id
nameToId VName
n
Operands
gss <- Grapher Operands
getGraphedScalars
if Id -> Operands -> Bool
IS.member Id
i Operands
gss
then Operands -> Grapher Operands
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Id -> Operands
IS.singleton Id
i)
else Operands -> Grapher Operands
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Operands
IS.empty
onlyGraphedScalarSubExp :: SubExp -> Grapher IdSet
onlyGraphedScalarSubExp :: SubExp -> Grapher Operands
onlyGraphedScalarSubExp (Constant PrimValue
_) = Operands -> Grapher Operands
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Operands
IS.empty
onlyGraphedScalarSubExp (Var VName
n) = VName -> Grapher Operands
onlyGraphedScalar VName
n
modifyGraph :: (Graph -> Graph) -> Grapher ()
modifyGraph :: (Graph -> Graph) -> Grapher ()
modifyGraph Graph -> Graph
f =
(State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateGraph = f (stateGraph st)}
modifyGraphedScalars :: (IdSet -> IdSet) -> Grapher ()
modifyGraphedScalars :: (Operands -> Operands) -> Grapher ()
modifyGraphedScalars Operands -> Operands
f =
(State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateGraphedScalars = f (stateGraphedScalars st)}
modifyCopyableMemory :: (CopyableMemoryMap -> CopyableMemoryMap) -> Grapher ()
modifyCopyableMemory :: (CopyableMemoryMap -> CopyableMemoryMap) -> Grapher ()
modifyCopyableMemory CopyableMemoryMap -> CopyableMemoryMap
f =
(State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateCopyableMemory = f (stateCopyableMemory st)}
modifySources :: (Sources -> Sources) -> Grapher ()
modifySources :: (Sources -> Sources) -> Grapher ()
modifySources Sources -> Sources
f =
(State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateSources = f (stateSources st)}
recordCopyableMemory :: Id -> BodyDepth -> Grapher ()
recordCopyableMemory :: Id -> Id -> Grapher ()
recordCopyableMemory Id
i Id
bd =
(CopyableMemoryMap -> CopyableMemoryMap) -> Grapher ()
modifyCopyableMemory (Id -> Id -> CopyableMemoryMap -> CopyableMemoryMap
forall a. Id -> a -> IntMap a -> IntMap a
IM.insert Id
i Id
bd)
incForkDepthFor :: Grapher a -> Grapher a
incForkDepthFor :: forall a. Grapher a -> Grapher a
incForkDepthFor =
(Env -> Env) -> Grapher a -> Grapher a
forall a. (Env -> Env) -> Grapher a -> Grapher a
local ((Env -> Env) -> Grapher a -> Grapher a)
-> (Env -> Env) -> Grapher a -> Grapher a
forall a b. (a -> b) -> a -> b
$ \Env
env ->
let meta :: Meta
meta = Env -> Meta
envMeta Env
env
fork_depth :: Id
fork_depth = Meta -> Id
metaForkDepth Meta
meta
in Env
env {envMeta = meta {metaForkDepth = fork_depth + 1}}
incBodyDepthFor :: Grapher a -> Grapher a
incBodyDepthFor :: forall a. Grapher a -> Grapher a
incBodyDepthFor =
(Env -> Env) -> Grapher a -> Grapher a
forall a. (Env -> Env) -> Grapher a -> Grapher a
local ((Env -> Env) -> Grapher a -> Grapher a)
-> (Env -> Env) -> Grapher a -> Grapher a
forall a b. (a -> b) -> a -> b
$ \Env
env ->
let meta :: Meta
meta = Env -> Meta
envMeta Env
env
body_depth :: Id
body_depth = Meta -> Id
metaBodyDepth Meta
meta
in Env
env {envMeta = meta {metaBodyDepth = body_depth + 1}}
graphIdFor :: Id -> Grapher a -> Grapher a
graphIdFor :: forall a. Id -> Grapher a -> Grapher a
graphIdFor Id
i =
(Env -> Env) -> Grapher a -> Grapher a
forall a. (Env -> Env) -> Grapher a -> Grapher a
local ((Env -> Env) -> Grapher a -> Grapher a)
-> (Env -> Env) -> Grapher a -> Grapher a
forall a b. (a -> b) -> a -> b
$ \Env
env ->
let meta :: Meta
meta = Env -> Meta
envMeta Env
env
in Env
env {envMeta = meta {metaGraphId = Just i}}
captureBodyStats :: Grapher a -> Grapher BodyStats
captureBodyStats :: forall a. Grapher a -> Grapher BodyStats
captureBodyStats Grapher a
m = do
BodyStats
stats <- (State -> BodyStats) -> Grapher BodyStats
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> BodyStats
stateStats
(State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateStats = mempty}
a
_ <- Grapher a
m
BodyStats
stats' <- (State -> BodyStats) -> Grapher BodyStats
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> BodyStats
stateStats
(State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateStats = stats <> stats'}
BodyStats -> Grapher BodyStats
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure BodyStats
stats'
isHostOnlyFun :: Name -> Grapher Bool
isHostOnlyFun :: Name -> Grapher Bool
isHostOnlyFun Name
fn = (Env -> Bool) -> Grapher Bool
forall a. (Env -> a) -> Grapher a
asks ((Env -> Bool) -> Grapher Bool) -> (Env -> Bool) -> Grapher Bool
forall a b. (a -> b) -> a -> b
$ Name -> HostOnlyFuns -> Bool
forall a. Ord a => a -> Set a -> Bool
S.member Name
fn (HostOnlyFuns -> Bool) -> (Env -> HostOnlyFuns) -> Env -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> HostOnlyFuns
envHostOnlyFuns
getMeta :: Grapher Meta
getMeta :: Grapher Meta
getMeta = (Env -> Meta) -> Grapher Meta
forall a. (Env -> a) -> Grapher a
asks Env -> Meta
envMeta
getBodyDepth :: Grapher BodyDepth
getBodyDepth :: StateT State (Reader Env) Id
getBodyDepth = (Env -> Id) -> StateT State (Reader Env) Id
forall a. (Env -> a) -> Grapher a
asks (Meta -> Id
metaBodyDepth (Meta -> Id) -> (Env -> Meta) -> Env -> Id
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> Meta
envMeta)