{-# LANGUAGE TypeFamilies #-}

-- | An unstructured grab-bag of various tools and inspection
-- functions that didn't really fit anywhere else.
module Futhark.Tools
  ( module Futhark.Construct,
    redomapToMapAndReduce,
    scanomapToMapAndScan,
    dissectScrema,
    sequentialStreamWholeArray,
    partitionChunkedFoldParameters,
    withAcc,
    doScatter,

    -- * Primitive expressions
    module Futhark.Analysis.PrimExp.Convert,
  )
where

import Control.Monad
import Futhark.Analysis.PrimExp.Convert
import Futhark.Construct
import Futhark.IR
import Futhark.IR.SOACS.SOAC
import Futhark.Util

-- | Turns a binding of a @redomap@ into two seperate bindings, a
-- @map@ binding and a @reduce@ binding (returned in that order).
--
-- Reuses the original pattern for the @reduce@, and creates a new
-- pattern with new 'Ident's for the result of the @map@.
redomapToMapAndReduce ::
  ( MonadFreshNames m,
    Buildable rep,
    ExpDec rep ~ (),
    Op rep ~ SOAC rep
  ) =>
  Pat (LetDec rep) ->
  ( SubExp,
    [Reduce rep],
    Lambda rep,
    [VName]
  ) ->
  m (Stm rep, Stm rep)
redomapToMapAndReduce :: forall (m :: * -> *) rep.
(MonadFreshNames m, Buildable rep, ExpDec rep ~ (),
 Op rep ~ SOAC rep) =>
Pat (LetDec rep)
-> (SubExp, [Reduce rep], Lambda rep, [VName])
-> m (Stm rep, Stm rep)
redomapToMapAndReduce (Pat [PatElem (LetDec rep)]
pes) (SubExp
w, [Reduce rep]
reds, Lambda rep
map_lam, [VName]
arrs) = do
  ([Ident]
map_pat, Pat (LetDec rep)
red_pat, [VName]
red_arrs) <-
    [PatElem (LetDec rep)]
-> SubExp
-> Lambda rep
-> [[SubExp]]
-> m ([Ident], Pat (LetDec rep), [VName])
forall dec (m :: * -> *) rep.
(Typed dec, MonadFreshNames m) =>
[PatElem dec]
-> SubExp
-> Lambda rep
-> [[SubExp]]
-> m ([Ident], Pat dec, [VName])
splitScanOrRedomap [PatElem (LetDec rep)]
pes SubExp
w Lambda rep
map_lam ([[SubExp]] -> m ([Ident], Pat (LetDec rep), [VName]))
-> [[SubExp]] -> m ([Ident], Pat (LetDec rep), [VName])
forall a b. (a -> b) -> a -> b
$ (Reduce rep -> [SubExp]) -> [Reduce rep] -> [[SubExp]]
forall a b. (a -> b) -> [a] -> [b]
map Reduce rep -> [SubExp]
forall rep. Reduce rep -> [SubExp]
redNeutral [Reduce rep]
reds
  let map_stm :: Stm rep
map_stm = [Ident] -> Exp rep -> Stm rep
forall rep. Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet [Ident]
map_pat (Exp rep -> Stm rep) -> Exp rep -> Stm rep
forall a b. (a -> b) -> a -> b
$ Op rep -> Exp rep
forall rep. Op rep -> Exp rep
Op (Op rep -> Exp rep) -> Op rep -> Exp rep
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm rep -> SOAC rep
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs (Lambda rep -> ScremaForm rep
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda rep
map_lam)
  Stm rep
red_stm <-
    Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec rep)
red_pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp rep -> Stm rep)
-> (SOAC rep -> Exp rep) -> SOAC rep -> Stm rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op rep -> Exp rep
SOAC rep -> Exp rep
forall rep. Op rep -> Exp rep
Op
      (SOAC rep -> Stm rep) -> m (SOAC rep) -> m (Stm rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> [VName] -> ScremaForm rep -> SOAC rep
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
red_arrs (ScremaForm rep -> SOAC rep) -> m (ScremaForm rep) -> m (SOAC rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Reduce rep] -> m (ScremaForm rep)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Reduce rep] -> m (ScremaForm rep)
reduceSOAC [Reduce rep]
reds)
  (Stm rep, Stm rep) -> m (Stm rep, Stm rep)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm rep
