{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Futhark.IR.SOACS.SOAC
( SOAC (..),
ScremaForm (..),
ScatterSpec,
HistOp (..),
Scan (..),
scanResults,
singleScan,
Reduce (..),
redResults,
singleReduce,
scremaType,
soacType,
typeCheckSOAC,
mkIdentityLambda,
isIdentityLambda,
nilFn,
scanomapSOAC,
redomapSOAC,
scanSOAC,
reduceSOAC,
mapSOAC,
isScanomapSOAC,
isRedomapSOAC,
isScanSOAC,
isReduceSOAC,
isMapSOAC,
ppScrema,
ppHist,
ppStream,
ppScatter,
groupScatterResults,
groupScatterResults',
splitScatterResults,
SOACMapper (..),
identitySOACMapper,
mapSOACM,
traverseSOACStms,
)
where
import Control.Category
import Control.Monad
import Control.Monad.Identity
import Control.Monad.State.Strict
import Control.Monad.Writer
import Data.Function ((&))
import Data.List (intersperse)
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.Analysis.Alias qualified as Alias
import Futhark.Analysis.DataDependencies
import Futhark.Analysis.Metrics
import Futhark.Analysis.PrimExp.Convert
import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.Construct
import Futhark.IR
import Futhark.IR.Aliases (Aliases, CanBeAliased (..))
import Futhark.IR.Prop.Aliases
import Futhark.IR.TypeCheck qualified as TC
import Futhark.Optimise.Simplify.Rep
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util (chunks, maybeNth, splitAt3)
import Futhark.Util.Pretty (Doc, align, comma, commasep, docText, parens, ppTuple', pretty, (<+>), (</>))
import Futhark.Util.Pretty qualified as PP
import Prelude hiding (id, (.))
type ScatterSpec v = [(Shape, Int, v)]
data SOAC rep
= Stream SubExp [VName] [SubExp] (Lambda rep)
|
Scatter SubExp [VName] (ScatterSpec VName) (Lambda rep)
|
Hist SubExp [VName] [HistOp rep] (Lambda rep)
|
JVP [SubExp] [SubExp] (Lambda rep)
|
VJP [SubExp] [SubExp] (Lambda rep)
|
Screma SubExp [VName] (ScremaForm rep)
deriving (SOAC rep -> SOAC rep -> Bool
(SOAC rep -> SOAC rep -> Bool)
-> (SOAC rep -> SOAC rep -> Bool) -> Eq (SOAC rep)
forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
== :: SOAC rep -> SOAC rep -> Bool
$c/= :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
/= :: SOAC rep -> SOAC rep -> Bool
Eq, Eq (SOAC rep)
Eq (SOAC rep) =>
(SOAC rep -> SOAC rep -> Ordering)
-> (SOAC rep -> SOAC rep -> Bool)
-> (SOAC rep -> SOAC rep -> Bool)
-> (SOAC rep -> SOAC rep -> Bool)
-> (SOAC rep -> SOAC rep -> Bool)
-> (SOAC rep -> SOAC rep -> SOAC rep)
-> (SOAC rep -> SOAC rep -> SOAC rep)
-> Ord (SOAC rep)
SOAC rep -> SOAC rep -> Bool
SOAC rep -> SOAC rep -> Ordering
SOAC rep -> SOAC rep -> SOAC rep
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall rep. RepTypes rep => Eq (SOAC rep)
forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Ordering
forall rep. RepTypes rep => SOAC rep -> SOAC rep -> SOAC rep
$ccompare :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Ordering
compare :: SOAC rep -> SOAC rep -> Ordering
$c< :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
< :: SOAC rep -> SOAC rep -> Bool
$c<= :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
<= :: SOAC rep -> SOAC rep -> Bool
$c> :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
> :: SOAC rep -> SOAC rep -> Bool
$c>= :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
>= :: SOAC rep -> SOAC rep -> Bool
$cmax :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> SOAC rep
max :: SOAC rep -> SOAC rep -> SOAC rep
$cmin :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> SOAC rep
min :: SOAC rep -> SOAC rep -> SOAC rep
Ord, Int -> SOAC rep -> ShowS
[SOAC rep] -> ShowS
SOAC rep -> String
(Int -> SOAC rep -> ShowS)
-> (SOAC rep -> String) -> ([SOAC rep] -> ShowS) -> Show (SOAC rep)
forall rep. RepTypes rep => Int -> SOAC rep -> ShowS
forall rep. RepTypes rep => [SOAC rep] -> ShowS
forall rep. RepTypes rep => SOAC rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall rep. RepTypes rep => Int -> SOAC rep -> ShowS
showsPrec :: Int -> SOAC rep -> ShowS
$cshow :: forall rep. RepTypes rep => SOAC rep -> String
show :: SOAC rep -> String
$cshowList :: forall rep. RepTypes rep => [SOAC rep] -> ShowS
showList :: [SOAC rep] -> ShowS
Show)
data HistOp rep = HistOp
{ forall rep. HistOp rep -> Shape
histShape :: Shape,
forall rep. HistOp rep -> SubExp
histRaceFactor :: SubExp,
forall rep. HistOp rep -> [VName]
histDest :: [VName],
forall rep. HistOp rep -> [SubExp]
histNeutral :: [SubExp],
forall rep. HistOp rep -> Lambda rep
histOp :: Lambda rep
}
deriving (HistOp rep -> HistOp rep -> Bool
(HistOp rep -> HistOp rep -> Bool)
-> (HistOp rep -> HistOp rep -> Bool) -> Eq (HistOp rep)
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
== :: HistOp rep -> HistOp rep -> Bool
$c/= :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
/= :: HistOp rep -> HistOp rep -> Bool
Eq, Eq (HistOp rep)
Eq (HistOp rep) =>
(HistOp rep -> HistOp rep -> Ordering)
-> (HistOp rep -> HistOp rep -> Bool)
-> (HistOp rep -> HistOp rep -> Bool)
-> (HistOp rep -> HistOp rep -> Bool)
-> (HistOp rep -> HistOp rep -> Bool)
-> (HistOp rep -> HistOp rep -> HistOp rep)
-> (HistOp rep -> HistOp rep -> HistOp rep)
-> Ord (HistOp rep)
HistOp rep -> HistOp rep -> Bool
HistOp rep -> HistOp rep -> Ordering
HistOp rep -> HistOp rep -> HistOp rep
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall rep. RepTypes rep => Eq (HistOp rep)
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Ordering
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> HistOp rep
$ccompare :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Ordering
compare :: HistOp rep -> HistOp rep -> Ordering
$c< :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
< :: HistOp rep -> HistOp rep -> Bool
$c<= :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
<= :: HistOp rep -> HistOp rep -> Bool
$c> :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
> :: HistOp rep -> HistOp rep -> Bool
$c>= :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
>= :: HistOp rep -> HistOp rep -> Bool
$cmax :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> HistOp rep
max :: HistOp rep -> HistOp rep -> HistOp rep
$cmin :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> HistOp rep
min :: HistOp rep -> HistOp rep -> HistOp rep
Ord, Int -> HistOp rep -> ShowS
[HistOp rep] -> ShowS
HistOp rep -> String
(Int -> HistOp rep -> ShowS)
-> (HistOp rep -> String)
-> ([HistOp rep] -> ShowS)
-> Show (HistOp rep)
forall rep. RepTypes rep => Int -> HistOp rep -> ShowS
forall rep. RepTypes rep => [HistOp rep] -> ShowS
forall rep. RepTypes rep => HistOp rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall rep. RepTypes rep => Int -> HistOp rep -> ShowS
showsPrec :: Int -> HistOp rep -> ShowS
$cshow :: forall rep. RepTypes rep => HistOp rep -> String
show :: HistOp rep -> String
$cshowList :: forall rep. RepTypes rep => [HistOp rep] -> ShowS
showList :: [HistOp rep] -> ShowS
Show)
data ScremaForm rep = ScremaForm
{
forall rep. ScremaForm rep -> Lambda rep
scremaLambda :: Lambda rep,
forall rep. ScremaForm rep -> [Scan rep]
scremaScans :: [Scan rep],
forall rep. ScremaForm rep -> [Reduce rep]
scremaReduces :: [Reduce rep]
}
deriving (ScremaForm rep -> ScremaForm rep -> Bool
(ScremaForm rep -> ScremaForm rep -> Bool)
-> (ScremaForm rep -> ScremaForm rep -> Bool)
-> Eq (ScremaForm rep)
forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
== :: ScremaForm rep -> ScremaForm rep -> Bool
$c/= :: forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
/= :: ScremaForm rep -> ScremaForm rep -> Bool
Eq, Eq (ScremaForm rep)
Eq (ScremaForm rep) =>
(ScremaForm rep -> ScremaForm rep -> Ordering)
-> (ScremaForm rep -> ScremaForm rep -> Bool)
-> (ScremaForm rep -> ScremaForm rep -> Bool)
-> (ScremaForm rep -> ScremaForm rep -> Bool)
-> (ScremaForm rep -> ScremaForm rep -> Bool)
-> (ScremaForm rep -> ScremaForm rep -> ScremaForm rep)
-> (ScremaForm rep -> ScremaForm rep -> ScremaForm rep)
-> Ord (ScremaForm rep)
ScremaForm rep -> ScremaForm rep -> Bool
ScremaForm rep -> ScremaForm rep -> Ordering
ScremaForm rep -> ScremaForm rep -> ScremaForm rep
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall rep. RepTypes rep => Eq (ScremaForm rep)
forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Ordering
forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> ScremaForm rep
$ccompare :: forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Ordering
compare :: ScremaForm rep -> ScremaForm rep -> Ordering
$c< :: forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
< :: ScremaForm rep -> ScremaForm rep -> Bool
$c<= :: forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
<= :: ScremaForm rep -> ScremaForm rep -> Bool
$c> :: forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
> :: ScremaForm rep -> ScremaForm rep -> Bool
$c>= :: forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
>= :: ScremaForm rep -> ScremaForm rep -> Bool
$cmax :: forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> ScremaForm rep
max :: ScremaForm rep -> ScremaForm rep -> ScremaForm rep
$cmin :: forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> ScremaForm rep
min :: ScremaForm rep -> ScremaForm rep -> ScremaForm rep
Ord, Int -> ScremaForm rep -> ShowS
[ScremaForm rep] -> ShowS
ScremaForm rep -> String
(Int -> ScremaForm rep -> ShowS)
-> (ScremaForm rep -> String)
-> ([ScremaForm rep] -> ShowS)
-> Show (ScremaForm rep)
forall rep. RepTypes rep => Int -> ScremaForm rep -> ShowS
forall rep. RepTypes rep => [ScremaForm rep] -> ShowS
forall rep. RepTypes rep => ScremaForm rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall rep. RepTypes rep => Int -> ScremaForm rep -> ShowS
showsPrec :: Int -> ScremaForm rep -> ShowS
$cshow :: forall rep. RepTypes rep => ScremaForm rep -> String
show :: ScremaForm rep -> String
$cshowList :: forall rep. RepTypes rep => [ScremaForm rep] -> ShowS
showList :: [ScremaForm rep] -> ShowS
Show)
singleBinOp :: (Buildable rep) => [Lambda rep] -> Lambda rep
singleBinOp :: forall rep. Buildable rep => [Lambda rep] -> Lambda rep
singleBinOp [Lambda rep]
lams =
Lambda
{ lambdaParams :: [LParam rep]
lambdaParams = (Lambda rep -> [Param Type]) -> [Lambda rep] -> [Param Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Lambda rep -> [Param Type]
Lambda rep -> [LParam rep]
forall {rep}. Lambda rep -> [Param (LParamInfo rep)]
xParams [Lambda rep]
lams [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ (Lambda rep -> [Param Type]) -> [Lambda rep] -> [Param Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Lambda rep -> [Param Type]
Lambda rep -> [LParam rep]
forall {rep}. Lambda rep -> [Param (LParamInfo rep)]
yParams [Lambda rep]
lams,
lambdaReturnType :: [Type]
lambdaReturnType = (Lambda rep -> [Type]) -> [Lambda rep] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType [Lambda rep]
lams,
lambdaBody :: Body rep
lambdaBody =
Stms rep -> Result -> Body rep
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody
([Stms rep] -> Stms rep
forall a. Monoid a => [a] -> a
mconcat ((Lambda rep -> Stms rep) -> [Lambda rep] -> [Stms rep]
forall a b. (a -> b) -> [a] -> [b]
map (Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms (Body rep -> Stms rep)
-> (Lambda rep -> Body rep) -> Lambda rep -> Stms rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody) [Lambda rep]
lams))
((Lambda rep -> Result) -> [Lambda rep] -> Result
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Body rep -> Result
forall rep. Body rep -> Result
bodyResult (Body rep -> Result)
-> (Lambda rep -> Body rep) -> Lambda rep -> Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody) [Lambda rep]
lams)
}
where
xParams :: Lambda rep -> [Param (LParamInfo rep)]
xParams Lambda rep
lam = Int -> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. Int -> [a] -> [a]
take ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam)) (Lambda rep -> [Param (LParamInfo rep)]
forall {rep}. Lambda rep -> [Param (LParamInfo rep)]
lambdaParams Lambda rep
lam)
yParams :: Lambda rep -> [Param (LParamInfo rep)]
yParams Lambda rep
lam = Int -> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam)) (Lambda rep -> [Param (LParamInfo rep)]
forall {rep}. Lambda rep -> [Param (LParamInfo rep)]
lambdaParams Lambda rep
lam)
data Scan rep = Scan
{ forall rep. Scan rep -> Lambda rep
scanLambda :: Lambda rep,
forall rep. Scan rep -> [SubExp]
scanNeutral :: [SubExp]
}
deriving (Scan rep -> Scan rep -> Bool
(Scan rep -> Scan rep -> Bool)
-> (Scan rep -> Scan rep -> Bool) -> Eq (Scan rep)
forall rep. RepTypes rep => Scan rep -> Scan rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall rep. RepTypes rep => Scan rep -> Scan rep -> Bool
== :: Scan rep -> Scan rep -> Bool
$c/= :: forall rep. RepTypes rep => Scan rep -> Scan rep -> Bool
/= :: Scan rep -> Scan rep -> Bool
Eq, Eq (Scan rep)
Eq (Scan rep) =>
(Scan rep -> Scan rep -> Ordering)
-> (Scan rep -> Scan rep -> Bool)
-> (Scan rep -> Scan rep -> Bool)
-> (Scan rep -> Scan rep -> Bool)
-> (Scan rep -> Scan rep -> Bool)
-> (Scan rep -> Scan rep -> Scan rep)
-> (Scan rep -> Scan rep -> Scan rep)
-> Ord (Scan rep)
Scan rep -> Scan rep -> Bool
Scan rep -> Scan rep -> Ordering
Scan rep -> Scan rep -> Scan rep
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall rep. RepTypes rep => Eq (Scan rep)
forall rep. RepTypes rep => Scan rep -> Scan rep -> Bool
forall rep. RepTypes rep => Scan rep -> Scan rep -> Ordering
forall rep. RepTypes rep => Scan rep -> Scan rep -> Scan rep
$ccompare :: forall rep. RepTypes rep => Scan rep -> Scan rep -> Ordering
compare :: Scan rep -> Scan rep -> Ordering
$c< :: forall rep. RepTypes rep => Scan rep -> Scan rep -> Bool
< :: Scan rep -> Scan rep -> Bool
$c<= :: forall rep. RepTypes rep => Scan rep -> Scan rep -> Bool
<= :: Scan rep -> Scan rep -> Bool
$c> :: forall rep. RepTypes rep => Scan rep -> Scan rep -> Bool
> :: Scan rep -> Scan rep -> Bool
$c>= :: forall rep. RepTypes rep => Scan rep -> Scan rep -> Bool
>= :: Scan rep -> Scan rep -> Bool
$cmax :: forall rep. RepTypes rep => Scan rep -> Scan rep -> Scan rep
max :: Scan rep -> Scan rep -> Scan rep
$cmin :: forall rep. RepTypes rep => Scan rep -> Scan rep -> Scan rep
min :: Scan rep -> Scan rep -> Scan rep
Ord, Int -> Scan rep -> ShowS
[Scan rep] -> ShowS
Scan rep -> String
(Int -> Scan rep -> ShowS)
-> (Scan rep -> String) -> ([Scan rep] -> ShowS) -> Show (Scan rep)
forall rep. RepTypes rep => Int -> Scan rep -> ShowS
forall rep. RepTypes rep => [Scan rep] -> ShowS
forall rep. RepTypes rep => Scan rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall rep. RepTypes rep => Int -> Scan rep -> ShowS
showsPrec :: Int -> Scan rep -> ShowS
$cshow :: forall rep. RepTypes rep => Scan rep -> String
show :: Scan rep -> String
$cshowList :: forall rep. RepTypes rep => [Scan rep] -> ShowS
showList :: [Scan rep] -> ShowS
Show)
scanSizes :: [Scan rep] -> [Int]
scanSizes :: forall rep. [Scan rep] -> [Int]
scanSizes = (Scan rep -> Int) -> [Scan rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int) -> (Scan rep -> [SubExp]) -> Scan rep -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Scan rep -> [SubExp]
forall rep. Scan rep -> [SubExp]
scanNeutral)
scanResults :: [Scan rep] -> Int
scanResults :: forall rep. [Scan rep] -> Int
scanResults = [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> ([Scan rep] -> [Int]) -> [Scan rep] -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Scan rep] -> [Int]
forall rep. [Scan rep] -> [Int]
scanSizes
singleScan :: (Buildable rep) => [Scan rep] -> Scan rep
singleScan :: forall rep. Buildable rep => [Scan rep] -> Scan rep
singleScan [Scan rep]
scans =
let scan_nes :: [SubExp]
scan_nes = (Scan rep -> [SubExp]) -> [Scan rep] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Scan rep -> [SubExp]
forall rep. Scan rep -> [SubExp]
scanNeutral [Scan rep]
scans
scan_lam :: Lambda rep
scan_lam = [Lambda rep] -> Lambda rep
forall rep. Buildable rep => [Lambda rep] -> Lambda rep
singleBinOp ([Lambda rep] -> Lambda rep) -> [Lambda rep] -> Lambda rep
forall a b. (a -> b) -> a -> b
$ (Scan rep -> Lambda rep) -> [Scan rep] -> [Lambda rep]
forall a b. (a -> b) -> [a] -> [b]
map Scan rep -> Lambda rep
forall rep. Scan rep -> Lambda rep
scanLambda [Scan rep]
scans
in Lambda rep -> [SubExp] -> Scan rep
forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda rep
scan_lam [SubExp]
scan_nes
data Reduce rep = Reduce
{ forall rep. Reduce rep -> Commutativity
redComm :: Commutativity,
forall rep. Reduce rep -> Lambda rep
redLambda :: Lambda rep,
forall rep. Reduce rep -> [SubExp]
redNeutral :: [SubExp]
}
deriving (Reduce rep -> Reduce rep -> Bool
(Reduce rep -> Reduce rep -> Bool)
-> (Reduce rep -> Reduce rep -> Bool) -> Eq (Reduce rep)
forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Bool
== :: Reduce rep -> Reduce rep -> Bool
$c/= :: forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Bool
/= :: Reduce rep -> Reduce rep -> Bool
Eq, Eq (Reduce rep)
Eq (Reduce rep) =>
(Reduce rep -> Reduce rep -> Ordering)
-> (Reduce rep -> Reduce rep -> Bool)
-> (Reduce rep -> Reduce rep -> Bool)
-> (Reduce rep -> Reduce rep -> Bool)
-> (Reduce rep -> Reduce rep -> Bool)
-> (Reduce rep -> Reduce rep -> Reduce rep)
-> (Reduce rep -> Reduce rep -> Reduce rep)
-> Ord (Reduce rep)
Reduce rep -> Reduce rep -> Bool
Reduce rep -> Reduce rep -> Ordering
Reduce rep -> Reduce rep -> Reduce rep
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall rep. RepTypes rep => Eq (Reduce rep)
forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Bool
forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Ordering
forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Reduce rep
$ccompare :: forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Ordering
compare :: Reduce rep -> Reduce rep -> Ordering
$c< :: forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Bool
< :: Reduce rep -> Reduce rep -> Bool
$c<= :: forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Bool
<= :: Reduce rep -> Reduce rep -> Bool
$c> :: forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Bool
> :: Reduce rep -> Reduce rep -> Bool
$c>= :: forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Bool
>= :: Reduce rep -> Reduce rep -> Bool
$cmax :: forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Reduce rep
max :: Reduce rep -> Reduce rep -> Reduce rep
$cmin :: forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Reduce rep
min :: Reduce rep -> Reduce rep -> Reduce rep
Ord, Int -> Reduce rep -> ShowS
[Reduce rep] -> ShowS
Reduce rep -> String
(Int -> Reduce rep -> ShowS)
-> (Reduce rep -> String)
-> ([Reduce rep] -> ShowS)
-> Show (Reduce rep)
forall rep. RepTypes rep => Int -> Reduce rep -> ShowS
forall rep. RepTypes rep => [Reduce rep] -> ShowS
forall rep. RepTypes rep => Reduce rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall rep. RepTypes rep => Int -> Reduce rep -> ShowS
showsPrec :: Int -> Reduce rep -> ShowS
$cshow :: forall rep. RepTypes rep => Reduce rep -> String
show :: Reduce rep -> String
$cshowList :: forall rep. RepTypes rep => [Reduce rep] -> ShowS
showList :: [Reduce rep] -> ShowS
Show)
redSizes :: [Reduce rep] -> [Int]
redSizes :: forall rep. [Reduce rep] -> [Int]
redSizes = (Reduce rep -> Int) -> [Reduce rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int) -> (Reduce rep -> [SubExp]) -> Reduce rep -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Reduce rep -> [SubExp]
forall rep. Reduce rep -> [SubExp]
redNeutral)
redResults :: [Reduce rep] -> Int
redResults :: forall rep. [Reduce rep] -> Int
redResults = [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> ([Reduce rep] -> [Int]) -> [Reduce rep] -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Reduce rep] -> [Int]
forall rep. [Reduce rep] -> [Int]
redSizes
singleReduce :: (Buildable rep) => [Reduce rep] -> Reduce rep
singleReduce :: forall rep. Buildable rep => [Reduce rep] -> Reduce rep
singleReduce [Reduce rep]
reds =
let red_nes :: [SubExp]
red_nes = (Reduce rep -> [SubExp]) -> [Reduce rep] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Reduce rep -> [SubExp]
forall rep. Reduce rep -> [SubExp]
redNeutral [Reduce rep]
reds
red_lam :: Lambda rep
red_lam = [Lambda rep] -> Lambda rep
forall rep. Buildable rep => [Lambda rep] -> Lambda rep
singleBinOp ([Lambda rep] -> Lambda rep) -> [Lambda rep] -> Lambda rep
forall a b. (a -> b) -> a -> b
$ (Reduce rep -> Lambda rep) -> [Reduce rep] -> [Lambda rep]
forall a b. (a -> b) -> [a] -> [b]
map Reduce rep -> Lambda rep
forall rep. Reduce rep -> Lambda rep
redLambda [Reduce rep]
reds
in Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce ([Commutativity] -> Commutativity
forall a. Monoid a => [a] -> a
mconcat ((Reduce rep -> Commutativity) -> [Reduce rep] -> [Commutativity]
forall a b. (a -> b) -> [a] -> [b]
map Reduce rep -> Commutativity
forall rep. Reduce rep -> Commutativity
redComm [Reduce rep]
reds)) Lambda rep
red_lam [SubExp]
red_nes
scremaType :: SubExp -> ScremaForm rep -> [Type]
scremaType :: forall rep. SubExp -> ScremaForm rep -> [Type]
scremaType SubExp
w (ScremaForm Lambda rep
map_lam [Scan rep]
scans [Reduce rep]
reds) =
[Type]
scan_tps [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
red_tps [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
w) [Type]
map_tps
where
scan_tps :: [Type]
scan_tps =
(Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
w) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$
(Scan rep -> [Type]) -> [Scan rep] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (Lambda rep -> [Type])
-> (Scan rep -> Lambda rep) -> Scan rep -> [Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Scan rep -> Lambda rep
forall rep. Scan rep -> Lambda rep
scanLambda) [Scan rep]
scans
red_tps :: [Type]
red_tps = (Reduce rep -> [Type]) -> [Reduce rep] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (Lambda rep -> [Type])
-> (Reduce rep -> Lambda rep) -> Reduce rep -> [Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Reduce rep -> Lambda rep
forall rep. Reduce rep -> Lambda rep
redLambda) [Reduce rep]
reds
map_tps :: [Type]
map_tps = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
scan_tps Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
red_tps) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
map_lam
mkIdentityLambda ::
(Buildable rep, MonadFreshNames m) =>
[Type] ->
m (Lambda rep)
mkIdentityLambda :: forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda [Type]
ts = do
[Param Type]
params <- (Type -> m (Param Type)) -> [Type] -> m [Param Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (String -> Type -> m (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"x") [Type]
ts
Lambda rep -> m (Lambda rep)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
Lambda
{ lambdaParams :: [LParam rep]
lambdaParams = [Param Type]
[LParam rep]
params,
lambdaBody :: Body rep
lambdaBody = Stms rep -> Result -> Body rep
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody Stms rep
forall a. Monoid a => a
mempty (Result -> Body rep) -> Result -> Body rep
forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes ([VName] -> Result) -> [VName] -> Result
forall a b. (a -> b) -> a -> b
$ (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
params,
lambdaReturnType :: [Type]
lambdaReturnType = [Type]
ts
}
isIdentityLambda :: Lambda rep -> Bool
isIdentityLambda :: forall rep. Lambda rep -> Bool
isIdentityLambda Lambda rep
lam =
(SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (Body rep -> Result
forall rep. Body rep -> Result
bodyResult (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam))
[SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== (Param (LParamInfo rep) -> SubExp)
-> [Param (LParamInfo rep)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (Param (LParamInfo rep) -> VName)
-> Param (LParamInfo rep)
-> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName) (Lambda rep -> [Param (LParamInfo rep)]
forall {rep}. Lambda rep -> [Param (LParamInfo rep)]
lambdaParams Lambda rep
lam)
nilFn :: (Buildable rep) => Lambda rep
nilFn :: forall rep. Buildable rep => Lambda rep
nilFn = [LParam rep] -> [Type] -> Body rep -> Lambda rep
forall rep. [LParam rep] -> [Type] -> Body rep -> Lambda rep
Lambda [Param Type]
[LParam rep]
forall a. Monoid a => a
mempty [Type]
forall a. Monoid a => a
mempty (Stms rep -> Result -> Body rep
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody Stms rep
forall a. Monoid a => a
mempty Result
forall a. Monoid a => a
mempty)
scanomapSOAC :: [Scan rep] -> Lambda rep -> ScremaForm rep
scanomapSOAC :: forall rep. [Scan rep] -> Lambda rep -> ScremaForm rep
scanomapSOAC [Scan rep]
scans Lambda rep
lam = Lambda rep -> [Scan rep] -> [Reduce rep] -> ScremaForm rep
forall rep.
Lambda rep -> [Scan rep] -> [Reduce rep] -> ScremaForm rep
ScremaForm Lambda rep
lam [Scan rep]
scans []
redomapSOAC :: [Reduce rep] -> Lambda rep -> ScremaForm rep
redomapSOAC :: forall rep. [Reduce rep] -> Lambda rep -> ScremaForm rep
redomapSOAC [Reduce rep]
reds Lambda rep
lam = Lambda rep -> [Scan rep] -> [Reduce rep] -> ScremaForm rep
forall rep.
Lambda rep -> [Scan rep] -> [Reduce rep] -> ScremaForm rep
ScremaForm Lambda rep
lam [] [Reduce rep]
reds
scanSOAC ::
(Buildable rep, MonadFreshNames m) =>
[Scan rep] ->
m (ScremaForm rep)
scanSOAC :: forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Scan rep] -> m (ScremaForm rep)
scanSOAC [Scan rep]
scans = [Scan rep] -> Lambda rep -> ScremaForm rep
forall rep. [Scan rep] -> Lambda rep -> ScremaForm rep
scanomapSOAC [Scan rep]
scans (Lambda rep -> ScremaForm rep)
-> m (Lambda rep) -> m (ScremaForm rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Type] -> m (Lambda rep)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda [Type]
ts
where
ts :: [Type]
ts = (Scan rep -> [Type]) -> [Scan rep] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (Lambda rep -> [Type])
-> (Scan rep -> Lambda rep) -> Scan rep -> [Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Scan rep -> Lambda rep
forall rep. Scan rep -> Lambda rep
scanLambda) [Scan rep]
scans
reduceSOAC ::
(Buildable rep, MonadFreshNames m) =>
[Reduce rep] ->
m (ScremaForm rep)
reduceSOAC :: forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Reduce rep] -> m (ScremaForm rep)
reduceSOAC [Reduce rep]
reds = [Reduce rep] -> Lambda rep -> ScremaForm rep
forall rep. [Reduce rep] -> Lambda rep -> ScremaForm rep
redomapSOAC [Reduce rep]
reds (Lambda rep -> ScremaForm rep)
-> m (Lambda rep) -> m (ScremaForm rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Type] -> m (Lambda rep)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda [Type]
ts
where
ts :: [Type]
ts = (Reduce rep -> [Type]) -> [Reduce rep] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (Lambda rep -> [Type])
-> (Reduce rep -> Lambda rep) -> Reduce rep -> [Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Reduce rep -> Lambda rep
forall rep. Reduce rep -> Lambda rep
redLambda) [Reduce rep]
reds
mapSOAC :: Lambda rep -> ScremaForm rep
mapSOAC :: forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda rep
lam = Lambda rep -> [Scan rep] -> [Reduce rep] -> ScremaForm rep
forall rep.
Lambda rep -> [Scan rep] -> [Reduce rep] -> ScremaForm rep
ScremaForm Lambda rep
lam [] []
isScanomapSOAC :: ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
isScanomapSOAC :: forall rep. ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
isScanomapSOAC (ScremaForm Lambda rep
map_lam [Scan rep]
scans [Reduce rep]
reds) = do
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ [Reduce rep] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Reduce rep]
reds
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Scan rep] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Scan rep]
scans
([Scan rep], Lambda rep) -> Maybe ([Scan rep], Lambda rep)
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Scan rep]
scans, Lambda rep
map_lam)
isScanSOAC :: ScremaForm rep -> Maybe [Scan rep]
isScanSOAC :: forall rep. ScremaForm rep -> Maybe [Scan rep]
isScanSOAC ScremaForm rep
form = do
([Scan rep]
scans, Lambda rep
map_lam) <- ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
forall rep. ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
isScanomapSOAC ScremaForm rep
form
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Bool
forall rep. Lambda rep -> Bool
isIdentityLambda Lambda rep
map_lam
[Scan rep] -> Maybe [Scan rep]
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Scan rep]
scans
isRedomapSOAC :: ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC :: forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC (ScremaForm Lambda rep
map_lam [Scan rep]
scans [Reduce rep]
reds) = do
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ [Scan rep] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Scan rep]
scans
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Reduce rep] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Reduce rep]
reds
([Reduce rep], Lambda rep) -> Maybe ([Reduce rep], Lambda rep)
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Reduce rep]
reds, Lambda rep
map_lam)
isReduceSOAC :: ScremaForm rep -> Maybe [Reduce rep]
isReduceSOAC :: forall rep. ScremaForm rep -> Maybe [Reduce rep]
isReduceSOAC ScremaForm rep
form = do
([Reduce rep]
reds, Lambda rep
map_lam) <- ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm rep
form
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Bool
forall rep. Lambda rep -> Bool
isIdentityLambda Lambda rep
map_lam
[Reduce rep] -> Maybe [Reduce rep]
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Reduce rep]
reds
isMapSOAC :: ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC :: forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC (ScremaForm Lambda rep
map_lam [Scan rep]
scans [Reduce rep]
reds) = do
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ [Scan rep] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Scan rep]
scans
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ [Reduce rep] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Reduce rep]
reds
Lambda rep -> Maybe (Lambda rep)
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda rep
map_lam
splitScatterResults :: [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults :: forall array a. [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults [(Shape, Int, array)]
output_spec [a]
results =
let ([Shape]
shapes, [Int]
ns, [array]
_) = [(Shape, Int, array)] -> ([Shape], [Int], [array])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, array)]
output_spec
num_indices :: Int
num_indices = [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) [Int]
ns ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ (Shape -> Int) -> [Shape] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Shape -> Int
forall a. ShapeBase a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Shape]
shapes
in Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_indices [a]
results
groupScatterResults' :: [(Shape, Int, array)] -> [a] -> [([a], a)]
groupScatterResults' :: forall array a. [(Shape, Int, array)] -> [a] -> [([a], a)]
groupScatterResults' [(Shape, Int, array)]
output_spec [a]
results =
let ([a]
indices, [a]
values) = [(Shape, Int, array)] -> [a] -> ([a], [a])
forall array a. [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults [(Shape, Int, array)]
output_spec [a]
results
([Shape]
shapes, [Int]
ns, [array]
_) = [(Shape, Int, array)] -> ([Shape], [Int], [array])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, array)]
output_spec
chunk_sizes :: [Int]
chunk_sizes =
[[Int]] -> [Int]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Int]] -> [Int]) -> [[Int]] -> [Int]
forall a b. (a -> b) -> a -> b
$ (Shape -> Int -> [Int]) -> [Shape] -> [Int] -> [[Int]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Shape
shp Int
n -> Int -> Int -> [Int]
forall a. Int -> a -> [a]
replicate Int
n (Int -> [Int]) -> Int -> [Int]
forall a b. (a -> b) -> a -> b
$ Shape -> Int
forall a. ShapeBase a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape
shp) [Shape]
shapes [Int]
ns
in [[a]] -> [a] -> [([a], a)]
forall a b. [a] -> [b] -> [(a, b)]
zip ([Int] -> [a] -> [[a]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
chunk_sizes [a]
indices) [a]
values
groupScatterResults :: ScatterSpec array -> [a] -> [(Shape, array, [([a], a)])]
groupScatterResults :: forall array a.
ScatterSpec array -> [a] -> [(Shape, array, [([a], a)])]
groupScatterResults ScatterSpec array
output_spec [a]
results =
let ([Shape]
shapes, [Int]
ns, [array]
arrays) = ScatterSpec array -> ([Shape], [Int], [array])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ScatterSpec array
output_spec
in ScatterSpec array -> [a] -> [([a], a)]
forall array a. [(Shape, Int, array)] -> [a] -> [([a], a)]
groupScatterResults' ScatterSpec array
output_spec [a]
results
[([a], a)] -> ([([a], a)] -> [[([a], a)]]) -> [[([a], a)]]
forall a b. a -> (a -> b) -> b
& [Int] -> [([a], a)] -> [[([a], a)]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
ns
[[([a], a)]]
-> ([[([a], a)]] -> [(Shape, array, [([a], a)])])
-> [(Shape, array, [([a], a)])]
forall a b. a -> (a -> b) -> b
& [Shape] -> [array] -> [[([a], a)]] -> [(Shape, array, [([a], a)])]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Shape]
shapes [array]
arrays
data SOACMapper frep trep m = SOACMapper
{ forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp :: SubExp -> m SubExp,
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda :: Lambda frep -> m (Lambda trep),
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> VName -> m VName
mapOnSOACVName :: VName -> m VName
}
identitySOACMapper :: forall rep m. (Monad m) => SOACMapper rep rep m
identitySOACMapper :: forall rep (m :: * -> *). Monad m => SOACMapper rep rep m
identitySOACMapper =
SOACMapper
{ mapOnSOACSubExp :: SubExp -> m SubExp
mapOnSOACSubExp = SubExp -> m SubExp
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure,
mapOnSOACLambda :: Lambda rep -> m (Lambda rep)
mapOnSOACLambda = Lambda rep -> m (Lambda rep)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure,
mapOnSOACVName :: VName -> m VName
mapOnSOACVName = VName -> m VName
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
}
mapSOACM ::
(Monad m) =>
SOACMapper frep trep m ->
SOAC frep ->
m (SOAC trep)
mapSOACM :: forall (m :: * -> *) frep trep.
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper frep trep m
tv (JVP [SubExp]
args [SubExp]
vec Lambda frep
lam) =
[SubExp] -> [SubExp] -> Lambda trep -> SOAC trep
forall rep. [SubExp] -> [SubExp] -> Lambda rep -> SOAC rep
JVP
([SubExp] -> [SubExp] -> Lambda trep -> SOAC trep)
-> m [SubExp] -> m ([SubExp] -> Lambda trep -> SOAC trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) [SubExp]
args
m ([SubExp] -> Lambda trep -> SOAC trep)
-> m [SubExp] -> m (Lambda trep -> SOAC trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) [SubExp]
vec
m (Lambda trep -> SOAC trep) -> m (Lambda trep) -> m (SOAC trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
lam
mapSOACM SOACMapper frep trep m
tv (VJP [SubExp]
args [SubExp]
vec Lambda frep
lam) =
[SubExp] -> [SubExp] -> Lambda trep -> SOAC trep
forall rep. [SubExp] -> [SubExp] -> Lambda rep -> SOAC rep
VJP
([SubExp] -> [SubExp] -> Lambda trep -> SOAC trep)
-> m [SubExp] -> m ([SubExp] -> Lambda trep -> SOAC trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) [SubExp]
args
m ([SubExp] -> Lambda trep -> SOAC trep)
-> m [SubExp] -> m (Lambda trep -> SOAC trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) [SubExp]
vec
m (Lambda trep -> SOAC trep) -> m (Lambda trep) -> m (SOAC trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
lam
mapSOACM SOACMapper frep trep m
tv (Stream SubExp
size [VName]
arrs [SubExp]
accs Lambda frep
lam) =
SubExp -> [VName] -> [SubExp] -> Lambda trep -> SOAC trep
forall rep. SubExp -> [VName] -> [SubExp] -> Lambda rep -> SOAC rep
Stream
(SubExp -> [VName] -> [SubExp] -> Lambda trep -> SOAC trep)
-> m SubExp -> m ([VName] -> [SubExp] -> Lambda trep -> SOAC trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv SubExp
size
m ([VName] -> [SubExp] -> Lambda trep -> SOAC trep)
-> m [VName] -> m ([SubExp] -> Lambda trep -> SOAC trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SOACMapper frep trep m -> VName -> m VName
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> VName -> m VName
mapOnSOACVName SOACMapper frep trep m
tv) [VName]
arrs
m ([SubExp] -> Lambda trep -> SOAC trep)
-> m [SubExp] -> m (Lambda trep -> SOAC trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) [SubExp]
accs
m (Lambda trep -> SOAC trep) -> m (Lambda trep) -> m (SOAC trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
lam
mapSOACM SOACMapper frep trep m
tv (Scatter SubExp
w [VName]
ivs ScatterSpec VName
as Lambda frep
lam) =
SubExp -> [VName] -> ScatterSpec VName -> Lambda trep -> SOAC trep
forall rep.
SubExp -> [VName] -> ScatterSpec VName -> Lambda rep -> SOAC rep
Scatter
(SubExp
-> [VName] -> ScatterSpec VName -> Lambda trep -> SOAC trep)
-> m SubExp
-> m ([VName] -> ScatterSpec VName -> Lambda trep -> SOAC trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv SubExp
w
m ([VName] -> ScatterSpec VName -> Lambda trep -> SOAC trep)
-> m [VName] -> m (ScatterSpec VName -> Lambda trep -> SOAC trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SOACMapper frep trep m -> VName -> m VName
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> VName -> m VName
mapOnSOACVName SOACMapper frep trep m
tv) [VName]
ivs
m (ScatterSpec VName -> Lambda trep -> SOAC trep)
-> m (ScatterSpec VName) -> m (Lambda trep -> SOAC trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ((Shape, Int, VName) -> m (Shape, Int, VName))
-> ScatterSpec VName -> m (ScatterSpec VName)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM
( \(Shape
aw, Int
an, VName
a) ->
(,,)
(Shape -> Int -> VName -> (Shape, Int, VName))
-> m Shape -> m (Int -> VName -> (Shape, Int, VName))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> Shape -> m Shape
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> ShapeBase a -> m (ShapeBase b)
mapM (SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) Shape
aw
m (Int -> VName -> (Shape, Int, VName))
-> m Int -> m (VName -> (Shape, Int, VName))
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> m Int
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
an
m (VName -> (Shape, Int, VName))
-> m VName -> m (Shape, Int, VName)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper frep trep m -> VName -> m VName
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> VName -> m VName
mapOnSOACVName SOACMapper frep trep m
tv VName
a
)
ScatterSpec VName
as
m (Lambda trep -> SOAC trep) -> m (Lambda trep) -> m (SOAC trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
lam
mapSOACM SOACMapper frep trep m
tv (Hist SubExp
w [VName]
arrs [HistOp frep]
ops Lambda frep
bucket_fun) =
SubExp -> [VName] -> [HistOp trep] -> Lambda trep -> SOAC trep
forall rep.
SubExp -> [VName] -> [HistOp rep] -> Lambda rep -> SOAC rep
Hist
(SubExp -> [VName] -> [HistOp trep] -> Lambda trep -> SOAC trep)
-> m SubExp
-> m ([VName] -> [HistOp trep] -> Lambda trep -> SOAC trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv SubExp
w
m ([VName] -> [HistOp trep] -> Lambda trep -> SOAC trep)
-> m [VName] -> m ([HistOp trep] -> Lambda trep -> SOAC trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SOACMapper frep trep m -> VName -> m VName
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> VName -> m VName
mapOnSOACVName SOACMapper frep trep m
tv) [VName]
arrs
m ([HistOp trep] -> Lambda trep -> SOAC trep)
-> m [HistOp trep] -> m (Lambda trep -> SOAC trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (HistOp frep -> m (HistOp trep))
-> [HistOp frep] -> m [HistOp trep]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM
( \(HistOp Shape
shape SubExp
rf [VName]
op_arrs [SubExp]
nes Lambda frep
op) ->
Shape
-> SubExp -> [VName] -> [SubExp] -> Lambda trep -> HistOp trep
forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp
(Shape
-> SubExp -> [VName] -> [SubExp] -> Lambda trep -> HistOp trep)
-> m Shape
-> m (SubExp -> [VName] -> [SubExp] -> Lambda trep -> HistOp trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> Shape -> m Shape
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> ShapeBase a -> m (ShapeBase b)
mapM (SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) Shape
shape
m (SubExp -> [VName] -> [SubExp] -> Lambda trep -> HistOp trep)
-> m SubExp
-> m ([VName] -> [SubExp] -> Lambda trep -> HistOp trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv SubExp
rf
m ([VName] -> [SubExp] -> Lambda trep -> HistOp trep)
-> m [VName] -> m ([SubExp] -> Lambda trep -> HistOp trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SOACMapper frep trep m -> VName -> m VName
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> VName -> m VName
mapOnSOACVName SOACMapper frep trep m
tv) [VName]
op_arrs
m ([SubExp] -> Lambda trep -> HistOp trep)
-> m [SubExp] -> m (Lambda trep -> HistOp trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) [SubExp]
nes
m (Lambda trep -> HistOp trep)
-> m (Lambda trep) -> m (HistOp trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
op
)
[HistOp frep]
ops
m (Lambda trep -> SOAC trep) -> m (Lambda trep) -> m (SOAC trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
bucket_fun
mapSOACM SOACMapper frep trep m
tv (Screma SubExp
w [VName]
arrs (ScremaForm Lambda frep
map_lam [Scan frep]
scans [Reduce frep]
reds)) =
SubExp -> [VName] -> ScremaForm trep -> SOAC trep
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma
(SubExp -> [VName] -> ScremaForm trep -> SOAC trep)
-> m SubExp -> m ([VName] -> ScremaForm trep -> SOAC trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv SubExp
w
m ([VName] -> ScremaForm trep -> SOAC trep)
-> m [VName] -> m (ScremaForm trep -> SOAC trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SOACMapper frep trep m -> VName -> m VName
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> VName -> m VName
mapOnSOACVName SOACMapper frep trep m
tv) [VName]
arrs
m (ScremaForm trep -> SOAC trep)
-> m (ScremaForm trep) -> m (SOAC trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ( Lambda trep -> [Scan trep] -> [Reduce trep] -> ScremaForm trep
forall rep.
Lambda rep -> [Scan rep] -> [Reduce rep] -> ScremaForm rep
ScremaForm
(Lambda trep -> [Scan trep] -> [Reduce trep] -> ScremaForm trep)
-> m (Lambda trep)
-> m ([Scan trep] -> [Reduce trep] -> ScremaForm trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
map_lam
m ([Scan trep] -> [Reduce trep] -> ScremaForm trep)
-> m [Scan trep] -> m ([Reduce trep] -> ScremaForm trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Scan frep] -> (Scan frep -> m (Scan trep)) -> m [Scan trep]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM
[Scan frep]
scans
( \(Scan Lambda frep
red_lam [SubExp]
red_nes) ->
Lambda trep -> [SubExp] -> Scan trep
forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan
(Lambda trep -> [SubExp] -> Scan trep)
-> m (Lambda trep) -> m ([SubExp] -> Scan trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
red_lam
m ([SubExp] -> Scan trep) -> m [SubExp] -> m (Scan trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) [SubExp]
red_nes
)
m ([Reduce trep] -> ScremaForm trep)
-> m [Reduce trep] -> m (ScremaForm trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Reduce frep]
-> (Reduce frep -> m (Reduce trep)) -> m [Reduce trep]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM
[Reduce frep]
reds
( \(Reduce Commutativity
comm Lambda frep
red_lam [SubExp]
red_nes) ->
Commutativity -> Lambda trep -> [SubExp] -> Reduce trep
forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
comm
(Lambda trep -> [SubExp] -> Reduce trep)
-> m (Lambda trep) -> m ([SubExp] -> Reduce trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
red_lam
m ([SubExp] -> Reduce trep) -> m [SubExp] -> m (Reduce trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) [SubExp]
red_nes
)
)
traverseSOACStms :: (Monad m) => OpStmsTraverser m (SOAC rep) rep
traverseSOACStms :: forall (m :: * -> *) rep.
Monad m =>
OpStmsTraverser m (SOAC rep) rep
traverseSOACStms Scope rep -> Stms rep -> m (Stms rep)
f = SOACMapper rep rep m -> SOAC rep -> m (SOAC rep)
forall (m :: * -> *) frep trep.
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper rep rep m
mapper
where
mapper :: SOACMapper rep rep m
mapper = SOACMapper Any Any m
forall rep (m :: * -> *). Monad m => SOACMapper rep rep m
identitySOACMapper {mapOnSOACLambda = traverseLambdaStms f}
instance (ASTRep rep) => FreeIn (Scan rep) where
freeIn' :: Scan rep -> FV
freeIn' (Scan Lambda rep
lam [SubExp]
ne) = Lambda rep -> FV
forall a. FreeIn a => a -> FV
freeIn' Lambda rep
lam FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [SubExp] -> FV
forall a. FreeIn a => a -> FV
freeIn' [SubExp]
ne
instance (ASTRep rep) => FreeIn (Reduce rep) where
freeIn' :: Reduce rep -> FV
freeIn' (Reduce Commutativity
_ Lambda rep
lam [SubExp]
ne) = Lambda rep -> FV
forall a. FreeIn a => a -> FV
freeIn' Lambda rep
lam FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [SubExp] -> FV
forall a. FreeIn a => a -> FV
freeIn' [SubExp]
ne
instance (ASTRep rep) => FreeIn (ScremaForm rep) where
freeIn' :: ScremaForm rep -> FV
freeIn' (ScremaForm Lambda rep
scans [Scan rep]
reds [Reduce rep]
lam) =
Lambda rep -> FV
forall a. FreeIn a => a -> FV
freeIn' Lambda rep
scans FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [Scan rep] -> FV
forall a. FreeIn a => a -> FV
freeIn' [Scan rep]
reds FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [Reduce rep] -> FV
forall a. FreeIn a => a -> FV
freeIn' [Reduce rep]
lam
instance (ASTRep rep) => FreeIn (HistOp rep) where
freeIn' :: HistOp rep -> FV
freeIn' (HistOp Shape
w SubExp
rf [VName]
dests [SubExp]
nes Lambda rep
lam) =
Shape -> FV
forall a. FreeIn a => a -> FV
freeIn' Shape
w FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
rf FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [VName] -> FV
forall a. FreeIn a => a -> FV
freeIn' [VName]
dests FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [SubExp] -> FV
forall a. FreeIn a => a -> FV
freeIn' [SubExp]
nes FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> Lambda rep -> FV
forall a. FreeIn a => a -> FV
freeIn' Lambda rep
lam
instance (ASTRep rep) => FreeIn (SOAC rep) where
freeIn' :: SOAC rep -> FV
freeIn' = (State FV (SOAC rep) -> FV -> FV)
-> FV -> State FV (SOAC rep) -> FV
forall a b c. (a -> b -> c) -> b -> a -> c
flip State FV (SOAC rep) -> FV -> FV
forall s a. State s a -> s -> s
execState FV
forall a. Monoid a => a
mempty (State FV (SOAC rep) -> FV)
-> (SOAC rep -> State FV (SOAC rep)) -> SOAC rep -> FV
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SOACMapper rep rep (StateT FV Identity)
-> SOAC rep -> State FV (SOAC rep)
forall (m :: * -> *) frep trep.
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper rep rep (StateT FV Identity)
free
where
walk :: (b -> s) -> b -> m b
walk b -> s
f b
x = (s -> s) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (s -> s -> s
forall a. Semigroup a => a -> a -> a
<> b -> s
f b
x) m () -> m b -> m b
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> b -> m b
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure b
x
free :: SOACMapper rep rep (StateT FV Identity)
free =
SOACMapper
{ mapOnSOACSubExp :: SubExp -> StateT FV Identity SubExp
mapOnSOACSubExp = (SubExp -> FV) -> SubExp -> StateT FV Identity SubExp
forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn',
mapOnSOACLambda :: Lambda rep -> StateT FV Identity (Lambda rep)
mapOnSOACLambda = (Lambda rep -> FV) -> Lambda rep -> StateT FV Identity (Lambda rep)
forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk Lambda rep -> FV
forall a. FreeIn a => a -> FV
freeIn',
mapOnSOACVName :: VName -> StateT FV Identity VName
mapOnSOACVName = (VName -> FV) -> VName -> StateT FV Identity VName
forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk VName -> FV
forall a. FreeIn a => a -> FV
freeIn'
}
instance (ASTRep rep) => Substitute (SOAC rep) where
substituteNames :: Map VName VName -> SOAC rep -> SOAC rep
substituteNames Map VName VName
subst =
Identity (SOAC rep) -> SOAC rep
forall a. Identity a -> a
runIdentity (Identity (SOAC rep) -> SOAC rep)
-> (SOAC rep -> Identity (SOAC rep)) -> SOAC rep -> SOAC rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SOACMapper rep rep Identity -> SOAC rep -> Identity (SOAC rep)
forall (m :: * -> *) frep trep.
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper rep rep Identity
substitute
where
substitute :: SOACMapper rep rep Identity
substitute =
SOACMapper
{ mapOnSOACSubExp :: SubExp -> Identity SubExp
mapOnSOACSubExp = SubExp -> Identity SubExp
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> Identity SubExp)
-> (SubExp -> SubExp) -> SubExp -> Identity SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
mapOnSOACLambda :: Lambda rep -> Identity (Lambda rep)
mapOnSOACLambda = Lambda rep -> Identity (Lambda rep)
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda rep -> Identity (Lambda rep))
-> (Lambda rep -> Lambda rep)
-> Lambda rep
-> Identity (Lambda rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> Lambda rep -> Lambda rep
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
mapOnSOACVName :: VName -> Identity VName
mapOnSOACVName = VName -> Identity VName
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> Identity VName)
-> (VName -> VName) -> VName -> Identity VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst
}
instance (ASTRep rep) => Rename (SOAC rep) where
rename :: SOAC rep -> RenameM (SOAC rep)
rename = SOACMapper rep rep RenameM -> SOAC rep -> RenameM (SOAC rep)
forall (m :: * -> *) frep trep.
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper rep rep RenameM
renamer
where
renamer :: SOACMapper rep rep RenameM
renamer = (SubExp -> RenameM SubExp)
-> (Lambda rep -> RenameM (Lambda rep))
-> (VName -> RenameM VName)
-> SOACMapper rep rep RenameM
forall frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (VName -> m VName)
-> SOACMapper frep trep m
SOACMapper SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename Lambda rep -> RenameM (Lambda rep)
forall a. Rename a => a -> RenameM a
rename VName -> RenameM VName
forall a. Rename a => a -> RenameM a
rename
soacType :: (Typed (LParamInfo rep)) => SOAC rep -> [Type]
soacType :: forall rep. Typed (LParamInfo rep) => SOAC rep -> [Type]
soacType (JVP [SubExp]
_ [SubExp]
_ Lambda rep
lam) =
Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam
soacType (VJP [SubExp]
_ [SubExp]
_ Lambda rep
lam) =
Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ (Param (LParamInfo rep) -> Type)
-> [Param (LParamInfo rep)] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo rep) -> Type
forall dec. Typed dec => Param dec -> Type
paramType (Lambda rep -> [Param (LParamInfo rep)]
forall {rep}. Lambda rep -> [Param (LParamInfo rep)]
lambdaParams Lambda rep
lam)
soacType (Stream SubExp
outersize [VName]
_ [SubExp]
accs Lambda rep
lam) =
(Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Map VName SubExp -> Type -> Type
substNamesInType Map VName SubExp
substs) [Type]
rtp
where
nms :: [VName]
nms = (Param (LParamInfo rep) -> VName)
-> [Param (LParamInfo rep)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName ([Param (LParamInfo rep)] -> [VName])
-> [Param (LParamInfo rep)] -> [VName]
forall a b. (a -> b) -> a -> b
$ Int -> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. Int -> [a] -> [a]
take (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
accs) [Param (LParamInfo rep)]
params
substs :: Map VName SubExp
substs = [(VName, SubExp)] -> Map VName SubExp
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, SubExp)] -> Map VName SubExp)
-> [(VName, SubExp)] -> Map VName SubExp
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
nms (SubExp
outersize SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: [SubExp]
accs)
Lambda [Param (LParamInfo rep)]
params [Type]
rtp Body rep
_ = Lambda rep
lam
soacType (Scatter SubExp
_w [VName]
_ivs ScatterSpec VName
dests Lambda rep
lam) =
(Type -> Shape -> Type) -> [Type] -> [Shape] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Type -> Shape -> Type
arrayOfShape (([([Type], Type)] -> Type) -> [[([Type], Type)]] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (([Type], Type) -> Type
forall a b. (a, b) -> b
snd (([Type], Type) -> Type)
-> ([([Type], Type)] -> ([Type], Type)) -> [([Type], Type)] -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [([Type], Type)] -> ([Type], Type)
forall a. HasCallStack => [a] -> a
head) [[([Type], Type)]]
rets) [Shape]
shapes
where
([Shape]
shapes, [VName]
_, [[([Type], Type)]]
rets) =
[(Shape, VName, [([Type], Type)])]
-> ([Shape], [VName], [[([Type], Type)]])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(Shape, VName, [([Type], Type)])]
-> ([Shape], [VName], [[([Type], Type)]]))
-> [(Shape, VName, [([Type], Type)])]
-> ([Shape], [VName], [[([Type], Type)]])
forall a b. (a -> b) -> a -> b
$ ScatterSpec VName -> [Type] -> [(Shape, VName, [([Type], Type)])]
forall array a.
ScatterSpec array -> [a] -> [(Shape, array, [([a], a)])]
groupScatterResults ScatterSpec VName
dests ([Type] -> [(Shape, VName, [([Type], Type)])])
-> [Type] -> [(Shape, VName, [([Type], Type)])]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam
soacType (Hist SubExp
_ [VName]
_ [HistOp rep]
ops Lambda rep
_bucket_fun) = do
HistOp rep
op <- [HistOp rep]
ops
(Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Shape -> Type
`arrayOfShape` HistOp rep -> Shape
forall rep. HistOp rep -> Shape
histShape HistOp rep
op) (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (Lambda rep -> [Type]) -> Lambda rep -> [Type]
forall a b. (a -> b) -> a -> b
$ HistOp rep -> Lambda rep
forall rep. HistOp rep -> Lambda rep
histOp HistOp rep
op)
soacType (Screma SubExp
w [VName]
_arrs ScremaForm rep
form) =
SubExp -> ScremaForm rep -> [Type]
forall rep. SubExp -> ScremaForm rep -> [Type]
scremaType SubExp
w ScremaForm rep
form
instance TypedOp SOAC where
opType :: forall rep (m :: * -> *). HasScope rep m => SOAC rep -> m [ExtType]
opType = [ExtType] -> m [ExtType]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([ExtType] -> m [ExtType])
-> (SOAC rep -> [ExtType]) -> SOAC rep -> m [ExtType]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Type] -> [ExtType]
forall u. [TypeBase Shape u] -> [TypeBase ExtShape u]
staticShapes ([Type] -> [ExtType])
-> (SOAC rep -> [Type]) -> SOAC rep -> [ExtType]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SOAC rep -> [Type]
forall rep. Typed (LParamInfo rep) => SOAC rep -> [Type]
soacType
instance AliasedOp SOAC where
opAliases :: forall rep. Aliased rep => SOAC rep -> [Names]
opAliases = (Type -> Names) -> [Type] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Names -> Type -> Names
forall a b. a -> b -> a
const Names
forall a. Monoid a => a
mempty) ([Type] -> [Names]) -> (SOAC rep -> [Type]) -> SOAC rep -> [Names]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SOAC rep -> [Type]
forall rep. Typed (LParamInfo rep) => SOAC rep -> [Type]
soacType
consumedInOp :: forall rep. Aliased rep => SOAC rep -> Names
consumedInOp JVP {} = Names
forall a. Monoid a => a
mempty
consumedInOp VJP {} = Names
forall a. Monoid a => a
mempty
consumedInOp (Screma SubExp
_ [VName]
arrs (ScremaForm Lambda rep
map_lam [Scan rep]
_ [Reduce rep]
_)) =
(VName -> VName) -> Names -> Names
mapNames VName -> VName
consumedArray (Names -> Names) -> Names -> Names
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Names
forall rep. Aliased rep => Lambda rep -> Names
consumedByLambda Lambda rep
map_lam
where
consumedArray :: VName -> VName
consumedArray VName
v = VName -> Maybe VName -> VName
forall a. a -> Maybe a -> a
fromMaybe VName
v (Maybe VName -> VName) -> Maybe VName -> VName
forall a b. (a -> b) -> a -> b
$ VName -> [(VName, VName)] -> Maybe VName
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v [(VName, VName)]
params_to_arrs
params_to_arrs :: [(VName, VName)]
params_to_arrs = [VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param (LParamInfo rep) -> VName)
-> [Param (LParamInfo rep)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName ([Param (LParamInfo rep)] -> [VName])
-> [Param (LParamInfo rep)] -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Param (LParamInfo rep)]
forall {rep}. Lambda rep -> [Param (LParamInfo rep)]
lambdaParams Lambda rep
map_lam) [VName]
arrs
consumedInOp (Stream SubExp
_ [VName]
arrs [SubExp]
accs Lambda rep
lam) =
[VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [VName]
subExpVars ([SubExp] -> [VName]) -> [SubExp] -> [VName]
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
consumedArray ([VName] -> [SubExp]) -> [VName] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Names
forall rep. Aliased rep => Lambda rep -> Names
consumedByLambda Lambda rep
lam
where
consumedArray :: VName -> SubExp
consumedArray VName
v = SubExp -> Maybe SubExp -> SubExp
forall a. a -> Maybe a -> a
fromMaybe (VName -> SubExp
Var VName
v) (Maybe SubExp -> SubExp) -> Maybe SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ VName -> [(VName, SubExp)] -> Maybe SubExp
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v [(VName, SubExp)]
paramsToInput
paramsToInput :: [(VName, SubExp)]
paramsToInput =
[VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param (LParamInfo rep) -> VName)
-> [Param (LParamInfo rep)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName ([Param (LParamInfo rep)] -> [VName])
-> [Param (LParamInfo rep)] -> [VName]
forall a b. (a -> b) -> a -> b
$ Int -> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. Int -> [a] -> [a]
drop Int
1 ([Param (LParamInfo rep)] -> [Param (LParamInfo rep)])
-> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Param (LParamInfo rep)]
forall {rep}. Lambda rep -> [Param (LParamInfo rep)]
lambdaParams Lambda rep
lam) ([SubExp]
accs [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
arrs)
consumedInOp (Scatter SubExp
_ [VName]
_ ScatterSpec VName
spec Lambda rep
_) =
[VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ ((Shape, Int, VName) -> VName) -> ScatterSpec VName -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (\(Shape
_, Int
_, VName
a) -> VName
a) ScatterSpec VName
spec
consumedInOp (Hist SubExp
_ [VName]
_ [HistOp rep]
ops Lambda rep
_) =
[VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (HistOp rep -> [VName]) -> [HistOp rep] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap HistOp rep -> [VName]
forall rep. HistOp rep -> [VName]
histDest [HistOp rep]
ops
mapHistOp ::
(Lambda frep -> Lambda trep) ->
HistOp frep ->
HistOp trep
mapHistOp :: forall frep trep.
(Lambda frep -> Lambda trep) -> HistOp frep -> HistOp trep
mapHistOp Lambda frep -> Lambda trep
f (HistOp Shape
w SubExp
rf [VName]
dests [SubExp]
nes Lambda frep
lam) =
Shape
-> SubExp -> [VName] -> [SubExp] -> Lambda trep -> HistOp trep
forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp Shape
w SubExp
rf [VName]
dests [SubExp]
nes (Lambda trep -> HistOp trep) -> Lambda trep -> HistOp trep
forall a b. (a -> b) -> a -> b
$ Lambda frep -> Lambda trep
f Lambda frep
lam
instance CanBeAliased SOAC where
addOpAliases :: forall rep.
AliasableRep rep =>
AliasTable -> SOAC rep -> SOAC (Aliases rep)
addOpAliases AliasTable
aliases (JVP [SubExp]
args [SubExp]
vec Lambda rep
lam) =
[SubExp] -> [SubExp] -> Lambda (Aliases rep) -> SOAC (Aliases rep)
forall rep. [SubExp] -> [SubExp] -> Lambda rep -> SOAC rep
JVP [SubExp]
args [SubExp]
vec (AliasTable -> Lambda rep -> Lambda (Aliases rep)
forall rep.
AliasableRep rep =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases Lambda rep
lam)
addOpAliases AliasTable
aliases (VJP [SubExp]
args [SubExp]
vec Lambda rep
lam) =
[SubExp] -> [SubExp] -> Lambda (Aliases rep) -> SOAC (Aliases rep)
forall rep. [SubExp] -> [SubExp] -> Lambda rep -> SOAC rep
VJP [SubExp]
args [SubExp]
vec (AliasTable -> Lambda rep -> Lambda (Aliases rep)
forall rep.
AliasableRep rep =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases Lambda rep
lam)
addOpAliases AliasTable
aliases (Stream SubExp
size [VName]
arr [SubExp]
accs Lambda rep
lam) =
SubExp
-> [VName]
-> [SubExp]
-> Lambda (Aliases rep)
-> SOAC (Aliases rep)
forall rep. SubExp -> [VName] -> [SubExp] -> Lambda rep -> SOAC rep
Stream SubExp
size [VName]
arr [SubExp]
accs (Lambda (Aliases rep) -> SOAC (Aliases rep))
-> Lambda (Aliases rep) -> SOAC (Aliases rep)
forall a b. (a -> b) -> a -> b
$ AliasTable -> Lambda rep -> Lambda (Aliases rep)
forall rep.
AliasableRep rep =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases Lambda rep
lam
addOpAliases AliasTable
aliases (Scatter SubExp
len [VName]
arrs ScatterSpec VName
dests Lambda rep
lam) =
SubExp
-> [VName]
-> ScatterSpec VName
-> Lambda (Aliases rep)
-> SOAC (Aliases rep)
forall rep.
SubExp -> [VName] -> ScatterSpec VName -> Lambda rep -> SOAC rep
Scatter SubExp
len [VName]
arrs ScatterSpec VName
dests (AliasTable -> Lambda rep -> Lambda (Aliases rep)
forall rep.
AliasableRep rep =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases Lambda rep
lam)
addOpAliases AliasTable
aliases (Hist SubExp
w [VName]
arrs [HistOp rep]
ops Lambda rep
bucket_fun) =
SubExp
-> [VName]
-> [HistOp (Aliases rep)]
-> Lambda (Aliases rep)
-> SOAC (Aliases rep)
forall rep.
SubExp -> [VName] -> [HistOp rep] -> Lambda rep -> SOAC rep
Hist
SubExp
w
[VName]
arrs
((HistOp rep -> HistOp (Aliases rep))
-> [HistOp rep] -> [HistOp (Aliases rep)]
forall a b. (a -> b) -> [a] -> [b]
map ((Lambda rep -> Lambda (Aliases rep))
-> HistOp rep -> HistOp (Aliases rep)
forall frep trep.
(Lambda frep -> Lambda trep) -> HistOp frep -> HistOp trep
mapHistOp (AliasTable -> Lambda rep -> Lambda (Aliases rep)
forall rep.
AliasableRep rep =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases)) [HistOp rep]
ops)
(AliasTable -> Lambda rep -> Lambda (Aliases rep)
forall rep.
AliasableRep rep =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases Lambda rep
bucket_fun)
addOpAliases AliasTable
aliases (Screma SubExp
w [VName]
arrs (ScremaForm Lambda rep
map_lam [Scan rep]
scans [Reduce rep]
reds)) =
SubExp -> [VName] -> ScremaForm (Aliases rep) -> SOAC (Aliases rep)
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs (ScremaForm (Aliases rep) -> SOAC (Aliases rep))
-> ScremaForm (Aliases rep) -> SOAC (Aliases rep)
forall a b. (a -> b) -> a -> b
$
Lambda (Aliases rep)
-> [Scan (Aliases rep)]
-> [Reduce (Aliases rep)]
-> ScremaForm (Aliases rep)
forall rep.
Lambda rep -> [Scan rep] -> [Reduce rep] -> ScremaForm rep
ScremaForm
(AliasTable -> Lambda rep -> Lambda (Aliases rep)
forall rep.
AliasableRep rep =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases Lambda rep
map_lam)
((Scan rep -> Scan (Aliases rep))
-> [Scan rep] -> [Scan (Aliases rep)]
forall a b. (a -> b) -> [a] -> [b]
map Scan rep -> Scan (Aliases rep)
onScan [Scan rep]
scans)
((Reduce rep -> Reduce (Aliases rep))
-> [Reduce rep] -> [Reduce (Aliases rep)]
forall a b. (a -> b) -> [a] -> [b]
map Reduce rep -> Reduce (Aliases rep)
onRed [Reduce rep]
reds)
where
onRed :: Reduce rep -> Reduce (Aliases rep)
onRed Reduce rep
red = Reduce rep
red {redLambda = Alias.analyseLambda aliases $ redLambda red}
onScan :: Scan rep -> Scan (Aliases rep)
onScan Scan rep
scan = Scan rep
scan {scanLambda = Alias.analyseLambda aliases $ scanLambda scan}
instance IsOp SOAC where
safeOp :: forall rep. ASTRep rep => SOAC rep -> Bool
safeOp SOAC rep
_ = Bool
False
cheapOp :: forall rep. ASTRep rep => SOAC rep -> Bool
cheapOp SOAC rep
_ = Bool
False
opDependencies :: forall rep. ASTRep rep => SOAC rep -> [Names]
opDependencies (Stream SubExp
w [VName]
arrs [SubExp]
accs Lambda rep
lam) =
let accs_deps :: [Names]
accs_deps = (SubExp -> Names) -> [SubExp] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Names
depsOf' [SubExp]
accs
arrs_deps :: [Names]
arrs_deps = SubExp -> [VName] -> [Names]
depsOfArrays SubExp
w [VName]
arrs
in AliasTable -> Lambda rep -> [Names] -> [Names]
forall rep.
ASTRep rep =>
AliasTable -> Lambda rep -> [Names] -> [Names]
lambdaDependencies AliasTable
forall a. Monoid a => a
mempty Lambda rep
lam ([Names]
arrs_deps [Names] -> [Names] -> [Names]
forall a. Semigroup a => a -> a -> a
<> [Names]
accs_deps)
opDependencies (Hist SubExp
w [VName]
arrs [HistOp rep]
ops Lambda rep
lam) =
let bucket_fun_deps' :: [Names]
bucket_fun_deps' = AliasTable -> Lambda rep -> [Names] -> [Names]
forall rep.
ASTRep rep =>
AliasTable -> Lambda rep -> [Names] -> [Names]
lambdaDependencies AliasTable
forall a. Monoid a => a
mempty Lambda rep
lam (SubExp -> [VName] -> [Names]
depsOfArrays SubExp
w [VName]
arrs)
ranks :: [Int]
ranks = (HistOp rep -> Int) -> [HistOp rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (Shape -> Int) -> (HistOp rep -> Shape) -> HistOp rep -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp rep -> Shape
forall rep. HistOp rep -> Shape
histShape) [HistOp rep]
ops
value_lengths :: [Int]
value_lengths = (HistOp rep -> Int) -> [HistOp rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int) -> (HistOp rep -> [SubExp]) -> HistOp rep -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp rep -> [SubExp]
forall rep. HistOp rep -> [SubExp]
histNeutral) [HistOp rep]
ops
([Names]
indices, [Names]
values) = Int -> [Names] -> ([Names], [Names])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
ranks) [Names]
bucket_fun_deps'
bucket_fun_deps :: [[Names]]
bucket_fun_deps =
([Names] -> [Names] -> [Names])
-> [[Names]] -> [[Names]] -> [[Names]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
[Names] -> [Names] -> [Names]
forall {b}. Monoid b => [b] -> [b] -> [b]
concatIndicesToEachValue
([Int] -> [Names] -> [[Names]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
ranks [Names]
indices)
([Int] -> [Names] -> [[Names]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
value_lengths [Names]
values)
in [[Names]] -> [Names]
forall a. Monoid a => [a] -> a
mconcat ([[Names]] -> [Names]) -> [[Names]] -> [Names]
forall a b. (a -> b) -> a -> b
$ ([Names] -> [Names] -> [Names])
-> [[Names]] -> [[Names]] -> [[Names]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith ((Names -> Names -> Names) -> [Names] -> [Names] -> [Names]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
(<>)) [[Names]]
bucket_fun_deps ((HistOp rep -> [Names]) -> [HistOp rep] -> [[Names]]
forall a b. (a -> b) -> [a] -> [b]
map HistOp rep -> [Names]
forall {rep}. ASTRep rep => HistOp rep -> [Names]
depsOfHistOp [HistOp rep]
ops)
where
depsOfHistOp :: HistOp rep -> [Names]
depsOfHistOp (HistOp Shape
dest_shape SubExp
rf [VName]
dests [SubExp]
nes Lambda rep
op) =
let shape_deps :: Names
shape_deps = Shape -> Names
depsOfShape Shape
dest_shape
in_deps :: [Names]
in_deps = (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (\VName
vn -> VName -> Names
oneName VName
vn Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
shape_deps Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> SubExp -> Names
depsOf' SubExp
rf) [VName]
dests
in AliasTable -> Lambda rep -> [SubExp] -> [Names] -> [Names]
forall rep.
ASTRep rep =>
AliasTable -> Lambda rep -> [SubExp] -> [Names] -> [Names]
reductionDependencies AliasTable
forall a. Monoid a => a
mempty Lambda rep
op [SubExp]
nes [Names]
in_deps
concatIndicesToEachValue :: [b] -> [b] -> [b]
concatIndicesToEachValue [b]
is [b]
vs =
let is_flat :: b
is_flat = [b] -> b
forall a. Monoid a => [a] -> a
mconcat [b]
is
in (b -> b) -> [b] -> [b]
forall a b. (a -> b) -> [a] -> [b]
map (b
is_flat <>) [b]
vs
opDependencies (Scatter SubExp
w [VName]
arrs ScatterSpec VName
outputs Lambda rep
lam) =
let deps :: [Names]
deps = AliasTable -> Lambda rep -> [Names] -> [Names]
forall rep.
ASTRep rep =>
AliasTable -> Lambda rep -> [Names] -> [Names]
lambdaDependencies AliasTable
forall a. Monoid a => a
mempty Lambda rep
lam (SubExp -> [VName] -> [Names]
depsOfArrays SubExp
w [VName]
arrs)
in ((Shape, VName, [([Names], Names)]) -> Names)
-> [(Shape, VName, [([Names], Names)])] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Shape, VName, [([Names], Names)]) -> Names
forall {a}. (a, VName, [([Names], Names)]) -> Names
flattenBlocks (ScatterSpec VName
-> [Names] -> [(Shape, VName, [([Names], Names)])]
forall array a.
ScatterSpec array -> [a] -> [(Shape, array, [([a], a)])]
groupScatterResults ScatterSpec VName
outputs [Names]
deps)
where
flattenBlocks :: (a, VName, [([Names], Names)]) -> Names
flattenBlocks (a
_, VName
arr, [([Names], Names)]
ivs) =
VName -> Names
oneName VName
arr Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ((([Names], Names) -> Names) -> [([Names], Names)] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map ([Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names)
-> (([Names], Names) -> [Names]) -> ([Names], Names) -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. ([Names], Names) -> [Names]
forall a b. (a, b) -> a
fst) [([Names], Names)]
ivs) Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ((([Names], Names) -> Names) -> [([Names], Names)] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map ([Names], Names) -> Names
forall a b. (a, b) -> b
snd [([Names], Names)]
ivs)
opDependencies (JVP [SubExp]
args [SubExp]
vec Lambda rep
lam) =
[[Names]] -> [Names]
forall a. Monoid a => [a] -> a
mconcat ([[Names]] -> [Names]) -> [[Names]] -> [Names]
forall a b. (a -> b) -> a -> b
$
Int -> [Names] -> [[Names]]
forall a. Int -> a -> [a]
replicate Int
2 ([Names] -> [[Names]]) -> [Names] -> [[Names]]
forall a b. (a -> b) -> a -> b
$
AliasTable -> Lambda rep -> [Names] -> [Names]
forall rep.
ASTRep rep =>
AliasTable -> Lambda rep -> [Names] -> [Names]
lambdaDependencies AliasTable
forall a. Monoid a => a
mempty Lambda rep
lam ([Names] -> [Names]) -> [Names] -> [Names]
forall a b. (a -> b) -> a -> b
$
(Names -> Names -> Names) -> [Names] -> [Names] -> [Names]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
(<>) ((SubExp -> Names) -> [SubExp] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Names
depsOf' [SubExp]
args) ((SubExp -> Names) -> [SubExp] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Names
depsOf' [SubExp]
vec)
opDependencies (VJP [SubExp]
args [SubExp]
vec Lambda rep
lam) =
AliasTable -> Lambda rep -> [Names] -> [Names]
forall rep.
ASTRep rep =>
AliasTable -> Lambda rep -> [Names] -> [Names]
lambdaDependencies
AliasTable
forall a. Monoid a => a
mempty
Lambda rep
lam
((Names -> Names -> Names) -> [Names] -> [Names] -> [Names]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
(<>) ((SubExp -> Names) -> [SubExp] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Names
depsOf' [SubExp]
args) ((SubExp -> Names) -> [SubExp] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Names
depsOf' [SubExp]
vec))
[Names] -> [Names] -> [Names]
forall a. Semigroup a => a -> a -> a
<> (Param (LParamInfo rep) -> Names)
-> [Param (LParamInfo rep)] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Names -> Param (LParamInfo rep) -> Names
forall a b. a -> b -> a
const (Names -> Param (LParamInfo rep) -> Names)
-> Names -> Param (LParamInfo rep) -> Names
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Names
forall a. FreeIn a => a -> Names
freeIn [SubExp]
args Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Lambda rep -> Names
forall a. FreeIn a => a -> Names
freeIn Lambda rep
lam) (Lambda rep -> [Param (LParamInfo rep)]
forall {rep}. Lambda rep -> [Param (LParamInfo rep)]
lambdaParams Lambda rep
lam)
opDependencies (Screma SubExp
w [VName]
arrs (ScremaForm Lambda rep
map_lam [Scan rep]
scans [Reduce rep]
reds)) =
let ([Names]
scans_in, [Names]
reds_in, [Names]
map_deps) =
Int -> Int -> [Names] -> ([Names], [Names], [Names])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 ([Scan rep] -> Int
forall rep. [Scan rep] -> Int
scanResults [Scan rep]
scans) ([Reduce rep] -> Int
forall rep. [Reduce rep] -> Int
redResults [Reduce rep]
reds) ([Names] -> ([Names], [Names], [Names]))
-> [Names] -> ([Names], [Names], [Names])
forall a b. (a -> b) -> a -> b
$
AliasTable -> Lambda rep -> [Names] -> [Names]
forall rep.
ASTRep rep =>
AliasTable -> Lambda rep -> [Names] -> [Names]
lambdaDependencies AliasTable
forall a. Monoid a => a
mempty Lambda rep
map_lam (SubExp -> [VName] -> [Names]
depsOfArrays SubExp
w [VName]
arrs)
scans_deps :: [Names]
scans_deps =
((Scan rep, [Names]) -> [Names])
-> [(Scan rep, [Names])] -> [Names]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Scan rep, [Names]) -> [Names]
forall {rep}. ASTRep rep => (Scan rep, [Names]) -> [Names]
depsOfScan ([Scan rep] -> [[Names]] -> [(Scan rep, [Names])]
forall a b. [a] -> [b] -> [(a, b)]
zip [Scan rep]
scans ([[Names]] -> [(Scan rep, [Names])])
-> [[Names]] -> [(Scan rep, [Names])]
forall a b. (a -> b) -> a -> b
$ [Int] -> [Names] -> [[Names]]
forall a. [Int] -> [a] -> [[a]]
chunks ([Scan rep] -> [Int]
forall rep. [Scan rep] -> [Int]
scanSizes [Scan rep]
scans) [Names]
scans_in)
reds_deps :: [Names]
reds_deps =
((Reduce rep, [Names]) -> [Names])
-> [(Reduce rep, [Names])] -> [Names]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Reduce rep, [Names]) -> [Names]
forall {rep}. ASTRep rep => (Reduce rep, [Names]) -> [Names]
depsOfRed ([Reduce rep] -> [[Names]] -> [(Reduce rep, [Names])]
forall a b. [a] -> [b] -> [(a, b)]
zip [Reduce rep]
reds ([[Names]] -> [(Reduce rep, [Names])])
-> [[Names]] -> [(Reduce rep, [Names])]
forall a b. (a -> b) -> a -> b
$ [Int] -> [Names] -> [[Names]]
forall a. [Int] -> [a] -> [[a]]
chunks ([Reduce rep] -> [Int]
forall rep. [Reduce rep] -> [Int]
redSizes [Reduce rep]
reds) [Names]
reds_in)
in [Names]
scans_deps [Names] -> [Names] -> [Names]
forall a. Semigroup a => a -> a -> a
<> [Names]
reds_deps [Names] -> [Names] -> [Names]
forall a. Semigroup a => a -> a -> a
<> [Names]
map_deps
where
depsOfScan :: (Scan rep, [Names]) -> [Names]
depsOfScan (Scan Lambda rep
lam [SubExp]
nes, [Names]
deps_in) =
AliasTable -> Lambda rep -> [SubExp] -> [Names] -> [Names]
forall rep.
ASTRep rep =>
AliasTable -> Lambda rep -> [SubExp] -> [Names] -> [Names]
reductionDependencies AliasTable
forall a. Monoid a => a
mempty Lambda rep
lam [SubExp]
nes [Names]
deps_in
depsOfRed :: (Reduce rep, [Names]) -> [Names]
depsOfRed (Reduce Commutativity
_ Lambda rep
lam [SubExp]
nes, [Names]
deps_in) =
AliasTable -> Lambda rep -> [SubExp] -> [Names] -> [Names]
forall rep.
ASTRep rep =>
AliasTable -> Lambda rep -> [SubExp] -> [Names] -> [Names]
reductionDependencies AliasTable
forall a. Monoid a => a
mempty Lambda rep
lam [SubExp]
nes [Names]
deps_in
substNamesInType :: M.Map VName SubExp -> Type -> Type
substNamesInType :: Map VName SubExp -> Type -> Type
substNamesInType Map VName SubExp
_ t :: Type
t@Prim {} = Type
t
substNamesInType Map VName SubExp
_ t :: Type
t@Acc {} = Type
t
substNamesInType Map VName SubExp
_ (Mem Space
space) = Space -> Type
forall shape u. Space -> TypeBase shape u
Mem Space
space
substNamesInType Map VName SubExp
subs (Array PrimType
btp Shape
shp NoUniqueness
u) =
let shp' :: Shape
shp' = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$ (SubExp -> SubExp) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (Map VName SubExp -> SubExp -> SubExp
substNamesInSubExp Map VName SubExp
subs) (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shp)
in PrimType -> Shape -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
btp Shape
shp' NoUniqueness
u
substNamesInSubExp :: M.Map VName SubExp -> SubExp -> SubExp
substNamesInSubExp :: Map VName SubExp -> SubExp -> SubExp
substNamesInSubExp Map VName SubExp
_ e :: SubExp
e@(Constant PrimValue
_) = SubExp
e
substNamesInSubExp Map VName SubExp
subs (Var VName
idd) =
SubExp -> VName -> Map VName SubExp -> SubExp
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault (VName -> SubExp
Var VName
idd) VName
idd Map VName SubExp
subs
instance CanBeWise SOAC where
addOpWisdom :: forall rep. Informing rep => SOAC rep -> SOAC (Wise rep)
addOpWisdom = Identity (SOAC (Wise rep)) -> SOAC (Wise rep)
forall a. Identity a -> a
runIdentity (Identity (SOAC (Wise rep)) -> SOAC (Wise rep))
-> (SOAC rep -> Identity (SOAC (Wise rep)))
-> SOAC rep
-> SOAC (Wise rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SOACMapper rep (Wise rep) Identity
-> SOAC rep -> Identity (SOAC (Wise rep))
forall (m :: * -> *) frep trep.
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM ((SubExp -> Identity SubExp)
-> (Lambda rep -> Identity (Lambda (Wise rep)))
-> (VName -> Identity VName)
-> SOACMapper rep (Wise rep) Identity
forall frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (VName -> m VName)
-> SOACMapper frep trep m
SOACMapper SubExp -> Identity SubExp
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda (Wise rep) -> Identity (Lambda (Wise rep))
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda (Wise rep) -> Identity (Lambda (Wise rep)))
-> (Lambda rep -> Lambda (Wise rep))
-> Lambda rep
-> Identity (Lambda (Wise rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Lambda rep -> Lambda (Wise rep)
forall rep. Informing rep => Lambda rep -> Lambda (Wise rep)
informLambda) VName -> Identity VName
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure)
instance (RepTypes rep) => ST.IndexOp (SOAC rep) where
indexOp :: forall rep.
(ASTRep rep, IndexOp (Op rep)) =>
SymbolTable rep
-> Int -> SOAC rep -> [TPrimExp Int64 VName] -> Maybe Indexed
indexOp SymbolTable rep
vtable Int
k SOAC rep
soac [TPrimExp Int64 VName
i] = do
(Lambda rep
lam, SubExpRes
se, [Param (LParamInfo rep)]
arr_params, [VName]
arrs) <- SOAC rep
-> Maybe (Lambda rep, SubExpRes, [Param (LParamInfo rep)], [VName])
lambdaAndSubExp SOAC rep
soac
let arr_indexes :: Map VName (PrimExp VName, Certs)
arr_indexes = [(VName, (PrimExp VName, Certs))]
-> Map VName (PrimExp VName, Certs)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, (PrimExp VName, Certs))]
-> Map VName (PrimExp VName, Certs))
-> [(VName, (PrimExp VName, Certs))]
-> Map VName (PrimExp VName, Certs)
forall a b. (a -> b) -> a -> b
$ [Maybe (VName, (PrimExp VName, Certs))]
-> [(VName, (PrimExp VName, Certs))]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe (VName, (PrimExp VName, Certs))]
-> [(VName, (PrimExp VName, Certs))])
-> [Maybe (VName, (PrimExp VName, Certs))]
-> [(VName, (PrimExp VName, Certs))]
forall a b. (a -> b) -> a -> b
$ (Param (LParamInfo rep)
-> VName -> Maybe (VName, (PrimExp VName, Certs)))
-> [Param (LParamInfo rep)]
-> [VName]
-> [Maybe (VName, (PrimExp VName, Certs))]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Param (LParamInfo rep)
-> VName -> Maybe (VName, (PrimExp VName, Certs))
arrIndex [Param (LParamInfo rep)]
arr_params [VName]
arrs
arr_indexes' :: Map VName (PrimExp VName, Certs)
arr_indexes' = (Map VName (PrimExp VName, Certs)
-> Stm rep -> Map VName (PrimExp VName, Certs))
-> Map VName (PrimExp VName, Certs)
-> Seq (Stm rep)
-> Map VName (PrimExp VName, Certs)
forall b a. (b -> a -> b) -> b -> Seq a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Map VName (PrimExp VName, Certs)
-> Stm rep -> Map VName (PrimExp VName, Certs)
expandPrimExpTable Map VName (PrimExp VName, Certs)
arr_indexes (Seq (Stm rep) -> Map VName (PrimExp VName, Certs))
-> Seq (Stm rep) -> Map VName (PrimExp VName, Certs)
forall a b. (a -> b) -> a -> b
$ Body rep -> Seq (Stm rep)
forall rep. Body rep -> Stms rep
bodyStms (Body rep -> Seq (Stm rep)) -> Body rep -> Seq (Stm rep)
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam
case SubExpRes
se of
SubExpRes Certs
_ (Var VName
v) -> (PrimExp VName -> Certs -> Indexed)
-> (PrimExp VName, Certs) -> Indexed
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((Certs -> PrimExp VName -> Indexed)
-> PrimExp VName -> Certs -> Indexed
forall a b c. (a -> b -> c) -> b -> a -> c
flip Certs -> PrimExp VName -> Indexed
ST.Indexed) ((PrimExp VName, Certs) -> Indexed)
-> Maybe (PrimExp VName, Certs) -> Maybe Indexed
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName
-> Map VName (PrimExp VName, Certs) -> Maybe (PrimExp VName, Certs)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName (PrimExp VName, Certs)
arr_indexes'
SubExpRes
_ -> Maybe Indexed
forall a. Maybe a
Nothing
where
lambdaAndSubExp :: SOAC rep
-> Maybe (Lambda rep, SubExpRes, [Param (LParamInfo rep)], [VName])
lambdaAndSubExp (Screma SubExp
_ [VName]
arrs (ScremaForm Lambda rep
map_lam [Scan rep]
scans [Reduce rep]
reds)) =
Int
-> Lambda rep
-> [VName]
-> Maybe (Lambda rep, SubExpRes, [Param (LParamInfo rep)], [VName])
nthMapOut ([Scan rep] -> Int
forall rep. [Scan rep] -> Int
scanResults [Scan rep]
scans Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Reduce rep] -> Int
forall rep. [Reduce rep] -> Int
redResults [Reduce rep]
reds) Lambda rep
map_lam [VName]
arrs
lambdaAndSubExp SOAC rep
_ =
Maybe (Lambda rep, SubExpRes, [Param (LParamInfo rep)], [VName])
forall a. Maybe a
Nothing
nthMapOut :: Int
-> Lambda rep
-> [VName]
-> Maybe (Lambda rep, SubExpRes, [Param (LParamInfo rep)], [VName])
nthMapOut Int
num_accs Lambda rep
lam [VName]
arrs = do
SubExpRes
se <- Int -> Result -> Maybe SubExpRes
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth (Int
num_accs Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
k) (Result -> Maybe SubExpRes) -> Result -> Maybe SubExpRes
forall a b. (a -> b) -> a -> b
$ Body rep -> Result
forall rep. Body rep -> Result
bodyResult (Body rep -> Result) -> Body rep -> Result
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam
(Lambda rep, SubExpRes, [Param (LParamInfo rep)], [VName])
-> Maybe (Lambda rep, SubExpRes, [Param (LParamInfo rep)], [VName])
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda rep
lam, SubExpRes
se, Int -> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. Int -> [a] -> [a]
drop Int
num_accs ([Param (LParamInfo rep)] -> [Param (LParamInfo rep)])
-> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Param (LParamInfo rep)]
forall {rep}. Lambda rep -> [Param (LParamInfo rep)]
lambdaParams Lambda rep
lam, [VName]
arrs)
arrIndex :: Param (LParamInfo rep)
-> VName -> Maybe (VName, (PrimExp VName, Certs))
arrIndex Param (LParamInfo rep)
p VName
arr = do
ST.Indexed Certs
cs PrimExp VName
pe <- VName -> [TPrimExp Int64 VName] -> SymbolTable rep -> Maybe Indexed
forall rep.
VName -> [TPrimExp Int64 VName] -> SymbolTable rep -> Maybe Indexed
ST.index' VName
arr [TPrimExp Int64 VName
i] SymbolTable rep
vtable
(VName, (PrimExp VName, Certs))
-> Maybe (VName, (PrimExp VName, Certs))
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo rep)
p, (PrimExp VName
pe, Certs
cs))
expandPrimExpTable :: Map VName (PrimExp VName, Certs)
-> Stm rep -> Map VName (PrimExp VName, Certs)
expandPrimExpTable Map VName (PrimExp VName, Certs)
table Stm rep
stm
| [VName
v] <- Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat (LetDec rep) -> [VName]) -> Pat (LetDec rep) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm rep -> Pat (LetDec rep)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm rep
stm,
Just (PrimExp VName
pe, Certs
cs) <-
WriterT Certs Maybe (PrimExp VName) -> Maybe (PrimExp VName, Certs)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT Certs Maybe (PrimExp VName)
-> Maybe (PrimExp VName, Certs))
-> WriterT Certs Maybe (PrimExp VName)
-> Maybe (PrimExp VName, Certs)
forall a b. (a -> b) -> a -> b
$ (VName -> WriterT Certs Maybe (PrimExp VName))
-> Exp rep -> WriterT Certs Maybe (PrimExp VName)
forall (m :: * -> *) rep v.
(MonadFail m, RepTypes rep) =>
(VName -> m (PrimExp v)) -> Exp rep -> m (PrimExp v)
primExpFromExp (Map VName (PrimExp VName, Certs)
-> VName -> WriterT Certs Maybe (PrimExp VName)
asPrimExp Map VName (PrimExp VName, Certs)
table) (Exp rep -> WriterT Certs Maybe (PrimExp VName))
-> Exp rep -> WriterT Certs Maybe (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm,
(VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> SymbolTable rep -> Bool
forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` SymbolTable rep
vtable) (Certs -> [VName]
unCerts (Certs -> [VName]) -> Certs -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm rep -> Certs
forall rep. Stm rep -> Certs
stmCerts Stm rep
stm) =
VName
-> (PrimExp VName, Certs)
-> Map VName (PrimExp VName, Certs)
-> Map VName (PrimExp VName, Certs)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v (PrimExp VName
pe, Stm rep -> Certs
forall rep. Stm rep -> Certs
stmCerts Stm rep
stm Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
cs) Map VName (PrimExp VName, Certs)
table
| Bool
otherwise =
Map VName (PrimExp VName, Certs)
table
asPrimExp :: Map VName (PrimExp VName, Certs)
-> VName -> WriterT Certs Maybe (PrimExp VName)
asPrimExp Map VName (PrimExp VName, Certs)
table VName
v
| Just (PrimExp VName
e, Certs
cs) <- VName
-> Map VName (PrimExp VName, Certs) -> Maybe (PrimExp VName, Certs)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName (PrimExp VName, Certs)
table = Certs -> WriterT Certs Maybe ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Certs
cs WriterT Certs Maybe ()
-> WriterT Certs Maybe (PrimExp VName)
-> WriterT Certs Maybe (PrimExp VName)
forall a b.
WriterT Certs Maybe a
-> WriterT Certs Maybe b -> WriterT Certs Maybe b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> PrimExp VName -> WriterT Certs Maybe (PrimExp VName)
forall a. a -> WriterT Certs Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure PrimExp VName
e
| Just (Prim PrimType
pt) <- VName -> SymbolTable rep -> Maybe Type
forall rep. ASTRep rep => VName -> SymbolTable rep -> Maybe Type
ST.lookupType VName
v SymbolTable rep
vtable =
PrimExp VName -> WriterT Certs Maybe (PrimExp VName)
forall a. a -> WriterT Certs Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PrimExp VName -> WriterT Certs Maybe (PrimExp VName))
-> PrimExp VName -> WriterT Certs Maybe (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
v PrimType
pt
| Bool
otherwise = Maybe (PrimExp VName) -> WriterT Certs Maybe (PrimExp VName)
forall (m :: * -> *) a. Monad m => m a -> WriterT Certs m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift Maybe (PrimExp VName)
forall a. Maybe a
Nothing
indexOp SymbolTable rep
_ Int
_ SOAC rep
_ [TPrimExp Int64 VName]
_ = Maybe Indexed
forall a. Maybe a
Nothing
typeCheckSOAC :: (TC.Checkable rep) => SOAC (Aliases rep) -> TC.TypeM rep ()
typeCheckSOAC :: forall rep. Checkable rep => SOAC (Aliases rep) -> TypeM rep ()
typeCheckSOAC (VJP [SubExp]
args [SubExp]
vec Lambda (Aliases rep)
lam) = do
[Arg]
args' <- (SubExp -> TypeM rep Arg) -> [SubExp] -> TypeM rep [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> TypeM rep Arg
forall rep. Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
args
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
lam ([Arg] -> TypeM rep ()) -> [Arg] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases [Arg]
args'
[Type]
vec_ts <- (SubExp -> TypeM rep Type) -> [SubExp] -> TypeM rep [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> TypeM rep Type
forall rep. Checkable rep => SubExp -> TypeM rep Type
TC.checkSubExp [SubExp]
vec
Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
vec_ts [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda (Aliases rep) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
lam) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ())
-> (Doc Any -> ErrorCase rep) -> Doc Any -> TypeM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError (Text -> ErrorCase rep)
-> (Doc Any -> Text) -> Doc Any -> ErrorCase rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Doc Any -> Text
forall a. Doc a -> Text
docText (Doc Any -> TypeM rep ()) -> Doc Any -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
Doc Any
"Return type"
Doc Any -> Doc Any -> Doc Any
forall a. Doc a -> Doc a -> Doc a
</> Int -> Doc Any -> Doc Any
forall ann. Int -> Doc ann -> Doc ann
PP.indent Int
2 ([Type] -> Doc Any
forall ann. [Type] -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty (Lambda (Aliases rep) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
lam))
Doc Any -> Doc Any -> Doc Any
forall a. Doc a -> Doc a -> Doc a
</> Doc Any
"does not match type of seed vector"
Doc Any -> Doc Any -> Doc Any
forall a. Doc a -> Doc a -> Doc a
</> Int -> Doc Any -> Doc Any
forall ann. Int -> Doc ann -> Doc ann
PP.indent Int
2 ([Type] -> Doc Any
forall ann. [Type] -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty [Type]
vec_ts)
typeCheckSOAC (JVP [SubExp]
args [SubExp]
vec Lambda (Aliases rep)
lam) = do
[Arg]
args' <- (SubExp -> TypeM rep Arg) -> [SubExp] -> TypeM rep [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> TypeM rep Arg
forall rep. Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
args
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
lam ([Arg] -> TypeM rep ()) -> [Arg] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases [Arg]
args'
[Type]
vec_ts <- (SubExp -> TypeM rep Type) -> [SubExp] -> TypeM rep [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> TypeM rep Type
forall rep. Checkable rep => SubExp -> TypeM rep Type
TC.checkSubExp [SubExp]
vec
Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
vec_ts [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
args') (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ())
-> (Doc Any -> ErrorCase rep) -> Doc Any -> TypeM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError (Text -> ErrorCase rep)
-> (Doc Any -> Text) -> Doc Any -> ErrorCase rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Doc Any -> Text
forall a. Doc a -> Text
docText (Doc Any -> TypeM rep ()) -> Doc Any -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
Doc Any
"Parameter type"
Doc Any -> Doc Any -> Doc Any
forall a. Doc a -> Doc a -> Doc a
</> Int -> Doc Any -> Doc Any
forall ann. Int -> Doc ann -> Doc ann
PP.indent Int
2 ([Type] -> Doc Any
forall ann. [Type] -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty ([Type] -> Doc Any) -> [Type] -> Doc Any
forall a b. (a -> b) -> a -> b
$ (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
args')
Doc Any -> Doc Any -> Doc Any
forall a. Doc a -> Doc a -> Doc a
</> Doc Any
"does not match type of seed vector"
Doc Any -> Doc Any -> Doc Any
forall a. Doc a -> Doc a -> Doc a
</> Int -> Doc Any -> Doc Any
forall ann. Int -> Doc ann -> Doc ann
PP.indent Int
2 ([Type] -> Doc Any
forall ann. [Type] -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty [Type]
vec_ts)
typeCheckSOAC (Stream SubExp
size [VName]
arrexps [SubExp]
accexps Lambda (Aliases rep)
lam) = do
[Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
size
[Arg]
accargs <- (SubExp -> TypeM rep Arg) -> [SubExp] -> TypeM rep [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> TypeM rep Arg
forall rep. Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
accexps
[Type]
arrargs <- (VName -> TypeM rep Type) -> [VName] -> TypeM rep [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> TypeM rep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
arrexps
[Arg]
_ <- SubExp -> [VName] -> TypeM rep [Arg]
forall rep. Checkable rep => SubExp -> [VName] -> TypeM rep [Arg]
TC.checkSOACArrayArgs SubExp
size [VName]
arrexps
Param (LParamInfo rep)
chunk <- case Lambda (Aliases rep) -> [LParam (Aliases rep)]
forall {rep}. Lambda rep -> [Param (LParamInfo rep)]
lambdaParams Lambda (Aliases rep)
lam of
LParam (Aliases rep)
chunk : [LParam (Aliases rep)]
_ -> Param (LParamInfo rep) -> TypeM rep (Param (LParamInfo rep))
forall a. a -> TypeM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Param (LParamInfo rep)
LParam (Aliases rep)
chunk
[] -> ErrorCase rep -> TypeM rep (Param (LParamInfo rep))
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep (Param (LParamInfo rep)))
-> ErrorCase rep -> TypeM rep (Param (LParamInfo rep))
forall a b. (a -> b) -> a -> b
$ Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError Text
"Stream lambda without parameters."
let asArg :: a -> (a, b)
asArg a
t = (a
t, b
forall a. Monoid a => a
mempty)
inttp :: TypeBase shape u
inttp = PrimType -> TypeBase shape u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
lamarrs' :: [Type]
lamarrs' = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> SubExp -> Type
forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` VName -> SubExp
Var (Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo rep)
chunk)) [Type]
arrargs
acc_len :: Int
acc_len = [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
accexps
lamrtp :: [Type]
lamrtp = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take Int
acc_len ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda (Aliases rep) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
lam
Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ((Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
accargs [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
lamrtp) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ())
-> (Text -> ErrorCase rep) -> Text -> TypeM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError (Text -> TypeM rep ()) -> Text -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
Text
"Stream with inconsistent accumulator type in lambda."
let fake_lamarrs' :: [Arg]
fake_lamarrs' = (Type -> Arg) -> [Type] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Arg
forall {b} {a}. Monoid b => a -> (a, b)
asArg [Type]
lamarrs'
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
lam ([Arg] -> TypeM rep ()) -> [Arg] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ Type -> Arg
forall {b} {a}. Monoid b => a -> (a, b)
asArg Type
forall {shape} {u}. TypeBase shape u
inttp Arg -> [Arg] -> [Arg]
forall a. a -> [a] -> [a]
: [Arg]
accargs [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
fake_lamarrs'
typeCheckSOAC (Scatter SubExp
w [VName]
arrs ScatterSpec VName
as Lambda (Aliases rep)
lam) = do
[Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
w
let ([Shape]
as_ws, [Int]
as_ns, [VName]
_as_vs) = ScatterSpec VName -> ([Shape], [Int], [VName])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ScatterSpec VName
as
indexes :: Int
indexes = [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) [Int]
as_ns ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ (Shape -> Int) -> [Shape] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Shape -> Int
forall a. ShapeBase a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Shape]
as_ws
rts :: [Type]
rts = Lambda (Aliases rep) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
lam
rtsI :: [Type]
rtsI = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take Int
indexes [Type]
rts
rtsV :: [Type]
rtsV = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop Int
indexes [Type]
rts
Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
rts Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
as_ns Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) [Int]
as_ns ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ (Shape -> Int) -> [Shape] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Shape -> Int
forall a. ShapeBase a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Shape]
as_ws)) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError Text
"Scatter: number of index types, value types and array outputs do not match."
[Type] -> (Type -> TypeM rep ()) -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Type]
rtsI ((Type -> TypeM rep ()) -> TypeM rep ())
-> (Type -> TypeM rep ()) -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ \Type
rtI ->
Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64 Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
rtI) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError Text
"Scatter: Index return type must be i64."
[([Type], (Shape, Int, VName))]
-> (([Type], (Shape, Int, VName)) -> TypeM rep ()) -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([[Type]] -> ScatterSpec VName -> [([Type], (Shape, Int, VName))]
forall a b. [a] -> [b] -> [(a, b)]
zip ([Int] -> [Type] -> [[Type]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
as_ns [Type]
rtsV) ScatterSpec VName
as) ((([Type], (Shape, Int, VName)) -> TypeM rep ()) -> TypeM rep ())
-> (([Type], (Shape, Int, VName)) -> TypeM rep ()) -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ \([Type]
rtVs, (Shape
aw, Int
_, VName
a)) -> do
(SubExp -> TypeM rep ()) -> Shape -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) Shape
aw
[Type] -> (Type -> TypeM rep ()) -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Type]
rtVs ((Type -> TypeM rep ()) -> TypeM rep ())
-> (Type -> TypeM rep ()) -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ \Type
rtV -> [Type] -> VName -> TypeM rep ()
forall rep. Checkable rep => [Type] -> VName -> TypeM rep ()
TC.requireI [Type -> Shape -> Type
arrayOfShape Type
rtV Shape
aw] VName
a
Names -> TypeM rep ()
forall rep. Checkable rep => Names -> TypeM rep ()
TC.consume (Names -> TypeM rep ()) -> TypeM rep Names -> TypeM rep ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> TypeM rep Names
forall rep. Checkable rep => VName -> TypeM rep Names
TC.lookupAliases VName
a
[Arg]
arrargs <- SubExp -> [VName] -> TypeM rep [Arg]
forall rep. Checkable rep => SubExp -> [VName] -> TypeM rep [Arg]
TC.checkSOACArrayArgs SubExp
w [VName]
arrs
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
lam [Arg]
arrargs
typeCheckSOAC (Hist SubExp
w [VName]
arrs [HistOp (Aliases rep)]
ops Lambda (Aliases rep)
bucket_fun) = do
[Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
w
[HistOp (Aliases rep)]
-> (HistOp (Aliases rep) -> TypeM rep ()) -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [HistOp (Aliases rep)]
ops ((HistOp (Aliases rep) -> TypeM rep ()) -> TypeM rep ())
-> (HistOp (Aliases rep) -> TypeM rep ()) -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ \(HistOp Shape
dest_shape SubExp
rf [VName]
dests [SubExp]
nes Lambda (Aliases rep)
op) -> do
[Arg]
nes' <- (SubExp -> TypeM rep Arg) -> [SubExp] -> TypeM rep [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> TypeM rep Arg
forall rep. Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
nes
(SubExp -> TypeM rep ()) -> Shape -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) Shape
dest_shape
[Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
rf
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
op ([Arg] -> TypeM rep ()) -> [Arg] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
nes'
let nes_t :: [Type]
nes_t = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
nes'
Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
nes_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda (Aliases rep) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
op) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ())
-> (Text -> ErrorCase rep) -> Text -> TypeM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError (Text -> TypeM rep ()) -> Text -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
Text
"Operator has return type "
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Type] -> Text
forall a. Pretty a => [a] -> Text
prettyTuple (Lambda (Aliases rep) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
op)
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" but neutral element has type "
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Type] -> Text
forall a. Pretty a => [a] -> Text
prettyTuple [Type]
nes_t
[(Type, VName)] -> ((Type, VName) -> TypeM rep ()) -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Type] -> [VName] -> [(Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
nes_t [VName]
dests) (((Type, VName) -> TypeM rep ()) -> TypeM rep ())
-> ((Type, VName) -> TypeM rep ()) -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ \(Type
t, VName
dest) -> do
[Type] -> VName -> TypeM rep ()
forall rep. Checkable rep => [Type] -> VName -> TypeM rep ()
TC.requireI [Type
t Type -> Shape -> Type
`arrayOfShape` Shape
dest_shape] VName
dest
Names -> TypeM rep ()
forall rep. Checkable rep => Names -> TypeM rep ()
TC.consume (Names -> TypeM rep ()) -> TypeM rep Names -> TypeM rep ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> TypeM rep Names
forall rep. Checkable rep => VName -> TypeM rep Names
TC.lookupAliases VName
dest
[Arg]
img' <- SubExp -> [VName] -> TypeM rep [Arg]
forall rep. Checkable rep => SubExp -> [VName] -> TypeM rep [Arg]
TC.checkSOACArrayArgs SubExp
w [VName]
arrs
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
bucket_fun [Arg]
img'
[Type]
nes_ts <- [[Type]] -> [Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Type]] -> [Type]) -> TypeM rep [[Type]] -> TypeM rep [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HistOp (Aliases rep) -> TypeM rep [Type])
-> [HistOp (Aliases rep)] -> TypeM rep [[Type]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((SubExp -> TypeM rep Type) -> [SubExp] -> TypeM rep [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> TypeM rep Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType ([SubExp] -> TypeM rep [Type])
-> (HistOp (Aliases rep) -> [SubExp])
-> HistOp (Aliases rep)
-> TypeM rep [Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp (Aliases rep) -> [SubExp]
forall rep. HistOp rep -> [SubExp]
histNeutral) [HistOp (Aliases rep)]
ops
let bucket_ret_t :: [Type]
bucket_ret_t =
(HistOp (Aliases rep) -> [Type])
-> [HistOp (Aliases rep)] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((Int -> Type -> [Type]
forall a. Int -> a -> [a]
`replicate` PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) (Int -> [Type])
-> (HistOp (Aliases rep) -> Int) -> HistOp (Aliases rep) -> [Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (Shape -> Int)
-> (HistOp (Aliases rep) -> Shape) -> HistOp (Aliases rep) -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp (Aliases rep) -> Shape
forall rep. HistOp rep -> Shape
histShape) [HistOp (Aliases rep)]
ops
[Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
nes_ts
Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
bucket_ret_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda (Aliases rep) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
bucket_fun) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ())
-> (Text -> ErrorCase rep) -> Text -> TypeM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError (Text -> TypeM rep ()) -> Text -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
Text
"Bucket function has return type "
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Type] -> Text
forall a. Pretty a => [a] -> Text
prettyTuple (Lambda (Aliases rep) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
bucket_fun)
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" but should have type "
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Type] -> Text
forall a. Pretty a => [a] -> Text
prettyTuple [Type]
bucket_ret_t
typeCheckSOAC (Screma SubExp
w [VName]
arrs (ScremaForm Lambda (Aliases rep)
map_lam [Scan (Aliases rep)]
scans [Reduce (Aliases rep)]
reds)) = do
[Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
w
[Arg]
arrs' <- SubExp -> [VName] -> TypeM rep [Arg]
forall rep. Checkable rep => SubExp -> [VName] -> TypeM rep [Arg]
TC.checkSOACArrayArgs SubExp
w [VName]
arrs
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
map_lam [Arg]
arrs'
[Arg]
scan_nes' <- ([[Arg]] -> [Arg]) -> TypeM rep [[Arg]] -> TypeM rep [Arg]
forall a b. (a -> b) -> TypeM rep a -> TypeM rep b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [[Arg]] -> [Arg]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (TypeM rep [[Arg]] -> TypeM rep [Arg])
-> TypeM rep [[Arg]] -> TypeM rep [Arg]
forall a b. (a -> b) -> a -> b
$
[Scan (Aliases rep)]
-> (Scan (Aliases rep) -> TypeM rep [Arg]) -> TypeM rep [[Arg]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Scan (Aliases rep)]
scans ((Scan (Aliases rep) -> TypeM rep [Arg]) -> TypeM rep [[Arg]])
-> (Scan (Aliases rep) -> TypeM rep [Arg]) -> TypeM rep [[Arg]]
forall a b. (a -> b) -> a -> b
$ \(Scan Lambda (Aliases rep)
scan_lam [SubExp]
scan_nes) -> do
[Arg]
scan_nes' <- (SubExp -> TypeM rep Arg) -> [SubExp] -> TypeM rep [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> TypeM rep Arg
forall rep. Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
scan_nes
let scan_t :: [Type]
scan_t = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
scan_nes'
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
scan_lam ([Arg] -> TypeM rep ()) -> [Arg] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
scan_nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
scan_nes'
Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
scan_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda (Aliases rep) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
scan_lam) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ())
-> (Text -> ErrorCase rep) -> Text -> TypeM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError (Text -> TypeM rep ()) -> Text -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
Text
"Scan function returns type "
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Type] -> Text
forall a. Pretty a => [a] -> Text
prettyTuple (Lambda (Aliases rep) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
scan_lam)
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" but neutral element has type "
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Type] -> Text
forall a. Pretty a => [a] -> Text
prettyTuple [Type]
scan_t
[Arg] -> TypeM rep [Arg]
forall a. a -> TypeM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Arg]
scan_nes'
[Arg]
red_nes' <- ([[Arg]] -> [Arg]) -> TypeM rep [[Arg]] -> TypeM rep [Arg]
forall a b. (a -> b) -> TypeM rep a -> TypeM rep b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [[Arg]] -> [Arg]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (TypeM rep [[Arg]] -> TypeM rep [Arg])
-> TypeM rep [[Arg]] -> TypeM rep [Arg]
forall a b. (a -> b) -> a -> b
$
[Reduce (Aliases rep)]
-> (Reduce (Aliases rep) -> TypeM rep [Arg]) -> TypeM rep [[Arg]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Reduce (Aliases rep)]
reds ((Reduce (Aliases rep) -> TypeM rep [Arg]) -> TypeM rep [[Arg]])
-> (Reduce (Aliases rep) -> TypeM rep [Arg]) -> TypeM rep [[Arg]]
forall a b. (a -> b) -> a -> b
$ \(Reduce Commutativity
_ Lambda (Aliases rep)
red_lam [SubExp]
red_nes) -> do
[Arg]
red_nes' <- (SubExp -> TypeM rep Arg) -> [SubExp] -> TypeM rep [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> TypeM rep Arg
forall rep. Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
red_nes
let red_t :: [Type]
red_t = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
red_nes'
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
red_lam ([Arg] -> TypeM rep ()) -> [Arg] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
red_nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
red_nes'
Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
red_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda (Aliases rep) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
red_lam) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ())
-> (Text -> ErrorCase rep) -> Text -> TypeM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError (Text -> TypeM rep ()) -> Text -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
Text
"Reduce function returns type "
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Type] -> Text
forall a. Pretty a => [a] -> Text
prettyTuple (Lambda (Aliases rep) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
red_lam)
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" but neutral element has type "
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Type] -> Text
forall a. Pretty a => [a] -> Text
prettyTuple [Type]
red_t
[Arg] -> TypeM rep [Arg]
forall a. a -> TypeM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Arg]
red_nes'
let map_lam_ts :: [Type]
map_lam_ts = Lambda (Aliases rep) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
map_lam
Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless
( Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take ([Arg] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Arg]
scan_nes' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Arg] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Arg]
red_nes') [Type]
map_lam_ts
[Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType ([Arg]
scan_nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
red_nes')
)
(TypeM rep () -> TypeM rep ())
-> (Text -> TypeM rep ()) -> Text -> TypeM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad
(ErrorCase rep -> TypeM rep ())
-> (Text -> ErrorCase rep) -> Text -> TypeM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError
(Text -> TypeM rep ()) -> Text -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ Text
"Map function return type "
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Type] -> Text
forall a. Pretty a => [a] -> Text
prettyTuple [Type]
map_lam_ts
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" wrong for given scan and reduction functions."
instance RephraseOp SOAC where
rephraseInOp :: forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> SOAC from -> m (SOAC to)
rephraseInOp Rephraser m from to
r (VJP [SubExp]
args [SubExp]
vec Lambda from
lam) =
[SubExp] -> [SubExp] -> Lambda to -> SOAC to
forall rep. [SubExp] -> [SubExp] -> Lambda rep -> SOAC rep
VJP [SubExp]
args [SubExp]
vec (Lambda to -> SOAC to) -> m (Lambda to) -> m (SOAC to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Rephraser m from to -> Lambda from -> m (Lambda to)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda Rephraser m from to
r Lambda from
lam
rephraseInOp Rephraser m from to
r (JVP [SubExp]
args [SubExp]
vec Lambda from
lam) =
[SubExp] -> [SubExp] -> Lambda to -> SOAC to
forall rep. [SubExp] -> [SubExp] -> Lambda rep -> SOAC rep
JVP [SubExp]
args [SubExp]
vec (Lambda to -> SOAC to) -> m (Lambda to) -> m (SOAC to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Rephraser m from to -> Lambda from -> m (Lambda to)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda Rephraser m from to
r Lambda from
lam
rephraseInOp Rephraser m from to
r (Stream SubExp
w [VName]
arrs [SubExp]
acc Lambda from
lam) =
SubExp -> [VName] -> [SubExp] -> Lambda to -> SOAC to
forall rep. SubExp -> [VName] -> [SubExp] -> Lambda rep -> SOAC rep
Stream SubExp
w [VName]
arrs [SubExp]
acc (Lambda to -> SOAC to) -> m (Lambda to) -> m (SOAC to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Rephraser m from to -> Lambda from -> m (Lambda to)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda Rephraser m from to
r Lambda from
lam
rephraseInOp Rephraser m from to
r (Scatter SubExp
w [VName]
arrs ScatterSpec VName
dests Lambda from
lam) =
SubExp -> [VName] -> ScatterSpec VName -> Lambda to -> SOAC to
forall rep.
SubExp -> [VName] -> ScatterSpec VName -> Lambda rep -> SOAC rep
Scatter SubExp
w [VName]
arrs ScatterSpec VName
dests (Lambda to -> SOAC to) -> m (Lambda to) -> m (SOAC to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Rephraser m from to -> Lambda from -> m (Lambda to)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda Rephraser m from to
r Lambda from
lam
rephraseInOp Rephraser m from to
r (Hist SubExp
w [VName]
arrs [HistOp from]
ops Lambda from
lam) =
SubExp -> [VName] -> [HistOp to] -> Lambda to -> SOAC to
forall rep.
SubExp -> [VName] -> [HistOp rep] -> Lambda rep -> SOAC rep
Hist SubExp
w [VName]
arrs ([HistOp to] -> Lambda to -> SOAC to)
-> m [HistOp to] -> m (Lambda to -> SOAC to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HistOp from -> m (HistOp to)) -> [HistOp from] -> m [HistOp to]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM HistOp from -> m (HistOp to)
onOp [HistOp from]
ops m (Lambda to -> SOAC to) -> m (Lambda to) -> m (SOAC to)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Rephraser m from to -> Lambda from -> m (Lambda to)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda Rephraser m from to
r Lambda from
lam
where
onOp :: HistOp from -> m (HistOp to)
onOp (HistOp Shape
dest_shape SubExp
rf [VName]
dests [SubExp]
nes Lambda from
op) =
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda to -> HistOp to
forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp Shape
dest_shape SubExp
rf [VName]
dests [SubExp]
nes (Lambda to -> HistOp to) -> m (Lambda to) -> m (HistOp to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Rephraser m from to -> Lambda from -> m (Lambda to)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda Rephraser m from to
r Lambda from
op
rephraseInOp Rephraser m from to
r (Screma SubExp
w [VName]
arrs (ScremaForm Lambda from
lam [Scan from]
scans [Reduce from]
red)) =
SubExp -> [VName] -> ScremaForm to -> SOAC to
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs
(ScremaForm to -> SOAC to) -> m (ScremaForm to) -> m (SOAC to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ( Lambda to -> [Scan to] -> [Reduce to] -> ScremaForm to
forall rep.
Lambda rep -> [Scan rep] -> [Reduce rep] -> ScremaForm rep
ScremaForm
(Lambda to -> [Scan to] -> [Reduce to] -> ScremaForm to)
-> m (Lambda to) -> m ([Scan to] -> [Reduce to] -> ScremaForm to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Rephraser m from to -> Lambda from -> m (Lambda to)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda Rephraser m from to
r Lambda from
lam
m ([Scan to] -> [Reduce to] -> ScremaForm to)
-> m [Scan to] -> m ([Reduce to] -> ScremaForm to)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Scan from -> m (Scan to)) -> [Scan from] -> m [Scan to]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Scan from -> m (Scan to)
onScan [Scan from]
scans
m ([Reduce to] -> ScremaForm to)
-> m [Reduce to] -> m (ScremaForm to)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Reduce from -> m (Reduce to)) -> [Reduce from] -> m [Reduce to]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Reduce from -> m (Reduce to)
onRed [Reduce from]
red
)
where
onScan :: Scan from -> m (Scan to)
onScan (Scan Lambda from
op [SubExp]
nes) = Lambda to -> [SubExp] -> Scan to
forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan (Lambda to -> [SubExp] -> Scan to)
-> m (Lambda to) -> m ([SubExp] -> Scan to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Rephraser m from to -> Lambda from -> m (Lambda to)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda Rephraser m from to
r Lambda from
op m ([SubExp] -> Scan to) -> m [SubExp] -> m (Scan to)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [SubExp] -> m [SubExp]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp]
nes
onRed :: Reduce from -> m (Reduce to)
onRed (Reduce Commutativity
comm Lambda from
op [SubExp]
nes) = Commutativity -> Lambda to -> [SubExp] -> Reduce to
forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
comm (Lambda to -> [SubExp] -> Reduce to)
-> m (Lambda to) -> m ([SubExp] -> Reduce to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Rephraser m from to -> Lambda from -> m (Lambda to)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda Rephraser m from to
r Lambda from
op m ([SubExp] -> Reduce to) -> m [SubExp] -> m (Reduce to)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [SubExp] -> m [SubExp]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp]
nes
instance (OpMetrics (Op rep)) => OpMetrics (SOAC rep) where
opMetrics :: SOAC rep -> MetricsM ()
opMetrics (VJP [SubExp]
_ [SubExp]
_ Lambda rep
lam) =
Text -> MetricsM () -> MetricsM ()
inside Text
"VJP" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics Lambda rep
lam
opMetrics (JVP [SubExp]
_ [SubExp]
_ Lambda rep
lam) =
Text -> MetricsM () -> MetricsM ()
inside Text
"JVP" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics Lambda rep
lam
opMetrics (Stream SubExp
_ [VName]
_ [SubExp]
_ Lambda rep
lam) =
Text -> MetricsM () -> MetricsM ()
inside Text
"Stream" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics Lambda rep
lam
opMetrics (Scatter SubExp
_len [VName]
_ ScatterSpec VName
_ Lambda rep
lam) =
Text -> MetricsM () -> MetricsM ()
inside Text
"Scatter" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics Lambda rep
lam
opMetrics (Hist SubExp
_ [VName]
_ [HistOp rep]
ops Lambda rep
bucket_fun) =
Text -> MetricsM () -> MetricsM ()
inside Text
"Hist" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ (HistOp rep -> MetricsM ()) -> [HistOp rep] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics (Lambda rep -> MetricsM ())
-> (HistOp rep -> Lambda rep) -> HistOp rep -> MetricsM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp rep -> Lambda rep
forall rep. HistOp rep -> Lambda rep
histOp) [HistOp rep]
ops MetricsM () -> MetricsM () -> MetricsM ()
forall a b. MetricsM a -> MetricsM b -> MetricsM b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics Lambda rep
bucket_fun
opMetrics (Screma SubExp
_ [VName]
_ (ScremaForm Lambda rep
map_lam [Scan rep]
scans [Reduce rep]
reds)) =
Text -> MetricsM () -> MetricsM ()
inside Text
"Screma" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ do
Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics Lambda rep
map_lam
(Scan rep -> MetricsM ()) -> [Scan rep] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics (Lambda rep -> MetricsM ())
-> (Scan rep -> Lambda rep) -> Scan rep -> MetricsM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Scan rep -> Lambda rep
forall rep. Scan rep -> Lambda rep
scanLambda) [Scan rep]
scans
(Reduce rep -> MetricsM ()) -> [Reduce rep] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics (Lambda rep -> MetricsM ())
-> (Reduce rep -> Lambda rep) -> Reduce rep -> MetricsM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Reduce rep -> Lambda rep
forall rep. Reduce rep -> Lambda rep
redLambda) [Reduce rep]
reds
instance (PrettyRep rep) => PP.Pretty (SOAC rep) where
pretty :: forall ann. SOAC rep -> Doc ann
pretty (VJP [SubExp]
args [SubExp]
vec Lambda rep
lam) =
Doc ann
"vjp"
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
parens
( Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.align (Doc ann -> Doc ann) -> Doc ann -> Doc ann
forall a b. (a -> b) -> a -> b
$
Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.braces ([Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
commasep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc ann) -> [SubExp] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty [SubExp]
args)
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.braces ([Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
commasep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc ann) -> [SubExp] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty [SubExp]
vec)
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Lambda rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Lambda rep -> Doc ann
pretty Lambda rep
lam
)
pretty (JVP [SubExp]
args [SubExp]
vec Lambda rep
lam) =
Doc ann
"jvp"
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
parens
( Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.align (Doc ann -> Doc ann) -> Doc ann -> Doc ann
forall a b. (a -> b) -> a -> b
$
Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.braces ([Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
commasep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc ann) -> [SubExp] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty [SubExp]
args)
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.braces ([Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
commasep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc ann) -> [SubExp] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty [SubExp]
vec)
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Lambda rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Lambda rep -> Doc ann
pretty Lambda rep
lam
)
pretty (Stream SubExp
size [VName]
arrs [SubExp]
acc Lambda rep
lam) =
SubExp -> [VName] -> [SubExp] -> Lambda rep -> Doc ann
forall rep inp ann.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> [SubExp] -> Lambda rep -> Doc ann
ppStream SubExp
size [VName]
arrs [SubExp]
acc Lambda rep
lam
pretty (Scatter SubExp
w [VName]
arrs ScatterSpec VName
dests Lambda rep
lam) =
SubExp -> [VName] -> ScatterSpec VName -> Lambda rep -> Doc ann
forall rep inp ann.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> ScatterSpec VName -> Lambda rep -> Doc ann
ppScatter SubExp
w [VName]
arrs ScatterSpec VName
dests Lambda rep
lam
pretty (Hist SubExp
w [VName]
arrs [HistOp rep]
ops Lambda rep
bucket_fun) =
SubExp -> [VName] -> [HistOp rep] -> Lambda rep -> Doc ann
forall rep inp ann.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> [HistOp rep] -> Lambda rep -> Doc ann
ppHist SubExp
w [VName]
arrs [HistOp rep]
ops Lambda rep
bucket_fun
pretty (Screma SubExp
w [VName]
arrs (ScremaForm Lambda rep
map_lam [Scan rep]
scans [Reduce rep]
reds))
| [Scan rep] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Scan rep]
scans,
[Reduce rep] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Reduce rep]
reds =
Doc ann
"map"
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> (Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
parens (Doc ann -> Doc ann) -> (Doc ann -> Doc ann) -> Doc ann -> Doc ann
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
align)
( SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty SubExp
w
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
ppTuple' ((VName -> Doc ann) -> [VName] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. VName -> Doc ann
pretty [VName]
arrs)
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Lambda rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Lambda rep -> Doc ann
pretty Lambda rep
map_lam
)
| [Scan rep] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Scan rep]
scans =
Doc ann
"redomap"
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> (Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
parens (Doc ann -> Doc ann) -> (Doc ann -> Doc ann) -> Doc ann -> Doc ann
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
align)
( SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty SubExp
w
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
ppTuple' ((VName -> Doc ann) -> [VName] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. VName -> Doc ann
pretty [VName]
arrs)
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Lambda rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Lambda rep -> Doc ann
pretty Lambda rep
map_lam
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.braces ([Doc ann] -> Doc ann
forall a. Monoid a => [a] -> a
mconcat ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ Doc ann -> [Doc ann] -> [Doc ann]
forall a. a -> [a] -> [a]
intersperse (Doc ann
forall ann. Doc ann
comma Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.line) ([Doc ann] -> [Doc ann]) -> [Doc ann] -> [Doc ann]
forall a b. (a -> b) -> a -> b
$ (Reduce rep -> Doc ann) -> [Reduce rep] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map Reduce rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Reduce rep -> Doc ann
pretty [Reduce rep]
reds)
)
| [Reduce rep] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Reduce rep]
reds =
Doc ann
"scanomap"
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> (Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
parens (Doc ann -> Doc ann) -> (Doc ann -> Doc ann) -> Doc ann -> Doc ann
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
align)
( SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty SubExp
w
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
ppTuple' ((VName -> Doc ann) -> [VName] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. VName -> Doc ann
pretty [VName]
arrs)
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Lambda rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Lambda rep -> Doc ann
pretty Lambda rep
map_lam
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.braces
([Doc ann] -> Doc ann
forall a. Monoid a => [a] -> a
mconcat ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ Doc ann -> [Doc ann] -> [Doc ann]
forall a. a -> [a] -> [a]
intersperse (Doc ann
forall ann. Doc ann
comma Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.line) ([Doc ann] -> [Doc ann]) -> [Doc ann] -> [Doc ann]
forall a b. (a -> b) -> a -> b
$ (Scan rep -> Doc ann) -> [Scan rep] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map Scan rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Scan rep -> Doc ann
pretty [Scan rep]
scans)
)
pretty (Screma SubExp
w [VName]
arrs ScremaForm rep
form) = SubExp -> [VName] -> ScremaForm rep -> Doc ann
forall rep inp ann.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> ScremaForm rep -> Doc ann
ppScrema SubExp
w [VName]
arrs ScremaForm rep
form
ppScrema ::
(PrettyRep rep, Pretty inp) => SubExp -> [inp] -> ScremaForm rep -> Doc ann
ppScrema :: forall rep inp ann.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> ScremaForm rep -> Doc ann
ppScrema SubExp
w [inp]
arrs (ScremaForm Lambda rep
map_lam [Scan rep]
scans [Reduce rep]
reds) =
Doc ann
"screma"
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> (Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
parens (Doc ann -> Doc ann) -> (Doc ann -> Doc ann) -> Doc ann -> Doc ann
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
align)
( SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty SubExp
w
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
ppTuple' ((inp -> Doc ann) -> [inp] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map inp -> Doc ann
forall ann. inp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty [inp]
arrs)
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Lambda rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Lambda rep -> Doc ann
pretty Lambda rep
map_lam
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.braces ([Doc ann] -> Doc ann
forall a. Monoid a => [a] -> a
mconcat ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ Doc ann -> [Doc ann] -> [Doc ann]
forall a. a -> [a] -> [a]
intersperse (Doc ann
forall ann. Doc ann
comma Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.line) ([Doc ann] -> [Doc ann]) -> [Doc ann] -> [Doc ann]
forall a b. (a -> b) -> a -> b
$ (Scan rep -> Doc ann) -> [Scan rep] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map Scan rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Scan rep -> Doc ann
pretty [Scan rep]
scans)
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.braces ([Doc ann] -> Doc ann
forall a. Monoid a => [a] -> a
mconcat ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ Doc ann -> [Doc ann] -> [Doc ann]
forall a. a -> [a] -> [a]
intersperse (Doc ann
forall ann. Doc ann
comma Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.line) ([Doc ann] -> [Doc ann]) -> [Doc ann] -> [Doc ann]
forall a b. (a -> b) -> a -> b
$ (Reduce rep -> Doc ann) -> [Reduce rep] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map Reduce rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Reduce rep -> Doc ann
pretty [Reduce rep]
reds)
)
ppStream ::
(PrettyRep rep, Pretty inp) => SubExp -> [inp] -> [SubExp] -> Lambda rep -> Doc ann
ppStream :: forall rep inp ann.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> [SubExp] -> Lambda rep -> Doc ann
ppStream SubExp
size [inp]
arrs [SubExp]
acc Lambda rep
lam =
Doc ann
"streamSeq"
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> (Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
parens (Doc ann -> Doc ann) -> (Doc ann -> Doc ann) -> Doc ann -> Doc ann
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
align)
( SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty SubExp
size
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
ppTuple' ((inp -> Doc ann) -> [inp] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map inp -> Doc ann
forall ann. inp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty [inp]
arrs)
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
ppTuple' ((SubExp -> Doc ann) -> [SubExp] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty [SubExp]
acc)
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Lambda rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Lambda rep -> Doc ann
pretty Lambda rep
lam
)
ppScatter ::
(PrettyRep rep, Pretty inp) => SubExp -> [inp] -> [(Shape, Int, VName)] -> Lambda rep -> Doc ann
ppScatter :: forall rep inp ann.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> ScatterSpec VName -> Lambda rep -> Doc ann
ppScatter SubExp
w [inp]
arrs ScatterSpec VName
dests Lambda rep
lam =
Doc ann
"scatter"
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> (Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
parens (Doc ann -> Doc ann) -> (Doc ann -> Doc ann) -> Doc ann -> Doc ann
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
align)
( SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty SubExp
w
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
ppTuple' ((inp -> Doc ann) -> [inp] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map inp -> Doc ann
forall ann. inp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty [inp]
arrs)
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
commasep (((Shape, Int, VName) -> Doc ann) -> ScatterSpec VName -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map (Shape, Int, VName) -> Doc ann
forall ann. (Shape, Int, VName) -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty ScatterSpec VName
dests)
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Lambda rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Lambda rep -> Doc ann
pretty Lambda rep
lam
)
instance (PrettyRep rep) => Pretty (Scan rep) where
pretty :: forall ann. Scan rep -> Doc ann
pretty (Scan Lambda rep
scan_lam [SubExp]
scan_nes) =
Lambda rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Lambda rep -> Doc ann
pretty Lambda rep
scan_lam Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.braces ([Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
commasep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc ann) -> [SubExp] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty [SubExp]
scan_nes)
ppComm :: Commutativity -> Doc ann
ppComm :: forall ann. Commutativity -> Doc ann
ppComm Commutativity
Noncommutative = Doc ann
forall a. Monoid a => a
mempty
ppComm Commutativity
Commutative = Doc ann
"commutative "
instance (PrettyRep rep) => Pretty (Reduce rep) where
pretty :: forall ann. Reduce rep -> Doc ann
pretty (Reduce Commutativity
comm Lambda rep
red_lam [SubExp]
red_nes) =
Commutativity -> Doc ann
forall ann. Commutativity -> Doc ann
ppComm Commutativity
comm
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Lambda rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Lambda rep -> Doc ann
pretty Lambda rep
red_lam
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.braces ([Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
commasep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc ann) -> [SubExp] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty [SubExp]
red_nes)
ppHist ::
(PrettyRep rep, Pretty inp) =>
SubExp ->
[inp] ->
[HistOp rep] ->
Lambda rep ->
Doc ann
ppHist :: forall rep inp ann.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> [HistOp rep] -> Lambda rep -> Doc ann
ppHist SubExp
w [inp]
arrs [HistOp rep]
ops Lambda rep
bucket_fun =
Doc ann
"hist"
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
parens
( SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty SubExp
w
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
ppTuple' ((inp -> Doc ann) -> [inp] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map inp -> Doc ann
forall ann. inp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty [inp]
arrs)
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.braces ([Doc ann] -> Doc ann
forall a. Monoid a => [a] -> a
mconcat ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ Doc ann -> [Doc ann] -> [Doc ann]
forall a. a -> [a] -> [a]
intersperse (Doc ann
forall ann. Doc ann
comma Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.line) ([Doc ann] -> [Doc ann]) -> [Doc ann] -> [Doc ann]
forall a b. (a -> b) -> a -> b
$ (HistOp rep -> Doc ann) -> [HistOp rep] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map HistOp rep -> Doc ann
forall {rep} {ann}. PrettyRep rep => HistOp rep -> Doc ann
ppOp [HistOp rep]
ops)
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Lambda rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Lambda rep -> Doc ann
pretty Lambda rep
bucket_fun
)
where
ppOp :: HistOp rep -> Doc ann
ppOp (HistOp Shape
dest_w SubExp
rf [VName]
dests [SubExp]
nes Lambda rep
op) =
Shape -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Shape -> Doc ann
pretty Shape
dest_w
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty SubExp
rf
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.braces ([Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
commasep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ (VName -> Doc ann) -> [VName] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. VName -> Doc ann
pretty [VName]
dests)
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
ppTuple' ((SubExp -> Doc ann) -> [SubExp] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty [SubExp]
nes)
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Lambda rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Lambda rep -> Doc ann
pretty Lambda rep
op