{-# LANGUAGE TypeFamilies #-}

module Futhark.Analysis.HORep.MapNest
  ( Nesting (..),
    MapNest (..),
    depth,
    typeOf,
    params,
    inputs,
    setInputs,
    fromSOAC,
    toSOAC,
    reshape,
  )
where

import Control.Monad (replicateM)
import Data.List (find)
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.Analysis.HORep.SOAC (SOAC)
import Futhark.Analysis.HORep.SOAC qualified as SOAC
import Futhark.Construct
import Futhark.IR hiding (typeOf)
import Futhark.IR.SOACS (SOACS)
import Futhark.IR.SOACS.SOAC qualified as Futhark
import Futhark.Transform.Substitute

data Nesting = Nesting
  { Nesting -> [VName]
nestingParamNames :: [VName],
    Nesting -> [VName]
nestingResult :: [VName],
    Nesting -> [Type]
nestingReturnType :: [Type],
    Nesting -> SubExp
nestingWidth :: SubExp
  }
  deriving (Nesting -> Nesting -> Bool
(Nesting -> Nesting -> Bool)
-> (Nesting -> Nesting -> Bool) -> Eq Nesting
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Nesting -> Nesting -> Bool
== :: Nesting -> Nesting -> Bool
$c/= :: Nesting -> Nesting -> Bool
/= :: Nesting -> Nesting -> Bool
Eq, Eq Nesting
Eq Nesting =>
(Nesting -> Nesting -> Ordering)
-> (Nesting -> Nesting -> Bool)
-> (Nesting -> Nesting -> Bool)
-> (Nesting -> Nesting -> Bool)
-> (Nesting -> Nesting -> Bool)
-> (Nesting -> Nesting -> Nesting)
-> (Nesting -> Nesting -> Nesting)
-> Ord Nesting
Nesting -> Nesting -> Bool
Nesting -> Nesting -> Ordering
Nesting -> Nesting -> Nesting
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: Nesting -> Nesting -> Ordering
compare :: Nesting -> Nesting -> Ordering
$c< :: Nesting -> Nesting -> Bool
< :: Nesting -> Nesting -> Bool
$c<= :: Nesting -> Nesting -> Bool
<= :: Nesting -> Nesting -> Bool
$c> :: Nesting -> Nesting -> Bool
> :: Nesting -> Nesting -> Bool
$c>= :: Nesting -> Nesting -> Bool
>= :: Nesting -> Nesting -> Bool
$cmax :: Nesting -> Nesting -> Nesting
max :: Nesting -> Nesting -> Nesting
$cmin :: Nesting -> Nesting -> Nesting
min :: Nesting -> Nesting -> Nesting
Ord, Int -> Nesting -> ShowS
[Nesting] -> ShowS
Nesting -> String
(Int -> Nesting -> ShowS)
-> (Nesting -> String) -> ([Nesting] -> ShowS) -> Show Nesting
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Nesting -> ShowS
showsPrec :: Int -> Nesting -> ShowS
$cshow :: Nesting -> String
show :: Nesting -> String
$cshowList :: [Nesting] -> ShowS
showList :: [Nesting] -> ShowS
Show)

data MapNest = MapNest
  { MapNest -> SubExp
mapNestWidth :: SubExp,
    MapNest -> Lambda SOACS
mapNestLambda :: Lambda SOACS,
    MapNest -> [Nesting]
mapNestNestings :: [Nesting],
    MapNest -> [Input]
mapNestInput :: [SOAC.Input]
  }
  deriving (Int -> MapNest -> ShowS
[MapNest] -> ShowS
MapNest -> String
(Int -> MapNest -> ShowS)
-> (MapNest -> String) -> ([MapNest] -> ShowS) -> Show MapNest
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> MapNest -> ShowS
showsPrec :: Int -> MapNest -> ShowS
$cshow :: MapNest -> String
show :: MapNest -> String
$cshowList :: [MapNest] -> ShowS
showList :: [MapNest] -> ShowS
Show)

depth :: MapNest -> Int
depth :: MapNest -> Int
depth (MapNest SubExp
_ Lambda SOACS
_ [Nesting]
nests [Input]
_) = Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Nesting] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Nesting]
nests

typeOf :: MapNest -> [Type]
typeOf :: MapNest -> [Type]
typeOf (MapNest SubExp
w Lambda SOACS
lam [] [Input]
_) =
  (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
w) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
lam
typeOf (MapNest SubExp
w Lambda SOACS
_ (Nesting
nest : [Nesting]
_) [Input]
_) =
  (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
w) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Nesting -> [Type]
nestingReturnType Nesting
nest