map_stm, Stm rep
red_stm)

scanomapToMapAndScan ::
  ( MonadFreshNames m,
    Buildable rep,
    ExpDec rep ~ (),
    Op rep ~ SOAC rep
  ) =>
  Pat (LetDec rep) ->
  ( SubExp,
    [Scan rep],
    Lambda rep,
    [VName]
  ) ->
  m (Stm rep, Stm rep)
scanomapToMapAndScan :: forall (m :: * -> *) rep.
(MonadFreshNames m, Buildable rep, ExpDec rep ~ (),
 Op rep ~ SOAC rep) =>
Pat (LetDec rep)
-> (SubExp, [Scan rep], Lambda rep, [VName])
-> m (Stm rep, Stm rep)
scanomapToMapAndScan (Pat [PatElem (LetDec rep)]
pes) (SubExp
w, [Scan rep]
scans, Lambda rep
map_lam, [VName]
arrs) = do
  ([Ident]
map_pat, Pat (LetDec rep)
scan_pat, [VName]
scan_arrs) <-
    [PatElem (LetDec rep)]
-> SubExp
-> Lambda rep
-> [[SubExp]]
-> m ([Ident], Pat (LetDec rep), [VName])
forall dec (m :: * -> *) rep.
(Typed dec, MonadFreshNames m) =>
[PatElem dec]
-> SubExp
-> Lambda rep
-> [[SubExp]]
-> m ([Ident], Pat dec, [VName])
splitScanOrRedomap [PatElem (LetDec rep)]
pes SubExp
w Lambda rep
map_lam ([[SubExp]] -> m ([Ident], Pat (LetDec rep), [VName]))
-> [[SubExp]] -> m ([Ident], Pat (LetDec rep), [VName])
forall a b. (a -> b) -> a -> b
$ (Scan rep -> [SubExp]) -> [Scan rep] -> [[SubExp]]
forall a b. (a -> b) -> [a] -> [b]
map Scan rep -> [SubExp]
forall rep. Scan rep -> [SubExp]
scanNeutral [Scan rep]
scans
  let map_stm :: Stm rep
map_stm = [Ident] -> Exp rep -> Stm rep
forall rep. Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet [Ident]
map_pat (Exp rep -> Stm rep) -> Exp rep -> Stm rep
forall a b. (a -> b) -> a -> b
$ Op rep -> Exp rep
forall rep. Op rep -> Exp rep
Op (Op rep -> Exp rep) -> Op rep -> Exp rep
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm rep -> SOAC rep
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs (Lambda rep -> ScremaForm rep
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda rep
map_lam)
  Stm rep
scan_stm <-
    Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec rep)
scan_pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp rep -> Stm rep)
-> (SOAC rep -> Exp rep) -> SOAC rep -> Stm rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op rep -> Exp rep
SOAC rep -> Exp rep
forall rep. Op rep -> Exp rep
Op
      (SOAC rep -> Stm rep) -> m (SOAC rep) -> m (Stm rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> [VName] -> ScremaForm rep -> SOAC rep
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
scan_arrs (ScremaForm rep -> SOAC rep) -> m (ScremaForm rep) -> m (SOAC rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Scan rep] -> m (ScremaForm rep)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Scan rep] -> m (ScremaForm rep)
scanSOAC [Scan rep]
scans)
  (Stm rep, Stm rep) -> m (Stm rep, Stm rep)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm rep
map_stm, Stm rep
scan_stm)

splitScanOrRedomap ::
  (Typed dec, MonadFreshNames m) =>
  [PatElem dec] ->
  SubExp ->
  Lambda rep ->
  [[SubExp]] ->
  m ([Ident], Pat dec, [VName])
splitScanOrRedomap :: forall dec (m :: * -> *) rep.
(Typed dec, MonadFreshNames m) =>
[PatElem dec]
-> SubExp
-> Lambda rep
-> [[SubExp]]
-> m ([Ident], Pat dec, [VName])
splitScanOrRedomap [PatElem dec]
pes SubExp
w Lambda rep
map_lam [[SubExp]]
nes = do
  let ([PatElem dec]
acc_pes, [PatElem dec]
arr_pes) =
        Int -> [PatElem dec] -> ([PatElem dec], [PatElem dec])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int) -> [SubExp] -> Int
