{-# LANGUAGE TypeFamilies #-}
module Futhark.Optimise.Simplify.Rules
( standardRules,
removeUnnecessaryCopy,
)
where
import Control.Monad
import Control.Monad.State
import Data.List (insert, unzip4, zip4)
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.Analysis.PrimExp.Convert
import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.Analysis.UsageTable qualified as UT
import Futhark.Construct
import Futhark.IR
import Futhark.Optimise.Simplify.Rule
import Futhark.Optimise.Simplify.Rules.BasicOp
import Futhark.Optimise.Simplify.Rules.Index
import Futhark.Optimise.Simplify.Rules.Loop
import Futhark.Optimise.Simplify.Rules.Match
import Futhark.Util
topDownRules :: (BuilderOps rep) => [TopDownRule rep]
topDownRules :: forall rep. BuilderOps rep => [TopDownRule rep]
topDownRules =
[ RuleGeneric rep (TopDown rep) -> TopDownRule rep
forall rep a. RuleGeneric rep a -> SimplificationRule rep a
RuleGeneric RuleGeneric rep (TopDown rep)
forall rep. BuilderOps rep => TopDownRuleGeneric rep
constantFoldPrimFun,
RuleGeneric rep (TopDown rep) -> TopDownRule rep
forall rep a. RuleGeneric rep a -> SimplificationRule rep a
RuleGeneric RuleGeneric rep (TopDown rep)
forall rep. BuilderOps rep => TopDownRuleGeneric rep
withAccTopDown,
RuleGeneric rep (TopDown rep) -> TopDownRule rep
forall rep a. RuleGeneric rep a -> SimplificationRule rep a
RuleGeneric RuleGeneric rep (TopDown rep)
forall rep. BuilderOps rep => TopDownRuleGeneric rep
emptyArrayToScratch
]
bottomUpRules :: (BuilderOps rep, TraverseOpStms rep) => [BottomUpRule rep]
bottomUpRules :: forall rep.
(BuilderOps rep, TraverseOpStms rep) =>
[BottomUpRule rep]
bottomUpRules =
[ RuleGeneric rep (BottomUp rep) -> BottomUpRule rep
forall rep a. RuleGeneric rep a -> SimplificationRule rep a
RuleGeneric RuleGeneric rep (BottomUp rep)
forall rep.
(TraverseOpStms rep, BuilderOps rep) =>
BottomUpRuleGeneric rep
withAccBottomUp,
RuleBasicOp rep (BottomUp rep) -> BottomUpRule rep
forall rep a. RuleBasicOp rep a -> SimplificationRule rep a
RuleBasicOp RuleBasicOp rep (BottomUp rep)
forall rep. BuilderOps rep => BottomUpRuleBasicOp rep
simplifyIndex
]
standardRules :: (BuilderOps rep, TraverseOpStms rep) => RuleBook rep
standardRules :: forall rep. (BuilderOps rep, TraverseOpStms rep) => RuleBook rep
standardRules =
[TopDownRule rep] -> [BottomUpRule rep] -> RuleBook rep
forall m. [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook [TopDownRule rep]
forall rep. BuilderOps rep => [TopDownRule rep]
topDownRules [BottomUpRule rep]
forall rep.
(BuilderOps rep, TraverseOpStms rep) =>
[BottomUpRule rep]
bottomUpRules
RuleBook rep -> RuleBook rep -> RuleBook rep
forall a. Semigroup a => a -> a -> a
<> RuleBook rep
forall rep. BuilderOps rep => RuleBook rep
loopRules
RuleBook rep -> RuleBook rep -> RuleBook rep
forall a. Semigroup a => a -> a -> a
<> RuleBook rep
forall rep. BuilderOps rep => RuleBook rep
basicOpRules
RuleBook rep -> RuleBook rep -> RuleBook rep
forall a. Semigroup a => a -> a -> a
<> RuleBook rep
forall rep. BuilderOps rep => RuleBook rep
matchRules
removeUnnecessaryCopy :: (BuilderOps rep) => BottomUpRuleBasicOp rep
removeUnnecessaryCopy :: forall rep. BuilderOps rep => BottomUpRuleBasicOp rep
removeUnnecessaryCopy (SymbolTable rep
vtable, UsageTable
used) (Pat [PatElem (LetDec rep)
d]) StmAux (ExpDec rep)
aux (Replicate (Shape []) (Var VName
v))
| Bool -> Bool
not (VName
v VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used),
Bool -> Bool
not (PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
d VName -> UsageTable -> Bool
`UT.isInResult` UsageTable
used)
Bool -> Bool -> Bool
|| (PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
d VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used)
Bool -> Bool -> Bool
|| (Bool
v_is_fresh Bool -> Bool -> Bool
&& Bool
v_not_used_again),
(Bool
v_not_used_again Bool -> Bool -> Bool
&& Bool
consumable) Bool -> Bool -> Bool
|| Bool -> Bool
not (PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
d VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used) =
RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
d] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
where
v_not_used_again :: Bool
v_not_used_again = Bool -> Bool
not (VName
v VName -> UsageTable -> Bool
`UT.used` UsageTable
used)
v_is_fresh :: Bool
v_is_fresh = VName
v VName -> SymbolTable rep -> Names
forall rep. VName -> SymbolTable rep -> Names
`ST.lookupAliases` SymbolTable rep
vtable Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
== Names
forall a. Monoid a => a
mempty
consumable :: Bool
consumable = Bool -> Maybe Bool -> Bool
forall a. a -> Maybe a -> a
fromMaybe Bool
False (Maybe Bool -> Bool) -> Maybe Bool -> Bool
forall a b. (a -> b) -> a -> b
$ do
Entry rep
e <- VName -> SymbolTable rep -> Maybe (Entry rep)
forall rep. VName -> SymbolTable rep -> Maybe (Entry rep)
ST.lookup VName
v SymbolTable rep
vtable
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Entry rep -> Int
forall rep. Entry rep -> Int
ST.entryDepth Entry rep
e Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== SymbolTable rep -> Int
forall rep. SymbolTable rep -> Int
ST.loopDepth SymbolTable rep
vtable
Entry rep -> Maybe Bool
consumableStm Entry rep
e Maybe Bool -> Maybe Bool -> Maybe Bool
forall a. Maybe a -> Maybe a -> Maybe a
forall (m :: * -> *) a. MonadPlus m => m a -> m a -> m a
`mplus` Entry rep -> Maybe Bool
consumableFParam Entry rep
e
consumableFParam :: Entry rep -> Maybe Bool
consumableFParam =
Bool -> Maybe Bool
forall a. a -> Maybe a
Just (Bool -> Maybe Bool)
-> (Entry rep -> Bool) -> Entry rep -> Maybe Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> (FParamInfo rep -> Bool) -> Maybe (FParamInfo rep) -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (TypeBase (ShapeBase SubExp) Uniqueness -> Bool
forall shape. TypeBase shape Uniqueness -> Bool
unique (TypeBase (ShapeBase SubExp) Uniqueness -> Bool)
-> (FParamInfo rep -> TypeBase (ShapeBase SubExp) Uniqueness)
-> FParamInfo rep
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FParamInfo rep -> TypeBase (ShapeBase SubExp) Uniqueness
forall t.
DeclTyped t =>
t -> TypeBase (ShapeBase SubExp) Uniqueness
declTypeOf) (Maybe (FParamInfo rep) -> Bool)
-> (Entry rep -> Maybe (FParamInfo rep)) -> Entry rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Entry rep -> Maybe (FParamInfo rep)
forall rep. Entry rep -> Maybe (FParamInfo rep)
ST.entryFParam
consumableStm :: Entry rep -> Maybe Bool
consumableStm Entry rep
e = do
Maybe (Stm rep) -> Maybe ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (Maybe (Stm rep) -> Maybe ()) -> Maybe (Stm rep) -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Entry rep -> Maybe (Stm rep)
forall rep. Entry rep -> Maybe (Stm rep)
ST.entryStm Entry rep
e
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard Bool
v_is_fresh
Bool -> Maybe Bool
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
removeUnnecessaryCopy (SymbolTable rep, UsageTable)
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ BasicOp
_ = Rule rep
forall rep. Rule rep
Skip
constantFoldPrimFun :: (BuilderOps rep) => TopDownRuleGeneric rep
constantFoldPrimFun :: forall rep. BuilderOps rep => TopDownRuleGeneric rep
constantFoldPrimFun TopDown rep
_ (Let Pat (LetDec rep)
pat (StmAux Certs
cs Attrs
attrs ExpDec rep
_) (Apply Name
fname [(SubExp, Diet)]
args [(RetType rep, RetAls)]
_ (Safety, SrcLoc, [SrcLoc])
_))
| Just [PrimValue]
args' <- ((SubExp, Diet) -> Maybe PrimValue)
-> [(SubExp, Diet)] -> Maybe [PrimValue]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SubExp -> Maybe PrimValue
isConst (SubExp -> Maybe PrimValue)
-> ((SubExp, Diet) -> SubExp) -> (SubExp, Diet) -> Maybe PrimValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp, Diet) -> SubExp
forall a b. (a, b) -> a
fst) [(SubExp, Diet)]
args,
Just ([PrimType]
_, PrimType
_, [PrimValue] -> Maybe PrimValue
fun) <- Text
-> Map Text ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
-> Maybe ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (Name -> Text
nameToText Name
fname) Map Text ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
primFuns,
Just PrimValue
result <- [PrimValue] -> Maybe PrimValue
fun [PrimValue]
args' =
RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$
Certs -> RuleM rep () -> RuleM rep ()
forall a. Certs -> RuleM rep a -> RuleM rep a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
Attrs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Attrs -> m a -> m a
attributing Attrs
attrs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$
SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
PrimValue -> SubExp
Constant PrimValue
result
where
isConst :: SubExp -> Maybe PrimValue
isConst (Constant PrimValue
v) = PrimValue -> Maybe PrimValue
forall a. a -> Maybe a
Just PrimValue
v
isConst SubExp
_ = Maybe PrimValue
forall a. Maybe a
Nothing
constantFoldPrimFun TopDown rep
_ Stm rep
_ = Rule rep
forall rep. Rule rep
Skip
emptyArrayToScratch :: (BuilderOps rep) => TopDownRuleGeneric rep
emptyArrayToScratch :: forall rep. BuilderOps rep => TopDownRuleGeneric rep
emptyArrayToScratch TopDown rep
_ (Let pat :: Pat (LetDec rep)
pat@(Pat [PatElem (LetDec rep)
pe]) StmAux (ExpDec rep)
aux Exp rep
e)
| Just (PrimType
pt, ShapeBase SubExp
shape) <- Type -> Maybe (PrimType, ShapeBase SubExp)
isEmptyArray (Type -> Maybe (PrimType, ShapeBase SubExp))
-> Type -> Maybe (PrimType, ShapeBase SubExp)
forall a b. (a -> b) -> a -> b
$ PatElem (LetDec rep) -> Type
forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem (LetDec rep)
pe,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Exp rep -> Bool
forall {rep}. Exp rep -> Bool
isScratch Exp rep
e =
RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ PrimType -> [SubExp] -> BasicOp
Scratch PrimType
pt ([SubExp] -> BasicOp) -> [SubExp] -> BasicOp
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape
where
isScratch :: Exp rep -> Bool
isScratch (BasicOp Scratch {}) = Bool
True
isScratch Exp rep
_ = Bool
False
emptyArrayToScratch TopDown rep
_ Stm rep
_ = Rule rep
forall rep. Rule rep
Skip
simplifyIndex :: (BuilderOps rep) => BottomUpRuleBasicOp rep
simplifyIndex :: forall rep. BuilderOps rep => BottomUpRuleBasicOp rep
simplifyIndex (SymbolTable rep
vtable, UsageTable
used) pat :: Pat (LetDec rep)
pat@(Pat [PatElem (LetDec rep)
pe]) (StmAux Certs
cs Attrs
attrs ExpDec rep
_) (Index VName
idd Slice SubExp
inds)
| Just RuleM rep IndexResult
m <- SymbolTable (Rep (RuleM rep))
-> TypeLookup
-> VName
-> Slice SubExp
-> Bool
-> (VName -> Bool)
-> Maybe (RuleM rep IndexResult)
forall (m :: * -> *).
MonadBuilder m =>
SymbolTable (Rep m)
-> TypeLookup
-> VName
-> Slice SubExp
-> Bool
-> (VName -> Bool)
-> Maybe (m IndexResult)
simplifyIndexing SymbolTable rep
SymbolTable (Rep (RuleM rep))
vtable TypeLookup
seType VName
idd Slice SubExp
inds Bool
consumed VName -> Bool
consuming =
RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ Certs -> RuleM rep () -> RuleM rep ()
forall a. Certs -> RuleM rep a -> RuleM rep a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ do
IndexResult
res <- RuleM rep IndexResult
m
Attrs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Attrs -> m a -> m a
attributing Attrs
attrs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ case IndexResult
res of
SubExpResult Certs
cs' SubExp
se ->
Certs -> RuleM rep () -> RuleM rep ()
forall a. Certs -> RuleM rep a -> RuleM rep a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs' (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames (Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat) (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
IndexResult Certs
extra_cs VName
idd' Slice SubExp
inds' ->
Certs -> RuleM rep () -> RuleM rep ()
forall a. Certs -> RuleM rep a -> RuleM rep a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
extra_cs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames (Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat) (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
idd' Slice SubExp
inds'
where
consuming :: VName -> Bool
consuming = (VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used)
consumed :: Bool
consumed = VName -> Bool
consuming (VName -> Bool) -> VName -> Bool
forall a b. (a -> b) -> a -> b
$ PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe
seType :: TypeLookup
seType (Var VName
v) = VName -> SymbolTable rep -> Maybe Type
forall rep. ASTRep rep => VName -> SymbolTable rep -> Maybe Type
ST.lookupType VName
v SymbolTable rep
vtable
seType (Constant PrimValue
v) = Type -> Maybe Type
forall a. a -> Maybe a
Just (Type -> Maybe Type) -> Type -> Maybe Type
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v
simplifyIndex (SymbolTable rep, UsageTable)
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ BasicOp
_ = Rule rep
forall rep. Rule rep
Skip
withAccTopDown :: (BuilderOps rep) => TopDownRuleGeneric rep
withAccTopDown :: forall rep. BuilderOps rep => TopDownRuleGeneric rep
withAccTopDown TopDown rep
_ (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (WithAcc [] Lambda rep
lam)) = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep)
-> (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> Rule rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
Result
lam_res <- Body (Rep (RuleM rep)) -> RuleM rep Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (Body (Rep (RuleM rep)) -> RuleM rep Result)
-> Body (Rep (RuleM rep)) -> RuleM rep Result
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam
[(VName, SubExpRes)]
-> ((VName, SubExpRes) -> RuleM rep ()) -> RuleM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> Result -> [(VName, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat) Result
lam_res) (((VName, SubExpRes) -> RuleM rep ()) -> RuleM rep ())
-> ((VName, SubExpRes) -> RuleM rep ()) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ \(VName
v, SubExpRes Certs
cs SubExp
se) ->
Certs -> RuleM rep () -> RuleM rep ()
forall a. Certs -> RuleM rep a -> RuleM rep a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
v] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
withAccTopDown TopDown rep
vtable (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (WithAcc [WithAccInput rep]
inputs Lambda rep
lam)) = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep)
-> (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> Rule rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
let ([Param (LParamInfo rep)]
cert_params, [Param (LParamInfo rep)]
acc_params) =
Int
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([WithAccInput rep] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput rep]
inputs) ([Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)]))
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Param (LParamInfo rep)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam
(Result
acc_res, Result
nonacc_res) =
Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitFromEnd Int
num_nonaccs (Result -> (Result, Result)) -> Result -> (Result, Result)
forall a b. (a -> b) -> a -> b
$ Body rep -> Result
forall rep. Body rep -> Result
bodyResult (Body rep -> Result) -> Body rep -> Result
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam
([PatElem (LetDec rep)]
acc_pes, [PatElem (LetDec rep)]
nonacc_pes) =
Int
-> [PatElem (LetDec rep)]
-> ([PatElem (LetDec rep)], [PatElem (LetDec rep)])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd Int
num_nonaccs ([PatElem (LetDec rep)]
-> ([PatElem (LetDec rep)], [PatElem (LetDec rep)]))
-> [PatElem (LetDec rep)]
-> ([PatElem (LetDec rep)], [PatElem (LetDec rep)])
forall a b. (a -> b) -> a -> b
$ Pat (LetDec rep) -> [PatElem (LetDec rep)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat
([[PatElem (LetDec rep)]]
acc_pes', [WithAccInput rep]
inputs', [(Param (LParamInfo rep), Param (LParamInfo rep))]
params', Result
acc_res') <-
([Maybe
([PatElem (LetDec rep)], WithAccInput rep,
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> ([[PatElem (LetDec rep)]], [WithAccInput rep],
[(Param (LParamInfo rep), Param (LParamInfo rep))], Result))
-> RuleM
rep
[Maybe
([PatElem (LetDec rep)], WithAccInput rep,
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> RuleM
rep
([[PatElem (LetDec rep)]], [WithAccInput rep],
[(Param (LParamInfo rep), Param (LParamInfo rep))], Result)
forall a b. (a -> b) -> RuleM rep a -> RuleM rep b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([([PatElem (LetDec rep)], WithAccInput rep,
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> ([[PatElem (LetDec rep)]], [WithAccInput rep],
[(Param (LParamInfo rep), Param (LParamInfo rep))], Result)
forall a b c d. [(a, b, c, d)] -> ([a], [b], [c], [d])
unzip4 ([([PatElem (LetDec rep)], WithAccInput rep,
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> ([[PatElem (LetDec rep)]], [WithAccInput rep],
[(Param (LParamInfo rep), Param (LParamInfo rep))], Result))
-> ([Maybe
([PatElem (LetDec rep)], WithAccInput rep,
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> [([PatElem (LetDec rep)], WithAccInput rep,
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)])
-> [Maybe
([PatElem (LetDec rep)], WithAccInput rep,
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> ([[PatElem (LetDec rep)]], [WithAccInput rep],
[(Param (LParamInfo rep), Param (LParamInfo rep))], Result)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe
([PatElem (LetDec rep)], WithAccInput rep,
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> [([PatElem (LetDec rep)], WithAccInput rep,
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
forall a. [Maybe a] -> [a]
catMaybes) (RuleM
rep
[Maybe
([PatElem (LetDec rep)], WithAccInput rep,
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> RuleM
rep
([[PatElem (LetDec rep)]], [WithAccInput rep],
[(Param (LParamInfo rep), Param (LParamInfo rep))], Result))
-> ([([PatElem (LetDec rep)], WithAccInput rep,
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> RuleM
rep
[Maybe
([PatElem (LetDec rep)], WithAccInput rep,
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)])
-> [([PatElem (LetDec rep)], WithAccInput rep,
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> RuleM
rep
([[PatElem (LetDec rep)]], [WithAccInput rep],
[(Param (LParamInfo rep), Param (LParamInfo rep))], Result)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([PatElem (LetDec rep)], WithAccInput rep,
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)
-> RuleM
rep
(Maybe
([PatElem (LetDec rep)], WithAccInput rep,
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)))
-> [([PatElem (LetDec rep)], WithAccInput rep,
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> RuleM
rep
[Maybe
([PatElem (LetDec rep)], WithAccInput rep,
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ([PatElem (LetDec rep)], WithAccInput rep,
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)
-> RuleM
rep
(Maybe
([PatElem (LetDec rep)], WithAccInput rep,
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes))
forall {m :: * -> *} {dec} {a} {c} {a} {dec}.
MonadBuilder m =>
([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes)
-> m (Maybe
([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes))
tryMoveAcc ([([PatElem (LetDec rep)], WithAccInput rep,
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> RuleM
rep
([[PatElem (LetDec rep)]], [WithAccInput rep],
[(Param (LParamInfo rep), Param (LParamInfo rep))], Result))
-> [([PatElem (LetDec rep)], WithAccInput rep,
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> RuleM
rep
([[PatElem (LetDec rep)]], [WithAccInput rep],
[(Param (LParamInfo rep), Param (LParamInfo rep))], Result)
forall a b. (a -> b) -> a -> b
$
[[PatElem (LetDec rep)]]
-> [WithAccInput rep]
-> [(Param (LParamInfo rep), Param (LParamInfo rep))]
-> Result
-> [([PatElem (LetDec rep)], WithAccInput rep,
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4
([Int] -> [PatElem (LetDec rep)] -> [[PatElem (LetDec rep)]]
forall a. [Int] -> [a] -> [[a]]
chunks ((WithAccInput rep -> Int) -> [WithAccInput rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map WithAccInput rep -> Int
forall {t :: * -> *} {a} {a} {c}. Foldable t => (a, t a, c) -> Int
inputArrs [WithAccInput rep]
inputs) [PatElem (LetDec rep)]
acc_pes)
[WithAccInput rep]
inputs
([Param (LParamInfo rep)]
-> [Param (LParamInfo rep)]
-> [(Param (LParamInfo rep), Param (LParamInfo rep))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (LParamInfo rep)]
cert_params [Param (LParamInfo rep)]
acc_params)
Result
acc_res
let ([Param (LParamInfo rep)]
cert_params', [Param (LParamInfo rep)]
acc_params') = [(Param (LParamInfo rep), Param (LParamInfo rep))]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param (LParamInfo rep), Param (LParamInfo rep))]
params'
([PatElem (LetDec rep)]
nonacc_pes', Result
nonacc_res') <-
[(PatElem (LetDec rep), SubExpRes)]
-> ([PatElem (LetDec rep)], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(PatElem (LetDec rep), SubExpRes)]
-> ([PatElem (LetDec rep)], Result))
-> ([Maybe (PatElem (LetDec rep), SubExpRes)]
-> [(PatElem (LetDec rep), SubExpRes)])
-> [Maybe (PatElem (LetDec rep), SubExpRes)]
-> ([PatElem (LetDec rep)], Result)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe (PatElem (LetDec rep), SubExpRes)]
-> [(PatElem (LetDec rep), SubExpRes)]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe (PatElem (LetDec rep), SubExpRes)]
-> ([PatElem (LetDec rep)], Result))
-> RuleM rep [Maybe (PatElem (LetDec rep), SubExpRes)]
-> RuleM rep ([PatElem (LetDec rep)], Result)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((PatElem (LetDec rep), SubExpRes)
-> RuleM rep (Maybe (PatElem (LetDec rep), SubExpRes)))
-> [(PatElem (LetDec rep), SubExpRes)]
-> RuleM rep [Maybe (PatElem (LetDec rep), SubExpRes)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (PatElem (LetDec rep), SubExpRes)
-> RuleM rep (Maybe (PatElem (LetDec rep), SubExpRes))
tryMoveNonAcc ([PatElem (LetDec rep)]
-> Result -> [(PatElem (LetDec rep), SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem (LetDec rep)]
nonacc_pes Result
nonacc_res)
Bool -> RuleM rep () -> RuleM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([[PatElem (LetDec rep)]] -> [PatElem (LetDec rep)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[PatElem (LetDec rep)]]
acc_pes' [PatElem (LetDec rep)] -> [PatElem (LetDec rep)] -> Bool
forall a. Eq a => a -> a -> Bool
== [PatElem (LetDec rep)]
acc_pes Bool -> Bool -> Bool
&& [PatElem (LetDec rep)]
nonacc_pes' [PatElem (LetDec rep)] -> [PatElem (LetDec rep)] -> Bool
forall a. Eq a => a -> a -> Bool
== [PatElem (LetDec rep)]
nonacc_pes) RuleM rep ()
forall rep a. RuleM rep a
cannotSimplify
Lambda rep
lam' <-
[LParam (Rep (RuleM rep))]
-> RuleM rep Result -> RuleM rep (Lambda (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda ([Param (LParamInfo rep)]
cert_params' [Param (LParamInfo rep)]
-> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo rep)]
acc_params') (RuleM rep Result -> RuleM rep (Lambda (Rep (RuleM rep))))
-> RuleM rep Result -> RuleM rep (Lambda (Rep (RuleM rep)))
forall a b. (a -> b) -> a -> b
$
Body (Rep (RuleM rep)) -> RuleM rep Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (Body (Rep (RuleM rep)) -> RuleM rep Result)
-> Body (Rep (RuleM rep)) -> RuleM rep Result
forall a b. (a -> b) -> a -> b
$
(Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam) {bodyResult = acc_res' <> nonacc_res'}
Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind ([PatElem (LetDec rep)] -> Pat (LetDec rep)
forall dec. [PatElem dec] -> Pat dec
Pat ([[PatElem (LetDec rep)]] -> [PatElem (LetDec rep)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[PatElem (LetDec rep)]]
acc_pes' [PatElem (LetDec rep)]
-> [PatElem (LetDec rep)] -> [PatElem (LetDec rep)]
forall a. Semigroup a => a -> a -> a
<> [PatElem (LetDec rep)]
nonacc_pes')) (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ [WithAccInput rep] -> Lambda rep -> Exp rep
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [WithAccInput rep]
inputs' Lambda rep
lam'
where
num_nonaccs :: Int
num_nonaccs = [Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam) Int -> Int -> Int
forall a. Num a => a -> a -> a
- [WithAccInput rep] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput rep]
inputs
inputArrs :: (a, t a, c) -> Int
inputArrs (a
_, t a
arrs, c
_) = t a -> Int
forall a. t a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length t a
arrs
tryMoveAcc :: ([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes)
-> m (Maybe
([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes))
tryMoveAcc ([PatElem dec]
pes, (a
_, [VName]
arrs, c
_), (a
_, Param dec
acc_p), SubExpRes Certs
cs (Var VName
v))
| Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
acc_p VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v,
Certs
cs Certs -> Certs -> Bool
forall a. Eq a => a -> a -> Bool
== Certs
forall a. Monoid a => a
mempty = do
[(PatElem dec, VName)] -> ((PatElem dec, VName) -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem dec] -> [VName] -> [(PatElem dec, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem dec]
pes [VName]
arrs) (((PatElem dec, VName) -> m ()) -> m ())
-> ((PatElem dec, VName) -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \(PatElem dec
pe, VName
arr) ->
[VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElem dec -> VName
forall dec. PatElem dec -> VName
patElemName PatElem dec
pe] (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr
Maybe ([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes)
-> m (Maybe
([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe ([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes)
forall a. Maybe a
Nothing
tryMoveAcc ([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes)
x =
Maybe ([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes)
-> m (Maybe
([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe ([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes)
-> m (Maybe
([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes)))
-> Maybe
([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes)
-> m (Maybe
([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes))
forall a b. (a -> b) -> a -> b
$ ([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes)
-> Maybe
([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes)
forall a. a -> Maybe a
Just ([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes)
x
tryMoveNonAcc :: (PatElem (LetDec rep), SubExpRes)
-> RuleM rep (Maybe (PatElem (LetDec rep), SubExpRes))
tryMoveNonAcc (PatElem (LetDec rep)
pe, SubExpRes Certs
cs (Var VName
v))
| VName
v VName -> TopDown rep -> Bool
forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` TopDown rep
vtable,
Certs
cs Certs -> Certs -> Bool
forall a. Eq a => a -> a -> Bool
== Certs
forall a. Monoid a => a
mempty = do
[VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
Maybe (PatElem (LetDec rep), SubExpRes)
-> RuleM rep (Maybe (PatElem (LetDec rep), SubExpRes))
forall a. a -> RuleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (PatElem (LetDec rep), SubExpRes)
forall a. Maybe a
Nothing
tryMoveNonAcc (PatElem (LetDec rep)
pe, SubExpRes Certs
cs (Constant PrimValue
v))
| Certs
cs Certs -> Certs -> Bool
forall a. Eq a => a -> a -> Bool
== Certs
forall a. Monoid a => a
mempty = do
[VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
v
Maybe (PatElem (LetDec rep), SubExpRes)
-> RuleM rep (Maybe (PatElem (LetDec rep), SubExpRes))
forall a. a -> RuleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (PatElem (LetDec rep), SubExpRes)
forall a. Maybe a
Nothing
tryMoveNonAcc (PatElem (LetDec rep), SubExpRes)
x =
Maybe (PatElem (LetDec rep), SubExpRes)
-> RuleM rep (Maybe (PatElem (LetDec rep), SubExpRes))
forall a. a -> RuleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (PatElem (LetDec rep), SubExpRes)
-> RuleM rep (Maybe (PatElem (LetDec rep), SubExpRes)))
-> Maybe (PatElem (LetDec rep), SubExpRes)
-> RuleM rep (Maybe (PatElem (LetDec rep), SubExpRes))
forall a b. (a -> b) -> a -> b
$ (PatElem (LetDec rep), SubExpRes)
-> Maybe (PatElem (LetDec rep), SubExpRes)
forall a. a -> Maybe a
Just (PatElem (LetDec rep), SubExpRes)
x
withAccTopDown TopDown rep
_ Stm rep
_ = Rule rep
forall rep. Rule rep
Skip
elimUpdates :: forall rep. (ASTRep rep, TraverseOpStms rep) => [VName] -> Body rep -> (Body rep, [VName])
elimUpdates :: forall rep.
(ASTRep rep, TraverseOpStms rep) =>
[VName] -> Body rep -> (Body rep, [VName])
elimUpdates [VName]
get_rid_of = (State [VName] (Body rep) -> [VName] -> (Body rep, [VName]))
-> [VName] -> State [VName] (Body rep) -> (Body rep, [VName])
forall a b c. (a -> b -> c) -> b -> a -> c
flip State [VName] (Body rep) -> [VName] -> (Body rep, [VName])
forall s a. State s a -> s -> (a, s)
runState [VName]
forall a. Monoid a => a
mempty (State [VName] (Body rep) -> (Body rep, [VName]))
-> (Body rep -> State [VName] (Body rep))
-> Body rep
-> (Body rep, [VName])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body rep -> State [VName] (Body rep)
onBody
where
onBody :: Body rep -> State [VName] (Body rep)
onBody Body rep
body = do
Stms rep
stms' <- Stms rep -> StateT [VName] Identity (Stms rep)
onStms (Stms rep -> StateT [VName] Identity (Stms rep))
-> Stms rep -> StateT [VName] Identity (Stms rep)
forall a b. (a -> b) -> a -> b
$ Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms Body rep
body
Body rep -> State [VName] (Body rep)
forall a. a -> StateT [VName] Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Body rep
body {bodyStms = stms'}
onStms :: Stms rep -> StateT [VName] Identity (Stms rep)
onStms = (Stm rep -> StateT [VName] Identity (Stm rep))
-> Stms rep -> StateT [VName] Identity (Stms rep)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Seq a -> f (Seq b)
traverse Stm rep -> StateT [VName] Identity (Stm rep)
onStm
onStm :: Stm rep -> StateT [VName] Identity (Stm rep)
onStm (Let pat :: Pat (LetDec rep)
pat@(Pat [PatElem VName
_ LetDec rep
dec]) StmAux (ExpDec rep)
aux (BasicOp (UpdateAcc Safety
_ VName
acc [SubExp]
_ [SubExp]
_)))
| Acc VName
c ShapeBase SubExp
_ [Type]
_ NoUniqueness
_ <- LetDec rep -> Type
forall t. Typed t => t -> Type
typeOf LetDec rep
dec,
VName
c VName -> [VName] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
get_rid_of = do
([VName] -> [VName]) -> StateT [VName] Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (VName -> [VName] -> [VName]
forall a. Ord a => a -> [a] -> [a]
insert VName
c)
Stm rep -> StateT [VName] Identity (Stm rep)
forall a. a -> StateT [VName] Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm rep -> StateT [VName] Identity (Stm rep))
-> Stm rep -> StateT [VName] Identity (Stm rep)
forall a b. (a -> b) -> a -> b
$ Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (Exp rep -> Stm rep) -> Exp rep -> Stm rep
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
acc
onStm (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux Exp rep
e) = Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (Exp rep -> Stm rep)
-> StateT [VName] Identity (Exp rep)
-> StateT [VName] Identity (Stm rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp rep -> StateT [VName] Identity (Exp rep)
onExp Exp rep
e
onExp :: Exp rep -> StateT [VName] Identity (Exp rep)
onExp = Mapper rep rep (StateT [VName] Identity)
-> Exp rep -> StateT [VName] Identity (Exp rep)
forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper rep rep (StateT [VName] Identity)
mapper
where
mapper :: Mapper rep rep (StateT [VName] Identity)
mapper =
(Mapper rep rep m
forall rep (m :: * -> *). Monad m => Mapper rep rep m
forall {m :: * -> *}. Monad m => Mapper rep rep m
identityMapper :: forall m. (Monad m) => Mapper rep rep m)
{ mapOnOp = traverseOpStms (\Scope rep
_ Stms rep
stms -> Stms rep -> StateT [VName] Identity (Stms rep)
onStms Stms rep
stms),
mapOnBody = \Scope rep
_ Body rep
body -> Body rep -> State [VName] (Body rep)
onBody Body rep
body
}
withAccBottomUp :: (TraverseOpStms rep, BuilderOps rep) => BottomUpRuleGeneric rep
withAccBottomUp :: forall rep.
(TraverseOpStms rep, BuilderOps rep) =>
BottomUpRuleGeneric rep
withAccBottomUp (SymbolTable rep
_, UsageTable
utable) (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (WithAcc [WithAccInput rep]
inputs Lambda rep
lam))
| Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> UsageTable -> Bool
`UT.used` UsageTable
utable) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
let (Result
acc_res, Result
nonacc_res) =
Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitFromEnd Int
num_nonaccs (Result -> (Result, Result)) -> Result -> (Result, Result)
forall a b. (a -> b) -> a -> b
$ Body rep -> Result
forall rep. Body rep -> Result
bodyResult (Body rep -> Result) -> Body rep -> Result
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam
([PatElem (LetDec rep)]
acc_pes, [PatElem (LetDec rep)]
nonacc_pes) =
Int
-> [PatElem (LetDec rep)]
-> ([PatElem (LetDec rep)], [PatElem (LetDec rep)])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd Int
num_nonaccs ([PatElem (LetDec rep)]
-> ([PatElem (LetDec rep)], [PatElem (LetDec rep)]))
-> [PatElem (LetDec rep)]
-> ([PatElem (LetDec rep)], [PatElem (LetDec rep)])
forall a b. (a -> b) -> a -> b
$ Pat (LetDec rep) -> [PatElem (LetDec rep)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat
([Param (LParamInfo rep)]
cert_params, [Param (LParamInfo rep)]
acc_params) =
Int
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([WithAccInput rep] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput rep]
inputs) ([Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)]))
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Param (LParamInfo rep)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam
let get_rid_of :: [VName]
get_rid_of =
(([PatElem (LetDec rep)], VName) -> VName)
-> [([PatElem (LetDec rep)], VName)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map ([PatElem (LetDec rep)], VName) -> VName
forall a b. (a, b) -> b
snd ([([PatElem (LetDec rep)], VName)] -> [VName])
-> ([([PatElem (LetDec rep)], VName)]
-> [([PatElem (LetDec rep)], VName)])
-> [([PatElem (LetDec rep)], VName)]
-> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([PatElem (LetDec rep)], VName) -> Bool)
-> [([PatElem (LetDec rep)], VName)]
-> [([PatElem (LetDec rep)], VName)]
forall a. (a -> Bool) -> [a] -> [a]
filter ([PatElem (LetDec rep)], VName) -> Bool
getRidOf
([([PatElem (LetDec rep)], VName)] -> [VName])
-> [([PatElem (LetDec rep)], VName)] -> [VName]
forall a b. (a -> b) -> a -> b
$ [[PatElem (LetDec rep)]]
-> [VName] -> [([PatElem (LetDec rep)], VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip
([Int] -> [PatElem (LetDec rep)] -> [[PatElem (LetDec rep)]]
forall a. [Int] -> [a] -> [[a]]
chunks ((WithAccInput rep -> Int) -> [WithAccInput rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map WithAccInput rep -> Int
forall {t :: * -> *} {a} {a} {c}. Foldable t => (a, t a, c) -> Int
inputArrs [WithAccInput rep]
inputs) [PatElem (LetDec rep)]
acc_pes)
([VName] -> [([PatElem (LetDec rep)], VName)])
-> [VName] -> [([PatElem (LetDec rep)], VName)]
forall a b. (a -> b) -> a -> b
$ (Param (LParamInfo rep) -> VName)
-> [Param (LParamInfo rep)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName [Param (LParamInfo rep)]
cert_params
let ([PatElem (LetDec rep)]
nonacc_pes', Result
nonacc_res') =
[(PatElem (LetDec rep), SubExpRes)]
-> ([PatElem (LetDec rep)], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(PatElem (LetDec rep), SubExpRes)]
-> ([PatElem (LetDec rep)], Result))
-> [(PatElem (LetDec rep), SubExpRes)]
-> ([PatElem (LetDec rep)], Result)
forall a b. (a -> b) -> a -> b
$ ((PatElem (LetDec rep), SubExpRes) -> Bool)
-> [(PatElem (LetDec rep), SubExpRes)]
-> [(PatElem (LetDec rep), SubExpRes)]
forall a. (a -> Bool) -> [a] -> [a]
filter (PatElem (LetDec rep), SubExpRes) -> Bool
keepNonAccRes ([(PatElem (LetDec rep), SubExpRes)]
-> [(PatElem (LetDec rep), SubExpRes)])
-> [(PatElem (LetDec rep), SubExpRes)]
-> [(PatElem (LetDec rep), SubExpRes)]
forall a b. (a -> b) -> a -> b
$ [PatElem (LetDec rep)]
-> Result -> [(PatElem (LetDec rep), SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem (LetDec rep)]
nonacc_pes Result
nonacc_res
Bool -> RuleM rep () -> RuleM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([VName] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
get_rid_of Bool -> Bool -> Bool
&& [PatElem (LetDec rep)]
nonacc_pes' [PatElem (LetDec rep)] -> [PatElem (LetDec rep)] -> Bool
forall a. Eq a => a -> a -> Bool
== [PatElem (LetDec rep)]
nonacc_pes) RuleM rep ()
forall rep a. RuleM rep a
cannotSimplify
let (Body rep
body', [VName]
eliminated) = [VName] -> Body rep -> (Body rep, [VName])
forall rep.
(ASTRep rep, TraverseOpStms rep) =>
[VName] -> Body rep -> (Body rep, [VName])
elimUpdates [VName]
get_rid_of (Body rep -> (Body rep, [VName]))
-> Body rep -> (Body rep, [VName])
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam
Bool -> RuleM rep () -> RuleM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([VName] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
eliminated Bool -> Bool -> Bool
&& [PatElem (LetDec rep)]
nonacc_pes' [PatElem (LetDec rep)] -> [PatElem (LetDec rep)] -> Bool
forall a. Eq a => a -> a -> Bool
== [PatElem (LetDec rep)]
nonacc_pes) RuleM rep ()
forall rep a. RuleM rep a
cannotSimplify
let pes' :: [PatElem (LetDec rep)]
pes' = [PatElem (LetDec rep)]
acc_pes [PatElem (LetDec rep)]
-> [PatElem (LetDec rep)] -> [PatElem (LetDec rep)]
forall a. [a] -> [a] -> [a]
++ [PatElem (LetDec rep)]
nonacc_pes'
Lambda rep
lam' <- [LParam (Rep (RuleM rep))]
-> RuleM rep Result -> RuleM rep (Lambda (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda ([Param (LParamInfo rep)]
cert_params [Param (LParamInfo rep)]
-> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo rep)]
acc_params) (RuleM rep Result -> RuleM rep (Lambda (Rep (RuleM rep))))
-> RuleM rep Result -> RuleM rep (Lambda (Rep (RuleM rep)))
forall a b. (a -> b) -> a -> b
$ do
RuleM rep Result -> RuleM rep ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (RuleM rep Result -> RuleM rep ())
-> RuleM rep Result -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Body (Rep (RuleM rep)) -> RuleM rep Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind Body rep
Body (Rep (RuleM rep))
body'
Result -> RuleM rep Result
forall a. a -> RuleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> RuleM rep Result) -> Result -> RuleM rep Result
forall a b. (a -> b) -> a -> b
$ Result
acc_res Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
nonacc_res'
StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind ([PatElem (LetDec rep)] -> Pat (LetDec rep)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)]
pes') (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ [WithAccInput rep] -> Lambda rep -> Exp rep
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [WithAccInput rep]
inputs Lambda rep
lam'
where
num_nonaccs :: Int
num_nonaccs = [Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam) Int -> Int -> Int
forall a. Num a => a -> a -> a
- [WithAccInput rep] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput rep]
inputs
inputArrs :: (a, t a, c) -> Int
inputArrs (a
_, t a
arrs, c
_) = t a -> Int
forall a. t a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length t a
arrs
getRidOf :: ([PatElem (LetDec rep)], VName) -> Bool
getRidOf ([PatElem (LetDec rep)]
pes, VName
_) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (PatElem (LetDec rep) -> Bool) -> [PatElem (LetDec rep)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((VName -> UsageTable -> Bool
`UT.used` UsageTable
utable) (VName -> Bool)
-> (PatElem (LetDec rep) -> VName) -> PatElem (LetDec rep) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName) [PatElem (LetDec rep)]
pes
keepNonAccRes :: (PatElem (LetDec rep), SubExpRes) -> Bool
keepNonAccRes (PatElem (LetDec rep)
pe, SubExpRes
_) = PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe VName -> UsageTable -> Bool
`UT.used` UsageTable
utable
withAccBottomUp (SymbolTable rep, UsageTable)
_ Stm rep
_ = Rule rep
forall rep. Rule rep
Skip