{-# LANGUAGE Strict #-}

-- | This module consists of rules for fusion
--     that involves WithAcc constructs.
--   Currently, we support two non-trivial
--   transformations:
--     I. map-flatten-scatter: a map nest produces
--        multi-dimensional index and values arrays
--        that are then flattened and used in a
--        scatter consumer. Such pattern can be fused
--        by re-writing the scatter by means of a WithAcc
--        containing a map-nest, thus eliminating the flatten
--        operations. The obtained WithAcc can then be fused
--        with the producer map nest, e.g., benefiting intra-group
--        kernels. The eloquent target for this rule is
--        an efficient implementation of radix-sort.
--
--    II. WithAcc-WithAcc fusion: two withaccs can be
--        fused as long as the common accumulators use
--        the same operator, and as long as the non-accumulator
--        input of an WithAcc is not used as an accumulator in
--        the other. This fusion opens the door for fusing
--        the SOACs appearing inside the WithAccs. This is
--        also intended to demonstrate that it is not so
--        important where exactly the WithAccs were originally
--        introduced in the code, it is more important that
--        they can be transformed by various optimizations passes.
module Futhark.Optimise.Fusion.RulesWithAccs
  ( tryFuseWithAccs,
  )
where

import Control.Monad
import Data.Map.Strict qualified as M
import Futhark.Construct
import Futhark.IR.SOACS hiding (SOAC (..))
import Futhark.Transform.Rename
import Futhark.Transform.Substitute

---------------------------------------------------
--- II. WithAcc-WithAcc Fusion
---------------------------------------------------

-- | Local helper type that tuples together:
--   1.   the pattern element corresponding to one withacc input
--   2.   the withacc input
--   3-5  withacc's lambda corresponding acc-certificate param,
--           argument param and result name
type AccTup =
  ( [PatElem (LetDec SOACS)],
    WithAccInput SOACS,
    LParam SOACS,
    LParam SOACS,
    (VName, Certs)
  )

accTup1 :: AccTup -> [PatElem (LetDec SOACS)]
accTup1 :: AccTup -> [PatElem (LetDec SOACS)]
accTup1 ([PatElem (LetDec SOACS)]
a, WithAccInput SOACS
_, LParam SOACS
_, LParam SOACS
_, (VName, Certs)
_) = [PatElem (LetDec SOACS)]
a

accTup2 :: AccTup -> WithAccInput SOACS
accTup2 :: AccTup -> WithAccInput SOACS
accTup2 ([PatElem (LetDec SOACS)]
_, WithAccInput SOACS
a, LParam SOACS
_, LParam SOACS
_, (VName, Certs)
_) = WithAccInput SOACS
a

accTup3 :: AccTup -> LParam SOACS
accTup3 :: AccTup -> LParam SOACS
accTup3 ([PatElem (LetDec SOACS)]
_, WithAccInput SOACS
_, LParam SOACS
a, LParam SOACS
_, (VName, Certs)
_) = LParam SOACS
a

accTup4 :: AccTup -> LParam SOACS
accTup4 :: AccTup -> LParam SOACS
accTup4 ([PatElem (LetDec SOACS)]
_, WithAccInput SOACS
_, LParam SOACS
_, LParam SOACS
a, (VName, Certs)
_) = LParam SOACS
a

accTup5 :: AccTup -> (VName, Certs)
accTup5 :: AccTup -> (VName, Certs)
accTup5 ([PatElem (LetDec SOACS)]
_, WithAccInput SOACS
_, LParam SOACS
_, LParam SOACS
_, (VName, Certs)
a) = (VName, Certs)
a

-- | Simple case for fusing two withAccs (can be extended):
--    let (b1, ..., bm, x1, ..., xq) = withAcc a1 ... am lam1
--    let (d1, ..., dn, y1, ..., yp) = withAcc c1 ... cn lam2
-- Notation: `b1 ... bm` are the accumulator results of the
--     first withAcc and `d1, ..., dn` of the second withAcc.
--     `x1 ... xq` and `y1, ..., yp` are non-accumulator results.
-- Conservative conditions:
--   1. for any bi (i=1..m) either `bi IN {c1, ..., cm}` OR
--        `bi NOT-IN FV(lam2)`, i.e., perfect producer-consumer
--        relation on accums. Of course the binary-op should
--        be the same.
--   2. The `bs` that are also accumulated upon in lam2
--        do NOT belong to the `infusible` set (they are destroyed)
--   3. x1 ... xq do not overlap with c1 ... cn
-- Fusion will create one withacc that accumulates on the
--   union of `a1 ... am` and `c1 ... cn` and returns, in addition
--   to the accumulator arrays the union of regular variables
--   `x1 ... xq` and `y1, ..., yp`
tryFuseWithAccs ::
  (HasScope SOACS m, MonadFreshNames m) =>
  [VName] ->
  Stm SOACS ->
  Stm SOACS ->
  Maybe (m (Stm SOACS))
tryFuseWithAccs :: forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
[VName] -> Stm SOACS -> Stm SOACS -> Maybe (m (Stm SOACS))
tryFuseWithAccs
  [VName]
infusible
  (Let Pat (LetDec SOACS)
pat1 StmAux (ExpDec SOACS)
aux1 (WithAcc [WithAccInput SOACS]
w_inps1 Lambda SOACS
lam1))
  (Let Pat (LetDec SOACS)
pat2 StmAux (ExpDec SOACS)
aux2 (WithAcc [WithAccInput SOACS]
w_inps2 Lambda SOACS
lam2))
    | ([PatElem Type]
pat1_els, [PatElem Type]
pat2_els) <- (Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
Pat (LetDec SOACS)
pat1, Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
Pat (LetDec SOACS)
pat2),
      ([AccTup]
acc_tup1, [(PatElem (LetDec SOACS), SubExpRes)]
other_pr1) <- [PatElem (LetDec SOACS)]
-> [WithAccInput SOACS]
-> Lambda SOACS
-> ([AccTup], [(PatElem (LetDec SOACS), SubExpRes)])
groupAccs [PatElem Type]
[PatElem (LetDec SOACS)]
pat1_els [WithAccInput SOACS]
w_inps1 Lambda SOACS
lam1,
      ([AccTup]
acc_tup2, [(PatElem (LetDec SOACS), SubExpRes)]
other_pr2) <- [PatElem (LetDec SOACS)]
-> [WithAccInput SOACS]
-> Lambda SOACS
-> ([AccTup], [(PatElem (LetDec SOACS), SubExpRes)])
groupAccs [PatElem Type]
[PatElem (LetDec SOACS)]
pat2_els [WithAccInput SOACS]
w_inps2 Lambda SOACS
lam2,
      ([(AccTup, AccTup)]
tup_common, [AccTup]
acc_tup1', [AccTup]
acc_tup2') <-
        [AccTup] -> [AccTup] -> ([(AccTup, AccTup)], [AccTup], [AccTup])
groupCommonAccs [AccTup]
acc_tup1 [AccTup]
acc_tup2,
      -- safety 0: make sure that the accs from acc_tup1' and
      --           acc_tup2' do not overlap
      [VName]
pnms_1' <- (PatElem Type -> VName) -> [PatElem Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName ([PatElem Type] -> [VName]) -> [PatElem Type] -> [VName]
forall a b. (a -> b) -> a -> b
$ (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))
 -> [PatElem Type])
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs))]
-> [PatElem Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\([PatElem Type]
nms, WithAccInput SOACS
_, Param Type
_, Param Type
_, (VName, Certs)
_) -> [PatElem Type]
nms) [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))]
[AccTup]
acc_tup1',
      [VName]