forall a b. (a -> b) -> a -> b
$ [[SubExp]] -> [SubExp]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[SubExp]]
nes) [PatElem dec]
pes
      ([Type]
acc_ts, [Type]
_arr_ts) =
        Int -> [Type] -> ([Type], [Type])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([[SubExp]] -> [SubExp]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[SubExp]]
nes)) ([Type] -> ([Type], [Type])) -> [Type] -> ([Type], [Type])
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
map_lam
  [Ident]
map_accpat <- (PatElem dec -> Type -> m Ident)
-> [PatElem dec] -> [Type] -> m [Ident]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM PatElem dec -> Type -> m Ident
accMapPatElem [PatElem dec]
acc_pes [Type]
acc_ts
  [Ident]
map_arrpat <- (PatElem dec -> m Ident) -> [PatElem dec] -> m [Ident]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM PatElem dec -> m Ident
arrMapPatElem [PatElem dec]
arr_pes
  let map_pat :: [Ident]
map_pat = [Ident]
map_accpat [Ident] -> [Ident] -> [Ident]
forall a. [a] -> [a] -> [a]
++ [Ident]
map_arrpat
  ([Ident], Pat dec, [VName]) -> m ([Ident], Pat dec, [VName])
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Ident]
map_pat, [PatElem dec] -> Pat dec
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem dec]
acc_pes, (Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
map_accpat)
  where
    accMapPatElem :: PatElem dec -> Type -> m Ident
accMapPatElem PatElem dec
pe Type
acc_t =
      Name -> Type -> m Ident
forall (m :: * -> *). MonadFreshNames m => Name -> Type -> m Ident
newIdent (VName -> Name
baseName (PatElem dec -> VName
forall dec. PatElem dec -> VName
patElemName PatElem dec
pe) Name -> Name -> Name
forall a. Semigroup a => a -> a -> a
<> Name
"_map_acc") (Type -> m Ident) -> Type -> m Ident
forall a b. (a -> b) -> a -> b
$ Type
acc_t Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
w
    arrMapPatElem :: PatElem dec -> m Ident
arrMapPatElem = Ident -> m Ident
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Ident -> m Ident)
-> (PatElem dec -> Ident) -> PatElem dec -> m Ident
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem dec -> Ident
forall dec. Typed dec => PatElem dec -> Ident
patElemIdent

-- | Turn a Screma into a Scanomap (possibly with mapout parts) and a
-- Redomap.  This is used to handle Scremas that are so complicated
-- that we cannot directly generate efficient parallel code for them.
-- In essense, what happens is the opposite of horisontal fusion.
dissectScrema ::
  ( MonadBuilder m,
    Op (Rep m) ~ SOAC (Rep m),
    Buildable (Rep m)
  ) =>
  Pat (LetDec (Rep m)) ->
  SubExp ->
  ScremaForm (Rep m) ->
  [VName] ->
  m ()
dissectScrema :: forall (m :: * -> *).
(MonadBuilder m, Op (Rep m) ~ SOAC (Rep m), Buildable (Rep m)) =>
Pat (LetDec (Rep m))
-> SubExp -> ScremaForm (Rep m) -> [VName] -> m ()
dissectScrema Pat (LetDec (Rep m))
pat SubExp
w (ScremaForm Lambda (Rep m)
map_lam [Scan (Rep m)]
scans [Reduce (Rep m)]
reds) [VName]
arrs = do
  let num_reds :: Int
num_reds = [Reduce (Rep m)] -> Int
forall rep. [Reduce rep] -> Int
redResults [Reduce (Rep m)]
reds
      num_scans :: Int
num_scans = [Scan (Rep m)] -> Int
forall rep. [Scan rep] -> Int
scanResults [Scan (Rep m)]
scans
      ([VName]
scan_res, [VName]
red_res, [VName]
map_res) = Int -> Int -> [VName] -> ([VName], [VName], [VName])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 Int
num_scans Int
num_reds ([VName] -> ([VName], [VName], [VName]))
-> [VName] -> ([VName], [VName], [VName])
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep m)) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec (Rep m))
pat

  [VName]
