{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}

module Futhark.IR.SOACS.Simplify
  ( simplifySOACS,
    simplifyLambda,
    simplifyFun,
    simplifyStms,
    simplifyConsts,
    simpleSOACS,
    simplifySOAC,
    soacRules,
    HasSOAC (..),
    simplifyKnownIterationSOAC,
    removeReplicateMapping,
    removeUnusedSOACInput,
    liftIdentityMapping,
    simplifyMapIota,
    SOACS,
  )
where

import Control.Monad
import Control.Monad.Identity
import Control.Monad.State
import Control.Monad.Writer
import Data.Bifunctor
import Data.Either
import Data.Foldable
import Data.List (partition, transpose, unzip4)
import Data.List.NonEmpty (NonEmpty (..))
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Set qualified as S
import Futhark.Analysis.DataDependencies
import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.Analysis.UsageTable qualified as UT
import Futhark.IR.Prop.Aliases
import Futhark.IR.SOACS hiding (reshapeInner)
import Futhark.MonadFreshNames
import Futhark.Optimise.Simplify qualified as Simplify
import Futhark.Optimise.Simplify.Engine qualified as Engine
import Futhark.Optimise.Simplify.Rep
import Futhark.Optimise.Simplify.Rule
import Futhark.Optimise.Simplify.Rules
import Futhark.Optimise.Simplify.Rules.ClosedForm
import Futhark.Pass
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Util

simpleSOACS :: Simplify.SimpleOps SOACS
simpleSOACS :: SimpleOps SOACS
simpleSOACS = SimplifyOp SOACS (Op (Wise SOACS)) -> SimpleOps SOACS
forall rep.
(SimplifiableRep rep, Buildable rep) =>
SimplifyOp rep (Op (Wise rep)) -> SimpleOps rep
Simplify.bindableSimpleOps SimplifyOp SOACS (Op (Wise SOACS))
SimplifyOp SOACS (SOAC (Wise SOACS))
forall rep. SimplifiableRep rep => SimplifyOp rep (SOAC (Wise rep))
simplifySOAC

simplifySOACS :: Prog SOACS -> PassM (Prog SOACS)
simplifySOACS :: Prog SOACS -> PassM (Prog SOACS)
simplifySOACS =
  SimpleOps SOACS
-> RuleBook (Wise SOACS)
-> HoistBlockers SOACS
-> Prog SOACS
-> PassM (Prog SOACS)
forall rep.
SimplifiableRep rep =>
SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> Prog rep
-> PassM (Prog rep)
Simplify.simplifyProg SimpleOps SOACS
simpleSOACS RuleBook (Wise SOACS)
soacRules HoistBlockers SOACS
forall {k} (rep :: k). HoistBlockers rep
Engine.noExtraHoistBlockers

simplifyFun ::
  (MonadFreshNames m) =>
  ST.SymbolTable (Wise SOACS) ->
  FunDef SOACS ->
  m (FunDef SOACS)
simplifyFun :: forall (m :: * -> *).
MonadFreshNames m =>
SymbolTable (Wise SOACS) -> FunDef SOACS -> m (FunDef SOACS)
simplifyFun =
  SimpleOps SOACS
-> RuleBook (Wise SOACS)
-> HoistBlockers SOACS
-> SymbolTable (Wise SOACS)
-> FunDef SOACS
-> m (FunDef SOACS)
forall (m :: * -> *) rep.
(MonadFreshNames m, SimplifiableRep rep) =>
SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> SymbolTable (Wise rep)
-> FunDef rep
-> m (FunDef rep)
Simplify.simplifyFun SimpleOps SOACS
simpleSOACS RuleBook (Wise SOACS)
soacRules HoistBlockers SOACS
forall {k} (rep :: k). HoistBlockers rep
Engine.noExtraHoistBlockers

simplifyLambda ::
  (HasScope SOACS m, MonadFreshNames m) => Lambda SOACS -> m (Lambda SOACS)
simplifyLambda :: forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Lambda SOACS -> m (Lambda SOACS)
simplifyLambda =
  SimpleOps SOACS
-> RuleBook (Wise SOACS)
-> HoistBlockers SOACS
-> Lambda SOACS
-> m (Lambda SOACS)
forall (m :: * -> *) rep.
(MonadFreshNames m, HasScope rep m, SimplifiableRep rep) =>
SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> Lambda rep
-> m (Lambda rep)
Simplify.simplifyLambda SimpleOps SOACS
simpleSOACS RuleBook (Wise SOACS)
soacRules HoistBlockers SOACS
forall {k} (rep :: k). HoistBlockers rep
Engine.noExtraHoistBlockers

simplifyStms ::
  (HasScope SOACS m, MonadFreshNames m) => Stms SOACS -> m (Stms SOACS)
simplifyStms :: forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Stms SOACS -> m (Stms SOACS)
simplifyStms Stms SOACS
stms = do
  Scope SOACS
scope <- m (Scope SOACS)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  SimpleOps SOACS
-> RuleBook (Wise SOACS)
-> HoistBlockers SOACS
-> Scope SOACS
-> Stms SOACS
-> m (Stms SOACS)
forall (m :: * -> *) rep.
(MonadFreshNames m, SimplifiableRep rep) =>
SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> Scope rep
-> Stms rep
-> m (Stms rep)
Simplify.simplifyStms SimpleOps SOACS
simpleSOACS RuleBook (Wise SOACS)
soacRules HoistBlockers SOACS
forall {k} (rep :: k). HoistBlockers rep
Engine.noExtraHoistBlockers Scope SOACS
scope Stms SOACS
stms

simplifyConsts ::
  (MonadFreshNames m) => Stms SOACS -> m (Stms SOACS)
simplifyConsts :: forall (m :: * -> *).
MonadFreshNames m =>
Stms SOACS -> m (Stms SOACS)
simplifyConsts =
  SimpleOps SOACS
-> RuleBook (Wise SOACS)
-> HoistBlockers SOACS
-> Scope SOACS
-> Stms SOACS
-> m (Stms SOACS)
forall (m :: * -> *) rep.
(MonadFreshNames m, SimplifiableRep rep) =>
SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> Scope rep
-> Stms rep
-> m (Stms rep)
Simplify.simplifyStms SimpleOps SOACS
simpleSOACS RuleBook (Wise SOACS)
soacRules HoistBlockers SOACS
forall {k} (rep :: k). HoistBlockers rep
Engine.noExtraHoistBlockers Scope SOACS
forall a. Monoid a => a
mempty

simplifySOAC ::
  (Simplify.SimplifiableRep rep) =>
  Simplify.SimplifyOp rep (SOAC (Wise rep))
simplifySOAC :: forall rep. SimplifiableRep rep => SimplifyOp rep (SOAC (Wise rep))
simplifySOAC (VJP [SubExp]
arr [SubExp]
vec Lambda (Wise rep)
lam) = do
  (Lambda (Wise rep)
lam', Stms (Wise rep)
hoisted) <- Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda Names
forall a. Monoid a => a
mempty Lambda (Wise rep)
lam
  [SubExp]
arr' <- (SubExp -> SimpleM rep SubExp) -> [SubExp] -> SimpleM rep [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> SimpleM rep SubExp
forall rep. SimplifiableRep rep => SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
arr
  [SubExp]
vec' <- (SubExp -> SimpleM rep SubExp) -> [SubExp] -> SimpleM rep [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> SimpleM rep SubExp
forall rep. SimplifiableRep rep => SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
vec
  (SOAC (Wise rep), Stms (Wise rep))
-> SimpleM rep (SOAC (Wise rep), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([SubExp] -> [SubExp] -> Lambda (Wise rep) -> SOAC (Wise rep)
forall rep. [SubExp] -> [SubExp] -> Lambda rep -> SOAC rep
VJP [SubExp]
arr' [SubExp]
vec' Lambda (Wise rep)
lam', Stms (Wise rep)
hoisted)
simplifySOAC (JVP [SubExp]
arr [SubExp]
vec Lambda (Wise rep)
lam) = do
  (Lambda (Wise rep)
lam', Stms (Wise rep)
hoisted) <- Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda Names
forall a. Monoid a => a
mempty Lambda (Wise rep)
lam
  [SubExp]
arr' <- (SubExp -> SimpleM rep SubExp) -> [SubExp] -> SimpleM rep [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> SimpleM rep SubExp
forall rep. SimplifiableRep rep => SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
arr
  [SubExp]
vec' <- (SubExp -> SimpleM rep SubExp) -> [SubExp] -> SimpleM rep [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> SimpleM rep SubExp
forall rep. SimplifiableRep rep => SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
vec
  (SOAC (Wise rep), Stms (Wise rep))
-> SimpleM rep (SOAC (Wise rep), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([SubExp] -> [SubExp] -> Lambda (Wise rep) -> SOAC (Wise rep)
forall rep. [SubExp] -> [SubExp] -> Lambda rep -> SOAC rep
JVP [SubExp]
arr' [SubExp]
vec' Lambda (Wise rep)
lam', Stms (Wise rep)
hoisted)
simplifySOAC (Stream SubExp
outerdim [VName]
arr [SubExp]
nes Lambda (Wise rep)
lam) = do
  SubExp
outerdim' <- SubExp -> SimpleM rep SubExp
forall rep. SimplifiableRep rep => SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
outerdim
  [SubExp]
nes' <- (SubExp -> SimpleM rep SubExp) -> [SubExp] -> SimpleM rep [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> SimpleM rep SubExp
forall rep. SimplifiableRep rep => SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
nes
  [VName]
arr' <- (VName -> SimpleM rep VName) -> [VName] -> SimpleM rep [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> SimpleM rep VName
forall rep. SimplifiableRep rep => VName -> SimpleM rep VName
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [VName]
arr
  (Lambda (Wise rep)
lam', Stms (Wise rep)
lam_hoisted) <- SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall {k} (rep :: k) a. SimpleM rep a -> SimpleM rep a
Engine.enterLoop (SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
 -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep)))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$ Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda Names
forall a. Monoid a => a
mempty Lambda (Wise rep)
lam
  (SOAC (Wise rep), Stms (Wise rep))
-> SimpleM rep (SOAC (Wise rep), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
-> [VName] -> [SubExp] -> Lambda (Wise rep) -> SOAC (Wise rep)
forall rep. SubExp -> [VName] -> [SubExp] -> Lambda rep -> SOAC rep
Stream SubExp
outerdim' [VName]
arr' [SubExp]
nes' Lambda (Wise rep)
lam', Stms (Wise rep)
lam_hoisted)
simplifySOAC (Hist SubExp
w [VName]
imgs [HistOp (Wise rep)]
ops Lambda (Wise rep)
bfun) = do
  SubExp
w' <- SubExp -> SimpleM rep SubExp
forall rep. SimplifiableRep rep => SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
w
  ([HistOp (Wise rep)]
ops', [Stms (Wise rep)]
hoisted) <- ([(HistOp (Wise rep), Stms (Wise rep))]
 -> ([HistOp (Wise rep)], [Stms (Wise rep)]))
-> SimpleM rep [(HistOp (Wise rep), Stms (Wise rep))]
-> SimpleM rep ([HistOp (Wise rep)], [Stms (Wise rep)])
forall a b. (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(HistOp (Wise rep), Stms (Wise rep))]
-> ([HistOp (Wise rep)], [Stms (Wise rep)])
forall a b. [(a, b)] -> ([a], [b])
unzip (SimpleM rep [(HistOp (Wise rep), Stms (Wise rep))]
 -> SimpleM rep ([HistOp (Wise rep)], [Stms (Wise rep)]))
-> SimpleM rep [(HistOp (Wise rep), Stms (Wise rep))]
-> SimpleM rep ([HistOp (Wise rep)], [Stms (Wise rep)])
forall a b. (a -> b) -> a -> b
$
    [HistOp (Wise rep)]
-> (HistOp (Wise rep)
    -> SimpleM rep (HistOp (Wise rep), Stms (Wise rep)))
-> SimpleM rep [(HistOp (Wise rep), Stms (Wise rep))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp (Wise rep)]
ops ((HistOp (Wise rep)
  -> SimpleM rep (HistOp (Wise rep), Stms (Wise rep)))
 -> SimpleM rep [(HistOp (Wise rep), Stms (Wise rep))])
-> (HistOp (Wise rep)
    -> SimpleM rep (HistOp (Wise rep), Stms (Wise rep)))
-> SimpleM rep [(HistOp (Wise rep), Stms (Wise rep))]
forall a b. (a -> b) -> a -> b
$ \(HistOp Shape
dests_w SubExp
rf [VName]
dests [SubExp]
nes Lambda (Wise rep)
op) -> do
      Shape
dests_w' <- Shape -> SimpleM rep Shape
forall rep. SimplifiableRep rep => Shape -> SimpleM rep Shape
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Shape
dests_w
      SubExp
rf' <- SubExp -> SimpleM rep SubExp
forall rep. SimplifiableRep rep => SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
rf
      [VName]
dests' <- [VName] -> SimpleM rep [VName]
forall rep. SimplifiableRep rep => [VName] -> SimpleM rep [VName]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [VName]
dests
      [SubExp]
nes' <- (SubExp -> SimpleM rep SubExp) -> [SubExp] -> SimpleM rep [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> SimpleM rep SubExp
forall rep. SimplifiableRep rep => SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
nes
      (Lambda (Wise rep)
op', Stms (Wise rep)
hoisted) <- SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall {k} (rep :: k) a. SimpleM rep a -> SimpleM rep a
Engine.enterLoop (SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
 -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep)))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$ Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda Names
forall a. Monoid a => a
mempty Lambda (Wise rep)
op
      (HistOp (Wise rep), Stms (Wise rep))
-> SimpleM rep (HistOp (Wise rep), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Shape
-> SubExp
-> [VName]
-> [SubExp]
-> Lambda (Wise rep)
-> HistOp (Wise rep)
forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp Shape
dests_w' SubExp
rf' [VName]
dests' [SubExp]
nes' Lambda (Wise rep)
op', Stms (Wise rep)
hoisted)
  [VName]
imgs' <- (VName -> SimpleM rep VName) -> [VName] -> SimpleM rep [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> SimpleM rep VName
forall rep. SimplifiableRep rep => VName -> SimpleM rep VName
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [VName]
imgs
  (Lambda (Wise rep)
bfun', Stms (Wise rep)
bfun_hoisted) <- SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall {k} (rep :: k) a. SimpleM rep a -> SimpleM rep a
Engine.enterLoop (SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
 -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep)))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$ Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda Names
forall a. Monoid a => a
mempty Lambda (Wise rep)
bfun
  (SOAC (Wise rep), Stms (Wise rep))
-> SimpleM rep (SOAC (Wise rep), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
-> [VName]
-> [HistOp (Wise rep)]
-> Lambda (Wise rep)
-> SOAC (Wise rep)
forall rep.
SubExp -> [VName] -> [HistOp rep] -> Lambda rep -> SOAC rep
Hist SubExp
w' [VName]
imgs' [HistOp (Wise rep)]
ops' Lambda (Wise rep)
bfun', [Stms (Wise rep)] -> Stms (Wise rep)
forall a. Monoid a => [a] -> a
mconcat [Stms (Wise rep)]
hoisted Stms (Wise rep) -> Stms (Wise rep) -> Stms (Wise rep)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise rep)
bfun_hoisted)
simplifySOAC (Screma SubExp
w [VName]
arrs (ScremaForm Lambda (Wise rep)
map_lam [Scan (Wise rep)]
scans [Reduce (Wise rep)]
reds)) = do
  ([Scan (Wise rep)]
scans', [Stms (Wise rep)]
scans_hoisted) <- ([(Scan (Wise rep), Stms (Wise rep))]
 -> ([Scan (Wise rep)], [Stms (Wise rep)]))
-> SimpleM rep [(Scan (Wise rep), Stms (Wise rep))]
-> SimpleM rep ([Scan (Wise rep)], [Stms (Wise rep)])
forall a b. (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(Scan (Wise rep), Stms (Wise rep))]
-> ([Scan (Wise rep)], [Stms (Wise rep)])
forall a b. [(a, b)] -> ([a], [b])
unzip (SimpleM rep [(Scan (Wise rep), Stms (Wise rep))]
 -> SimpleM rep ([Scan (Wise rep)], [Stms (Wise rep)]))
-> SimpleM rep [(Scan (Wise rep), Stms (Wise rep))]
-> SimpleM rep ([Scan (Wise rep)], [Stms (Wise rep)])
forall a b. (a -> b) -> a -> b
$
    [Scan (Wise rep)]
-> (Scan (Wise rep)
    -> SimpleM rep (Scan (Wise rep), Stms (Wise rep)))
-> SimpleM rep [(Scan (Wise rep), Stms (Wise rep))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Scan (Wise rep)]
scans ((Scan (Wise rep)
  -> SimpleM rep (Scan (Wise rep), Stms (Wise rep)))
 -> SimpleM rep [(Scan (Wise rep), Stms (Wise rep))])
-> (Scan (Wise rep)
    -> SimpleM rep (Scan (Wise rep), Stms (Wise rep)))
-> SimpleM rep [(Scan (Wise rep), Stms (Wise rep))]
forall a b. (a -> b) -> a -> b
$ \(Scan Lambda (Wise rep)
lam [SubExp]
nes) -> do
      (Lambda (Wise rep)
lam', Stms (Wise rep)
hoisted) <- Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda Names
forall a. Monoid a => a
mempty Lambda (Wise rep)
lam
      [SubExp]
nes' <- [SubExp] -> SimpleM rep [SubExp]
forall rep. SimplifiableRep rep => [SubExp] -> SimpleM rep [SubExp]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
nes
      (Scan (Wise rep), Stms (Wise rep))
-> SimpleM rep (Scan (Wise rep), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda (Wise rep) -> [SubExp] -> Scan (Wise rep)
forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda (Wise rep)
lam' [SubExp]
nes', Stms (Wise rep)
hoisted)

  ([Reduce (Wise rep)]
reds', [Stms (Wise rep)]
reds_hoisted) <- ([(Reduce (Wise rep), Stms (Wise rep))]
 -> ([Reduce (Wise rep)], [Stms (Wise rep)]))
-> SimpleM rep [(Reduce (Wise rep), Stms (Wise rep))]
-> SimpleM rep ([Reduce (Wise rep)], [Stms (Wise rep)])
forall a b. (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(Reduce (Wise rep), Stms (Wise rep))]
-> ([Reduce (Wise rep)], [Stms (Wise rep)])
forall a b. [(a, b)] -> ([a], [b])
unzip (SimpleM rep [(Reduce (Wise rep), Stms (Wise rep))]
 -> SimpleM rep ([Reduce (Wise rep)], [Stms (Wise rep)]))
-> SimpleM rep [(Reduce (Wise rep), Stms (Wise rep))]
-> SimpleM rep ([Reduce (Wise rep)], [Stms (Wise rep)])
forall a b. (a -> b) -> a -> b
$
    [Reduce (Wise rep)]
-> (Reduce (Wise rep)
    -> SimpleM rep (Reduce (Wise rep), Stms (Wise rep)))
-> SimpleM rep [(Reduce (Wise rep), Stms (Wise rep))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Reduce (Wise rep)]
reds ((Reduce (Wise rep)
  -> SimpleM rep (Reduce (Wise rep), Stms (Wise rep)))
 -> SimpleM rep [(Reduce (Wise rep), Stms (Wise rep))])
-> (Reduce (Wise rep)
    -> SimpleM rep (Reduce (Wise rep), Stms (Wise rep)))
-> SimpleM rep [(Reduce (Wise rep), Stms (Wise rep))]
forall a b. (a -> b) -> a -> b
$ \(Reduce Commutativity
comm Lambda (Wise rep)
lam [SubExp]
nes) -> do
      (Lambda (Wise rep)
lam', Stms (Wise rep)
hoisted) <- Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda Names
forall a. Monoid a => a
mempty Lambda (Wise rep)
lam
      [SubExp]
nes' <- [SubExp] -> SimpleM rep [SubExp]
forall rep. SimplifiableRep rep => [SubExp] -> SimpleM rep [SubExp]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
nes
      (Reduce (Wise rep), Stms (Wise rep))
-> SimpleM rep (Reduce (Wise rep), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Commutativity -> Lambda (Wise rep) -> [SubExp] -> Reduce (Wise rep)
forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
comm Lambda (Wise rep)
lam' [SubExp]
nes', Stms (Wise rep)
hoisted)

  (Lambda (Wise rep)
map_lam', Stms (Wise rep)
map_lam_hoisted) <- SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall {k} (rep :: k) a. SimpleM rep a -> SimpleM rep a
Engine.enterLoop (SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
 -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep)))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$ Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda Names
forall a. Monoid a => a
mempty Lambda (Wise rep)
map_lam

  (,)
    (SOAC (Wise rep)
 -> Stms (Wise rep) -> (SOAC (Wise rep), Stms (Wise rep)))
-> SimpleM rep (SOAC (Wise rep))
-> SimpleM
     rep (Stms (Wise rep) -> (SOAC (Wise rep), Stms (Wise rep)))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ( SubExp -> [VName] -> ScremaForm (Wise rep) -> SOAC (Wise rep)
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma
            (SubExp -> [VName] -> ScremaForm (Wise rep) -> SOAC (Wise rep))
-> SimpleM rep SubExp
-> SimpleM
     rep ([VName] -> ScremaForm (Wise rep) -> SOAC (Wise rep))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> SimpleM rep SubExp
forall rep. SimplifiableRep rep => SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
w
            SimpleM rep ([VName] -> ScremaForm (Wise rep) -> SOAC (Wise rep))
-> SimpleM rep [VName]
-> SimpleM rep (ScremaForm (Wise rep) -> SOAC (Wise rep))
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [VName] -> SimpleM rep [VName]
forall rep. SimplifiableRep rep => [VName] -> SimpleM rep [VName]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [VName]
arrs
            SimpleM rep (ScremaForm (Wise rep) -> SOAC (Wise rep))
-> SimpleM rep (ScremaForm (Wise rep))
-> SimpleM rep (SOAC (Wise rep))
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ScremaForm (Wise rep) -> SimpleM rep (ScremaForm (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda (Wise rep)
-> [Scan (Wise rep)]
-> [Reduce (Wise rep)]
-> ScremaForm (Wise rep)
forall rep.
Lambda rep -> [Scan rep] -> [Reduce rep] -> ScremaForm rep
ScremaForm Lambda (Wise rep)
map_lam' [Scan (Wise rep)]
scans' [Reduce (Wise rep)]
reds')
        )
    SimpleM rep (Stms (Wise rep) -> (SOAC (Wise rep), Stms (Wise rep)))
-> SimpleM rep (Stms (Wise rep))
-> SimpleM rep (SOAC (Wise rep), Stms (Wise rep))
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Stms (Wise rep) -> SimpleM rep (Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Stms (Wise rep)] -> Stms (Wise rep)
forall a. Monoid a => [a] -> a
mconcat [Stms (Wise rep)]
scans_hoisted Stms (Wise rep) -> Stms (Wise rep) -> Stms (Wise rep)
forall a. Semigroup a => a -> a -> a
<> [Stms (Wise rep)] -> Stms (Wise rep)
forall a. Monoid a => [a] -> a
mconcat [Stms (Wise rep)]
reds_hoisted Stms (Wise rep) -> Stms (Wise rep) -> Stms (Wise rep)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise rep)
map_lam_hoisted)

instance BuilderOps (Wise SOACS)

instance TraverseOpStms (Wise SOACS) where
  traverseOpStms :: forall (m :: * -> *).
Monad m =>
OpStmsTraverser m (Op (Wise SOACS)) (Wise SOACS)
traverseOpStms = (Map VName (NameInfo (Wise SOACS))
 -> Stms (Wise SOACS) -> m (Stms (Wise SOACS)))
-> Op (Wise SOACS) -> m (Op (Wise SOACS))
OpStmsTraverser m (SOAC (Wise SOACS)) (Wise SOACS)
forall (m :: * -> *) rep.
Monad m =>
OpStmsTraverser m (SOAC rep) rep
traverseSOACStms

fixLambdaParams ::
  (MonadBuilder m, Buildable (Rep m), BuilderOps (Rep m)) =>
  Lambda (Rep m) ->
  [Maybe SubExp] ->
  m (Lambda (Rep m))
fixLambdaParams :: forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m), BuilderOps (Rep m)) =>
Lambda (Rep m) -> [Maybe SubExp] -> m (Lambda (Rep m))
fixLambdaParams Lambda (Rep m)
lam [Maybe SubExp]
fixes = do
  GBody (Rep m) SubExpRes
body <- Builder (Rep m) [SubExpRes] -> m (GBody (Rep m) SubExpRes)
forall rep (m :: * -> *) somerep res.
(Buildable rep, MonadFreshNames m, HasScope somerep m,
 SameScope somerep rep, IsResult res) =>
Builder rep [res] -> m (GBody rep res)
runBodyBuilder (Builder (Rep m) [SubExpRes] -> m (GBody (Rep m) SubExpRes))
-> Builder (Rep m) [SubExpRes] -> m (GBody (Rep m) SubExpRes)
forall a b. (a -> b) -> a -> b
$
    Scope (Rep m)
-> Builder (Rep m) [SubExpRes] -> Builder (Rep m) [SubExpRes]
forall a.
Scope (Rep m)
-> BuilderT (Rep m) (State VNameSource) a
-> BuilderT (Rep m) (State VNameSource) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param (LParamInfo (Rep m))] -> Scope (Rep m)
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams ([Param (LParamInfo (Rep m))] -> Scope (Rep m))
-> [Param (LParamInfo (Rep m))] -> Scope (Rep m)
forall a b. (a -> b) -> a -> b
$ Lambda (Rep m) -> [Param (LParamInfo (Rep m))]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda (Rep m)
lam) (Builder (Rep m) [SubExpRes] -> Builder (Rep m) [SubExpRes])
-> Builder (Rep m) [SubExpRes] -> Builder (Rep m) [SubExpRes]
forall a b. (a -> b) -> a -> b
$ do
      (Param (LParamInfo (Rep m))
 -> Maybe SubExp -> BuilderT (Rep m) (State VNameSource) ())
-> [Param (LParamInfo (Rep m))]
-> [Maybe SubExp]
-> BuilderT (Rep m) (State VNameSource) ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param (LParamInfo (Rep m))
-> Maybe SubExp -> BuilderT (Rep m) (State VNameSource) ()
forall {m :: * -> *} {dec}.
MonadBuilder m =>
Param dec -> Maybe SubExp -> m ()
maybeFix (Lambda (Rep m) -> [Param (LParamInfo (Rep m))]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda (Rep m)
lam) [Maybe SubExp]
fixes'
      Body (Rep (BuilderT (Rep m) (State VNameSource)))
-> Builder (Rep m) [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Body (Rep m) -> m [SubExpRes]
bodyBind (Body (Rep (BuilderT (Rep m) (State VNameSource)))
 -> Builder (Rep m) [SubExpRes])
-> Body (Rep (BuilderT (Rep m) (State VNameSource)))
-> Builder (Rep m) [SubExpRes]
forall a b. (a -> b) -> a -> b
$ Lambda (Rep m) -> GBody (Rep m) SubExpRes
forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Rep m)
lam
  Lambda (Rep m) -> m (Lambda (Rep m))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    Lambda (Rep m)
lam
      { lambdaBody = body,
        lambdaParams =
          map fst $
            filter (isNothing . snd) $
              zip (lambdaParams lam) fixes'
      }
  where
    fixes' :: [Maybe SubExp]
fixes' = [Maybe SubExp]
fixes [Maybe SubExp] -> [Maybe SubExp] -> [Maybe SubExp]
forall a. [a] -> [a] -> [a]
++ Maybe SubExp -> [Maybe SubExp]
forall a. a -> [a]
repeat Maybe SubExp
forall a. Maybe a
Nothing
    maybeFix :: Param dec -> Maybe SubExp -> m ()
maybeFix Param dec
p (Just SubExp
x) = [VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p] (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
x
    maybeFix Param dec
_ Maybe SubExp
Nothing = () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

removeLambdaResults :: [Bool] -> Lambda rep -> Lambda rep
removeLambdaResults :: forall rep. [Bool] -> Lambda rep -> Lambda rep
removeLambdaResults [Bool]
keep Lambda rep
lam =
  Lambda rep
lam
    { lambdaBody = lam_body',
      lambdaReturnType = ret
    }
  where
    keep' :: [a] -> [a]
    keep' :: forall a. [a] -> [a]
keep' = ((Bool, a) -> a) -> [(Bool, a)] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (Bool, a) -> a
forall a b. (a, b) -> b
snd ([(Bool, a)] -> [a]) -> ([a] -> [(Bool, a)]) -> [a] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Bool, a) -> Bool) -> [(Bool, a)] -> [(Bool, a)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool, a) -> Bool
forall a b. (a, b) -> a
fst ([(Bool, a)] -> [(Bool, a)])
-> ([a] -> [(Bool, a)]) -> [a] -> [(Bool, a)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Bool] -> [a] -> [(Bool, a)]
forall a b. [a] -> [b] -> [(a, b)]
zip ([Bool]
keep [Bool] -> [Bool] -> [Bool]
forall a. [a] -> [a] -> [a]
++ Bool -> [Bool]
forall a. a -> [a]
repeat Bool
True)
    lam_body :: GBody rep SubExpRes
