module Futhark.AD.Rev.Scan (diffScan, diffScanVec, diffScanAdd) where
import Control.Monad
import Data.List (transpose)
import Futhark.AD.Rev.Monad
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.SOACS
import Futhark.IR.SOACS.Simplify (simplifyLambda)
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Util (chunk)
data FirstOrSecond = WrtFirst | WrtSecond
identityM :: Int -> Type -> ADM [[SubExp]]
identityM :: Int -> TypeBase Shape NoUniqueness -> ADM [[SubExp]]
identityM Int
n TypeBase Shape NoUniqueness
t =
([Exp SOACS] -> ADM [SubExp]) -> [[Exp SOACS]] -> ADM [[SubExp]]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse
((Exp SOACS -> ADM SubExp) -> [Exp SOACS] -> ADM [SubExp]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ([Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"id"))
[[if Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
j then TypeBase Shape NoUniqueness -> Exp SOACS
forall rep. TypeBase Shape NoUniqueness -> Exp rep
oneExp TypeBase Shape NoUniqueness
t else TypeBase Shape NoUniqueness -> Exp SOACS
forall rep. TypeBase Shape NoUniqueness -> Exp rep
zeroExp TypeBase Shape NoUniqueness
t | Int
i <- [Int
1 .. Int
n]] | Int
j <- [Int
1 .. Int
n]]
matrixMul :: [[PrimExp VName]] -> [[PrimExp VName]] -> PrimType -> [[PrimExp VName]]
matrixMul :: [[PrimExp VName]]
-> [[PrimExp VName]] -> PrimType -> [[PrimExp VName]]
matrixMul [[PrimExp VName]]
m1 [[PrimExp VName]]
m2 PrimType
t =
let zero :: PrimExp VName
zero = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t (SubExp -> PrimExp VName) -> SubExp -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t
in [[(PrimExp VName -> PrimExp VName -> PrimExp VName)
-> PrimExp VName -> [PrimExp VName] -> PrimExp VName
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
(~+~) PrimExp VName
zero ([PrimExp VName] -> PrimExp VName)
-> [PrimExp VName] -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ (PrimExp VName -> PrimExp VName -> PrimExp VName)
-> [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
(~*~) [PrimExp VName]
r [PrimExp VName]
q | [PrimExp VName]
q <- [[PrimExp VName]] -> [[PrimExp VName]]
forall a. [[a]] -> [[a]]
transpose [[PrimExp VName]]
m2] | [PrimExp VName]
r <- [[PrimExp VName]]
m1]
matrixVecMul :: [[PrimExp VName]] -> [PrimExp VName] -> PrimType -> [PrimExp VName]
matrixVecMul :: [[PrimExp VName]] -> [PrimExp VName] -> PrimType -> [PrimExp VName]
matrixVecMul [[PrimExp VName]]
m [PrimExp VName]
v PrimType
t =
let zero :: PrimExp VName
zero = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t (SubExp -> PrimExp VName) -> SubExp -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t
in [(PrimExp VName -> PrimExp VName -> PrimExp VName)
-> PrimExp VName -> [PrimExp VName] -> PrimExp VName
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
(~+~) PrimExp VName
zero ([PrimExp VName] -> PrimExp VName)
-> [PrimExp VName] -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ (PrimExp VName -> PrimExp VName -> PrimExp VName)
-> [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
(~*~) [PrimExp VName]
v [PrimExp VName]
r | [PrimExp VName]
r <- [[PrimExp VName]]
m]
vectorAdd :: [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName]
vectorAdd :: [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName]
vectorAdd = (PrimExp VName -> PrimExp VName -> PrimExp VName)
-> [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
(~+~)
orderArgs :: Special -> [a] -> [[a]]
orderArgs :: forall a. Special -> [a] -> [[a]]
orderArgs Special
s [a]
lst = Int -> [a] -> [[a]]
forall a. Int -> [a] -> [[a]]
chunk (Int -> Int -> Int
forall a. Integral a => a -> a -> a
div ([a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
lst) (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ Special -> Int
specialScans Special
s) [a]
lst
mkScanAdjointLam :: VjpOps -> Lambda SOACS -> FirstOrSecond -> [SubExp] -> ADM (Lambda SOACS)
mkScanAdjointLam :: VjpOps
-> Lambda SOACS -> FirstOrSecond -> [SubExp] -> ADM (Lambda SOACS)
mkScanAdjointLam VjpOps
ops Lambda SOACS
lam0 FirstOrSecond
which [SubExp]
adjs = do
let len :: Int
len = [TypeBase Shape NoUniqueness] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([TypeBase Shape NoUniqueness] -> Int)
-> [TypeBase Shape NoUniqueness] -> Int
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
lam0
Lambda SOACS
lam <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam0
let p2diff :: [Param (TypeBase Shape NoUniqueness)]
p2diff =
case FirstOrSecond
which of
FirstOrSecond
WrtFirst -> Int
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. Int -> [a] -> [a]
take Int
len ([Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)])
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam
FirstOrSecond
WrtSecond -> Int
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. Int -> [a] -> [a]
drop Int
len ([Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)])
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam
VjpOps -> [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda VjpOps
ops ((SubExp -> Adj) -> [SubExp] -> [Adj]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> Adj
AdjVal [SubExp]
adjs) ((Param (TypeBase Shape NoUniqueness) -> VName)
-> [Param (TypeBase Shape NoUniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName [Param (TypeBase Shape NoUniqueness)]
p2diff) Lambda SOACS
lam
mkScanFusedMapLam ::
VjpOps ->
SubExp ->
Lambda SOACS ->
[VName] ->
[VName] ->
[VName] ->
Special ->
Int ->
ADM (Lambda SOACS)
mkScanFusedMapLam :: VjpOps
-> SubExp
-> Lambda SOACS
-> [VName]
-> [VName]
-> [VName]
-> Special
-> Int
-> ADM (Lambda SOACS)
mkScanFusedMapLam VjpOps
ops SubExp
w Lambda SOACS
scn_lam [VName]
xs [VName]
ys [VName]
ys_adj Special
s Int
d = do
let sc :: Maybe SpecialCase
sc = Special -> Maybe SpecialCase
specialCase Special
s
k :: Int
k = Special -> Int
specialSubSize Special
s
[TypeBase Shape NoUniqueness]
ys_ts <- (VName -> ADM (TypeBase Shape NoUniqueness))
-> [VName] -> ADM [TypeBase Shape NoUniqueness]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse VName -> ADM (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType [VName]
ys
[[SubExp]]
idmat <- Int -> TypeBase Shape NoUniqueness -> ADM [[SubExp]]
identityM ([VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
ys) (TypeBase Shape NoUniqueness -> ADM [[SubExp]])
-> TypeBase Shape NoUniqueness -> ADM [[SubExp]]
forall a b. (a -> b) -> a -> b
$ TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. TypeBase Shape u -> TypeBase Shape u
rowType (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall a b. (a -> b) -> a -> b
$ [TypeBase Shape NoUniqueness] -> TypeBase Shape NoUniqueness
forall a. HasCallStack => [a] -> a
head [TypeBase Shape NoUniqueness]
ys_ts
[Lambda SOACS]
lams <- ([SubExp] -> ADM (Lambda SOACS))
-> [[SubExp]] -> ADM [Lambda SOACS]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (VjpOps
-> Lambda SOACS -> FirstOrSecond -> [SubExp] -> ADM (Lambda SOACS)
mkScanAdjointLam VjpOps
ops Lambda SOACS
scn_lam FirstOrSecond
WrtFirst) [[SubExp]]
idmat
Param (TypeBase Shape NoUniqueness)
par_i <- [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"i" (TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness)))
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
let i :: VName
i = Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
par_i
[LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [Param (TypeBase Shape NoUniqueness)
LParam (Rep ADM)
par_i] (ADM Result -> ADM (Lambda (Rep ADM)))
-> ADM Result -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
([VName] -> Result) -> ADM [VName] -> ADM Result
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes (ADM [VName] -> ADM Result)
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"x"
(Exp SOACS -> ADM Result) -> ADM (Exp SOACS) -> ADM Result
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0)
( ADM Result -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (ADM Result -> ADM (Body (Rep ADM)))
-> ADM Result -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
SubExp
j <- [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"j" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
w TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1))
[SubExp]
y_s <- [VName] -> (VName -> ADM SubExp) -> ADM [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
ys_adj ((VName -> ADM SubExp) -> ADM [SubExp])
-> (VName -> ADM SubExp) -> ADM [SubExp]
forall a b. (a -> b) -> a -> b
$ \VName
y_ ->
[Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp (VName -> [Char]
baseString VName
y_ [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_j") (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> [ADM (Exp (Rep ADM))] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
eIndex VName
y_ [SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
j]
let zso :: [[SubExp]]
zso = Special -> [SubExp] -> [[SubExp]]
forall a. Special -> [a] -> [[a]]
orderArgs Special
s [SubExp]
y_s
let ido :: [[[SubExp]]]
ido = Special -> [[SubExp]] -> [[[SubExp]]]
forall a. Special -> [a] -> [[a]]
orderArgs Special
s ([[SubExp]] -> [[[SubExp]]]) -> [[SubExp]] -> [[[SubExp]]]
forall a b. (a -> b) -> a -> b
$ Int -> Maybe SpecialCase -> [[SubExp]] -> [[SubExp]]
forall a. Int -> Maybe SpecialCase -> [[a]] -> [[a]]
caseJac Int
k Maybe SpecialCase
sc [[SubExp]]
idmat
Result -> ADM Result
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> ADM Result) -> Result -> ADM Result
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes ([SubExp] -> Result) -> [SubExp] -> Result
forall a b. (a -> b) -> a -> b
$ [[SubExp]] -> [SubExp]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[SubExp]] -> [SubExp]) -> [[SubExp]] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ ([SubExp] -> [SubExp] -> [SubExp])
-> [[SubExp]] -> [[SubExp]] -> [[SubExp]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
(++) [[SubExp]]
zso ([[SubExp]] -> [[SubExp]]) -> [[SubExp]] -> [[SubExp]]
forall a b. (a -> b) -> a -> b
$ ([[SubExp]] -> [SubExp]) -> [[[SubExp]]] -> [[SubExp]]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [[SubExp]] -> [SubExp]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[[SubExp]]]
ido
)
( ADM Result -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (ADM Result -> ADM (Body (Rep ADM)))
-> ADM Result -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
SubExp
j <- [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"j" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
w TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1))
SubExp
j1 <- [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"j1" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
w TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i)
[SubExp]
y_s <- [VName] -> (VName -> ADM SubExp) -> ADM [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
ys_adj ((VName -> ADM SubExp) -> ADM [SubExp])
-> (VName -> ADM SubExp) -> ADM [SubExp]
forall a b. (a -> b) -> a -> b
$ \VName
y_ ->
[Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp (VName -> [Char]
baseString VName
y_ [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_j") (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> [ADM (Exp (Rep ADM))] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
eIndex VName
y_ [SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
j]
let args :: [ADM (Exp (Rep ADM))]
args =
(VName -> ADM (Exp (Rep ADM))) -> [VName] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> [ADM (Exp (Rep ADM))] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
`eIndex` [SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
j]) [VName]
ys [ADM (Exp (Rep ADM))]
-> [ADM (Exp (Rep ADM))] -> [ADM (Exp (Rep ADM))]
forall a. [a] -> [a] -> [a]
++ (VName -> ADM (Exp (Rep ADM))) -> [VName] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> [ADM (Exp (Rep ADM))] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
`eIndex` [SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
j1]) [VName]
xs
[Result]
lam_rs <- (Lambda SOACS -> ADM Result) -> [Lambda SOACS] -> ADM [Result]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM Result
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
`eLambda` [ADM (Exp (Rep ADM))]
args) [Lambda SOACS]
lams
let yso :: [Result]
yso = Special -> Result -> [Result]
forall a. Special -> [a] -> [[a]]
orderArgs Special
s (Result -> [Result]) -> Result -> [Result]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes [SubExp]
y_s
let jaco :: [[Result]]
jaco = Special -> [Result] -> [[Result]]
forall a. Special -> [a] -> [[a]]
orderArgs Special
s ([Result] -> [[Result]]) -> [Result] -> [[Result]]
forall a b. (a -> b) -> a -> b
$ Int -> Maybe SpecialCase -> [Result] -> [Result]
forall a. Int -> Maybe SpecialCase -> [[a]] -> [[a]]
caseJac Int
k Maybe SpecialCase
sc ([Result] -> [Result]) -> [Result] -> [Result]
forall a b. (a -> b) -> a -> b
$ [Result] -> [Result]
forall a. [[a]] -> [[a]]
transpose [Result]
lam_rs
Result -> ADM Result
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> ADM Result) -> Result -> ADM Result
forall a b. (a -> b) -> a -> b
$ [Result] -> Result
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([Result] -> Result) -> [Result] -> Result
forall a b. (a -> b) -> a -> b
$ (Result -> Result -> Result) -> [Result] -> [Result] -> [Result]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Result -> Result -> Result
forall a. [a] -> [a] -> [a]
(++) [Result]
yso ([Result] -> [Result]) -> [Result] -> [Result]
forall a b. (a -> b) -> a -> b
$ ([Result] -> Result) -> [[Result]] -> [Result]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Result] -> Result
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Result]]
jaco
)
where
caseJac :: Int -> Maybe SpecialCase -> [[a]] -> [[a]]
caseJac :: forall a. Int -> Maybe SpecialCase -> [[a]] -> [[a]]
caseJac Int
_ Maybe SpecialCase
Nothing [[a]]
jac = [[a]]
jac
caseJac Int
k (Just SpecialCase
ZeroQuadrant) [[a]]
jac =
[[[a]]] -> [[a]]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[[a]]] -> [[a]]) -> [[[a]]] -> [[a]]
forall a b. (a -> b) -> a -> b
$
(Int -> [[a]] -> [[a]]) -> [Int] -> [[[a]]] -> [[[a]]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Int
i -> ([a] -> [a]) -> [[a]] -> [[a]]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
take Int
k ([a] -> [a]) -> ([a] -> [a]) -> [a] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
drop (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
k))) [Int
0 .. Int
d Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
k] ([[[a]]] -> [[[a]]]) -> [[[a]]] -> [[[a]]]
forall a b. (a -> b) -> a -> b
$
Int -> [[a]] -> [[[a]]]
forall a. Int -> [a] -> [[a]]
chunk Int
k [[a]]
jac
caseJac Int
k (Just SpecialCase
MatrixMul) [[a]]
jac =
Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
take Int
k ([a] -> [a]) -> [[a]] -> [[a]]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> [[a]] -> [[a]]
forall a. Int -> [a] -> [a]
take Int
k [[a]]
jac
linFunT0 :: [PrimExp VName] -> [PrimExp VName] -> [[PrimExp VName]] -> Special -> PrimType -> [PrimExp VName]
linFunT0 :: [PrimExp VName]
-> [PrimExp VName]
-> [[PrimExp VName]]
-> Special
-> PrimType
-> [PrimExp VName]
linFunT0 [PrimExp VName]
a1 [PrimExp VName]
a2 [[PrimExp VName]]
b Special
s PrimType
pt =
let t :: [PrimExp VName]
t = case Special -> Maybe SpecialCase
specialCase Special
s of
Just SpecialCase
MatrixMul ->
([PrimExp VName] -> [PrimExp VName])
-> [[PrimExp VName]] -> [PrimExp VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\[PrimExp VName]
v -> [[PrimExp VName]] -> [PrimExp VName] -> PrimType -> [PrimExp VName]
matrixVecMul [[PrimExp VName]]
b [PrimExp VName]
v PrimType
pt) ([[PrimExp VName]] -> [PrimExp VName])
-> [[PrimExp VName]] -> [PrimExp VName]
forall a b. (a -> b) -> a -> b
$ Int -> [PrimExp VName] -> [[PrimExp VName]]
forall a. Int -> [a] -> [[a]]
chunk (Special -> Int
specialSubSize Special
s) [PrimExp VName]
a1
Maybe SpecialCase
_ -> [[PrimExp VName]] -> [PrimExp VName] -> PrimType -> [PrimExp VName]
matrixVecMul [[PrimExp VName]]
b [PrimExp VName]
a1 PrimType
pt
in [PrimExp VName]
a2 [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName]
`vectorAdd` [PrimExp VName]
t
mkScanLinFunO :: Type -> Special -> ADM (Scan SOACS)
mkScanLinFunO :: TypeBase Shape NoUniqueness -> Special -> ADM (Scan SOACS)
mkScanLinFunO TypeBase Shape NoUniqueness
t Special
s = do
let pt :: PrimType
pt = TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
t
[SubExp]
neu_elm <- (Int, Int) -> ADM [SubExp]
mkNeutral ((Int, Int) -> ADM [SubExp]) -> (Int, Int) -> ADM [SubExp]
forall a b. (a -> b) -> a -> b
$ Special -> (Int, Int)
specialNeutral Special
s
let (Int
as, Int
bs) = Special -> (Int, Int)
specialParams Special
s
([VName]
a1s, [VName]
b1s, [VName]
a2s, [VName]
b2s) <- (Int, Int) -> ADM ([VName], [VName], [VName], [VName])
forall {m :: * -> *}.
MonadFreshNames m =>
(Int, Int) -> m ([VName], [VName], [VName], [VName])
mkParams (Int
as, Int
bs)
let pet :: VName -> PrimExp VName
pet = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
pt (SubExp -> PrimExp VName)
-> (VName -> SubExp) -> VName -> PrimExp VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var
let (Int
_, Int
n) = Special -> (Int, Int)
specialNeutral Special
s
Lambda SOACS
lam <- [LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda ((VName -> Param (TypeBase Shape NoUniqueness))
-> [VName] -> [Param (TypeBase Shape NoUniqueness)]
forall a b. (a -> b) -> [a] -> [b]
map (\VName
v -> Attrs
-> VName
-> TypeBase Shape NoUniqueness
-> Param (TypeBase Shape NoUniqueness)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty VName
v (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. TypeBase Shape u -> TypeBase Shape u
rowType TypeBase Shape NoUniqueness
t)) ([VName]
a1s [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
b1s [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
a2s [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
b2s)) (ADM Result -> ADM (Lambda SOACS))
-> (ADM [SubExp] -> ADM Result)
-> ADM [SubExp]
-> ADM (Lambda SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([SubExp] -> Result) -> ADM [SubExp] -> ADM Result
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [SubExp] -> Result
subExpsRes (ADM [SubExp] -> ADM (Lambda SOACS))
-> ADM [SubExp] -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ do
let [[PrimExp VName]
a1s', [PrimExp VName]
b1s', [PrimExp VName]
a2s', [PrimExp VName]
b2s'] = (([VName] -> [PrimExp VName]) -> [[VName]] -> [[PrimExp VName]]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([VName] -> [PrimExp VName]) -> [[VName]] -> [[PrimExp VName]])
-> ((VName -> PrimExp VName) -> [VName] -> [PrimExp VName])
-> (VName -> PrimExp VName)
-> [[VName]]
-> [[PrimExp VName]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> PrimExp VName) -> [VName] -> [PrimExp VName]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap) VName -> PrimExp VName
pet [[VName]
a1s, [VName]
b1s, [VName]
a2s, [VName]
b2s]
let ([[PrimExp VName]]
b1sm, [[PrimExp VName]]
b2sm) = (Int -> [PrimExp VName] -> [[PrimExp VName]]
forall a. Int -> [a] -> [[a]]
chunk Int
n [PrimExp VName]
b1s', Int -> [PrimExp VName] -> [[PrimExp VName]]
forall a. Int -> [a] -> [[a]]
chunk Int
n [PrimExp VName]
b2s')
let t0 :: [PrimExp VName]
t0 = [PrimExp VName]
-> [PrimExp VName]
-> [[PrimExp VName]]
-> Special
-> PrimType
-> [PrimExp VName]
linFunT0 [PrimExp VName]
a1s' [PrimExp VName]
a2s' [[PrimExp VName]]
b2sm Special
s PrimType
pt
let t1 :: [PrimExp VName]
t1 = [[PrimExp VName]] -> [PrimExp VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[PrimExp VName]] -> [PrimExp VName])
-> [[PrimExp VName]] -> [PrimExp VName]
forall a b. (a -> b) -> a -> b
$ [[PrimExp VName]]
-> [[PrimExp VName]] -> PrimType -> [[PrimExp VName]]
matrixMul [[PrimExp VName]]
b2sm [[PrimExp VName]]
b1sm PrimType
pt
(PrimExp VName -> ADM SubExp) -> [PrimExp VName] -> ADM [SubExp]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ([Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"r" (Exp SOACS -> ADM SubExp)
-> (PrimExp VName -> ADM (Exp SOACS))
-> PrimExp VName
-> ADM SubExp
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< PrimExp VName -> ADM (Exp (Rep ADM))
PrimExp VName -> ADM (Exp SOACS)
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
PrimExp VName -> m (Exp (Rep m))
toExp) ([PrimExp VName] -> ADM [SubExp])
-> [PrimExp VName] -> ADM [SubExp]
forall a b. (a -> b) -> a -> b
$ [PrimExp VName]
t0 [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName]
forall a. [a] -> [a] -> [a]
++ [PrimExp VName]
t1
Scan SOACS -> ADM (Scan SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Scan SOACS -> ADM (Scan SOACS)) -> Scan SOACS -> ADM (Scan SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [SubExp] -> Scan SOACS
forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda SOACS
lam [SubExp]
neu_elm
where
mkNeutral :: (Int, Int) -> ADM [SubExp]
mkNeutral (Int
a, Int
b) = do
[SubExp]
zeros <- Int -> ADM SubExp -> ADM [SubExp]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
a (ADM SubExp -> ADM [SubExp]) -> ADM SubExp -> ADM [SubExp]
forall a b. (a -> b) -> a -> b
$ [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"zeros" (Exp (Rep ADM) -> ADM SubExp) -> Exp (Rep ADM) -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ TypeBase Shape NoUniqueness -> Exp (Rep ADM)
forall rep. TypeBase Shape NoUniqueness -> Exp rep
zeroExp (TypeBase Shape NoUniqueness -> Exp (Rep ADM))
-> TypeBase Shape NoUniqueness -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. TypeBase Shape u -> TypeBase Shape u
rowType TypeBase Shape NoUniqueness
t
[[SubExp]]
idmat <- Int -> TypeBase Shape NoUniqueness -> ADM [[SubExp]]
identityM Int
b (TypeBase Shape NoUniqueness -> ADM [[SubExp]])
-> TypeBase Shape NoUniqueness -> ADM [[SubExp]]
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> TypeBase Shape NoUniqueness)
-> PrimType -> TypeBase Shape NoUniqueness
forall a b. (a -> b) -> a -> b
$ TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
t
[SubExp] -> ADM [SubExp]
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([SubExp] -> ADM [SubExp]) -> [SubExp] -> ADM [SubExp]
forall a b. (a -> b) -> a -> b
$ [SubExp]
zeros [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [[SubExp]] -> [SubExp]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[SubExp]]
idmat
mkParams :: (Int, Int) -> m ([VName], [VName], [VName], [VName])
mkParams (Int
a, Int
b) = do
[VName]
a1s <- Int -> m VName -> m [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
a (m VName -> m [VName]) -> m VName -> m [VName]
forall a b. (a -> b) -> a -> b
$ [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"a1"
[VName]
b1s <- Int -> m VName -> m [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
b (m VName -> m [VName]) -> m VName -> m [VName]
forall a b. (a -> b) -> a -> b
$ [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"b1"
[VName]
a2s <- Int -> m VName -> m [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
a (m VName -> m [VName]) -> m VName -> m [VName]
forall a b. (a -> b) -> a -> b
$ [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"a2"
[VName]
b2s <- Int -> m VName -> m [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
b (m VName -> m [VName]) -> m VName -> m [VName]
forall a b. (a -> b) -> a -> b
$ [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"b2"
([VName], [VName], [VName], [VName])
-> m ([VName], [VName], [VName], [VName])
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([VName]
a1s, [VName]
b1s, [VName]
a2s, [VName]
b2s)
mkScanFinalMap :: VjpOps -> SubExp -> Lambda SOACS -> [VName] -> [VName] -> [VName] -> ADM [VName]
mkScanFinalMap :: VjpOps
-> SubExp
-> Lambda SOACS
-> [VName]
-> [VName]
-> [VName]
-> ADM [VName]
mkScanFinalMap VjpOps
ops SubExp
w Lambda SOACS
scan_lam [VName]
xs [VName]
ys [VName]
ds = do
let eltps :: [TypeBase Shape NoUniqueness]
eltps = Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
scan_lam
Param (TypeBase Shape NoUniqueness)
par_i <- [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"i" (TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness)))
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
let i :: VName
i = Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
par_i
[Param (TypeBase Shape NoUniqueness)]
par_x <- (VName
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness)))
-> [VName]
-> [TypeBase Shape NoUniqueness]
-> ADM [Param (TypeBase Shape NoUniqueness)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_par_x")) [VName]
xs [TypeBase Shape NoUniqueness]
eltps
Lambda SOACS
map_lam <-
[LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda (Param (TypeBase Shape NoUniqueness)
par_i Param (TypeBase Shape NoUniqueness)
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. a -> [a] -> [a]
: [Param (TypeBase Shape NoUniqueness)]
par_x) (ADM Result -> ADM (Lambda (Rep ADM)))
-> ADM Result -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
SubExp
j <- [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"j" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
w TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1))
[VName]
dj <-
[VName] -> (VName -> ADM VName) -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
ds ((VName -> ADM VName) -> ADM [VName])
-> (VName -> ADM VName) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ \VName
dd ->
[Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
dd [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_dj") (Exp SOACS -> ADM VName) -> ADM (Exp SOACS) -> ADM VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> [ADM (Exp (Rep ADM))] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
eIndex VName
dd [SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
j]
([VName] -> Result) -> ADM [VName] -> ADM Result
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes (ADM [VName] -> ADM Result)
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"scan_contribs"
(Exp SOACS -> ADM Result) -> ADM (Exp SOACS) -> ADM Result
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0)
([SubExp] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM ([SubExp] -> ADM (Body (Rep ADM)))
-> [SubExp] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> SubExp
Var [VName]
dj)
( ADM Result -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (ADM Result -> ADM (Body (Rep ADM)))
-> ADM Result -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
Lambda SOACS
lam <- VjpOps
-> Lambda SOACS -> FirstOrSecond -> [SubExp] -> ADM (Lambda SOACS)
mkScanAdjointLam VjpOps
ops Lambda SOACS
scan_lam FirstOrSecond
WrtSecond ([SubExp] -> ADM (Lambda SOACS)) -> [SubExp] -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> SubExp
Var [VName]
dj
SubExp
im1 <- [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"im1" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)
[SubExp]
ys_im1 <- [VName] -> (VName -> ADM SubExp) -> ADM [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
ys ((VName -> ADM SubExp) -> ADM [SubExp])
-> (VName -> ADM SubExp) -> ADM [SubExp]
forall a b. (a -> b) -> a -> b
$ \VName
y ->
[Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp (VName -> [Char]
baseString VName
y [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_im1") (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> [ADM (Exp (Rep ADM))] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
eIndex VName
y [SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
im1]
let args :: [ADM (Exp (Rep ADM))]
args = (SubExp -> ADM (Exp (Rep ADM)))
-> [SubExp] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp ([SubExp] -> [ADM (Exp (Rep ADM))])
-> [SubExp] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ [SubExp]
ys_im1 [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ (Param (TypeBase Shape NoUniqueness) -> SubExp)
-> [Param (TypeBase Shape NoUniqueness)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (Param (TypeBase Shape NoUniqueness) -> VName)
-> Param (TypeBase Shape NoUniqueness)
-> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName) [Param (TypeBase Shape NoUniqueness)]
par_x
Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM Result
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda (Rep ADM)
Lambda SOACS
lam [ADM (Exp (Rep ADM))]
args
)
VName
iota <- [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"iota" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
[Char] -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"scan_contribs" (Exp (Rep ADM) -> ADM [VName]) -> Exp (Rep ADM) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM))
-> OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w (VName
iota VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
xs) (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
map_lam
data SpecialCase = ZeroQuadrant | MatrixMul deriving (Int -> SpecialCase -> [Char] -> [Char]
[SpecialCase] -> [Char] -> [Char]
SpecialCase -> [Char]
(Int -> SpecialCase -> [Char] -> [Char])
-> (SpecialCase -> [Char])
-> ([SpecialCase] -> [Char] -> [Char])
-> Show SpecialCase
forall a.
(Int -> a -> [Char] -> [Char])
-> (a -> [Char]) -> ([a] -> [Char] -> [Char]) -> Show a
$cshowsPrec :: Int -> SpecialCase -> [Char] -> [Char]
showsPrec :: Int -> SpecialCase -> [Char] -> [Char]
$cshow :: SpecialCase -> [Char]
show :: SpecialCase -> [Char]
$cshowList :: [SpecialCase] -> [Char] -> [Char]
showList :: [SpecialCase] -> [Char] -> [Char]
Show)
data Special = Special
{
Special -> (Int, Int)
specialNeutral :: (Int, Int),
Special -> (Int, Int)
specialParams :: (Int, Int),
Special -> Int
specialScans :: Int,
Special -> Int
specialSubSize :: Int,
Special -> Maybe SpecialCase
specialCase :: Maybe SpecialCase
}
deriving (Int -> Special -> [Char] -> [Char]
[Special] -> [Char] -> [Char]
Special -> [Char]
(Int -> Special -> [Char] -> [Char])
-> (Special -> [Char])
-> ([Special] -> [Char] -> [Char])
-> Show Special
forall a.
(Int -> a -> [Char] -> [Char])
-> (a -> [Char]) -> ([a] -> [Char] -> [Char]) -> Show a
$cshowsPrec :: Int -> Special -> [Char] -> [Char]
showsPrec :: Int -> Special -> [Char] -> [Char]
$cshow :: Special -> [Char]
show :: Special -> [Char]
$cshowList :: [Special] -> [Char] -> [Char]
showList :: [Special] -> [Char] -> [Char]
Show)
data ScanAlgo
=
GenericIFL23 Special
|
GenericPPAD
deriving (Int -> ScanAlgo -> [Char] -> [Char]
[ScanAlgo] -> [Char] -> [Char]
ScanAlgo -> [Char]
(Int -> ScanAlgo -> [Char] -> [Char])
-> (ScanAlgo -> [Char])
-> ([ScanAlgo] -> [Char] -> [Char])
-> Show ScanAlgo
forall a.
(Int -> a -> [Char] -> [Char])
-> (a -> [Char]) -> ([a] -> [Char] -> [Char]) -> Show a
$cshowsPrec :: Int -> ScanAlgo -> [Char] -> [Char]
showsPrec :: Int -> ScanAlgo -> [Char] -> [Char]
$cshow :: ScanAlgo -> [Char]
show :: ScanAlgo -> [Char]
$cshowList :: [ScanAlgo] -> [Char] -> [Char]
showList :: [ScanAlgo] -> [Char] -> [Char]
Show)
subMats :: Int -> [[Exp SOACS]] -> Exp SOACS -> Maybe Int
subMats :: Int -> [[Exp SOACS]] -> Exp SOACS -> Maybe Int
subMats Int
d [[Exp SOACS]]
mat Exp SOACS
zero =
let sub_d :: [Int]
sub_d = (Int -> Bool) -> [Int] -> [Int]
forall a. (a -> Bool) -> [a] -> [a]
filter (\Int
x -> Int
d Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
x Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) [Int
1 .. (Int
d Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2)]
poss :: [Bool]
poss = (Int -> Bool) -> [Int] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map (\Int
m -> (([Exp SOACS], Int) -> Bool) -> [([Exp SOACS], Int)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Int -> ([Exp SOACS], Int) -> Bool
ok Int
m) ([([Exp SOACS], Int)] -> Bool) -> [([Exp SOACS], Int)] -> Bool
forall a b. (a -> b) -> a -> b
$ [[Exp SOACS]] -> [Int] -> [([Exp SOACS], Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [[Exp SOACS]]
mat [Int
0 .. Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]) [Int]
sub_d
tmp :: [(Bool, Int)]
tmp = ((Bool, Int) -> Bool) -> [(Bool, Int)] -> [(Bool, Int)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool, Int) -> Bool
forall a b. (a, b) -> a
fst ([Bool] -> [Int] -> [(Bool, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
poss [Int]
sub_d)
in if [(Bool, Int)] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Bool, Int)]
tmp then Maybe Int
forall a. Maybe a
Nothing else Int -> Maybe Int
forall a. a -> Maybe a
Just (Int -> Maybe Int) -> Int -> Maybe Int
forall a b. (a -> b) -> a -> b
$ (Bool, Int) -> Int
forall a b. (a, b) -> b
snd ((Bool, Int) -> Int) -> (Bool, Int) -> Int
forall a b. (a -> b) -> a -> b
$ [(Bool, Int)] -> (Bool, Int)
forall a. HasCallStack => [a] -> a
head [(Bool, Int)]
tmp
where
ok :: Int -> ([Exp SOACS], Int) -> Bool
ok Int
m ([Exp SOACS]
row, Int
i) =
((Exp SOACS, Int) -> Bool) -> [(Exp SOACS, Int)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\(Exp SOACS
v, Int
j) -> Exp SOACS
v Exp SOACS -> Exp SOACS -> Bool
forall a. Eq a => a -> a -> Bool
== Exp SOACS
zero Bool -> Bool -> Bool
|| Int
i Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
j Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
m) ([(Exp SOACS, Int)] -> Bool) -> [(Exp SOACS, Int)] -> Bool
forall a b. (a -> b) -> a -> b
$
[Exp SOACS] -> [Int] -> [(Exp SOACS, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Exp SOACS]
row [Int
0 .. Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
cases :: Int -> Type -> [[Exp SOACS]] -> ScanAlgo
cases :: Int -> TypeBase Shape NoUniqueness -> [[Exp SOACS]] -> ScanAlgo
cases Int
d TypeBase Shape NoUniqueness
t [[Exp SOACS]]
mat = case Int -> [[Exp SOACS]] -> Exp SOACS -> Maybe Int
subMats Int
d [[Exp SOACS]]
mat (Exp SOACS -> Maybe Int) -> Exp SOACS -> Maybe Int
forall a b. (a -> b) -> a -> b
$ TypeBase Shape NoUniqueness -> Exp SOACS
forall rep. TypeBase Shape NoUniqueness -> Exp rep
zeroExp TypeBase Shape NoUniqueness
t of
Just Int
k ->
let nonZeros :: [[[Exp SOACS]]]
nonZeros = (Int -> [[Exp SOACS]] -> [[Exp SOACS]])
-> [Int] -> [[[Exp SOACS]]] -> [[[Exp SOACS]]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Int
i -> ([Exp SOACS] -> [Exp SOACS]) -> [[Exp SOACS]] -> [[Exp SOACS]]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> [Exp SOACS] -> [Exp SOACS]
forall a. Int -> [a] -> [a]
take Int
k ([Exp SOACS] -> [Exp SOACS])
-> ([Exp SOACS] -> [Exp SOACS]) -> [Exp SOACS] -> [Exp SOACS]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [Exp SOACS] -> [Exp SOACS]
forall a. Int -> [a] -> [a]
drop (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
k))) [Int
0 .. Int
d Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
k] ([[[Exp SOACS]]] -> [[[Exp SOACS]]])
-> [[[Exp SOACS]]] -> [[[Exp SOACS]]]
forall a b. (a -> b) -> a -> b
$ Int -> [[Exp SOACS]] -> [[[Exp SOACS]]]
forall a. Int -> [a] -> [[a]]
chunk Int
k [[Exp SOACS]]
mat
in if ([[Exp SOACS]] -> Bool) -> [[[Exp SOACS]]] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ([[Exp SOACS]] -> [[Exp SOACS]] -> Bool
forall a. Eq a => a -> a -> Bool
== [[[Exp SOACS]]] -> [[Exp SOACS]]
forall a. HasCallStack => [a] -> a
head [[[Exp SOACS]]]
nonZeros) ([[[Exp SOACS]]] -> Bool) -> [[[Exp SOACS]]] -> Bool
forall a b. (a -> b) -> a -> b
$ [[[Exp SOACS]]] -> [[[Exp SOACS]]]
forall a. HasCallStack => [a] -> [a]
tail [[[Exp SOACS]]]
nonZeros
then Special -> ScanAlgo
GenericIFL23 (Special -> ScanAlgo) -> Special -> ScanAlgo
forall a b. (a -> b) -> a -> b
$ (Int, Int)
-> (Int, Int) -> Int -> Int -> Maybe SpecialCase -> Special
Special (Int
d, Int
k) (Int
d, Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
k) Int
1 Int
k (Maybe SpecialCase -> Special) -> Maybe SpecialCase -> Special
forall a b. (a -> b) -> a -> b
$ SpecialCase -> Maybe SpecialCase
forall a. a -> Maybe a
Just SpecialCase
MatrixMul
else Special -> ScanAlgo
GenericIFL23 (Special -> ScanAlgo) -> Special -> ScanAlgo
forall a b. (a -> b) -> a -> b
$ (Int, Int)
-> (Int, Int) -> Int -> Int -> Maybe SpecialCase -> Special
Special (Int
k, Int
k) (Int
k, Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
k) (Int
d Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
k) Int
k (Maybe SpecialCase -> Special) -> Maybe SpecialCase -> Special
forall a b. (a -> b) -> a -> b
$ SpecialCase -> Maybe SpecialCase
forall a. a -> Maybe a
Just SpecialCase
ZeroQuadrant
Maybe Int
Nothing ->
case Int
d of
Int
1 -> Special -> ScanAlgo
GenericIFL23 (Special -> ScanAlgo) -> Special -> ScanAlgo
forall a b. (a -> b) -> a -> b
$ (Int, Int)
-> (Int, Int) -> Int -> Int -> Maybe SpecialCase -> Special
Special (Int
d, Int
d) (Int
d, Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
d) Int
1 Int
d Maybe SpecialCase
forall a. Maybe a
Nothing
Int
_ -> ScanAlgo
GenericPPAD
identifyCase :: VjpOps -> Lambda SOACS -> ADM ScanAlgo
identifyCase :: VjpOps -> Lambda SOACS -> ADM ScanAlgo
identifyCase VjpOps
ops Lambda SOACS
lam = do
let t :: [TypeBase Shape NoUniqueness]
t = Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
lam
let d :: Int
d = [TypeBase Shape NoUniqueness] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TypeBase Shape NoUniqueness]
t
[[SubExp]]
idmat <- Int -> TypeBase Shape NoUniqueness -> ADM [[SubExp]]
identityM Int
d (TypeBase Shape NoUniqueness -> ADM [[SubExp]])
-> TypeBase Shape NoUniqueness -> ADM [[SubExp]]
forall a b. (a -> b) -> a -> b
$ [TypeBase Shape NoUniqueness] -> TypeBase Shape NoUniqueness
forall a. HasCallStack => [a] -> a
head [TypeBase Shape NoUniqueness]
t
[Lambda SOACS]
lams <- ([SubExp] -> ADM (Lambda SOACS))
-> [[SubExp]] -> ADM [Lambda SOACS]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (VjpOps
-> Lambda SOACS -> FirstOrSecond -> [SubExp] -> ADM (Lambda SOACS)
mkScanAdjointLam VjpOps
ops Lambda SOACS
lam FirstOrSecond
WrtFirst) [[SubExp]]
idmat
[Param (TypeBase Shape NoUniqueness)]
par1 <- (TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness)))
-> [TypeBase Shape NoUniqueness]
-> ADM [Param (TypeBase Shape NoUniqueness)]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ([Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"tmp1") [TypeBase Shape NoUniqueness]
t
[Param (TypeBase Shape NoUniqueness)]
par2 <- (TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness)))
-> [TypeBase Shape NoUniqueness]
-> ADM [Param (TypeBase Shape NoUniqueness)]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ([Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"tmp2") [TypeBase Shape NoUniqueness]
t
Lambda SOACS
jac_lam <- [LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda ([Param (TypeBase Shape NoUniqueness)]
par1 [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ [Param (TypeBase Shape NoUniqueness)]
par2) (ADM Result -> ADM (Lambda (Rep ADM)))
-> ADM Result -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
let args :: [ADM (Exp (Rep ADM))]
args = (Param (TypeBase Shape NoUniqueness) -> ADM (Exp (Rep ADM)))
-> [Param (TypeBase Shape NoUniqueness)] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Param (TypeBase Shape NoUniqueness) -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam ([Param (TypeBase Shape NoUniqueness)] -> [ADM (Exp (Rep ADM))])
-> [Param (TypeBase Shape NoUniqueness)] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ [Param (TypeBase Shape NoUniqueness)]
par1 [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ [Param (TypeBase Shape NoUniqueness)]
par2
[Result]
lam_rs <- (Lambda SOACS -> ADM Result) -> [Lambda SOACS] -> ADM [Result]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM Result
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
`eLambda` [ADM (Exp (Rep ADM))]
args) [Lambda SOACS]
lams
Result -> ADM Result
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> ADM Result) -> Result -> ADM Result
forall a b. (a -> b) -> a -> b
$ [Result] -> Result
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([Result] -> [Result]
forall a. [[a]] -> [[a]]
transpose [Result]
lam_rs)
Lambda SOACS
simp <- Lambda SOACS -> ADM (Lambda SOACS)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Lambda SOACS -> m (Lambda SOACS)
simplifyLambda Lambda SOACS
jac_lam
let jac :: [[Exp rep]]
jac = Int -> [Exp rep] -> [[Exp rep]]
forall a. Int -> [a] -> [[a]]
chunk Int
d ([Exp rep] -> [[Exp rep]]) -> [Exp rep] -> [[Exp rep]]
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> Exp rep) -> Result -> [Exp rep]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep)
-> (SubExpRes -> BasicOp) -> SubExpRes -> Exp rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp (SubExp -> BasicOp)
-> (SubExpRes -> SubExp) -> SubExpRes -> BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) (Result -> [Exp rep]) -> Result -> [Exp rep]
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult (Body SOACS -> Result) -> Body SOACS -> Result
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
simp
ScanAlgo -> ADM ScanAlgo
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ScanAlgo -> ADM ScanAlgo) -> ScanAlgo -> ADM ScanAlgo
forall a b. (a -> b) -> a -> b
$ Int -> TypeBase Shape NoUniqueness -> [[Exp SOACS]] -> ScanAlgo
cases Int
d ([TypeBase Shape NoUniqueness] -> TypeBase Shape NoUniqueness
forall a. HasCallStack => [a] -> a
head [TypeBase Shape NoUniqueness]
t) [[Exp SOACS]]
forall {rep}. [[Exp rep]]
jac
scanRight :: [VName] -> SubExp -> Scan SOACS -> ADM [VName]
scanRight :: [VName] -> SubExp -> Scan SOACS -> ADM [VName]
scanRight [VName]
as SubExp
w Scan SOACS
scan = do
[TypeBase Shape NoUniqueness]
as_types <- (VName -> ADM (TypeBase Shape NoUniqueness))
-> [VName] -> ADM [TypeBase Shape NoUniqueness]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType [VName]
as
let arg_type_row :: [TypeBase Shape NoUniqueness]
arg_type_row = (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. TypeBase Shape u -> TypeBase Shape u
rowType [TypeBase Shape NoUniqueness]
as_types
[Param (TypeBase Shape NoUniqueness)]
par_a1 <- (VName
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness)))
-> [VName]
-> [TypeBase Shape NoUniqueness]
-> ADM [Param (TypeBase Shape NoUniqueness)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_par_a1")) [VName]
as [TypeBase Shape NoUniqueness]
arg_type_row
[Param (TypeBase Shape NoUniqueness)]
par_a2 <- (VName
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness)))
-> [VName]
-> [TypeBase Shape NoUniqueness]
-> ADM [Param (TypeBase Shape NoUniqueness)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_par_a2")) [VName]
as [TypeBase Shape NoUniqueness]
arg_type_row
Lambda SOACS
rev_op <- [LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda ([Param (TypeBase Shape NoUniqueness)]
par_a1 [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. Semigroup a => a -> a -> a
<> [Param (TypeBase Shape NoUniqueness)]
par_a2) (ADM Result -> ADM (Lambda (Rep ADM)))
-> ADM Result -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
Lambda SOACS
op <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda (Lambda SOACS -> ADM (Lambda SOACS))
-> Lambda SOACS -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ Scan SOACS -> Lambda SOACS
forall rep. Scan rep -> Lambda rep
scanLambda Scan SOACS
scan
Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM Result
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda (Rep ADM)
Lambda SOACS
op ((Param (TypeBase Shape NoUniqueness) -> ADM (Exp SOACS))
-> [Param (TypeBase Shape NoUniqueness)] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> ADM (Exp (Rep ADM))
VName -> ADM (Exp SOACS)
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *). MonadBuilder m => VName -> m (Exp (Rep m))
toExp (VName -> ADM (Exp SOACS))
-> (Param (TypeBase Shape NoUniqueness) -> VName)
-> Param (TypeBase Shape NoUniqueness)
-> ADM (Exp SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName) ([Param (TypeBase Shape NoUniqueness)]
par_a2 [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. Semigroup a => a -> a -> a
<> [Param (TypeBase Shape NoUniqueness)]
par_a1))
let e :: [SubExp]
e = Scan SOACS -> [SubExp]
forall rep. Scan rep -> [SubExp]
scanNeutral Scan SOACS
scan
let rev_scan :: Scan SOACS
rev_scan = Lambda SOACS -> [SubExp] -> Scan SOACS
forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda SOACS
rev_op [SubExp]
e
VName
iota <-
[Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"iota" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
Lambda SOACS
map_scan <- [VName] -> ADM (Lambda SOACS)
revArrLam [VName]
as
[VName]
scan_res <-
[Char] -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"adj_ctrb_scan" (Exp SOACS -> ADM [VName])
-> (ScremaForm SOACS -> Exp SOACS)
-> ScremaForm SOACS
-> ADM [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> Exp SOACS)
-> (ScremaForm SOACS -> SOAC SOACS)
-> ScremaForm SOACS
-> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName
iota] (ScremaForm SOACS -> ADM [VName])
-> ScremaForm SOACS -> ADM [VName]
forall a b. (a -> b) -> a -> b
$
[Scan SOACS] -> Lambda SOACS -> ScremaForm SOACS
forall rep. [Scan rep] -> Lambda rep -> ScremaForm rep
scanomapSOAC [Scan SOACS
rev_scan] Lambda SOACS
map_scan
Lambda SOACS
rev_lam <- [VName] -> ADM (Lambda SOACS)
revArrLam [VName]
scan_res
[Char] -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"reverse_scan_result" (Exp (Rep ADM) -> ADM [VName]) -> Exp (Rep ADM) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM))
-> OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName
iota] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
rev_lam
where
revArrLam :: [VName] -> ADM (Lambda SOACS)
revArrLam :: [VName] -> ADM (Lambda SOACS)
revArrLam [VName]
arrs = do
Param (TypeBase Shape NoUniqueness)
par_i <- [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"i" (TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness)))
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
[LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [Param (TypeBase Shape NoUniqueness)
LParam (Rep ADM)
par_i] (ADM Result -> ADM (Lambda SOACS))
-> ((VName -> ADM SubExpRes) -> ADM Result)
-> (VName -> ADM SubExpRes)
-> ADM (Lambda SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [VName] -> (VName -> ADM SubExpRes) -> ADM Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
arrs ((VName -> ADM SubExpRes) -> ADM (Lambda SOACS))
-> (VName -> ADM SubExpRes) -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ \VName
arr ->
(VName -> SubExpRes) -> ADM VName -> ADM SubExpRes
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> SubExpRes
varRes (ADM VName -> ADM SubExpRes)
-> (Exp SOACS -> ADM VName) -> Exp SOACS -> ADM SubExpRes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"ys_bar_rev"
(Exp SOACS -> ADM SubExpRes) -> ADM (Exp SOACS) -> ADM SubExpRes
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> [ADM (Exp (Rep ADM))] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
eIndex VName
arr [TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
w TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
par_i) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)]
mkPPADOpLifted :: VjpOps -> [VName] -> Scan SOACS -> ADM (Lambda SOACS)
mkPPADOpLifted :: VjpOps -> [VName] -> Scan SOACS -> ADM (Lambda SOACS)
mkPPADOpLifted VjpOps
ops [VName]
as Scan SOACS
scan = do
[TypeBase Shape NoUniqueness]
as_types <- (VName -> ADM (TypeBase Shape NoUniqueness))
-> [VName] -> ADM [TypeBase Shape NoUniqueness]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType [VName]
as
let arg_type_row :: [TypeBase Shape NoUniqueness]
arg_type_row = (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. TypeBase Shape u -> TypeBase Shape u
rowType [TypeBase Shape NoUniqueness]
as_types
[Param (TypeBase Shape NoUniqueness)]
par_x1 <- (VName
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness)))
-> [VName]
-> [TypeBase Shape NoUniqueness]
-> ADM [Param (TypeBase Shape NoUniqueness)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_par_x1")) [VName]
as [TypeBase Shape NoUniqueness]
arg_type_row
[Param (TypeBase Shape NoUniqueness)]
par_x2_unused <- (VName
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness)))
-> [VName]
-> [TypeBase Shape NoUniqueness]
-> ADM [Param (TypeBase Shape NoUniqueness)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_par_x2_unused")) [VName]
as [TypeBase Shape NoUniqueness]
arg_type_row
[Param (TypeBase Shape NoUniqueness)]
par_a1 <- (VName
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness)))
-> [VName]
-> [TypeBase Shape NoUniqueness]
-> ADM [Param (TypeBase Shape NoUniqueness)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_par_a1")) [VName]
as [TypeBase Shape NoUniqueness]
arg_type_row
[Param (TypeBase Shape NoUniqueness)]
par_a2 <- (VName
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness)))
-> [VName]
-> [TypeBase Shape NoUniqueness]
-> ADM [Param (TypeBase Shape NoUniqueness)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_par_a2")) [VName]
as [TypeBase Shape NoUniqueness]
arg_type_row
[Param (TypeBase Shape NoUniqueness)]
par_y1_h <- (VName
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness)))
-> [VName]
-> [TypeBase Shape NoUniqueness]
-> ADM [Param (TypeBase Shape NoUniqueness)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_par_y1_h")) [VName]
as [TypeBase Shape NoUniqueness]
arg_type_row
[Param (TypeBase Shape NoUniqueness)]
par_y2_h <- (VName
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness)))
-> [VName]
-> [TypeBase Shape NoUniqueness]
-> ADM [Param (TypeBase Shape NoUniqueness)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_par_y2_h")) [VName]
as [TypeBase Shape NoUniqueness]
arg_type_row
[Lambda SOACS]
add_lams <- (TypeBase Shape NoUniqueness -> ADM (Lambda SOACS))
-> [TypeBase Shape NoUniqueness] -> ADM [Lambda SOACS]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM TypeBase Shape NoUniqueness -> ADM (Lambda SOACS)
addLambda [TypeBase Shape NoUniqueness]
arg_type_row
[LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda
([Param (TypeBase Shape NoUniqueness)]
par_x1 [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ [Param (TypeBase Shape NoUniqueness)]
par_a1 [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ [Param (TypeBase Shape NoUniqueness)]
par_y1_h [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ [Param (TypeBase Shape NoUniqueness)]
par_x2_unused [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ [Param (TypeBase Shape NoUniqueness)]
par_a2 [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ [Param (TypeBase Shape NoUniqueness)]
par_y2_h)
([Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [Lambda SOACS]
-> ADM Result
forall {dec} {dec} {dec}.
[Param dec]
-> [Param dec]
-> [Param dec]
-> [Param dec]
-> [Param dec]
-> [Lambda SOACS]
-> ADM Result
op_lift [Param (TypeBase Shape NoUniqueness)]
par_x1 [Param (TypeBase Shape NoUniqueness)]
par_a1 [Param (TypeBase Shape NoUniqueness)]
par_y1_h [Param (TypeBase Shape NoUniqueness)]
par_a2 [Param (TypeBase Shape NoUniqueness)]
par_y2_h [Lambda SOACS]
add_lams)
where
op_lift :: [Param dec]
-> [Param dec]
-> [Param dec]
-> [Param dec]
-> [Param dec]
-> [Lambda SOACS]
-> ADM Result
op_lift [Param dec]
px1 [Param dec]
pa1 [Param dec]
py1 [Param dec]
pa2 [Param dec]
py2 [Lambda SOACS]
adds = do
Lambda SOACS
op_bar_1 <- VjpOps
-> Lambda SOACS -> FirstOrSecond -> [SubExp] -> ADM (Lambda SOACS)
mkScanAdjointLam VjpOps
ops (Scan SOACS -> Lambda SOACS
forall rep. Scan rep -> Lambda rep
scanLambda Scan SOACS
scan) FirstOrSecond
WrtFirst (VName -> SubExp
Var (VName -> SubExp) -> (Param dec -> VName) -> Param dec -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> VName
forall dec. Param dec -> VName
paramName (Param dec -> SubExp) -> [Param dec] -> [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Param dec]
py2)
let op_bar_args :: [ADM (Exp (Rep ADM))]
op_bar_args = SubExp -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
toExp (SubExp -> ADM (Exp (Rep ADM)))
-> (Param dec -> SubExp) -> Param dec -> ADM (Exp (Rep ADM))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> SubExp) -> (Param dec -> VName) -> Param dec -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> VName
forall dec. Param dec -> VName
paramName (Param dec -> ADM (Exp (Rep ADM)))
-> [Param dec] -> [ADM (Exp (Rep ADM))]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Param dec]
px1 [Param dec] -> [Param dec] -> [Param dec]
forall a. [a] -> [a] -> [a]
++ [Param dec]
pa1
[SubExp]
z_term <- (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (Result -> [SubExp]) -> ADM Result -> ADM [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM Result
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda (Rep ADM)
Lambda SOACS
op_bar_1 [ADM (Exp (Rep ADM))]
op_bar_args
let z :: ADM Result
z =
((SubExp, SubExp, Lambda (Rep ADM)) -> ADM SubExpRes)
-> [(SubExp, SubExp, Lambda (Rep ADM))] -> ADM Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM
(\(SubExp
z_t, SubExp
y_1, Lambda (Rep ADM)
add) -> Result -> SubExpRes
forall a. HasCallStack => [a] -> a
head (Result -> SubExpRes) -> ADM Result -> ADM SubExpRes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM Result
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda (Rep ADM)
add [SubExp -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
toExp SubExp
z_t, SubExp -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
toExp SubExp
y_1])
([SubExp]
-> [SubExp]
-> [Lambda (Rep ADM)]
-> [(SubExp, SubExp, Lambda (Rep ADM))]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [SubExp]
z_term (VName -> SubExp
Var (VName -> SubExp) -> (Param dec -> VName) -> Param dec -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> VName
forall dec. Param dec -> VName
paramName (Param dec -> SubExp) -> [Param dec] -> [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Param dec]
py1) [Lambda (Rep ADM)]
[Lambda SOACS]
adds)
let x1 :: ADM Result
x1 = [SubExp] -> Result
subExpsRes ([SubExp] -> Result) -> ADM [SubExp] -> ADM Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Param dec -> ADM SubExp) -> [Param dec] -> ADM [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ([Char] -> SubExp -> ADM SubExp
forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
[Char] -> a -> m SubExp
toSubExp [Char]
"x1" (SubExp -> ADM SubExp)
-> (Param dec -> SubExp) -> Param dec -> ADM SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> SubExp) -> (Param dec -> VName) -> Param dec -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> VName
forall dec. Param dec -> VName
paramName) [Param dec]
px1
Lambda SOACS
op <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda (Lambda SOACS -> ADM (Lambda SOACS))
-> Lambda SOACS -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ Scan SOACS -> Lambda SOACS
forall rep. Scan rep -> Lambda rep
scanLambda Scan SOACS
scan
let a3 :: ADM Result
a3 = Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM Result
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda (Rep ADM)
Lambda SOACS
op (VName -> ADM (Exp (Rep ADM))
VName -> ADM (Exp SOACS)
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *). MonadBuilder m => VName -> m (Exp (Rep m))
toExp (VName -> ADM (Exp SOACS))
-> (Param dec -> VName) -> Param dec -> ADM (Exp SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> VName
forall dec. Param dec -> VName
paramName (Param dec -> ADM (Exp SOACS)) -> [Param dec] -> [ADM (Exp SOACS)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Param dec]
pa1 [Param dec] -> [Param dec] -> [Param dec]
forall a. [a] -> [a] -> [a]
++ [Param dec]
pa2)
[Result] -> Result
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([Result] -> Result) -> ADM [Result] -> ADM Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [ADM Result] -> ADM [Result]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence [ADM Result
x1, ADM Result
a3, ADM Result
z]
asLiftPPAD :: [VName] -> SubExp -> [SubExp] -> ADM [VName]
asLiftPPAD :: [VName] -> SubExp -> [SubExp] -> ADM [VName]
asLiftPPAD [VName]
as SubExp
w [SubExp]
e = do
Param (TypeBase Shape NoUniqueness)
par_i <- [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"i" (TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness)))
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
Lambda SOACS
lmb <- [LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [Param (TypeBase Shape NoUniqueness)
LParam (Rep ADM)
par_i] (ADM Result -> ADM (Lambda (Rep ADM)))
-> ADM Result -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
[(VName, SubExp)]
-> ((VName, SubExp) -> ADM SubExpRes) -> ADM Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
as [SubExp]
e) (((VName, SubExp) -> ADM SubExpRes) -> ADM Result)
-> ((VName, SubExp) -> ADM SubExpRes) -> ADM Result
forall a b. (a -> b) -> a -> b
$ \(VName
arr, SubExp
arr_e) -> do
VName
a_lift <-
[Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"a_lift"
(Exp SOACS -> ADM VName) -> ADM (Exp SOACS) -> ADM VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
( do
SubExp
nm1 <- [Char] -> TPrimExp Int64 VName -> ADM SubExp
forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
[Char] -> a -> m SubExp
toSubExp [Char]
"n_minus_one" (TPrimExp Int64 VName -> ADM SubExp)
-> TPrimExp Int64 VName -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 SubExp
w TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1
Exp SOACS -> ADM (Exp SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp SOACS -> ADM (Exp SOACS)) -> Exp SOACS -> ADM (Exp SOACS)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (IntType -> CmpOp
CmpSlt IntType
Int64) (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
par_i) SubExp
nm1
)
( ADM Result -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (ADM Result -> ADM (Body (Rep ADM)))
-> ADM Result -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ (\SubExp
x -> [SubExp -> SubExpRes
subExpRes SubExp
x]) (SubExp -> Result) -> ADM SubExp -> ADM Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"val" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> [ADM (Exp (Rep ADM))] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
eIndex VName
arr [TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (TPrimExp Int64 VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
par_i) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1])
)
(ADM Result -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (ADM Result -> ADM (Body (Rep ADM)))
-> ADM Result -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ Result -> ADM Result
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp -> SubExpRes
subExpRes SubExp
arr_e])
SubExpRes -> ADM SubExpRes
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExpRes -> ADM SubExpRes) -> SubExpRes -> ADM SubExpRes
forall a b. (a -> b) -> a -> b
$ VName -> SubExpRes
varRes VName
a_lift
VName
iota <- [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"iota" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
[Char] -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"as_lift" (Exp (Rep ADM) -> ADM [VName]) -> Exp (Rep ADM) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM))
-> OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName
iota] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lmb
ysRightPPAD :: [VName] -> SubExp -> [SubExp] -> ADM [VName]
ysRightPPAD :: [VName] -> SubExp -> [SubExp] -> ADM [VName]
ysRightPPAD [VName]
ys SubExp
w [SubExp]
e = do
Param (TypeBase Shape NoUniqueness)
par_i <- [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"i" (TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness)))
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
Lambda SOACS
lmb <- [LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [Param (TypeBase Shape NoUniqueness)
LParam (Rep ADM)
par_i] (ADM Result -> ADM (Lambda (Rep ADM)))
-> ADM Result -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
[(VName, SubExp)]
-> ((VName, SubExp) -> ADM SubExpRes) -> ADM Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ys [SubExp]
e) (((VName, SubExp) -> ADM SubExpRes) -> ADM Result)
-> ((VName, SubExp) -> ADM SubExpRes) -> ADM Result
forall a b. (a -> b) -> a -> b
$ \(VName
arr, SubExp
arr_e) -> do
VName
a_lift <-
[Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"y_right"
(Exp SOACS -> ADM VName) -> ADM (Exp SOACS) -> ADM VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
( Exp (Rep ADM) -> ADM (Exp (Rep ADM))
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Rep ADM) -> ADM (Exp (Rep ADM)))
-> Exp (Rep ADM) -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (PrimType -> CmpOp
CmpEq PrimType
int64) (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
par_i) (Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64))
)
(ADM Result -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (ADM Result -> ADM (Body (Rep ADM)))
-> ADM Result -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ Result -> ADM Result
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp -> SubExpRes
subExpRes SubExp
arr_e])
( ADM Result -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (ADM Result -> ADM (Body (Rep ADM)))
-> ADM Result -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ (\SubExp
x -> [SubExp -> SubExpRes
subExpRes SubExp
x]) (SubExp -> Result) -> ADM SubExp -> ADM Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"val" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> [ADM (Exp (Rep ADM))] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
eIndex VName
arr [TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (TPrimExp Int64 VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
par_i) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1])
)
SubExpRes -> ADM SubExpRes
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExpRes -> ADM SubExpRes) -> SubExpRes -> ADM SubExpRes
forall a b. (a -> b) -> a -> b
$ VName -> SubExpRes
varRes VName
a_lift
VName
iota <- [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"iota" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
[Char] -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"ys_right" (Exp (Rep ADM) -> ADM [VName]) -> Exp (Rep ADM) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM))
-> OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName
iota] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lmb
finalMapPPAD :: VjpOps -> [VName] -> Scan SOACS -> ADM (Lambda SOACS)
finalMapPPAD :: VjpOps -> [VName] -> Scan SOACS -> ADM (Lambda SOACS)
finalMapPPAD VjpOps
ops [VName]
as Scan SOACS
scan = do
[TypeBase Shape NoUniqueness]
as_types <- (VName -> ADM (TypeBase Shape NoUniqueness))
-> [VName] -> ADM [TypeBase Shape NoUniqueness]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType [VName]
as
let arg_type_row :: [TypeBase Shape NoUniqueness]
arg_type_row = (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. TypeBase Shape u -> TypeBase Shape u
rowType [TypeBase Shape NoUniqueness]
as_types
[Param (TypeBase Shape NoUniqueness)]
par_y_right <- (VName
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness)))
-> [VName]
-> [TypeBase Shape NoUniqueness]
-> ADM [Param (TypeBase Shape NoUniqueness)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_par_y_right")) [VName]
as [TypeBase Shape NoUniqueness]
arg_type_row
[Param (TypeBase Shape NoUniqueness)]
par_a <- (VName
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness)))
-> [VName]
-> [TypeBase Shape NoUniqueness]
-> ADM [Param (TypeBase Shape NoUniqueness)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_par_a")) [VName]
as [TypeBase Shape NoUniqueness]
arg_type_row
[Param (TypeBase Shape NoUniqueness)]
par_r_adj <- (VName
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness)))
-> [VName]
-> [TypeBase Shape NoUniqueness]
-> ADM [Param (TypeBase Shape NoUniqueness)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_par_r_adj")) [VName]
as [TypeBase Shape NoUniqueness]
arg_type_row
[LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda ([Param (TypeBase Shape NoUniqueness)]
par_y_right [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ [Param (TypeBase Shape NoUniqueness)]
par_a [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ [Param (TypeBase Shape NoUniqueness)]
par_r_adj) (ADM Result -> ADM (Lambda (Rep ADM)))
-> ADM Result -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
Lambda SOACS
op_bar_2 <- VjpOps
-> Lambda SOACS -> FirstOrSecond -> [SubExp] -> ADM (Lambda SOACS)
mkScanAdjointLam VjpOps
ops (Scan SOACS -> Lambda SOACS
forall rep. Scan rep -> Lambda rep
scanLambda Scan SOACS
scan) FirstOrSecond
WrtSecond (VName -> SubExp
Var (VName -> SubExp)
-> (Param (TypeBase Shape NoUniqueness) -> VName)
-> Param (TypeBase Shape NoUniqueness)
-> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName (Param (TypeBase Shape NoUniqueness) -> SubExp)
-> [Param (TypeBase Shape NoUniqueness)] -> [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Param (TypeBase Shape NoUniqueness)]
par_r_adj)
Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM Result
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda (Rep ADM)
Lambda SOACS
op_bar_2 ([ADM (Exp (Rep ADM))] -> ADM Result)
-> [ADM (Exp (Rep ADM))] -> ADM Result
forall a b. (a -> b) -> a -> b
$ SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
toExp (SubExp -> ADM (Exp SOACS))
-> (Param (TypeBase Shape NoUniqueness) -> SubExp)
-> Param (TypeBase Shape NoUniqueness)
-> ADM (Exp SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> SubExp)
-> (Param (TypeBase Shape NoUniqueness) -> VName)
-> Param (TypeBase Shape NoUniqueness)
-> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName (Param (TypeBase Shape NoUniqueness) -> ADM (Exp SOACS))
-> [Param (TypeBase Shape NoUniqueness)] -> [ADM (Exp SOACS)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Param (TypeBase Shape NoUniqueness)]
par_y_right [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ [Param (TypeBase Shape NoUniqueness)]
par_a
diffScan :: VjpOps -> [VName] -> SubExp -> [VName] -> Scan SOACS -> ADM ()
diffScan :: VjpOps -> [VName] -> SubExp -> [VName] -> Scan SOACS -> ADM ()
diffScan VjpOps
ops [VName]
ys SubExp
w [VName]
as Scan SOACS
scan = do
ScanAlgo
scan_case <- VjpOps -> Lambda SOACS -> ADM ScanAlgo
identifyCase VjpOps
ops (Lambda SOACS -> ADM ScanAlgo) -> Lambda SOACS -> ADM ScanAlgo
forall a b. (a -> b) -> a -> b
$ Scan SOACS -> Lambda SOACS
forall rep. Scan rep -> Lambda rep
scanLambda Scan SOACS
scan
let d :: Int
d = [VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
as
[VName]
ys_adj <- (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM VName
lookupAdjVal [VName]
ys
[TypeBase Shape NoUniqueness]
as_ts <- (VName -> ADM (TypeBase Shape NoUniqueness))
-> [VName] -> ADM [TypeBase Shape NoUniqueness]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType [VName]
as
[VName]
as_contribs <- case ScanAlgo
scan_case of
ScanAlgo
GenericPPAD -> do
let e :: [SubExp]
e = Scan SOACS -> [SubExp]
forall rep. Scan rep -> [SubExp]
scanNeutral Scan SOACS
scan
[VName]
as_lift <- [VName] -> SubExp -> [SubExp] -> ADM [VName]
asLiftPPAD [VName]
as SubExp
w [SubExp]
e
let m :: [VName]
m = [VName]
ys [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
as_lift [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
ys_adj
Lambda SOACS
op_lft <- VjpOps -> [VName] -> Scan SOACS -> ADM (Lambda SOACS)
mkPPADOpLifted VjpOps
ops [VName]
as Scan SOACS
scan
[SubExp]
a_zero <- (TypeBase Shape NoUniqueness -> ADM SubExp)
-> [TypeBase Shape NoUniqueness] -> ADM [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((VName -> SubExp) -> ADM VName -> ADM SubExp
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> SubExp
Var (ADM VName -> ADM SubExp)
-> (TypeBase Shape NoUniqueness -> ADM VName)
-> TypeBase Shape NoUniqueness
-> ADM SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"rscan_zero" (Exp SOACS -> ADM VName)
-> (TypeBase Shape NoUniqueness -> Exp SOACS)
-> TypeBase Shape NoUniqueness
-> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeBase Shape NoUniqueness -> Exp SOACS
forall rep. TypeBase Shape NoUniqueness -> Exp rep
zeroExp (TypeBase Shape NoUniqueness -> Exp SOACS)
-> (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> TypeBase Shape NoUniqueness
-> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. TypeBase Shape u -> TypeBase Shape u
rowType) [TypeBase Shape NoUniqueness]
as_ts
let lft_scan :: Scan SOACS
lft_scan = Lambda SOACS -> [SubExp] -> Scan SOACS
forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda SOACS
op_lft ([SubExp] -> Scan SOACS) -> [SubExp] -> Scan SOACS
forall a b. (a -> b) -> a -> b
$ [SubExp]
e [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
e [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
a_zero
[VName]
rs_adj <- ([[VName]] -> Int -> [VName]
forall a. HasCallStack => [a] -> Int -> a
!! Int
2) ([[VName]] -> [VName])
-> ([VName] -> [[VName]]) -> [VName] -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [VName] -> [[VName]]
forall a. Int -> [a] -> [[a]]
chunk Int
d ([VName] -> [VName]) -> ADM [VName] -> ADM [VName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [VName] -> SubExp -> Scan SOACS -> ADM [VName]
scanRight [VName]
m SubExp
w Scan SOACS
lft_scan
[VName]
ys_right <- [VName] -> SubExp -> [SubExp] -> ADM [VName]
ysRightPPAD [VName]
ys SubExp
w [SubExp]
e
Lambda SOACS
final_lmb <- VjpOps -> [VName] -> Scan SOACS -> ADM (Lambda SOACS)
finalMapPPAD VjpOps
ops [VName]
as Scan SOACS
scan
[Char] -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"as_bar" (Exp (Rep ADM) -> ADM [VName]) -> Exp (Rep ADM) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM))
-> OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w ([VName]
ys_right [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
as [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
rs_adj) (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
final_lmb
GenericIFL23 Special
sc -> do
Lambda SOACS
map1_lam <- VjpOps
-> SubExp
-> Lambda SOACS
-> [VName]
-> [VName]
-> [VName]
-> Special
-> Int
-> ADM (Lambda SOACS)
mkScanFusedMapLam VjpOps
ops SubExp
w (Scan SOACS -> Lambda SOACS
forall rep. Scan rep -> Lambda rep
scanLambda Scan SOACS
scan) [VName]
as [VName]
ys [VName]
ys_adj Special
sc Int
d
Scan SOACS
scans_lin_fun_o <- TypeBase Shape NoUniqueness -> Special -> ADM (Scan SOACS)
mkScanLinFunO ([TypeBase Shape NoUniqueness] -> TypeBase Shape NoUniqueness
forall a. HasCallStack => [a] -> a
head [TypeBase Shape NoUniqueness]
as_ts) Special
sc
[Scan SOACS]
scan_lams <- Int -> Scan SOACS -> ADM [Scan SOACS]
mkScans (Special -> Int
specialScans Special
sc) Scan SOACS
scans_lin_fun_o
VName
iota <-
[Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"iota" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
[VName]
r_scan <-
[Char] -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"adj_ctrb_scan" (Exp SOACS -> ADM [VName])
-> (ScremaForm SOACS -> Exp SOACS)
-> ScremaForm SOACS
-> ADM [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> Exp SOACS)
-> (ScremaForm SOACS -> SOAC SOACS)
-> ScremaForm SOACS
-> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName
iota] (ScremaForm SOACS -> ADM [VName])
-> ScremaForm SOACS -> ADM [VName]
forall a b. (a -> b) -> a -> b
$
[Scan SOACS] -> Lambda SOACS -> ScremaForm SOACS
forall rep. [Scan rep] -> Lambda rep -> ScremaForm rep
scanomapSOAC [Scan SOACS]
scan_lams Lambda SOACS
map1_lam
VjpOps
-> SubExp
-> Lambda SOACS
-> [VName]
-> [VName]
-> [VName]
-> ADM [VName]
mkScanFinalMap VjpOps
ops SubExp
w (Scan SOACS -> Lambda SOACS
forall rep. Scan rep -> Lambda rep
scanLambda Scan SOACS
scan) [VName]
as [VName]
ys (Special -> [VName] -> Int -> [VName]
forall {b}. Special -> [b] -> Int -> [b]
splitScanRes Special
sc [VName]
r_scan Int
d)
(VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
updateAdj [VName]
as [VName]
as_contribs
where
mkScans :: Int -> Scan SOACS -> ADM [Scan SOACS]
mkScans :: Int -> Scan SOACS -> ADM [Scan SOACS]
mkScans Int
d Scan SOACS
s =
Int -> ADM (Scan SOACS) -> ADM [Scan SOACS]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
d (ADM (Scan SOACS) -> ADM [Scan SOACS])
-> ADM (Scan SOACS) -> ADM [Scan SOACS]
forall a b. (a -> b) -> a -> b
$ do
Lambda SOACS
lam' <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda (Lambda SOACS -> ADM (Lambda SOACS))
-> Lambda SOACS -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ Scan SOACS -> Lambda SOACS
forall rep. Scan rep -> Lambda rep
scanLambda Scan SOACS
s
Scan SOACS -> ADM (Scan SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Scan SOACS -> ADM (Scan SOACS)) -> Scan SOACS -> ADM (Scan SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [SubExp] -> Scan SOACS
forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda SOACS
lam' ([SubExp] -> Scan SOACS) -> [SubExp] -> Scan SOACS
forall a b. (a -> b) -> a -> b
$ Scan SOACS -> [SubExp]
forall rep. Scan rep -> [SubExp]
scanNeutral Scan SOACS
s
splitScanRes :: Special -> [b] -> Int -> [b]
splitScanRes Special
sc [b]
res Int
d =
([b] -> [b]) -> [[b]] -> [b]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Int -> [b] -> [b]
forall a. Int -> [a] -> [a]
take (Int -> Int -> Int
forall a. Integral a => a -> a -> a
div Int
d (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ Special -> Int
specialScans Special
sc)) (Special -> [b] -> [[b]]
forall a. Special -> [a] -> [[a]]
orderArgs Special
sc [b]
res)
diffScanVec ::
VjpOps ->
[VName] ->
StmAux () ->
SubExp ->
Lambda SOACS ->
[SubExp] ->
[VName] ->
ADM () ->
ADM ()
diffScanVec :: VjpOps
-> [VName]
-> StmAux ()
-> SubExp
-> Lambda SOACS
-> [SubExp]
-> [VName]
-> ADM ()
-> ADM ()
diffScanVec VjpOps
ops [VName]
ys StmAux ()
aux SubExp
w Lambda SOACS
lam [SubExp]
ne [VName]
as ADM ()
m = do
Seq (Stm SOACS)
stmts <- ADM [()] -> ADM (Stms (Rep ADM))
forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ (ADM [()] -> ADM (Stms (Rep ADM)))
-> ADM [()] -> ADM (Stms (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
Int
rank <- TypeBase Shape NoUniqueness -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (TypeBase Shape NoUniqueness -> Int)
-> ADM (TypeBase Shape NoUniqueness) -> ADM Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ADM (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType ([VName] -> VName
forall a. HasCallStack => [a] -> a
head [VName]
as)
let rear :: [Int]
rear = [Int
1, Int
0] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
drop Int
2 [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
[VName]
transp_as <-
[VName] -> (VName -> ADM VName) -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
as ((VName -> ADM VName) -> ADM [VName])
-> (VName -> ADM VName) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ \VName
a ->
[Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
a [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_transp") (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ VName -> [Int] -> BasicOp
Rearrange VName
a [Int]
rear
[TypeBase Shape NoUniqueness]
ts <- (VName -> ADM (TypeBase Shape NoUniqueness))
-> [VName] -> ADM [TypeBase Shape NoUniqueness]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse VName -> ADM (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType [VName]
transp_as
let n :: SubExp
n = Int -> [TypeBase Shape NoUniqueness] -> SubExp
forall u. Int -> [TypeBase Shape u] -> SubExp
arraysSize Int
0 [TypeBase Shape NoUniqueness]
ts
[Param (TypeBase Shape NoUniqueness)]
as_par <- (TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness)))
-> [TypeBase Shape NoUniqueness]
-> ADM [Param (TypeBase Shape NoUniqueness)]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ([Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"as_par" (TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness)))
-> (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. TypeBase Shape u -> TypeBase Shape u
rowType) [TypeBase Shape NoUniqueness]
ts
[Param (TypeBase Shape NoUniqueness)]
ne_par <- (TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness)))
-> [TypeBase Shape NoUniqueness]
-> ADM [Param (TypeBase Shape NoUniqueness)]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ([Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"ne_par") ([TypeBase Shape NoUniqueness]
-> ADM [Param (TypeBase Shape NoUniqueness)])
-> [TypeBase Shape NoUniqueness]
-> ADM [Param (TypeBase Shape NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
lam
ScremaForm SOACS
scan_form <- [Scan SOACS] -> ADM (ScremaForm SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Scan rep] -> m (ScremaForm rep)
scanSOAC [Lambda SOACS -> [SubExp] -> Scan SOACS
forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda SOACS
lam ((Param (TypeBase Shape NoUniqueness) -> SubExp)
-> [Param (TypeBase Shape NoUniqueness)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (Param (TypeBase Shape NoUniqueness) -> VName)
-> Param (TypeBase Shape NoUniqueness)
-> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName) [Param (TypeBase Shape NoUniqueness)]
ne_par)]
Lambda SOACS
map_lam <-
[LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda ([Param (TypeBase Shape NoUniqueness)]
as_par [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ [Param (TypeBase Shape NoUniqueness)]
ne_par) (ADM Result -> ADM (Lambda SOACS))
-> (SOAC SOACS -> ADM Result) -> SOAC SOACS -> ADM (Lambda SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([VName] -> Result) -> ADM [VName] -> ADM Result
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes (ADM [VName] -> ADM Result)
-> (SOAC SOACS -> ADM [VName]) -> SOAC SOACS -> ADM Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"map_res" (Exp SOACS -> ADM [VName])
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM (Lambda SOACS))
-> SOAC SOACS -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$
SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w ((Param (TypeBase Shape NoUniqueness) -> VName)
-> [Param (TypeBase Shape NoUniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName [Param (TypeBase Shape NoUniqueness)]
as_par) ScremaForm SOACS
scan_form
[VName]
transp_ys <-
[Char] -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"trans_ys" (Exp SOACS -> ADM [VName])
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM [VName]) -> SOAC SOACS -> ADM [VName]
forall a b. (a -> b) -> a -> b
$
SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n ([VName]
transp_as [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [SubExp] -> [VName]
subExpVars [SubExp]
ne) (Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
map_lam)
[(VName, VName)] -> ((VName, VName) -> ADM ()) -> ADM [()]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ys [VName]
transp_ys) (((VName, VName) -> ADM ()) -> ADM [()])
-> ((VName, VName) -> ADM ()) -> ADM [()]
forall a b. (a -> b) -> a -> b
$ \(VName
y, VName
x) ->
StmAux () -> ADM () -> ADM ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
y] (Exp (Rep ADM) -> ADM ()) -> Exp (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ VName -> [Int] -> BasicOp
Rearrange VName
x [Int]
rear
(Stm SOACS -> ADM () -> ADM ())
-> ADM () -> Seq (Stm SOACS) -> ADM ()
forall a b. (a -> b -> b) -> b -> Seq a -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (VjpOps -> Stm SOACS -> ADM () -> ADM ()
vjpStm VjpOps
ops) ADM ()
m Seq (Stm SOACS)
stmts
diffScanAdd :: VjpOps -> VName -> SubExp -> Lambda SOACS -> SubExp -> VName -> ADM ()
diffScanAdd :: VjpOps
-> VName -> SubExp -> Lambda SOACS -> SubExp -> VName -> ADM ()
diffScanAdd VjpOps
_ops VName
ys SubExp
n Lambda SOACS
lam' SubExp
ne VName
as = do
Lambda SOACS
lam <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam'
VName
ys_bar <- VName -> ADM VName
lookupAdjVal VName
ys
Lambda SOACS
map_scan <- VName -> ADM (Lambda SOACS)
rev_arr_lam VName
ys_bar
VName
iota <-
[Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"iota" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
n (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
VName
scan_res <-
[Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"res_rev" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM))
-> OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
iota] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ [Scan SOACS] -> Lambda SOACS -> ScremaForm SOACS
forall rep. [Scan rep] -> Lambda rep -> ScremaForm rep
scanomapSOAC [Lambda SOACS -> [SubExp] -> Scan SOACS
forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda SOACS
lam [SubExp
ne]] Lambda SOACS
map_scan
Lambda SOACS
rev_lam <- VName -> ADM (Lambda SOACS)
rev_arr_lam VName
scan_res
VName
contrb <- [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"contrb" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM))
-> OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
iota] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
rev_lam
VName -> VName -> ADM ()
updateAdj VName
as VName
contrb
where
rev_arr_lam :: VName -> ADM (Lambda SOACS)
rev_arr_lam :: VName -> ADM (Lambda SOACS)
rev_arr_lam VName
arr = do
Param (TypeBase Shape NoUniqueness)
par_i <- [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"i" (TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness)))
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
[LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [Param (TypeBase Shape NoUniqueness)
LParam (Rep ADM)
par_i] (ADM Result -> ADM (Lambda (Rep ADM)))
-> ADM Result -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
VName
a <-
[Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"ys_bar_rev"
(Exp SOACS -> ADM VName) -> ADM (Exp SOACS) -> ADM VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> [ADM (Exp (Rep ADM))] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
eIndex VName
arr [TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
n TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
par_i) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)]
Result -> ADM Result
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName -> SubExpRes
varRes VName
a]