params :: MapNest -> [VName]
params :: MapNest -> [VName]
params (MapNest SubExp
_ Lambda SOACS
lam [] [Input]
_) =
  (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName ([Param Type] -> [VName]) -> [Param Type] -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam
params (MapNest SubExp
_ Lambda SOACS
_ (Nesting
nest : [Nesting]
_) [Input]
_) =
  Nesting -> [VName]
nestingParamNames Nesting
nest

inputs :: MapNest -> [SOAC.Input]
inputs :: MapNest -> [Input]
inputs (MapNest SubExp
_ Lambda SOACS
_ [Nesting]
_ [Input]
inps) = [Input]
inps

setInputs :: [SOAC.Input] -> MapNest -> MapNest
setInputs :: [Input] -> MapNest -> MapNest
setInputs [] (MapNest SubExp
w Lambda SOACS
body [Nesting]
ns [Input]
_) = SubExp -> Lambda SOACS -> [Nesting] -> [Input] -> MapNest
MapNest SubExp
w Lambda SOACS
body [Nesting]
ns []
setInputs (Input
inp : [Input]
inps) (MapNest SubExp
_ Lambda SOACS
body [Nesting]
ns [Input]
_) = SubExp -> Lambda SOACS -> [Nesting] -> [Input] -> MapNest
MapNest SubExp
w Lambda SOACS
body [Nesting]
ns' (Input
inp Input -> [Input] -> [Input]
forall a. a -> [a] -> [a]
: [Input]
inps)
  where
    w :: SubExp
w = Int -> Type -> SubExp
forall u. Int -> TypeBase (ShapeBase SubExp) u -> SubExp
arraySize Int
0 (Type -> SubExp) -> Type -> SubExp
forall a b. (a -> b) -> a -> b
$ Input -> Type
SOAC.inputType Input
inp
    ws :: [SubExp]
ws = Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop Int
1 ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims (Type -> [SubExp]) -> Type -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Input -> Type
SOAC.inputType Input
inp
    ns' :: [Nesting]
ns' = (Nesting -> SubExp -> Nesting)
-> [Nesting] -> [SubExp] -> [Nesting]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Nesting -> SubExp -> Nesting
setDepth [Nesting]
ns [SubExp]
ws
    setDepth :: Nesting -> SubExp -> Nesting
setDepth Nesting
n SubExp
nw = Nesting
n {nestingWidth = nw}

pushIntoMapLambda ::
  Stms SOACS ->
  Stm SOACS ->
  Maybe (Stm SOACS)
pushIntoMapLambda :: Stms SOACS -> Stm SOACS -> Maybe (Stm SOACS)
pushIntoMapLambda Stms SOACS
stms (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op (Futhark.Screma SubExp
w [VName]
inps ScremaForm SOACS
form)))
  | Just Lambda SOACS
map_lam <- ScremaForm SOACS -> Maybe (Lambda SOACS)
forall rep. ScremaForm rep -> Maybe (Lambda rep)
Futhark.isMapSOAC ScremaForm SOACS
form,
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((Names -> Names -> Bool
`namesIntersect` Names
bound_by_stms) (Names -> Bool) -> (VName -> Names) -> VName -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Names
forall a. FreeIn a => a -> Names
freeIn) [VName]
inps =
      let lam_body :: Body SOACS
lam_body = Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
map_lam
          map_lam' :: Lambda SOACS
map_lam' =
            Lambda SOACS
map_lam {lambdaBody = lam_body {bodyStms = stms <> bodyStms lam_body}}
          form' :: ScremaForm SOACS
form' = Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
Futhark.mapSOAC Lambda SOACS
map_lam'
       in Stm SOACS -> Maybe (Stm SOACS)
forall a. a -> Maybe a
Just (Stm SOACS -> Maybe (Stm SOACS)) -> Stm SOACS -> Maybe (Stm SOACS)
forall a b. (a -> b) -> a -> b
$ Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Futhark.Screma SubExp
w [VName]
inps ScremaForm SOACS
form'))
  where
    bound_by_stms :: Names
bound_by_stms = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (Stm SOACS -> [VName]) -> Stms SOACS -> [VName]
forall m a. Monoid m => (a -> m) -> Seq a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat Type -> [VName])
-> (Stm SOACS -> Pat Type) -> Stm SOACS -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm SOACS -> Pat Type
Stm SOACS -> Pat (LetDec SOACS)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat) Stms SOACS
stms
pushIntoMapLambda Stms SOACS
_ Stm SOACS
_ = Maybe (Stm SOACS)
forall a. Maybe a
Nothing