winp_2' <- (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))
 -> [VName])
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs))]
-> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\([PatElem Type]
_, (ShapeBase SubExp
_, [VName]
nms, Maybe (Lambda SOACS, [SubExp])
_), Param Type
_, Param Type
_, (VName, Certs)
_) -> [VName]
nms) [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))]
[AccTup]
acc_tup2',
      Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Names -> Names -> Bool
namesIntersect ([VName] -> Names
namesFromList [VName]
pnms_1') ([VName] -> Names
namesFromList [VName]
winp_2'),
      -- safety 1: we have already determined the commons;
      --           now we also need to check NOT-IN FV(lam2)
      Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Names -> Names -> Bool
namesIntersect ([VName] -> Names
namesFromList [VName]
pnms_1') (Lambda SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn Lambda SOACS
lam2),
      -- safety 2:
      -- bs <- map patElemName $ concatMap accTup1 acc_tup1,
      [VName]
bs <- (PatElem Type -> VName) -> [PatElem Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName ([PatElem Type] -> [VName]) -> [PatElem Type] -> [VName]
forall a b. (a -> b) -> a -> b
$ ((([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
   (VName, Certs)),
  ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
   (VName, Certs)))
 -> [PatElem Type])
-> [(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
      (VName, Certs)),
     ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
      (VName, Certs)))]
-> [PatElem Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
 (VName, Certs))
-> [PatElem Type]
AccTup -> [PatElem (LetDec SOACS)]
accTup1 (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))
 -> [PatElem Type])
-> ((([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
      (VName, Certs)),
     ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
      (VName, Certs)))
    -> ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
        (VName, Certs)))
-> (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs)),
    ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs)))
-> [PatElem Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs)),
 ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs)))
-> ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
    (VName, Certs))
forall a b. (a, b) -> a
fst) [(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
   (VName, Certs)),
  ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
   (VName, Certs)))]
[(AccTup, AccTup)]
tup_common,
      (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
infusible) [VName]
bs,
      -- safety 3:
      Names
cs <- [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))
 -> [VName])
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs))]
-> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((\(ShapeBase SubExp
_, [VName]
xs, Maybe (Lambda SOACS, [SubExp])
_) -> [VName]
xs) (WithAccInput SOACS -> [VName])
-> (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs))
    -> WithAccInput SOACS)
-> ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
    (VName, Certs))
-> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
 (VName, Certs))
-> WithAccInput SOACS
AccTup -> WithAccInput SOACS
accTup2) [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))]
[AccTup]
acc_tup2,
      ((PatElem Type, SubExpRes) -> Bool)
-> [(PatElem Type, SubExpRes)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((VName -> Names -> Bool
`notNameIn` Names
cs) (VName -> Bool)
-> ((PatElem Type, SubExpRes) -> VName)
-> (PatElem Type, SubExpRes)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName (PatElem Type -> VName)
-> ((PatElem Type, SubExpRes) -> PatElem Type)
-> (PatElem Type, SubExpRes)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElem Type, SubExpRes) -> PatElem Type
forall a b. (a, b) -> a
fst) [(PatElem Type, SubExpRes)]
[(PatElem (LetDec SOACS), SubExpRes)]
other_pr1 = m (Stm SOACS) -> Maybe (m (Stm SOACS))
forall a. a -> Maybe a
Just (m (Stm SOACS) -> Maybe (m (Stm SOACS)))
-> m (Stm SOACS) -> Maybe (m (Stm SOACS))
forall a b. (a -> b) -> a -> b
$ do
        let getCertPairs :: (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs)),
 ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs)))
-> (VName, VName)
getCertPairs (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
 (VName, Certs))
t1, ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
 (VName, Certs))
t2) = (Param Type -> VName
forall dec. Param dec -> VName
paramName (AccTup -> LParam SOACS
accTup3 ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
 (VName, Certs))
AccTup
t2), Param Type -> VName
forall dec. Param dec -> VName
paramName (AccTup -> LParam SOACS
accTup3 ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
 (VName, Certs))
AccTup
t1))
            tab_certs :: Map VName VName
tab_certs = [(VName, VName)] -> Map VName VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, VName)] -> Map VName VName)
-> [(VName, VName)] -> Map VName VName
forall a b. (a -> b) -> a -> b
$ ((([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
   (VName, Certs)),
  ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
   (VName, Certs)))
 -> (VName, VName))
-> [(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
      (VName, Certs)),
     ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
      (VName, Certs)))]
-> [(VName, VName)]
forall a b. (a -> b) -> [a] -> [b]
map (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs)),
 ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs)))
-> (VName, VName)
getCertPairs [(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
   (VName, Certs)),
  ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
   (VName, Certs)))]
[(AccTup, AccTup)]
tup_common
            lam2_bdy' :: Body SOACS
lam2_bdy' = Map VName VName -> Body SOACS -> Body SOACS
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
tab_certs (Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam2)
            rcrt_params :: [Param Type]
rcrt_params = ((([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
   (VName, Certs)),
  ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
   (VName, Certs)))
 -> Param Type)
-> [(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
      (VName, Certs)),
     ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
      (VName, Certs)))]
-> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
 (VName, Certs))
-> Param Type
AccTup -> LParam SOACS
accTup3 (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))
 -> Param Type)
-> ((([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
      (VName, Certs)),
     ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
      (VName, Certs)))
    -> ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
        (VName, Certs)))
-> (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs)),
    ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs)))
-> Param Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs)),
 ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs)))
-> ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
    (VName, Certs))
forall a b. (a, b) -> a
fst) [(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
   (VName, Certs)),
  ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
   (VName, Certs)))]
[(AccTup, AccTup)]
tup_common [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))
 -> Param Type)
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs))]
-> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
 (VName, Certs))
-> Param Type
AccTup -> LParam SOACS
accTup3 [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))]
[AccTup]
acc_tup1' [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))
 -> Param Type)
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs))]
-> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
 (VName, Certs))