lam_body = Lambda rep -> GBody rep SubExpRes
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam
    lam_body' :: GBody rep SubExpRes
lam_body' = GBody rep SubExpRes
lam_body {bodyResult = keep' $ bodyResult lam_body}
    ret :: [TypeBase Shape NoUniqueness]
ret = [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a. [a] -> [a]
keep' ([TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness])
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda rep
lam

soacRules :: RuleBook (Wise SOACS)
soacRules :: RuleBook (Wise SOACS)
soacRules = RuleBook (Wise SOACS)
forall rep. (BuilderOps rep, TraverseOpStms rep) => RuleBook rep
standardRules RuleBook (Wise SOACS)
-> RuleBook (Wise SOACS) -> RuleBook (Wise SOACS)
forall a. Semigroup a => a -> a -> a
<> [TopDownRule (Wise SOACS)]
-> [BottomUpRule (Wise SOACS)] -> RuleBook (Wise SOACS)
forall m. [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook [TopDownRule (Wise SOACS)]
topDownRules [BottomUpRule (Wise SOACS)]
bottomUpRules

-- | Does this rep contain 'SOAC's in its t'Op's?  A rep must be an
-- instance of this class for the simplification rules to work.
class HasSOAC rep where
  asSOAC :: Op rep -> Maybe (SOAC rep)
  soacOp :: SOAC rep -> Op rep

instance HasSOAC (Wise SOACS) where
  asSOAC :: Op (Wise SOACS) -> Maybe (SOAC (Wise SOACS))
asSOAC = Op (Wise SOACS) -> Maybe (SOAC (Wise SOACS))
SOAC (Wise SOACS) -> Maybe (SOAC (Wise SOACS))
forall a. a -> Maybe a
Just
  soacOp :: SOAC (Wise SOACS) -> Op (Wise SOACS)
soacOp = SOAC (Wise SOACS) -> Op (Wise SOACS)
SOAC (Wise SOACS) -> SOAC (Wise SOACS)
forall a. a -> a
id

topDownRules :: [TopDownRule (Wise SOACS)]
topDownRules :: [TopDownRule (Wise SOACS)]
topDownRules =
  [ RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
-> TopDownRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
hoistCerts,
    RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
-> TopDownRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
forall rep.
(Aliased rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
removeReplicateMapping,
    RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
-> TopDownRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
forall rep.
(Aliased rep, Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
removeUnusedSOACInput,
    RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
-> TopDownRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
simplifyClosedFormReduce,
    RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
-> TopDownRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
simplifyKnownIterationSOAC,
    RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
-> TopDownRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
liftIdentityMapping,
    RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
-> TopDownRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
removeDuplicateMapOutput,
    RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
-> TopDownRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
fuseConcatScatter,
    RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
-> TopDownRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
simplifyMapIota,
    RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
-> TopDownRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
moveTransformToInput,
    RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
-> TopDownRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
moveTransformToOutput
  ]

bottomUpRules :: [BottomUpRule (Wise SOACS)]
bottomUpRules :: [BottomUpRule (Wise SOACS)]
bottomUpRules =
  [ RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
-> BottomUpRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
removeDeadMapping,
    RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
-> BottomUpRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
removeDeadReduction,
    RuleBasicOp (Wise SOACS) (BottomUp (Wise SOACS))
-> BottomUpRule (Wise SOACS)
forall rep a. RuleBasicOp rep a -> SimplificationRule rep a
RuleBasicOp RuleBasicOp (Wise SOACS) (BottomUp (Wise SOACS))
forall rep. BuilderOps rep => BottomUpRuleBasicOp rep
removeUnnecessaryCopy,
    RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
-> BottomUpRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
liftIdentityStreaming,
    RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
-> BottomUpRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
mapOpToOp
  ]

-- Any certificates attached to a trivial Stm in the body might as
-- well be applied to the SOAC itself.
hoistCerts :: TopDownRuleOp (Wise SOACS)
hoistCerts :: RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
hoistCerts SymbolTable (Wise SOACS)
vtable Pat (LetDec (Wise SOACS))
pat StmAux (ExpDec (Wise SOACS))
aux Op (Wise SOACS)
soac
  | (SOAC (Wise SOACS)
soac', Certs
hoisted) <- State Certs (SOAC (Wise SOACS))
-> Certs -> (SOAC (Wise SOACS), Certs)
forall s a. State s a -> s -> (a, s)
runState (SOACMapper (Wise SOACS) (Wise SOACS) (StateT Certs Identity)
-> SOAC (Wise SOACS) -> State Certs (SOAC (Wise SOACS))
forall (m :: * -> *) frep trep.
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper (Wise SOACS) (Wise SOACS) (StateT Certs Identity)
mapper Op (Wise SOACS)
SOAC (Wise SOACS)
soac) Certs
forall a. Monoid a => a
mempty,
    Certs
hoisted Certs -> Certs -> Bool
forall a. Eq a => a -> a -> Bool
/= Certs
forall a. Monoid a => a
mempty =
      RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ StmAux (ExpWisdom, ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ Certs -> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a. Certs -> RuleM (Wise SOACS) a -> RuleM (Wise SOACS) a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
hoisted (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (RuleM (Wise SOACS))))
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Rep (RuleM (Wise SOACS))))
Pat (LetDec (Wise SOACS))
pat (Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ Op (Wise SOACS) -> Exp (Wise SOACS)
forall rep. Op rep -> Exp rep
Op Op (Wise SOACS)
SOAC (Wise SOACS)
soac'
  where
    mapper :: SOACMapper (Wise SOACS) (Wise SOACS) (StateT Certs Identity)
mapper = SOACMapper Any Any (StateT Certs Identity)
forall rep (m :: * -> *). Monad m => SOACMapper rep rep m
identitySOACMapper {mapOnSOACLambda = onLambda}
    onLambda :: Lambda (Wise SOACS) -> StateT Certs Identity (Lambda (Wise SOACS))
onLambda Lambda (Wise SOACS)
lam = do
      Stms (Wise SOACS)
stms' <- (Stm (Wise SOACS) -> StateT Certs Identity (Stm (Wise SOACS)))
-> Stms (Wise SOACS) -> StateT Certs Identity (Stms (Wise SOACS))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Seq a -> m (Seq b)
mapM Stm (Wise SOACS) -> StateT Certs Identity (Stm (Wise SOACS))
onStm (Stms (Wise SOACS) -> StateT Certs Identity (Stms (Wise SOACS)))
-> Stms (Wise SOACS) -> StateT Certs Identity (Stms (Wise SOACS))
forall a b. (a -> b) -> a -> b
$ GBody (Wise SOACS) SubExpRes -> Stms (Wise SOACS)
forall rep res. GBody rep res -> Stms rep
bodyStms (GBody (Wise SOACS) SubExpRes -> Stms (Wise SOACS))
-> GBody (Wise SOACS) SubExpRes -> Stms (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> GBody (Wise SOACS) SubExpRes
forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
lam
      Lambda (Wise SOACS) -> StateT Certs Identity (Lambda (Wise SOACS))
forall a. a -> StateT Certs Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
        Lambda (Wise SOACS)
lam
          { lambdaBody =
              mkBody stms' $ bodyResult $ lambdaBody lam
          }
    onStm :: Stm (Wise SOACS) -> StateT Certs Identity (Stm (Wise SOACS))
onStm (Let Pat (LetDec (Wise SOACS))
se_pat StmAux (ExpDec (Wise SOACS))
se_aux (BasicOp (SubExp SubExp
se))) = do
      let ([VName]
invariant, [VName]
variant) =
            (VName -> Bool) -> [VName] -> ([VName], [VName])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (VName -> SymbolTable (Wise SOACS) -> Bool
forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` SymbolTable (Wise SOACS)
vtable) ([VName] -> ([VName], [VName])) -> [VName] -> ([VName], [VName])
forall a b. (a -> b) -> a -> b
$
              Certs -> [VName]
unCerts (Certs -> [VName]) -> Certs -> [VName]
forall a b. (a -> b) -> a -> b
$
                StmAux (ExpDec (Wise SOACS)) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec (Wise SOACS))
se_aux
          se_aux' :: StmAux (ExpDec (Wise SOACS))
se_aux' = StmAux (ExpDec (Wise SOACS))
se_aux {stmAuxCerts = Certs variant}
      (Certs -> Certs) -> StateT Certs Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ([VName] -> Certs
Certs [VName]
invariant <>)
      Stm (Wise SOACS) -> StateT Certs Identity (Stm (Wise SOACS))
forall a. a -> StateT Certs Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm (Wise SOACS) -> StateT Certs Identity (Stm (Wise SOACS)))
-> Stm (Wise SOACS) -> StateT Certs Identity (Stm (Wise SOACS))
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Wise SOACS))
-> StmAux (ExpDec (Wise SOACS))
-> Exp (Wise SOACS)
-> Stm (Wise SOACS)
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec (Wise SOACS))
se_pat StmAux (ExpDec (Wise SOACS))
se_aux' (Exp (Wise SOACS) -> Stm (Wise SOACS))
-> Exp (Wise SOACS) -> Stm (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Wise SOACS)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Wise SOACS)) -> BasicOp -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
    onStm Stm (Wise SOACS)
stm = Stm (Wise SOACS) -> StateT Certs Identity (Stm (Wise SOACS))
forall a. a -> StateT Certs Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stm (Wise SOACS)
stm
hoistCerts SymbolTable (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ =
  Rule (Wise SOACS)
forall rep. Rule rep
Skip

liftIdentityMapping ::
  forall rep.
  (Buildable rep, BuilderOps rep, HasSOAC rep) =>
  TopDownRuleOp rep
liftIdentityMapping :: forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
liftIdentityMapping TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux OpC rep rep
op
  | Just (Screma SubExp
w [VName]
arrs ScremaForm rep
form :: SOAC rep) <- OpC rep rep -> Maybe (SOAC rep)
forall rep. HasSOAC rep => Op rep -> Maybe (SOAC rep)
asSOAC OpC rep rep
op,
    Just Lambda rep
fun <- ScremaForm rep -> Maybe (Lambda rep)
forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm rep
form = do
      let inputMap :: Map VName VName
inputMap = [(VName, VName)] -> Map VName VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, VName)] -> Map VName VName)
-> [(VName, VName)] -> Map VName VName
forall a b. (a -> b) -> a -> b
$ [VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param (TypeBase Shape NoUniqueness) -> VName)
-> [Param (TypeBase Shape NoUniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName ([Param (TypeBase Shape NoUniqueness)] -> [VName])
-> [Param (TypeBase Shape NoUniqueness)] -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [LParam rep]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
fun) [VName]
arrs
          free :: Names
free = Body rep -> Names
forall a. FreeIn a => a -> Names
freeIn (Body rep -> Names) -> Body rep -> Names
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
fun
          rettype :: [TypeBase Shape NoUniqueness]
rettype = Lambda rep -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda rep
fun
          ses :: [SubExpRes]
ses = Body rep -> [SubExpRes]
forall rep res. GBody rep res -> [res]
bodyResult (Body rep -> [SubExpRes]) -> Body rep -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
fun

          freeOrConst :: SubExp -> Bool
freeOrConst (Var VName
v) = VName
v VName -> Names -> Bool
`nameIn` Names
free
          freeOrConst Constant {} = Bool
True

          checkInvariance :: (PatElem (LetDec rep), SubExpRes, TypeBase Shape NoUniqueness)
-> ([(Pat (LetDec rep), Exp rep)],
    [(PatElem (LetDec rep), SubExp)], [TypeBase Shape NoUniqueness])
-> ([(Pat (LetDec rep), Exp rep)],
    [(PatElem (LetDec rep), SubExp)], [TypeBase Shape NoUniqueness])
checkInvariance (PatElem (LetDec rep)
outId, SubExpRes Certs
_ (Var VName
v), TypeBase Shape NoUniqueness
_) ([(Pat (LetDec rep), Exp rep)]
invariant, [(PatElem (LetDec rep), SubExp)]
mapresult, [TypeBase Shape NoUniqueness]
rettype')
            | Just VName
inp <- VName -> Map VName VName -> Maybe VName
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName VName
inputMap =
                ( ([PatElem (LetDec rep)] -> Pat (LetDec rep)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)
outId], VName -> Exp rep
e VName
inp) (Pat (LetDec rep), Exp rep)
-> [(Pat (LetDec rep), Exp rep)] -> [(Pat (LetDec rep), Exp rep)]
forall a. a -> [a] -> [a]
: [(Pat (LetDec rep), Exp rep)]
invariant,
                  [(PatElem (LetDec rep), SubExp)]
mapresult,
                  [TypeBase Shape NoUniqueness]
rettype'
                )
            where
              e :: VName -> Exp rep
e VName
inp = case PatElem (LetDec rep) -> TypeBase Shape NoUniqueness
forall dec. Typed dec => PatElem dec -> TypeBase Shape NoUniqueness
patElemType PatElem (LetDec rep)
outId of
                Acc {} -> BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
inp
                TypeBase Shape NoUniqueness
_ -> BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (Shape -> SubExp -> BasicOp
Replicate Shape
forall a. Monoid a => a
mempty (VName -> SubExp
Var VName
inp))
          checkInvariance (PatElem (LetDec rep)
outId, SubExpRes Certs
_ SubExp
e, TypeBase Shape NoUniqueness
t) ([(Pat (LetDec rep), Exp rep)]
invariant, [(PatElem (LetDec rep), SubExp)]
mapresult, [TypeBase Shape NoUniqueness]
rettype')
            | SubExp -> Bool
freeOrConst SubExp
e =
                ( ([PatElem (LetDec rep)] -> Pat (LetDec rep)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)
outId], BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
e) (Pat (LetDec rep), Exp rep)
-> [(Pat (LetDec rep), Exp rep)] -> [(Pat (LetDec rep), Exp rep)]
forall a. a -> [a] -> [a]
: [(Pat (LetDec rep), Exp rep)]
invariant,
                  [(PatElem (LetDec rep), SubExp)]
mapresult,
                  [TypeBase Shape NoUniqueness]
rettype'
                )
            | Bool
otherwise =
                ( [(Pat (LetDec rep), Exp rep)]
invariant,
                  (PatElem (LetDec rep)
outId, SubExp
e) (PatElem (LetDec rep), SubExp)
-> [(PatElem (LetDec rep), SubExp)]
-> [(PatElem (LetDec rep), SubExp)]
forall a. a -> [a] -> [a]
: [(PatElem (LetDec rep), SubExp)]
mapresult,
                  TypeBase Shape NoUniqueness
t TypeBase Shape NoUniqueness
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a. a -> [a] -> [a]
: [TypeBase Shape NoUniqueness]
rettype'
                )

      case ((PatElem (LetDec rep), SubExpRes, TypeBase Shape NoUniqueness)
 -> ([(Pat (LetDec rep), Exp rep)],
     [(PatElem (LetDec rep), SubExp)], [TypeBase Shape NoUniqueness])
 -> ([(Pat (LetDec rep), Exp rep)],
     [(PatElem (LetDec rep), SubExp)], [TypeBase Shape NoUniqueness]))
-> ([(Pat (LetDec rep), Exp rep)],
    [(PatElem (LetDec rep), SubExp)], [TypeBase Shape NoUniqueness])
-> [(PatElem (LetDec rep), SubExpRes, TypeBase Shape NoUniqueness)]
-> ([(Pat (LetDec rep), Exp rep)],
    [(PatElem (LetDec rep), SubExp)], [TypeBase Shape NoUniqueness])
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (PatElem (LetDec rep), SubExpRes, TypeBase Shape NoUniqueness)
-> ([(Pat (LetDec rep), Exp rep)],
    [(PatElem (LetDec rep), SubExp)], [TypeBase Shape NoUniqueness])
-> ([(Pat (LetDec rep), Exp rep)],
    [(PatElem (LetDec rep), SubExp)], [TypeBase Shape NoUniqueness])
checkInvariance ([], [], []) ([(PatElem (LetDec rep), SubExpRes, TypeBase Shape NoUniqueness)]
 -> ([(Pat (LetDec rep), Exp rep)],
     [(PatElem (LetDec rep), SubExp)], [TypeBase Shape NoUniqueness]))
-> [(PatElem (LetDec rep), SubExpRes, TypeBase Shape NoUniqueness)]
-> ([(Pat (LetDec rep), Exp rep)],
    [(PatElem (LetDec rep), SubExp)], [TypeBase Shape NoUniqueness])
forall a b. (a -> b) -> a -> b
$
        [PatElem (LetDec rep)]
-> [SubExpRes]
-> [TypeBase Shape NoUniqueness]
-> [(PatElem (LetDec rep), SubExpRes, TypeBase Shape NoUniqueness)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (Pat (LetDec rep) -> [PatElem (LetDec rep)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat) [SubExpRes]
ses [TypeBase Shape NoUniqueness]
rettype of
        ([], [(PatElem (LetDec rep), SubExp)]
_, [TypeBase Shape NoUniqueness]
_) -> Rule rep
forall rep. Rule rep
Skip
        ([(Pat (LetDec rep), Exp rep)]
invariant, [(PatElem (LetDec rep), SubExp)]
mapresult, [TypeBase Shape NoUniqueness]
rettype') -> RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
          let ([PatElem (LetDec rep)]
pat', [SubExp]
ses') = [(PatElem (LetDec rep), SubExp)]
-> ([PatElem (LetDec rep)], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(PatElem (LetDec rep), SubExp)]
mapresult
              fun' :: Lambda rep
fun' =
                Lambda rep
fun
                  { lambdaBody = (lambdaBody fun) {bodyResult = subExpsRes ses'},
                    lambdaReturnType = rettype'
                  }
          StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ do
            ((Pat (LetDec rep), Exp rep) -> RuleM rep ())
-> [(Pat (LetDec rep), Exp rep)] -> RuleM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((Pat (LetDec rep) -> Exp rep -> RuleM rep ())
-> (Pat (LetDec rep), Exp rep) -> RuleM rep ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Pat (LetDec rep) -> Exp rep -> RuleM rep ()
Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind) [(Pat (LetDec rep), Exp rep)]
invariant
            [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames ((PatElem (LetDec rep) -> VName)
-> [PatElem (LetDec rep)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName [PatElem (LetDec rep)]
pat') (Exp rep -> RuleM rep ())
-> (OpC rep rep -> Exp rep) -> OpC rep rep -> RuleM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OpC rep rep -> Exp rep
forall rep. Op rep -> Exp rep
Op (OpC rep rep -> RuleM rep ()) -> OpC rep rep -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
              SOAC rep -> OpC rep rep
forall rep. HasSOAC rep => SOAC rep -> Op rep
soacOp (SubExp -> [VName] -> ScremaForm rep -> SOAC rep
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs (Lambda rep -> ScremaForm rep
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda rep
fun'))
liftIdentityMapping TopDown rep
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ OpC rep rep
_ = Rule rep
forall rep. Rule rep
Skip

liftIdentityStreaming :: BottomUpRuleOp (Wise SOACS)
liftIdentityStreaming :: RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
liftIdentityStreaming BottomUp (Wise SOACS)
_ (Pat [PatElem (LetDec (Wise SOACS))]
pes) StmAux (ExpDec (Wise SOACS))
aux (Stream SubExp
w [VName]
arrs [SubExp]
nes Lambda (Wise SOACS)
lam)
  | ([(TypeBase Shape NoUniqueness,
  PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExpRes)]
variant_map, [(PatElem (VarWisdom, TypeBase Shape NoUniqueness), VName)]
invariant_map) <-
      [Either
   (TypeBase Shape NoUniqueness,
    PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExpRes)
   (PatElem (VarWisdom, TypeBase Shape NoUniqueness), VName)]
-> ([(TypeBase Shape NoUniqueness,
      PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExpRes)],
    [(PatElem (VarWisdom, TypeBase Shape NoUniqueness), VName)])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either
    (TypeBase Shape NoUniqueness,
     PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExpRes)
    (PatElem (VarWisdom, TypeBase Shape NoUniqueness), VName)]
 -> ([(TypeBase Shape NoUniqueness,
       PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExpRes)],
     [(PatElem (VarWisdom, TypeBase Shape NoUniqueness), VName)]))
-> [Either
      (TypeBase Shape NoUniqueness,
       PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExpRes)
      (PatElem (VarWisdom, TypeBase Shape NoUniqueness), VName)]
-> ([(TypeBase Shape NoUniqueness,
      PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExpRes)],
    [(PatElem (VarWisdom, TypeBase Shape NoUniqueness), VName)])
forall a b. (a -> b) -> a -> b
$ ((TypeBase Shape NoUniqueness,
  PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExpRes)
 -> Either
      (TypeBase Shape NoUniqueness,
       PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExpRes)
      (PatElem (VarWisdom, TypeBase Shape NoUniqueness), VName))
-> [(TypeBase Shape NoUniqueness,
     PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExpRes)]
-> [Either
      (TypeBase Shape NoUniqueness,
       PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExpRes)
      (PatElem (VarWisdom, TypeBase Shape NoUniqueness), VName)]
forall a b. (a -> b) -> [a] -> [b]
map (TypeBase Shape NoUniqueness,
 PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExpRes)
-> Either
     (TypeBase Shape NoUniqueness,
      PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExpRes)
     (PatElem (VarWisdom, TypeBase Shape NoUniqueness), VName)
isInvariantRes ([(TypeBase Shape NoUniqueness,
   PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExpRes)]
 -> [Either
       (TypeBase Shape NoUniqueness,
        PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExpRes)
       (PatElem (VarWisdom, TypeBase Shape NoUniqueness), VName)])
-> [(TypeBase Shape NoUniqueness,
     PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExpRes)]
-> [Either
      (TypeBase Shape NoUniqueness,
       PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExpRes)
      (PatElem (VarWisdom, TypeBase Shape NoUniqueness), VName)]
forall a b. (a -> b) -> a -> b
$ [TypeBase Shape NoUniqueness]
-> [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
-> [SubExpRes]
-> [(TypeBase Shape NoUniqueness,
     PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExpRes)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TypeBase Shape NoUniqueness]
map_ts [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
map_pes [SubExpRes]
map_res,
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [(PatElem (VarWisdom, TypeBase Shape NoUniqueness), VName)] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(PatElem (VarWisdom, TypeBase Shape NoUniqueness), VName)]
invariant_map = RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ do
      [(PatElem (VarWisdom, TypeBase Shape NoUniqueness), VName)]
-> ((PatElem (VarWisdom, TypeBase Shape NoUniqueness), VName)
    -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(PatElem (VarWisdom, TypeBase Shape NoUniqueness), VName)]
invariant_map (((PatElem (VarWisdom, TypeBase Shape NoUniqueness), VName)
  -> RuleM (Wise SOACS) ())
 -> RuleM (Wise SOACS) ())
-> ((PatElem (VarWisdom, TypeBase Shape NoUniqueness), VName)
    -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ \(PatElem (VarWisdom, TypeBase Shape NoUniqueness)
pe, VName
arr) ->
        Pat (LetDec (Rep (RuleM (Wise SOACS))))
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind ([PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
-> Pat (VarWisdom, TypeBase Shape NoUniqueness)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (VarWisdom, TypeBase Shape NoUniqueness)
pe]) (Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM (Wise SOACS)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM (Wise SOACS))))
-> BasicOp -> Exp (Rep (RuleM (Wise SOACS)))
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
forall a. Monoid a => a
mempty (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr

      let ([TypeBase Shape NoUniqueness]
variant_map_ts, [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
variant_map_pes, [SubExpRes]
variant_map_res) = [(TypeBase Shape NoUniqueness,
  PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExpRes)]
-> ([TypeBase Shape NoUniqueness],
    [PatElem (VarWisdom, TypeBase Shape NoUniqueness)], [SubExpRes])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(TypeBase Shape NoUniqueness,
  PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExpRes)]
variant_map
          lam' :: Lambda (Wise SOACS)
lam' =
            Lambda (Wise SOACS)
lam
              { lambdaBody = (lambdaBody lam) {bodyResult = fold_res ++ variant_map_res},
                lambdaReturnType = fold_ts ++ variant_map_ts
              }

      StmAux (ExpWisdom, ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> (SOAC (Wise SOACS) -> RuleM (Wise SOACS) ())
-> SOAC (Wise SOACS)
-> RuleM (Wise SOACS) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec (Rep (RuleM (Wise SOACS))))
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind ([PatElem (LetDec (Rep (RuleM (Wise SOACS))))]
-> Pat (LetDec (Rep (RuleM (Wise SOACS))))
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem (LetDec (Rep (RuleM (Wise SOACS))))]
 -> Pat (LetDec (Rep (RuleM (Wise SOACS)))))
-> [PatElem (LetDec (Rep (RuleM (Wise SOACS))))]
-> Pat (LetDec (Rep (RuleM (Wise SOACS))))
forall a b. (a -> b) -> a -> b
$ [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
fold_pes [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
-> [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
-> [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
variant_map_pes) (Exp (Wise SOACS) -> RuleM (Wise SOACS) ())
-> (SOAC (Wise SOACS) -> Exp (Wise SOACS))
-> SOAC (Wise SOACS)
-> RuleM (Wise SOACS) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op (Wise SOACS) -> Exp (Wise SOACS)
SOAC (Wise SOACS) -> Exp (Wise SOACS)
forall rep. Op rep -> Exp rep
Op (SOAC (Wise SOACS) -> RuleM (Wise SOACS) ())
-> SOAC (Wise SOACS) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
        SubExp
-> [VName] -> [SubExp] -> Lambda (Wise SOACS) -> SOAC (Wise SOACS)
forall rep. SubExp -> [VName] -> [SubExp] -> Lambda rep -> SOAC rep
Stream SubExp
w [VName]
arrs [SubExp]
nes Lambda (Wise SOACS)
lam'
  where
    num_folds :: Int
num_folds = [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes
    ([PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
fold_pes, [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
map_pes) = Int
-> [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
-> ([PatElem (VarWisdom, TypeBase Shape NoUniqueness)],
    [PatElem (VarWisdom, TypeBase Shape NoUniqueness)])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_folds [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
[PatElem (LetDec (Wise SOACS))]
pes
    ([TypeBase Shape NoUniqueness]
fold_ts, [TypeBase Shape NoUniqueness]
map_ts) = Int
-> [TypeBase Shape NoUniqueness]
-> ([TypeBase Shape NoUniqueness], [TypeBase Shape NoUniqueness])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_folds ([TypeBase Shape NoUniqueness]
 -> ([TypeBase Shape NoUniqueness], [TypeBase Shape NoUniqueness]))
-> [TypeBase Shape NoUniqueness]
-> ([TypeBase Shape NoUniqueness], [TypeBase Shape NoUniqueness])
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda (Wise SOACS)
lam
    lam_res :: [SubExpRes]
lam_res = GBody (Wise SOACS) SubExpRes -> [SubExpRes]
forall rep res. GBody rep res -> [res]
bodyResult (GBody (Wise SOACS) SubExpRes -> [SubExpRes])
-> GBody (Wise SOACS) SubExpRes -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> GBody (Wise SOACS) SubExpRes
forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
lam
    ([SubExpRes]
fold_res, [SubExpRes]
map_res) = Int -> [SubExpRes] -> ([SubExpRes], [SubExpRes])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_folds [SubExpRes]
lam_res
    params_to_arrs :: [(VName, VName)]
params_to_arrs = [VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param (TypeBase Shape NoUniqueness) -> VName)
-> [Param (TypeBase Shape NoUniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName ([Param (TypeBase Shape NoUniqueness)] -> [VName])
-> [Param (TypeBase Shape NoUniqueness)] -> [VName]
forall a b. (a -> b) -> a -> b
$ Int
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. Int -> [a] -> [a]
drop (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
num_folds) ([Param (TypeBase Shape NoUniqueness)]
 -> [Param (TypeBase Shape NoUniqueness)])
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> [LParam (Wise SOACS)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda (Wise SOACS)
lam) [VName]
arrs

    isInvariantRes :: (TypeBase Shape NoUniqueness,
 PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExpRes)
-> Either
     (TypeBase Shape NoUniqueness,
      PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExpRes)
     (PatElem (VarWisdom, TypeBase Shape NoUniqueness), VName)
isInvariantRes (TypeBase Shape NoUniqueness
_, PatElem (VarWisdom, TypeBase Shape NoUniqueness)
pe, SubExpRes Certs
_ (Var VName
v))
      | Just VName
arr <- VName -> [(VName, VName)] -> Maybe VName
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v [(VName, VName)]
params_to_arrs =
          (PatElem (VarWisdom, TypeBase Shape NoUniqueness), VName)
-> Either
     (TypeBase Shape NoUniqueness,
      PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExpRes)
     (PatElem (VarWisdom, TypeBase Shape NoUniqueness), VName)
forall a b. b -> Either a b
Right (PatElem (VarWisdom, TypeBase Shape NoUniqueness)
pe, VName
arr)
    isInvariantRes (TypeBase Shape NoUniqueness,
 PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExpRes)
x =
      (TypeBase Shape NoUniqueness,
 PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExpRes)
-> Either
     (TypeBase Shape NoUniqueness,
      PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExpRes)
     (PatElem (VarWisdom, TypeBase Shape NoUniqueness), VName)
forall a b. a -> Either a b
Left (TypeBase Shape NoUniqueness,
 PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExpRes)
x
liftIdentityStreaming BottomUp (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = Rule (Wise SOACS)
forall rep. Rule rep
Skip

-- | Remove all arguments to the map that are simply replicates.
-- These can be turned into free variables instead.
removeReplicateMapping ::
  (Aliased rep, BuilderOps rep, HasSOAC rep) =>
  TopDownRuleOp rep
removeReplicateMapping :: forall rep.
(Aliased rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
removeReplicateMapping TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux Op rep
op
  | Just (Screma SubExp
w [VName]
arrs ScremaForm rep
form) <- Op rep -> Maybe (SOAC rep)
forall rep. HasSOAC rep => Op rep -> Maybe (SOAC rep)
asSOAC Op rep
op,
    Just Lambda rep
fun <- ScremaForm rep -> Maybe (Lambda rep)
forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm rep
form,
    Just ([([VName], Certs, Exp rep)]
stms, Lambda rep
fun', [VName]
arrs') <- TopDown rep
-> Lambda rep
-> [VName]
-> Maybe ([([VName], Certs, Exp rep)], Lambda rep, [VName])
forall rep.
Aliased rep =>
SymbolTable rep
-> Lambda rep
-> [VName]
-> Maybe ([([VName], Certs, Exp rep)], Lambda rep, [VName])
removeReplicateInput TopDown rep
vtable Lambda rep
fun [VName]
arrs = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
      [([VName], Certs, Exp rep)]
-> (([VName], Certs, Exp rep) -> RuleM rep ()) -> RuleM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [([VName], Certs, Exp rep)]
stms ((([VName], Certs, Exp rep) -> RuleM rep ()) -> RuleM rep ())
-> (([VName], Certs, Exp rep) -> RuleM rep ()) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ \([VName]
vs, Certs
cs, Exp rep
e) -> Certs -> RuleM rep () -> RuleM rep ()
forall a. Certs -> RuleM rep a -> RuleM rep a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName]
vs Exp rep
Exp (Rep (RuleM rep))
e
      StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Op (Rep (RuleM rep)) -> Exp (Rep (RuleM rep))
forall rep. Op rep -> Exp rep
Op (Op (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)))
-> Op (Rep (RuleM rep)) -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ SOAC rep -> Op rep
forall rep. HasSOAC rep => SOAC rep -> Op rep
soacOp (SOAC rep -> Op rep) -> SOAC rep -> Op rep
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm rep -> SOAC rep
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs' (ScremaForm rep -> SOAC rep) -> ScremaForm rep -> SOAC rep
forall a b. (a -> b) -> a -> b
$ Lambda rep -> ScremaForm rep
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda rep
fun'
removeReplicateMapping TopDown rep
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ Op rep
_ = Rule rep
forall rep. Rule rep
Skip

removeReplicateInput ::
  (Aliased rep) =>
  ST.SymbolTable rep ->
  Lambda rep ->
  [VName] ->
  Maybe
    ( [([VName], Certs, Exp rep)],
      Lambda rep,
      [VName]
    )
removeReplicateInput :: forall rep.
Aliased rep =>
SymbolTable rep
-> Lambda rep
-> [VName]
-> Maybe ([([VName], Certs, Exp rep)], Lambda rep, [VName])
removeReplicateInput SymbolTable rep
vtable Lambda rep
fun [VName]
arrs
  | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [([VName], Certs, Exp rep)] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [([VName], Certs, Exp rep)]
parameterBnds = do
      let ([Param (LParamInfo rep)]
arr_params', [VName]
arrs') = [(Param (LParamInfo rep), VName)]
-> ([Param (LParamInfo rep)], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param (LParamInfo rep), VName)]
params_and_arrs
          fun' :: Lambda rep
fun' = Lambda rep
fun {lambdaParams = acc_params <> arr_params'}
      ([([VName], Certs, Exp rep)], Lambda rep, [VName])
-> Maybe ([([VName], Certs, Exp rep)], Lambda rep, [VName])
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([([VName], Certs, Exp rep)]
parameterBnds, Lambda rep
fun', [VName]
arrs')
  | Bool
otherwise = Maybe ([([VName], Certs, Exp rep)], Lambda rep, [VName])
forall a. Maybe a
Nothing
  where
    params :: [Param (LParamInfo rep)]
params = Lambda rep -> [Param (LParamInfo rep)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
fun
    ([Param (LParamInfo rep)]
acc_params, [Param (LParamInfo rep)]
arr_params) =
      Int
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Param (LParamInfo rep)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Param (LParamInfo rep)]
params Int -> Int -> Int
forall a. Num a => a -> a -> a
- [VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
arrs) [Param (LParamInfo rep)]
params
    ([(Param (LParamInfo rep), VName)]
params_and_arrs, [([VName], Certs, Exp rep)]
parameterBnds) =
      [Either (Param (LParamInfo rep), VName) ([VName], Certs, Exp rep)]
-> ([(Param (LParamInfo rep), VName)], [([VName], Certs, Exp rep)])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either (Param (LParamInfo rep), VName) ([VName], Certs, Exp rep)]
 -> ([(Param (LParamInfo rep), VName)],
     [([VName], Certs, Exp rep)]))
-> [Either
      (Param (LParamInfo rep), VName) ([VName], Certs, Exp rep)]
-> ([(Param (LParamInfo rep), VName)], [([VName], Certs, Exp rep)])
forall a b. (a -> b) -> a -> b
$ (Param (LParamInfo rep)
 -> VName
 -> Either
      (Param (LParamInfo rep), VName) ([VName], Certs, Exp rep))
-> [Param (LParamInfo rep)]
-> [VName]
-> [Either
      (Param (LParamInfo rep), VName) ([VName], Certs, Exp rep)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Param (LParamInfo rep)
-> VName
-> Either (Param (LParamInfo rep), VName) ([VName], Certs, Exp rep)
isReplicateAndNotConsumed [Param (LParamInfo rep)]
arr_params [VName]
arrs

    isReplicateAndNotConsumed :: Param (LParamInfo rep)
-> VName
-> Either (Param (LParamInfo rep), VName) ([VName], Certs, Exp rep)
isReplicateAndNotConsumed Param (LParamInfo rep)
p VName
v
      | Just (BasicOp (Replicate (Shape (SubExp
_ : [SubExp]
ds)) SubExp
e), Certs
v_cs) <-
          VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
v SymbolTable rep
vtable,
        Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo rep)
p VName -> Names -> Bool
`notNameIn` Lambda rep -> Names
forall rep. Aliased rep => Lambda rep -> Names
consumedByLambda Lambda rep
fun =
          ([VName], Certs, Exp rep)
-> Either (Param (LParamInfo rep), VName) ([VName], Certs, Exp rep)
forall a b. b -> Either a b
Right
            ( [Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo rep)
p],
              Certs
v_cs,
              case [SubExp]
ds of
                [] -> BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
e
                [SubExp]
_ -> BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
ds) SubExp
e
            )
      | Bool
otherwise =
          (Param (LParamInfo rep), VName)
-> Either (Param (LParamInfo rep), VName) ([VName], Certs, Exp rep)
forall a b. a -> Either a b
Left (Param (LParamInfo rep)
p, VName
v)

-- | Remove inputs that are not used inside the SOAC.
removeUnusedSOACInput ::
  forall rep.
  (Aliased rep, Buildable rep, BuilderOps rep, HasSOAC rep) =>
  TopDownRuleOp rep
removeUnusedSOACInput :: forall rep.
(Aliased rep, Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
removeUnusedSOACInput TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux OpC rep rep
op
  | Just (Screma SubExp
w [VName]
arrs ScremaForm rep
form :: SOAC rep) <- OpC rep rep -> Maybe (SOAC rep)
forall rep. HasSOAC rep => Op rep -> Maybe (SOAC rep)
asSOAC OpC rep rep
op,
    ScremaForm Lambda rep
map_lam [Scan rep]
scan [Reduce rep]
reduce <- ScremaForm rep
form,
    Just ([VName]
used_arrs, Lambda rep
map_lam') <- Lambda rep -> [VName] -> Maybe ([VName], Lambda rep)
forall {rep} {b}.
(FreeDec (ExpDec rep), FreeDec (BodyDec rep), FreeIn (OpC rep rep),
 FreeIn (FParamInfo rep), FreeIn (LParamInfo rep),
 FreeIn (LetDec rep), FreeIn (RetType rep),
 FreeIn (BranchType rep)) =>
Lambda rep -> [b] -> Maybe ([b], Lambda rep)
remove Lambda rep
map_lam [VName]
arrs =
      RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep)
-> (OpC rep rep -> RuleM rep ()) -> OpC rep rep -> Rule rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ())
-> (OpC rep rep -> RuleM rep ()) -> OpC rep rep -> RuleM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp rep -> RuleM rep ())
-> (OpC rep rep -> Exp rep) -> OpC rep rep -> RuleM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OpC rep rep -> Exp rep
forall rep. Op rep -> Exp rep
Op (OpC rep rep -> Rule rep) -> OpC rep rep -> Rule rep
forall a b. (a -> b) -> a -> b
$
        SOAC rep -> OpC rep rep
forall rep. HasSOAC rep => SOAC rep -> Op rep
soacOp (SubExp -> [VName] -> ScremaForm rep -> SOAC rep
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
used_arrs (Lambda rep -> [Scan rep] -> [Reduce rep] -> ScremaForm rep
forall rep.
Lambda rep -> [Scan rep] -> [Reduce rep] -> ScremaForm rep
ScremaForm Lambda rep
map_lam' [Scan rep]
scan [Reduce rep]
reduce))
  where
    used_in_body :: Lambda rep -> Names
used_in_body Lambda rep
map_lam = Body rep -> Names
forall a. FreeIn a => a -> Names
freeIn (Body rep -> Names) -> Body rep -> Names
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
map_lam
    usedInput :: Lambda rep -> (Param dec, b) -> Bool
usedInput Lambda rep
map_lam (Param dec
param, b
_) = Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
param VName -> Names -> Bool
`nameIn` Lambda rep -> Names
forall {rep}.
(FreeDec (ExpDec rep), FreeDec (BodyDec rep), FreeIn (OpC rep rep),
 FreeIn (FParamInfo rep), FreeIn (LParamInfo rep),
 FreeIn (LetDec rep), FreeIn (RetType rep),
 FreeIn (BranchType rep)) =>
Lambda rep -> Names
used_in_body Lambda rep
map_lam
    remove :: Lambda rep -> [b] -> Maybe ([b], Lambda rep)
remove Lambda rep
map_lam [b]
arrs =
      let ([(Param (LParamInfo rep), b)]
used, [(Param (LParamInfo rep), b)]
unused) = ((Param (LParamInfo rep), b) -> Bool)
-> [(Param (LParamInfo rep), b)]
-> ([(Param (LParamInfo rep), b)], [(Param (LParamInfo rep), b)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (Lambda rep -> (Param (LParamInfo rep), b) -> Bool
forall {rep} {dec} {b}.
(FreeDec (ExpDec rep), FreeDec (BodyDec rep), FreeIn (OpC rep rep),
 FreeIn (FParamInfo rep), FreeIn (LParamInfo rep),
 FreeIn (LetDec rep), FreeIn (RetType rep),
 FreeIn (BranchType rep)) =>
Lambda rep -> (Param dec, b) -> Bool
usedInput Lambda rep
map_lam) ([Param (LParamInfo rep)] -> [b] -> [(Param (LParamInfo rep), b)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda rep -> [Param (LParamInfo rep)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
map_lam) [b]
arrs)
          ([Param (LParamInfo rep)]
used_params, [b]
used_arrs) = [(Param (LParamInfo rep), b)] -> ([Param (LParamInfo rep)], [b])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param (LParamInfo rep), b)]
used
          map_lam' :: Lambda rep
map_lam' = Lambda rep
map_lam {lambdaParams = used_params}
       in if [(Param (LParamInfo rep), b)] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Param (LParamInfo rep), b)]
unused then Maybe ([b], Lambda rep)
forall a. Maybe a
Nothing else ([b], Lambda rep) -> Maybe ([b], Lambda rep)
forall a. a -> Maybe a
Just ([b]
used_arrs, Lambda rep
map_lam')
removeUnusedSOACInput TopDown rep
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ OpC rep rep
_ = Rule rep
forall rep. Rule rep
Skip

removeDeadMapping :: BottomUpRuleOp (Wise SOACS)
removeDeadMapping :: RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
removeDeadMapping (SymbolTable (Wise SOACS)
_, UsageTable
used) (Pat [PatElem (LetDec (Wise SOACS))]
pes) StmAux (ExpDec (Wise SOACS))
aux (Screma SubExp
w [VName]
arrs (ScremaForm Lambda (Wise SOACS)
lam [Scan (Wise SOACS)]
scans [Reduce (Wise SOACS)]
reds))
  | ([PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
nonmap_pes, [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
map_pes) <- Int
-> [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
-> ([PatElem (VarWisdom, TypeBase Shape NoUniqueness)],
    [PatElem (VarWisdom, TypeBase Shape NoUniqueness)])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_nonmap_res [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
[PatElem (LetDec (Wise SOACS))]
pes,
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [PatElem (VarWisdom, TypeBase Shape NoUniqueness)] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
map_pes =
      let ([SubExpRes]
nonmap_res, [SubExpRes]
map_res) = Int -> [SubExpRes] -> ([SubExpRes], [SubExpRes])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_nonmap_res ([SubExpRes] -> ([SubExpRes], [SubExpRes]))
-> [SubExpRes] -> ([SubExpRes], [SubExpRes])
forall a b. (a -> b) -> a -> b
$ GBody (Wise SOACS) SubExpRes -> [SubExpRes]
forall rep res. GBody rep res -> [res]
bodyResult (GBody (Wise SOACS) SubExpRes -> [SubExpRes])
-> GBody (Wise SOACS) SubExpRes -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> GBody (Wise SOACS) SubExpRes
forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
lam
          ([TypeBase Shape NoUniqueness]
nonmap_ts, [TypeBase Shape NoUniqueness]
map_ts) = Int
-> [TypeBase Shape NoUniqueness]
-> ([TypeBase Shape NoUniqueness], [TypeBase Shape NoUniqueness])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_nonmap_res ([TypeBase Shape NoUniqueness]
 -> ([TypeBase Shape NoUniqueness], [TypeBase Shape NoUniqueness]))
-> [TypeBase Shape NoUniqueness]
-> ([TypeBase Shape NoUniqueness], [TypeBase Shape NoUniqueness])
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda (Wise SOACS)
lam
          isUsed :: (PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExpRes,
 TypeBase Shape NoUniqueness)
-> Bool
isUsed (PatElem (VarWisdom, TypeBase Shape NoUniqueness)
bindee, SubExpRes
_, TypeBase Shape NoUniqueness
_) = (VName -> UsageTable -> Bool
`UT.used` UsageTable
used) (VName -> Bool) -> VName -> Bool
forall a b. (a -> b) -> a -> b
$ PatElem (VarWisdom, TypeBase Shape NoUniqueness) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (VarWisdom, TypeBase Shape NoUniqueness)
bindee
          ([PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
map_pes', [SubExpRes]
map_res', [TypeBase Shape NoUniqueness]
map_ts') =
            [(PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExpRes,
  TypeBase Shape NoUniqueness)]
-> ([PatElem (VarWisdom, TypeBase Shape NoUniqueness)],
    [SubExpRes], [TypeBase Shape NoUniqueness])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExpRes,
   TypeBase Shape NoUniqueness)]
 -> ([PatElem (VarWisdom, TypeBase Shape NoUniqueness)],
     [SubExpRes], [TypeBase Shape NoUniqueness]))
-> [(PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExpRes,
     TypeBase Shape NoUniqueness)]
-> ([PatElem (VarWisdom, TypeBase Shape NoUniqueness)],
    [SubExpRes], [TypeBase Shape NoUniqueness])
forall a b. (a -> b) -> a -> b
$ ((PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExpRes,
  TypeBase Shape NoUniqueness)
 -> Bool)
-> [(PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExpRes,
     TypeBase Shape NoUniqueness)]
-> [(PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExpRes,
     TypeBase Shape NoUniqueness)]
forall a. (a -> Bool) -> [a] -> [a]
filter (PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExpRes,
 TypeBase Shape NoUniqueness)
-> Bool
isUsed ([(PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExpRes,
   TypeBase Shape NoUniqueness)]
 -> [(PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExpRes,
      TypeBase Shape NoUniqueness)])
-> [(PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExpRes,
     TypeBase Shape NoUniqueness)]
-> [(PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExpRes,
     TypeBase Shape NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
-> [SubExpRes]
-> [TypeBase Shape NoUniqueness]
-> [(PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExpRes,
     TypeBase Shape NoUniqueness)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
map_pes [SubExpRes]
map_res [TypeBase Shape NoUniqueness]
map_ts
          lam' :: Lambda (Wise SOACS)
lam' =
            Lambda (Wise SOACS)
lam
              { lambdaBody = (lambdaBody lam) {bodyResult = nonmap_res <> map_res'},
                lambdaReturnType = nonmap_ts <> map_ts'
              }
       in if [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
map_pes [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
-> [PatElem (VarWisdom, TypeBase Shape NoUniqueness)] -> Bool
forall a. Eq a => a -> a -> Bool
/= [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
map_pes'
            then
              RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) ()
-> Rule (Wise SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StmAux (ExpWisdom, ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$
                Pat (LetDec (Rep (RuleM (Wise SOACS))))
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind ([PatElem (LetDec (Rep (RuleM (Wise SOACS))))]
-> Pat (LetDec (Rep (RuleM (Wise SOACS))))
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem (LetDec (Rep (RuleM (Wise SOACS))))]
 -> Pat (LetDec (Rep (RuleM (Wise SOACS)))))
-> [PatElem (LetDec (Rep (RuleM (Wise SOACS))))]
-> Pat (LetDec (Rep (RuleM (Wise SOACS))))
forall a b. (a -> b) -> a -> b
$ [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
nonmap_pes [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
-> [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
-> [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
forall a. Semigroup a => a -> a -> a
<> [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
map_pes') (Exp (Wise SOACS) -> RuleM (Wise SOACS) ())
-> (SOAC (Wise SOACS) -> Exp (Wise SOACS))
-> SOAC (Wise SOACS)
-> RuleM (Wise SOACS) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op (Wise SOACS) -> Exp (Wise SOACS)
SOAC (Wise SOACS) -> Exp (Wise SOACS)
forall rep. Op rep -> Exp rep
Op (SOAC (Wise SOACS) -> RuleM (Wise SOACS) ())
-> SOAC (Wise SOACS) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
                  SubExp -> [VName] -> ScremaForm (Wise SOACS) -> SOAC (Wise SOACS)
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs (Lambda (Wise SOACS)
-> [Scan (Wise SOACS)]
-> [Reduce (Wise SOACS)]
-> ScremaForm (Wise SOACS)
forall rep.
Lambda rep -> [Scan rep] -> [Reduce rep] -> ScremaForm rep
ScremaForm Lambda (Wise SOACS)
lam' [Scan (Wise SOACS)]
scans [Reduce (Wise SOACS)]
reds)
            else Rule (Wise SOACS)
forall rep. Rule rep
Skip
  where
    num_nonmap_res :: Int
num_nonmap_res = [Scan (Wise SOACS)] -> Int
forall rep. [Scan rep] -> Int
scanResults [Scan (Wise SOACS)]
scans Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Reduce (Wise SOACS)] -> Int
forall rep. [Reduce rep] -> Int
redResults [Reduce (Wise SOACS)]
reds
removeDeadMapping BottomUp (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = Rule (Wise SOACS)
forall rep. Rule rep
Skip

removeDuplicateMapOutput :: TopDownRuleOp (Wise SOACS)
removeDuplicateMapOutput :: RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
removeDuplicateMapOutput SymbolTable (Wise SOACS)
_ (Pat [PatElem (LetDec (Wise SOACS))]
pes) StmAux (ExpDec (Wise SOACS))
aux (Screma SubExp
w [VName]
arrs ScremaForm (Wise SOACS)
form)
  | Just Lambda (Wise SOACS)
fun <- ScremaForm (Wise SOACS) -> Maybe (Lambda (Wise SOACS))
forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm (Wise SOACS)
form =
      let ses :: [SubExpRes]
ses = GBody (Wise SOACS) SubExpRes -> [SubExpRes]
forall rep res. GBody rep res -> [res]
bodyResult (GBody (Wise SOACS) SubExpRes -> [SubExpRes])
-> GBody (Wise SOACS) SubExpRes -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> GBody (Wise SOACS) SubExpRes
forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
fun
          ts :: [TypeBase Shape NoUniqueness]
ts = Lambda (Wise SOACS) -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda (Wise SOACS)
fun
          ses_ts_pes :: [(SubExpRes, TypeBase Shape NoUniqueness,
  PatElem (VarWisdom, TypeBase Shape NoUniqueness))]
ses_ts_pes = [SubExpRes]
-> [TypeBase Shape NoUniqueness]
-> [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
-> [(SubExpRes, TypeBase Shape NoUniqueness,
     PatElem (VarWisdom, TypeBase Shape NoUniqueness))]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [SubExpRes]
ses [TypeBase Shape NoUniqueness]
ts [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
[PatElem (LetDec (Wise SOACS))]
pes
          ([(SubExpRes, TypeBase Shape NoUniqueness,
  PatElem (VarWisdom, TypeBase Shape NoUniqueness))]
ses_ts_pes', [(PatElem (VarWisdom, TypeBase Shape NoUniqueness),
  PatElem (VarWisdom, TypeBase Shape NoUniqueness))]
copies) =
            (([(SubExpRes, TypeBase Shape NoUniqueness,
    PatElem (VarWisdom, TypeBase Shape NoUniqueness))],
  [(PatElem (VarWisdom, TypeBase Shape NoUniqueness),
    PatElem (VarWisdom, TypeBase Shape NoUniqueness))])
 -> (SubExpRes, TypeBase Shape NoUniqueness,
     PatElem (VarWisdom, TypeBase Shape NoUniqueness))
 -> ([(SubExpRes, TypeBase Shape NoUniqueness,
       PatElem (VarWisdom, TypeBase Shape NoUniqueness))],
     [(PatElem (VarWisdom, TypeBase Shape NoUniqueness),
       PatElem (VarWisdom, TypeBase Shape NoUniqueness))]))
-> ([(SubExpRes, TypeBase Shape NoUniqueness,
      PatElem (VarWisdom, TypeBase Shape NoUniqueness))],
    [(PatElem (VarWisdom, TypeBase Shape NoUniqueness),
      PatElem (VarWisdom, TypeBase Shape NoUniqueness))])
-> [(SubExpRes, TypeBase Shape NoUniqueness,
     PatElem (VarWisdom, TypeBase Shape NoUniqueness))]
-> ([(SubExpRes, TypeBase Shape NoUniqueness,
      PatElem (VarWisdom, TypeBase Shape NoUniqueness))],
    [(PatElem (VarWisdom, TypeBase Shape NoUniqueness),
      PatElem (VarWisdom, TypeBase Shape NoUniqueness))])
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ([(SubExpRes, TypeBase Shape NoUniqueness,
   PatElem (VarWisdom, TypeBase Shape NoUniqueness))],
 [(PatElem (VarWisdom, TypeBase Shape NoUniqueness),
   PatElem (VarWisdom, TypeBase Shape NoUniqueness))])
-> (SubExpRes, TypeBase Shape NoUniqueness,
    PatElem (VarWisdom, TypeBase Shape NoUniqueness))
-> ([(SubExpRes, TypeBase Shape NoUniqueness,
      PatElem (VarWisdom, TypeBase Shape NoUniqueness))],
    [(PatElem (VarWisdom, TypeBase Shape NoUniqueness),
      PatElem (VarWisdom, TypeBase Shape NoUniqueness))])
forall {b} {a}.
([(SubExpRes, b, a)], [(a, a)])
-> (SubExpRes, b, a) -> ([(SubExpRes, b, a)], [(a, a)])
checkForDuplicates ([(SubExpRes, TypeBase Shape NoUniqueness,
  PatElem (VarWisdom, TypeBase Shape NoUniqueness))]
forall a. Monoid a => a
mempty, [(PatElem (VarWisdom, TypeBase Shape NoUniqueness),
  PatElem (VarWisdom, TypeBase Shape NoUniqueness))]
forall a. Monoid a => a
mempty) [(SubExpRes, TypeBase Shape NoUniqueness,
  PatElem (VarWisdom, TypeBase Shape NoUniqueness))]
ses_ts_pes
       in if [(PatElem (VarWisdom, TypeBase Shape NoUniqueness),
  PatElem (VarWisdom, TypeBase Shape NoUniqueness))]
-> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(PatElem (VarWisdom, TypeBase Shape NoUniqueness),
  PatElem (VarWisdom, TypeBase Shape NoUniqueness))]
copies
            then Rule (Wise SOACS)
forall rep. Rule rep
Skip
            else RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ do
              let ([SubExpRes]
ses', [TypeBase Shape NoUniqueness]
ts', [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
pes') = [(SubExpRes, TypeBase Shape NoUniqueness,
  PatElem (VarWisdom, TypeBase Shape NoUniqueness))]
-> ([SubExpRes], [TypeBase Shape NoUniqueness],
    [PatElem (VarWisdom, TypeBase Shape NoUniqueness)])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(SubExpRes, TypeBase Shape NoUniqueness,
  PatElem (VarWisdom, TypeBase Shape NoUniqueness))]
ses_ts_pes'
                  fun' :: Lambda (Wise SOACS)
fun' =
                    Lambda (Wise SOACS)
fun
                      { lambdaBody = (lambdaBody fun) {bodyResult = ses'},
                        lambdaReturnType = ts'
                      }
              StmAux (ExpWisdom, ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ do
                Pat (LetDec (Rep (RuleM (Wise SOACS))))
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind ([PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
-> Pat (VarWisdom, TypeBase Shape NoUniqueness)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
pes') (Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ Op (Rep (RuleM (Wise SOACS))) -> Exp (Rep (RuleM (Wise SOACS)))
forall rep. Op rep -> Exp rep
Op (Op (Rep (RuleM (Wise SOACS))) -> Exp (Rep (RuleM (Wise SOACS))))
-> Op (Rep (RuleM (Wise SOACS))) -> Exp (Rep (RuleM (Wise SOACS)))
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm (Wise SOACS) -> SOAC (Wise SOACS)
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs (ScremaForm (Wise SOACS) -> SOAC (Wise SOACS))
-> ScremaForm (Wise SOACS) -> SOAC (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> ScremaForm (Wise SOACS)
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda (Wise SOACS)
fun'
                [(PatElem (VarWisdom, TypeBase Shape NoUniqueness),
  PatElem (VarWisdom, TypeBase Shape NoUniqueness))]
-> ((PatElem (VarWisdom, TypeBase Shape NoUniqueness),
     PatElem (VarWisdom, TypeBase Shape NoUniqueness))
    -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(PatElem (VarWisdom, TypeBase Shape NoUniqueness),
  PatElem (VarWisdom, TypeBase Shape NoUniqueness))]
copies (((PatElem (VarWisdom, TypeBase Shape NoUniqueness),
   PatElem (VarWisdom, TypeBase Shape NoUniqueness))
  -> RuleM (Wise SOACS) ())
 -> RuleM (Wise SOACS) ())
-> ((PatElem (VarWisdom, TypeBase Shape NoUniqueness),
     PatElem (VarWisdom, TypeBase Shape NoUniqueness))
    -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ \(PatElem (VarWisdom, TypeBase Shape NoUniqueness)
from, PatElem (VarWisdom, TypeBase Shape NoUniqueness)
to) ->
                  Pat (LetDec (Rep (RuleM (Wise SOACS))))
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind ([PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
-> Pat (VarWisdom, TypeBase Shape NoUniqueness)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (VarWisdom, TypeBase Shape NoUniqueness)
to]) (Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM (Wise SOACS)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM (Wise SOACS))))
-> BasicOp -> Exp (Rep (RuleM (Wise SOACS)))
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
forall a. Monoid a => a
mempty (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElem (VarWisdom, TypeBase Shape NoUniqueness) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (VarWisdom, TypeBase Shape NoUniqueness)
from
  where
    checkForDuplicates :: ([(SubExpRes, b, a)], [(a, a)])
-> (SubExpRes, b, a) -> ([(SubExpRes, b, a)], [(a, a)])
checkForDuplicates ([(SubExpRes, b, a)]
ses_ts_pes', [(a, a)]
copies) (SubExpRes
se, b
t, a
pe)
      | Just (SubExpRes
_, b
_, a
pe') <- ((SubExpRes, b, a) -> Bool)
-> [(SubExpRes, b, a)] -> Maybe (SubExpRes, b, a)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (\(SubExpRes
x, b
_, a
_) -> SubExpRes -> SubExp
resSubExp SubExpRes
x SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExpRes -> SubExp
resSubExp SubExpRes
se) [(SubExpRes, b, a)]
ses_ts_pes' =
          -- This result has been returned before, producing the
          -- array pe'.
          ([(SubExpRes, b, a)]
ses_ts_pes', (a
pe', a
pe) (a, a) -> [(a, a)] -> [(a, a)]
forall a. a -> [a] -> [a]
: [(a, a)]
copies)
      | Bool
otherwise = ([(SubExpRes, b, a)]
ses_ts_pes' [(SubExpRes, b, a)] -> [(SubExpRes, b, a)] -> [(SubExpRes, b, a)]
forall a. [a] -> [a] -> [a]
++ [(SubExpRes
se, b
t, a
pe)], [(a, a)]
copies)
removeDuplicateMapOutput SymbolTable (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = Rule (Wise SOACS)
forall rep. Rule rep
Skip

reshapeInner :: SubExp -> NewShape SubExp -> NewShape SubExp
reshapeInner :: SubExp -> NewShape SubExp -> NewShape SubExp
reshapeInner SubExp
w NewShape SubExp
new_shape =
  Shape -> NewShape SubExp
forall new. ShapeBase new -> NewShape new
reshapeCoerce Shape
outer NewShape SubExp -> NewShape SubExp -> NewShape SubExp
forall a. Semigroup a => a -> a -> a
<> Shape -> NewShape SubExp -> NewShape SubExp
newshapeInner Shape
outer NewShape SubExp
new_shape
  where
    outer :: Shape
outer = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w]

-- Mapping some operations becomes an extension of that operation.
mapOpToOp :: BottomUpRuleOp (Wise SOACS)
mapOpToOp :: RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
mapOpToOp (SymbolTable (Wise SOACS)
_, UsageTable
used) Pat (LetDec (Wise SOACS))
pat StmAux (ExpDec (Wise SOACS))
aux1 Op (Wise SOACS)
e
  | Just (PatElem (VarWisdom, TypeBase Shape NoUniqueness)
map_pe, Certs
cs, SubExp
w, BasicOp (Reshape VName
reshape_arr NewShape SubExp
newshape), [Param (TypeBase Shape NoUniqueness)
p], [VName
arr]) <-
      Pat (VarWisdom, TypeBase Shape NoUniqueness)
-> SOAC (Wise SOACS)
-> Maybe
     (PatElem (VarWisdom, TypeBase Shape NoUniqueness), Certs, SubExp,
      Exp (Wise SOACS), [Param (TypeBase Shape NoUniqueness)], [VName])
forall dec.
Pat dec
-> SOAC (Wise SOACS)
-> Maybe
     (PatElem dec, Certs, SubExp, Exp (Wise SOACS),
      [Param (TypeBase Shape NoUniqueness)], [VName])
isMapWithOp Pat (VarWisdom, TypeBase Shape NoUniqueness)
Pat (LetDec (Wise SOACS))
pat Op (Wise SOACS)
SOAC (Wise SOACS)
e,
    Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
p VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
reshape_arr,
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> UsageTable -> Bool
UT.isConsumed (PatElem (VarWisdom, TypeBase Shape NoUniqueness) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (VarWisdom, TypeBase Shape NoUniqueness)
map_pe) UsageTable
used = RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ do
      Certs -> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a. Certs -> RuleM (Wise SOACS) a -> RuleM (Wise SOACS) a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (StmAux (ExpWisdom, ()) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux1 Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
cs) (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> (BasicOp -> RuleM (Wise SOACS) ())
-> BasicOp
-> RuleM (Wise SOACS) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec (Rep (RuleM (Wise SOACS))))
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Rep (RuleM (Wise SOACS))))
Pat (LetDec (Wise SOACS))
pat (Exp (Wise SOACS) -> RuleM (Wise SOACS) ())
-> (BasicOp -> Exp (Wise SOACS))
-> BasicOp
-> RuleM (Wise SOACS) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp (Wise SOACS)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> RuleM (Wise SOACS) ())
-> BasicOp -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
        VName -> NewShape SubExp -> BasicOp
Reshape VName
arr (SubExp -> NewShape SubExp -> NewShape SubExp
reshapeInner SubExp
w NewShape SubExp
newshape)
  | Just (PatElem (VarWisdom, TypeBase Shape NoUniqueness)
_, Certs
cs, SubExp
_, BasicOp (Concat Int
d (VName
arr :| [VName]
arrs) SubExp
dw), [Param (TypeBase Shape NoUniqueness)]
ps, VName
outer_arr : [VName]
outer_arrs) <-
      Pat (VarWisdom, TypeBase Shape NoUniqueness)
-> SOAC (Wise SOACS)
-> Maybe
     (PatElem (VarWisdom, TypeBase Shape NoUniqueness), Certs, SubExp,
      Exp (Wise SOACS), [Param (TypeBase Shape NoUniqueness)], [VName])
forall dec.
Pat dec
-> SOAC (Wise SOACS)
-> Maybe
     (PatElem dec, Certs, SubExp, Exp (Wise SOACS),
      [Param (TypeBase Shape NoUniqueness)], [VName])
isMapWithOp Pat (VarWisdom, TypeBase Shape NoUniqueness)
Pat (LetDec (Wise SOACS))
pat Op (Wise SOACS)
SOAC (Wise SOACS)
e,
    (VName
arr VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
arrs) [VName] -> [VName] -> Bool
forall a. Eq a => a -> a -> Bool
== (Param (TypeBase Shape NoUniqueness) -> VName)
-> [Param (TypeBase Shape NoUniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName [Param (TypeBase Shape NoUniqueness)]
ps =
      RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> (BasicOp -> RuleM (Wise SOACS) ())
-> BasicOp
-> Rule (Wise SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Certs -> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a. Certs -> RuleM (Wise SOACS) a -> RuleM (Wise SOACS) a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (StmAux (ExpWisdom, ()) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux1 Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
cs) (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> (BasicOp -> RuleM (Wise SOACS) ())
-> BasicOp
-> RuleM (Wise SOACS) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec (Rep (RuleM (Wise SOACS))))
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Rep (RuleM (Wise SOACS))))
Pat (LetDec (Wise SOACS))
pat (Exp (Wise SOACS) -> RuleM (Wise SOACS) ())
-> (BasicOp -> Exp (Wise SOACS))
-> BasicOp
-> RuleM (Wise SOACS) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp (Wise SOACS)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Rule (Wise SOACS)) -> BasicOp -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$
        Int -> NonEmpty VName -> SubExp -> BasicOp
Concat (Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (VName
outer_arr VName -> [VName] -> NonEmpty VName
forall a. a -> [a] -> NonEmpty a
:| [VName]
outer_arrs) SubExp
dw
  | Just
      (PatElem (VarWisdom, TypeBase Shape NoUniqueness)
map_pe, Certs
cs, SubExp
_, BasicOp (Rearrange VName
rearrange_arr [Int]
perm), [Param (TypeBase Shape NoUniqueness)
p], [VName
arr]) <-
      Pat (VarWisdom, TypeBase Shape NoUniqueness)
-> SOAC (Wise SOACS)
-> Maybe
     (PatElem (VarWisdom, TypeBase Shape NoUniqueness), Certs, SubExp,
      Exp (Wise SOACS), [Param (TypeBase Shape NoUniqueness)], [VName])
forall dec.
Pat dec
-> SOAC (Wise SOACS)
-> Maybe
     (PatElem dec, Certs, SubExp, Exp (Wise SOACS),
      [Param (TypeBase Shape NoUniqueness)], [VName])
isMapWithOp Pat (VarWisdom, TypeBase Shape NoUniqueness)
Pat (LetDec (Wise SOACS))
pat Op (Wise SOACS)
SOAC (Wise SOACS)
e,
    Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
p VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
rearrange_arr,
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> UsageTable -> Bool
UT.isConsumed (PatElem (VarWisdom, TypeBase Shape NoUniqueness) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (VarWisdom, TypeBase Shape NoUniqueness)
map_pe) UsageTable
used =
      RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> (BasicOp -> RuleM (Wise SOACS) ())
-> BasicOp
-> Rule (Wise SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Certs -> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a. Certs -> RuleM (Wise SOACS) a -> RuleM (Wise SOACS) a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (StmAux (ExpWisdom, ()) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux1 Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
cs) (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> (BasicOp -> RuleM (Wise SOACS) ())
-> BasicOp
-> RuleM (Wise SOACS) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec (Rep (RuleM (Wise SOACS))))
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Rep (RuleM (Wise SOACS))))
Pat (LetDec (Wise SOACS))
pat (Exp (Wise SOACS) -> RuleM (Wise SOACS) ())
-> (BasicOp -> Exp (Wise SOACS))
-> BasicOp
-> RuleM (Wise SOACS) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp (Wise SOACS)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Rule (Wise SOACS)) -> BasicOp -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$
        VName -> [Int] -> BasicOp
Rearrange VName
arr (Int
0 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: (Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int
1 +) [Int]
perm)
mapOpToOp BottomUp (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = Rule (Wise SOACS)
forall rep. Rule rep
Skip

isMapWithOp ::
  Pat dec ->
  SOAC (Wise SOACS) ->
  Maybe
    ( PatElem dec,
      Certs,
      SubExp,
      Exp (Wise SOACS),
      [Param Type],
      [VName]
    )
isMapWithOp :: forall dec.
Pat dec
-> SOAC (Wise SOACS)
-> Maybe
     (PatElem dec, Certs, SubExp, Exp (Wise SOACS),
      [Param (TypeBase Shape NoUniqueness)], [VName])
isMapWithOp Pat dec
pat SOAC (Wise SOACS)
e
  | Pat [PatElem dec
map_pe] <- Pat dec
pat,
    Screma SubExp
w [VName]
arrs ScremaForm (Wise SOACS)
form <- SOAC (Wise SOACS)
e,
    Just Lambda (Wise SOACS)
map_lam <- ScremaForm (Wise SOACS) -> Maybe (Lambda (Wise SOACS))
forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm (Wise SOACS)
form,
    [Let (Pat [PatElem (LetDec (Wise SOACS))
pe]) StmAux (ExpDec (Wise SOACS))
aux2 Exp (Wise SOACS)
e'] <- Stms (Wise SOACS) -> [Stm (Wise SOACS)]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms (Wise SOACS) -> [Stm (Wise SOACS)])
-> Stms (Wise SOACS) -> [Stm (Wise SOACS)]
forall a b. (a -> b) -> a -> b
$ GBody (Wise SOACS) SubExpRes -> Stms (Wise SOACS)
forall rep res. GBody rep res -> Stms rep
bodyStms (GBody (Wise SOACS) SubExpRes -> Stms (Wise SOACS))
-> GBody (Wise SOACS) SubExpRes -> Stms (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> GBody (Wise SOACS) SubExpRes
forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
map_lam,
    [SubExpRes Certs
_ (Var VName
r)] <- GBody (Wise SOACS) SubExpRes -> [SubExpRes]
forall rep res. GBody rep res -> [res]
bodyResult (GBody (Wise SOACS) SubExpRes -> [SubExpRes])
-> GBody (Wise SOACS) SubExpRes -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> GBody (Wise SOACS) SubExpRes
forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
map_lam,
    VName
r VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== PatElem (VarWisdom, TypeBase Shape NoUniqueness) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (VarWisdom, TypeBase Shape NoUniqueness)
PatElem (LetDec (Wise SOACS))
pe =
      (PatElem dec, Certs, SubExp, Exp (Wise SOACS),
 [Param (TypeBase Shape NoUniqueness)], [VName])
-> Maybe
     (PatElem dec, Certs, SubExp, Exp (Wise SOACS),
      [Param (TypeBase Shape NoUniqueness)], [VName])
forall a. a -> Maybe a
Just (PatElem dec
map_pe, StmAux (ExpWisdom, ()) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux2, SubExp
w, Exp (Wise SOACS)
e', Lambda (Wise SOACS) -> [LParam (Wise SOACS)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda (Wise SOACS)
map_lam, [VName]
arrs)
  | Bool
otherwise = Maybe
  (PatElem dec, Certs, SubExp, Exp (Wise SOACS),
   [Param (TypeBase Shape NoUniqueness)], [VName])
forall a. Maybe a
Nothing

-- | Some of the results of a reduction (or really: Redomap) may be
-- dead.  We remove them here.  The trick is that we need to look at
-- the data dependencies to see that the "dead" result is not
-- actually used for computing one of the live ones.
removeDeadReduction :: BottomUpRuleOp (Wise SOACS)
removeDeadReduction :: RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
removeDeadReduction (SymbolTable (Wise SOACS)
_, UsageTable
used) Pat (LetDec (Wise SOACS))
pat StmAux (ExpDec (Wise SOACS))
aux (Screma SubExp
w [VName]
arrs ScremaForm (Wise SOACS)
form) =
  case ScremaForm (Wise SOACS)
-> Maybe ([Reduce (Wise SOACS)], Lambda (Wise SOACS))
forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm (Wise SOACS)
form of
    Just ([Reduce Commutativity
comm Lambda (Wise SOACS)
redlam [SubExp]
rednes], Lambda (Wise SOACS)
maplam) ->
      let mkOp :: Lambda (Wise SOACS)
-> [SubExp] -> Lambda (Wise SOACS) -> ScremaForm (Wise SOACS)
mkOp Lambda (Wise SOACS)
lam [SubExp]
nes' = [Reduce (Wise SOACS)]
-> Lambda (Wise SOACS) -> ScremaForm (Wise SOACS)
forall rep. [Reduce rep] -> Lambda rep -> ScremaForm rep
redomapSOAC [Commutativity
-> Lambda (Wise SOACS) -> [SubExp] -> Reduce (Wise SOACS)
forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
comm Lambda (Wise SOACS)
lam [SubExp]
nes']
       in Lambda (Wise SOACS)
-> [SubExp]
-> Lambda (Wise SOACS)
-> (Lambda (Wise SOACS)
    -> [SubExp] -> Lambda (Wise SOACS) -> ScremaForm (Wise SOACS))
-> Rule (Wise SOACS)
removeDeadReduction' Lambda (Wise SOACS)
redlam [SubExp]
rednes Lambda (Wise SOACS)
maplam Lambda (Wise SOACS)
-> [SubExp] -> Lambda (Wise SOACS) -> ScremaForm (Wise SOACS)
mkOp
    Maybe ([Reduce (Wise SOACS)], Lambda (Wise SOACS))
_ ->
      case ScremaForm (Wise SOACS)
-> Maybe ([Scan (Wise SOACS)], Lambda (Wise SOACS))
forall rep. ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
isScanomapSOAC ScremaForm (Wise SOACS)
form of
        Just ([Scan Lambda (Wise SOACS)
scanlam [SubExp]
nes], Lambda (Wise SOACS)
maplam) ->
          let mkOp :: Lambda rep -> [SubExp] -> Lambda rep -> ScremaForm rep
mkOp Lambda rep
lam [SubExp]
nes' = [Scan rep] -> Lambda rep -> ScremaForm rep
forall rep. [Scan rep] -> Lambda rep -> ScremaForm rep
scanomapSOAC [Lambda rep -> [SubExp] -> Scan rep
forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda rep
lam [SubExp]
nes']
           in Lambda (Wise SOACS)
-> [SubExp]
-> Lambda (Wise SOACS)
-> (Lambda (Wise SOACS)
    -> [SubExp] -> Lambda (Wise SOACS) -> ScremaForm (Wise SOACS))
-> Rule (Wise SOACS)
removeDeadReduction' Lambda (Wise SOACS)
scanlam [SubExp]
nes Lambda (Wise SOACS)
maplam Lambda (Wise SOACS)
-> [SubExp] -> Lambda (Wise SOACS) -> ScremaForm (Wise SOACS)
forall {rep}.
Lambda rep -> [SubExp] -> Lambda rep -> ScremaForm rep
mkOp
        Maybe ([Scan (Wise SOACS)], Lambda (Wise SOACS))
_ -> Rule (Wise SOACS)
forall rep. Rule rep
Skip
  where
    removeDeadReduction' :: Lambda (Wise SOACS)
-> [SubExp]
-> Lambda (Wise SOACS)
-> (Lambda (Wise SOACS)
    -> [SubExp] -> Lambda (Wise SOACS) -> ScremaForm (Wise SOACS))
-> Rule (Wise SOACS)
removeDeadReduction' Lambda (Wise SOACS)
redlam [SubExp]
nes Lambda (Wise SOACS)
maplam Lambda (Wise SOACS)
-> [SubExp] -> Lambda (Wise SOACS) -> ScremaForm (Wise SOACS)
mkOp
      | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> UsageTable -> Bool
`UT.used` UsageTable
used) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Pat (VarWisdom, TypeBase Shape NoUniqueness) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (VarWisdom, TypeBase Shape NoUniqueness)
Pat (LetDec (Wise SOACS))
pat, -- Quick/cheap check
        let ([PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
red_pes, [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
map_pes) = Int
-> [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
-> ([PatElem (VarWisdom, TypeBase Shape NoUniqueness)],
    [PatElem (VarWisdom, TypeBase Shape NoUniqueness)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
 -> ([PatElem (VarWisdom, TypeBase Shape NoUniqueness)],
     [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]))
-> [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
-> ([PatElem (VarWisdom, TypeBase Shape NoUniqueness)],
    [PatElem (VarWisdom, TypeBase Shape NoUniqueness)])
forall a b. (a -> b) -> a -> b
$ Pat (VarWisdom, TypeBase Shape NoUniqueness)
-> [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (VarWisdom, TypeBase Shape NoUniqueness)
Pat (LetDec (Wise SOACS))
pat,
        let redlam_deps :: Dependencies
redlam_deps = GBody (Wise SOACS) SubExpRes -> Dependencies
forall rep. ASTRep rep => Body rep -> Dependencies
dataDependencies (GBody (Wise SOACS) SubExpRes -> Dependencies)
-> GBody (Wise SOACS) SubExpRes -> Dependencies
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> GBody (Wise SOACS) SubExpRes
forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
redlam,
        let redlam_res :: [SubExpRes]
redlam_res = GBody (Wise SOACS) SubExpRes -> [SubExpRes]
forall rep res. GBody rep res -> [res]
bodyResult (GBody (Wise SOACS) SubExpRes -> [SubExpRes])
-> GBody (Wise SOACS) SubExpRes -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> GBody (Wise SOACS) SubExpRes
forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
redlam,
        let redlam_params :: [LParam (Wise SOACS)]
redlam_params = Lambda (Wise SOACS) -> [LParam (Wise SOACS)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda (Wise SOACS)
redlam,
        let ([Param (TypeBase Shape NoUniqueness)]
redlam_xparams, [Param (TypeBase Shape NoUniqueness)]
redlam_yparams) =
              Int
-> [Param (TypeBase Shape NoUniqueness)]
-> ([Param (TypeBase Shape NoUniqueness)],
    [Param (TypeBase Shape NoUniqueness)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [Param (TypeBase Shape NoUniqueness)]
[LParam (Wise SOACS)]
redlam_params,
        let used_after :: [Param (TypeBase Shape NoUniqueness)]
used_after =
              ((PatElem (VarWisdom, TypeBase Shape NoUniqueness),
  Param (TypeBase Shape NoUniqueness))
 -> Param (TypeBase Shape NoUniqueness))
-> [(PatElem (VarWisdom, TypeBase Shape NoUniqueness),
     Param (TypeBase Shape NoUniqueness))]
-> [Param (TypeBase Shape NoUniqueness)]
forall a b. (a -> b) -> [a] -> [b]
map (PatElem (VarWisdom, TypeBase Shape NoUniqueness),
 Param (TypeBase Shape NoUniqueness))
-> Param (TypeBase Shape NoUniqueness)
forall a b. (a, b) -> b
snd ([(PatElem (VarWisdom, TypeBase Shape NoUniqueness),
   Param (TypeBase Shape NoUniqueness))]
 -> [Param (TypeBase Shape NoUniqueness)])
-> ([(PatElem (VarWisdom, TypeBase Shape NoUniqueness),
      Param (TypeBase Shape NoUniqueness))]
    -> [(PatElem (VarWisdom, TypeBase Shape NoUniqueness),
         Param (TypeBase Shape NoUniqueness))])
-> [(PatElem (VarWisdom, TypeBase Shape NoUniqueness),
     Param (TypeBase Shape NoUniqueness))]
-> [Param (TypeBase Shape NoUniqueness)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((PatElem (VarWisdom, TypeBase Shape NoUniqueness),
  Param (TypeBase Shape NoUniqueness))
 -> Bool)
-> [(PatElem (VarWisdom, TypeBase Shape NoUniqueness),
     Param (TypeBase Shape NoUniqueness))]
-> [(PatElem (VarWisdom, TypeBase Shape NoUniqueness),
     Param (TypeBase Shape NoUniqueness))]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> UsageTable -> Bool
`UT.used` UsageTable
used) (VName -> Bool)
-> ((PatElem (VarWisdom, TypeBase Shape NoUniqueness),
     Param (TypeBase Shape NoUniqueness))
    -> VName)
-> (PatElem (VarWisdom, TypeBase Shape NoUniqueness),
    Param (TypeBase Shape NoUniqueness))
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem (VarWisdom, TypeBase Shape NoUniqueness) -> VName
forall dec. PatElem dec -> VName
patElemName (PatElem (VarWisdom, TypeBase Shape NoUniqueness) -> VName)
-> ((PatElem (VarWisdom, TypeBase Shape NoUniqueness),
     Param (TypeBase Shape NoUniqueness))
    -> PatElem (VarWisdom, TypeBase Shape NoUniqueness))
-> (PatElem (VarWisdom, TypeBase Shape NoUniqueness),
    Param (TypeBase Shape NoUniqueness))
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElem (VarWisdom, TypeBase Shape NoUniqueness),
 Param (TypeBase Shape NoUniqueness))
-> PatElem (VarWisdom, TypeBase Shape NoUniqueness)
forall a b. (a, b) -> a
fst) ([(PatElem (VarWisdom, TypeBase Shape NoUniqueness),
   Param (TypeBase Shape NoUniqueness))]
 -> [Param (TypeBase Shape NoUniqueness)])
-> [(PatElem (VarWisdom, TypeBase Shape NoUniqueness),
     Param (TypeBase Shape NoUniqueness))]
-> [Param (TypeBase Shape NoUniqueness)]
forall a b. (a -> b) -> a -> b
$
                [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [(PatElem (VarWisdom, TypeBase Shape NoUniqueness),
     Param (TypeBase Shape NoUniqueness))]
forall a b. [a] -> [b] -> [(a, b)]
zip ([PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
red_pes [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
-> [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
-> [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
forall a. Semigroup a => a -> a -> a
<> [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
red_pes) [Param (TypeBase Shape NoUniqueness)]
[LParam (Wise SOACS)]
redlam_params,
        let necessary :: Names
necessary =
              (Param (TypeBase Shape NoUniqueness) -> Bool)
-> [(Param (TypeBase Shape NoUniqueness), SubExp)]
-> Dependencies
-> Names
forall dec.
(Param dec -> Bool)
-> [(Param dec, SubExp)] -> Dependencies -> Names
findNecessaryForReturned
                (Param (TypeBase Shape NoUniqueness)
-> [Param (TypeBase Shape NoUniqueness)] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Param (TypeBase Shape NoUniqueness)]
used_after)
                ([Param (TypeBase Shape NoUniqueness)]
-> [SubExp] -> [(Param (TypeBase Shape NoUniqueness), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase Shape NoUniqueness)]
[LParam (Wise SOACS)]
redlam_params ([SubExp] -> [(Param (TypeBase Shape NoUniqueness), SubExp)])
-> [SubExp] -> [(Param (TypeBase Shape NoUniqueness), SubExp)]
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ([SubExpRes] -> [SubExp]) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ [SubExpRes]
redlam_res [SubExpRes] -> [SubExpRes] -> [SubExpRes]
forall a. Semigroup a => a -> a -> a
<> [SubExpRes]
redlam_res)
                Dependencies
redlam_deps,
        let alive_mask :: [Bool]
alive_mask =
              (Bool -> Bool -> Bool) -> [Bool] -> [Bool] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
                Bool -> Bool -> Bool
(||)
                ((Param (TypeBase Shape NoUniqueness) -> Bool)
-> [Param (TypeBase Shape NoUniqueness)] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map ((VName -> Names -> Bool
`nameIn` Names
necessary) (VName -> Bool)
-> (Param (TypeBase Shape NoUniqueness) -> VName)
-> Param (TypeBase Shape NoUniqueness)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName) [Param (TypeBase Shape NoUniqueness)]
redlam_xparams)
                ((Param (TypeBase Shape NoUniqueness) -> Bool)
-> [Param (TypeBase Shape NoUniqueness)] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map ((VName -> Names -> Bool
`nameIn` Names
necessary) (VName -> Bool)
-> (Param (TypeBase Shape NoUniqueness) -> VName)
-> Param (TypeBase Shape NoUniqueness)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName) [Param (TypeBase Shape NoUniqueness)]
redlam_yparams),
        Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and [Bool]
alive_mask = RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ do
          let fixDeadToNeutral :: Bool -> a -> Maybe a
fixDeadToNeutral Bool
lives a
ne = if Bool
lives then Maybe a
forall a. Maybe a
Nothing else a -> Maybe a
forall a. a -> Maybe a
Just a
ne
              dead_fix :: [Maybe SubExp]
dead_fix = (Bool -> SubExp -> Maybe SubExp)
-> [Bool] -> [SubExp] -> [Maybe SubExp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Bool -> SubExp -> Maybe SubExp
forall {a}. Bool -> a -> Maybe a
fixDeadToNeutral [Bool]
alive_mask [SubExp]
nes
              ([PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
used_red_pes, [SubExp]
used_nes) =
                [(PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExp)]
-> ([PatElem (VarWisdom, TypeBase Shape NoUniqueness)], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExp)]
 -> ([PatElem (VarWisdom, TypeBase Shape NoUniqueness)], [SubExp]))
-> ([(Bool,
      (PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExp))]
    -> [(PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExp)])
-> [(Bool,
     (PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExp))]
-> ([PatElem (VarWisdom, TypeBase Shape NoUniqueness)], [SubExp])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Bool, (PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExp))
 -> (PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExp))
-> [(Bool,
     (PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExp))]
-> [(PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExp)]
forall a b. (a -> b) -> [a] -> [b]
map (Bool, (PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExp))
-> (PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExp)
forall a b. (a, b) -> b
snd ([(Bool,
   (PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExp))]
 -> [(PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExp)])
-> ([(Bool,
      (PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExp))]
    -> [(Bool,
         (PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExp))])
-> [(Bool,
     (PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExp))]
-> [(PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExp)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Bool, (PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExp))
 -> Bool)
-> [(Bool,
     (PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExp))]
-> [(Bool,
     (PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExp))]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool, (PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExp))
-> Bool
forall a b. (a, b) -> a
fst ([(Bool,
   (PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExp))]
 -> ([PatElem (VarWisdom, TypeBase Shape NoUniqueness)], [SubExp]))
-> [(Bool,
     (PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExp))]
-> ([PatElem (VarWisdom, TypeBase Shape NoUniqueness)], [SubExp])
forall a b. (a -> b) -> a -> b
$ [Bool]
-> [(PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExp)]
-> [(Bool,
     (PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExp))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
alive_mask ([(PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExp)]
 -> [(Bool,
      (PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExp))])
-> [(PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExp)]
-> [(Bool,
     (PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExp))]
forall a b. (a -> b) -> a -> b
$ [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
-> [SubExp]
-> [(PatElem (VarWisdom, TypeBase Shape NoUniqueness), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
red_pes [SubExp]
nes

          Bool -> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([SubExp]
used_nes [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp]
nes) RuleM (Wise SOACS) ()
forall rep a. RuleM rep a
cannotSimplify

          let maplam' :: Lambda (Wise SOACS)
maplam' = [Bool] -> Lambda (Wise SOACS) -> Lambda (Wise SOACS)
forall rep. [Bool] -> Lambda rep -> Lambda rep
removeLambdaResults [Bool]
alive_mask Lambda (Wise SOACS)
maplam
          Lambda (Wise SOACS)
redlam' <-
            [Bool] -> Lambda (Wise SOACS) -> Lambda (Wise SOACS)
forall rep. [Bool] -> Lambda rep -> Lambda rep
removeLambdaResults [Bool]
alive_mask
              (Lambda (Wise SOACS) -> Lambda (Wise SOACS))
-> RuleM (Wise SOACS) (Lambda (Wise SOACS))
-> RuleM (Wise SOACS) (Lambda (Wise SOACS))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Lambda (Rep (RuleM (Wise SOACS)))
-> [Maybe SubExp]
-> RuleM (Wise SOACS) (Lambda (Rep (RuleM (Wise SOACS))))
forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m), BuilderOps (Rep m)) =>
Lambda (Rep m) -> [Maybe SubExp] -> m (Lambda (Rep m))
fixLambdaParams Lambda (Rep (RuleM (Wise SOACS)))
Lambda (Wise SOACS)
redlam ([Maybe SubExp]
dead_fix [Maybe SubExp] -> [Maybe SubExp] -> [Maybe SubExp]
forall a. [a] -> [a] -> [a]
++ [Maybe SubExp]
dead_fix)

          StmAux (ExpWisdom, ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> (SOAC (Wise SOACS) -> RuleM (Wise SOACS) ())
-> SOAC (Wise SOACS)
-> RuleM (Wise SOACS) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec (Rep (RuleM (Wise SOACS))))
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind ([PatElem (LetDec (Rep (RuleM (Wise SOACS))))]
-> Pat (LetDec (Rep (RuleM (Wise SOACS))))
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem (LetDec (Rep (RuleM (Wise SOACS))))]
 -> Pat (LetDec (Rep (RuleM (Wise SOACS)))))
-> [PatElem (LetDec (Rep (RuleM (Wise SOACS))))]
-> Pat (LetDec (Rep (RuleM (Wise SOACS))))
forall a b. (a -> b) -> a -> b
$ [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
used_red_pes [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
-> [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
-> [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
map_pes) (Exp (Wise SOACS) -> RuleM (Wise SOACS) ())
-> (SOAC (Wise SOACS) -> Exp (Wise SOACS))
-> SOAC (Wise SOACS)
-> RuleM (Wise SOACS) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op (Wise SOACS) -> Exp (Wise SOACS)
SOAC (Wise SOACS) -> Exp (Wise SOACS)
forall rep. Op rep -> Exp rep
Op (SOAC (Wise SOACS) -> RuleM (Wise SOACS) ())
-> SOAC (Wise SOACS) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
            SubExp -> [VName] -> ScremaForm (Wise SOACS) -> SOAC (Wise SOACS)
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs (Lambda (Wise SOACS)
-> [SubExp] -> Lambda (Wise SOACS) -> ScremaForm (Wise SOACS)
mkOp Lambda (Wise SOACS)
redlam' [SubExp]
used_nes Lambda (Wise SOACS)
maplam')
    removeDeadReduction' Lambda (Wise SOACS)
_ [SubExp]
_ Lambda (Wise SOACS)
_ Lambda (Wise SOACS)
-> [SubExp] -> Lambda (Wise SOACS) -> ScremaForm (Wise SOACS)
_ = Rule (Wise SOACS)
forall rep. Rule rep
Skip
removeDeadReduction BottomUp (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = Rule (Wise SOACS)
forall rep. Rule rep
Skip

{-# NOINLINE fuseConcatScatter #-}
fuseConcatScatter :: TopDownRuleOp (Wise SOACS)
fuseConcatScatter :: RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
fuseConcatScatter SymbolTable (Wise SOACS)
vtable Pat (LetDec (Wise SOACS))
pat StmAux (ExpDec (Wise SOACS))
aux (Screma SubExp
_ [VName]
arrs ScremaForm (Wise SOACS)
form)
  | Just Lambda (Wise SOACS)
lam <- ScremaForm (Wise SOACS) -> Maybe (Lambda (Wise SOACS))
forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm (Wise SOACS)
form,
    (TypeBase Shape NoUniqueness -> Bool)
-> [TypeBase Shape NoUniqueness] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
isAcc ([TypeBase Shape NoUniqueness] -> Bool)
-> [TypeBase Shape NoUniqueness] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda (Wise SOACS)
lam,
    Just ([(Param (TypeBase Shape NoUniqueness), VName)]
accs, (ws :: [SubExp]
ws@(SubExp
w' : [SubExp]
_), [Param (TypeBase Shape NoUniqueness)]
ps, [[VName]]
xss, [Certs]
css)) <-
      ([(SubExp, Param (TypeBase Shape NoUniqueness), [VName], Certs)]
 -> ([SubExp], [Param (TypeBase Shape NoUniqueness)], [[VName]],
     [Certs]))
-> ([(Param (TypeBase Shape NoUniqueness), VName)],
    [(SubExp, Param (TypeBase Shape NoUniqueness), [VName], Certs)])
-> ([(Param (TypeBase Shape NoUniqueness), VName)],
    ([SubExp], [Param (TypeBase Shape NoUniqueness)], [[VName]],
     [Certs]))
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second [(SubExp, Param (TypeBase Shape NoUniqueness), [VName], Certs)]
-> ([SubExp], [Param (TypeBase Shape NoUniqueness)], [[VName]],
    [Certs])
forall a b c d. [(a, b, c, d)] -> ([a], [b], [c], [d])
unzip4 (([(Param (TypeBase Shape NoUniqueness), VName)],
  [(SubExp, Param (TypeBase Shape NoUniqueness), [VName], Certs)])
 -> ([(Param (TypeBase Shape NoUniqueness), VName)],
     ([SubExp], [Param (TypeBase Shape NoUniqueness)], [[VName]],
      [Certs])))
-> ([Either
       (Param (TypeBase Shape NoUniqueness), VName)
       (SubExp, Param (TypeBase Shape NoUniqueness), [VName], Certs)]
    -> ([(Param (TypeBase Shape NoUniqueness), VName)],
        [(SubExp, Param (TypeBase Shape NoUniqueness), [VName], Certs)]))
-> [Either
      (Param (TypeBase Shape NoUniqueness), VName)
      (SubExp, Param (TypeBase Shape NoUniqueness), [VName], Certs)]
-> ([(Param (TypeBase Shape NoUniqueness), VName)],
    ([SubExp], [Param (TypeBase Shape NoUniqueness)], [[VName]],
     [Certs]))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Either
   (Param (TypeBase Shape NoUniqueness), VName)
   (SubExp, Param (TypeBase Shape NoUniqueness), [VName], Certs)]
-> ([(Param (TypeBase Shape NoUniqueness), VName)],
    [(SubExp, Param (TypeBase Shape NoUniqueness), [VName], Certs)])
forall a b. [Either a b] -> ([a], [b])
partitionEithers
        ([Either
    (Param (TypeBase Shape NoUniqueness), VName)
    (SubExp, Param (TypeBase Shape NoUniqueness), [VName], Certs)]
 -> ([(Param (TypeBase Shape NoUniqueness), VName)],
     ([SubExp], [Param (TypeBase Shape NoUniqueness)], [[VName]],
      [Certs])))
-> Maybe
     [Either
        (Param (TypeBase Shape NoUniqueness), VName)
        (SubExp, Param (TypeBase Shape NoUniqueness), [VName], Certs)]
-> Maybe
     ([(Param (TypeBase Shape NoUniqueness), VName)],
      ([SubExp], [Param (TypeBase Shape NoUniqueness)], [[VName]],
       [Certs]))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Param (TypeBase Shape NoUniqueness), VName)
 -> Maybe
      (Either
         (Param (TypeBase Shape NoUniqueness), VName)
         (SubExp, Param (TypeBase Shape NoUniqueness), [VName], Certs)))
-> [(Param (TypeBase Shape NoUniqueness), VName)]
-> Maybe
     [Either
        (Param (TypeBase Shape NoUniqueness), VName)
        (SubExp, Param (TypeBase Shape NoUniqueness), [VName], Certs)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Param (TypeBase Shape NoUniqueness), VName)
-> Maybe
     (Either
        (Param (TypeBase Shape NoUniqueness), VName)
        (SubExp, Param (TypeBase Shape NoUniqueness), [VName], Certs))
isConcatOrAcc ([Param (TypeBase Shape NoUniqueness)]
-> [VName] -> [(Param (TypeBase Shape NoUniqueness), VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda (Wise SOACS) -> [LParam (Wise SOACS)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda (Wise SOACS)
lam) [VName]
arrs),
    [[VName]]
xivs <- [[VName]] -> [[VName]]
forall a. [[a]] -> [[a]]
transpose [[VName]]
xss,
    (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (SubExp
w' ==) [SubExp]
ws = RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ StmAux (ExpWisdom, ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ Certs -> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a. Certs -> RuleM (Wise SOACS) a -> RuleM (Wise SOACS) a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying ([Certs] -> Certs
forall a. Monoid a => [a] -> a
mconcat [Certs]
css) (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ do
      -- r is the amount of arrays being concatenated.
      let r :: Int
r = [[VName]] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [[VName]]
xivs
          num_accs :: Int
num_accs = [(Param (TypeBase Shape NoUniqueness), VName)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Param (TypeBase Shape NoUniqueness), VName)]
accs
      [Lambda (Wise SOACS)]
lams <- Int
-> RuleM (Wise SOACS) (Lambda (Wise SOACS))
-> RuleM (Wise SOACS) [Lambda (Wise SOACS)]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
r (Lambda (Wise SOACS) -> RuleM (Wise SOACS) (Lambda (Wise SOACS))
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda (Wise SOACS)
lam {lambdaParams = map fst accs <> ps})
      let acc_params :: [Param (TypeBase Shape NoUniqueness)]
acc_params = ((Param (TypeBase Shape NoUniqueness), VName)
 -> Param (TypeBase Shape NoUniqueness))
-> [(Param (TypeBase Shape NoUniqueness), VName)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (TypeBase Shape NoUniqueness), VName)
-> Param (TypeBase Shape NoUniqueness)
forall a b. (a, b) -> a
fst [(Param (TypeBase Shape NoUniqueness), VName)]
accs
          input_params :: [Param (TypeBase Shape NoUniqueness)]
input_params = (Lambda (Wise SOACS) -> [Param (TypeBase Shape NoUniqueness)])
-> [Lambda (Wise SOACS)] -> [Param (TypeBase Shape NoUniqueness)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Int
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. Int -> [a] -> [a]
drop Int
num_accs ([Param (TypeBase Shape NoUniqueness)]
 -> [Param (TypeBase Shape NoUniqueness)])
-> (Lambda (Wise SOACS) -> [Param (TypeBase Shape NoUniqueness)])
-> Lambda (Wise SOACS)
-> [Param (TypeBase Shape NoUniqueness)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda (Wise SOACS) -> [Param (TypeBase Shape NoUniqueness)]
Lambda (Wise SOACS) -> [LParam (Wise SOACS)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams) [Lambda (Wise SOACS)]
lams
      Lambda (Wise SOACS)
lam' <-
        [LParam (Rep (RuleM (Wise SOACS)))]
-> RuleM (Wise SOACS) [SubExpRes]
-> RuleM (Wise SOACS) (Lambda (Rep (RuleM (Wise SOACS))))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda ([Param (TypeBase Shape NoUniqueness)]
acc_params [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. Semigroup a => a -> a -> a
<> [Param (TypeBase Shape NoUniqueness)]
input_params) (RuleM (Wise SOACS) [SubExpRes]
 -> RuleM (Wise SOACS) (Lambda (Rep (RuleM (Wise SOACS)))))
-> RuleM (Wise SOACS) [SubExpRes]
-> RuleM (Wise SOACS) (Lambda (Rep (RuleM (Wise SOACS))))
forall a b. (a -> b) -> a -> b
$
          [SubExp] -> [SubExpRes]
subExpsRes ([SubExp] -> [SubExpRes])
-> RuleM (Wise SOACS) [SubExp] -> RuleM (Wise SOACS) [SubExpRes]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [SubExp]
-> [Lambda (Rep (RuleM (Wise SOACS)))]
-> RuleM (Wise SOACS) [SubExp]
forall {f :: * -> *}.
MonadBuilder f =>
[SubExp] -> [Lambda (Rep f)] -> f [SubExp]
recurse ((Param (TypeBase Shape NoUniqueness) -> SubExp)
-> [Param (TypeBase Shape NoUniqueness)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (Param (TypeBase Shape NoUniqueness) -> VName)
-> Param (TypeBase Shape NoUniqueness)
-> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName) [Param (TypeBase Shape NoUniqueness)]
acc_params) [Lambda (Rep (RuleM (Wise SOACS)))]
[Lambda (Wise SOACS)]
lams
      Pat (LetDec (Rep (RuleM (Wise SOACS))))
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Rep (RuleM (Wise SOACS))))
Pat (LetDec (Wise SOACS))
pat (Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ Op (Rep (RuleM (Wise SOACS))) -> Exp (Rep (RuleM (Wise SOACS)))
forall rep. Op rep -> Exp rep
Op (Op (Rep (RuleM (Wise SOACS))) -> Exp (Rep (RuleM (Wise SOACS))))
-> Op (Rep (RuleM (Wise SOACS))) -> Exp (Rep (RuleM (Wise SOACS)))
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm (Wise SOACS) -> SOAC (Wise SOACS)
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w' (((Param (TypeBase Shape NoUniqueness), VName) -> VName)
-> [(Param (TypeBase Shape NoUniqueness), VName)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param (TypeBase Shape NoUniqueness), VName) -> VName
forall a b. (a, b) -> b
snd [(Param (TypeBase Shape NoUniqueness), VName)]
accs [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
xivs) (Lambda (Wise SOACS) -> ScremaForm (Wise SOACS)
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda (Wise SOACS)
lam')
  where
    recurse :: [SubExp] -> [Lambda (Rep f)] -> f [SubExp]
recurse [SubExp]
accs [] = [SubExp] -> f [SubExp]
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp]
accs
    recurse [SubExp]
accs (Lambda (Rep f)
lam : [Lambda (Rep f)]
lams) = do
      -- We know that the accumulators are the first params and that the rest are
      -- already bound.
      [(SubExp, Param (LParamInfo (Rep f)))]
-> ((SubExp, Param (LParamInfo (Rep f))) -> f ()) -> f ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SubExp]
-> [Param (LParamInfo (Rep f))]
-> [(SubExp, Param (LParamInfo (Rep f)))]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
accs (Lambda (Rep f) -> [Param (LParamInfo (Rep f))]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda (Rep f)
lam)) (((SubExp, Param (LParamInfo (Rep f))) -> f ()) -> f ())
-> ((SubExp, Param (LParamInfo (Rep f))) -> f ()) -> f ()
forall a b. (a -> b) -> a -> b
$ \(SubExp
acc, Param (LParamInfo (Rep f))
p) ->
        [VName] -> Exp (Rep f) -> f ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param (LParamInfo (Rep f)) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo (Rep f))
p] (Exp (Rep f) -> f ()) -> Exp (Rep f) -> f ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep f)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep f)) -> BasicOp -> Exp (Rep f)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
acc
      [SubExp]
accs' <- (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ([SubExpRes] -> [SubExp]) -> f [SubExpRes] -> f [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body (Rep f) -> f [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Body (Rep m) -> m [SubExpRes]
bodyBind (Lambda (Rep f) -> Body (Rep f)
forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Rep f)
lam)
      [SubExp] -> [Lambda (Rep f)] -> f [SubExp]
recurse [SubExp]
accs' [Lambda (Rep f)]
lams

    sizeOf :: VName -> Maybe SubExp
    sizeOf :: VName -> Maybe SubExp
sizeOf VName
x = Int -> TypeBase Shape NoUniqueness -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 (TypeBase Shape NoUniqueness -> SubExp)
-> (Entry (Wise SOACS) -> TypeBase Shape NoUniqueness)
-> Entry (Wise SOACS)
-> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Entry (Wise SOACS) -> TypeBase Shape NoUniqueness
forall t. Typed t => t -> TypeBase Shape NoUniqueness
typeOf (Entry (Wise SOACS) -> SubExp)
-> Maybe (Entry (Wise SOACS)) -> Maybe SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> SymbolTable (Wise SOACS) -> Maybe (Entry (Wise SOACS))
forall rep. VName -> SymbolTable rep -> Maybe (Entry rep)
ST.lookup VName
x SymbolTable (Wise SOACS)
vtable
    isConcatOrAcc :: (Param (TypeBase Shape NoUniqueness), VName)
-> Maybe
     (Either
        (Param (TypeBase Shape NoUniqueness), VName)
        (SubExp, Param (TypeBase Shape NoUniqueness), [VName], Certs))
isConcatOrAcc (p :: Param (TypeBase Shape NoUniqueness)
p@(Param Attrs
_ VName
_ Acc {}), VName
v) =
      Either
  (Param (TypeBase Shape NoUniqueness), VName)
  (SubExp, Param (TypeBase Shape NoUniqueness), [VName], Certs)
-> Maybe
     (Either
        (Param (TypeBase Shape NoUniqueness), VName)
        (SubExp, Param (TypeBase Shape NoUniqueness), [VName], Certs))
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Param (TypeBase Shape NoUniqueness), VName)
-> Either
     (Param (TypeBase Shape NoUniqueness), VName)
     (SubExp, Param (TypeBase Shape NoUniqueness), [VName], Certs)
forall a b. a -> Either a b
Left (Param (TypeBase Shape NoUniqueness)
p, VName
v))
    isConcatOrAcc (Param (TypeBase Shape NoUniqueness)
p, VName
v) = case VName
-> SymbolTable (Wise SOACS) -> Maybe (Exp (Wise SOACS), Certs)
forall rep. VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
v SymbolTable (Wise SOACS)
vtable of
      Just (BasicOp (Concat Int
0 (VName
x :| [VName]
ys) SubExp
_), Certs
cs) -> do
        SubExp
x_w <- VName -> Maybe SubExp
sizeOf VName
x
        [SubExp]
y_ws <- (VName -> Maybe SubExp) -> [VName] -> Maybe [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> Maybe SubExp
sizeOf [VName]
ys
        Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (SubExp
x_w ==) [SubExp]
y_ws
        Either
  (Param (TypeBase Shape NoUniqueness), VName)
  (SubExp, Param (TypeBase Shape NoUniqueness), [VName], Certs)
-> Maybe
     (Either
        (Param (TypeBase Shape NoUniqueness), VName)
        (SubExp, Param (TypeBase Shape NoUniqueness), [VName], Certs))
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((SubExp, Param (TypeBase Shape NoUniqueness), [VName], Certs)
-> Either
     (Param (TypeBase Shape NoUniqueness), VName)
     (SubExp, Param (TypeBase Shape NoUniqueness), [VName], Certs)
forall a b. b -> Either a b
Right (SubExp
x_w, Param (TypeBase Shape NoUniqueness)
p, VName
x VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
ys, Certs
cs))
      Just (BasicOp (Reshape VName
arr NewShape SubExp
newshape), Certs
cs)
        | ReshapeKind
ReshapeCoerce <- NewShape SubExp -> ReshapeKind
reshapeKind NewShape SubExp
newshape -> do
            Right (SubExp
a, Param (TypeBase Shape NoUniqueness)
_, [VName]
b, Certs
cs') <- (Param (TypeBase Shape NoUniqueness), VName)
-> Maybe
     (Either
        (Param (TypeBase Shape NoUniqueness), VName)
        (SubExp, Param (TypeBase Shape NoUniqueness), [VName], Certs))
isConcatOrAcc (Param (TypeBase Shape NoUniqueness)
p, VName
arr)
            Either
  (Param (TypeBase Shape NoUniqueness), VName)
  (SubExp, Param (TypeBase Shape NoUniqueness), [VName], Certs)
-> Maybe
     (Either
        (Param (TypeBase Shape NoUniqueness), VName)
        (SubExp, Param (TypeBase Shape NoUniqueness), [VName], Certs))
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((SubExp, Param (TypeBase Shape NoUniqueness), [VName], Certs)
-> Either
     (Param (TypeBase Shape NoUniqueness), VName)
     (SubExp, Param (TypeBase Shape NoUniqueness), [VName], Certs)
forall a b. b -> Either a b
Right (SubExp
a, Param (TypeBase Shape NoUniqueness)
p, [VName]
b, Certs
cs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
cs'))
      Maybe (Exp (Wise SOACS), Certs)
_ -> Maybe
  (Either
     (Param (TypeBase Shape NoUniqueness), VName)
     (SubExp, Param (TypeBase Shape NoUniqueness), [VName], Certs))
forall a. Maybe a
Nothing
fuseConcatScatter SymbolTable (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = Rule (Wise SOACS)
forall rep. Rule rep
Skip

simplifyClosedFormReduce :: TopDownRuleOp (Wise SOACS)
simplifyClosedFormReduce :: RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
simplifyClosedFormReduce SymbolTable (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
pat StmAux (ExpDec (Wise SOACS))
_ (Screma (Constant PrimValue
w) [VName]
_ ScremaForm (Wise SOACS)
form)
  | Just [SubExp]
nes <- (Reduce (Wise SOACS) -> [SubExp])
-> [Reduce (Wise SOACS)] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Reduce (Wise SOACS) -> [SubExp]
forall rep. Reduce rep -> [SubExp]
redNeutral ([Reduce (Wise SOACS)] -> [SubExp])
-> (([Reduce (Wise SOACS)], Lambda (Wise SOACS))
    -> [Reduce (Wise SOACS)])
-> ([Reduce (Wise SOACS)], Lambda (Wise SOACS))
-> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Reduce (Wise SOACS)], Lambda (Wise SOACS))
-> [Reduce (Wise SOACS)]
forall a b. (a, b) -> a
fst (([Reduce (Wise SOACS)], Lambda (Wise SOACS)) -> [SubExp])
-> Maybe ([Reduce (Wise SOACS)], Lambda (Wise SOACS))
-> Maybe [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ScremaForm (Wise SOACS)
-> Maybe ([Reduce (Wise SOACS)], Lambda (Wise SOACS))
forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm (Wise SOACS)
form,
    PrimValue -> Bool
zeroIsh PrimValue
w =
      RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> (((VName, SubExp) -> RuleM (Wise SOACS) ())
    -> RuleM (Wise SOACS) ())
-> ((VName, SubExp) -> RuleM (Wise SOACS) ())
-> Rule (Wise SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(VName, SubExp)]
-> ((VName, SubExp) -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat (VarWisdom, TypeBase Shape NoUniqueness) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (VarWisdom, TypeBase Shape NoUniqueness)
Pat (LetDec (Wise SOACS))
pat) [SubExp]
nes) (((VName, SubExp) -> RuleM (Wise SOACS) ()) -> Rule (Wise SOACS))
-> ((VName, SubExp) -> RuleM (Wise SOACS) ()) -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ \(VName
v, SubExp
ne) ->
        [VName] -> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
v] (Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM (Wise SOACS)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM (Wise SOACS))))
-> BasicOp -> Exp (Rep (RuleM (Wise SOACS)))
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
ne
simplifyClosedFormReduce SymbolTable (Wise SOACS)
vtable Pat (LetDec (Wise SOACS))
pat StmAux (ExpDec (Wise SOACS))
_ (Screma SubExp
_ [VName]
arrs ScremaForm (Wise SOACS)
form)
  | Just [Reduce Commutativity
_ Lambda (Wise SOACS)
red_fun [SubExp]
nes] <- ScremaForm (Wise SOACS) -> Maybe [Reduce (Wise SOACS)]
forall rep. ScremaForm rep -> Maybe [Reduce rep]
isReduceSOAC ScremaForm (Wise SOACS)
form =
      RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ VarLookup (Wise SOACS)
-> Pat (LetDec (Wise SOACS))
-> Lambda (Wise SOACS)
-> [SubExp]
-> [VName]
-> RuleM (Wise SOACS) ()
forall rep.
BuilderOps rep =>
VarLookup rep
-> Pat (LetDec rep)
-> Lambda rep
-> [SubExp]
-> [VName]
-> RuleM rep ()
foldClosedForm (VName
-> SymbolTable (Wise SOACS) -> Maybe (Exp (Wise SOACS), Certs)
forall rep. VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
`ST.lookupExp` SymbolTable (Wise SOACS)
vtable) Pat (LetDec (Wise SOACS))
pat Lambda (Wise SOACS)
red_fun [SubExp]
nes [VName]
arrs
simplifyClosedFormReduce SymbolTable (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = Rule (Wise SOACS)
forall rep. Rule rep
Skip

-- For now we just remove singleton SOACs and those with unroll attributes.
simplifyKnownIterationSOAC ::
  (Buildable rep, BuilderOps rep, HasSOAC rep) =>
  TopDownRuleOp rep
simplifyKnownIterationSOAC :: forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
simplifyKnownIterationSOAC TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ OpC rep rep
op
  | Just (Screma (Constant PrimValue
k) [VName]
arrs (ScremaForm Lambda rep
map_lam [Scan rep]
scans [Reduce rep]
reds)) <- OpC rep rep -> Maybe (SOAC rep)
forall rep. HasSOAC rep => Op rep -> Maybe (SOAC rep)
asSOAC OpC rep rep
op,
    PrimValue -> Bool
oneIsh PrimValue
k = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
      let (Reduce Commutativity
_ Lambda rep
red_lam [SubExp]
red_nes) = [Reduce rep] -> Reduce rep
forall rep. Buildable rep => [Reduce rep] -> Reduce rep
singleReduce [Reduce rep]
reds
          (Scan Lambda rep
scan_lam [SubExp]
scan_nes) = [Scan rep] -> Scan rep
forall rep. Buildable rep => [Scan rep] -> Scan rep
singleScan [Scan rep]
scans
          ([PatElem (LetDec rep)]
scan_pes, [PatElem (LetDec rep)]
red_pes, [PatElem (LetDec rep)]
map_pes) =
            Int
-> Int
-> [PatElem (LetDec rep)]
-> ([PatElem (LetDec rep)], [PatElem (LetDec rep)],
    [PatElem (LetDec rep)])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
scan_nes) ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes) ([PatElem (LetDec rep)]
 -> ([PatElem (LetDec rep)], [PatElem (LetDec rep)],
     [PatElem (LetDec rep)]))
-> [PatElem (LetDec rep)]
-> ([PatElem (LetDec rep)], [PatElem (LetDec rep)],
    [PatElem (LetDec rep)])
forall a b. (a -> b) -> a -> b
$
              Pat (LetDec rep) -> [PatElem (LetDec rep)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat
          bindMapParam :: Param dec -> VName -> m ()
bindMapParam Param dec
p VName
a = do
            TypeBase Shape NoUniqueness
a_t <- VName -> m (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
a
            [VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p] (Exp (Rep m) -> m ())
-> (BasicOp -> Exp (Rep m)) -> BasicOp -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> m ()) -> BasicOp -> m ()
forall a b. (a -> b) -> a -> b
$
              if TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
isAcc TypeBase Shape NoUniqueness
a_t
                then SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
a
                else VName -> Slice SubExp -> BasicOp
Index VName
a (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ TypeBase Shape NoUniqueness -> [DimIndex SubExp] -> Slice SubExp
fullSlice TypeBase Shape NoUniqueness
a_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64)]
          bindArrayResult :: PatElem dec -> SubExpRes -> m ()
bindArrayResult PatElem dec
pe (SubExpRes Certs
cs SubExp
se) =
            Certs -> m () -> m ()
forall a. Certs -> m a -> m a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (m () -> m ()) -> (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElem dec -> VName
forall dec. PatElem dec -> VName
patElemName PatElem dec
pe] (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$
              BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$
                [SubExp] -> TypeBase Shape NoUniqueness -> BasicOp
ArrayLit [SubExp
se] (TypeBase Shape NoUniqueness -> BasicOp)
-> TypeBase Shape NoUniqueness -> BasicOp
forall a b. (a -> b) -> a -> b
$
                  TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. TypeBase Shape u -> TypeBase Shape u
rowType (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall a b. (a -> b) -> a -> b
$
                    PatElem dec -> TypeBase Shape NoUniqueness
forall dec. Typed dec => PatElem dec -> TypeBase Shape NoUniqueness
patElemType PatElem dec
pe
          bindResult :: PatElem dec -> SubExpRes -> m ()
bindResult PatElem dec
pe (SubExpRes Certs
cs SubExp
se) =
            Certs -> m () -> m ()
forall a. Certs -> m a -> m a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElem dec -> VName
forall dec. PatElem dec -> VName
patElemName PatElem dec
pe] (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se

      (Param (TypeBase Shape NoUniqueness) -> VName -> RuleM rep ())
-> [Param (TypeBase Shape NoUniqueness)] -> [VName] -> RuleM rep ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param (TypeBase Shape NoUniqueness) -> VName -> RuleM rep ()
forall {m :: * -> *} {dec}.
MonadBuilder m =>
Param dec -> VName -> m ()
bindMapParam (Lambda rep -> [LParam rep]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
map_lam) [VName]
arrs
      ([SubExpRes]
to_scan, [SubExpRes]
to_red, [SubExpRes]
map_res) <-
        Int
-> Int -> [SubExpRes] -> ([SubExpRes], [SubExpRes], [SubExpRes])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
scan_nes) ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes)
          ([SubExpRes] -> ([SubExpRes], [SubExpRes], [SubExpRes]))
-> RuleM rep [SubExpRes]
-> RuleM rep ([SubExpRes], [SubExpRes], [SubExpRes])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body (Rep (RuleM rep)) -> RuleM rep [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Body (Rep m) -> m [SubExpRes]
bodyBind (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
map_lam)
      [SubExpRes]
scan_res <- Lambda (Rep (RuleM rep))
-> [RuleM rep (Exp (Rep (RuleM rep)))] -> RuleM rep [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda rep
Lambda (Rep (RuleM rep))
scan_lam ([RuleM rep (Exp (Rep (RuleM rep)))] -> RuleM rep [SubExpRes])
-> [RuleM rep (Exp (Rep (RuleM rep)))] -> RuleM rep [SubExpRes]
forall a b. (a -> b) -> a -> b
$ (SubExp -> RuleM rep (Exp (Rep (RuleM rep))))
-> [SubExp] -> [RuleM rep (Exp (Rep (RuleM rep)))]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> RuleM rep (Exp (Rep (RuleM rep)))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp ([SubExp] -> [RuleM rep (Exp (Rep (RuleM rep)))])
-> [SubExp] -> [RuleM rep (Exp (Rep (RuleM rep)))]
forall a b. (a -> b) -> a -> b
$ [SubExp]
scan_nes [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp [SubExpRes]
to_scan
      [SubExpRes]
red_res <- Lambda (Rep (RuleM rep))
-> [RuleM rep (Exp (Rep (RuleM rep)))] -> RuleM rep [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda rep
Lambda (Rep (RuleM rep))
red_lam ([RuleM rep (Exp (Rep (RuleM rep)))] -> RuleM rep [SubExpRes])
-> [RuleM rep (Exp (Rep (RuleM rep)))] -> RuleM rep [SubExpRes]
forall a b. (a -> b) -> a -> b
$ (SubExp -> RuleM rep (Exp (Rep (RuleM rep))))
-> [SubExp] -> [RuleM rep (Exp (Rep (RuleM rep)))]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> RuleM rep (Exp (Rep (RuleM rep)))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp ([SubExp] -> [RuleM rep (Exp (Rep (RuleM rep)))])
-> [SubExp] -> [RuleM rep (Exp (Rep (RuleM rep)))]
forall a b. (a -> b) -> a -> b
$ [SubExp]
red_nes [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp [SubExpRes]
to_red

      (PatElem (LetDec rep) -> SubExpRes -> RuleM rep ())
-> [PatElem (LetDec rep)] -> [SubExpRes] -> RuleM rep ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ PatElem (LetDec rep) -> SubExpRes -> RuleM rep ()
forall {m :: * -> *} {dec}.
(MonadBuilder m, Typed dec) =>
PatElem dec -> SubExpRes -> m ()
bindArrayResult [PatElem (LetDec rep)]
scan_pes [SubExpRes]
scan_res
      (PatElem (LetDec rep) -> SubExpRes -> RuleM rep ())
-> [PatElem (LetDec rep)] -> [SubExpRes] -> RuleM rep ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ PatElem (LetDec rep) -> SubExpRes -> RuleM rep ()
forall {m :: * -> *} {dec}.
MonadBuilder m =>
PatElem dec -> SubExpRes -> m ()
bindResult [PatElem (LetDec rep)]
red_pes [SubExpRes]
red_res
      (PatElem (LetDec rep) -> SubExpRes -> RuleM rep ())
-> [PatElem (LetDec rep)] -> [SubExpRes] -> RuleM rep ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ PatElem (LetDec rep) -> SubExpRes -> RuleM rep ()
forall {m :: * -> *} {dec}.
(MonadBuilder m, Typed dec) =>
PatElem dec -> SubExpRes -> m ()
bindArrayResult [PatElem (LetDec rep)]
map_pes [SubExpRes]
map_res
simplifyKnownIterationSOAC TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ OpC rep rep
op
  | Just (Stream (Constant PrimValue
k) [VName]
arrs [SubExp]
nes Lambda rep
fold_lam) <- OpC rep rep -> Maybe (SOAC rep)
forall rep. HasSOAC rep => Op rep -> Maybe (SOAC rep)
asSOAC OpC rep rep
op,
    PrimValue -> Bool
oneIsh PrimValue
k = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
      let (Param (TypeBase Shape NoUniqueness)
chunk_param, [Param (TypeBase Shape NoUniqueness)]
acc_params, [Param (TypeBase Shape NoUniqueness)]
slice_params) =
            Int
-> [Param (TypeBase Shape NoUniqueness)]
-> (Param (TypeBase Shape NoUniqueness),
    [Param (TypeBase Shape NoUniqueness)],
    [Param (TypeBase Shape NoUniqueness)])
forall dec.
Int -> [Param dec] -> (Param dec, [Param dec], [Param dec])
partitionChunkedFoldParameters ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) (Lambda rep -> [LParam rep]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
fold_lam)

      [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
chunk_param] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
        BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$
          SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
            IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1

      [(Param (TypeBase Shape NoUniqueness), SubExp)]
-> ((Param (TypeBase Shape NoUniqueness), SubExp) -> RuleM rep ())
-> RuleM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (TypeBase Shape NoUniqueness)]
-> [SubExp] -> [(Param (TypeBase Shape NoUniqueness), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase Shape NoUniqueness)]
acc_params [SubExp]
nes) (((Param (TypeBase Shape NoUniqueness), SubExp) -> RuleM rep ())
 -> RuleM rep ())
-> ((Param (TypeBase Shape NoUniqueness), SubExp) -> RuleM rep ())
-> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ \(Param (TypeBase Shape NoUniqueness)
p, SubExp
ne) ->
        [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
p] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
ne

      [(Param (TypeBase Shape NoUniqueness), VName)]
-> ((Param (TypeBase Shape NoUniqueness), VName) -> RuleM rep ())
-> RuleM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (TypeBase Shape NoUniqueness)]
-> [VName] -> [(Param (TypeBase Shape NoUniqueness), VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase Shape NoUniqueness)]
slice_params [VName]
arrs) (((Param (TypeBase Shape NoUniqueness), VName) -> RuleM rep ())
 -> RuleM rep ())
-> ((Param (TypeBase Shape NoUniqueness), VName) -> RuleM rep ())
-> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ \(Param (TypeBase Shape NoUniqueness)
p, VName
arr) ->
        [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
p] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr

      [SubExpRes]
res <- Body (Rep (RuleM rep)) -> RuleM rep [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Body (Rep m) -> m [SubExpRes]
bodyBind (Body (Rep (RuleM rep)) -> RuleM rep [SubExpRes])
-> Body (Rep (RuleM rep)) -> RuleM rep [SubExpRes]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
fold_lam

      [(VName, SubExpRes)]
-> ((VName, SubExpRes) -> RuleM rep ()) -> RuleM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExpRes] -> [(VName, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat) [SubExpRes]
res) (((VName, SubExpRes) -> RuleM rep ()) -> RuleM rep ())
-> ((VName, SubExpRes) -> RuleM rep ()) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ \(VName
v, SubExpRes Certs
cs SubExp
se) ->
        Certs -> RuleM rep () -> RuleM rep ()
forall a. Certs -> RuleM rep a -> RuleM rep a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
v] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
--
simplifyKnownIterationSOAC TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux OpC rep rep
op
  | Just (Screma (Constant (IntValue (Int64Value Int64
k))) [VName]
arrs (ScremaForm Lambda rep
map_lam [] [])) <- OpC rep rep -> Maybe (SOAC rep)
forall rep. HasSOAC rep => Op rep -> Maybe (SOAC rep)
asSOAC OpC rep rep
op,
    Attr
"unroll" Attr -> Attrs -> Bool
`inAttrs` StmAux (ExpDec rep) -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux (ExpDec rep)
aux = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
      [[SubExpRes]]
arrs_elems <- ([[SubExpRes]] -> [[SubExpRes]])
-> RuleM rep [[SubExpRes]] -> RuleM rep [[SubExpRes]]
forall a b. (a -> b) -> RuleM rep a -> RuleM rep b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [[SubExpRes]] -> [[SubExpRes]]
forall a. [[a]] -> [[a]]
transpose (RuleM rep [[SubExpRes]] -> RuleM rep [[SubExpRes]])
-> ((Int64 -> RuleM rep [SubExpRes]) -> RuleM rep [[SubExpRes]])
-> (Int64 -> RuleM rep [SubExpRes])
-> RuleM rep [[SubExpRes]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int64]
-> (Int64 -> RuleM rep [SubExpRes]) -> RuleM rep [[SubExpRes]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int64
0 .. Int64
k Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
- Int64
1] ((Int64 -> RuleM rep [SubExpRes]) -> RuleM rep [[SubExpRes]])
-> (Int64 -> RuleM rep [SubExpRes]) -> RuleM rep [[SubExpRes]]
forall a b. (a -> b) -> a -> b
$ \Int64
i -> do
        Lambda rep
map_lam' <- Lambda rep -> RuleM rep (Lambda rep)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda rep
map_lam
        Lambda (Rep (RuleM rep))
-> [RuleM rep (Exp (Rep (RuleM rep)))] -> RuleM rep [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda rep
Lambda (Rep (RuleM rep))
map_lam' ([RuleM rep (Exp (Rep (RuleM rep)))] -> RuleM rep [SubExpRes])
-> [RuleM rep (Exp (Rep (RuleM rep)))] -> RuleM rep [SubExpRes]
forall a b. (a -> b) -> a -> b
$ (VName -> RuleM rep (Exp rep)) -> [VName] -> [RuleM rep (Exp rep)]
forall a b. (a -> b) -> [a] -> [b]
map (VName
-> [RuleM rep (Exp (Rep (RuleM rep)))]
-> RuleM rep (Exp (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
`eIndex` [SubExp -> RuleM rep (Exp (Rep (RuleM rep)))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant Int64
i)]) [VName]
arrs
      [(VName, [SubExpRes], TypeBase Shape NoUniqueness)]
-> ((VName, [SubExpRes], TypeBase Shape NoUniqueness)
    -> RuleM rep ())
-> RuleM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName]
-> [[SubExpRes]]
-> [TypeBase Shape NoUniqueness]
-> [(VName, [SubExpRes], TypeBase Shape NoUniqueness)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat) [[SubExpRes]]
arrs_elems (Lambda rep -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda rep
map_lam)) (((VName, [SubExpRes], TypeBase Shape NoUniqueness)
  -> RuleM rep ())
 -> RuleM rep ())
-> ((VName, [SubExpRes], TypeBase Shape NoUniqueness)
    -> RuleM rep ())
-> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
        \(VName
v, [SubExpRes]
arr_elems, TypeBase Shape NoUniqueness
t) ->
          Certs -> RuleM rep () -> RuleM rep ()
forall a. Certs -> RuleM rep a -> RuleM rep a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying ([Certs] -> Certs
forall a. Monoid a => [a] -> a
mconcat ((SubExpRes -> Certs) -> [SubExpRes] -> [Certs]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> Certs
resCerts [SubExpRes]
arr_elems)) (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
            [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
v] (Exp rep -> RuleM rep ())
-> (BasicOp -> Exp rep) -> BasicOp -> RuleM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> RuleM rep ()) -> BasicOp -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
              [SubExp] -> TypeBase Shape NoUniqueness -> BasicOp
ArrayLit ((SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp [SubExpRes]
arr_elems) TypeBase Shape NoUniqueness
t
--
simplifyKnownIterationSOAC TopDown rep
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ OpC rep rep
_ = Rule rep
forall rep. Rule rep
Skip

data ArrayOp
  = ArrayIndexing Certs VName (Slice SubExp)
  | ArrayRearrange Certs VName [Int]
  | ArrayReshape Certs VName (NewShape SubExp)
  | ArrayCopy Certs VName
  | -- | Never constructed.
    ArrayVar Certs VName
  deriving (ArrayOp -> ArrayOp -> Bool
(ArrayOp -> ArrayOp -> Bool)
-> (ArrayOp -> ArrayOp -> Bool) -> Eq ArrayOp
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ArrayOp -> ArrayOp -> Bool
== :: ArrayOp -> ArrayOp -> Bool
$c/= :: ArrayOp -> ArrayOp -> Bool
/= :: ArrayOp -> ArrayOp -> Bool
Eq, Eq ArrayOp
Eq ArrayOp =>
(ArrayOp -> ArrayOp -> Ordering)
-> (ArrayOp -> ArrayOp -> Bool)
-> (ArrayOp -> ArrayOp -> Bool)
-> (ArrayOp -> ArrayOp -> Bool)
-> (ArrayOp -> ArrayOp -> Bool)
-> (ArrayOp -> ArrayOp -> ArrayOp)
-> (ArrayOp -> ArrayOp -> ArrayOp)
-> Ord ArrayOp
ArrayOp -> ArrayOp -> Bool
ArrayOp -> ArrayOp -> Ordering
ArrayOp -> ArrayOp -> ArrayOp
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: ArrayOp -> ArrayOp -> Ordering
compare :: ArrayOp -> ArrayOp -> Ordering
$c< :: ArrayOp -> ArrayOp -> Bool
< :: ArrayOp -> ArrayOp -> Bool
$c<= :: ArrayOp -> ArrayOp -> Bool
<= :: ArrayOp -> ArrayOp -> Bool
$c> :: ArrayOp -> ArrayOp -> Bool
> :: ArrayOp -> ArrayOp -> Bool
$c>= :: ArrayOp -> ArrayOp -> Bool
>= :: ArrayOp -> ArrayOp -> Bool
$cmax :: ArrayOp -> ArrayOp -> ArrayOp
max :: ArrayOp -> ArrayOp -> ArrayOp
$cmin :: ArrayOp -> ArrayOp -> ArrayOp
min :: ArrayOp -> ArrayOp -> ArrayOp
Ord, Int -> ArrayOp -> ShowS
[ArrayOp] -> ShowS
ArrayOp -> String
(Int -> ArrayOp -> ShowS)
-> (ArrayOp -> String) -> ([ArrayOp] -> ShowS) -> Show ArrayOp
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ArrayOp -> ShowS
showsPrec :: Int -> ArrayOp -> ShowS
$cshow :: ArrayOp -> String
show :: ArrayOp -> String
$cshowList :: [ArrayOp] -> ShowS
showList :: [ArrayOp] -> ShowS
Show)

arrayOpArr :: ArrayOp -> VName
arrayOpArr :: ArrayOp -> VName
arrayOpArr (ArrayIndexing Certs
_ VName
arr Slice SubExp
_) = VName
arr
arrayOpArr (ArrayRearrange Certs
_ VName
arr [Int]
_) = VName
arr
arrayOpArr (ArrayReshape Certs
_ VName
arr NewShape SubExp
_) = VName
arr
arrayOpArr (ArrayCopy Certs
_ VName
arr) = VName
arr
arrayOpArr (ArrayVar Certs
_ VName
arr) = VName
arr

arrayOpCerts :: ArrayOp -> Certs
arrayOpCerts :: ArrayOp -> Certs
arrayOpCerts (ArrayIndexing Certs
cs VName
_ Slice SubExp
_) = Certs
cs
arrayOpCerts (ArrayRearrange Certs
cs VName
_ [Int]
_) = Certs
cs
arrayOpCerts (ArrayReshape Certs
cs VName
_ NewShape SubExp
_) = Certs
cs
arrayOpCerts (ArrayCopy Certs
cs VName
_) = Certs
cs
arrayOpCerts (ArrayVar Certs
cs VName
_) = Certs
cs

isArrayOp :: Certs -> Exp rep -> Maybe ArrayOp
isArrayOp :: forall rep. Certs -> Exp rep -> Maybe ArrayOp
isArrayOp Certs
cs (BasicOp (Index VName
arr Slice SubExp
slice)) =
  ArrayOp -> Maybe ArrayOp
forall a. a -> Maybe a
Just (ArrayOp -> Maybe ArrayOp) -> ArrayOp -> Maybe ArrayOp
forall a b. (a -> b) -> a -> b
$ Certs -> VName -> Slice SubExp -> ArrayOp
ArrayIndexing Certs
cs VName
arr Slice SubExp
slice
isArrayOp Certs
cs (BasicOp (Rearrange VName
arr [Int]
perm)) =
  ArrayOp -> Maybe ArrayOp
forall a. a -> Maybe a
Just (ArrayOp -> Maybe ArrayOp) -> ArrayOp -> Maybe ArrayOp
forall a b. (a -> b) -> a -> b
$ Certs -> VName -> [Int] -> ArrayOp
ArrayRearrange Certs
cs VName
arr [Int]
perm
isArrayOp Certs
cs (BasicOp (Reshape VName
arr NewShape SubExp
new_shape)) =
  ArrayOp -> Maybe ArrayOp
forall a. a -> Maybe a
Just (ArrayOp -> Maybe ArrayOp) -> ArrayOp -> Maybe ArrayOp
forall a b. (a -> b) -> a -> b
$ Certs -> VName -> NewShape SubExp -> ArrayOp
ArrayReshape Certs
cs VName
arr NewShape SubExp
new_shape
isArrayOp Certs
cs (BasicOp (Replicate (Shape []) (Var VName
arr))) =
  ArrayOp -> Maybe ArrayOp
forall a. a -> Maybe a
Just (ArrayOp -> Maybe ArrayOp) -> ArrayOp -> Maybe ArrayOp
forall a b. (a -> b) -> a -> b
$ Certs -> VName -> ArrayOp
ArrayCopy Certs
cs VName
arr
isArrayOp Certs
_ Exp rep
_ =
  Maybe ArrayOp
forall a. Maybe a
Nothing

fromArrayOp :: ArrayOp -> (Certs, Exp rep)
fromArrayOp :: forall rep. ArrayOp -> (Certs, Exp rep)
fromArrayOp (ArrayIndexing Certs
cs VName
arr Slice SubExp
slice) = (Certs
cs, BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr Slice SubExp
slice)
fromArrayOp (ArrayRearrange Certs
cs VName
arr [Int]
perm) = (Certs
cs, BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ VName -> [Int] -> BasicOp
Rearrange VName
arr [Int]
perm)
fromArrayOp (ArrayReshape Certs
cs VName
arr NewShape SubExp
new_shape) = (Certs
cs, BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ VName -> NewShape SubExp -> BasicOp
Reshape VName
arr NewShape SubExp
new_shape)
fromArrayOp (ArrayCopy Certs
cs VName
arr) = (Certs
cs, BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
forall a. Monoid a => a
mempty (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr)
fromArrayOp (ArrayVar Certs
cs VName
arr) = (Certs
cs, BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr)

arrayOps ::
  forall rep.
  (Buildable rep, HasSOAC rep) =>
  Certs ->
  Body rep ->
  S.Set (Pat (LetDec rep), ArrayOp)
arrayOps :: forall rep.
(Buildable rep, HasSOAC rep) =>
Certs -> Body rep -> Set (Pat (LetDec rep), ArrayOp)
arrayOps Certs
cs = [Set (Pat (LetDec rep), ArrayOp)]
-> Set (Pat (LetDec rep), ArrayOp)
forall a. Monoid a => [a] -> a
mconcat ([Set (Pat (LetDec rep), ArrayOp)]
 -> Set (Pat (LetDec rep), ArrayOp))
-> (Body rep -> [Set (Pat (LetDec rep), ArrayOp)])
-> Body rep
-> Set (Pat (LetDec rep), ArrayOp)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm rep -> Set (Pat (LetDec rep), ArrayOp))
-> [Stm rep] -> [Set (Pat (LetDec rep), ArrayOp)]
forall a b. (a -> b) -> [a] -> [b]
map Stm rep -> Set (Pat (LetDec rep), ArrayOp)
onStm ([Stm rep] -> [Set (Pat (LetDec rep), ArrayOp)])
-> (Body rep -> [Stm rep])
-> Body rep
-> [Set (Pat (LetDec rep), ArrayOp)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms rep -> [Stm rep])
-> (Body rep -> Stms rep) -> Body rep -> [Stm rep]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body rep -> Stms rep
forall rep res. GBody rep res -> Stms rep
bodyStms
  where
    -- It is not safe to move everything out of branches (#1874) or
    -- loops (#2015); probably we need to put some more intelligence
    -- in here somehow.
    onStm :: Stm rep -> Set (Pat (LetDec rep), ArrayOp)
onStm (Let Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ Match {}) = Set (Pat (LetDec rep), ArrayOp)
forall a. Monoid a => a
mempty
    onStm (Let Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ Loop {}) = Set (Pat (LetDec rep), ArrayOp)
forall a. Monoid a => a
mempty
    onStm (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux Exp rep
e) =
      case Certs -> Exp rep -> Maybe ArrayOp
forall rep. Certs -> Exp rep -> Maybe ArrayOp
isArrayOp (Certs
cs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> StmAux (ExpDec rep) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec rep)
aux) Exp rep
e of
        Just ArrayOp
op -> (Pat (LetDec rep), ArrayOp) -> Set (Pat (LetDec rep), ArrayOp)
forall a. a -> Set a
S.singleton (Pat (LetDec rep)
pat, ArrayOp
op)
        Maybe ArrayOp
Nothing -> State (Set (Pat (LetDec rep), ArrayOp)) ()
-> Set (Pat (LetDec rep), ArrayOp)
-> Set (Pat (LetDec rep), ArrayOp)
forall s a. State s a -> s -> s
execState (Walker rep (StateT (Set (Pat (LetDec rep), ArrayOp)) Identity)
-> Exp rep -> State (Set (Pat (LetDec rep), ArrayOp)) ()
forall (m :: * -> *) rep.
Monad m =>
Walker rep m -> Exp rep -> m ()
walkExpM (Certs
-> Walker rep (StateT (Set (Pat (LetDec rep), ArrayOp)) Identity)
walker (StmAux (ExpDec rep) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec rep)
aux)) Exp rep
e) Set (Pat (LetDec rep), ArrayOp)
forall a. Monoid a => a
mempty
    onOp :: Certs -> OpC rep rep -> Set (Pat (LetDec rep), ArrayOp)
onOp Certs
more_cs OpC rep rep
op
      | Just SOAC rep
soac <- OpC rep rep -> Maybe (SOAC rep)
forall rep. HasSOAC rep => Op rep -> Maybe (SOAC rep)
asSOAC OpC rep rep
op =
          -- Copies are not safe to move out of nested ops (#1753).
          ((Pat (LetDec rep), ArrayOp) -> Bool)
-> Set (Pat (LetDec rep), ArrayOp)
-> Set (Pat (LetDec rep), ArrayOp)
forall a. (a -> Bool) -> Set a -> Set a
S.filter (ArrayOp -> Bool
notCopy (ArrayOp -> Bool)
-> ((Pat (LetDec rep), ArrayOp) -> ArrayOp)
-> (Pat (LetDec rep), ArrayOp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Pat (LetDec rep), ArrayOp) -> ArrayOp
forall a b. (a, b) -> b
snd) (Set (Pat (LetDec rep), ArrayOp)
 -> Set (Pat (LetDec rep), ArrayOp))
-> Set (Pat (LetDec rep), ArrayOp)
-> Set (Pat (LetDec rep), ArrayOp)
forall a b. (a -> b) -> a -> b
$
            Writer (Set (Pat (LetDec rep), ArrayOp)) (SOAC rep)
-> Set (Pat (LetDec rep), ArrayOp)
forall w a. Writer w a -> w
execWriter (Writer (Set (Pat (LetDec rep), ArrayOp)) (SOAC rep)
 -> Set (Pat (LetDec rep), ArrayOp))
-> Writer (Set (Pat (LetDec rep), ArrayOp)) (SOAC rep)
-> Set (Pat (LetDec rep), ArrayOp)
forall a b. (a -> b) -> a -> b
$
              SOACMapper
  rep rep (WriterT (Set (Pat (LetDec rep), ArrayOp)) Identity)
-> SOAC rep -> Writer (Set (Pat (LetDec rep), ArrayOp)) (SOAC rep)
forall (m :: * -> *) frep trep.
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM
                SOACMapper
  Any Any (WriterT (Set (Pat (LetDec rep), ArrayOp)) Identity)
forall rep (m :: * -> *). Monad m => SOACMapper rep rep m
identitySOACMapper {mapOnSOACLambda = onLambda more_cs}
                (SOAC rep
soac :: SOAC rep)
      | Bool
otherwise =
          Set (Pat (LetDec rep), ArrayOp)
forall a. Monoid a => a
mempty
    onLambda :: Certs
-> Lambda rep
-> WriterT (Set (Pat (LetDec rep), ArrayOp)) Identity (Lambda rep)
onLambda Certs
more_cs Lambda rep
lam = do
      Set (Pat (LetDec rep), ArrayOp)
-> WriterT (Set (Pat (LetDec rep), ArrayOp)) Identity ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (Set (Pat (LetDec rep), ArrayOp)
 -> WriterT (Set (Pat (LetDec rep), ArrayOp)) Identity ())
-> Set (Pat (LetDec rep), ArrayOp)
-> WriterT (Set (Pat (LetDec rep), ArrayOp)) Identity ()
forall a b. (a -> b) -> a -> b
$ Certs -> Body rep -> Set (Pat (LetDec rep), ArrayOp)
forall rep.
(Buildable rep, HasSOAC rep) =>
Certs -> Body rep -> Set (Pat (LetDec rep), ArrayOp)
arrayOps (Certs
cs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
more_cs) (Body rep -> Set (Pat (LetDec rep), ArrayOp))
-> Body rep -> Set (Pat (LetDec rep), ArrayOp)
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam
      Lambda rep
-> WriterT (Set (Pat (LetDec rep), ArrayOp)) Identity (Lambda rep)
forall a. a -> WriterT (Set (Pat (LetDec rep), ArrayOp)) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda rep
lam
    walker :: Certs
-> Walker rep (StateT (Set (Pat (LetDec rep), ArrayOp)) Identity)
walker Certs
more_cs =
      (forall rep (m :: * -> *). Monad m => Walker rep m
identityWalker @rep)
        { walkOnBody = const $ modify . (<>) . arrayOps (cs <> more_cs),
          walkOnOp = modify . (<>) . onOp more_cs
        }
    notCopy :: ArrayOp -> Bool
notCopy (ArrayCopy {}) = Bool
False
    notCopy ArrayOp
_ = Bool
True

replaceArrayOps ::
  forall rep.
  (Buildable rep, BuilderOps rep, HasSOAC rep) =>
  M.Map (Pat (LetDec rep)) ArrayOp ->
  Body rep ->
  Body rep
replaceArrayOps :: forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
Map (Pat (LetDec rep)) ArrayOp -> Body rep -> Body rep
replaceArrayOps Map (Pat (LetDec rep)) ArrayOp
substs (Body BodyDec rep
_ Stms rep
stms [SubExpRes]
res) =
  Stms rep -> [SubExpRes] -> GBody rep SubExpRes
forall res. IsResult res => Stms rep -> [res] -> GBody rep res
forall rep res.
(Buildable rep, IsResult res) =>
Stms rep -> [res] -> GBody rep res
mkBody ((Stm rep -> Stm rep) -> Stms rep -> Stms rep
forall a b. (a -> b) -> Seq a -> Seq b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Stm rep -> Stm rep
onStm Stms rep
stms) [SubExpRes]
res
  where
    onStm :: Stm rep -> Stm rep
onStm (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux Exp rep
e) =
      let (Certs
cs', Exp rep
e') =
            (Certs, Exp rep)
-> (ArrayOp -> (Certs, Exp rep))
-> Maybe ArrayOp
-> (Certs, Exp rep)
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Certs
forall a. Monoid a => a
mempty, Mapper rep rep Identity -> Exp rep -> Exp rep
forall frep trep. Mapper frep trep Identity -> Exp frep -> Exp trep
mapExp Mapper rep rep Identity
mapper Exp rep
e) ArrayOp -> (Certs, Exp rep)
forall rep. ArrayOp -> (Certs, Exp rep)
fromArrayOp (Maybe ArrayOp -> (Certs, Exp rep))
-> Maybe ArrayOp -> (Certs, Exp rep)
forall a b. (a -> b) -> a -> b
$ Pat (LetDec rep) -> Map (Pat (LetDec rep)) ArrayOp -> Maybe ArrayOp
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Pat (LetDec rep)
pat Map (Pat (LetDec rep)) ArrayOp
substs
       in Certs -> Stm rep -> Stm rep
forall rep. Certs -> Stm rep -> Stm rep
certify Certs
cs' (Stm rep -> Stm rep) -> Stm rep -> Stm rep
forall a b. (a -> b) -> a -> b
$ [Ident] -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep a.
Buildable rep =>
[Ident] -> StmAux a -> Exp rep -> Stm rep
mkLet' (Pat (LetDec rep) -> [Ident]
forall dec. Typed dec => Pat dec -> [Ident]
patIdents Pat (LetDec rep)
pat) StmAux (ExpDec rep)
aux Exp rep
e'
    mapper :: Mapper rep rep Identity
mapper =
      (forall rep (m :: * -> *). Monad m => Mapper rep rep m
identityMapper @rep)
        { mapOnBody = const $ pure . replaceArrayOps substs,
          mapOnOp = pure . onOp
        }
    onOp :: OpC rep rep -> OpC rep rep
onOp OpC rep rep
op
      | Just (SOAC rep
soac :: SOAC rep) <- OpC rep rep -> Maybe (SOAC rep)
forall rep. HasSOAC rep => Op rep -> Maybe (SOAC rep)
asSOAC OpC rep rep
op =
          SOAC rep -> OpC rep rep
forall rep. HasSOAC rep => SOAC rep -> Op rep
soacOp (SOAC rep -> OpC rep rep)
-> (Identity (SOAC rep) -> SOAC rep)
-> Identity (SOAC rep)
-> OpC rep rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Identity (SOAC rep) -> SOAC rep
forall a. Identity a -> a
runIdentity (Identity (SOAC rep) -> OpC rep rep)
-> Identity (SOAC rep) -> OpC rep rep
forall a b. (a -> b) -> a -> b
$
            SOACMapper rep rep Identity -> SOAC rep -> Identity (SOAC rep)
forall (m :: * -> *) frep trep.
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper Any Any Identity
forall rep (m :: * -> *). Monad m => SOACMapper rep rep m
identitySOACMapper {mapOnSOACLambda = pure . onLambda} SOAC rep
soac
      | Bool
otherwise =
          OpC rep rep
op
    onLambda :: Lambda rep -> Lambda rep
onLambda Lambda rep
lam = Lambda rep
lam {lambdaBody = replaceArrayOps substs $ lambdaBody lam}

-- Turn
--
--    map (\i -> ... xs[i] ...) (iota n)
--
-- into
--
--    map (\i x -> ... x ...) (iota n) xs
--
-- This is not because we want to encourage the map-iota pattern, but
-- it may be present in generated code.  This is an unfortunately
-- expensive simplification rule, since it requires multiple passes
-- over the entire lambda body.  It only handles the very simplest
-- case - if you find yourself planning to extend it to handle more
-- complex situations (rotate or whatnot), consider turning it into a
-- separate compiler pass instead.
simplifyMapIota ::
  forall rep.
  (Buildable rep, BuilderOps rep, HasSOAC rep) =>
  TopDownRuleOp rep
simplifyMapIota :: forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
simplifyMapIota TopDown rep
vtable Pat (LetDec rep)
screma_pat StmAux (ExpDec rep)
aux Op rep
op
  | Just (Screma SubExp
w [VName]
arrs (ScremaForm Lambda rep
map_lam [Scan rep]
scan [Reduce rep]
reduce) :: SOAC rep) <- Op rep -> Maybe (SOAC rep)
forall rep. HasSOAC rep => Op rep -> Maybe (SOAC rep)
asSOAC Op rep
op,
    Just (Param (TypeBase Shape NoUniqueness)
p, VName
_) <- ((Param (TypeBase Shape NoUniqueness), VName) -> Bool)
-> [(Param (TypeBase Shape NoUniqueness), VName)]
-> Maybe (Param (TypeBase Shape NoUniqueness), VName)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (Param (TypeBase Shape NoUniqueness), VName) -> Bool
isIota ([Param (TypeBase Shape NoUniqueness)]
-> [VName] -> [(Param (TypeBase Shape NoUniqueness), VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda rep -> [LParam rep]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
map_lam) [VName]
arrs),
    [(Pat (LetDec rep), [SubExp], ArrayOp)]
indexings <-
      ((Pat (LetDec rep), ArrayOp)
 -> Maybe (Pat (LetDec rep), [SubExp], ArrayOp))
-> [(Pat (LetDec rep), ArrayOp)]
-> [(Pat (LetDec rep), [SubExp], ArrayOp)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (VName
-> (Pat (LetDec rep), ArrayOp)
-> Maybe (Pat (LetDec rep), [SubExp], ArrayOp)
indexesWith (Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
p)) ([(Pat (LetDec rep), ArrayOp)]
 -> [(Pat (LetDec rep), [SubExp], ArrayOp)])
-> (Set (Pat (LetDec rep), ArrayOp)
    -> [(Pat (LetDec rep), ArrayOp)])
-> Set (Pat (LetDec rep), ArrayOp)
-> [(Pat (LetDec rep), [SubExp], ArrayOp)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Set (Pat (LetDec rep), ArrayOp) -> [(Pat (LetDec rep), ArrayOp)]
forall a. Set a -> [a]
S.toList (Set (Pat (LetDec rep), ArrayOp)
 -> [(Pat (LetDec rep), [SubExp], ArrayOp)])
-> Set (Pat (LetDec rep), ArrayOp)
-> [(Pat (LetDec rep), [SubExp], ArrayOp)]
forall a b. (a -> b) -> a -> b
$
        Certs -> Body rep -> Set (Pat (LetDec rep), ArrayOp)
forall rep.
(Buildable rep, HasSOAC rep) =>
Certs -> Body rep -> Set (Pat (LetDec rep), ArrayOp)
arrayOps Certs
forall a. Monoid a => a
mempty (Body rep -> Set (Pat (LetDec rep), ArrayOp))
-> Body rep -> Set (Pat (LetDec rep), ArrayOp)
forall a b. (a -> b) -> a -> b
$
          Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
map_lam,
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [(Pat (LetDec rep), [SubExp], ArrayOp)] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Pat (LetDec rep), [SubExp], ArrayOp)]
indexings = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
      -- For each indexing with iota, add the corresponding array to
      -- the Screma, and construct a new lambda parameter.
      ([VName]
more_arrs, [Param (TypeBase Shape NoUniqueness)]
more_params, [(Pat (LetDec rep), ArrayOp)]
replacements) <-
        [(VName, Param (TypeBase Shape NoUniqueness),
  (Pat (LetDec rep), ArrayOp))]
-> ([VName], [Param (TypeBase Shape NoUniqueness)],
    [(Pat (LetDec rep), ArrayOp)])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(VName, Param (TypeBase Shape NoUniqueness),
   (Pat (LetDec rep), ArrayOp))]
 -> ([VName], [Param (TypeBase Shape NoUniqueness)],
     [(Pat (LetDec rep), ArrayOp)]))
-> ([Maybe
       (VName, Param (TypeBase Shape NoUniqueness),
        (Pat (LetDec rep), ArrayOp))]
    -> [(VName, Param (TypeBase Shape NoUniqueness),
         (Pat (LetDec rep), ArrayOp))])
-> [Maybe
      (VName, Param (TypeBase Shape NoUniqueness),
       (Pat (LetDec rep), ArrayOp))]
-> ([VName], [Param (TypeBase Shape NoUniqueness)],
    [(Pat (LetDec rep), ArrayOp)])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe
   (VName, Param (TypeBase Shape NoUniqueness),
    (Pat (LetDec rep), ArrayOp))]
-> [(VName, Param (TypeBase Shape NoUniqueness),
     (Pat (LetDec rep), ArrayOp))]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe
    (VName, Param (TypeBase Shape NoUniqueness),
     (Pat (LetDec rep), ArrayOp))]
 -> ([VName], [Param (TypeBase Shape NoUniqueness)],
     [(Pat (LetDec rep), ArrayOp)]))
-> RuleM
     rep
     [Maybe
        (VName, Param (TypeBase Shape NoUniqueness),
         (Pat (LetDec rep), ArrayOp))]
-> RuleM
     rep
     ([VName], [Param (TypeBase Shape NoUniqueness)],
      [(Pat (LetDec rep), ArrayOp)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Pat (LetDec rep), [SubExp], ArrayOp)
 -> RuleM
      rep
      (Maybe
         (VName, Param (TypeBase Shape NoUniqueness),
          (Pat (LetDec rep), ArrayOp))))
-> [(Pat (LetDec rep), [SubExp], ArrayOp)]
-> RuleM
     rep
     [Maybe
        (VName, Param (TypeBase Shape NoUniqueness),
         (Pat (LetDec rep), ArrayOp))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SubExp
-> (Pat (LetDec rep), [SubExp], ArrayOp)
-> RuleM
     rep
     (Maybe
        (VName, Param (TypeBase Shape NoUniqueness),
         (Pat (LetDec rep), ArrayOp)))
forall {m :: * -> *} {a}.
MonadBuilder m =>
SubExp
-> (a, [SubExp], ArrayOp)
-> m (Maybe
        (VName, Param (TypeBase Shape NoUniqueness), (a, ArrayOp)))
mapOverArr SubExp
w) [(Pat (LetDec rep), [SubExp], ArrayOp)]
indexings
      let substs :: Map (Pat (LetDec rep)) ArrayOp
substs = [(Pat (LetDec rep), ArrayOp)] -> Map (Pat (LetDec rep)) ArrayOp
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Pat (LetDec rep), ArrayOp)]
replacements
          map_lam' :: Lambda rep
map_lam' =
            Lambda rep
map_lam
              { lambdaParams = lambdaParams map_lam <> more_params,
                lambdaBody = replaceArrayOps substs $ lambdaBody map_lam
              }

      StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ())
-> (SOAC rep -> RuleM rep ()) -> SOAC rep -> RuleM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
screma_pat (Exp rep -> RuleM rep ())
-> (SOAC rep -> Exp rep) -> SOAC rep -> RuleM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op rep -> Exp rep
forall rep. Op rep -> Exp rep
Op (Op rep -> Exp rep) -> (SOAC rep -> Op rep) -> SOAC rep -> Exp rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SOAC rep -> Op rep
forall rep. HasSOAC rep => SOAC rep -> Op rep
soacOp (SOAC rep -> RuleM rep ()) -> SOAC rep -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
        SubExp -> [VName] -> ScremaForm rep -> SOAC rep
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w ([VName]
arrs [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
more_arrs) (Lambda rep -> [Scan rep] -> [Reduce rep] -> ScremaForm rep
forall rep.
Lambda rep -> [Scan rep] -> [Reduce rep] -> ScremaForm rep
ScremaForm Lambda rep
map_lam' [Scan rep]
scan [Reduce rep]
reduce)
  where
    isIota :: (Param (TypeBase Shape NoUniqueness), VName) -> Bool
isIota (Param (TypeBase Shape NoUniqueness)
_, VName
arr) = case VName -> TopDown rep -> Maybe (BasicOp, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (BasicOp, Certs)
ST.lookupBasicOp VName
arr TopDown rep
vtable of
      Just (Iota SubExp
_ (Constant PrimValue
o) (Constant PrimValue
s) IntType
_, Certs
_) ->
        PrimValue -> Bool
zeroIsh PrimValue
o Bool -> Bool -> Bool
&& PrimValue -> Bool
oneIsh PrimValue
s
      Maybe (BasicOp, Certs)
_ -> Bool
False

    -- Find a 'DimFix i', optionally preceded by other DimFixes, and
    -- if so return those DimFixes.
    fixWith :: VName -> [DimIndex SubExp] -> Maybe [SubExp]
fixWith VName
i (DimFix SubExp
j : [DimIndex SubExp]
slice)
      | VName -> SubExp
Var VName
i SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
j = [SubExp] -> Maybe [SubExp]
forall a. a -> Maybe a
Just []
      | Bool
otherwise = (SubExp
j :) ([SubExp] -> [SubExp]) -> Maybe [SubExp] -> Maybe [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> [DimIndex SubExp] -> Maybe [SubExp]
fixWith VName
i [DimIndex SubExp]
slice
    fixWith VName
_ [DimIndex SubExp]
_ = Maybe [SubExp]
forall a. Maybe a
Nothing

    indexesWith :: VName
-> (Pat (LetDec rep), ArrayOp)
-> Maybe (Pat (LetDec rep), [SubExp], ArrayOp)
indexesWith VName
v (Pat (LetDec rep)
pat, idx :: ArrayOp
idx@(ArrayIndexing Certs
cs VName
arr (Slice [DimIndex SubExp]
js)))
      | VName
arr VName -> TopDown rep -> Bool
forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` TopDown rep
vtable,
        (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> TopDown rep -> Bool
forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` TopDown rep
vtable) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Certs -> [VName]
unCerts Certs
cs,
        Just [SubExp]
js' <- VName -> [DimIndex SubExp] -> Maybe [SubExp]
fixWith VName
v [DimIndex SubExp]
js,
        (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> TopDown rep -> Bool
forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` TopDown rep
vtable) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Names
forall a. FreeIn a => a -> Names
freeIn [SubExp]
js' =
          (Pat (LetDec rep), [SubExp], ArrayOp)
-> Maybe (Pat (LetDec rep), [SubExp], ArrayOp)
forall a. a -> Maybe a
Just (Pat (LetDec rep)
pat, [SubExp]
js', ArrayOp
idx)
    indexesWith VName
_ (Pat (LetDec rep), ArrayOp)
_ = Maybe (Pat (LetDec rep), [SubExp], ArrayOp)
forall a. Maybe a
Nothing

    properArr :: [SubExp] -> VName -> f VName
properArr [] VName
arr = VName -> f VName
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
arr
    properArr [SubExp]
js VName
arr = do
      TypeBase Shape NoUniqueness
arr_t <- VName -> f (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
arr
      Name -> Exp (Rep f) -> f VName
forall (m :: * -> *).
MonadBuilder m =>
Name -> Exp (Rep m) -> m VName
letExp (VName -> Name
baseName VName
arr) (Exp (Rep f) -> f VName) -> Exp (Rep f) -> f VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep f)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep f)) -> BasicOp -> Exp (Rep f)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ TypeBase Shape NoUniqueness -> [DimIndex SubExp] -> Slice SubExp
fullSlice TypeBase Shape NoUniqueness
arr_t ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ (SubExp -> DimIndex SubExp) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix [SubExp]
js

    mapOverArr :: SubExp
-> (a, [SubExp], ArrayOp)
-> m (Maybe
        (VName, Param (TypeBase Shape NoUniqueness), (a, ArrayOp)))
mapOverArr SubExp
w (a
pat, [SubExp]
js, ArrayIndexing Certs
cs VName
arr Slice SubExp
slice) = do
      VName
arr' <- [SubExp] -> VName -> m VName
forall {f :: * -> *}.
MonadBuilder f =>
[SubExp] -> VName -> f VName
properArr [SubExp]
js VName
arr
      TypeBase Shape NoUniqueness
arr_t <- VName -> m (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
arr'
      VName
arr'' <-
        if Int -> TypeBase Shape NoUniqueness -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 TypeBase Shape NoUniqueness
arr_t SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
w
          then VName -> m VName
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
arr'
          else
            Certs -> m VName -> m VName
forall a. Certs -> m a -> m a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (m VName -> m VName)
-> (Slice SubExp -> m VName) -> Slice SubExp -> m VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
Name -> Exp (Rep m) -> m VName
letExp (VName -> Name
baseName VName
arr Name -> Name -> Name
forall a. Semigroup a => a -> a -> a
<> Name
"_prefix") (Exp (Rep m) -> m VName)
-> (Slice SubExp -> Exp (Rep m)) -> Slice SubExp -> m VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m))
-> (Slice SubExp -> BasicOp) -> Slice SubExp -> Exp (Rep m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Slice SubExp -> BasicOp
Index VName
arr' (Slice SubExp -> m VName) -> Slice SubExp -> m VName
forall a b. (a -> b) -> a -> b
$
              TypeBase Shape NoUniqueness -> [DimIndex SubExp] -> Slice SubExp
fullSlice TypeBase Shape NoUniqueness
arr_t [SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)]
      Param (TypeBase Shape NoUniqueness)
arr_elem_param <- Name
-> TypeBase Shape NoUniqueness
-> m (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
Name -> dec -> m (Param dec)
newParam (VName -> Name
baseName VName
arr Name -> Name -> Name
forall a. Semigroup a => a -> a -> a
<> Name
"_elem") (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. TypeBase Shape u -> TypeBase Shape u
rowType TypeBase Shape NoUniqueness
arr_t)
      Maybe (VName, Param (TypeBase Shape NoUniqueness), (a, ArrayOp))
-> m (Maybe
        (VName, Param (TypeBase Shape NoUniqueness), (a, ArrayOp)))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (VName, Param (TypeBase Shape NoUniqueness), (a, ArrayOp))
 -> m (Maybe
         (VName, Param (TypeBase Shape NoUniqueness), (a, ArrayOp))))
-> Maybe (VName, Param (TypeBase Shape NoUniqueness), (a, ArrayOp))
-> m (Maybe
        (VName, Param (TypeBase Shape NoUniqueness), (a, ArrayOp)))
forall a b. (a -> b) -> a -> b
$
        (VName, Param (TypeBase Shape NoUniqueness), (a, ArrayOp))
-> Maybe (VName, Param (TypeBase Shape NoUniqueness), (a, ArrayOp))
forall a. a -> Maybe a
Just
          ( VName
arr'',
            Param (TypeBase Shape NoUniqueness)
arr_elem_param,
            ( a
pat,
              Certs -> VName -> Slice SubExp -> ArrayOp
ArrayIndexing Certs
cs (Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
arr_elem_param) ([DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice (Int -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
js Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Slice SubExp -> [DimIndex SubExp]
forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slice)))
            )
          )
    mapOverArr SubExp
_ (a, [SubExp], ArrayOp)
_ = Maybe (VName, Param (TypeBase Shape NoUniqueness), (a, ArrayOp))
-> m (Maybe
        (VName, Param (TypeBase Shape NoUniqueness), (a, ArrayOp)))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (VName, Param (TypeBase Shape NoUniqueness), (a, ArrayOp))
forall a. Maybe a
Nothing
simplifyMapIota TopDown rep
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ Op rep
_ = Rule rep
forall rep. Rule rep
Skip

-- If a Screma's map function contains a transformation
-- (e.g. transpose) on a parameter, create a new parameter
-- corresponding to that transformation performed on the rows of the
-- full array.
moveTransformToInput :: TopDownRuleOp (Wise SOACS)
moveTransformToInput :: RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
moveTransformToInput SymbolTable (Wise SOACS)
vtable Pat (LetDec (Wise SOACS))
screma_pat StmAux (ExpDec (Wise SOACS))
aux soac :: Op (Wise SOACS)
soac@(Screma SubExp
w [VName]
arrs (ScremaForm Lambda (Wise SOACS)
map_lam [Scan (Wise SOACS)]
scan [Reduce (Wise SOACS)]
reduce))
  | [(Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp)]
ops <- ((Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp) -> Bool)
-> [(Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp)]
-> [(Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp) -> Bool
arrayIsMapParam ([(Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp)]
 -> [(Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp)])
-> [(Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp)]
-> [(Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp)]
forall a b. (a -> b) -> a -> b
$ Set (Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp)
-> [(Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp)]
forall a. Set a -> [a]
S.toList (Set (Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp)
 -> [(Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp)])
-> Set (Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp)
-> [(Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp)]
forall a b. (a -> b) -> a -> b
$ Certs
-> GBody (Wise SOACS) SubExpRes
-> Set (Pat (LetDec (Wise SOACS)), ArrayOp)
forall rep.
(Buildable rep, HasSOAC rep) =>
Certs -> Body rep -> Set (Pat (LetDec rep), ArrayOp)
arrayOps Certs
forall a. Monoid a => a
mempty (GBody (Wise SOACS) SubExpRes
 -> Set (Pat (LetDec (Wise SOACS)), ArrayOp))
-> GBody (Wise SOACS) SubExpRes
-> Set (Pat (LetDec (Wise SOACS)), ArrayOp)
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> GBody (Wise SOACS) SubExpRes
forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
map_lam,
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [(Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp)] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp)]
ops = RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ do
      ([VName]
more_arrs, [Param (TypeBase Shape NoUniqueness)]
more_params, [(Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp)]
replacements) <-
        [(VName, Param (TypeBase Shape NoUniqueness),
  (Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp))]
-> ([VName], [Param (TypeBase Shape NoUniqueness)],
    [(Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp)])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(VName, Param (TypeBase Shape NoUniqueness),
   (Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp))]
 -> ([VName], [Param (TypeBase Shape NoUniqueness)],
     [(Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp)]))
