module Futhark.Optimise.Simplify.Rules.ClosedForm
( foldClosedForm,
loopClosedForm,
)
where
import Control.Monad
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.Construct
import Futhark.IR
import Futhark.Optimise.Simplify.Rule
import Futhark.Optimise.Simplify.Rules.Simple (VarLookup)
import Futhark.Transform.Rename
foldClosedForm ::
(BuilderOps rep) =>
VarLookup rep ->
Pat (LetDec rep) ->
Lambda rep ->
[SubExp] ->
[VName] ->
RuleM rep ()
foldClosedForm :: forall rep.
BuilderOps rep =>
VarLookup rep
-> Pat (LetDec rep)
-> Lambda rep
-> [SubExp]
-> [VName]
-> RuleM rep ()
foldClosedForm VarLookup rep
look Pat (LetDec rep)
pat Lambda rep
lam [SubExp]
accs [VName]
arrs = do
SubExp
inputsize <- Int -> [TypeBase Shape NoUniqueness] -> SubExp
forall u. Int -> [TypeBase Shape u] -> SubExp
arraysSize Int
0 ([TypeBase Shape NoUniqueness] -> SubExp)
-> RuleM rep [TypeBase Shape NoUniqueness] -> RuleM rep SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> RuleM rep (TypeBase Shape NoUniqueness))
-> [VName] -> RuleM rep [TypeBase Shape NoUniqueness]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> RuleM rep (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType [VName]
arrs
PrimType
t <- case Pat (LetDec rep) -> [TypeBase Shape NoUniqueness]
forall dec. Typed dec => Pat dec -> [TypeBase Shape NoUniqueness]
patTypes Pat (LetDec rep)
pat of
[Prim FloatType {}] -> RuleM rep PrimType
forall rep a. RuleM rep a
cannotSimplify
[Prim PrimType
t] -> PrimType -> RuleM rep PrimType
forall a. a -> RuleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure PrimType
t
[TypeBase Shape NoUniqueness]
_ -> RuleM rep PrimType
forall rep a. RuleM rep a
cannotSimplify
Body rep
closedBody <-
[VName]
-> SubExp
-> Names
-> IntType
-> Map VName SubExp
-> [VName]
-> Body rep
-> [SubExp]
-> RuleM rep (Body rep)
forall rep.
BuilderOps rep =>
[VName]
-> SubExp
-> Names
-> IntType
-> Map VName SubExp
-> [VName]
-> Body rep
-> [SubExp]
-> RuleM rep (Body rep)
checkResults
(Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat)
SubExp
inputsize
Names
forall a. Monoid a => a
mempty
IntType
Int64
Map VName SubExp
knownBnds
((Param (LParamInfo rep) -> VName)
-> [Param (LParamInfo rep)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName (Lambda rep -> [Param (LParamInfo rep)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam))
(Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam)
[SubExp]
accs
VName
isEmpty <- String -> RuleM rep VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"fold_input_is_empty"
[VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
isEmpty] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$
CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (PrimType -> CmpOp
CmpEq PrimType
int64) SubExp
inputsize (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0)
Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat
(Exp rep -> RuleM rep ()) -> RuleM rep (Exp rep) -> RuleM rep ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ( [SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [VName -> SubExp
Var VName
isEmpty]
([Case (Body rep)]
-> Body rep -> MatchDec (BranchType rep) -> Exp rep)
-> RuleM rep [Case (Body rep)]
-> RuleM rep (Body rep -> MatchDec (BranchType rep) -> Exp rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Case (Body rep) -> [Case (Body rep)]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Case (Body rep) -> [Case (Body rep)])
-> (Body rep -> Case (Body rep)) -> Body rep -> [Case (Body rep)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe PrimValue] -> Body rep -> Case (Body rep)
forall body. [Maybe PrimValue] -> body -> Case body
Case [PrimValue -> Maybe PrimValue
forall a. a -> Maybe a
Just (PrimValue -> Maybe PrimValue) -> PrimValue -> Maybe PrimValue
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True] (Body rep -> [Case (Body rep)])
-> RuleM rep (Body rep) -> RuleM rep [Case (Body rep)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [SubExp] -> RuleM rep (Body (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [SubExp]
accs)
RuleM rep (Body rep -> MatchDec (BranchType rep) -> Exp rep)
-> RuleM rep (Body rep)
-> RuleM rep (MatchDec (BranchType rep) -> Exp rep)
forall a b. RuleM rep (a -> b) -> RuleM rep a -> RuleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Body rep -> RuleM rep (Body rep)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody Body rep
closedBody
RuleM rep (MatchDec (BranchType rep) -> Exp rep)
-> RuleM rep (MatchDec (BranchType rep)) -> RuleM rep (Exp rep)
forall a b. RuleM rep (a -> b) -> RuleM rep a -> RuleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> MatchDec (BranchType rep) -> RuleM rep (MatchDec (BranchType rep))
forall a. a -> RuleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([BranchType rep] -> MatchSort -> MatchDec (BranchType rep)
forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [PrimType -> BranchType rep
forall rt. IsBodyType rt => PrimType -> rt
primBodyType PrimType
t] MatchSort
MatchNormal)
)
where
knownBnds :: Map VName SubExp
knownBnds = VarLookup rep
-> Lambda rep -> [SubExp] -> [VName] -> Map VName SubExp
forall rep.
VarLookup rep
-> Lambda rep -> [SubExp] -> [VName] -> Map VName SubExp
determineKnownBindings VarLookup rep
look Lambda rep
lam [SubExp]
accs [VName]
arrs
loopClosedForm ::
(BuilderOps rep) =>
Pat (LetDec rep) ->
[(FParam rep, SubExp)] ->
Names ->
IntType ->
SubExp ->
Body rep ->
RuleM rep ()
loopClosedForm :: forall rep.
BuilderOps rep =>
Pat (LetDec rep)
-> [(FParam rep, SubExp)]
-> Names
-> IntType
-> SubExp
-> Body rep
-> RuleM rep ()
loopClosedForm Pat (LetDec rep)
pat [(FParam rep, SubExp)]
merge Names
i IntType
it SubExp
bound Body rep
body = do
PrimType
t <- case Pat (LetDec rep) -> [TypeBase Shape NoUniqueness]
forall dec. Typed dec => Pat dec -> [TypeBase Shape NoUniqueness]
patTypes Pat (LetDec rep)
pat of
[Prim FloatType {}] -> RuleM rep PrimType
forall rep a. RuleM rep a
cannotSimplify
[Prim PrimType
t] -> PrimType -> RuleM rep PrimType
forall a. a -> RuleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure PrimType
t
[TypeBase Shape NoUniqueness]
_ -> RuleM rep PrimType
forall rep a. RuleM rep a
cannotSimplify
Body rep
closedBody <-
[VName]
-> SubExp
-> Names
-> IntType
-> Map VName SubExp
-> [VName]
-> Body rep
-> [SubExp]
-> RuleM rep (Body rep)
forall rep.
BuilderOps rep =>
[VName]
-> SubExp
-> Names
-> IntType
-> Map VName SubExp
-> [VName]
-> Body rep
-> [SubExp]
-> RuleM rep (Body rep)
checkResults
[VName]
mergenames
SubExp
bound
Names
i
IntType
it
Map VName SubExp
knownBnds
((Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
mergeidents)
Body rep
body
[SubExp]
mergeexp
VName
isEmpty <- String -> RuleM rep VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"bound_is_zero"
[VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
isEmpty] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$
CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (IntType -> CmpOp
CmpSlt IntType
it) SubExp
bound (IntType -> Integer -> SubExp
intConst IntType
it Integer
0)
Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat
(Exp rep -> RuleM rep ()) -> RuleM rep (Exp rep) -> RuleM rep ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ( [SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [VName -> SubExp
Var VName
isEmpty]
([Case (Body rep)]
-> Body rep -> MatchDec (BranchType rep) -> Exp rep)
-> RuleM rep [Case (Body rep)]
-> RuleM rep (Body rep -> MatchDec (BranchType rep) -> Exp rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Case (Body rep) -> [Case (Body rep)]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Case (Body rep) -> [Case (Body rep)])
-> (Body rep -> Case (Body rep)) -> Body rep -> [Case (Body rep)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe PrimValue] -> Body rep -> Case (Body rep)
forall body. [Maybe PrimValue] -> body -> Case body
Case [PrimValue -> Maybe PrimValue
forall a. a -> Maybe a
Just (Bool -> PrimValue
BoolValue Bool
True)] (Body rep -> [Case (Body rep)])
-> RuleM rep (Body rep) -> RuleM rep [Case (Body rep)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [SubExp] -> RuleM rep (Body (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [SubExp]
mergeexp)
RuleM rep (Body rep -> MatchDec (BranchType rep) -> Exp rep)
-> RuleM rep (Body rep)
-> RuleM rep (MatchDec (BranchType rep) -> Exp rep)
forall a b. RuleM rep (a -> b) -> RuleM rep a -> RuleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Body rep -> RuleM rep (Body rep)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody Body rep
closedBody
RuleM rep (MatchDec (BranchType rep) -> Exp rep)
-> RuleM rep (MatchDec (BranchType rep)) -> RuleM rep (Exp rep)
forall a b. RuleM rep (a -> b) -> RuleM rep a -> RuleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> MatchDec (BranchType rep) -> RuleM rep (MatchDec (BranchType rep))
forall a. a -> RuleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([BranchType rep] -> MatchSort -> MatchDec (BranchType rep)
forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [PrimType -> BranchType rep
forall rt. IsBodyType rt => PrimType -> rt
primBodyType PrimType
t] MatchSort
MatchNormal)
)
where
([FParam rep]
mergepat, [SubExp]
mergeexp) = [(FParam rep, SubExp)] -> ([FParam rep], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(FParam rep, SubExp)]
merge
mergeidents :: [Ident]
mergeidents = (FParam rep -> Ident) -> [FParam rep] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map FParam rep -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent [FParam rep]
mergepat
mergenames :: [VName]
mergenames = (FParam rep -> VName) -> [FParam rep] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map FParam rep -> VName
forall dec. Param dec -> VName
paramName [FParam rep]
mergepat
knownBnds :: Map VName SubExp
knownBnds = [(VName, SubExp)] -> Map VName SubExp
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, SubExp)] -> Map VName SubExp)
-> [(VName, SubExp)] -> Map VName SubExp
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
mergenames [SubExp]
mergeexp
checkResults ::
(BuilderOps rep) =>
[VName] ->
SubExp ->
Names ->
IntType ->
M.Map VName SubExp ->
[VName] ->
Body rep ->
[SubExp] ->
RuleM rep (Body rep)
checkResults :: forall rep.
BuilderOps rep =>
[VName]
-> SubExp
-> Names
-> IntType
-> Map VName SubExp
-> [VName]
-> Body rep
-> [SubExp]
-> RuleM rep (Body rep)
checkResults [VName]
pat SubExp
size Names
untouchable IntType
it Map VName SubExp
knownBnds [VName]
params Body rep
body [SubExp]
accs = do
((), Stms rep
stms) <-
RuleM rep () -> RuleM rep ((), Stms (Rep (RuleM rep)))
forall a. RuleM rep a -> RuleM rep (a, Stms (Rep (RuleM rep)))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms (RuleM rep () -> RuleM rep ((), Stms (Rep (RuleM rep))))
-> RuleM rep () -> RuleM rep ((), Stms (Rep (RuleM rep)))
forall a b. (a -> b) -> a -> b
$
((VName, SubExpRes) -> (VName, SubExp) -> RuleM rep ())
-> [(VName, SubExpRes)] -> [(VName, SubExp)] -> RuleM rep ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (VName, SubExpRes) -> (VName, SubExp) -> RuleM rep ()
forall {rep}.
BuilderOps rep =>
(VName, SubExpRes) -> (VName, SubExp) -> RuleM rep ()
checkResult ([VName] -> [SubExpRes] -> [(VName, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
pat [SubExpRes]
res) ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
accparams [SubExp]
accs)
Stms (Rep (RuleM rep))
-> [SubExpRes] -> RuleM rep (Body (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> [SubExpRes] -> m (Body (Rep m))
mkBodyM Stms rep
Stms (Rep (RuleM rep))
stms ([SubExpRes] -> RuleM rep (Body (Rep (RuleM rep))))
-> [SubExpRes] -> RuleM rep (Body (Rep (RuleM rep)))
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExpRes]
varsRes [VName]
pat
where
stmMap :: Map VName (Exp rep)
stmMap = Body rep -> Map VName (Exp rep)
forall rep. Body rep -> Map VName (Exp rep)
makeBindMap Body rep
body
([VName]
accparams, [VName]
_) = Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
accs) [VName]
params
res :: [SubExpRes]
res = Body rep -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body rep
body
nonFree :: Names
nonFree = Body rep -> Names
forall rep. Body rep -> Names
boundInBody Body rep
body Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [VName] -> Names
namesFromList [VName]
params Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
untouchable
checkResult :: (VName, SubExpRes) -> (VName, SubExp) -> RuleM rep ()
checkResult (VName
p, SubExpRes Certs
_ (Var VName
v)) (VName
accparam, SubExp
acc)
| Just (BasicOp (BinOp BinOp
bop SubExp
x SubExp
y)) <- VName -> Map VName (Exp rep) -> Maybe (Exp rep)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName (Exp rep)
stmMap,
SubExp
x SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
/= SubExp
y = do
let isThisAccum :: SubExp -> Bool
isThisAccum = (SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== VName -> SubExp
Var VName
accparam)
(SubExp
this, SubExp
el) <- Maybe (SubExp, SubExp) -> RuleM rep (SubExp, SubExp)
forall a rep. Maybe a -> RuleM rep a
liftMaybe (Maybe (SubExp, SubExp) -> RuleM rep (SubExp, SubExp))
-> Maybe (SubExp, SubExp) -> RuleM rep (SubExp, SubExp)
forall a b. (a -> b) -> a -> b
$
case ( (SubExp -> Maybe SubExp
asFreeSubExp SubExp
x, SubExp -> Bool
isThisAccum SubExp
y),
(SubExp -> Maybe SubExp
asFreeSubExp SubExp
y, SubExp -> Bool
isThisAccum SubExp
x)
) of
((Just SubExp
free, Bool
True), (Maybe SubExp, Bool)
_) -> (SubExp, SubExp) -> Maybe (SubExp, SubExp)
forall a. a -> Maybe a
Just (SubExp
acc, SubExp
free)
((Maybe SubExp, Bool)
_, (Just SubExp
free, Bool
True)) -> (SubExp, SubExp) -> Maybe (SubExp, SubExp)
forall a. a -> Maybe a
Just (SubExp
acc, SubExp
free)
((Maybe SubExp, Bool), (Maybe SubExp, Bool))
_ -> Maybe (SubExp, SubExp)
forall a. Maybe a
Nothing
case BinOp
bop of
BinOp
LogAnd ->
[VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
p] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogAnd SubExp
this SubExp
el
Add IntType
t Overflow
w -> do
SubExp
size' <- IntType -> SubExp -> RuleM rep SubExp
forall (m :: * -> *).
MonadBuilder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
t SubExp
size
[VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
p]
(Exp rep -> RuleM rep ()) -> RuleM rep (Exp rep) -> RuleM rep ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> RuleM rep (Exp (Rep (RuleM rep)))
-> RuleM rep (Exp (Rep (RuleM rep)))
-> RuleM rep (Exp (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
(IntType -> Overflow -> BinOp
Add IntType
t Overflow
w)
(SubExp -> RuleM rep (Exp (Rep (RuleM rep)))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
this)
(Exp (Rep (RuleM rep)) -> RuleM rep (Exp (Rep (RuleM rep)))
forall a. a -> RuleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Rep (RuleM rep)) -> RuleM rep (Exp (Rep (RuleM rep))))
-> Exp (Rep (RuleM rep)) -> RuleM rep (Exp (Rep (RuleM rep)))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
t Overflow
w) SubExp
el SubExp
size')
FAdd FloatType
t | Just RuleM rep SubExp
properly_typed_size <- FloatType -> Maybe (RuleM rep SubExp)
forall {m :: * -> *}.
MonadBuilder m =>
FloatType -> Maybe (m SubExp)
properFloatSize FloatType
t -> do
SubExp
size' <- RuleM rep SubExp
properly_typed_size
[VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
p]
(Exp rep -> RuleM rep ()) -> RuleM rep (Exp rep) -> RuleM rep ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> RuleM rep (Exp (Rep (RuleM rep)))
-> RuleM rep (Exp (Rep (RuleM rep)))
-> RuleM rep (Exp (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
(FloatType -> BinOp
FAdd FloatType
t)
(SubExp -> RuleM rep (Exp (Rep (RuleM rep)))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
this)
(Exp (Rep (RuleM rep)) -> RuleM rep (Exp (Rep (RuleM rep)))
forall a. a -> RuleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Rep (RuleM rep)) -> RuleM rep (Exp (Rep (RuleM rep))))
-> Exp (Rep (RuleM rep)) -> RuleM rep (Exp (Rep (RuleM rep)))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (FloatType -> BinOp
FMul FloatType
t) SubExp
el SubExp
size')
BinOp
_ -> RuleM rep ()
forall rep a. RuleM rep a
cannotSimplify
checkResult (VName, SubExpRes)
_ (VName, SubExp)
_ = RuleM rep ()
forall rep a. RuleM rep a
cannotSimplify
asFreeSubExp :: SubExp -> Maybe SubExp
asFreeSubExp :: SubExp -> Maybe SubExp
asFreeSubExp (Var VName
v)
| VName
v VName -> Names -> Bool
`nameIn` Names
nonFree = VName -> Map VName SubExp -> Maybe SubExp
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName SubExp
knownBnds
asFreeSubExp SubExp
se = SubExp -> Maybe SubExp
forall a. a -> Maybe a
Just SubExp
se
properFloatSize :: FloatType -> Maybe (m SubExp)
properFloatSize FloatType
t =
m SubExp -> Maybe (m SubExp)
forall a. a -> Maybe a
Just (m SubExp -> Maybe (m SubExp)) -> m SubExp -> Maybe (m SubExp)
forall a b. (a -> b) -> a -> b
$
String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"converted_size" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$
ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> FloatType -> ConvOp
SIToFP IntType
it FloatType
t) SubExp
size
determineKnownBindings ::
VarLookup rep ->
Lambda rep ->
[SubExp] ->
[VName] ->
M.Map VName SubExp
determineKnownBindings :: forall rep.
VarLookup rep
-> Lambda rep -> [SubExp] -> [VName] -> Map VName SubExp
determineKnownBindings VarLookup rep
look Lambda rep
lam [SubExp]
accs [VName]
arrs =
Map VName SubExp
accBnds Map VName SubExp -> Map VName SubExp -> Map VName SubExp
forall a. Semigroup a => a -> a -> a
<> Map VName SubExp
arrBnds
where
([Param (LParamInfo rep)]
accparams, [Param (LParamInfo rep)]
arrparams) =
Int
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
accs) ([Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)]))
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Param (LParamInfo rep)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam
accBnds :: Map VName SubExp
accBnds =
[(VName, SubExp)] -> Map VName SubExp
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, SubExp)] -> Map VName SubExp)
-> [(VName, SubExp)] -> Map VName SubExp
forall a b. (a -> b) -> a -> b
$
[VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param (LParamInfo rep) -> VName)
-> [Param (LParamInfo rep)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName [Param (LParamInfo rep)]
accparams) [SubExp]
accs
arrBnds :: Map VName SubExp
arrBnds =
[(VName, SubExp)] -> Map VName SubExp
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, SubExp)] -> Map VName SubExp)
-> [(VName, SubExp)] -> Map VName SubExp
forall a b. (a -> b) -> a -> b
$
((VName, VName) -> Maybe (VName, SubExp))
-> [(VName, VName)] -> [(VName, SubExp)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (VName, VName) -> Maybe (VName, SubExp)
forall {a}. (a, VName) -> Maybe (a, SubExp)
isReplicate ([(VName, VName)] -> [(VName, SubExp)])
-> [(VName, VName)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$
[VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param (LParamInfo rep) -> VName)
-> [Param (LParamInfo rep)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName [Param (LParamInfo rep)]
arrparams) [VName]
arrs
isReplicate :: (a, VName) -> Maybe (a, SubExp)
isReplicate (a
p, VName
v)
| Just (BasicOp (Replicate (Shape (SubExp
_ : [SubExp]
_)) SubExp
ve), Certs
cs) <- VarLookup rep
look VName
v,
Certs
cs Certs -> Certs -> Bool
forall a. Eq a => a -> a -> Bool
== Certs
forall a. Monoid a => a
mempty =
(a, SubExp) -> Maybe (a, SubExp)
forall a. a -> Maybe a
Just (a
p, SubExp
ve)
isReplicate (a, VName)
_ = Maybe (a, SubExp)
forall a. Maybe a
Nothing
makeBindMap :: Body rep -> M.Map VName (Exp rep)
makeBindMap :: forall rep. Body rep -> Map VName (Exp rep)
makeBindMap = [(VName, Exp rep)] -> Map VName (Exp rep)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Exp rep)] -> Map VName (Exp rep))
-> (Body rep -> [(VName, Exp rep)])
-> Body rep
-> Map VName (Exp rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm rep -> Maybe (VName, Exp rep))
-> [Stm rep] -> [(VName, Exp rep)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Stm rep -> Maybe (VName, Exp rep)
forall {rep}. Stm rep -> Maybe (VName, Exp rep)
isSingletonStm ([Stm rep] -> [(VName, Exp rep)])
-> (Body rep -> [Stm rep]) -> Body rep -> [(VName, Exp rep)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms rep -> [Stm rep])
-> (Body rep -> Stms rep) -> Body rep -> [Stm rep]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms
where
isSingletonStm :: Stm rep -> Maybe (VName, Exp rep)
isSingletonStm (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ Exp rep
e) = case Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat of
[VName
v] -> (VName, Exp rep) -> Maybe (VName, Exp rep)
forall a. a -> Maybe a
Just (VName
v, Exp rep
e)
[VName]
_ -> Maybe (VName, Exp rep)
forall a. Maybe a
Nothing