massage :: SOAC SOACS -> SOAC SOACS
massage :: SOAC SOACS -> SOAC SOACS
massage (SOAC.Screma SubExp
w [Input]
inps ScremaForm SOACS
form)
  | Just Lambda SOACS
lam <- ScremaForm SOACS -> Maybe (Lambda SOACS)
forall rep. ScremaForm rep -> Maybe (Lambda rep)
Futhark.isMapSOAC ScremaForm SOACS
form,
    Just (Stms SOACS
init_stms, Stm SOACS
last_stm) <- Stms SOACS -> Maybe (Stms SOACS, Stm SOACS)
forall lore. Stms lore -> Maybe (Stms lore, Stm lore)
stmsLast (Stms SOACS -> Maybe (Stms SOACS, Stm SOACS))
-> Stms SOACS -> Maybe (Stms SOACS, Stm SOACS)
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Stms SOACS
forall rep. Body rep -> Stms rep
bodyStms (Body SOACS -> Stms SOACS) -> Body SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam,
    (Stm SOACS -> Bool) -> Stms SOACS -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Exp SOACS -> Bool
forall {rep}. Exp rep -> Bool
cheap (Exp SOACS -> Bool)
-> (Stm SOACS -> Exp SOACS) -> Stm SOACS -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm SOACS -> Exp SOACS
forall rep. Stm rep -> Exp rep
stmExp) Stms SOACS
init_stms,
    (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> Names -> Bool
`notNameIn` Result -> Names
forall a. FreeIn a => a -> Names
freeIn (Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult (Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam))) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$
      (Stm SOACS -> [VName]) -> Stms SOACS -> [VName]
forall m a. Monoid m => (a -> m) -> Seq a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat Type -> [VName])
-> (Stm SOACS -> Pat Type) -> Stm SOACS -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm SOACS -> Pat Type
Stm SOACS -> Pat (LetDec SOACS)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat) Stms SOACS
init_stms,
    Just Stm SOACS
last_stm' <- Stms SOACS -> Stm SOACS -> Maybe (Stm SOACS)
pushIntoMapLambda Stms SOACS
init_stms Stm SOACS
last_stm =
      let lam' :: Lambda SOACS
lam' =
            Lambda SOACS