-> ([Maybe
       (VName, Param (TypeBase Shape NoUniqueness),
        (Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp))]
    -> [(VName, Param (TypeBase Shape NoUniqueness),
         (Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp))])
-> [Maybe
      (VName, Param (TypeBase Shape NoUniqueness),
       (Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp))]
-> ([VName], [Param (TypeBase Shape NoUniqueness)],
    [(Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp)])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe
   (VName, Param (TypeBase Shape NoUniqueness),
    (Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp))]
-> [(VName, Param (TypeBase Shape NoUniqueness),
     (Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp))]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe
    (VName, Param (TypeBase Shape NoUniqueness),
     (Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp))]
 -> ([VName], [Param (TypeBase Shape NoUniqueness)],
     [(Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp)]))
-> RuleM
     (Wise SOACS)
     [Maybe
        (VName, Param (TypeBase Shape NoUniqueness),
         (Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp))]
-> RuleM
     (Wise SOACS)
     ([VName], [Param (TypeBase Shape NoUniqueness)],
      [(Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp)
 -> RuleM
      (Wise SOACS)
      (Maybe
         (VName, Param (TypeBase Shape NoUniqueness),
          (Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp))))
-> [(Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp)]
-> RuleM
     (Wise SOACS)
     [Maybe
        (VName, Param (TypeBase Shape NoUniqueness),
         (Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp)
-> RuleM
     (Wise SOACS)
     (Maybe
        (VName, Param (TypeBase Shape NoUniqueness),
         (Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp)))
mapOverArr [(Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp)]
ops

      Bool -> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([VName] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
more_arrs) RuleM (Wise SOACS) ()
forall rep a. RuleM rep a
cannotSimplify

      let map_lam' :: Lambda (Wise SOACS)
map_lam' =
            Lambda (Wise SOACS)
map_lam
              { lambdaParams = lambdaParams map_lam <> more_params,
                lambdaBody = replaceArrayOps (M.fromList replacements) $ lambdaBody map_lam
              }

      StmAux (ExpWisdom, ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> (SOAC (Wise SOACS) -> RuleM (Wise SOACS) ())
-> SOAC (Wise SOACS)
-> RuleM (Wise SOACS) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec (Rep (RuleM (Wise SOACS))))
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Rep (RuleM (Wise SOACS))))
Pat (LetDec (Wise SOACS))
screma_pat (Exp (Wise SOACS) -> RuleM (Wise SOACS) ())
-> (SOAC (Wise SOACS) -> Exp (Wise SOACS))
-> SOAC (Wise SOACS)
-> RuleM (Wise SOACS) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op (Wise SOACS) -> Exp (Wise SOACS)
SOAC (Wise SOACS) -> Exp (Wise SOACS)
forall rep. Op rep -> Exp rep
Op (SOAC (Wise SOACS) -> RuleM (Wise SOACS) ())
-> SOAC (Wise SOACS) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
        SubExp -> [VName] -> ScremaForm (Wise SOACS) -> SOAC (Wise SOACS)
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w ([VName]
arrs [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
more_arrs) (Lambda (Wise SOACS)
-> [Scan (Wise SOACS)]
-> [Reduce (Wise SOACS)]
-> ScremaForm (Wise SOACS)
forall rep.
Lambda rep -> [Scan rep] -> [Reduce rep] -> ScremaForm rep
ScremaForm Lambda (Wise SOACS)
map_lam' [Scan (Wise SOACS)]
scan [Reduce (Wise SOACS)]
reduce)
  where
    -- It is not safe to move the transform if the root array is being
    -- consumed by the Screma.  This is a bit too conservative - it's
    -- actually safe if we completely replace the original input, but
    -- this rule is not that precise.
    consumed :: Names
consumed = SOAC (Wise SOACS) -> Names
forall rep. Aliased rep => SOAC rep -> Names
forall (op :: * -> *) rep.
(AliasedOp op, Aliased rep) =>
op rep -> Names
consumedInOp Op (Wise SOACS)
SOAC (Wise SOACS)
soac
    map_param_names :: [VName]
map_param_names = (Param (TypeBase Shape NoUniqueness) -> VName)
-> [Param (TypeBase Shape NoUniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName (Lambda (Wise SOACS) -> [LParam (Wise SOACS)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda (Wise SOACS)
map_lam)
    topLevelPat :: Pat (VarWisdom, TypeBase Shape NoUniqueness) -> Bool
topLevelPat = (Pat (VarWisdom, TypeBase Shape NoUniqueness)
-> Seq (Pat (VarWisdom, TypeBase Shape NoUniqueness)) -> Bool
forall a. Eq a => a -> Seq a -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` (Stm (Wise SOACS) -> Pat (VarWisdom, TypeBase Shape NoUniqueness))
-> Stms (Wise SOACS)
-> Seq (Pat (VarWisdom, TypeBase Shape NoUniqueness))
forall a b. (a -> b) -> Seq a -> Seq b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Stm (Wise SOACS) -> Pat (VarWisdom, TypeBase Shape NoUniqueness)
Stm (Wise SOACS) -> Pat (LetDec (Wise SOACS))
forall rep. Stm rep -> Pat (LetDec rep)
stmPat (GBody (Wise SOACS) SubExpRes -> Stms (Wise SOACS)
forall rep res. GBody rep res -> Stms rep
bodyStms (Lambda (Wise SOACS) -> GBody (Wise SOACS) SubExpRes
forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
map_lam)))
    onlyUsedOnce :: VName -> Bool
onlyUsedOnce VName
arr =
      case (Stm (Wise SOACS) -> Bool)
-> [Stm (Wise SOACS)] -> [Stm (Wise SOACS)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName
arr `nameIn`) (Names -> Bool)
-> (Stm (Wise SOACS) -> Names) -> Stm (Wise SOACS) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm (Wise SOACS) -> Names
forall a. FreeIn a => a -> Names
freeIn) ([Stm (Wise SOACS)] -> [Stm (Wise SOACS)])
-> [Stm (Wise SOACS)] -> [Stm (Wise SOACS)]
forall a b. (a -> b) -> a -> b
$ Stms (Wise SOACS) -> [Stm (Wise SOACS)]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms (Wise SOACS) -> [Stm (Wise SOACS)])
-> Stms (Wise SOACS) -> [Stm (Wise SOACS)]
forall a b. (a -> b) -> a -> b
$ GBody (Wise SOACS) SubExpRes -> Stms (Wise SOACS)
forall rep res. GBody rep res -> Stms rep
bodyStms (GBody (Wise SOACS) SubExpRes -> Stms (Wise SOACS))
-> GBody (Wise SOACS) SubExpRes -> Stms (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> GBody (Wise SOACS) SubExpRes
forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
map_lam of
        Stm (Wise SOACS)
_ : Stm (Wise SOACS)
_ : [Stm (Wise SOACS)]
_ -> Bool
False
        [Stm (Wise SOACS)]
_ -> Bool
True

    -- It's not just about whether the array is a parameter;
    -- everything else must be map-invariant.
    arrayIsMapParam :: (Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp) -> Bool
arrayIsMapParam (Pat (VarWisdom, TypeBase Shape NoUniqueness)
pat', ArrayIndexing Certs
cs VName
arr Slice SubExp
slice) =
      VName
arr VName -> [VName] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
map_param_names
        Bool -> Bool -> Bool
&& (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> SymbolTable (Wise SOACS) -> Bool
forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` SymbolTable (Wise SOACS)
vtable) (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Certs -> Names
forall a. FreeIn a => a -> Names
freeIn Certs
cs Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Slice SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn Slice SubExp
slice)
        Bool -> Bool -> Bool
&& Bool -> Bool
not (Slice SubExp -> Bool
forall a. Slice a -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Slice SubExp
slice)
        Bool -> Bool -> Bool
&& (Bool -> Bool
not ([SubExp] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([SubExp] -> Bool) -> [SubExp] -> Bool
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice) Bool -> Bool -> Bool
|| (Pat (VarWisdom, TypeBase Shape NoUniqueness) -> Bool
topLevelPat Pat (VarWisdom, TypeBase Shape NoUniqueness)
pat' Bool -> Bool -> Bool
&& VName -> Bool
onlyUsedOnce VName
arr))
    arrayIsMapParam (Pat (VarWisdom, TypeBase Shape NoUniqueness)
_, ArrayRearrange Certs
cs VName
arr [Int]
perm) =
      VName
arr VName -> [VName] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
map_param_names
        Bool -> Bool -> Bool
&& (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> SymbolTable (Wise SOACS) -> Bool
forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` SymbolTable (Wise SOACS)
vtable) (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Certs -> Names
forall a. FreeIn a => a -> Names
freeIn Certs
cs)
        Bool -> Bool -> Bool
&& Bool -> Bool
not ([Int] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
perm)
    arrayIsMapParam (Pat (VarWisdom, TypeBase Shape NoUniqueness)
_, ArrayReshape Certs
cs VName
arr NewShape SubExp
new_shape) =
      VName
arr VName -> [VName] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
map_param_names
        Bool -> Bool -> Bool
&& (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> SymbolTable (Wise SOACS) -> Bool
forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` SymbolTable (Wise SOACS)
vtable) (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Certs -> Names
forall a. FreeIn a => a -> Names
freeIn Certs
cs Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> NewShape SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn NewShape SubExp
new_shape)
    arrayIsMapParam (Pat (VarWisdom, TypeBase Shape NoUniqueness)
_, ArrayCopy Certs
cs VName
arr) =
      VName
arr VName -> [VName] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
map_param_names
        Bool -> Bool -> Bool
&& (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> SymbolTable (Wise SOACS) -> Bool
forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` SymbolTable (Wise SOACS)
vtable) (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Certs -> Names
forall a. FreeIn a => a -> Names
freeIn Certs
cs)
    arrayIsMapParam (Pat (VarWisdom, TypeBase Shape NoUniqueness)
_, ArrayVar {}) =
      Bool
False

    mapOverArr :: (Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp)
-> RuleM
     (Wise SOACS)
     (Maybe
        (VName, Param (TypeBase Shape NoUniqueness),
         (Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp)))
mapOverArr (Pat (VarWisdom, TypeBase Shape NoUniqueness)
pat, ArrayOp
op)
      | Just (VName
_, VName
arr) <- ((VName, VName) -> Bool)
-> [(VName, VName)] -> Maybe (VName, VName)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== ArrayOp -> VName
arrayOpArr ArrayOp
op) (VName -> Bool)
-> ((VName, VName) -> VName) -> (VName, VName) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, VName) -> VName
forall a b. (a, b) -> a
fst) ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
map_param_names [VName]
arrs),
        VName
arr VName -> Names -> Bool
`notNameIn` Names
consumed = do
          TypeBase Shape NoUniqueness
arr_t <- VName -> RuleM (Wise SOACS) (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
arr
          let whole_dim :: DimIndex SubExp
whole_dim = SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (Int -> TypeBase Shape NoUniqueness -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 TypeBase Shape NoUniqueness
arr_t) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)
          VName
arr_transformed <- Certs -> RuleM (Wise SOACS) VName -> RuleM (Wise SOACS) VName
forall a. Certs -> RuleM (Wise SOACS) a -> RuleM (Wise SOACS) a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (ArrayOp -> Certs
arrayOpCerts ArrayOp
op) (RuleM (Wise SOACS) VName -> RuleM (Wise SOACS) VName)
-> RuleM (Wise SOACS) VName -> RuleM (Wise SOACS) VName
forall a b. (a -> b) -> a -> b
$
            Name -> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) VName
forall (m :: * -> *).
MonadBuilder m =>
Name -> Exp (Rep m) -> m VName
letExp (VName -> Name
baseName VName
arr Name -> Name -> Name
forall a. Semigroup a => a -> a -> a
<> Name
"_transformed") (Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) VName)
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) VName
forall a b. (a -> b) -> a -> b
$
              case ArrayOp
op of
                ArrayIndexing Certs
_ VName
_ (Slice [DimIndex SubExp]
slice) ->
                  BasicOp -> Exp (Rep (RuleM (Wise SOACS)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM (Wise SOACS))))
-> BasicOp -> Exp (Rep (RuleM (Wise SOACS)))
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ DimIndex SubExp
whole_dim DimIndex SubExp -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. a -> [a] -> [a]
: [DimIndex SubExp]
slice
                ArrayRearrange Certs
_ VName
_ [Int]
perm ->
                  BasicOp -> Exp (Rep (RuleM (Wise SOACS)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM (Wise SOACS))))
-> BasicOp -> Exp (Rep (RuleM (Wise SOACS)))
forall a b. (a -> b) -> a -> b
$ VName -> [Int] -> BasicOp
Rearrange VName
arr (Int
0 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: (Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [Int]
perm)
                ArrayReshape Certs
_ VName
_ NewShape SubExp
new_shape ->
                  BasicOp -> Exp (Rep (RuleM (Wise SOACS)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM (Wise SOACS))))
-> BasicOp -> Exp (Rep (RuleM (Wise SOACS)))
forall a b. (a -> b) -> a -> b
$ VName -> NewShape SubExp -> BasicOp
Reshape VName
arr (NewShape SubExp -> BasicOp) -> NewShape SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ SubExp -> NewShape SubExp -> NewShape SubExp
reshapeInner SubExp
w NewShape SubExp
new_shape
                ArrayCopy {} ->
                  BasicOp -> Exp (Rep (RuleM (Wise SOACS)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM (Wise SOACS))))
-> BasicOp -> Exp (Rep (RuleM (Wise SOACS)))
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
forall a. Monoid a => a
mempty (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr
                ArrayVar {} ->
                  BasicOp -> Exp (Rep (RuleM (Wise SOACS)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM (Wise SOACS))))
-> BasicOp -> Exp (Rep (RuleM (Wise SOACS)))
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr
          TypeBase Shape NoUniqueness
arr_transformed_t <- VName -> RuleM (Wise SOACS) (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
arr_transformed
          VName
arr_transformed_row <- Name -> RuleM (Wise SOACS) VName
forall (m :: * -> *). MonadFreshNames m => Name -> m VName
newVName (Name -> RuleM (Wise SOACS) VName)
-> Name -> RuleM (Wise SOACS) VName
forall a b. (a -> b) -> a -> b
$ VName -> Name
baseName VName
arr Name -> Name -> Name
forall a. Semigroup a => a -> a -> a
<> Name
"_transformed_row"
          Maybe
  (VName, Param (TypeBase Shape NoUniqueness),
   (Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp))
-> RuleM
     (Wise SOACS)
     (Maybe
        (VName, Param (TypeBase Shape NoUniqueness),
         (Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp)))
forall a. a -> RuleM (Wise SOACS) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe
   (VName, Param (TypeBase Shape NoUniqueness),
    (Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp))
 -> RuleM
      (Wise SOACS)
      (Maybe
         (VName, Param (TypeBase Shape NoUniqueness),
          (Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp))))
-> Maybe
     (VName, Param (TypeBase Shape NoUniqueness),
      (Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp))
-> RuleM
     (Wise SOACS)
     (Maybe
        (VName, Param (TypeBase Shape NoUniqueness),
         (Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp)))
forall a b. (a -> b) -> a -> b
$
            (VName, Param (TypeBase Shape NoUniqueness),
 (Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp))
-> Maybe
     (VName, Param (TypeBase Shape NoUniqueness),
      (Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp))
forall a. a -> Maybe a
Just
              ( VName
arr_transformed,
                Attrs
-> VName
-> TypeBase Shape NoUniqueness
-> Param (TypeBase Shape NoUniqueness)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty VName
arr_transformed_row (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. TypeBase Shape u -> TypeBase Shape u
rowType TypeBase Shape NoUniqueness
arr_transformed_t),
                (Pat (VarWisdom, TypeBase Shape NoUniqueness)
pat, Certs -> VName -> ArrayOp
ArrayVar Certs
forall a. Monoid a => a
mempty VName
arr_transformed_row)
              )
    mapOverArr (Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp)
_ = Maybe
  (VName, Param (TypeBase Shape NoUniqueness),
   (Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp))
-> RuleM
     (Wise SOACS)
     (Maybe
        (VName, Param (TypeBase Shape NoUniqueness),
         (Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp)))
forall a. a -> RuleM (Wise SOACS) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe
  (VName, Param (TypeBase Shape NoUniqueness),
   (Pat (VarWisdom, TypeBase Shape NoUniqueness), ArrayOp))
forall a. Maybe a
Nothing
moveTransformToInput SymbolTable (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ =
  Rule (Wise SOACS)
forall rep. Rule rep
Skip

-- The idea behidn this rule is to tak cases such as
--
--   let ...A... =
--     map (\x -> ...
--              let x = ...
--              ...
--              let y = f(x)
--              ...
--              in ...y ...)
--
-- where 'f' is some transformation like a reshape, and move it out
-- such that we get
--
--   let ...A'... =
--     map (\x -> ...
--              let x = ...
--              ...
--              in ...x ...)
--   let A' = f'(A')
--
-- This can improve simplification in case A' fuses or simplifies with
-- something else.
--
-- TODO: currently we only handle reshapes here, but the principle
-- should actually hold for any ArrayTransform.
moveTransformToOutput :: TopDownRuleOp (Wise SOACS)
moveTransformToOutput :: RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
moveTransformToOutput SymbolTable (Wise SOACS)
vtable Pat (LetDec (Wise SOACS))
screma_pat StmAux (ExpDec (Wise SOACS))
screma_aux (Screma SubExp
w [VName]
arrs (ScremaForm Lambda (Wise SOACS)
map_lam [Scan (Wise SOACS)]
scan [Reduce (Wise SOACS)]
reduce))
  | ([(TypeBase Shape NoUniqueness,
  PatElem (VarWisdom, TypeBase Shape NoUniqueness),
  (VName, Certs, VName -> Exp (Wise SOACS)))]
transformed, [(SubExpRes, TypeBase Shape NoUniqueness,
  PatElem (VarWisdom, TypeBase Shape NoUniqueness))]
map_infos, Stms (Wise SOACS)
stms') <-
      (([(TypeBase Shape NoUniqueness,
    PatElem (VarWisdom, TypeBase Shape NoUniqueness),
    (VName, Certs, VName -> Exp (Wise SOACS)))],
  [(SubExpRes, TypeBase Shape NoUniqueness,
    PatElem (VarWisdom, TypeBase Shape NoUniqueness))],
  Stms (Wise SOACS))
 -> Stm (Wise SOACS)
 -> ([(TypeBase Shape NoUniqueness,
       PatElem (VarWisdom, TypeBase Shape NoUniqueness),
       (VName, Certs, VName -> Exp (Wise SOACS)))],
     [(SubExpRes, TypeBase Shape NoUniqueness,
       PatElem (VarWisdom, TypeBase Shape NoUniqueness))],
     Stms (Wise SOACS)))
-> ([(TypeBase Shape NoUniqueness,
      PatElem (VarWisdom, TypeBase Shape NoUniqueness),
      (VName, Certs, VName -> Exp (Wise SOACS)))],
    [(SubExpRes, TypeBase Shape NoUniqueness,
      PatElem (VarWisdom, TypeBase Shape NoUniqueness))],
    Stms (Wise SOACS))
-> Stms (Wise SOACS)
-> ([(TypeBase Shape NoUniqueness,
      PatElem (VarWisdom, TypeBase Shape NoUniqueness),
      (VName, Certs, VName -> Exp (Wise SOACS)))],
    [(SubExpRes, TypeBase Shape NoUniqueness,
      PatElem (VarWisdom, TypeBase Shape NoUniqueness))],
    Stms (Wise SOACS))
forall b a. (b -> a -> b) -> b -> Seq a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ([(TypeBase Shape NoUniqueness,
   PatElem (VarWisdom, TypeBase Shape NoUniqueness),
   (VName, Certs, VName -> Exp (Wise SOACS)))],
 [(SubExpRes, TypeBase Shape NoUniqueness,
   PatElem (VarWisdom, TypeBase Shape NoUniqueness))],
 Stms (Wise SOACS))
-> Stm (Wise SOACS)
-> ([(TypeBase Shape NoUniqueness,
      PatElem (VarWisdom, TypeBase Shape NoUniqueness),
      (VName, Certs, VName -> Exp (Wise SOACS)))],
    [(SubExpRes, TypeBase Shape NoUniqueness,
      PatElem (VarWisdom, TypeBase Shape NoUniqueness))],
    Stms (Wise SOACS))
onStm ([], [SubExpRes]
-> [TypeBase Shape NoUniqueness]
-> [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
-> [(SubExpRes, TypeBase Shape NoUniqueness,
     PatElem (VarWisdom, TypeBase Shape NoUniqueness))]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [SubExpRes]
map_res [TypeBase Shape NoUniqueness]
map_rets [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
map_pes, Stms (Wise SOACS)
forall a. Monoid a => a
mempty) (Stms (Wise SOACS)
 -> ([(TypeBase Shape NoUniqueness,
       PatElem (VarWisdom, TypeBase Shape NoUniqueness),
       (VName, Certs, VName -> Exp (Wise SOACS)))],
     [(SubExpRes, TypeBase Shape NoUniqueness,
       PatElem (VarWisdom, TypeBase Shape NoUniqueness))],
     Stms (Wise SOACS)))
-> Stms (Wise SOACS)
-> ([(TypeBase Shape NoUniqueness,
      PatElem (VarWisdom, TypeBase Shape NoUniqueness),
      (VName, Certs, VName -> Exp (Wise SOACS)))],
    [(SubExpRes, TypeBase Shape NoUniqueness,
      PatElem (VarWisdom, TypeBase Shape NoUniqueness))],
    Stms (Wise SOACS))
forall a b. (a -> b) -> a -> b
$ GBody (Wise SOACS) SubExpRes -> Stms (Wise SOACS)
forall rep res. GBody rep res -> Stms rep
bodyStms (GBody (Wise SOACS) SubExpRes -> Stms (Wise SOACS))
-> GBody (Wise SOACS) SubExpRes -> Stms (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> GBody (Wise SOACS) SubExpRes
forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
map_lam,
    ([SubExpRes]
map_res', [TypeBase Shape NoUniqueness]
map_rets', [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
map_pes') <- [(SubExpRes, TypeBase Shape NoUniqueness,
  PatElem (VarWisdom, TypeBase Shape NoUniqueness))]
-> ([SubExpRes], [TypeBase Shape NoUniqueness],
    [PatElem (VarWisdom, TypeBase Shape NoUniqueness)])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(SubExpRes, TypeBase Shape NoUniqueness,
  PatElem (VarWisdom, TypeBase Shape NoUniqueness))]
map_infos,
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [(TypeBase Shape NoUniqueness,
  PatElem (VarWisdom, TypeBase Shape NoUniqueness),
  (VName, Certs, VName -> Exp (Wise SOACS)))]
-> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(TypeBase Shape NoUniqueness,
  PatElem (VarWisdom, TypeBase Shape NoUniqueness),
  (VName, Certs, VName -> Exp (Wise SOACS)))]
transformed = RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ do
      ([SubExpRes]
tr_res, [TypeBase Shape NoUniqueness]
tr_rets, [VName]
tr_names, [RuleM (Wise SOACS) ()]
post) <- [(SubExpRes, TypeBase Shape NoUniqueness, VName,
  RuleM (Wise SOACS) ())]
-> ([SubExpRes], [TypeBase Shape NoUniqueness], [VName],
    [RuleM (Wise SOACS) ()])
forall a b c d. [(a, b, c, d)] -> ([a], [b], [c], [d])
unzip4 ([(SubExpRes, TypeBase Shape NoUniqueness, VName,
   RuleM (Wise SOACS) ())]
 -> ([SubExpRes], [TypeBase Shape NoUniqueness], [VName],
     [RuleM (Wise SOACS) ()]))
-> RuleM
     (Wise SOACS)
     [(SubExpRes, TypeBase Shape NoUniqueness, VName,
       RuleM (Wise SOACS) ())]
-> RuleM
     (Wise SOACS)
     ([SubExpRes], [TypeBase Shape NoUniqueness], [VName],
      [RuleM (Wise SOACS) ()])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((TypeBase Shape NoUniqueness,
  PatElem (VarWisdom, TypeBase Shape NoUniqueness),
  (VName, Certs, VName -> Exp (Wise SOACS)))
 -> RuleM
      (Wise SOACS)
      (SubExpRes, TypeBase Shape NoUniqueness, VName,
       RuleM (Wise SOACS) ()))
-> [(TypeBase Shape NoUniqueness,
     PatElem (VarWisdom, TypeBase Shape NoUniqueness),
     (VName, Certs, VName -> Exp (Wise SOACS)))]
-> RuleM
     (Wise SOACS)
     [(SubExpRes, TypeBase Shape NoUniqueness, VName,
       RuleM (Wise SOACS) ())]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (TypeBase Shape NoUniqueness,
 PatElem (VarWisdom, TypeBase Shape NoUniqueness),
 (VName, Certs, VName -> Exp (Rep (RuleM (Wise SOACS)))))
-> RuleM
     (Wise SOACS)
     (SubExpRes, TypeBase Shape NoUniqueness, VName,
      RuleM (Wise SOACS) ())
(TypeBase Shape NoUniqueness,
 PatElem (VarWisdom, TypeBase Shape NoUniqueness),
 (VName, Certs, VName -> Exp (Wise SOACS)))
-> RuleM
     (Wise SOACS)
     (SubExpRes, TypeBase Shape NoUniqueness, VName,
      RuleM (Wise SOACS) ())
forall {m :: * -> *} {m :: * -> *} {b} {dec}.
(MonadBuilder m, MonadFreshNames m) =>
(b, PatElem dec, (VName, Certs, VName -> Exp (Rep m)))
-> m (SubExpRes, b, VName, m ())
mkTransformed [(TypeBase Shape NoUniqueness,
  PatElem (VarWisdom, TypeBase Shape NoUniqueness),
  (VName, Certs, VName -> Exp (Wise SOACS)))]
transformed
      let map_lam' :: Lambda (Wise SOACS)
map_lam' =
            Lambda (Wise SOACS)
map_lam
              { lambdaBody = mkBody stms' $ nonmap_res <> map_res' <> tr_res,
                lambdaReturnType = nonmap_rets <> map_rets' <> tr_rets
              }
          pat_names :: [VName]
pat_names = (PatElem (VarWisdom, TypeBase Shape NoUniqueness) -> VName)
-> [PatElem (VarWisdom, TypeBase Shape NoUniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map PatElem (VarWisdom, TypeBase Shape NoUniqueness) -> VName
forall dec. PatElem dec -> VName
patElemName ([PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
nonmap_pes [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
-> [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
-> [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
forall a. Semigroup a => a -> a -> a
<> [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
map_pes') [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
tr_names
      StmAux (ExpWisdom, ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
screma_aux (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> (SOAC (Wise SOACS) -> RuleM (Wise SOACS) ())
-> SOAC (Wise SOACS)
-> RuleM (Wise SOACS) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [VName] -> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName]
pat_names (Exp (Wise SOACS) -> RuleM (Wise SOACS) ())
-> (SOAC (Wise SOACS) -> Exp (Wise SOACS))
-> SOAC (Wise SOACS)
-> RuleM (Wise SOACS) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op (Wise SOACS) -> Exp (Wise SOACS)
SOAC (Wise SOACS) -> Exp (Wise SOACS)
forall rep. Op rep -> Exp rep
Op (SOAC (Wise SOACS) -> RuleM (Wise SOACS) ())
-> SOAC (Wise SOACS) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
        SubExp -> [VName] -> ScremaForm (Wise SOACS) -> SOAC (Wise SOACS)
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs (Lambda (Wise SOACS)
-> [Scan (Wise SOACS)]
-> [Reduce (Wise SOACS)]
-> ScremaForm (Wise SOACS)
forall rep.
Lambda rep -> [Scan rep] -> [Reduce rep] -> ScremaForm rep
ScremaForm Lambda (Wise SOACS)
map_lam' [Scan (Wise SOACS)]
scan [Reduce (Wise SOACS)]
reduce)
      [RuleM (Wise SOACS) ()] -> RuleM (Wise SOACS) ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ [RuleM (Wise SOACS) ()]
post
  where
    num_nonmap_res :: Int
num_nonmap_res = [Scan (Wise SOACS)] -> Int
forall rep. [Scan rep] -> Int
scanResults [Scan (Wise SOACS)]
scan Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Reduce (Wise SOACS)] -> Int
forall rep. [Reduce rep] -> Int
redResults [Reduce (Wise SOACS)]
reduce
    ([PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
nonmap_pes, [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
map_pes) =
      Int
-> [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
-> ([PatElem (VarWisdom, TypeBase Shape NoUniqueness)],
    [PatElem (VarWisdom, TypeBase Shape NoUniqueness)])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_nonmap_res ([PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
 -> ([PatElem (VarWisdom, TypeBase Shape NoUniqueness)],
     [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]))
-> [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
-> ([PatElem (VarWisdom, TypeBase Shape NoUniqueness)],
    [PatElem (VarWisdom, TypeBase Shape NoUniqueness)])
forall a b. (a -> b) -> a -> b
$ Pat (VarWisdom, TypeBase Shape NoUniqueness)
-> [PatElem (VarWisdom, TypeBase Shape NoUniqueness)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (VarWisdom, TypeBase Shape NoUniqueness)
Pat (LetDec (Wise SOACS))
screma_pat
    ([TypeBase Shape NoUniqueness]
nonmap_rets, [TypeBase Shape NoUniqueness]
map_rets) =
      Int
-> [TypeBase Shape NoUniqueness]
-> ([TypeBase Shape NoUniqueness], [TypeBase Shape NoUniqueness])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_nonmap_res ([TypeBase Shape NoUniqueness]
 -> ([TypeBase Shape NoUniqueness], [TypeBase Shape NoUniqueness]))
-> [TypeBase Shape NoUniqueness]
-> ([TypeBase Shape NoUniqueness], [TypeBase Shape NoUniqueness])
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda (Wise SOACS)
map_lam
    ([SubExpRes]
nonmap_res, [SubExpRes]
map_res) =
      Int -> [SubExpRes] -> ([SubExpRes], [SubExpRes])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_nonmap_res ([SubExpRes] -> ([SubExpRes], [SubExpRes]))
-> [SubExpRes] -> ([SubExpRes], [SubExpRes])
forall a b. (a -> b) -> a -> b
$ GBody (Wise SOACS) SubExpRes -> [SubExpRes]
forall rep res. GBody rep res -> [res]
bodyResult (GBody (Wise SOACS) SubExpRes -> [SubExpRes])
-> GBody (Wise SOACS) SubExpRes -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> GBody (Wise SOACS) SubExpRes
forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
map_lam

    scope :: Map VName (NameInfo (Wise SOACS))
scope = Stms (Wise SOACS) -> Map VName (NameInfo (Wise SOACS))
forall rep a. Scoped rep a => a -> Scope rep
scopeOf (Stms (Wise SOACS) -> Map VName (NameInfo (Wise SOACS)))
-> Stms (Wise SOACS) -> Map VName (NameInfo (Wise SOACS))
forall a b. (a -> b) -> a -> b
$ GBody (Wise SOACS) SubExpRes -> Stms (Wise SOACS)
forall rep res. GBody rep res -> Stms rep
bodyStms (GBody (Wise SOACS) SubExpRes -> Stms (Wise SOACS))
-> GBody (Wise SOACS) SubExpRes -> Stms (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> GBody (Wise SOACS) SubExpRes
forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
map_lam

    invariantToMap :: (TypeBase Shape NoUniqueness, NewShape SubExp) -> Bool
invariantToMap = (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> SymbolTable (Wise SOACS) -> Bool
forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` SymbolTable (Wise SOACS)
vtable) ([VName] -> Bool)
-> ((TypeBase Shape NoUniqueness, NewShape SubExp) -> [VName])
-> (TypeBase Shape NoUniqueness, NewShape SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList (Names -> [VName])
-> ((TypeBase Shape NoUniqueness, NewShape SubExp) -> Names)
-> (TypeBase Shape NoUniqueness, NewShape SubExp)
-> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TypeBase Shape NoUniqueness, NewShape SubExp) -> Names
forall a. FreeIn a => a -> Names
freeIn

    onStm :: ([(TypeBase Shape NoUniqueness,
   PatElem (VarWisdom, TypeBase Shape NoUniqueness),
   (VName, Certs, VName -> Exp (Wise SOACS)))],
 [(SubExpRes, TypeBase Shape NoUniqueness,
   PatElem (VarWisdom, TypeBase Shape NoUniqueness))],
 Stms (Wise SOACS))
-> Stm (Wise SOACS)
-> ([(TypeBase Shape NoUniqueness,
      PatElem (VarWisdom, TypeBase Shape NoUniqueness),
      (VName, Certs, VName -> Exp (Wise SOACS)))],
    [(SubExpRes, TypeBase Shape NoUniqueness,
      PatElem (VarWisdom, TypeBase Shape NoUniqueness))],
    Stms (Wise SOACS))
onStm ([(TypeBase Shape NoUniqueness,
  PatElem (VarWisdom, TypeBase Shape NoUniqueness),
  (VName, Certs, VName -> Exp (Wise SOACS)))]
transformed, [(SubExpRes, TypeBase Shape NoUniqueness,
  PatElem (VarWisdom, TypeBase Shape NoUniqueness))]
map_infos, Stms (Wise SOACS)
stms) (Let (Pat [PatElem (LetDec (Wise SOACS))
pe]) StmAux (ExpDec (Wise SOACS))
aux (BasicOp (Reshape VName
arr NewShape SubExp
new_shape)))
      | ([(SubExpRes
res, TypeBase Shape NoUniqueness
_, PatElem (VarWisdom, TypeBase Shape NoUniqueness)
screma_pe)], [(SubExpRes, TypeBase Shape NoUniqueness,
  PatElem (VarWisdom, TypeBase Shape NoUniqueness))]
map_pesres') <- ((SubExpRes, TypeBase Shape NoUniqueness,
  PatElem (VarWisdom, TypeBase Shape NoUniqueness))
 -> Bool)
-> [(SubExpRes, TypeBase Shape NoUniqueness,
     PatElem (VarWisdom, TypeBase Shape NoUniqueness))]
-> ([(SubExpRes, TypeBase Shape NoUniqueness,
      PatElem (VarWisdom, TypeBase Shape NoUniqueness))],
    [(SubExpRes, TypeBase Shape NoUniqueness,
      PatElem (VarWisdom, TypeBase Shape NoUniqueness))])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (SubExpRes, TypeBase Shape NoUniqueness,
 PatElem (VarWisdom, TypeBase Shape NoUniqueness))
-> Bool
matches [(SubExpRes, TypeBase Shape NoUniqueness,
  PatElem (VarWisdom, TypeBase Shape NoUniqueness))]
map_infos,
        Just TypeBase Shape NoUniqueness
t <- NameInfo (Wise SOACS) -> TypeBase Shape NoUniqueness
forall t. Typed t => t -> TypeBase Shape NoUniqueness
typeOf (NameInfo (Wise SOACS) -> TypeBase Shape NoUniqueness)
-> Maybe (NameInfo (Wise SOACS))
-> Maybe (TypeBase Shape NoUniqueness)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName
-> Map VName (NameInfo (Wise SOACS))
-> Maybe (NameInfo (Wise SOACS))
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
arr Map VName (NameInfo (Wise SOACS))
scope,
        (TypeBase Shape NoUniqueness, NewShape SubExp) -> Bool
invariantToMap (TypeBase Shape NoUniqueness
t, NewShape SubExp
new_shape) =
          let cs :: Certs
cs = StmAux (ExpWisdom, ()) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> SubExpRes -> Certs
resCerts SubExpRes
res
              transform :: (VName, Certs, VName -> Exp (Wise SOACS))
transform = (VName
arr, Certs
cs, BasicOp -> Exp (Wise SOACS)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Wise SOACS))
-> (VName -> BasicOp) -> VName -> Exp (Wise SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> NewShape SubExp -> BasicOp)
-> NewShape SubExp -> VName -> BasicOp
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> NewShape SubExp -> BasicOp
Reshape (SubExp -> NewShape SubExp -> NewShape SubExp
reshapeInner SubExp
w NewShape SubExp
new_shape))
           in ((TypeBase Shape NoUniqueness
t, PatElem (VarWisdom, TypeBase Shape NoUniqueness)
screma_pe, (VName, Certs, VName -> Exp (Wise SOACS))
transform) (TypeBase Shape NoUniqueness,
 PatElem (VarWisdom, TypeBase Shape NoUniqueness),
 (VName, Certs, VName -> Exp (Wise SOACS)))
-> [(TypeBase Shape NoUniqueness,
     PatElem (VarWisdom, TypeBase Shape NoUniqueness),
     (VName, Certs, VName -> Exp (Wise SOACS)))]
-> [(TypeBase Shape NoUniqueness,
     PatElem (VarWisdom, TypeBase Shape NoUniqueness),
     (VName, Certs, VName -> Exp (Wise SOACS)))]
forall a. a -> [a] -> [a]
: [(TypeBase Shape NoUniqueness,
  PatElem (VarWisdom, TypeBase Shape NoUniqueness),
  (VName, Certs, VName -> Exp (Wise SOACS)))]
transformed, [(SubExpRes, TypeBase Shape NoUniqueness,
  PatElem (VarWisdom, TypeBase Shape NoUniqueness))]
map_pesres', Stms (Wise SOACS)
stms)
      where
        matches :: (SubExpRes, TypeBase Shape NoUniqueness,
 PatElem (VarWisdom, TypeBase Shape NoUniqueness))
-> Bool
matches (SubExpRes
r, TypeBase Shape NoUniqueness
_, PatElem (VarWisdom, TypeBase Shape NoUniqueness)
_) = SubExpRes -> SubExp
resSubExp SubExpRes
r SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== VName -> SubExp
Var (PatElem (VarWisdom, TypeBase Shape NoUniqueness) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (VarWisdom, TypeBase Shape NoUniqueness)
PatElem (LetDec (Wise SOACS))
pe)
    onStm ([(TypeBase Shape NoUniqueness,
  PatElem (VarWisdom, TypeBase Shape NoUniqueness),
  (VName, Certs, VName -> Exp (Wise SOACS)))]
transformed, [(SubExpRes, TypeBase Shape NoUniqueness,
  PatElem (VarWisdom, TypeBase Shape NoUniqueness))]
map_infos, Stms (Wise SOACS)
stms) Stm (Wise SOACS)
stm =
      ([(TypeBase Shape NoUniqueness,
  PatElem (VarWisdom, TypeBase Shape NoUniqueness),
  (VName, Certs, VName -> Exp (Wise SOACS)))]
transformed, [(SubExpRes, TypeBase Shape NoUniqueness,
  PatElem (VarWisdom, TypeBase Shape NoUniqueness))]
map_infos, Stms (Wise SOACS)
stms Stms (Wise SOACS) -> Stms (Wise SOACS) -> Stms (Wise SOACS)
forall a. Semigroup a => a -> a -> a
<> Stm (Wise SOACS) -> Stms (Wise SOACS)
forall rep. Stm rep -> Stms rep
oneStm Stm (Wise SOACS)
stm)

    mkTransformed :: (b, PatElem dec, (VName, Certs, VName -> Exp (Rep m)))
-> m (SubExpRes, b, VName, m ())
mkTransformed (b
t, PatElem dec
pe, (VName
arr, Certs
cs, VName -> Exp (Rep m)
f)) = do
      VName
v <- Name -> m VName
forall (m :: * -> *). MonadFreshNames m => Name -> m VName
newVName (VName -> Name
baseName (PatElem dec -> VName
forall dec. PatElem dec -> VName
patElemName PatElem dec
pe) Name -> Name -> Name
forall a. Semigroup a => a -> a -> a
<> Name
"_pretr")
      let bind :: m ()
bind = [VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElem dec -> VName
forall dec. PatElem dec -> VName
patElemName PatElem dec
pe] (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp (Rep m)
f VName
v
      (SubExpRes, b, VName, m ()) -> m (SubExpRes, b, VName, m ())
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Certs -> SubExp -> SubExpRes
SubExpRes Certs
cs (VName -> SubExp
Var VName
arr), b
t, VName
v, m ()
bind)
moveTransformToOutput SymbolTable (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ =
  Rule (Wise SOACS)
forall rep. Rule rep
Skip