-> Param Type
AccTup -> LParam SOACS
accTup3 [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))]
[AccTup]
acc_tup2'
            racc_params :: [Param Type]
racc_params = ((([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
   (VName, Certs)),
  ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
   (VName, Certs)))
 -> Param Type)
-> [(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
      (VName, Certs)),
     ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
      (VName, Certs)))]
-> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
 (VName, Certs))
-> Param Type
AccTup -> LParam SOACS
accTup4 (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))
 -> Param Type)
-> ((([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
      (VName, Certs)),
     ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
      (VName, Certs)))
    -> ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
        (VName, Certs)))
-> (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs)),
    ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs)))
-> Param Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs)),
 ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs)))
-> ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
    (VName, Certs))
forall a b. (a, b) -> a
fst) [(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
   (VName, Certs)),
  ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
   (VName, Certs)))]
[(AccTup, AccTup)]
tup_common [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))
 -> Param Type)
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs))]
-> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
 (VName, Certs))
-> Param Type
AccTup -> LParam SOACS
accTup4 [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))]
[AccTup]
acc_tup1' [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))
 -> Param Type)
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs))]
-> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
 (VName, Certs))
-> Param Type
AccTup -> LParam SOACS
accTup4 [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))]
[AccTup]
acc_tup2'
            ([VName]
comm_res_nms, [Certs]
comm_res_certs2) = [(VName, Certs)] -> ([VName], [Certs])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, Certs)] -> ([VName], [Certs]))
-> [(VName, Certs)] -> ([VName], [Certs])
forall a b. (a -> b) -> a -> b
$ ((([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
   (VName, Certs)),
  ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
   (VName, Certs)))
 -> (VName, Certs))
-> [(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
      (VName, Certs)),
     ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
      (VName, Certs)))]
-> [(VName, Certs)]
forall a b. (a -> b) -> [a] -> [b]
map (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
 (VName, Certs))
-> (VName, Certs)
AccTup -> (VName, Certs)
accTup5 (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))
 -> (VName, Certs))
-> ((([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
      (VName, Certs)),
     ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
      (VName, Certs)))
    -> ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
        (VName, Certs)))
-> (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs)),
    ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs)))
-> (VName, Certs)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs)),
 ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs)))
-> ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
    (VName, Certs))
forall a b. (a, b) -> b
snd) [(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
   (VName, Certs)),
  ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
   (VName, Certs)))]
[(AccTup, AccTup)]
tup_common
            ([VName]
_, [Certs]
comm_res_certs1) = [(VName, Certs)] -> ([VName], [Certs])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, Certs)] -> ([VName], [Certs]))
-> [(VName, Certs)] -> ([VName], [Certs])
forall a b. (a -> b) -> a -> b
$ ((([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
   (VName, Certs)),
  ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
   (VName, Certs)))
 -> (VName, Certs))
-> [(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
      (VName, Certs)),
     ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
      (VName, Certs)))]
-> [(VName, Certs)]
forall a b. (a -> b) -> [a] -> [b]
map (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
 (VName, Certs))
-> (VName, Certs)
AccTup -> (VName, Certs)
accTup5 (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))
 -> (VName, Certs))
-> ((([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
      (VName, Certs)),
     ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
      (VName, Certs)))
    -> ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
        (VName, Certs)))
-> (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs)),
    ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs)))
-> (VName, Certs)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs)),
 ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs)))
-> ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
    (VName, Certs))
forall a b. (a, b) -> a
fst) [(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
   (VName, Certs)),
  ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
   (VName, Certs)))]
[(AccTup, AccTup)]
tup_common
            com_res_certs :: [Certs]
com_res_certs = (Certs -> Certs -> Certs) -> [Certs] -> [Certs] -> [Certs]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Certs
x Certs
y -> [VName] -> Certs
Certs (Certs -> [VName]
unCerts Certs
x [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ Certs -> [VName]
unCerts Certs
y)) [Certs]
comm_res_certs1 [Certs]
comm_res_certs2
            bdyres_certs :: [Certs]
bdyres_certs = [Certs]
com_res_certs [Certs] -> [Certs] -> [Certs]
forall a. [a] -> [a] -> [a]
++ (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))
 -> Certs)
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs))]
-> [Certs]
forall a b. (a -> b) -> [a] -> [b]
map ((VName, Certs) -> Certs
forall a b. (a, b) -> b
snd ((VName, Certs) -> Certs)
-> (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs))
    -> (VName, Certs))
-> ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
    (VName, Certs))
-> Certs
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
 (VName, Certs))
-> (VName, Certs)
AccTup -> (VName, Certs)
accTup5) ([([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))]
[AccTup]
acc_tup1' [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))]
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs))]
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs))]
forall a. [a] -> [a] -> [a]
++ [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))]
[AccTup]
acc_tup2')
            bdyres_accse :: [SubExp]
bdyres_accse = (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
comm_res_nms [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))
 -> SubExp)
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs))]
-> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs))
    -> VName)
-> ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
    (VName, Certs))
-> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, Certs) -> VName
forall a b. (a, b) -> a
fst ((VName, Certs) -> VName)
-> (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs))
    -> (VName, Certs))
-> ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
    (VName, Certs))
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
 (VName, Certs))
-> (VName, Certs)
AccTup -> (VName, Certs)
accTup5) ([([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))]
[AccTup]
acc_tup1' [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))]
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs))]
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs))]
forall a. [a] -> [a] -> [a]
++ [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))]
[AccTup]
acc_tup2')
            bdy_res_accs :: [SubExpRes]
bdy_res_accs = (Certs -> SubExp -> SubExpRes)
-> [Certs] -> [SubExp] -> [SubExpRes]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Certs -> SubExp -> SubExpRes
SubExpRes [Certs]
bdyres_certs [SubExp]
bdyres_accse
            bdy_res_others :: [SubExpRes]
bdy_res_others = ((PatElem Type, SubExpRes) -> SubExpRes)
-> [(PatElem Type, SubExpRes)] -> [SubExpRes]
forall a b. (a -> b) -> [a] -> [b]
map (PatElem Type, SubExpRes) -> SubExpRes
forall a b. (a, b) -> b
snd ([(PatElem Type, SubExpRes)] -> [SubExpRes])
-> [(PatElem Type, SubExpRes)] -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ [(PatElem Type, SubExpRes)]
[(PatElem (LetDec SOACS), SubExpRes)]
other_pr1 [(PatElem Type, SubExpRes)]
-> [(PatElem Type, SubExpRes)] -> [(PatElem Type, SubExpRes)]
forall a. [a] -> [a] -> [a]
++ [(PatElem Type, SubExpRes)]
[(PatElem (LetDec SOACS), SubExpRes)]
other_pr2
        Scope SOACS
scope <- m (Scope SOACS)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
        Body SOACS