lam {lambdaBody = (lambdaBody lam) {bodyStms = oneStm last_stm'}}
       in SubExp -> [Input] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [Input] -> ScremaForm rep -> SOAC rep
SOAC.Screma SubExp
w [Input]
inps (Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
Futhark.mapSOAC Lambda SOACS
lam')
  where
    cheap :: Exp rep -> Bool
cheap (BasicOp BinOp {}) = Bool
True
    cheap (BasicOp SubExp {}) = Bool
True
    cheap (BasicOp CmpOp {}) = Bool
True
    cheap (BasicOp ConvOp {}) = Bool
True
    cheap (BasicOp UnOp {}) = Bool
True
    cheap Exp rep
_ = Bool
False
massage SOAC SOACS
soac = SOAC SOACS
soac

fromSOAC' ::
  (MonadFreshNames m, LocalScope SOACS m) =>
  [Ident] ->
  SOAC SOACS ->
  m (Maybe MapNest)
fromSOAC' :: forall (m :: * -> *).
(MonadFreshNames m, LocalScope SOACS m) =>
[Ident] -> SOAC SOACS -> m (Maybe MapNest)
fromSOAC' [Ident]
bound SOAC SOACS
soac
  | SOAC.Screma SubExp
w [Input]
inps (SOAC.ScremaForm Lambda SOACS
lam [] []) <- SOAC SOACS -> SOAC SOACS
massage SOAC SOACS
soac = do
      let bound' :: [Ident]
bound' = [Ident]
bound [Ident] -> [Ident] -> [Ident]
forall a. Semigroup a => a -> a -> a
<> (Param Type -> Ident) -> [Param Type] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam)

      Either NotSOAC (Maybe (Pat Type, MapNest))
maybenest <- case ( Stms SOACS -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms SOACS -> [Stm SOACS]) -> Stms SOACS -> [Stm SOACS]
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Stms SOACS
forall rep. Body rep -> Stms rep
bodyStms (Body SOACS -> Stms SOACS) -> Body SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam,
                          Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult (Body SOACS -> Result) -> Body SOACS -> Result
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam
                        ) of
        ([Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
_ Exp SOACS
e], Result
res)
          | (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
res [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
Pat (LetDec SOACS)
pat) ->
              Scope SOACS
-> m (Either NotSOAC (Maybe (Pat Type, MapNest)))
-> m (Either NotSOAC (Maybe (Pat Type, MapNest)))
forall a. Scope SOACS -> m a -> m a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([LParam SOACS] -> Scope SOACS
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams ([LParam SOACS] -> Scope SOACS) -> [LParam SOACS] -> Scope SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam) (m (Either NotSOAC (Maybe (Pat Type, MapNest)))
 -> m (Either NotSOAC (Maybe (Pat Type, MapNest))))
-> m (Either NotSOAC (Maybe (Pat Type, MapNest)))
-> m (Either NotSOAC (Maybe (Pat Type, MapNest)))
forall a b. (a -> b) -> a -> b
$
                Exp SOACS -> m (Either NotSOAC (SOAC SOACS))
forall rep (m :: * -> *).
(Op rep ~ SOAC rep, HasScope rep m) =>
Exp rep -> m (Either NotSOAC (SOAC rep))
SOAC.fromExp Exp SOACS
e
                  m (Either NotSOAC (SOAC SOACS))
-> (Either NotSOAC (SOAC SOACS)
    -> m (Either NotSOAC (Maybe (Pat Type, MapNest))))
-> m (Either NotSOAC (Maybe (Pat Type, MapNest)))
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (NotSOAC -> m (Either NotSOAC (Maybe (Pat Type, MapNest))))
-> (SOAC SOACS -> m (Either NotSOAC (Maybe (Pat Type, MapNest))))
-> Either NotSOAC (SOAC SOACS)
-> m (Either NotSOAC (Maybe (Pat Type, MapNest)))
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Either NotSOAC (Maybe (Pat Type, MapNest))
-> m (Either NotSOAC (Maybe (Pat Type, MapNest)))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either NotSOAC (Maybe (Pat Type, MapNest))
 -> m (Either NotSOAC (Maybe (Pat Type, MapNest))))
-> (NotSOAC -> Either NotSOAC (Maybe (Pat Type, MapNest)))
-> NotSOAC
-> m (Either NotSOAC (Maybe (Pat Type, MapNest)))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NotSOAC -> Either NotSOAC (Maybe (Pat Type, MapNest))
forall a b. a -> Either a b
Left) ((Maybe MapNest -> Either NotSOAC (Maybe (Pat Type, MapNest)))
-> m (Maybe MapNest)
-> m (Either NotSOAC (Maybe (Pat Type, MapNest)))
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Maybe (Pat Type, MapNest)
-> Either NotSOAC (Maybe (Pat Type, MapNest))
forall a b. b -> Either a b
Right (Maybe (Pat Type, MapNest)
 -> Either NotSOAC (Maybe (Pat Type, MapNest)))
-> (Maybe MapNest -> Maybe (Pat Type, MapNest))
-> Maybe MapNest
-> Either NotSOAC (Maybe (Pat Type, MapNest))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (MapNest -> (Pat Type, MapNest))
-> Maybe MapNest -> Maybe (Pat Type, MapNest)
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Pat Type
Pat (LetDec SOACS)
pat,)) (m (Maybe MapNest)
 -> m (Either NotSOAC (Maybe (Pat Type, MapNest))))
-> (SOAC SOACS -> m (Maybe MapNest))
-> SOAC SOACS
-> m (Either NotSOAC (Maybe (Pat Type, MapNest)))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Ident] -> SOAC SOACS -> m (Maybe MapNest)
forall (m :: * -> *).
(MonadFreshNames m, LocalScope SOACS m) =>
[Ident] -> SOAC SOACS -> m (Maybe MapNest)
fromSOAC' [Ident]
bound')
        ([Stm SOACS], Result)
_ ->
          Either NotSOAC (Maybe (Pat Type, MapNest))
-> m (Either NotSOAC (Maybe (Pat Type, MapNest)))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either NotSOAC (Maybe (Pat Type, MapNest))
 -> m (Either NotSOAC (Maybe (Pat Type, MapNest))))
-> Either NotSOAC (Maybe (Pat Type, MapNest))
-> m (Either NotSOAC (Maybe (Pat Type, MapNest)))
forall a b. (a -> b) -> a -> b
$ Maybe (Pat Type, MapNest)
-> Either NotSOAC (Maybe (Pat Type, MapNest))
forall a b. b -> Either a b
Right Maybe (Pat Type, MapNest)
forall a. Maybe a
Nothing

      case Either NotSOAC (Maybe (Pat Type, MapNest))
