{-# LANGUAGE TypeFamilies #-}
module Futhark.Analysis.HORep.SOAC
(
SOAC (..),
Futhark.ScremaForm (..),
inputs,
setInputs,
lambda,
setLambda,
typeOf,
width,
NotSOAC (..),
fromExp,
toExp,
toSOAC,
Input (..),
varInput,
inputTransforms,
identInput,
isVarInput,
isVarishInput,
addTransform,
addInitialTransforms,
inputArray,
inputRank,
inputType,
inputRowType,
transformRows,
transposeInput,
applyTransforms,
ArrayTransforms,
noTransforms,
nullTransforms,
(|>),
(<|),
viewf,
ViewF (..),
viewl,
ViewL (..),
ArrayTransform (..),
transformFromExp,
transformToExp,
soacToStream,
)
where
import Data.Foldable as Foldable
import Data.Maybe
import Data.Sequence qualified as Seq
import Futhark.Construct hiding (toExp)
import Futhark.IR hiding
( Index,
Iota,
Rearrange,
Replicate,
Reshape,
typeOf,
)
import Futhark.IR qualified as Futhark
import Futhark.IR.SOACS.SOAC
( HistOp (..),
ScatterSpec,
ScremaForm (..),
scremaType,
)
import Futhark.IR.SOACS.SOAC qualified as Futhark
import Futhark.Transform.Rename (renameLambda)
import Futhark.Transform.Substitute
import Futhark.Util.Pretty (pretty)
import Futhark.Util.Pretty qualified as PP
data ArrayTransform
=
Rearrange Certs [Int]
|
Reshape Certs ReshapeKind Shape
|
ReshapeOuter Certs ReshapeKind Shape
|
ReshapeInner Certs ReshapeKind Shape
|
Replicate Certs Shape
|
Index Certs (Slice SubExp)
deriving (Int -> ArrayTransform -> ShowS
[ArrayTransform] -> ShowS
ArrayTransform -> String
(Int -> ArrayTransform -> ShowS)
-> (ArrayTransform -> String)
-> ([ArrayTransform] -> ShowS)
-> Show ArrayTransform
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ArrayTransform -> ShowS
showsPrec :: Int -> ArrayTransform -> ShowS
$cshow :: ArrayTransform -> String
show :: ArrayTransform -> String
$cshowList :: [ArrayTransform] -> ShowS
showList :: [ArrayTransform] -> ShowS
Show, ArrayTransform -> ArrayTransform -> Bool
(ArrayTransform -> ArrayTransform -> Bool)
-> (ArrayTransform -> ArrayTransform -> Bool) -> Eq ArrayTransform
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ArrayTransform -> ArrayTransform -> Bool
== :: ArrayTransform -> ArrayTransform -> Bool
$c/= :: ArrayTransform -> ArrayTransform -> Bool
/= :: ArrayTransform -> ArrayTransform -> Bool
Eq, Eq ArrayTransform
Eq ArrayTransform =>
(ArrayTransform -> ArrayTransform -> Ordering)
-> (ArrayTransform -> ArrayTransform -> Bool)
-> (ArrayTransform -> ArrayTransform -> Bool)
-> (ArrayTransform -> ArrayTransform -> Bool)
-> (ArrayTransform -> ArrayTransform -> Bool)
-> (ArrayTransform -> ArrayTransform -> ArrayTransform)
-> (ArrayTransform -> ArrayTransform -> ArrayTransform)
-> Ord ArrayTransform
ArrayTransform -> ArrayTransform -> Bool
ArrayTransform -> ArrayTransform -> Ordering
ArrayTransform -> ArrayTransform -> ArrayTransform
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: ArrayTransform -> ArrayTransform -> Ordering
compare :: ArrayTransform -> ArrayTransform -> Ordering
$c< :: ArrayTransform -> ArrayTransform -> Bool
< :: ArrayTransform -> ArrayTransform -> Bool
$c<= :: ArrayTransform -> ArrayTransform -> Bool
<= :: ArrayTransform -> ArrayTransform -> Bool
$c> :: ArrayTransform -> ArrayTransform -> Bool
> :: ArrayTransform -> ArrayTransform -> Bool
$c>= :: ArrayTransform -> ArrayTransform -> Bool
>= :: ArrayTransform -> ArrayTransform -> Bool
$cmax :: ArrayTransform -> ArrayTransform -> ArrayTransform
max :: ArrayTransform -> ArrayTransform -> ArrayTransform
$cmin :: ArrayTransform -> ArrayTransform -> ArrayTransform
min :: ArrayTransform -> ArrayTransform -> ArrayTransform
Ord)
instance Substitute ArrayTransform where
substituteNames :: Map VName VName -> ArrayTransform -> ArrayTransform
substituteNames Map VName VName
substs (Rearrange Certs
cs [Int]
xs) =
Certs -> [Int] -> ArrayTransform
Rearrange (Map VName VName -> Certs -> Certs
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Certs
cs) [Int]
xs
substituteNames Map VName VName
substs (Reshape Certs
cs ReshapeKind
k Shape
ses) =
Certs -> ReshapeKind -> Shape -> ArrayTransform
Reshape (Map VName VName -> Certs -> Certs
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Certs
cs) ReshapeKind
k (Map VName VName -> Shape -> Shape
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Shape
ses)
substituteNames Map VName VName
substs (ReshapeOuter Certs
cs ReshapeKind
k Shape
ses) =
Certs -> ReshapeKind -> Shape -> ArrayTransform
ReshapeOuter (Map VName VName -> Certs -> Certs
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Certs
cs) ReshapeKind
k (Map VName VName -> Shape -> Shape
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Shape
ses)
substituteNames Map VName VName
substs (ReshapeInner Certs
cs ReshapeKind
k Shape
ses) =
Certs -> ReshapeKind -> Shape -> ArrayTransform
ReshapeInner (Map VName VName -> Certs -> Certs
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Certs
cs) ReshapeKind
k (Map VName VName -> Shape -> Shape
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Shape
ses)
substituteNames Map VName VName
substs (Replicate Certs
cs Shape
se) =
Certs -> Shape -> ArrayTransform
Replicate (Map VName VName -> Certs -> Certs
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Certs
cs) (Map VName VName -> Shape -> Shape
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Shape
se)
substituteNames Map VName VName
substs (Index Certs
cs Slice SubExp
slice) =
Certs -> Slice SubExp -> ArrayTransform
Index (Map VName VName -> Certs -> Certs
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Certs
cs) (Map VName VName -> Slice SubExp -> Slice SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Slice SubExp
slice)
newtype ArrayTransforms = ArrayTransforms (Seq.Seq ArrayTransform)
deriving (ArrayTransforms -> ArrayTransforms -> Bool
(ArrayTransforms -> ArrayTransforms -> Bool)
-> (ArrayTransforms -> ArrayTransforms -> Bool)
-> Eq ArrayTransforms
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ArrayTransforms -> ArrayTransforms -> Bool
== :: ArrayTransforms -> ArrayTransforms -> Bool
$c/= :: ArrayTransforms -> ArrayTransforms -> Bool
/= :: ArrayTransforms -> ArrayTransforms -> Bool
Eq, Eq ArrayTransforms
Eq ArrayTransforms =>
(ArrayTransforms -> ArrayTransforms -> Ordering)
-> (ArrayTransforms -> ArrayTransforms -> Bool)
-> (ArrayTransforms -> ArrayTransforms -> Bool)
-> (ArrayTransforms -> ArrayTransforms -> Bool)
-> (ArrayTransforms -> ArrayTransforms -> Bool)
-> (ArrayTransforms -> ArrayTransforms -> ArrayTransforms)
-> (ArrayTransforms -> ArrayTransforms -> ArrayTransforms)
-> Ord ArrayTransforms
ArrayTransforms -> ArrayTransforms -> Bool
ArrayTransforms -> ArrayTransforms -> Ordering
ArrayTransforms -> ArrayTransforms -> ArrayTransforms
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: ArrayTransforms -> ArrayTransforms -> Ordering
compare :: ArrayTransforms -> ArrayTransforms -> Ordering
$c< :: ArrayTransforms -> ArrayTransforms -> Bool
< :: ArrayTransforms -> ArrayTransforms -> Bool
$c<= :: ArrayTransforms -> ArrayTransforms -> Bool
<= :: ArrayTransforms -> ArrayTransforms -> Bool
$c> :: ArrayTransforms -> ArrayTransforms -> Bool
> :: ArrayTransforms -> ArrayTransforms -> Bool
$c>= :: ArrayTransforms -> ArrayTransforms -> Bool
>= :: ArrayTransforms -> ArrayTransforms -> Bool
$cmax :: ArrayTransforms -> ArrayTransforms -> ArrayTransforms
max :: ArrayTransforms -> ArrayTransforms -> ArrayTransforms
$cmin :: ArrayTransforms -> ArrayTransforms -> ArrayTransforms
min :: ArrayTransforms -> ArrayTransforms -> ArrayTransforms
Ord, Int -> ArrayTransforms -> ShowS
[ArrayTransforms] -> ShowS
ArrayTransforms -> String
(Int -> ArrayTransforms -> ShowS)
-> (ArrayTransforms -> String)
-> ([ArrayTransforms] -> ShowS)
-> Show ArrayTransforms
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ArrayTransforms -> ShowS
showsPrec :: Int -> ArrayTransforms -> ShowS
$cshow :: ArrayTransforms -> String
show :: ArrayTransforms -> String
$cshowList :: [ArrayTransforms] -> ShowS
showList :: [ArrayTransforms] -> ShowS
Show)
instance Semigroup ArrayTransforms where
ArrayTransforms
ts1 <> :: ArrayTransforms -> ArrayTransforms -> ArrayTransforms
<> ArrayTransforms
ts2 = case ArrayTransforms -> ViewF
viewf ArrayTransforms
ts2 of
ArrayTransform
t :< ArrayTransforms
ts2' -> (ArrayTransforms
ts1 ArrayTransforms -> ArrayTransform -> ArrayTransforms
|> ArrayTransform
t) ArrayTransforms -> ArrayTransforms -> ArrayTransforms
forall a. Semigroup a => a -> a -> a
<> ArrayTransforms
ts2'
ViewF
EmptyF -> ArrayTransforms
ts1
instance Monoid ArrayTransforms where
mempty :: ArrayTransforms
mempty = ArrayTransforms
noTransforms
instance Substitute ArrayTransforms where
substituteNames :: Map VName VName -> ArrayTransforms -> ArrayTransforms
substituteNames Map VName VName
substs (ArrayTransforms Seq ArrayTransform
ts) =
Seq ArrayTransform -> ArrayTransforms
ArrayTransforms (Seq ArrayTransform -> ArrayTransforms)
-> Seq ArrayTransform -> ArrayTransforms
forall a b. (a -> b) -> a -> b
$ Map VName VName -> ArrayTransform -> ArrayTransform
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs (ArrayTransform -> ArrayTransform)
-> Seq ArrayTransform -> Seq ArrayTransform
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Seq ArrayTransform
ts
noTransforms :: ArrayTransforms
noTransforms :: ArrayTransforms
noTransforms = Seq ArrayTransform -> ArrayTransforms
ArrayTransforms Seq ArrayTransform
forall a. Seq a
Seq.empty
nullTransforms :: ArrayTransforms -> Bool
nullTransforms :: ArrayTransforms -> Bool
nullTransforms (ArrayTransforms Seq ArrayTransform
s) = Seq ArrayTransform -> Bool
forall a. Seq a -> Bool
Seq.null Seq ArrayTransform
s
viewf :: ArrayTransforms -> ViewF
viewf :: ArrayTransforms -> ViewF
viewf (ArrayTransforms Seq ArrayTransform
s) = case Seq ArrayTransform -> ViewL ArrayTransform
forall a. Seq a -> ViewL a
Seq.viewl Seq ArrayTransform
s of
ArrayTransform
t Seq.:< Seq ArrayTransform
s' -> ArrayTransform
t ArrayTransform -> ArrayTransforms -> ViewF
:< Seq ArrayTransform -> ArrayTransforms
ArrayTransforms Seq ArrayTransform
s'
ViewL ArrayTransform
Seq.EmptyL -> ViewF
EmptyF
data ViewF
= EmptyF
| ArrayTransform :< ArrayTransforms
viewl :: ArrayTransforms -> ViewL
viewl :: ArrayTransforms -> ViewL
viewl (ArrayTransforms Seq ArrayTransform
s) = case Seq ArrayTransform -> ViewR ArrayTransform
forall a. Seq a -> ViewR a
Seq.viewr Seq ArrayTransform
s of
Seq ArrayTransform
s' Seq.:> ArrayTransform
t -> Seq ArrayTransform -> ArrayTransforms
ArrayTransforms Seq ArrayTransform
s' ArrayTransforms -> ArrayTransform -> ViewL
:> ArrayTransform
t
ViewR ArrayTransform
Seq.EmptyR -> ViewL
EmptyL
data ViewL
= EmptyL
| ArrayTransforms :> ArrayTransform
(|>) :: ArrayTransforms -> ArrayTransform -> ArrayTransforms
|> :: ArrayTransforms -> ArrayTransform -> ArrayTransforms
(|>) = (ArrayTransform -> ArrayTransforms -> ArrayTransforms)
-> ArrayTransforms -> ArrayTransform -> ArrayTransforms
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((ArrayTransform -> ArrayTransforms -> ArrayTransforms)
-> ArrayTransforms -> ArrayTransform -> ArrayTransforms)
-> (ArrayTransform -> ArrayTransforms -> ArrayTransforms)
-> ArrayTransforms
-> ArrayTransform
-> ArrayTransforms
forall a b. (a -> b) -> a -> b
$ (ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms))
-> (ArrayTransform -> ArrayTransforms -> ArrayTransforms)
-> ((ArrayTransform, ArrayTransform)
-> (ArrayTransform, ArrayTransform))
-> ArrayTransform
-> ArrayTransforms
-> ArrayTransforms
addTransform' ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms)
extract ArrayTransform -> ArrayTransforms -> ArrayTransforms
add (((ArrayTransform, ArrayTransform)
-> (ArrayTransform, ArrayTransform))
-> ArrayTransform -> ArrayTransforms -> ArrayTransforms)
-> ((ArrayTransform, ArrayTransform)
-> (ArrayTransform, ArrayTransform))
-> ArrayTransform
-> ArrayTransforms
-> ArrayTransforms
forall a b. (a -> b) -> a -> b
$ (ArrayTransform
-> ArrayTransform -> (ArrayTransform, ArrayTransform))
-> (ArrayTransform, ArrayTransform)
-> (ArrayTransform, ArrayTransform)
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((ArrayTransform
-> ArrayTransform -> (ArrayTransform, ArrayTransform))
-> ArrayTransform
-> ArrayTransform
-> (ArrayTransform, ArrayTransform)
forall a b c. (a -> b -> c) -> b -> a -> c
flip (,))
where
extract :: ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms)
extract ArrayTransforms
ts' = case ArrayTransforms -> ViewL
viewl ArrayTransforms
ts' of
ViewL
EmptyL -> Maybe (ArrayTransform, ArrayTransforms)
forall a. Maybe a
Nothing
ArrayTransforms
ts'' :> ArrayTransform
t' -> (ArrayTransform, ArrayTransforms)
-> Maybe (ArrayTransform, ArrayTransforms)
forall a. a -> Maybe a
Just (ArrayTransform
t', ArrayTransforms
ts'')
add :: ArrayTransform -> ArrayTransforms -> ArrayTransforms
add ArrayTransform
t' (ArrayTransforms Seq ArrayTransform
ts') = Seq ArrayTransform -> ArrayTransforms
ArrayTransforms (Seq ArrayTransform -> ArrayTransforms)
-> Seq ArrayTransform -> ArrayTransforms
forall a b. (a -> b) -> a -> b
$ Seq ArrayTransform
ts' Seq ArrayTransform -> ArrayTransform -> Seq ArrayTransform
forall a. Seq a -> a -> Seq a
Seq.|> ArrayTransform
t'
(<|) :: ArrayTransform -> ArrayTransforms -> ArrayTransforms
<| :: ArrayTransform -> ArrayTransforms -> ArrayTransforms
(<|) = (ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms))
-> (ArrayTransform -> ArrayTransforms -> ArrayTransforms)
-> ((ArrayTransform, ArrayTransform)
-> (ArrayTransform, ArrayTransform))
-> ArrayTransform
-> ArrayTransforms
-> ArrayTransforms
addTransform' ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms)
extract ArrayTransform -> ArrayTransforms -> ArrayTransforms
add (ArrayTransform, ArrayTransform)
-> (ArrayTransform, ArrayTransform)
forall a. a -> a
id
where
extract :: ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms)
extract ArrayTransforms
ts' = case ArrayTransforms -> ViewF
viewf ArrayTransforms
ts' of
ViewF
EmptyF -> Maybe (ArrayTransform, ArrayTransforms)
forall a. Maybe a
Nothing
ArrayTransform
t' :< ArrayTransforms
ts'' -> (ArrayTransform, ArrayTransforms)
-> Maybe (ArrayTransform, ArrayTransforms)
forall a. a -> Maybe a
Just (ArrayTransform
t', ArrayTransforms
ts'')
add :: ArrayTransform -> ArrayTransforms -> ArrayTransforms
add ArrayTransform
t' (ArrayTransforms Seq ArrayTransform
ts') = Seq ArrayTransform -> ArrayTransforms
ArrayTransforms (Seq ArrayTransform -> ArrayTransforms)
-> Seq ArrayTransform -> ArrayTransforms
forall a b. (a -> b) -> a -> b
$ ArrayTransform
t' ArrayTransform -> Seq ArrayTransform -> Seq ArrayTransform
forall a. a -> Seq a -> Seq a
Seq.<| Seq ArrayTransform
ts'
addTransform' ::
(ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms)) ->
(ArrayTransform -> ArrayTransforms -> ArrayTransforms) ->
((ArrayTransform, ArrayTransform) -> (ArrayTransform, ArrayTransform)) ->
ArrayTransform ->
ArrayTransforms ->
ArrayTransforms
addTransform' :: (ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms))
-> (ArrayTransform -> ArrayTransforms -> ArrayTransforms)
-> ((ArrayTransform, ArrayTransform)
-> (ArrayTransform, ArrayTransform))
-> ArrayTransform
-> ArrayTransforms
-> ArrayTransforms
addTransform' ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms)
extract ArrayTransform -> ArrayTransforms -> ArrayTransforms
add (ArrayTransform, ArrayTransform)
-> (ArrayTransform, ArrayTransform)
swap ArrayTransform
t ArrayTransforms
ts =
ArrayTransforms -> Maybe ArrayTransforms -> ArrayTransforms
forall a. a -> Maybe a -> a
fromMaybe (ArrayTransform
t ArrayTransform -> ArrayTransforms -> ArrayTransforms
`add` ArrayTransforms
ts) (Maybe ArrayTransforms -> ArrayTransforms)
-> Maybe ArrayTransforms -> ArrayTransforms
forall a b. (a -> b) -> a -> b
$ do
(ArrayTransform
t', ArrayTransforms
ts') <- ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms)
extract ArrayTransforms
ts
ArrayTransform
combined <- (ArrayTransform -> ArrayTransform -> Maybe ArrayTransform)
-> (ArrayTransform, ArrayTransform) -> Maybe ArrayTransform
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ArrayTransform -> ArrayTransform -> Maybe ArrayTransform
combineTransforms ((ArrayTransform, ArrayTransform) -> Maybe ArrayTransform)
-> (ArrayTransform, ArrayTransform) -> Maybe ArrayTransform
forall a b. (a -> b) -> a -> b
$ (ArrayTransform, ArrayTransform)
-> (ArrayTransform, ArrayTransform)
swap (ArrayTransform
t', ArrayTransform
t)
ArrayTransforms -> Maybe ArrayTransforms
forall a. a -> Maybe a
Just (ArrayTransforms -> Maybe ArrayTransforms)
-> ArrayTransforms -> Maybe ArrayTransforms
forall a b. (a -> b) -> a -> b
$
if ArrayTransform -> Bool
identityTransform ArrayTransform
combined
then ArrayTransforms
ts'
else (ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms))
-> (ArrayTransform -> ArrayTransforms -> ArrayTransforms)
-> ((ArrayTransform, ArrayTransform)
-> (ArrayTransform, ArrayTransform))
-> ArrayTransform
-> ArrayTransforms
-> ArrayTransforms
addTransform' ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms)
extract ArrayTransform -> ArrayTransforms -> ArrayTransforms
add (ArrayTransform, ArrayTransform)
-> (ArrayTransform, ArrayTransform)
swap ArrayTransform
combined ArrayTransforms
ts'
identityTransform :: ArrayTransform -> Bool
identityTransform :: ArrayTransform -> Bool
identityTransform (Rearrange Certs
_ [Int]
perm) =
[Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
Foldable.and ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> Bool) -> [Int] -> [Int] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
(==) [Int]
perm [Int
0 ..]
identityTransform ArrayTransform
_ = Bool
False
combineTransforms :: ArrayTransform -> ArrayTransform -> Maybe ArrayTransform
combineTransforms :: ArrayTransform -> ArrayTransform -> Maybe ArrayTransform
combineTransforms (Rearrange Certs
cs2 [Int]
perm2) (Rearrange Certs
cs1 [Int]
perm1) =
ArrayTransform -> Maybe ArrayTransform
forall a. a -> Maybe a
Just (ArrayTransform -> Maybe ArrayTransform)
-> ArrayTransform -> Maybe ArrayTransform
forall a b. (a -> b) -> a -> b
$ Certs -> [Int] -> ArrayTransform
Rearrange (Certs
cs1 Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
cs2) ([Int] -> ArrayTransform) -> [Int] -> ArrayTransform
forall a b. (a -> b) -> a -> b
$ [Int]
perm2 [Int] -> [Int] -> [Int]
`rearrangeCompose` [Int]
perm1
combineTransforms ArrayTransform
_ ArrayTransform
_ = Maybe ArrayTransform
forall a. Maybe a
Nothing
transformFromExp :: Certs -> Exp rep -> Maybe (VName, ArrayTransform)
transformFromExp :: forall rep. Certs -> Exp rep -> Maybe (VName, ArrayTransform)
transformFromExp Certs
cs (BasicOp (Futhark.Rearrange [Int]
perm VName
v)) =
(VName, ArrayTransform) -> Maybe (VName, ArrayTransform)
forall a. a -> Maybe a
Just (VName
v, Certs -> [Int] -> ArrayTransform
Rearrange Certs
cs [Int]
perm)
transformFromExp Certs
cs (BasicOp (Futhark.Reshape ReshapeKind
k Shape
shape VName
v)) =
(VName, ArrayTransform) -> Maybe (VName, ArrayTransform)
forall a. a -> Maybe a
Just (VName
v, Certs -> ReshapeKind -> Shape -> ArrayTransform
Reshape Certs
cs ReshapeKind
k Shape
shape)
transformFromExp Certs
cs (BasicOp (Futhark.Replicate Shape
shape (Var VName
v))) =
(VName, ArrayTransform) -> Maybe (VName, ArrayTransform)
forall a. a -> Maybe a
Just (VName
v, Certs -> Shape -> ArrayTransform
Replicate Certs
cs Shape
shape)
transformFromExp Certs
cs (BasicOp (Futhark.Index VName
v Slice SubExp
slice)) =
(VName, ArrayTransform) -> Maybe (VName, ArrayTransform)
forall a. a -> Maybe a
Just (VName
v, Certs -> Slice SubExp -> ArrayTransform
Index Certs
cs Slice SubExp
slice)
transformFromExp Certs
_ Exp rep
_ = Maybe (VName, ArrayTransform)
forall a. Maybe a
Nothing
transformToExp :: (Monad m, HasScope rep m) => ArrayTransform -> VName -> m (Certs, Exp rep)
transformToExp :: forall (m :: * -> *) rep.
(Monad m, HasScope rep m) =>
ArrayTransform -> VName -> m (Certs, Exp rep)
transformToExp (Replicate Certs
cs Shape
n) VName
ia =
(Certs, Exp rep) -> m (Certs, Exp rep)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Certs
cs, BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Futhark.Replicate Shape
n (VName -> SubExp
Var VName
ia))
transformToExp (Rearrange Certs
cs [Int]
perm) VName
ia = do
Int
r <- TypeBase Shape NoUniqueness -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (TypeBase Shape NoUniqueness -> Int)
-> m (TypeBase Shape NoUniqueness) -> m Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
ia
(Certs, Exp rep) -> m (Certs, Exp rep)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Certs
cs, BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Futhark.Rearrange ([Int]
perm [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [[Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
perm .. Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]) VName
ia)
transformToExp (Reshape Certs
cs ReshapeKind
k Shape
shape) VName
ia = do
(Certs, Exp rep) -> m (Certs, Exp rep)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Certs
cs, BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ ReshapeKind -> Shape -> VName -> BasicOp
Futhark.Reshape ReshapeKind
k Shape
shape VName
ia)
transformToExp (ReshapeOuter Certs
cs ReshapeKind
k Shape
shape) VName
ia = do
Shape
shape' <- Shape -> Int -> Shape -> Shape
reshapeOuter Shape
shape Int
1 (Shape -> Shape)
-> (TypeBase Shape NoUniqueness -> Shape)
-> TypeBase Shape NoUniqueness
-> Shape
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeBase Shape NoUniqueness -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape (TypeBase Shape NoUniqueness -> Shape)
-> m (TypeBase Shape NoUniqueness) -> m Shape
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
ia
(Certs, Exp rep) -> m (Certs, Exp rep)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Certs
cs, BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ ReshapeKind -> Shape -> VName -> BasicOp
Futhark.Reshape ReshapeKind
k Shape
shape' VName
ia)
transformToExp (ReshapeInner Certs
cs ReshapeKind
k Shape
shape) VName
ia = do
Shape
shape' <- Shape -> Int -> Shape -> Shape
reshapeInner Shape
shape Int
1 (Shape -> Shape)
-> (TypeBase Shape NoUniqueness -> Shape)
-> TypeBase Shape NoUniqueness
-> Shape
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeBase Shape NoUniqueness -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape (TypeBase Shape NoUniqueness -> Shape)
-> m (TypeBase Shape NoUniqueness) -> m Shape
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
ia
(Certs, Exp rep) -> m (Certs, Exp rep)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Certs
cs, BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ ReshapeKind -> Shape -> VName -> BasicOp
Futhark.Reshape ReshapeKind
k Shape
shape' VName
ia)
transformToExp (Index Certs
cs Slice SubExp
slice) VName
ia = do
(Certs, Exp rep) -> m (Certs, Exp rep)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Certs
cs, BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Futhark.Index VName
ia Slice SubExp
slice)
data Input = Input ArrayTransforms VName Type
deriving (Int -> Input -> ShowS
[Input] -> ShowS
Input -> String
(Int -> Input -> ShowS)
-> (Input -> String) -> ([Input] -> ShowS) -> Show Input
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Input -> ShowS
showsPrec :: Int -> Input -> ShowS
$cshow :: Input -> String
show :: Input -> String
$cshowList :: [Input] -> ShowS
showList :: [Input] -> ShowS
Show, Input -> Input -> Bool
(Input -> Input -> Bool) -> (Input -> Input -> Bool) -> Eq Input
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Input -> Input -> Bool
== :: Input -> Input -> Bool
$c/= :: Input -> Input -> Bool
/= :: Input -> Input -> Bool
Eq, Eq Input
Eq Input =>
(Input -> Input -> Ordering)
-> (Input -> Input -> Bool)
-> (Input -> Input -> Bool)
-> (Input -> Input -> Bool)
-> (Input -> Input -> Bool)
-> (Input -> Input -> Input)
-> (Input -> Input -> Input)
-> Ord Input
Input -> Input -> Bool
Input -> Input -> Ordering
Input -> Input -> Input
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: Input -> Input -> Ordering
compare :: Input -> Input -> Ordering
$c< :: Input -> Input -> Bool
< :: Input -> Input -> Bool
$c<= :: Input -> Input -> Bool
<= :: Input -> Input -> Bool
$c> :: Input -> Input -> Bool
> :: Input -> Input -> Bool
$c>= :: Input -> Input -> Bool
>= :: Input -> Input -> Bool
$cmax :: Input -> Input -> Input
max :: Input -> Input -> Input
$cmin :: Input -> Input -> Input
min :: Input -> Input -> Input
Ord)
instance Substitute Input where
substituteNames :: Map VName VName -> Input -> Input
substituteNames Map VName VName
substs (Input ArrayTransforms
ts VName
v TypeBase Shape NoUniqueness
t) =
ArrayTransforms -> VName -> TypeBase Shape NoUniqueness -> Input
Input
(Map VName VName -> ArrayTransforms -> ArrayTransforms
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs ArrayTransforms
ts)
(Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs VName
v)
(Map VName VName
-> TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs TypeBase Shape NoUniqueness
t)
varInput :: (HasScope t f) => VName -> f Input
varInput :: forall t (f :: * -> *). HasScope t f => VName -> f Input
varInput VName
v = TypeBase Shape NoUniqueness -> Input
withType (TypeBase Shape NoUniqueness -> Input)
-> f (TypeBase Shape NoUniqueness) -> f Input
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> f (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v
where
withType :: TypeBase Shape NoUniqueness -> Input
withType = ArrayTransforms -> VName -> TypeBase Shape NoUniqueness -> Input
Input (Seq ArrayTransform -> ArrayTransforms
ArrayTransforms Seq ArrayTransform
forall a. Seq a
Seq.empty) VName
v
identInput :: Ident -> Input
identInput :: Ident -> Input
identInput Ident
v = ArrayTransforms -> VName -> TypeBase Shape NoUniqueness -> Input
Input (Seq ArrayTransform -> ArrayTransforms
ArrayTransforms Seq ArrayTransform
forall a. Seq a
Seq.empty) (Ident -> VName
identName Ident
v) (Ident -> TypeBase Shape NoUniqueness
identType Ident
v)
isVarInput :: Input -> Maybe VName
isVarInput :: Input -> Maybe VName
isVarInput (Input ArrayTransforms
ts VName
v TypeBase Shape NoUniqueness
_) | ArrayTransforms -> Bool
nullTransforms ArrayTransforms
ts = VName -> Maybe VName
forall a. a -> Maybe a
Just VName
v
isVarInput Input
_ = Maybe VName
forall a. Maybe a
Nothing
isVarishInput :: Input -> Maybe VName
isVarishInput :: Input -> Maybe VName
isVarishInput (Input ArrayTransforms
ts VName
v TypeBase Shape NoUniqueness
t)
| ArrayTransforms -> Bool
nullTransforms ArrayTransforms
ts = VName -> Maybe VName
forall a. a -> Maybe a
Just VName
v
| Reshape Certs
cs ReshapeKind
ReshapeCoerce (Shape [SubExp
_]) :< ArrayTransforms
ts' <- ArrayTransforms -> ViewF
viewf ArrayTransforms
ts,
Certs
cs Certs -> Certs -> Bool
forall a. Eq a => a -> a -> Bool
== Certs
forall a. Monoid a => a
mempty =
Input -> Maybe VName
isVarishInput (Input -> Maybe VName) -> Input -> Maybe VName
forall a b. (a -> b) -> a -> b
$ ArrayTransforms -> VName -> TypeBase Shape NoUniqueness -> Input
Input ArrayTransforms
ts' VName
v TypeBase Shape NoUniqueness
t
isVarishInput Input
_ = Maybe VName
forall a. Maybe a
Nothing
addTransform :: ArrayTransform -> Input -> Input
addTransform :: ArrayTransform -> Input -> Input
addTransform ArrayTransform
tr (Input ArrayTransforms
trs VName
a TypeBase Shape NoUniqueness
t) =
ArrayTransforms -> VName -> TypeBase Shape NoUniqueness -> Input
Input (ArrayTransforms
trs ArrayTransforms -> ArrayTransform -> ArrayTransforms
|> ArrayTransform
tr) VName
a TypeBase Shape NoUniqueness
t
addInitialTransforms :: ArrayTransforms -> Input -> Input
addInitialTransforms :: ArrayTransforms -> Input -> Input
addInitialTransforms ArrayTransforms
ts (Input ArrayTransforms
ots VName
a TypeBase Shape NoUniqueness
t) = ArrayTransforms -> VName -> TypeBase Shape NoUniqueness -> Input
Input (ArrayTransforms
ts ArrayTransforms -> ArrayTransforms -> ArrayTransforms
forall a. Semigroup a => a -> a -> a
<> ArrayTransforms
ots) VName
a TypeBase Shape NoUniqueness
t
applyTransform :: (MonadBuilder m) => ArrayTransform -> VName -> m VName
applyTransform :: forall (m :: * -> *).
MonadBuilder m =>
ArrayTransform -> VName -> m VName
applyTransform ArrayTransform
tr VName
ia = do
(Certs
cs, Exp (Rep m)
e) <- ArrayTransform -> VName -> m (Certs, Exp (Rep m))
forall (m :: * -> *) rep.
(Monad m, HasScope rep m) =>
ArrayTransform -> VName -> m (Certs, Exp rep)
transformToExp ArrayTransform
tr VName
ia
Certs -> m VName -> m VName
forall a. Certs -> m a -> m a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (m VName -> m VName) -> m VName -> m VName
forall a b. (a -> b) -> a -> b
$ String -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
s Exp (Rep m)
e
where
s :: String
s = case ArrayTransform
tr of
Replicate {} -> String
"replicate"
Rearrange {} -> String
"rearrange"
Reshape {} -> String
"reshape"
ReshapeOuter {} -> String
"reshape_outer"
ReshapeInner {} -> String
"reshape_inner"
Index {} -> String
"index"
applyTransforms :: (MonadBuilder m) => ArrayTransforms -> VName -> m VName
applyTransforms :: forall (m :: * -> *).
MonadBuilder m =>
ArrayTransforms -> VName -> m VName
applyTransforms (ArrayTransforms Seq ArrayTransform
ts) VName
a = (VName -> ArrayTransform -> m VName)
-> VName -> Seq ArrayTransform -> m VName
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldlM ((ArrayTransform -> VName -> m VName)
-> VName -> ArrayTransform -> m VName
forall a b c. (a -> b -> c) -> b -> a -> c
flip ArrayTransform -> VName -> m VName
forall (m :: * -> *).
MonadBuilder m =>
ArrayTransform -> VName -> m VName
applyTransform) VName
a Seq ArrayTransform
ts
inputsToSubExps ::
(MonadBuilder m) =>
[Input] ->
m [VName]
inputsToSubExps :: forall (m :: * -> *). MonadBuilder m => [Input] -> m [VName]
inputsToSubExps = (Input -> m VName) -> [Input] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Input -> m VName
forall {m :: * -> *}. MonadBuilder m => Input -> m VName
f
where
f :: Input -> m VName
f (Input ArrayTransforms
ts VName
a TypeBase Shape NoUniqueness
_) = ArrayTransforms -> VName -> m VName
forall (m :: * -> *).
MonadBuilder m =>
ArrayTransforms -> VName -> m VName
applyTransforms ArrayTransforms
ts VName
a
inputArray :: Input -> VName
inputArray :: Input -> VName
inputArray (Input ArrayTransforms
_ VName
v TypeBase Shape NoUniqueness
_) = VName
v
inputTransforms :: Input -> ArrayTransforms
inputTransforms :: Input -> ArrayTransforms
inputTransforms (Input ArrayTransforms
ts VName
_ TypeBase Shape NoUniqueness
_) = ArrayTransforms
ts
inputType :: Input -> Type
inputType :: Input -> TypeBase Shape NoUniqueness
inputType (Input (ArrayTransforms Seq ArrayTransform
ts) VName
_ TypeBase Shape NoUniqueness
at) =
(TypeBase Shape NoUniqueness
-> ArrayTransform -> TypeBase Shape NoUniqueness)
-> TypeBase Shape NoUniqueness
-> Seq ArrayTransform
-> TypeBase Shape NoUniqueness
forall b a. (b -> a -> b) -> b -> Seq a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
Foldable.foldl TypeBase Shape NoUniqueness
-> ArrayTransform -> TypeBase Shape NoUniqueness
transformType TypeBase Shape NoUniqueness
at Seq ArrayTransform
ts
where
transformType :: TypeBase Shape NoUniqueness
-> ArrayTransform -> TypeBase Shape NoUniqueness
transformType TypeBase Shape NoUniqueness
t (Replicate Certs
_ Shape
shape) =
TypeBase Shape NoUniqueness -> Shape -> TypeBase Shape NoUniqueness
arrayOfShape TypeBase Shape NoUniqueness
t Shape
shape
transformType TypeBase Shape NoUniqueness
t (Rearrange Certs
_ [Int]
perm) =
[Int] -> TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
rearrangeType [Int]
perm TypeBase Shape NoUniqueness
t
transformType TypeBase Shape NoUniqueness
t (Reshape Certs
_ ReshapeKind
_ Shape
shape) =
TypeBase Shape NoUniqueness
t TypeBase Shape NoUniqueness -> Shape -> TypeBase Shape NoUniqueness
forall newshape oldshape u.
ArrayShape newshape =>
TypeBase oldshape u -> newshape -> TypeBase newshape u
`setArrayShape` Shape
shape
transformType TypeBase Shape NoUniqueness
t (ReshapeOuter Certs
_ ReshapeKind
_ Shape
shape) =
let Shape [SubExp]
oldshape = TypeBase Shape NoUniqueness -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase Shape NoUniqueness
t
in TypeBase Shape NoUniqueness
t TypeBase Shape NoUniqueness -> Shape -> TypeBase Shape NoUniqueness
forall newshape oldshape u.
ArrayShape newshape =>
TypeBase oldshape u -> newshape -> TypeBase newshape u
`setArrayShape` [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop Int
1 [SubExp]
oldshape)
transformType TypeBase Shape NoUniqueness
t (ReshapeInner Certs
_ ReshapeKind
_ Shape
shape) =
let Shape [SubExp]
oldshape = TypeBase Shape NoUniqueness -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase Shape NoUniqueness
t
in TypeBase Shape NoUniqueness
t TypeBase Shape NoUniqueness -> Shape -> TypeBase Shape NoUniqueness
forall newshape oldshape u.
ArrayShape newshape =>
TypeBase oldshape u -> newshape -> TypeBase newshape u
`setArrayShape` [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape (Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
take Int
1 [SubExp]
oldshape [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape)
transformType TypeBase Shape NoUniqueness
t (Index Certs
_ Slice SubExp
slice) =
TypeBase Shape NoUniqueness
t TypeBase Shape NoUniqueness -> Shape -> TypeBase Shape NoUniqueness
forall newshape oldshape u.
ArrayShape newshape =>
TypeBase oldshape u -> newshape -> TypeBase newshape u
`setArrayShape` Slice SubExp -> Shape
forall d. Slice d -> ShapeBase d
sliceShape Slice SubExp
slice
inputRowType :: Input -> Type
inputRowType :: Input -> TypeBase Shape NoUniqueness
inputRowType = TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. TypeBase Shape u -> TypeBase Shape u
rowType (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> (Input -> TypeBase Shape NoUniqueness)
-> Input
-> TypeBase Shape NoUniqueness
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Input -> TypeBase Shape NoUniqueness
inputType
inputRank :: Input -> Int
inputRank :: Input -> Int
inputRank = TypeBase Shape NoUniqueness -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (TypeBase Shape NoUniqueness -> Int)
-> (Input -> TypeBase Shape NoUniqueness) -> Input -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Input -> TypeBase Shape NoUniqueness
inputType
transformRows :: ArrayTransforms -> Input -> Input
transformRows :: ArrayTransforms -> Input -> Input
transformRows (ArrayTransforms Seq ArrayTransform
ts) =
(Input -> Seq ArrayTransform -> Input)
-> Seq ArrayTransform -> Input -> Input
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Input -> ArrayTransform -> Input)
-> Input -> Seq ArrayTransform -> Input
forall b a. (b -> a -> b) -> b -> Seq a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
Foldable.foldl Input -> ArrayTransform -> Input
transformRows') Seq ArrayTransform
ts
where
transformRows' :: Input -> ArrayTransform -> Input
transformRows' Input
inp (Rearrange Certs
cs [Int]
perm) =
ArrayTransform -> Input -> Input
addTransform (Certs -> [Int] -> ArrayTransform
Rearrange Certs
cs (Int
0 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: (Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [Int]
perm)) Input
inp
transformRows' Input
inp (Reshape Certs
cs ReshapeKind
k Shape
shape) =
ArrayTransform -> Input -> Input
addTransform (Certs -> ReshapeKind -> Shape -> ArrayTransform
ReshapeInner Certs
cs ReshapeKind
k Shape
shape) Input
inp
transformRows' Input
inp (Replicate Certs
cs Shape
n)
| Input -> Int
inputRank Input
inp Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 =
Certs -> [Int] -> ArrayTransform
Rearrange Certs
forall a. Monoid a => a
mempty [Int
1, Int
0]
ArrayTransform -> Input -> Input
`addTransform` (Certs -> Shape -> ArrayTransform
Replicate Certs
cs Shape
n ArrayTransform -> Input -> Input
`addTransform` Input
inp)
| Bool
otherwise =
Certs -> [Int] -> ArrayTransform
Rearrange Certs
forall a. Monoid a => a
mempty (Int
2 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: Int
0 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: Int
1 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int
3 .. Input -> Int
inputRank Input
inp])
ArrayTransform -> Input -> Input
`addTransform` ( Certs -> Shape -> ArrayTransform
Replicate Certs
cs Shape
n
ArrayTransform -> Input -> Input
`addTransform` (Certs -> [Int] -> ArrayTransform
Rearrange Certs
forall a. Monoid a => a
mempty (Int
1 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: Int
0 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int
2 .. Input -> Int
inputRank Input
inp Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]) ArrayTransform -> Input -> Input
`addTransform` Input
inp)
)
transformRows' Input
inp ArrayTransform
nts =
String -> Input
forall a. HasCallStack => String -> a
error (String -> Input) -> String -> Input
forall a b. (a -> b) -> a -> b
$ String
"transformRows: Cannot transform this yet:\n" String -> ShowS
forall a. [a] -> [a] -> [a]
++ ArrayTransform -> String
forall a. Show a => a -> String
show ArrayTransform
nts String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\n" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Input -> String
forall a. Show a => a -> String
show Input
inp
transposeInput :: Int -> Int -> Input -> Input
transposeInput :: Int -> Int -> Input -> Input
transposeInput Int
k Int
n Input
inp =
ArrayTransform -> Input -> Input
addTransform (Certs -> [Int] -> ArrayTransform
Rearrange Certs
forall a. Monoid a => a
mempty ([Int] -> ArrayTransform) -> [Int] -> ArrayTransform
forall a b. (a -> b) -> a -> b
$ Int -> Int -> [Int] -> [Int]
forall a. Int -> Int -> [a] -> [a]
transposeIndex Int
k Int
n [Int
0 .. Input -> Int
inputRank Input
inp Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]) Input
inp
data SOAC rep
= Stream SubExp [Input] [SubExp] (Lambda rep)
| Scatter SubExp [Input] (ScatterSpec VName) (Lambda rep)
| Screma SubExp [Input] (ScremaForm rep)
| Hist SubExp [Input] [HistOp rep] (Lambda rep)
deriving (SOAC rep -> SOAC rep -> Bool
(SOAC rep -> SOAC rep -> Bool)
-> (SOAC rep -> SOAC rep -> Bool) -> Eq (SOAC rep)
forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
== :: SOAC rep -> SOAC rep -> Bool
$c/= :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
/= :: SOAC rep -> SOAC rep -> Bool
Eq, Int -> SOAC rep -> ShowS
[SOAC rep] -> ShowS
SOAC rep -> String
(Int -> SOAC rep -> ShowS)
-> (SOAC rep -> String) -> ([SOAC rep] -> ShowS) -> Show (SOAC rep)
forall rep. RepTypes rep => Int -> SOAC rep -> ShowS
forall rep. RepTypes rep => [SOAC rep] -> ShowS
forall rep. RepTypes rep => SOAC rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall rep. RepTypes rep => Int -> SOAC rep -> ShowS
showsPrec :: Int -> SOAC rep -> ShowS
$cshow :: forall rep. RepTypes rep => SOAC rep -> String
show :: SOAC rep -> String
$cshowList :: forall rep. RepTypes rep => [SOAC rep] -> ShowS
showList :: [SOAC rep] -> ShowS
Show)
inputs :: SOAC rep -> [Input]
inputs :: forall rep. SOAC rep -> [Input]
inputs (Stream SubExp
_ [Input]
arrs [SubExp]
_ Lambda rep
_) = [Input]
arrs
inputs (Scatter SubExp
_ [Input]
arrs ScatterSpec VName
_lam Lambda rep
_spec) = [Input]
arrs
inputs (Screma SubExp
_ [Input]
arrs ScremaForm rep
_) = [Input]
arrs
inputs (Hist SubExp
_ [Input]
inps [HistOp rep]
_ Lambda rep
_) = [Input]
inps
setInputs :: [Input] -> SOAC rep -> SOAC rep
setInputs :: forall rep. [Input] -> SOAC rep -> SOAC rep
setInputs [Input]
arrs (Stream SubExp
w [Input]
_ [SubExp]
nes Lambda rep
lam) =
SubExp -> [Input] -> [SubExp] -> Lambda rep -> SOAC rep
forall rep. SubExp -> [Input] -> [SubExp] -> Lambda rep -> SOAC rep
Stream ([Input] -> SubExp -> SubExp
newWidth [Input]
arrs SubExp
w) [Input]
arrs [SubExp]
nes Lambda rep
lam
setInputs [Input]
arrs (Scatter SubExp
w [Input]
_ ScatterSpec VName
lam Lambda rep
spec) =
SubExp -> [Input] -> ScatterSpec VName -> Lambda rep -> SOAC rep
forall rep.
SubExp -> [Input] -> ScatterSpec VName -> Lambda rep -> SOAC rep
Scatter ([Input] -> SubExp -> SubExp
newWidth [Input]
arrs SubExp
w) [Input]
arrs ScatterSpec VName
lam Lambda rep
spec
setInputs [Input]
arrs (Screma SubExp
w [Input]
_ ScremaForm rep
form) =
SubExp -> [Input] -> ScremaForm rep -> SOAC rep
forall rep. SubExp -> [Input] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [Input]
arrs ScremaForm rep
form
setInputs [Input]
inps (Hist SubExp
w [Input]
_ [HistOp rep]
ops Lambda rep
lam) =
SubExp -> [Input] -> [HistOp rep] -> Lambda rep -> SOAC rep
forall rep.
SubExp -> [Input] -> [HistOp rep] -> Lambda rep -> SOAC rep
Hist SubExp
w [Input]
inps [HistOp rep]
ops Lambda rep
lam
newWidth :: [Input] -> SubExp -> SubExp
newWidth :: [Input] -> SubExp -> SubExp
newWidth [] SubExp
w = SubExp
w
newWidth (Input
inp : [Input]
_) SubExp
_ = Int -> TypeBase Shape NoUniqueness -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 (TypeBase Shape NoUniqueness -> SubExp)
-> TypeBase Shape NoUniqueness -> SubExp
forall a b. (a -> b) -> a -> b
$ Input -> TypeBase Shape NoUniqueness
inputType Input
inp
lambda :: SOAC rep -> Lambda rep
lambda :: forall rep. SOAC rep -> Lambda rep
lambda (Stream SubExp
_ [Input]
_ [SubExp]
_ Lambda rep
lam) = Lambda rep
lam
lambda (Scatter SubExp
_len [Input]
_ivs ScatterSpec VName
_spec Lambda rep
lam) = Lambda rep
lam
lambda (Screma SubExp
_ [Input]
_ (ScremaForm Lambda rep
lam [Scan rep]
_ [Reduce rep]
_)) = Lambda rep
lam
lambda (Hist SubExp
_ [Input]
_ [HistOp rep]
_ Lambda rep
lam) = Lambda rep
lam
setLambda :: Lambda rep -> SOAC rep -> SOAC rep
setLambda :: forall rep. Lambda rep -> SOAC rep -> SOAC rep
setLambda Lambda rep
lam (Stream SubExp
w [Input]
arrs [SubExp]
nes Lambda rep
_) =
SubExp -> [Input] -> [SubExp] -> Lambda rep -> SOAC rep
forall rep. SubExp -> [Input] -> [SubExp] -> Lambda rep -> SOAC rep
Stream SubExp
w [Input]
arrs [SubExp]
nes Lambda rep
lam
setLambda Lambda rep
lam (Scatter SubExp
len [Input]
arrs ScatterSpec VName
spec Lambda rep
_lam) =
SubExp -> [Input] -> ScatterSpec VName -> Lambda rep -> SOAC rep
forall rep.
SubExp -> [Input] -> ScatterSpec VName -> Lambda rep -> SOAC rep
Scatter SubExp
len [Input]
arrs ScatterSpec VName
spec Lambda rep
lam
setLambda Lambda rep
lam (Screma SubExp
w [Input]
arrs (ScremaForm Lambda rep
_ [Scan rep]
scan [Reduce rep]
red)) =
SubExp -> [Input] -> ScremaForm rep -> SOAC rep
forall rep. SubExp -> [Input] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [Input]
arrs (Lambda rep -> [Scan rep] -> [Reduce rep] -> ScremaForm rep
forall rep.
Lambda rep -> [Scan rep] -> [Reduce rep] -> ScremaForm rep
ScremaForm Lambda rep
lam [Scan rep]
scan [Reduce rep]
red)
setLambda Lambda rep
lam (Hist SubExp
w [Input]
ops [HistOp rep]
inps Lambda rep
_) =
SubExp -> [Input] -> [HistOp rep] -> Lambda rep -> SOAC rep
forall rep.
SubExp -> [Input] -> [HistOp rep] -> Lambda rep -> SOAC rep
Hist SubExp
w [Input]
ops [HistOp rep]
inps Lambda rep
lam
typeOf :: SOAC rep -> [Type]
typeOf :: forall rep. SOAC rep -> [TypeBase Shape NoUniqueness]
typeOf (Stream SubExp
w [Input]
_ [SubExp]
nes Lambda rep
lam) =
let accrtps :: [TypeBase Shape NoUniqueness]
accrtps = Int
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness])
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda rep
lam
arrtps :: [TypeBase Shape NoUniqueness]
arrtps =
[ TypeBase Shape NoUniqueness
-> Shape -> NoUniqueness -> TypeBase Shape NoUniqueness
forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf (Int -> TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. Int -> TypeBase Shape u -> TypeBase Shape u
stripArray Int
1 TypeBase Shape NoUniqueness
t) ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) NoUniqueness
NoUniqueness
| TypeBase Shape NoUniqueness
t <- Int
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) (Lambda rep -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda rep
lam)
]
in [TypeBase Shape NoUniqueness]
accrtps [TypeBase Shape NoUniqueness]
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a. [a] -> [a] -> [a]
++ [TypeBase Shape NoUniqueness]
arrtps
typeOf (Scatter SubExp
_w [Input]
_ivs ScatterSpec VName
dests Lambda rep
lam) =
(TypeBase Shape NoUniqueness
-> Shape -> TypeBase Shape NoUniqueness)
-> [TypeBase Shape NoUniqueness]
-> [Shape]
-> [TypeBase Shape NoUniqueness]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith TypeBase Shape NoUniqueness -> Shape -> TypeBase Shape NoUniqueness
arrayOfShape [TypeBase Shape NoUniqueness]
val_ts [Shape]
ws
where
indexes :: Int
indexes = [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) [Int]
ns ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ (Shape -> Int) -> [Shape] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Shape -> Int
forall a. ShapeBase a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Shape]
ws
val_ts :: [TypeBase Shape NoUniqueness]
val_ts = Int
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a. Int -> [a] -> [a]
drop Int
indexes ([TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness])
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda rep
lam
([Shape]
ws, [Int]
ns, [VName]
_) = ScatterSpec VName -> ([Shape], [Int], [VName])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ScatterSpec VName
dests
typeOf (Screma SubExp
w [Input]
_ ScremaForm rep
form) =
SubExp -> ScremaForm rep -> [TypeBase Shape NoUniqueness]
forall rep.
SubExp -> ScremaForm rep -> [TypeBase Shape NoUniqueness]
scremaType SubExp
w ScremaForm rep
form
typeOf (Hist SubExp
_ [Input]
_ [HistOp rep]
ops Lambda rep
_) = do
HistOp rep
op <- [HistOp rep]
ops
(TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map (TypeBase Shape NoUniqueness -> Shape -> TypeBase Shape NoUniqueness
`arrayOfShape` HistOp rep -> Shape
forall rep. HistOp rep -> Shape
histShape HistOp rep
op) (Lambda rep -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType (Lambda rep -> [TypeBase Shape NoUniqueness])
-> Lambda rep -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> a -> b
$ HistOp rep -> Lambda rep
forall rep. HistOp rep -> Lambda rep
histOp HistOp rep
op)
width :: SOAC rep -> SubExp
width :: forall rep. SOAC rep -> SubExp
width (Stream SubExp
w [Input]
_ [SubExp]
_ Lambda rep
_) = SubExp
w
width (Scatter SubExp
len [Input]
_lam ScatterSpec VName
_ivs Lambda rep
_as) = SubExp
len
width (Screma SubExp
w [Input]
_ ScremaForm rep
_) = SubExp
w
width (Hist SubExp
w [Input]
_ [HistOp rep]
_ Lambda rep
_) = SubExp
w
toExp ::
(MonadBuilder m, Op (Rep m) ~ Futhark.SOAC (Rep m)) =>
SOAC (Rep m) ->
m (Exp (Rep m))
toExp :: forall (m :: * -> *).
(MonadBuilder m, Op (Rep m) ~ SOAC (Rep m)) =>
SOAC (Rep m) -> m (Exp (Rep m))
toExp SOAC (Rep m)
soac = Op (Rep m) -> Exp (Rep m)
SOAC (Rep m) -> Exp (Rep m)
forall rep. Op rep -> Exp rep
Op (SOAC (Rep m) -> Exp (Rep m))
-> m (SOAC (Rep m)) -> m (Exp (Rep m))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOAC (Rep m) -> m (SOAC (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
SOAC (Rep m) -> m (SOAC (Rep m))
toSOAC SOAC (Rep m)
soac
toSOAC :: (MonadBuilder m) => SOAC (Rep m) -> m (Futhark.SOAC (Rep m))
toSOAC :: forall (m :: * -> *).
MonadBuilder m =>
SOAC (Rep m) -> m (SOAC (Rep m))
toSOAC (Stream SubExp
w [Input]
inps [SubExp]
nes Lambda (Rep m)
lam) =
SubExp -> [VName] -> [SubExp] -> Lambda (Rep m) -> SOAC (Rep m)
forall rep. SubExp -> [VName] -> [SubExp] -> Lambda rep -> SOAC rep
Futhark.Stream SubExp
w ([VName] -> [SubExp] -> Lambda (Rep m) -> SOAC (Rep m))
-> m [VName] -> m ([SubExp] -> Lambda (Rep m) -> SOAC (Rep m))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Input] -> m [VName]
forall (m :: * -> *). MonadBuilder m => [Input] -> m [VName]
inputsToSubExps [Input]
inps m ([SubExp] -> Lambda (Rep m) -> SOAC (Rep m))
-> m [SubExp] -> m (Lambda (Rep m) -> SOAC (Rep m))
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [SubExp] -> m [SubExp]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp]
nes m (Lambda (Rep m) -> SOAC (Rep m))
-> m (Lambda (Rep m)) -> m (SOAC (Rep m))
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Lambda (Rep m) -> m (Lambda (Rep m))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda (Rep m)
lam
toSOAC (Scatter SubExp
w [Input]
ivs ScatterSpec VName
dests Lambda (Rep m)
lam) =
SubExp
-> [VName] -> ScatterSpec VName -> Lambda (Rep m) -> SOAC (Rep m)
forall rep.
SubExp -> [VName] -> ScatterSpec VName -> Lambda rep -> SOAC rep
Futhark.Scatter SubExp
w ([VName] -> ScatterSpec VName -> Lambda (Rep m) -> SOAC (Rep m))
-> m [VName]
-> m (ScatterSpec VName -> Lambda (Rep m) -> SOAC (Rep m))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Input] -> m [VName]
forall (m :: * -> *). MonadBuilder m => [Input] -> m [VName]
inputsToSubExps [Input]
ivs m (ScatterSpec VName -> Lambda (Rep m) -> SOAC (Rep m))
-> m (ScatterSpec VName) -> m (Lambda (Rep m) -> SOAC (Rep m))
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ScatterSpec VName -> m (ScatterSpec VName)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ScatterSpec VName
dests m (Lambda (Rep m) -> SOAC (Rep m))
-> m (Lambda (Rep m)) -> m (SOAC (Rep m))
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Lambda (Rep m) -> m (Lambda (Rep m))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda (Rep m)
lam
toSOAC (Screma SubExp
w [Input]
arrs ScremaForm (Rep m)
form) =
SubExp -> [VName] -> ScremaForm (Rep m) -> SOAC (Rep m)
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Futhark.Screma SubExp
w ([VName] -> ScremaForm (Rep m) -> SOAC (Rep m))
-> m [VName] -> m (ScremaForm (Rep m) -> SOAC (Rep m))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Input] -> m [VName]
forall (m :: * -> *). MonadBuilder m => [Input] -> m [VName]
inputsToSubExps [Input]
arrs m (ScremaForm (Rep m) -> SOAC (Rep m))
-> m (ScremaForm (Rep m)) -> m (SOAC (Rep m))
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ScremaForm (Rep m) -> m (ScremaForm (Rep m))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ScremaForm (Rep m)
form
toSOAC (Hist SubExp
w [Input]
arrs [HistOp (Rep m)]
ops Lambda (Rep m)
lam) =
SubExp
-> [VName] -> [HistOp (Rep m)] -> Lambda (Rep m) -> SOAC (Rep m)
forall rep.
SubExp -> [VName] -> [HistOp rep] -> Lambda rep -> SOAC rep
Futhark.Hist SubExp
w ([VName] -> [HistOp (Rep m)] -> Lambda (Rep m) -> SOAC (Rep m))
-> m [VName]
-> m ([HistOp (Rep m)] -> Lambda (Rep m) -> SOAC (Rep m))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Input] -> m [VName]
forall (m :: * -> *). MonadBuilder m => [Input] -> m [VName]
inputsToSubExps [Input]
arrs m ([HistOp (Rep m)] -> Lambda (Rep m) -> SOAC (Rep m))
-> m [HistOp (Rep m)] -> m (Lambda (Rep m) -> SOAC (Rep m))
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [HistOp (Rep m)] -> m [HistOp (Rep m)]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [HistOp (Rep m)]
ops m (Lambda (Rep m) -> SOAC (Rep m))
-> m (Lambda (Rep m)) -> m (SOAC (Rep m))
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Lambda (Rep m) -> m (Lambda (Rep m))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda (Rep m)
lam
data NotSOAC
=
NotSOAC
deriving (Int -> NotSOAC -> ShowS
[NotSOAC] -> ShowS
NotSOAC -> String
(Int -> NotSOAC -> ShowS)
-> (NotSOAC -> String) -> ([NotSOAC] -> ShowS) -> Show NotSOAC
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> NotSOAC -> ShowS
showsPrec :: Int -> NotSOAC -> ShowS
$cshow :: NotSOAC -> String
show :: NotSOAC -> String
$cshowList :: [NotSOAC] -> ShowS
showList :: [NotSOAC] -> ShowS
Show)
fromExp ::
(Op rep ~ Futhark.SOAC rep, HasScope rep m) =>
Exp rep ->
m (Either NotSOAC (SOAC rep))
fromExp :: forall rep (m :: * -> *).
(Op rep ~ SOAC rep, HasScope rep m) =>
Exp rep -> m (Either NotSOAC (SOAC rep))
fromExp (Op (Futhark.Stream SubExp
w [VName]
as [SubExp]
nes Lambda rep
lam)) =
SOAC rep -> Either NotSOAC (SOAC rep)
forall a b. b -> Either a b
Right (SOAC rep -> Either NotSOAC (SOAC rep))
-> m (SOAC rep) -> m (Either NotSOAC (SOAC rep))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> [Input] -> [SubExp] -> Lambda rep -> SOAC rep
forall rep. SubExp -> [Input] -> [SubExp] -> Lambda rep -> SOAC rep
Stream SubExp
w ([Input] -> [SubExp] -> Lambda rep -> SOAC rep)
-> m [Input] -> m ([SubExp] -> Lambda rep -> SOAC rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> m Input) -> [VName] -> m [Input]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse VName -> m Input
forall t (f :: * -> *). HasScope t f => VName -> f Input
varInput [VName]
as m ([SubExp] -> Lambda rep -> SOAC rep)
-> m [SubExp] -> m (Lambda rep -> SOAC rep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [SubExp] -> m [SubExp]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp]
nes m (Lambda rep -> SOAC rep) -> m (Lambda rep) -> m (SOAC rep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Lambda rep -> m (Lambda rep)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda rep
lam)
fromExp (Op (Futhark.Scatter SubExp
w [VName]
arrs ScatterSpec VName
spec Lambda rep
lam)) =
SOAC rep -> Either NotSOAC (SOAC rep)
forall a b. b -> Either a b
Right (SOAC rep -> Either NotSOAC (SOAC rep))
-> m (SOAC rep) -> m (Either NotSOAC (SOAC rep))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> [Input] -> ScatterSpec VName -> Lambda rep -> SOAC rep
forall rep.
SubExp -> [Input] -> ScatterSpec VName -> Lambda rep -> SOAC rep
Scatter SubExp
w ([Input] -> ScatterSpec VName -> Lambda rep -> SOAC rep)
-> m [Input] -> m (ScatterSpec VName -> Lambda rep -> SOAC rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> m Input) -> [VName] -> m [Input]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse VName -> m Input
forall t (f :: * -> *). HasScope t f => VName -> f Input
varInput [VName]
arrs m (ScatterSpec VName -> Lambda rep -> SOAC rep)
-> m (ScatterSpec VName) -> m (Lambda rep -> SOAC rep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ScatterSpec VName -> m (ScatterSpec VName)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ScatterSpec VName
spec m (Lambda rep -> SOAC rep) -> m (Lambda rep) -> m (SOAC rep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Lambda rep -> m (Lambda rep)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda rep
lam)
fromExp (Op (Futhark.Screma SubExp
w [VName]
arrs ScremaForm rep
form)) =
SOAC rep -> Either NotSOAC (SOAC rep)
forall a b. b -> Either a b
Right (SOAC rep -> Either NotSOAC (SOAC rep))
-> m (SOAC rep) -> m (Either NotSOAC (SOAC rep))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> [Input] -> ScremaForm rep -> SOAC rep
forall rep. SubExp -> [Input] -> ScremaForm rep -> SOAC rep
Screma SubExp
w ([Input] -> ScremaForm rep -> SOAC rep)
-> m [Input] -> m (ScremaForm rep -> SOAC rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> m Input) -> [VName] -> m [Input]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse VName -> m Input
forall t (f :: * -> *). HasScope t f => VName -> f Input
varInput [VName]
arrs m (ScremaForm rep -> SOAC rep)
-> m (ScremaForm rep) -> m (SOAC rep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ScremaForm rep -> m (ScremaForm rep)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ScremaForm rep
form)
fromExp (Op (Futhark.Hist SubExp
w [VName]
arrs [HistOp rep]
ops Lambda rep
lam)) =
SOAC rep -> Either NotSOAC (SOAC rep)
forall a b. b -> Either a b
Right (SOAC rep -> Either NotSOAC (SOAC rep))
-> m (SOAC rep) -> m (Either NotSOAC (SOAC rep))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> [Input] -> [HistOp rep] -> Lambda rep -> SOAC rep
forall rep.
SubExp -> [Input] -> [HistOp rep] -> Lambda rep -> SOAC rep
Hist SubExp
w ([Input] -> [HistOp rep] -> Lambda rep -> SOAC rep)
-> m [Input] -> m ([HistOp rep] -> Lambda rep -> SOAC rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> m Input) -> [VName] -> m [Input]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse VName -> m Input
forall t (f :: * -> *). HasScope t f => VName -> f Input
varInput [VName]
arrs m ([HistOp rep] -> Lambda rep -> SOAC rep)
-> m [HistOp rep] -> m (Lambda rep -> SOAC rep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [HistOp rep] -> m [HistOp rep]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [HistOp rep]
ops m (Lambda rep -> SOAC rep) -> m (Lambda rep) -> m (SOAC rep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Lambda rep -> m (Lambda rep)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda rep
lam)
fromExp Exp rep
_ = Either NotSOAC (SOAC rep) -> m (Either NotSOAC (SOAC rep))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either NotSOAC (SOAC rep) -> m (Either NotSOAC (SOAC rep)))
-> Either NotSOAC (SOAC rep) -> m (Either NotSOAC (SOAC rep))
forall a b. (a -> b) -> a -> b
$ NotSOAC -> Either NotSOAC (SOAC rep)
forall a b. a -> Either a b
Left NotSOAC
NotSOAC
soacToStream ::
( HasScope rep m,
MonadFreshNames m,
Buildable rep,
BuilderOps rep,
Op rep ~ Futhark.SOAC rep
) =>
SOAC rep ->
m (SOAC rep, [Ident])
soacToStream :: forall rep (m :: * -> *).
(HasScope rep m, MonadFreshNames m, Buildable rep, BuilderOps rep,
Op rep ~ SOAC rep) =>
SOAC rep -> m (SOAC rep, [Ident])
soacToStream SOAC rep
soac = do
Param (TypeBase Shape NoUniqueness)
chunk_param <- String
-> TypeBase Shape NoUniqueness
-> m (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"chunk" (TypeBase Shape NoUniqueness
-> m (Param (TypeBase Shape NoUniqueness)))
-> TypeBase Shape NoUniqueness
-> m (Param (TypeBase Shape NoUniqueness))
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
let chvar :: SubExp
chvar = VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
chunk_param
(Lambda rep
lam, [Input]
inps) = (SOAC rep -> Lambda rep
forall rep. SOAC rep -> Lambda rep
lambda SOAC rep
soac, SOAC rep -> [Input]
forall rep. SOAC rep -> [Input]
inputs SOAC rep
soac)
w :: SubExp
w = SOAC rep -> SubExp
forall rep. SOAC rep -> SubExp
width SOAC rep
soac
Lambda rep
lam' <- Lambda rep -> m (Lambda rep)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda rep
lam
let arrrtps :: [TypeBase Shape NoUniqueness]
arrrtps = SubExp -> Lambda rep -> [TypeBase Shape NoUniqueness]
forall rep. SubExp -> Lambda rep -> [TypeBase Shape NoUniqueness]
mapType SubExp
w Lambda rep
lam
loutps :: [TypeBase Shape NoUniqueness]
loutps = [TypeBase Shape NoUniqueness
-> SubExp -> TypeBase Shape NoUniqueness
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
arrayOfRow TypeBase Shape NoUniqueness
t SubExp
chvar | TypeBase Shape NoUniqueness
t <- (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. TypeBase Shape u -> TypeBase Shape u
rowType [TypeBase Shape NoUniqueness]
arrrtps]
lintps :: [TypeBase Shape NoUniqueness]
lintps = [TypeBase Shape NoUniqueness
-> SubExp -> TypeBase Shape NoUniqueness
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
arrayOfRow TypeBase Shape NoUniqueness
t SubExp
chvar | TypeBase Shape NoUniqueness
t <- (Input -> TypeBase Shape NoUniqueness)
-> [Input] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map Input -> TypeBase Shape NoUniqueness
inputRowType [Input]
inps]
[Param (TypeBase Shape NoUniqueness)]
strm_inpids <- (TypeBase Shape NoUniqueness
-> m (Param (TypeBase Shape NoUniqueness)))
-> [TypeBase Shape NoUniqueness]
-> m [Param (TypeBase Shape NoUniqueness)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (String
-> TypeBase Shape NoUniqueness
-> m (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"inp") [TypeBase Shape NoUniqueness]
lintps
case SOAC rep
soac of
Screma SubExp
_ [Input]
_ ScremaForm rep
form
| Just Lambda rep
_ <- ScremaForm rep -> Maybe (Lambda rep)
forall rep. ScremaForm rep -> Maybe (Lambda rep)
Futhark.isMapSOAC ScremaForm rep
form -> do
[Ident]
strm_resids <- (TypeBase Shape NoUniqueness -> m Ident)
-> [TypeBase Shape NoUniqueness] -> m [Ident]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (String -> TypeBase Shape NoUniqueness -> m Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> TypeBase Shape NoUniqueness -> m Ident
newIdent String
"res") [TypeBase Shape NoUniqueness]
loutps
let insoac :: SOAC rep
insoac =
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Futhark.Screma SubExp
chvar ((Param (TypeBase Shape NoUniqueness) -> VName)
-> [Param (TypeBase Shape NoUniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName [Param (TypeBase Shape NoUniqueness)]
strm_inpids) (ScremaForm rep -> SOAC rep) -> ScremaForm rep -> SOAC rep
forall a b. (a -> b) -> a -> b
$
Lambda rep -> ScremaForm rep
forall rep. Lambda rep -> ScremaForm rep
Futhark.mapSOAC Lambda rep
lam'
insstm :: Stm rep
insstm = [Ident] -> Exp rep -> Stm rep
forall rep. Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet [Ident]
strm_resids (Exp rep -> Stm rep) -> Exp rep -> Stm rep
forall a b. (a -> b) -> a -> b
$ Op rep -> Exp rep
forall rep. Op rep -> Exp rep
Op Op rep
SOAC rep
insoac
strmbdy :: Body rep
strmbdy = Stms rep -> Result -> Body rep
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody (Stm rep -> Stms rep
forall rep. Stm rep -> Stms rep
oneStm Stm rep
insstm) (Result -> Body rep) -> Result -> Body rep
forall a b. (a -> b) -> a -> b
$ (Ident -> SubExpRes) -> [Ident] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> SubExpRes
subExpRes (SubExp -> SubExpRes) -> (Ident -> SubExp) -> Ident -> SubExpRes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> SubExp) -> (Ident -> VName) -> Ident -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ident -> VName
identName) [Ident]
strm_resids
strmpar :: [Param (TypeBase Shape NoUniqueness)]
strmpar = Param (TypeBase Shape NoUniqueness)
chunk_param Param (TypeBase Shape NoUniqueness)
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. a -> [a] -> [a]
: [Param (TypeBase Shape NoUniqueness)]
strm_inpids
strmlam :: Lambda rep
strmlam = [LParam rep]
-> [TypeBase Shape NoUniqueness] -> Body rep -> Lambda rep
forall rep.
[LParam rep]
-> [TypeBase Shape NoUniqueness] -> Body rep -> Lambda rep
Lambda [Param (TypeBase Shape NoUniqueness)]
[LParam rep]
strmpar [TypeBase Shape NoUniqueness]
loutps Body rep
strmbdy
(SOAC rep, [Ident]) -> m (SOAC rep, [Ident])
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> [Input] -> [SubExp] -> Lambda rep -> SOAC rep
forall rep. SubExp -> [Input] -> [SubExp] -> Lambda rep -> SOAC rep
Stream SubExp
w [Input]
inps [] Lambda rep
strmlam, [])
| Just ([Scan rep]
scans, Lambda rep
_) <- ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
forall rep. ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
Futhark.isScanomapSOAC ScremaForm rep
form,
Futhark.Scan Lambda rep
scan_lam [SubExp]
nes <- [Scan rep] -> Scan rep
forall rep. Buildable rep => [Scan rep] -> Scan rep
Futhark.singleScan [Scan rep]
scans -> do
let scan_arr_ts :: [TypeBase Shape NoUniqueness]
scan_arr_ts = (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map (TypeBase Shape NoUniqueness
-> SubExp -> TypeBase Shape NoUniqueness
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
chvar) ([TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness])
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda rep
scan_lam
accrtps :: [TypeBase Shape NoUniqueness]
accrtps = Lambda rep -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda rep
scan_lam
[Param (TypeBase Shape NoUniqueness)]
inpacc_ids <- (TypeBase Shape NoUniqueness
-> m (Param (TypeBase Shape NoUniqueness)))
-> [TypeBase Shape NoUniqueness]
-> m [Param (TypeBase Shape NoUniqueness)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (String
-> TypeBase Shape NoUniqueness
-> m (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"inpacc") [TypeBase Shape NoUniqueness]
accrtps
Lambda rep
maplam <- [SubExp] -> Lambda rep -> m (Lambda rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, Buildable rep) =>
[SubExp] -> Lambda rep -> m (Lambda rep)
mkMapPlusAccLam ((Param (TypeBase Shape NoUniqueness) -> SubExp)
-> [Param (TypeBase Shape NoUniqueness)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (Param (TypeBase Shape NoUniqueness) -> VName)
-> Param (TypeBase Shape NoUniqueness)
-> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName) [Param (TypeBase Shape NoUniqueness)]
inpacc_ids) Lambda rep
scan_lam
let strmpar :: [Param (TypeBase Shape NoUniqueness)]
strmpar = Param (TypeBase Shape NoUniqueness)
chunk_param Param (TypeBase Shape NoUniqueness)
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. a -> [a] -> [a]
: [Param (TypeBase Shape NoUniqueness)]
inpacc_ids [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ [Param (TypeBase Shape NoUniqueness)]
strm_inpids
Lambda rep
strmlam <- ((Lambda rep, Stms rep) -> Lambda rep)
-> m (Lambda rep, Stms rep) -> m (Lambda rep)
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Lambda rep, Stms rep) -> Lambda rep
forall a b. (a, b) -> a
fst (m (Lambda rep, Stms rep) -> m (Lambda rep))
-> (BuilderT rep (State VNameSource) Result
-> m (Lambda rep, Stms rep))
-> BuilderT rep (State VNameSource) Result
-> m (Lambda rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder rep (Lambda rep) -> m (Lambda rep, Stms rep)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder rep (Lambda rep) -> m (Lambda rep, Stms rep))
-> (BuilderT rep (State VNameSource) Result
-> Builder rep (Lambda rep))
-> BuilderT rep (State VNameSource) Result
-> m (Lambda rep, Stms rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [LParam (Rep (BuilderT rep (State VNameSource)))]
-> BuilderT rep (State VNameSource) Result
-> BuilderT
rep
(State VNameSource)
(Lambda (Rep (BuilderT rep (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [Param (TypeBase Shape NoUniqueness)]
[LParam (Rep (BuilderT rep (State VNameSource)))]
strmpar (BuilderT rep (State VNameSource) Result -> m (Lambda rep))
-> BuilderT rep (State VNameSource) Result -> m (Lambda rep)
forall a b. (a -> b) -> a -> b
$ do
([VName]
scan0_ids, [VName]
map_resids) <-
([VName] -> ([VName], [VName]))
-> BuilderT rep (State VNameSource) [VName]
-> BuilderT rep (State VNameSource) ([VName], [VName])
forall a b.
(a -> b)
-> BuilderT rep (State VNameSource) a
-> BuilderT rep (State VNameSource) b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([TypeBase Shape NoUniqueness] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TypeBase Shape NoUniqueness]
scan_arr_ts)) (BuilderT rep (State VNameSource) [VName]
-> BuilderT rep (State VNameSource) ([VName], [VName]))
-> (SOAC rep -> BuilderT rep (State VNameSource) [VName])
-> SOAC rep
-> BuilderT rep (State VNameSource) ([VName], [VName])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String
-> Exp (Rep (BuilderT rep (State VNameSource)))
-> BuilderT rep (State VNameSource) [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"scan" (Exp rep -> BuilderT rep (State VNameSource) [VName])
-> (SOAC rep -> Exp rep)
-> SOAC rep
-> BuilderT rep (State VNameSource) [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op rep -> Exp rep
SOAC rep -> Exp rep
forall rep. Op rep -> Exp rep
Op (SOAC rep -> BuilderT rep (State VNameSource) ([VName], [VName]))
-> SOAC rep -> BuilderT rep (State VNameSource) ([VName], [VName])
forall a b. (a -> b) -> a -> b
$
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Futhark.Screma SubExp
chvar ((Param (TypeBase Shape NoUniqueness) -> VName)
-> [Param (TypeBase Shape NoUniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName [Param (TypeBase Shape NoUniqueness)]
strm_inpids) (ScremaForm rep -> SOAC rep) -> ScremaForm rep -> SOAC rep
forall a b. (a -> b) -> a -> b
$
[Scan rep] -> Lambda rep -> ScremaForm rep
forall rep. [Scan rep] -> Lambda rep -> ScremaForm rep
Futhark.scanomapSOAC [Lambda rep -> [SubExp] -> Scan rep
forall rep. Lambda rep -> [SubExp] -> Scan rep
Futhark.Scan Lambda rep
scan_lam [SubExp]
nes] Lambda rep
lam'
SubExp
outszm1id <-
String
-> Exp (Rep (BuilderT rep (State VNameSource)))
-> BuilderT rep (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"outszm1" (Exp rep -> BuilderT rep (State VNameSource) SubExp)
-> (BasicOp -> Exp rep)
-> BasicOp
-> BuilderT rep (State VNameSource) SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> BuilderT rep (State VNameSource) SubExp)
-> BasicOp -> BuilderT rep (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$
BinOp -> SubExp -> SubExp -> BasicOp
BinOp
(IntType -> Overflow -> BinOp
Sub IntType
Int64 Overflow
OverflowUndef)
(VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
chunk_param)
(Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
1 :: Int64))
VName
empty_arr <-
String
-> Exp (Rep (BuilderT rep (State VNameSource)))
-> BuilderT rep (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"empty_arr" (Exp rep -> BuilderT rep (State VNameSource) VName)
-> (BasicOp -> Exp rep)
-> BasicOp
-> BuilderT rep (State VNameSource) VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> BuilderT rep (State VNameSource) VName)
-> BasicOp -> BuilderT rep (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$
CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp
(IntType -> CmpOp
CmpSlt IntType
Int64)
SubExp
outszm1id
(Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64))
let indexLast :: VName
-> BuilderT
rep
(State VNameSource)
(Exp (Rep (BuilderT rep (State VNameSource))))
indexLast VName
arr = VName
-> [BuilderT
rep
(State VNameSource)
(Exp (Rep (BuilderT rep (State VNameSource))))]
-> BuilderT
rep
(State VNameSource)
(Exp (Rep (BuilderT rep (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
eIndex VName
arr [SubExp
-> BuilderT
rep
(State VNameSource)
(Exp (Rep (BuilderT rep (State VNameSource))))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
outszm1id]
[VName]
lastel_ids <-
String
-> Exp (Rep (BuilderT rep (State VNameSource)))
-> BuilderT rep (State VNameSource) [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"lastel"
(Exp rep -> BuilderT rep (State VNameSource) [VName])
-> BuilderT rep (State VNameSource) (Exp rep)
-> BuilderT rep (State VNameSource) [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BuilderT
rep
(State VNameSource)
(Exp (Rep (BuilderT rep (State VNameSource))))
-> BuilderT
rep
(State VNameSource)
(Body (Rep (BuilderT rep (State VNameSource))))
-> BuilderT
rep
(State VNameSource)
(Body (Rep (BuilderT rep (State VNameSource))))
-> BuilderT
rep
(State VNameSource)
(Exp (Rep (BuilderT rep (State VNameSource))))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(SubExp
-> BuilderT
rep
(State VNameSource)
(Exp (Rep (BuilderT rep (State VNameSource))))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp
-> BuilderT
rep
(State VNameSource)
(Exp (Rep (BuilderT rep (State VNameSource)))))
-> SubExp
-> BuilderT
rep
(State VNameSource)
(Exp (Rep (BuilderT rep (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
empty_arr)
([SubExp]
-> BuilderT
rep
(State VNameSource)
(Body (Rep (BuilderT rep (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [SubExp]
nes)
([BuilderT
rep
(State VNameSource)
(Exp (Rep (BuilderT rep (State VNameSource))))]
-> BuilderT
rep
(State VNameSource)
(Body (Rep (BuilderT rep (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([BuilderT
rep
(State VNameSource)
(Exp (Rep (BuilderT rep (State VNameSource))))]
-> BuilderT
rep
(State VNameSource)
(Body (Rep (BuilderT rep (State VNameSource)))))
-> [BuilderT
rep
(State VNameSource)
(Exp (Rep (BuilderT rep (State VNameSource))))]
-> BuilderT
rep
(State VNameSource)
(Body (Rep (BuilderT rep (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (VName -> BuilderT rep (State VNameSource) (Exp rep))
-> [VName] -> [BuilderT rep (State VNameSource) (Exp rep)]
forall a b. (a -> b) -> [a] -> [b]
map VName -> BuilderT rep (State VNameSource) (Exp rep)
VName
-> BuilderT
rep
(State VNameSource)
(Exp (Rep (BuilderT rep (State VNameSource))))
indexLast [VName]
scan0_ids)
Body rep
addlelbdy <-
Lambda rep
-> [SubExp] -> BuilderT rep (State VNameSource) (Body rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, Buildable rep) =>
Lambda rep -> [SubExp] -> m (Body rep)
mkPlusBnds Lambda rep
scan_lam ([SubExp] -> BuilderT rep (State VNameSource) (Body rep))
-> [SubExp] -> BuilderT rep (State VNameSource) (Body rep)
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> [SubExp]) -> [VName] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ (Param (TypeBase Shape NoUniqueness) -> VName)
-> [Param (TypeBase Shape NoUniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName [Param (TypeBase Shape NoUniqueness)]
inpacc_ids [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
lastel_ids
let (Stms rep
addlelstm, Result
addlelres) = (Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms Body rep
addlelbdy, Body rep -> Result
forall rep. Body rep -> Result
bodyResult Body rep
addlelbdy)
[VName]
strm_resids <-
String
-> Exp (Rep (BuilderT rep (State VNameSource)))
-> BuilderT rep (State VNameSource) [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"strm_res" (Exp rep -> BuilderT rep (State VNameSource) [VName])
-> (SOAC rep -> Exp rep)
-> SOAC rep
-> BuilderT rep (State VNameSource) [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op rep -> Exp rep
SOAC rep -> Exp rep
forall rep. Op rep -> Exp rep
Op (SOAC rep -> BuilderT rep (State VNameSource) [VName])
-> SOAC rep -> BuilderT rep (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Futhark.Screma SubExp
chvar [VName]
scan0_ids (Lambda rep -> ScremaForm rep
forall rep. Lambda rep -> ScremaForm rep
Futhark.mapSOAC Lambda rep
maplam)
Stms (Rep (BuilderT rep (State VNameSource)))
-> BuilderT rep (State VNameSource) ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms rep
Stms (Rep (BuilderT rep (State VNameSource)))
addlelstm
Result -> BuilderT rep (State VNameSource) Result
forall a. a -> BuilderT rep (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> BuilderT rep (State VNameSource) Result)
-> Result -> BuilderT rep (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$ Result
addlelres Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ (VName -> SubExpRes) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> SubExpRes
subExpRes (SubExp -> SubExpRes) -> (VName -> SubExp) -> VName -> SubExpRes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) ([VName]
strm_resids [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
map_resids)
(SOAC rep, [Ident]) -> m (SOAC rep, [Ident])
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( SubExp -> [Input] -> [SubExp] -> Lambda rep -> SOAC rep
forall rep. SubExp -> [Input] -> [SubExp] -> Lambda rep -> SOAC rep
Stream SubExp
w [Input]
inps [SubExp]
nes Lambda rep
strmlam,
(Param (TypeBase Shape NoUniqueness) -> Ident)
-> [Param (TypeBase Shape NoUniqueness)] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape NoUniqueness) -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent [Param (TypeBase Shape NoUniqueness)]
inpacc_ids
)
| Just ([Reduce rep]
reds, Lambda rep
_) <- ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
Futhark.isRedomapSOAC ScremaForm rep
form,
Futhark.Reduce Commutativity
comm Lambda rep
lamin [SubExp]
nes <- [Reduce rep] -> Reduce rep
forall rep. Buildable rep => [Reduce rep] -> Reduce rep
Futhark.singleReduce [Reduce rep]
reds -> do
let accrtps :: [TypeBase Shape NoUniqueness]
accrtps = Int
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness])
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda rep
lam
loutps' :: [TypeBase Shape NoUniqueness]
loutps' = Int
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [TypeBase Shape NoUniqueness]
loutps
foldlam :: Lambda rep
foldlam = Lambda rep
lam'
[Ident]
strm_resids <- (TypeBase Shape NoUniqueness -> m Ident)
-> [TypeBase Shape NoUniqueness] -> m [Ident]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (String -> TypeBase Shape NoUniqueness -> m Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> TypeBase Shape NoUniqueness -> m Ident
newIdent String
"res") [TypeBase Shape NoUniqueness]
loutps'
[Param (TypeBase Shape NoUniqueness)]
inpacc_ids <- (TypeBase Shape NoUniqueness
-> m (Param (TypeBase Shape NoUniqueness)))
-> [TypeBase Shape NoUniqueness]
-> m [Param (TypeBase Shape NoUniqueness)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (String
-> TypeBase Shape NoUniqueness
-> m (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"inpacc") [TypeBase Shape NoUniqueness]
accrtps
[Ident]
acc0_ids <- (TypeBase Shape NoUniqueness -> m Ident)
-> [TypeBase Shape NoUniqueness] -> m [Ident]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (String -> TypeBase Shape NoUniqueness -> m Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> TypeBase Shape NoUniqueness -> m Ident
newIdent String
"acc0") [TypeBase Shape NoUniqueness]
accrtps
let insoac :: SOAC rep
insoac =
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Futhark.Screma
SubExp
chvar
((Param (TypeBase Shape NoUniqueness) -> VName)
-> [Param (TypeBase Shape NoUniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName [Param (TypeBase Shape NoUniqueness)]
strm_inpids)
(ScremaForm rep -> SOAC rep) -> ScremaForm rep -> SOAC rep
forall a b. (a -> b) -> a -> b
$ [Reduce rep] -> Lambda rep -> ScremaForm rep
forall rep. [Reduce rep] -> Lambda rep -> ScremaForm rep
Futhark.redomapSOAC [Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Futhark.Reduce Commutativity
comm Lambda rep
lamin [SubExp]
nes] Lambda rep
foldlam
insstm :: Stm rep
insstm = [Ident] -> Exp rep -> Stm rep
forall rep. Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet ([Ident]
acc0_ids [Ident] -> [Ident] -> [Ident]
forall a. [a] -> [a] -> [a]
++ [Ident]
strm_resids) (Exp rep -> Stm rep) -> Exp rep -> Stm rep
forall a b. (a -> b) -> a -> b
$ Op rep -> Exp rep
forall rep. Op rep -> Exp rep
Op Op rep
SOAC rep
insoac
Body rep
addaccbdy <-
Lambda rep -> [SubExp] -> m (Body rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, Buildable rep) =>
Lambda rep -> [SubExp] -> m (Body rep)
mkPlusBnds Lambda rep
lamin ([SubExp] -> m (Body rep)) -> [SubExp] -> m (Body rep)
forall a b. (a -> b) -> a -> b
$
(VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> [SubExp]) -> [VName] -> [SubExp]
forall a b. (a -> b) -> a -> b
$
(Param (TypeBase Shape NoUniqueness) -> VName)
-> [Param (TypeBase Shape NoUniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName [Param (TypeBase Shape NoUniqueness)]
inpacc_ids [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ (Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
acc0_ids
let (Stms rep
addaccstm, Result
addaccres) = (Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms Body rep
addaccbdy, Body rep -> Result
forall rep. Body rep -> Result
bodyResult Body rep
addaccbdy)
strmbdy :: Body rep
strmbdy =
Stms rep -> Result -> Body rep
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody (Stm rep -> Stms rep
forall rep. Stm rep -> Stms rep
oneStm Stm rep
insstm Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> Stms rep
addaccstm) (Result -> Body rep) -> Result -> Body rep
forall a b. (a -> b) -> a -> b
$
Result
addaccres Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ (Ident -> SubExpRes) -> [Ident] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> SubExpRes
subExpRes (SubExp -> SubExpRes) -> (Ident -> SubExp) -> Ident -> SubExpRes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> SubExp) -> (Ident -> VName) -> Ident -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ident -> VName
identName) [Ident]
strm_resids
strmpar :: [Param (TypeBase Shape NoUniqueness)]
strmpar = Param (TypeBase Shape NoUniqueness)
chunk_param Param (TypeBase Shape NoUniqueness)
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. a -> [a] -> [a]
: [Param (TypeBase Shape NoUniqueness)]
inpacc_ids [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ [Param (TypeBase Shape NoUniqueness)]
strm_inpids
strmlam :: Lambda rep
strmlam = [LParam rep]
-> [TypeBase Shape NoUniqueness] -> Body rep -> Lambda rep
forall rep.
[LParam rep]
-> [TypeBase Shape NoUniqueness] -> Body rep -> Lambda rep
Lambda [Param (TypeBase Shape NoUniqueness)]
[LParam rep]
strmpar ([TypeBase Shape NoUniqueness]
accrtps [TypeBase Shape NoUniqueness]
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a. [a] -> [a] -> [a]
++ [TypeBase Shape NoUniqueness]
loutps') Body rep
strmbdy
(SOAC rep, [Ident]) -> m (SOAC rep, [Ident])
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> [Input] -> [SubExp] -> Lambda rep -> SOAC rep
forall rep. SubExp -> [Input] -> [SubExp] -> Lambda rep -> SOAC rep
Stream SubExp
w [Input]
inps [SubExp]
nes Lambda rep
strmlam, [])
SOAC rep
_ -> (SOAC rep, [Ident]) -> m (SOAC rep, [Ident])
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SOAC rep
soac, [])
where
mkMapPlusAccLam ::
(MonadFreshNames m, Buildable rep) =>
[SubExp] ->
Lambda rep ->
m (Lambda rep)
mkMapPlusAccLam :: forall (m :: * -> *) rep.
(MonadFreshNames m, Buildable rep) =>
[SubExp] -> Lambda rep -> m (Lambda rep)
mkMapPlusAccLam [SubExp]
accs Lambda rep
plus = do
let ([Param (TypeBase Shape NoUniqueness)]
accpars, [Param (TypeBase Shape NoUniqueness)]
rempars) = Int
-> [Param (TypeBase Shape NoUniqueness)]
-> ([Param (TypeBase Shape NoUniqueness)],
[Param (TypeBase Shape NoUniqueness)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
accs) ([Param (TypeBase Shape NoUniqueness)]
-> ([Param (TypeBase Shape NoUniqueness)],
[Param (TypeBase Shape NoUniqueness)]))
-> [Param (TypeBase Shape NoUniqueness)]
-> ([Param (TypeBase Shape NoUniqueness)],
[Param (TypeBase Shape NoUniqueness)])
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [LParam rep]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
plus
parstms :: [Stm rep]
parstms =
(Param (TypeBase Shape NoUniqueness) -> SubExp -> Stm rep)
-> [Param (TypeBase Shape NoUniqueness)] -> [SubExp] -> [Stm rep]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
(\Param (TypeBase Shape NoUniqueness)
par SubExp
se -> [Ident] -> Exp rep -> Stm rep
forall rep. Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet [Param (TypeBase Shape NoUniqueness) -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent Param (TypeBase Shape NoUniqueness)
par] (BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se))
[Param (TypeBase Shape NoUniqueness)]
accpars
[SubExp]
accs
plus_bdy :: Body rep
plus_bdy = Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
plus
newlambdy :: Body rep
newlambdy =
BodyDec rep -> Stms rep -> Result -> Body rep
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body
(Body rep -> BodyDec rep
forall rep. Body rep -> BodyDec rep
bodyDec Body rep
plus_bdy)
([Stm rep] -> Stms rep
forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm rep]
parstms Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms Body rep
plus_bdy)
(Body rep -> Result
forall rep. Body rep -> Result
bodyResult Body rep
plus_bdy)
Lambda rep -> m (Lambda rep)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda (Lambda rep -> m (Lambda rep)) -> Lambda rep -> m (Lambda rep)
forall a b. (a -> b) -> a -> b
$ [LParam rep]
-> [TypeBase Shape NoUniqueness] -> Body rep -> Lambda rep
forall rep.
[LParam rep]
-> [TypeBase Shape NoUniqueness] -> Body rep -> Lambda rep
Lambda [Param (TypeBase Shape NoUniqueness)]
[LParam rep]
rempars (Lambda rep -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda rep
plus) Body rep
newlambdy
mkPlusBnds ::
(MonadFreshNames m, Buildable rep) =>
Lambda rep ->
[SubExp] ->
m (Body rep)
mkPlusBnds :: forall (m :: * -> *) rep.
(MonadFreshNames m, Buildable rep) =>
Lambda rep -> [SubExp] -> m (Body rep)
mkPlusBnds Lambda rep
plus [SubExp]
accels = do
Lambda rep
plus' <- Lambda rep -> m (Lambda rep)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda rep
plus
let parstms :: [Stm rep]
parstms =
(Param (TypeBase Shape NoUniqueness) -> SubExp -> Stm rep)
-> [Param (TypeBase Shape NoUniqueness)] -> [SubExp] -> [Stm rep]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
(\Param (TypeBase Shape NoUniqueness)
par SubExp
se -> [Ident] -> Exp rep -> Stm rep
forall rep. Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet [Param (TypeBase Shape NoUniqueness) -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent Param (TypeBase Shape NoUniqueness)
par] (BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se))
(Lambda rep -> [LParam rep]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
plus')
[SubExp]
accels
body :: Body rep
body = Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
plus'
Body rep -> m (Body rep)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body rep -> m (Body rep)) -> Body rep -> m (Body rep)
forall a b. (a -> b) -> a -> b
$ Body rep
body {bodyStms = stmsFromList parstms <> bodyStms body}
ppArrayTransform :: PP.Doc a -> ArrayTransform -> PP.Doc a
ppArrayTransform :: forall a. Doc a -> ArrayTransform -> Doc a
ppArrayTransform Doc a
e (Rearrange Certs
cs [Int]
perm) =
Doc a
"rearrange" Doc a -> Doc a -> Doc a
forall a. Semigroup a => a -> a -> a
<> Certs -> Doc a
forall a ann. Pretty a => a -> Doc ann
forall ann. Certs -> Doc ann
pretty Certs
cs Doc a -> Doc a -> Doc a
forall a. Semigroup a => a -> a -> a
<> [Doc a] -> Doc a
forall a. [Doc a] -> Doc a
PP.apply [[Doc a] -> Doc a
forall a. [Doc a] -> Doc a
PP.apply ((Int -> Doc a) -> [Int] -> [Doc a]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Doc a
forall ann. Int -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty [Int]
perm), Doc a
e]
ppArrayTransform Doc a
e (Reshape Certs
cs ReshapeKind
ReshapeArbitrary Shape
shape) =
Doc a
"reshape" Doc a -> Doc a -> Doc a
forall a. Semigroup a => a -> a -> a
<> Certs -> Doc a
forall a ann. Pretty a => a -> Doc ann
forall ann. Certs -> Doc ann
pretty Certs
cs Doc a -> Doc a -> Doc a
forall a. Semigroup a => a -> a -> a
<> [Doc a] -> Doc a
forall a. [Doc a] -> Doc a
PP.apply [Shape -> Doc a
forall a ann. Pretty a => a -> Doc ann
forall ann. Shape -> Doc ann
pretty Shape
shape, Doc a
e]
ppArrayTransform Doc a
e (ReshapeOuter Certs
cs ReshapeKind
ReshapeArbitrary Shape
shape) =
Doc a
"reshape_outer" Doc a -> Doc a -> Doc a
forall a. Semigroup a => a -> a -> a
<> Certs -> Doc a
forall a ann. Pretty a => a -> Doc ann
forall ann. Certs -> Doc ann
pretty Certs
cs Doc a -> Doc a -> Doc a
forall a. Semigroup a => a -> a -> a
<> [Doc a] -> Doc a
forall a. [Doc a] -> Doc a
PP.apply [Shape -> Doc a
forall a ann. Pretty a => a -> Doc ann
forall ann. Shape -> Doc ann
pretty Shape
shape, Doc a
e]
ppArrayTransform Doc a
e (ReshapeInner Certs
cs ReshapeKind
ReshapeArbitrary Shape
shape) =
Doc a
"reshape_inner" Doc a -> Doc a -> Doc a
forall a. Semigroup a => a -> a -> a
<> Certs -> Doc a
forall a ann. Pretty a => a -> Doc ann
forall ann. Certs -> Doc ann
pretty Certs
cs Doc a -> Doc a -> Doc a
forall a. Semigroup a => a -> a -> a
<> [Doc a] -> Doc a
forall a. [Doc a] -> Doc a
PP.apply [Shape -> Doc a
forall a ann. Pretty a => a -> Doc ann
forall ann. Shape -> Doc ann
pretty Shape
shape, Doc a
e]
ppArrayTransform Doc a
e (Reshape Certs
cs ReshapeKind
ReshapeCoerce Shape
shape) =
Doc a
"coerce" Doc a -> Doc a -> Doc a
forall a. Semigroup a => a -> a -> a
<> Certs -> Doc a
forall a ann. Pretty a => a -> Doc ann
forall ann. Certs -> Doc ann
pretty Certs
cs Doc a -> Doc a -> Doc a
forall a. Semigroup a => a -> a -> a
<> [Doc a] -> Doc a
forall a. [Doc a] -> Doc a
PP.apply [Shape -> Doc a
forall a ann. Pretty a => a -> Doc ann
forall ann. Shape -> Doc ann
pretty Shape
shape, Doc a
e]
ppArrayTransform Doc a
e (ReshapeOuter Certs
cs ReshapeKind
ReshapeCoerce Shape
shape) =
Doc a
"coerce_outer" Doc a -> Doc a -> Doc a
forall a. Semigroup a => a -> a -> a
<> Certs -> Doc a
forall a ann. Pretty a => a -> Doc ann
forall ann. Certs -> Doc ann
pretty Certs
cs Doc a -> Doc a -> Doc a
forall a. Semigroup a => a -> a -> a
<> [Doc a] -> Doc a
forall a. [Doc a] -> Doc a
PP.apply [Shape -> Doc a
forall a ann. Pretty a => a -> Doc ann
forall ann. Shape -> Doc ann
pretty Shape
shape, Doc a
e]
ppArrayTransform Doc a
e (ReshapeInner Certs
cs ReshapeKind
ReshapeCoerce Shape
shape) =
Doc a
"coerce_inner" Doc a -> Doc a -> Doc a
forall a. Semigroup a => a -> a -> a
<> Certs -> Doc a
forall a ann. Pretty a => a -> Doc ann
forall ann. Certs -> Doc ann
pretty Certs
cs Doc a -> Doc a -> Doc a
forall a. Semigroup a => a -> a -> a
<> [Doc a] -> Doc a
forall a. [Doc a] -> Doc a
PP.apply [Shape -> Doc a
forall a ann. Pretty a => a -> Doc ann
forall ann. Shape -> Doc ann
pretty Shape
shape, Doc a
e]
ppArrayTransform Doc a
e (Replicate Certs
cs Shape
ne) =
Doc a
"replicate" Doc a -> Doc a -> Doc a
forall a. Semigroup a => a -> a -> a
<> Certs -> Doc a
forall a ann. Pretty a => a -> Doc ann
forall ann. Certs -> Doc ann
pretty Certs
cs Doc a -> Doc a -> Doc a
forall a. Semigroup a => a -> a -> a
<> [Doc a] -> Doc a
forall a. [Doc a] -> Doc a
PP.apply [Shape -> Doc a
forall a ann. Pretty a => a -> Doc ann
forall ann. Shape -> Doc ann
pretty Shape
ne, Doc a
e]
ppArrayTransform Doc a
e (Index Certs
cs Slice SubExp
slice) =
Doc a
e Doc a -> Doc a -> Doc a
forall a. Semigroup a => a -> a -> a
<> Certs -> Doc a
forall a ann. Pretty a => a -> Doc ann
forall ann. Certs -> Doc ann
pretty Certs
cs Doc a -> Doc a -> Doc a
forall a. Semigroup a => a -> a -> a
<> Slice SubExp -> Doc a
forall a ann. Pretty a => a -> Doc ann
forall ann. Slice SubExp -> Doc ann
pretty Slice SubExp
slice
instance PP.Pretty Input where
pretty :: forall ann. Input -> Doc ann
pretty (Input (ArrayTransforms Seq ArrayTransform
ts) VName
arr TypeBase Shape NoUniqueness
_) = (Doc ann -> ArrayTransform -> Doc ann)
-> Doc ann -> Seq ArrayTransform -> Doc ann
forall b a. (b -> a -> b) -> b -> Seq a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Doc ann -> ArrayTransform -> Doc ann
forall a. Doc a -> ArrayTransform -> Doc a
ppArrayTransform (VName -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. VName -> Doc ann
pretty VName
arr) Seq ArrayTransform
ts
instance PP.Pretty ArrayTransform where
pretty :: forall ann. ArrayTransform -> Doc ann
pretty = Doc ann -> ArrayTransform -> Doc ann
forall a. Doc a -> ArrayTransform -> Doc a
ppArrayTransform Doc ann
"INPUT"
instance (PrettyRep rep) => PP.Pretty (SOAC rep) where
pretty :: forall ann. SOAC rep -> Doc ann
pretty (Screma SubExp
w [Input]
arrs ScremaForm rep
form) = SubExp -> [Input] -> ScremaForm rep -> Doc ann
forall rep inp ann.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> ScremaForm rep -> Doc ann
Futhark.ppScrema SubExp
w [Input]
arrs ScremaForm rep
form
pretty (Hist SubExp
len [Input]
imgs [HistOp rep]
ops Lambda rep
bucket_fun) = SubExp -> [Input] -> [HistOp rep] -> Lambda rep -> Doc ann
forall rep inp ann.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> [HistOp rep] -> Lambda rep -> Doc ann
Futhark.ppHist SubExp
len [Input]
imgs [HistOp rep]
ops Lambda rep
bucket_fun
pretty (Stream SubExp
w [Input]
arrs [SubExp]
nes Lambda rep
lam) = SubExp -> [Input] -> [SubExp] -> Lambda rep -> Doc ann
forall rep inp ann.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> [SubExp] -> Lambda rep -> Doc ann
Futhark.ppStream SubExp
w [Input]
arrs [SubExp]
nes Lambda rep
lam
pretty (Scatter SubExp
w [Input]
arrs ScatterSpec VName
dests Lambda rep
lam) = SubExp -> [Input] -> ScatterSpec VName -> Lambda rep -> Doc ann
forall rep inp ann.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> ScatterSpec VName -> Lambda rep -> Doc ann
Futhark.ppScatter SubExp
w [Input]
arrs ScatterSpec VName
dests Lambda rep
lam