lam_bdy <-
          Builder SOACS [SubExpRes] -> m (Body SOACS)
forall rep (m :: * -> *) somerep res.
(Buildable rep, MonadFreshNames m, HasScope somerep m,
 SameScope somerep rep, IsResult res) =>
Builder rep [res] -> m (GBody rep res)
runBodyBuilder (Builder SOACS [SubExpRes] -> m (Body SOACS))
-> Builder SOACS [SubExpRes] -> m (Body SOACS)
forall a b. (a -> b) -> a -> b
$ do
            Scope SOACS
-> Builder SOACS [SubExpRes] -> Builder SOACS [SubExpRes]
forall a.
Scope SOACS
-> BuilderT SOACS (State VNameSource) a
-> BuilderT SOACS (State VNameSource) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Scope SOACS
scope Scope SOACS -> Scope SOACS -> Scope SOACS
forall a. Semigroup a => a -> a -> a
<> [Param Type] -> Scope SOACS
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams ([Param Type]
rcrt_params [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type]
racc_params)) (Builder SOACS [SubExpRes] -> Builder SOACS [SubExpRes])
-> Builder SOACS [SubExpRes] -> Builder SOACS [SubExpRes]
forall a b. (a -> b) -> a -> b
$ do
              -- add the stms of lam1
              (Stm SOACS -> BuilderT SOACS (State VNameSource) ())
-> [Stm SOACS] -> BuilderT SOACS (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ()
Stm SOACS -> BuilderT SOACS (State VNameSource) ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm ([Stm SOACS] -> BuilderT SOACS (State VNameSource) ())
-> [Stm SOACS] -> BuilderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms SOACS -> [Stm SOACS]) -> Stms SOACS -> [Stm SOACS]
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Stms SOACS
forall rep res. GBody rep res -> Stms rep
bodyStms (Body SOACS -> Stms SOACS) -> Body SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam1
              -- add the copy stms for the common accumulator
              [(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
   (VName, Certs)),
  ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
   (VName, Certs)))]
-> ((([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
      (VName, Certs)),
     ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
      (VName, Certs)))
    -> BuilderT SOACS (State VNameSource) ())
-> BuilderT SOACS (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
   (VName, Certs)),
  ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
   (VName, Certs)))]
[(AccTup, AccTup)]
tup_common (((([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
    (VName, Certs)),
   ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
    (VName, Certs)))
  -> BuilderT SOACS (State VNameSource) ())
 -> BuilderT SOACS (State VNameSource) ())
-> ((([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
      (VName, Certs)),
     ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
      (VName, Certs)))
    -> BuilderT SOACS (State VNameSource) ())
-> BuilderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ \(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
 (VName, Certs))
tup1, ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
 (VName, Certs))
tup2) -> do
                let (Param Type
lpar1, Param Type
lpar2) = (AccTup -> LParam SOACS
accTup4 ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
 (VName, Certs))
AccTup
tup1, AccTup -> LParam SOACS
accTup4 ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
 (VName, Certs))
AccTup
tup2)
                    ((VName
nm1, Certs
_), VName
nm2, Type
tp_acc) = (AccTup -> (VName, Certs)
accTup5 ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
 (VName, Certs))
AccTup
tup1, Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
lpar2, Param Type -> Type
forall dec. Param dec -> dec
paramDec Param Type
lpar1)
                Pat (LetDec (Rep (BuilderT SOACS (State VNameSource))))
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind ([PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [VName -> Type -> PatElem Type
forall dec. VName -> dec -> PatElem dec
PatElem VName
nm2 Type
tp_acc]) (Exp (Rep (BuilderT SOACS (State VNameSource)))
 -> BuilderT SOACS (State VNameSource) ())
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (BuilderT SOACS (State VNameSource)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT SOACS (State VNameSource))))
-> BasicOp -> Exp (Rep (BuilderT SOACS (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
nm1
              -- add copy stms to bring in scope x1 ... xq
              [(PatElem Type, SubExpRes)]
-> ((PatElem Type, SubExpRes)
    -> BuilderT SOACS (State VNameSource) ())
-> BuilderT SOACS (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(PatElem Type, SubExpRes)]
[(PatElem (LetDec SOACS), SubExpRes)]
other_pr1 (((PatElem Type, SubExpRes)
  -> BuilderT SOACS (State VNameSource) ())
 -> BuilderT SOACS (State VNameSource) ())
-> ((PatElem Type, SubExpRes)
    -> BuilderT SOACS (State VNameSource) ())
-> BuilderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ \(PatElem Type
pat_elm, SubExpRes
bdy_res) -> do
                let (VName
nm, SubExp
se, Type
tp) = (PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName PatElem Type
pat_elm, SubExpRes -> SubExp
resSubExp SubExpRes
bdy_res, PatElem Type -> Type
forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem Type
pat_elm)
                Certs
-> BuilderT SOACS (State VNameSource) ()
-> BuilderT SOACS (State VNameSource) ()
forall a.
Certs
-> BuilderT SOACS (State VNameSource) a
-> BuilderT SOACS (State VNameSource) a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (SubExpRes -> Certs
resCerts SubExpRes
bdy_res) (BuilderT SOACS (State VNameSource) ()
 -> BuilderT SOACS (State VNameSource) ())
-> BuilderT SOACS (State VNameSource) ()
-> BuilderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
                  Pat (LetDec (Rep (BuilderT SOACS (State VNameSource))))
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind ([PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [VName -> Type -> PatElem Type
forall dec. VName -> dec -> PatElem dec
PatElem VName
nm Type
tp]) (Exp (Rep (BuilderT SOACS (State VNameSource)))
 -> BuilderT SOACS (State VNameSource) ())
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
                    BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (SubExp -> BasicOp
SubExp SubExp
se)
              -- add the statements of lam2 (in which the acc-certificates have been substituted)
              (Stm SOACS -> BuilderT SOACS (State VNameSource) ())
-> [Stm SOACS] -> BuilderT SOACS (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ()
Stm SOACS -> BuilderT SOACS (State VNameSource) ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm ([Stm SOACS] -> BuilderT SOACS (State VNameSource) ())
-> [Stm SOACS] -> BuilderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms SOACS -> [Stm SOACS]) -> Stms SOACS -> [Stm SOACS]
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Stms SOACS
forall rep res. GBody rep res -> Stms rep
bodyStms Body SOACS
lam2_bdy'
              -- build the result of body
              [SubExpRes] -> Builder SOACS [SubExpRes]
forall a. a -> BuilderT SOACS (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([SubExpRes] -> Builder SOACS [SubExpRes])
-> [SubExpRes] -> Builder SOACS [SubExpRes]
forall a b. (a -> b) -> a -> b
$ [SubExpRes]
bdy_res_accs [SubExpRes] -> [SubExpRes] -> [SubExpRes]
forall a. [a] -> [a] -> [a]
++ [SubExpRes]
bdy_res_others
        let tp_res_other :: [Type]
tp_res_other = ((PatElem Type, SubExpRes) -> Type)
-> [(PatElem Type, SubExpRes)] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (PatElem Type -> Type
forall dec. Typed dec => PatElem dec -> Type
patElemType (PatElem Type -> Type)
-> ((PatElem Type, SubExpRes) -> PatElem Type)
-> (PatElem Type, SubExpRes)
-> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElem Type, SubExpRes) -> PatElem Type
forall a b. (a, b) -> a
fst) ([(PatElem Type, SubExpRes)]
[(PatElem (LetDec SOACS), SubExpRes)]
other_pr1 [(PatElem Type, SubExpRes)]
-> [(PatElem Type, SubExpRes)] -> [(PatElem Type, SubExpRes)]
forall a. [a] -> [a] -> [a]
++ [(PatElem Type, SubExpRes)]
[(PatElem (LetDec SOACS), SubExpRes)]
other_pr2)
            res_lam :: Lambda SOACS
res_lam =
              Lambda
                { lambdaParams :: [LParam SOACS]
lambdaParams = [Param Type]
rcrt_params [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type]
racc_params,
                  lambdaBody :: Body SOACS
lambdaBody = Body SOACS
lam_bdy,
                  lambdaReturnType :: [Type]
lambdaReturnType = (Param Type -> Type) -> [Param Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> Type
forall dec. Param dec -> dec
paramDec [Param Type]
racc_params [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
tp_res_other
                }
        Lambda SOACS
res_lam' <- Lambda SOACS -> m (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
res_lam
        let res_pat :: [PatElem Type]
res_pat =
              ((([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
   (VName, Certs)),
  ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
   (VName, Certs)))
 -> [PatElem Type])
-> [(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
      (VName, Certs)),
     ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
      (VName, Certs)))]
-> [PatElem Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
 (VName, Certs))
-> [PatElem Type]
AccTup -> [PatElem (LetDec SOACS)]
accTup1 (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))
 -> [PatElem Type])
-> ((([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
      (VName, Certs)),
     ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
      (VName, Certs)))
    -> ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
        (VName, Certs)))
-> (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs)),
    ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs)))