maybenest of
        -- Do we have a nested MapNest?
        Right (Just (Pat Type
pat, mn :: MapNest
mn@(MapNest SubExp
inner_w Lambda SOACS
body' [Nesting]
ns' [Input]
inps'))) -> do
          ([VName]
ps, [Input]
inps'') <-
            [(VName, Input)] -> ([VName], [Input])
forall a b. [(a, b)] -> ([a], [b])
unzip
              ([(VName, Input)] -> ([VName], [Input]))
-> m [(VName, Input)] -> m ([VName], [Input])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp
-> [(VName, Input)] -> [(VName, Input)] -> m [(VName, Input)]
forall (m :: * -> *).
MonadFreshNames m =>
SubExp
-> [(VName, Input)] -> [(VName, Input)] -> m [(VName, Input)]
fixInputs
                SubExp
w
                ([VName] -> [Input] -> [(VName, Input)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName ([Param Type] -> [VName]) -> [Param Type] -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam) [Input]
inps)
                ([VName] -> [Input] -> [(VName, Input)]
forall a b. [a] -> [b] -> [(a, b)]
zip (MapNest -> [VName]
params MapNest
mn) [Input]
inps')
          let n' :: Nesting
n' =
                Nesting
                  { nestingParamNames :: [VName]
nestingParamNames = [VName]
ps,
                    nestingResult :: [VName]
nestingResult = Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
pat,
                    nestingReturnType :: [Type]
nestingReturnType = MapNest -> [Type]
typeOf MapNest
mn,
                    nestingWidth :: SubExp
nestingWidth = SubExp
inner_w
                  }
          Maybe MapNest -> m (Maybe MapNest)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe MapNest -> m (Maybe MapNest))
-> Maybe MapNest -> m (Maybe MapNest)
forall a b. (a -> b) -> a -> b
$ MapNest -> Maybe MapNest
forall a. a -> Maybe a
Just (MapNest -> Maybe MapNest) -> MapNest -> Maybe MapNest
forall a b. (a -> b) -> a -> b
$ SubExp -> Lambda SOACS -> [Nesting] -> [Input] -> MapNest
MapNest SubExp
w Lambda SOACS
body' (Nesting
n' Nesting -> [Nesting] -> [Nesting]
forall a. a -> [a] -> [a]
: [Nesting]
ns') [Input]
inps''
        -- No nested MapNest it seems.
        Either NotSOAC (Maybe (Pat Type, MapNest))
_ -> do
          let isBound :: VName -> Maybe Ident
isBound VName
name
                | Just Ident
param <- (Ident -> Bool) -> [Ident] -> Maybe Ident
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName
name ==) (VName -> Bool) -> (Ident -> VName) -> Ident -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ident -> VName
identName) [Ident]
bound =
                    Ident -> Maybe Ident
forall a. a -> Maybe a
Just Ident
param
                | Bool
otherwise =
                    Maybe Ident
forall a. Maybe a
Nothing
              boundUsedInBody :: [Ident]
boundUsedInBody =
                (VName -> Maybe Ident) -> [VName] -> [Ident]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe VName -> Maybe Ident
isBound ([VName] -> [Ident]) -> [VName] -> [Ident]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn Lambda SOACS
lam
          [Ident]
newParams <- (Ident -> m Ident) -> [Ident] -> m [Ident]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (ShowS -> Ident -> m Ident
forall (m :: * -> *).
MonadFreshNames m =>
ShowS -> Ident -> m Ident
newIdent' (String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_wasfree")) [Ident]
boundUsedInBody
          let subst :: Map VName VName
subst =
                [(VName, VName)] -> Map VName VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, VName)] -> Map VName VName)