to_red <- Int -> m VName -> m [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
num_reds (m VName -> m [VName]) -> m VName -> m [VName]
forall a b. (a -> b) -> a -> b
$ Name -> m VName
forall (m :: * -> *). MonadFreshNames m => Name -> m VName
newVName Name
"to_red"

  let scanomap :: ScremaForm (Rep m)
scanomap = [Scan (Rep m)] -> Lambda (Rep m) -> ScremaForm (Rep m)
forall rep. [Scan rep] -> Lambda rep -> ScremaForm rep
scanomapSOAC [Scan (Rep m)]
scans Lambda (Rep m)
map_lam
  [VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames ([VName]
scan_res [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
to_red [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
map_res) (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$
    Op (Rep m) -> Exp (Rep m)
forall rep. Op rep -> Exp rep
Op (SubExp -> [VName] -> ScremaForm (Rep m) -> SOAC (Rep m)
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs ScremaForm (Rep m)
scanomap)

  ScremaForm (Rep m)
reduce <- [Reduce (Rep m)] -> m (ScremaForm (Rep m))
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Reduce rep] -> m (ScremaForm rep)
reduceSOAC [Reduce (Rep m)]
reds
  [VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName]
red_res (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ Op (Rep m) -> Exp (Rep m)
forall rep. Op rep -> Exp rep
Op (Op (Rep m) -> Exp (Rep m)) -> Op (Rep m) -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm (Rep m) -> SOAC (Rep m)
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
to_red ScremaForm (Rep m)
reduce

-- | Turn a stream SOAC into statements that apply the stream lambda
-- to the entire input.
sequentialStreamWholeArray ::
  (MonadBuilder m, Buildable (Rep m)) =>
  Pat (LetDec (Rep m)) ->
  SubExp ->
  [SubExp] ->
  Lambda (Rep m) ->
  [VName] ->
  m ()
sequentialStreamWholeArray :: forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
Pat (LetDec (Rep m))
-> SubExp -> [SubExp] -> Lambda (Rep m) -> [VName] -> m ()
sequentialStreamWholeArray Pat (LetDec (Rep m))
pat SubExp
w [SubExp]
nes Lambda (Rep m)
lam [VName]
arrs = do
  -- We just set the chunksize to w and inline the lambda body.  There
  -- is no difference between parallel and sequential streams here.
  let (Param (LParamInfo (Rep m))
chunk_size_param, [Param (LParamInfo (Rep m))]
fold_params, [Param (LParamInfo (Rep m))]
arr_params) =
        Int
-> [Param (LParamInfo (Rep m))]
-> (Param (LParamInfo (Rep m)), [Param (LParamInfo (Rep m))],
    [Param (LParamInfo (Rep m))])
forall dec.
Int -> [Param dec] -> (Param dec, [Param dec], [Param dec])
partitionChunkedFoldParameters ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Param (LParamInfo (Rep m))]
 -> (Param (LParamInfo (Rep m)), [Param (LParamInfo (Rep m))],
     [Param (LParamInfo (Rep m))]))
-> [Param (LParamInfo (Rep m))]
-> (Param (LParamInfo (Rep m)), [Param (LParamInfo (Rep m))],
    [Param (LParamInfo (Rep m))])
forall a b. (a -> b) -> a -> b
$ Lambda (Rep m) -> [Param (LParamInfo (Rep m))]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda (Rep m)
lam

  -- The chunk size is the full size of the array.
  [VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param (LParamInfo (Rep m)) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo (Rep m))
chunk_size_param] (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
w

  -- The accumulator parameters are initialised to the neutral element.
  [(Param (LParamInfo (Rep m)), SubExp)]
-> ((Param (LParamInfo (Rep m)), SubExp) -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (LParamInfo (Rep m))]
-> [SubExp] -> [(Param (LParamInfo (Rep m)), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (LParamInfo (Rep m))]
fold_params [SubExp]
nes) (((Param (LParamInfo (Rep m)), SubExp) -> m ()) -> m ())
-> ((Param (LParamInfo (Rep m)), SubExp) -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \(Param (LParamInfo (Rep m))
p, SubExp
ne) ->
    [VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param (LParamInfo (Rep m)) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo (Rep m))
p] (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
ne

  -- Finally, the array parameters are set to the arrays (but reshaped
  -- to make the types work out; this will be simplified rapidly).
  [(Param (LParamInfo (Rep m)), VName)]
-> ((Param (LParamInfo (Rep m)), VName) -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (LParamInfo (Rep m))]
-> [VName] -> [(Param (LParamInfo (Rep m)), VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (LParamInfo (Rep m))]
arr_params [VName]
arrs) (((Param (LParamInfo (Rep m)), VName) -> m ()) -> m ())
-> ((Param (LParamInfo (Rep m)), VName) -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \(Param (LParamInfo (Rep m))
p, VName
arr) ->
    [VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param (LParamInfo (Rep m)) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo (Rep m))
p] (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$
      if [SubExp] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (Type -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims (Type -> [SubExp]) -> Type -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Param (LParamInfo (Rep m)) -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param (LParamInfo (Rep m))
p)
        then BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr
        else [SubExp] -> VName -> Exp (Rep m)
forall rep. [SubExp] -> VName -> Exp rep
shapeCoerce (Type -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims (Type -> [SubExp]) -> Type -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Param (LParamInfo (Rep m)) -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param (LParamInfo (Rep m))
p) VName
arr

  -- Then we just inline the lambda body.
  (Stm (Rep m) -> m ()) -> Seq (Stm (Rep m)) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm (Rep m) -> m ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Seq (Stm (Rep m)) -> m ()) -> Seq (Stm (Rep m)) -> m ()
forall a b. (a -> b) -> a -> b
$ GBody (Rep m) SubExpRes -> Seq (Stm (Rep m))
forall rep res. GBody rep res -> Stms rep
bodyStms (GBody (Rep m) SubExpRes -> Seq (Stm (Rep m)))
-> GBody (Rep m) SubExpRes -> Seq (Stm (Rep m))
forall a b. (a -> b) -> a -> b
$ Lambda (Rep m) -> GBody (Rep m) SubExpRes
forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Rep m)
lam

  -- The number of results in the body matches exactly the size (and
  -- order) of 'pat', so we bind them up here, again with a reshape to
  -- make the types work out.
  [(PatElem (LetDec (Rep m)), SubExpRes)]
-> ((PatElem (LetDec (Rep m)), SubExpRes) -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem (LetDec (Rep m))]
-> [SubExpRes] -> [(PatElem (LetDec (Rep m)), SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat (LetDec (Rep m)) -> [PatElem (LetDec (Rep m))]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec (Rep m))
pat) ([SubExpRes] -> [(PatElem (LetDec (Rep m)), SubExpRes)])
-> [SubExpRes] -> [(PatElem (LetDec (Rep m)), SubExpRes)]
forall a b. (a -> b) -> a -> b
$ GBody (Rep m) SubExpRes -> [SubExpRes]
forall rep res. GBody rep res -> [res]
bodyResult (GBody (Rep m) SubExpRes -> [SubExpRes])
-> GBody (Rep m) SubExpRes -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ Lambda (Rep m) -> GBody (Rep m) SubExpRes
forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Rep m)
lam) (((PatElem (LetDec (Rep m)), SubExpRes) -> m ()) -> m ())
-> ((PatElem (LetDec (Rep m)), SubExpRes) -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \(PatElem (LetDec (Rep m))
pe, SubExpRes Certs
cs SubExp
se) ->
    Certs -> m () -> m ()
forall a. Certs -> m a -> m a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ case (Type -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims (Type -> [SubExp]) -> Type -> [SubExp]
forall a b. (a -> b) -> a -> b
$ PatElem (LetDec (Rep m)) -> Type
forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem (LetDec (Rep m))
pe, SubExp
se) of
      ([SubExp]
dims, Var VName
v)
        | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [SubExp]
dims ->
            [VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElem (LetDec (Rep m)) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec (Rep m))
pe] (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> Exp (Rep m)
forall rep. [SubExp] -> VName -> Exp rep
shapeCoerce [SubExp]
dims VName
v
      ([SubExp], SubExp)
_ -> [VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElem (LetDec (Rep m)) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec (Rep m))
pe] (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se

-- | Split the parameters of a stream reduction lambda into the chunk
-- size parameter, the accumulator parameters, and the input chunk
-- parameters.  The integer argument is how many accumulators are
-- used.
partitionChunkedFoldParameters ::
  Int ->
  [Param dec] ->
  (Param dec, [Param dec], [Param dec])
partitionChunkedFoldParameters :: forall dec.
Int -> [Param dec] -> (Param dec, [Param dec], [Param dec])
partitionChunkedFoldParameters Int
_ [] =
  [Char] -> (Param dec, [Param dec], [Param dec])
forall a. HasCallStack => [Char] -> a
error [Char]
"partitionChunkedFoldParameters: lambda takes no parameters"
partitionChunkedFoldParameters Int
num_accs (Param dec
chunk_param : [Param dec]
params) =
  let ([Param dec]
acc_params, [Param dec]
arr_params) = Int -> [Param dec] -> ([Param dec], [Param dec])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_accs [Param dec]
params
   in (Param dec
chunk_param, [Param dec]
acc_params, [Param dec]
arr_params)

-- | Construct a one-dimensional scatter-like 'WithAcc'. The closure is invoked
-- with the accumulators.
withAcc ::
  (MonadBuilder m, LParam (Rep m) ~ Param Type) =>
  [VName] ->
  Int ->
  ([VName] -> m [SubExp]) ->
  m (Exp (Rep m))
withAcc :: forall (m :: * -> *).
(MonadBuilder m, LParam (Rep m) ~ Param Type) =>
[VName] -> Int -> ([VName] -> m [SubExp]) -> m (Exp (Rep m))
withAcc [VName]
dest Int
rank [VName] -> m [SubExp]
mk = do
  [Param Type]
cert_ps <- Int -> m (Param Type) -> m [Param Type]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM ([VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
dest) (m (Param Type) -> m [Param Type])
-> m (Param Type) -> m [Param Type]
forall a b. (a -> b) -> a -> b
$ Name -> Type -> m (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
Name -> dec -> m (Param dec)
newParam Name
"acc_cert" (Type -> m (Param Type)) -> Type -> m (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit
  [Type]
dest_ts <- (VName -> m Type) -> [VName] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
dest
  let acc_shape :: ShapeBase SubExp
acc_shape = [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> ShapeBase SubExp) -> [SubExp] -> ShapeBase SubExp
forall a b. (a -> b) -> a -> b
$ Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
take Int
rank ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims (Type -> [SubExp]) -> Type -> [SubExp]
forall a b. (a -> b) -> a -> b
$ [Type] -> Type
forall a. HasCallStack => [a] -> a
head [Type]
dest_ts
      mkT :: VName -> Type -> Type
mkT VName
cert Type
elem_t = VName -> ShapeBase SubExp -> [Type] -> NoUniqueness -> Type
forall shape u.
VName -> ShapeBase SubExp -> [Type] -> u -> TypeBase shape u
Acc VName
cert ShapeBase SubExp
acc_shape [Type
elem_t] NoUniqueness
NoUniqueness
      acc_ts :: [Type]
acc_ts =
        (VName -> Type -> Type) -> [VName] -> [Type] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> Type -> Type
mkT ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
cert_ps) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$
          (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Type -> Type
forall u.
Int
-> TypeBase (ShapeBase SubExp) u -> TypeBase (ShapeBase SubExp) u
stripArray Int
rank) [Type]
dest_ts
  [Param Type]
acc_ps <- (Type -> m (Param Type)) -> [Type] -> m [Param Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Name -> Type -> m (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
Name -> dec -> m (Param dec)
newParam Name
"acc_p") [Type]
acc_ts

  Lambda (Rep m)
withacc_lam <- [LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda ([Param Type]
cert_ps [Param Type] -> [Param Type] -> [Param Type]
forall a. Semigroup a => a -> a -> a
<> [Param Type]
acc_ps) (m [SubExpRes] -> m (Lambda (Rep m)))
-> m [SubExpRes] -> m (Lambda (Rep m))
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [SubExpRes]
subExpsRes ([SubExp] -> [SubExpRes]) -> m [SubExp] -> m [SubExpRes]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [VName] -> m [SubExp]
mk ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
acc_ps)

  Exp (Rep m) -> m (Exp (Rep m))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Rep m) -> m (Exp (Rep m))) -> Exp (Rep m) -> m (Exp (Rep m))
forall a b. (a -> b) -> a -> b
$ [WithAccInput (Rep m)] -> Lambda (Rep m) -> Exp (Rep m)
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [(ShapeBase SubExp
acc_shape, [VName
v], Maybe (Lambda (Rep m), [SubExp])
forall a. Maybe a
Nothing) | VName
v <- [VName]
dest] Lambda (Rep m)
withacc_lam

-- | Perform a scatter-like operation using accumulators and map.
doScatter ::
  (MonadBuilder m, Buildable (Rep m), Op (Rep m) ~ SOAC (Rep m)) =>
  Name ->
  Int ->
  [VName] ->
  [VName] ->
  ([LParam (Rep m)] -> m [SubExp]) ->
  m [VName]
doScatter :: forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m), Op (Rep m) ~ SOAC (Rep m)) =>
Name
-> Int
-> [VName]
-> [VName]
-> ([LParam (Rep m)] -> m [SubExp])
-> m [VName]
doScatter Name
desc Int
rank [VName]
dest [VName]
arrs [LParam (Rep m)] -> m [SubExp]
mk = do
  [Param Type]
cert_ps <- Int -> m (Param Type) -> m [Param Type]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM ([VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
dest) (m (Param Type) -> m [Param Type])
-> m (Param Type) -> m [Param Type]
forall a b. (a -> b) -> a -> b
$ Name -> Type -> m (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
Name -> dec -> m (Param dec)
newParam Name
"acc_cert" (Type -> m (Param Type)) -> Type -> m (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit
  [Type]
dest_ts <- (VName -> m Type) -> [VName] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
dest
  let acc_shape :: ShapeBase SubExp
acc_shape = [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> ShapeBase SubExp) -> [SubExp] -> ShapeBase SubExp
forall a b. (a -> b) -> a -> b
$ Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
take Int
rank ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims (Type -> [SubExp]) -> Type -> [SubExp]
forall a b. (a -> b) -> a -> b
$ [Type] -> Type
forall a. HasCallStack => [a] -> a
head [Type]
dest_ts
      mkT :: VName -> Type -> Type
mkT VName
cert Type
elem_t = VName -> ShapeBase SubExp -> [Type] -> NoUniqueness -> Type
forall shape u.
VName -> ShapeBase SubExp -> [Type] -> u -> TypeBase shape u
Acc VName
cert ShapeBase SubExp
acc_shape [Type
elem_t] NoUniqueness
NoUniqueness
      acc_ts :: [Type]
acc_ts =
        (VName -> Type -> Type) -> [VName] -> [Type] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> Type -> Type
mkT ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
cert_ps) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$
          (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Type -> Type
forall u.
Int
-> TypeBase (ShapeBase SubExp) u -> TypeBase (ShapeBase SubExp) u
stripArray Int
rank) [Type]
dest_ts
  [Param Type]
acc_ps <- (Type -> m (Param Type)) -> [Type] -> m [Param Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Name -> Type -> m (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
Name -> dec -> m (Param dec)
newParam Name
"acc_p") [Type]
acc_ts
  [Type]
arrs_ts <- (VName -> m Type) -> [VName] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
arrs

  Lambda (Rep m)
withacc_lam <- [LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda ([Param Type]
cert_ps [Param Type] -> [Param Type] -> [Param Type]
forall a. Semigroup a => a -> a -> a
<> [Param Type]
acc_ps) (m [SubExpRes] -> m (Lambda (Rep m)))
-> m [SubExpRes] -> m (Lambda (Rep m))
forall a b. (a -> b) -> a -> b
$ do
    [Param Type]
acc_ps_inner <- (Type -> m (Param Type)) -> [Type] -> m [Param Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Name -> Type -> m (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
Name -> dec -> m (Param dec)
newParam Name
"acc_p") [Type]
acc_ts
    [Param Type]
params <- (Type -> m (Param Type)) -> [Type] -> m [Param Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Name -> Type -> m (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
Name -> dec -> m (Param dec)
newParam Name
"v" (Type -> m (Param Type))
-> (Type -> Type) -> Type -> m (Param Type)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Type -> Type
forall u.
Int
-> TypeBase (ShapeBase SubExp) u -> TypeBase (ShapeBase SubExp) u
stripArray Int
1) [Type]
arrs_ts
    Lambda (Rep m)
map_lam <-
      [LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda ([Param Type]
acc_ps_inner [Param Type] -> [Param Type] -> [Param Type]
forall a. Semigroup a => a -> a -> a
<> [Param Type]
params) (m [SubExpRes] -> m (Lambda (Rep m)))
-> m [SubExpRes] -> m (Lambda (Rep m))
forall a b. (a -> b) -> a -> b
$ do
        ([SubExp]
is, [SubExp]
vs) <- Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
rank ([SubExp] -> ([SubExp], [SubExp]))
-> m [SubExp] -> m ([SubExp], [SubExp])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [LParam (Rep m)] -> m [SubExp]
mk [Param Type]
[LParam (Rep m)]
params
        ([SubExp] -> [SubExpRes]) -> m [SubExp] -> m [SubExpRes]
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [SubExp] -> [SubExpRes]
subExpsRes (m [SubExp] -> m [SubExpRes]) -> m [SubExp] -> m [SubExpRes]
forall a b. (a -> b) -> a -> b
$ [(Param Type, SubExp)]
-> ((Param Type, SubExp) -> m SubExp) -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Param Type] -> [SubExp] -> [(Param Type, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
acc_ps_inner [SubExp]
vs) (((Param Type, SubExp) -> m SubExp) -> m [SubExp])
-> ((Param Type, SubExp) -> m SubExp) -> m [SubExp]
forall a b. (a -> b) -> a -> b
$ \(Param Type
acc_p_inner, SubExp
v) ->
          Name -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
Name -> Exp (Rep m) -> m SubExp
letSubExp Name
"scatter_acc" (Exp (Rep m) -> m SubExp)
-> (BasicOp -> Exp (Rep m)) -> BasicOp -> m SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> m SubExp) -> BasicOp -> m SubExp
forall a b. (a -> b) -> a -> b
$
            Safety -> VName -> [SubExp] -> [SubExp] -> BasicOp
UpdateAcc Safety
Safe (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
acc_p_inner) [SubExp]
is [SubExp
v]
    let w :: SubExp
w = Int -> [Type] -> SubExp
forall u. Int -> [TypeBase (ShapeBase SubExp) u] -> SubExp
arraysSize Int
0 [Type]
arrs_ts
    ([VName] -> [SubExpRes]) -> m [VName] -> m [SubExpRes]
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (m [VName] -> m [SubExpRes])
-> (SOAC (Rep m) -> m [VName]) -> SOAC (Rep m) -> m [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> Exp (Rep m) -> m [VName]
forall (m :: * -> *).
MonadBuilder m =>
Name -> Exp (Rep m) -> m [VName]
letTupExp Name
"acc_res" (Exp (Rep m) -> m [VName])
-> (SOAC (Rep m) -> Exp (Rep m)) -> SOAC (Rep m) -> m [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op (Rep m) -> Exp (Rep m)
SOAC (Rep m) -> Exp (Rep m)
forall rep. Op rep -> Exp rep
Op (SOAC (Rep m) -> m [SubExpRes]) -> SOAC (Rep m) -> m [SubExpRes]
forall a b. (a -> b) -> a -> b
$
      SubExp -> [VName] -> ScremaForm (Rep m) -> SOAC (Rep m)
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
acc_ps [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
arrs) (Lambda (Rep m) -> ScremaForm (Rep m)
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda (Rep m)
map_lam)

  Name -> Exp (Rep m) -> m [VName]
forall (m :: * -> *).
MonadBuilder m =>
Name -> Exp (Rep m) -> m [VName]
letTupExp Name
desc (Exp (Rep m) -> m [VName]) -> Exp (Rep m) -> m [VName]
forall a b. (a -> b) -> a -> b
$ [WithAccInput (Rep m)] -> Lambda (Rep m) -> Exp (Rep m)
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [(ShapeBase SubExp
acc_shape, [VName
v], Maybe (Lambda (Rep m), [SubExp])
forall a. Maybe a
Nothing) | VName
v <- [VName]
dest] Lambda (Rep m)
withacc_lam