-> [PatElem Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs)),
 ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs)))
-> ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
    (VName, Certs))
forall a b. (a, b) -> b
snd) [(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
   (VName, Certs)),
  ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
   (VName, Certs)))]
[(AccTup, AccTup)]
tup_common
                [PatElem Type] -> [PatElem Type] -> [PatElem Type]
forall a. [a] -> [a] -> [a]
++ (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))
 -> [PatElem Type])
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs))]
-> [PatElem Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
 (VName, Certs))
-> [PatElem Type]
AccTup -> [PatElem (LetDec SOACS)]
accTup1 ([([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))]
[AccTup]
acc_tup1' [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))]
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs))]
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs))]
forall a. [a] -> [a] -> [a]
++ [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))]
[AccTup]
acc_tup2')
                [PatElem Type] -> [PatElem Type] -> [PatElem Type]
forall a. [a] -> [a] -> [a]
++ ((PatElem Type, SubExpRes) -> PatElem Type)
-> [(PatElem Type, SubExpRes)] -> [PatElem Type]
forall a b. (a -> b) -> [a] -> [b]
map (PatElem Type, SubExpRes) -> PatElem Type
forall a b. (a, b) -> a
fst ([(PatElem Type, SubExpRes)]
[(PatElem (LetDec SOACS), SubExpRes)]
other_pr1 [(PatElem Type, SubExpRes)]
-> [(PatElem Type, SubExpRes)] -> [(PatElem Type, SubExpRes)]
forall a. [a] -> [a] -> [a]
++ [(PatElem Type, SubExpRes)]
[(PatElem (LetDec SOACS), SubExpRes)]
other_pr2)
            res_w_inps :: [WithAccInput SOACS]
res_w_inps = ((([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
   (VName, Certs)),
  ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
   (VName, Certs)))
 -> WithAccInput SOACS)
-> [(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
      (VName, Certs)),
     ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
      (VName, Certs)))]
-> [WithAccInput SOACS]
forall a b. (a -> b) -> [a] -> [b]
map (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
 (VName, Certs))
-> WithAccInput SOACS
AccTup -> WithAccInput SOACS
accTup2 (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))
 -> WithAccInput SOACS)
-> ((([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
      (VName, Certs)),
     ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
      (VName, Certs)))
    -> ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
        (VName, Certs)))
-> (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs)),
    ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs)))
-> WithAccInput SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs)),
 ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs)))
-> ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
    (VName, Certs))
forall a b. (a, b) -> a
fst) [(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
   (VName, Certs)),
  ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
   (VName, Certs)))]
[(AccTup, AccTup)]
tup_common [WithAccInput SOACS]
-> [WithAccInput SOACS] -> [WithAccInput SOACS]
forall a. [a] -> [a] -> [a]
++ (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))
 -> WithAccInput SOACS)
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs))]
-> [WithAccInput SOACS]
forall a b. (a -> b) -> [a] -> [b]
map ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
 (VName, Certs))
-> WithAccInput SOACS
AccTup -> WithAccInput SOACS
accTup2 ([([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))]
[AccTup]
acc_tup1' [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))]
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs))]
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs))]
forall a. [a] -> [a] -> [a]
++ [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))]
[AccTup]
acc_tup2')
        [WithAccInput SOACS]
res_w_inps' <- (WithAccInput SOACS -> m (WithAccInput SOACS))
-> [WithAccInput SOACS] -> m [WithAccInput SOACS]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM WithAccInput SOACS -> m (WithAccInput SOACS)
forall {m :: * -> *} {rep} {a} {b} {b}.
(Rename (OpC rep rep), Rename (LetDec rep), Rename (ExpDec rep),
 Rename (BodyDec rep), Rename (FParamInfo rep),
 Rename (LParamInfo rep), Rename (RetType rep),
 Rename (BranchType rep), MonadFreshNames m) =>
(a, b, Maybe (Lambda rep, b)) -> m (a, b, Maybe (Lambda rep, b))
renameLamInWAccInp [WithAccInput SOACS]
res_w_inps
        Stm SOACS -> m (Stm SOACS)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm SOACS -> m (Stm SOACS)) -> Stm SOACS -> m (Stm SOACS)