-> [(VName, VName)] -> Map VName VName
forall a b. (a -> b) -> a -> b
$
                  [VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
boundUsedInBody) ((Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
newParams)
              inps' :: [Input]
inps' =
                [Input]
inps
                  [Input] -> [Input] -> [Input]
forall a. [a] -> [a] -> [a]
++ (Ident -> Input) -> [Ident] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map
                    (ArrayTransform -> Input -> Input
SOAC.addTransform (StmAux () -> ShapeBase SubExp -> ArrayTransform
SOAC.Replicate StmAux ()
forall a. Monoid a => a
mempty (ShapeBase SubExp -> ArrayTransform)
-> ShapeBase SubExp -> ArrayTransform
forall a b. (a -> b) -> a -> b
$ [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) (Input -> Input) -> (Ident -> Input) -> Ident -> Input
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ident -> Input
SOAC.identInput)
                    [Ident]
boundUsedInBody
              lam' :: Lambda SOACS
lam' =
                Lambda SOACS
lam
                  { lambdaBody =
                      substituteNames subst $ lambdaBody lam,
                    lambdaParams =
                      lambdaParams lam
                        ++ [Param mempty name t | Ident name t <- newParams]
                  }
          Maybe MapNest -> m (Maybe MapNest)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe MapNest -> m (Maybe MapNest))
-> Maybe MapNest -> m (Maybe MapNest)
forall a b. (a -> b) -> a -> b
$ MapNest -> Maybe MapNest
forall a. a -> Maybe a
Just (MapNest -> Maybe MapNest) -> MapNest -> Maybe MapNest
forall a b. (a -> b) -> a -> b
$ SubExp -> Lambda SOACS -> [Nesting] -> [Input] -> MapNest
MapNest SubExp
w Lambda SOACS
lam' [] [Input]
inps'
fromSOAC' [Ident]
_ SOAC SOACS
_ = Maybe MapNest -> m (Maybe MapNest)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe MapNest
forall a. Maybe a
Nothing

fromSOAC :: (MonadFreshNames m, LocalScope SOACS m) => SOAC SOACS -> m (Maybe MapNest)
fromSOAC :: forall (m :: * -> *).
(MonadFreshNames m, LocalScope SOACS m) =>
SOAC SOACS -> m (Maybe MapNest)
fromSOAC = [Ident] -> SOAC SOACS -> m (Maybe MapNest)
forall (m :: * -> *).
(MonadFreshNames m, LocalScope SOACS m) =>
[Ident] -> SOAC SOACS -> m (Maybe MapNest)
fromSOAC' [Ident]
forall a. Monoid a => a
mempty

toSOAC :: (MonadFreshNames m, HasScope SOACS m) => MapNest -> m (SOAC SOACS)
toSOAC :: forall (m :: * -> *).
(MonadFreshNames m, HasScope SOACS m) =>
MapNest -> m (SOAC SOACS)
toSOAC (MapNest SubExp
w Lambda SOACS
lam [] [Input]
inps) =
  SOAC SOACS -> m (SOAC SOACS)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SOAC SOACS -> m (SOAC SOACS)) -> SOAC SOACS -> m (SOAC SOACS)
forall a b. (a -> b) -> a -> b
$ SubExp -> [Input] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [Input] -> ScremaForm rep -> SOAC rep
SOAC.Screma SubExp
w [Input]
inps (Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
Futhark.mapSOAC Lambda SOACS
lam)
toSOAC (MapNest SubExp
w Lambda SOACS
lam (Nesting [VName]
npnames [VName]
nres [Type]
nrettype SubExp
nw : [Nesting]
ns) [Input]
inps) = do
  let nparams :: [Param Type]
nparams = (VName -> Type -> Param Type) -> [VName] -> [Type] -> [Param Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (Attrs -> VName -> Type -> Param Type
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty) [VName]
npnames ([Type] -> [Param Type]) -> [Type] -> [Param Type]
forall a b. (a -> b) -> a -> b
$ (Input -> Type) -> [Input] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Input -> Type
SOAC.inputRowType [Input]
inps
  Body SOACS
body <- Builder SOACS Result -> m (Body SOACS)
forall rep (m :: * -> *) somerep.
(Buildable rep, MonadFreshNames m, HasScope somerep m,
 SameScope somerep rep) =>
Builder rep Result -> m (Body rep)
runBodyBuilder (Builder SOACS Result -> m (Body SOACS))
-> Builder SOACS Result -> m (Body SOACS)
forall a b. (a -> b) -> a -> b
$
    Scope SOACS -> Builder SOACS Result -> Builder SOACS Result
forall a.
Scope SOACS
-> BuilderT SOACS (State VNameSource) a
-> BuilderT SOACS (State VNameSource) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param Type] -> Scope SOACS
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [Param Type]
nparams) (Builder SOACS Result -> Builder SOACS Result)
-> Builder SOACS Result -> Builder SOACS Result
forall a b. (a -> b) -> a -> b
$ do
      [VName]
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName]
nres
        (Exp SOACS -> BuilderT SOACS (State VNameSource) ())