forall a b. (a -> b) -> a -> b
$ Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type]
res_pat) (StmAux ()
StmAux (ExpDec SOACS)
aux1 StmAux () -> StmAux () -> StmAux ()
forall a. Semigroup a => a -> a -> a
<> StmAux ()
StmAux (ExpDec SOACS)
aux2) (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ [WithAccInput SOACS] -> Lambda SOACS -> Exp SOACS
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [WithAccInput SOACS]
res_w_inps' Lambda SOACS
res_lam'
    where
      -- local helpers:

      groupAccs ::
        [PatElem (LetDec SOACS)] ->
        [WithAccInput SOACS] ->
        Lambda SOACS ->
        ([AccTup], [(PatElem (LetDec SOACS), SubExpRes)])
      groupAccs :: [PatElem (LetDec SOACS)]
-> [WithAccInput SOACS]
-> Lambda SOACS
-> ([AccTup], [(PatElem (LetDec SOACS), SubExpRes)])
groupAccs [PatElem (LetDec SOACS)]
pat_els [WithAccInput SOACS]
wacc_inps Lambda SOACS
wlam =
        let lam_params :: [LParam SOACS]
lam_params = Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
wlam
            n :: Int
n = [Param Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Param Type]
lam_params
            ([Param Type]
lam_par_crts, [Param Type]
lam_par_accs) = Int -> [Param Type] -> ([Param Type], [Param Type])
forall a. Int -> [a] -> ([a], [a])
splitAt (Int
n Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) [Param Type]
lam_params
            lab_res_ses :: [SubExpRes]
lab_res_ses = Body SOACS -> [SubExpRes]
forall rep res. GBody rep res -> [res]
bodyResult (Body SOACS -> [SubExpRes]) -> Body SOACS -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
wlam
         in [PatElem (LetDec SOACS)]
-> [WithAccInput SOACS]
-> [LParam SOACS]
-> [LParam SOACS]
-> [SubExpRes]
-> ([AccTup], [(PatElem (LetDec SOACS), SubExpRes)])
groupAccsHlp [PatElem (LetDec SOACS)]
pat_els [WithAccInput SOACS]
wacc_inps [Param Type]
[LParam SOACS]
lam_par_crts [Param Type]
[LParam SOACS]
lam_par_accs [SubExpRes]
lab_res_ses
      groupAccsHlp ::
        [PatElem (LetDec SOACS)] ->
        [WithAccInput SOACS] ->
        [LParam SOACS] ->
        [LParam SOACS] ->
        [SubExpRes] ->
        ([AccTup], [(PatElem (LetDec SOACS), SubExpRes)])
      groupAccsHlp :: [PatElem (LetDec SOACS)]
-> [WithAccInput SOACS]
-> [LParam SOACS]
-> [LParam SOACS]
-> [SubExpRes]
-> ([AccTup], [(PatElem (LetDec SOACS), SubExpRes)])
groupAccsHlp [PatElem (LetDec SOACS)]
pat_els [] [] [] [SubExpRes]
lam_res_ses
        | [PatElem Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElem Type]
[PatElem (LetDec SOACS)]
pat_els Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExpRes] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExpRes]
lam_res_ses =
            ([], [PatElem Type] -> [SubExpRes] -> [(PatElem Type, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem Type]
[PatElem (LetDec SOACS)]
pat_els [SubExpRes]
lam_res_ses)
      groupAccsHlp
        [PatElem (LetDec SOACS)]
pat_els
        (winp :: WithAccInput SOACS
winp@(ShapeBase SubExp
_, [VName]
inp, Maybe (Lambda SOACS, [SubExp])
_) : [WithAccInput SOACS]
wacc_inps)
        (LParam SOACS
par_crt : [LParam SOACS]
lam_par_crts)
        (LParam SOACS
par_acc : [LParam SOACS]
lam_par_accs)
        (SubExpRes
res_se : [SubExpRes]
lam_res_ses)
          | Int
n <- [VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
inp,
            (Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= [PatElem Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElem Type]
[PatElem (LetDec SOACS)]
pat_els) Bool -> Bool -> Bool
&& (Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [SubExpRes] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExpRes]
lam_res_ses)),
            Var VName
res_nm <- SubExpRes -> SubExp
resSubExp SubExpRes
res_se =
              let ([PatElem Type]
pat_els_cur, [PatElem Type]
pat_els') = Int -> [PatElem Type] -> ([PatElem Type], [PatElem Type])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
n [PatElem Type]
[PatElem (LetDec SOACS)]
pat_els
                  ([AccTup]
rec1, [(PatElem (LetDec SOACS), SubExpRes)]
rec2) = [PatElem (LetDec SOACS)]
-> [WithAccInput SOACS]
-> [LParam SOACS]
-> [LParam SOACS]
-> [SubExpRes]
-> ([AccTup], [(PatElem (LetDec SOACS), SubExpRes)])
groupAccsHlp [PatElem Type]
[PatElem (LetDec SOACS)]
pat_els' [WithAccInput SOACS]
wacc_inps [LParam SOACS]
lam_par_crts [LParam SOACS]
lam_par_accs [SubExpRes]
lam_res_ses
               in (([PatElem Type]
pat_els_cur, WithAccInput SOACS
winp, Param Type
LParam SOACS
par_crt, Param Type
LParam SOACS
par_acc, (VName
res_nm, SubExpRes -> Certs
resCerts SubExpRes
res_se)) ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
 (VName, Certs))
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs))]
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs))]
forall a. a -> [a] -> [a]
: [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))]
rec1, [(PatElem Type, SubExpRes)]
[(PatElem (LetDec SOACS), SubExpRes)]
rec2)
      groupAccsHlp [PatElem (LetDec SOACS)]
_ [WithAccInput SOACS]
_ [LParam SOACS]
_ [LParam SOACS]
_ [SubExpRes]
_ =
        [Char]
-> ([([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
      (VName, Certs))],
    [(PatElem Type, SubExpRes)])
forall a. HasCallStack => [Char] -> a
error [Char]
"Unreachable case reached in groupAccsHlp!"
      --
      groupCommonAccs :: [AccTup] -> [AccTup] -> ([(AccTup, AccTup)], [AccTup], [AccTup])
      groupCommonAccs :: [AccTup] -> [AccTup] -> ([(AccTup, AccTup)], [AccTup], [AccTup])
groupCommonAccs [] [AccTup]
tup_accs2 =
        ([], [], [AccTup]
tup_accs2)
      groupCommonAccs (AccTup
tup_acc1 : [AccTup]
tup_accs1) [AccTup]
tup_accs2
        | [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))]
commons2 <- (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))
 -> Bool)
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs))]
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs))]
forall a. (a -> Bool) -> [a] -> [a]
filter (AccTup -> AccTup -> Bool
matchingAccTup AccTup
tup_acc1) [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))]
[AccTup]
tup_accs2,
          [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))]
-> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))]
commons2 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1 =
            let ([(AccTup, AccTup)]
rec1, [AccTup]
rec2, [AccTup]
rec3) =
                  [AccTup] -> [AccTup] -> ([(AccTup, AccTup)], [AccTup], [AccTup])