-> BuilderT SOACS (State VNameSource) (Exp SOACS)
-> BuilderT SOACS (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SOAC (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT
     SOACS
     (State VNameSource)
     (Exp (Rep (BuilderT SOACS (State VNameSource))))
SOAC SOACS -> BuilderT SOACS (State VNameSource) (Exp SOACS)
forall (m :: * -> *).
(MonadBuilder m, Op (Rep m) ~ SOAC (Rep m)) =>
SOAC (Rep m) -> m (Exp (Rep m))
SOAC.toExp
        (SOAC SOACS -> BuilderT SOACS (State VNameSource) (Exp SOACS))
-> BuilderT SOACS (State VNameSource) (SOAC SOACS)
-> BuilderT SOACS (State VNameSource) (Exp SOACS)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< MapNest -> BuilderT SOACS (State VNameSource) (SOAC SOACS)
forall (m :: * -> *).
(MonadFreshNames m, HasScope SOACS m) =>
MapNest -> m (SOAC SOACS)
toSOAC (SubExp -> Lambda SOACS -> [Nesting] -> [Input] -> MapNest
MapNest SubExp
nw Lambda SOACS
lam [Nesting]
ns ([Input] -> MapNest) -> [Input] -> MapNest
forall a b. (a -> b) -> a -> b
$ (Param Type -> Input) -> [Param Type] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map (Ident -> Input
SOAC.identInput (Ident -> Input) -> (Param Type -> Ident) -> Param Type -> Input
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent) [Param Type]
nparams)
      Result -> Builder SOACS Result
forall a. a -> BuilderT SOACS (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> Builder SOACS Result) -> Result -> Builder SOACS Result
forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes [VName]
nres
  let outerlam :: Lambda SOACS
outerlam =
        Lambda
          { lambdaParams :: [LParam SOACS]
lambdaParams = [Param Type]
[LParam SOACS]
nparams,
            lambdaBody :: Body SOACS
lambdaBody = Body SOACS
body,
            lambdaReturnType :: [Type]
lambdaReturnType = [Type]
nrettype
          }
  SOAC SOACS -> m (SOAC SOACS)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SOAC SOACS -> m (SOAC SOACS)) -> SOAC SOACS -> m (SOAC SOACS)
forall a b. (a -> b) -> a -> b
$ SubExp -> [Input] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [Input] -> ScremaForm rep -> SOAC rep
SOAC.Screma SubExp
w [Input]
inps (Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
Futhark.mapSOAC Lambda SOACS
outerlam)

fixInputs ::
  (MonadFreshNames m) =>
  SubExp ->
  [(VName, SOAC.Input)] ->
  [(VName, SOAC.Input)] ->
  m [(VName, SOAC.Input)]
fixInputs :: forall (m :: * -> *).
MonadFreshNames m =>
SubExp
-> [(VName, Input)] -> [(VName, Input)] -> m [(VName, Input)]
fixInputs SubExp
w [(VName, Input)]
ourInps = ((VName, Input) -> m (VName, Input))
-> [(VName, Input)] -> m [(VName, Input)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (VName, Input) -> m (VName, Input)
inspect
  where
    isParam :: a -> (a, b) -> Bool
isParam a
x (a
y, b
_) = a
x a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
y

    inspect :: (VName, Input) -> m (VName, Input)
inspect (VName
_, SOAC.Input ArrayTransforms
ts VName
v Type
_)
      | Just (VName
p, Input
pInp) <- ((VName, Input) -> Bool)
-> [(VName, Input)] -> Maybe (VName, Input)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (VName -> (VName, Input) -> Bool
forall {a} {b}. Eq a => a -> (a, b) -> Bool
isParam VName
v) [(VName, Input)]
ourInps = do
          let pInp' :: Input
pInp' = ArrayTransforms -> Input -> Input
SOAC.transformRows ArrayTransforms
ts Input
pInp
          VName
p' <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newNameFromString (String -> m VName) -> String -> m VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
p
          (VName, Input) -> m (VName, Input)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
p', Input
pInp')
    inspect (VName
param, SOAC.Input ArrayTransforms
ts VName
a Type
t) = do
      VName
param' <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newNameFromString (VName -> String
baseString VName
param String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_rep")
      (VName, Input) -> m (VName, Input)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
param', ArrayTransforms -> VName -> Type -> Input
SOAC.Input (ArrayTransforms
ts ArrayTransforms -> ArrayTransform -> ArrayTransforms
SOAC.|> StmAux () -> ShapeBase SubExp -> ArrayTransform
SOAC.Replicate StmAux ()
forall a. Monoid a => a
mempty ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
w])) VName
a Type
t)

-- | Reshape a map nest. It is assumed that any validity tests have
-- already been done. Will automatically reshape the inputs
-- appropriately.
reshape :: (MonadFreshNames m) => StmAux () -> Shape -> MapNest -> m MapNest
reshape :: forall (m :: * -> *).
MonadFreshNames m =>
StmAux () -> ShapeBase SubExp -> MapNest -> m MapNest
reshape StmAux ()
aux ShapeBase SubExp
shape (MapNest SubExp
_ Lambda SOACS
map_lam [Nesting]
_ [Input]
inps) =
  [Nesting] -> ShapeBase SubExp -> m MapNest
descend [] (ShapeBase SubExp -> m MapNest) -> ShapeBase SubExp -> m MapNest
forall a b. (a -> b) -> a -> b
$ Int -> ShapeBase SubExp -> ShapeBase SubExp
forall d. Int -> ShapeBase d -> ShapeBase d
stripDims Int
1 ShapeBase SubExp
shape
  where
    w :: SubExp
w = Int -> ShapeBase SubExp -> SubExp
shapeSize Int
0 ShapeBase SubExp
shape
    transform :: Type -> Input -> Input
transform Type
p Input
inp =
      let shape' :: ShapeBase SubExp
shape' = ShapeBase SubExp
shape ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> Type -> ShapeBase SubExp
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
p
          inp_shape :: ShapeBase SubExp
inp_shape = Type -> ShapeBase SubExp
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape (Input -> Type
SOAC.inputType Input
inp)
          tr :: ArrayTransform
tr = StmAux () -> NewShape SubExp -> ArrayTransform
SOAC.Reshape StmAux ()
aux (NewShape SubExp -> ArrayTransform)
-> NewShape SubExp -> ArrayTransform
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> ShapeBase SubExp -> NewShape SubExp
forall old new.
ArrayShape old =>
old -> ShapeBase new -> NewShape new
reshapeAll ShapeBase SubExp
inp_shape ShapeBase SubExp
shape'
       in ArrayTransform -> Input -> Input
SOAC.addTransform ArrayTransform
tr Input
inp
    inps' :: [Input]
inps' = (Type -> Input -> Input) -> [Type] -> [Input] -> [Input]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Type -> Input -> Input
transform ((Param Type -> Type) -> [Param Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType ([Param Type] -> [Type]) -> [Param Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam) [Input]
inps

    descend :: [Nesting] -> ShapeBase SubExp -> m MapNest
descend [Nesting]
nests ShapeBase SubExp
nest_shape
      | ShapeBase SubExp -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
nest_shape Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 =
          MapNest -> m MapNest
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (MapNest -> m MapNest) -> MapNest -> m MapNest
forall a b. (a -> b) -> a -> b
$ SubExp -> Lambda SOACS -> [Nesting] -> [Input] -> MapNest
MapNest SubExp
w Lambda SOACS
map_lam [Nesting]
nests [Input]
inps'
      | Bool
otherwise = do
          [VName]
nest_params <-
            (Param Type -> m VName) -> [Param Type] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> m VName)
-> (Param Type -> String) -> Param Type -> m VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> String
baseString (VName -> String) -> (Param Type -> VName) -> Param Type -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> VName
forall dec. Param dec -> VName
paramName) ([Param Type] -> m [VName]) -> [Param Type] -> m [VName]
forall a b. (a -> b) -> a -> b
$
              Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam
          [VName]
res <-
            Int -> m VName -> m [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM
              ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Type] -> Int) -> [Type] -> Int
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
map_lam)
              (String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"mapnest_res")
          let types :: [Type]
types =
                (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
nest_shape) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
map_lam
              nest :: Nesting
nest = [VName] -> [VName] -> [Type] -> SubExp -> Nesting
Nesting [VName]
nest_params [VName]
res [Type]
types (Int -> ShapeBase SubExp -> SubExp
shapeSize Int
0 ShapeBase SubExp
nest_shape)
          [Nesting] -> ShapeBase SubExp -> m MapNest
descend ([Nesting]
nests [Nesting] -> [Nesting] -> [Nesting]
forall a. [a] -> [a] -> [a]
++ [Nesting
nest]) (ShapeBase SubExp -> m MapNest) -> ShapeBase SubExp -> m MapNest
forall a b. (a -> b) -> a -> b
$ Int -> ShapeBase SubExp -> ShapeBase SubExp
forall d. Int -> ShapeBase d -> ShapeBase d
stripDims Int
1 ShapeBase SubExp
nest_shape