groupCommonAccs [AccTup]
tup_accs1 ([AccTup] -> ([(AccTup, AccTup)], [AccTup], [AccTup]))
-> [AccTup] -> ([(AccTup, AccTup)], [AccTup], [AccTup])
forall a b. (a -> b) -> a -> b
$
                    if [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))]
-> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))]
commons2
                      then [AccTup]
tup_accs2
                      else (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))
 -> Bool)
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs))]
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs))]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool)
-> (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs))
    -> Bool)
-> ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
    (VName, Certs))
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AccTup -> AccTup -> Bool
matchingAccTup AccTup
tup_acc1) [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))]
[AccTup]
tup_accs2
             in if [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))]
-> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))]
commons2
                  then ([(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
   (VName, Certs)),
  ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
   (VName, Certs)))]
[(AccTup, AccTup)]
rec1, ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
 (VName, Certs))
AccTup
tup_acc1 ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
 (VName, Certs))
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs))]
-> [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
     (VName, Certs))]
forall a. a -> [a] -> [a]
: [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))]
rec2, [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))]
[AccTup]
rec3)
                  else ((([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
 (VName, Certs))
AccTup
tup_acc1, [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))]
-> ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
    (VName, Certs))
forall a. HasCallStack => [a] -> a
head [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))]
commons2) (([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs)),
 ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs)))
-> [(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
      (VName, Certs)),
     ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
      (VName, Certs)))]
-> [(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
      (VName, Certs)),
     ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
      (VName, Certs)))]
forall a. a -> [a] -> [a]
: [(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
   (VName, Certs)),
  ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
   (VName, Certs)))]
rec1, [AccTup]
tup_accs1, [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
  (VName, Certs))]
[AccTup]
rec3)
      groupCommonAccs [AccTup]
_ [AccTup]
_ =
        [Char]
-> ([(([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
       (VName, Certs)),
      ([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
       (VName, Certs)))],
    [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
      (VName, Certs))],
    [([PatElem Type], WithAccInput SOACS, Param Type, Param Type,
      (VName, Certs))])
forall a. HasCallStack => [Char] -> a
error [Char]
"Unreachable case reached in groupCommonAccs!"
      renameLamInWAccInp :: (a, b, Maybe (Lambda rep, b)) -> m (a, b, Maybe (Lambda rep, b))
renameLamInWAccInp (a
shp, b
inps, Just (Lambda rep
lam, b
se)) = do
        Lambda rep
lam' <- Lambda rep -> m (Lambda rep)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda rep
lam
        (a, b, Maybe (Lambda rep, b)) -> m (a, b, Maybe (Lambda rep, b))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
shp, b
inps, (Lambda rep, b) -> Maybe (Lambda rep, b)
forall a. a -> Maybe a
Just (Lambda rep
lam', b
se))
      renameLamInWAccInp (a, b, Maybe (Lambda rep, b))
winp = (a, b, Maybe (Lambda rep, b)) -> m (a, b, Maybe (Lambda rep, b))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a, b, Maybe (Lambda rep, b))
winp
--
tryFuseWithAccs [VName]
_ Stm SOACS
_ Stm SOACS
_ =
  Maybe (m (Stm SOACS))
forall a. Maybe a
Nothing

-------------------------------
--- simple helper functions ---
-------------------------------

equivLambda ::
  M.Map VName VName ->
  Lambda SOACS ->
  Lambda SOACS ->
  Bool
equivLambda :: Map VName VName -> Lambda SOACS -> Lambda SOACS -> Bool
equivLambda Map VName VName
stab Lambda SOACS
lam1 Lambda SOACS
lam2
  | ([Param Type]
ps1, [Param Type]
ps2) <- (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam1, Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam2),
    ([VName]
nms1, [VName]
nms2) <- ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
ps1, (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
ps2),
    (Param Type -> Type) -> [Param Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> Type
forall dec. Param dec -> dec
paramDec [Param Type]
ps1 [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== (Param Type -> Type) -> [Param Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> Type
forall dec. Param dec -> dec
paramDec [Param Type]
ps2,
    (Param Type -> Attrs) -> [Param Type] -> [Attrs]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> Attrs
forall dec. Param dec -> Attrs
paramAttrs [Param Type]
ps1 [Attrs] -> [Attrs] -> Bool
forall a. Eq a => a -> a -> Bool
== (Param Type -> Attrs) -> [Param Type] -> [Attrs]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> Attrs
forall dec. Param dec -> Attrs
paramAttrs [Param Type]
ps2,
    Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
lam1 [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
lam2,
    (Body SOACS
bdy1, Body SOACS
bdy2) <- (Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam1, Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam2),
    Body SOACS -> BodyDec SOACS
forall rep res. GBody rep res -> BodyDec rep
bodyDec Body SOACS
bdy1 () -> () -> Bool
forall a. Eq a => a -> a -> Bool
== Body SOACS -> BodyDec SOACS
forall rep res. GBody rep res -> BodyDec rep
bodyDec Body SOACS
bdy2 =
      let insert :: Map k a -> (a, k) -> Map k a
insert Map k a
tab (a
x, k
k) = k -> a -> Map k a -> Map k a
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert k
k a
x Map k a
tab
          stab' :: Map VName VName
stab' = (Map VName VName -> (VName, VName) -> Map VName VName)
-> Map VName VName -> [(VName, VName)] -> Map VName VName
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Map VName VName -> (VName, VName) -> Map VName VName
forall {k} {a}. Ord k => Map k a -> (a, k) -> Map k a
insert Map VName VName
stab ([(VName, VName)] -> Map VName VName)
-> [(VName, VName)] -> Map VName VName
forall a b. (a -> b) -> a -> b
$ [VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
nms1 [VName]
nms2
          fStm :: (Map VName VName, Bool)
-> (Stm SOACS, Stm SOACS) -> (Map VName VName, Bool)
fStm (Map VName VName
vtab, Bool
False) (Stm SOACS, Stm SOACS)
_ = (Map VName VName
vtab, Bool
False)
          fStm (Map VName VName
vtab, Bool
True) (Stm SOACS
s1, Stm SOACS
s2) = Map VName VName
-> Stm SOACS -> Stm SOACS -> (Map VName VName, Bool)
equivStm Map VName VName
vtab Stm SOACS
s1 Stm SOACS
s2
          (Map VName VName
stab'', Bool
success) =
            ((Map VName VName, Bool)
 -> (Stm SOACS, Stm SOACS) -> (Map VName VName, Bool))
-> (Map VName VName, Bool)
-> [(Stm SOACS, Stm SOACS)]
-> (Map VName VName, Bool)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (Map VName VName, Bool)
-> (Stm SOACS, Stm SOACS) -> (Map VName VName, Bool)
fStm (Map VName VName
stab', Bool
True) ([(Stm SOACS, Stm SOACS)] -> (Map VName VName, Bool))
-> [(Stm SOACS, Stm SOACS)] -> (Map VName VName, Bool)
forall a b. (a -> b) -> a -> b
$
              [Stm SOACS] -> [Stm SOACS] -> [(Stm SOACS, Stm SOACS)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Stms SOACS -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList (Body SOACS -> Stms SOACS
forall rep res. GBody rep res -> Stms rep
bodyStms Body SOACS
bdy1)) ([Stm SOACS] -> [(Stm SOACS, Stm SOACS)])
-> [Stm SOACS] -> [(Stm SOACS, Stm SOACS)]
forall a b. (a -> b) -> a -> b
$
                Stms SOACS -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList (Body SOACS -> Stms SOACS
forall rep res. GBody rep res -> Stms rep
bodyStms Body SOACS
bdy2)
          sres2 :: [SubExp]
sres2 = Map VName VName -> [SubExp] -> [SubExp]
substInSEs Map VName VName
stab'' ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ([SubExpRes] -> [SubExp]) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Body SOACS -> [SubExpRes]
forall rep res. GBody rep res -> [res]
bodyResult Body SOACS
bdy2
       in Bool
success Bool -> Bool -> Bool
&& (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (Body SOACS -> [SubExpRes]
forall rep res. GBody rep res -> [res]
bodyResult Body SOACS
bdy1) [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp]
sres2
equivLambda Map VName VName
_ Lambda SOACS
_ Lambda SOACS
_ =
  Bool
False

equivStm ::
  M.Map VName VName ->
  Stm SOACS ->
  Stm SOACS ->
  (M.Map VName VName, Bool)
equivStm :: Map VName VName
-> Stm SOACS -> Stm SOACS -> (Map VName VName, Bool)
equivStm
  Map VName VName
stab
  (Let Pat (LetDec SOACS)
pat1 StmAux (ExpDec SOACS)
aux1 (BasicOp (BinOp BinOp
bop1 SubExp
se11 SubExp
se12)))
  (Let Pat (LetDec SOACS)
pat2 StmAux (ExpDec SOACS)
aux2 (BasicOp (BinOp BinOp
bop2 SubExp
se21 SubExp
se22)))
    | [SubExp
se11, SubExp
se12] [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== Map VName VName -> [SubExp] -> [SubExp]
substInSEs Map VName VName
stab [SubExp
se21, SubExp
se22],
      ([PatElem Type]
pels1, [PatElem Type]
pels2) <- (Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
Pat (LetDec SOACS)
pat1, Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
Pat (LetDec SOACS)
pat2),
      (PatElem Type -> Type) -> [PatElem Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map PatElem Type -> Type
forall dec. PatElem dec -> dec
patElemDec [PatElem Type]
pels1 [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== (PatElem Type -> Type) -> [PatElem Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map PatElem Type -> Type
forall dec. PatElem dec -> dec
patElemDec [PatElem Type]
pels2,
      BinOp
bop1 BinOp -> BinOp -> Bool
forall a. Eq a => a -> a -> Bool
== BinOp
bop2 Bool -> Bool -> Bool
&& StmAux ()
StmAux (ExpDec SOACS)
aux1 StmAux () -> StmAux () -> Bool
forall a. Eq a => a -> a -> Bool
== StmAux ()
StmAux (ExpDec SOACS)
aux2 =
        let stab_new :: Map VName VName
stab_new =
              [(VName, VName)] -> Map VName VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, VName)] -> Map VName VName)
-> [(VName, VName)] -> Map VName VName
forall a b. (a -> b) -> a -> b
$
                [VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((PatElem Type -> VName) -> [PatElem Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName [PatElem Type]
pels2) ((PatElem Type -> VName) -> [PatElem Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName [PatElem Type]
pels1)
         in (Map VName VName -> Map VName VName -> Map VName VName
forall k a. Ord k => Map k a -> Map k a -> Map k a
M.union Map VName VName
stab_new Map VName VName
stab, Bool
True)
-- To Be Continued ...
equivStm Map VName VName
vtab Stm SOACS
_ Stm SOACS
_ = (Map VName VName
vtab, Bool
False)

matchingAccTup :: AccTup -> AccTup -> Bool
matchingAccTup :: AccTup -> AccTup -> Bool
matchingAccTup
  ([PatElem (LetDec SOACS)]
pat_els1, (ShapeBase SubExp
shp1, [VName]
_winp_arrs1, Maybe (Lambda SOACS, [SubExp])
mlam1), LParam SOACS
_, LParam SOACS
_, (VName, Certs)
_)
  ([PatElem (LetDec SOACS)]
_, (ShapeBase SubExp
shp2, [VName]
winp_arrs2, Maybe (Lambda SOACS, [SubExp])
mlam2), LParam SOACS
_, LParam SOACS
_, (VName, Certs)
_) =
    ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shp1 [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shp2
      Bool -> Bool -> Bool
&& (PatElem Type -> VName) -> [PatElem Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName [PatElem Type]
[PatElem (LetDec SOACS)]
pat_els1 [VName] -> [VName] -> Bool
forall a. Eq a => a -> a -> Bool
== [VName]
winp_arrs2
      Bool -> Bool -> Bool
&& case (Maybe (Lambda SOACS, [SubExp])
mlam1, Maybe (Lambda SOACS, [SubExp])
mlam2) of
        (Maybe (Lambda SOACS, [SubExp])
Nothing, Maybe (Lambda SOACS, [SubExp])
Nothing) -> Bool
True
        (Just (Lambda SOACS
lam1, [SubExp]
see1), Just (Lambda SOACS
lam2, [SubExp]
see2)) ->
          ([SubExp]
see1 [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp]
see2) Bool -> Bool -> Bool
&& Map VName VName -> Lambda SOACS -> Lambda SOACS -> Bool
equivLambda Map VName VName
forall k a. Map k a
M.empty Lambda SOACS
lam1 Lambda SOACS
lam2
        (Maybe (Lambda SOACS, [SubExp]), Maybe (Lambda SOACS, [SubExp]))
_ -> Bool
False

substInSEs :: M.Map VName VName -> [SubExp] -> [SubExp]
substInSEs :: Map VName VName -> [SubExp] -> [SubExp]
substInSEs Map VName VName
vtab = (SubExp -> SubExp) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> SubExp
substInSE
  where
    substInSE :: SubExp -> SubExp
substInSE (Var VName
x)
      | Just VName
y <- VName -> Map VName VName -> Maybe VName
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
x Map VName VName
vtab = VName -> SubExp
Var VName
y
    substInSE SubExp
z = SubExp
z