module Futhark.AD.Rev.Scan (diffScan, diffScanVec, diffScanAdd) where

import Control.Monad
import Data.List (transpose)
import Futhark.AD.Rev.Monad
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.SOACS
import Futhark.IR.SOACS.Simplify (simplifyLambda)
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Util (chunk)

data FirstOrSecond = WrtFirst | WrtSecond

identityM :: Int -> Type -> ADM [[SubExp]]
identityM :: Int -> TypeBase Shape NoUniqueness -> ADM [[SubExp]]
identityM Int
n TypeBase Shape NoUniqueness
t =
  ([Exp SOACS] -> ADM [SubExp]) -> [[Exp SOACS]] -> ADM [[SubExp]]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse
    ((Exp SOACS -> ADM SubExp) -> [Exp SOACS] -> ADM [SubExp]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ([Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"id"))
    [[if Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
j then TypeBase Shape NoUniqueness -> Exp SOACS
forall rep. TypeBase Shape NoUniqueness -> Exp rep
oneExp TypeBase Shape NoUniqueness
t else TypeBase Shape NoUniqueness -> Exp SOACS
forall rep. TypeBase Shape NoUniqueness -> Exp rep
zeroExp TypeBase Shape NoUniqueness
t | Int
i <- [Int
1 .. Int
n]] | Int
j <- [Int
1 .. Int
n]]

matrixMul :: [[PrimExp VName]] -> [[PrimExp VName]] -> PrimType -> [[PrimExp VName]]
matrixMul :: [[PrimExp VName]]
-> [[PrimExp VName]] -> PrimType -> [[PrimExp VName]]
matrixMul [[PrimExp VName]]
m1 [[PrimExp VName]]
m2 PrimType
t =
  let zero :: PrimExp VName
zero = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t (SubExp -> PrimExp VName) -> SubExp -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t
   in [[(PrimExp VName -> PrimExp VName -> PrimExp VName)
-> PrimExp VName -> [PrimExp VName] -> PrimExp VName
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
(~+~) PrimExp VName
zero ([PrimExp VName] -> PrimExp VName)
-> [PrimExp VName] -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ (PrimExp VName -> PrimExp VName -> PrimExp VName)
-> [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
(~*~) [PrimExp VName]
r [PrimExp VName]
q | [PrimExp VName]
q <- [[PrimExp VName]] -> [[PrimExp VName]]
forall a. [[a]] -> [[a]]
transpose [[PrimExp VName]]
m2] | [PrimExp VName]
r <- [[PrimExp VName]]
m1]

matrixVecMul :: [[PrimExp VName]] -> [PrimExp VName] -> PrimType -> [PrimExp VName]
matrixVecMul :: [[PrimExp VName]] -> [PrimExp VName] -> PrimType -> [PrimExp VName]
matrixVecMul [[PrimExp VName]]
m [PrimExp VName]
v PrimType
t =
  let zero :: PrimExp VName
zero = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t (SubExp -> PrimExp VName) -> SubExp -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t
   in [(PrimExp VName -> PrimExp VName -> PrimExp VName)
-> PrimExp VName -> [PrimExp VName] -> PrimExp VName
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
(~+~) PrimExp VName
zero ([PrimExp VName] -> PrimExp VName)
-> [PrimExp VName] -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ (PrimExp VName -> PrimExp VName -> PrimExp VName)
-> [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
(~*~) [PrimExp VName]
v [PrimExp VName]
r | [PrimExp VName]
r <- [[PrimExp VName]]
m]

vectorAdd :: [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName]
vectorAdd :: [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName]
vectorAdd = (PrimExp VName -> PrimExp VName -> PrimExp VName)
-> [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
(~+~)

orderArgs :: Special -> [a] -> [[a]]
orderArgs :: forall a. Special -> [a] -> [[a]]
orderArgs Special
s [a]
lst = Int -> [a] -> [[a]]
forall a. Int -> [a] -> [[a]]
chunk (Int -> Int -> Int
forall a. Integral a => a -> a -> a
div ([a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
lst) (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ Special -> Int
specialScans Special
s) [a]
lst

-- computes `d(x op y)/dx` or d(x op y)/dy
mkScanAdjointLam :: VjpOps -> Lambda SOACS -> FirstOrSecond -> [SubExp] -> ADM (Lambda SOACS)
mkScanAdjointLam :: VjpOps
-> Lambda SOACS -> FirstOrSecond -> [SubExp] -> ADM (Lambda SOACS)
mkScanAdjointLam VjpOps
ops Lambda SOACS
lam0 FirstOrSecond
which [SubExp]
adjs = do
  let len :: Int
len = [TypeBase Shape NoUniqueness] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([TypeBase Shape NoUniqueness] -> Int)
-> [TypeBase Shape NoUniqueness] -> Int
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
lam0
  Lambda SOACS
lam <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam0
  let p2diff :: [Param (TypeBase Shape NoUniqueness)]
p2diff =
        case FirstOrSecond
which of
          FirstOrSecond
WrtFirst -> Int
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. Int -> [a] -> [a]
take Int
len ([Param (TypeBase Shape NoUniqueness)]
 -> [Param (TypeBase Shape NoUniqueness)])
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam
          FirstOrSecond
WrtSecond -> Int
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. Int -> [a] -> [a]
drop Int
len ([Param (TypeBase Shape NoUniqueness)]
 -> [Param (TypeBase Shape NoUniqueness)])
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam
  VjpOps -> [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda VjpOps
ops ((SubExp -> Adj) -> [SubExp] -> [Adj]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> Adj
AdjVal [SubExp]
adjs) ((Param (TypeBase Shape NoUniqueness) -> VName)
-> [Param (TypeBase Shape NoUniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName [Param (TypeBase Shape NoUniqueness)]
p2diff) Lambda SOACS
lam

-- Should generate something like:
-- `\ j -> let i = n - 1 - j
--         if i < n-1 then ( ys_adj[i], df2dx ys[i] xs[i+1]) else (ys_adj[i],1) )`
-- where `ys` is  the result of scan
--       `xs` is  the input  of scan
--       `ys_adj` is the known adjoint of ys
--       `j` draw values from `iota n`
mkScanFusedMapLam :: -- i and j above are probably swapped in the code below
  VjpOps -> -- (ops) helper functions
  SubExp -> -- (w) ~length of arrays e.g. xs
  Lambda SOACS -> -- (scn_lam) the scan to be differentiated ('scan' turned into a lambda)
  [VName] -> -- (xs) input of the scan (actually as)
  [VName] -> -- (ys) output of the scan
  [VName] -> -- (ys_adj) adjoint of ys
  Special -> -- (s) information about which special case we're working with for the scan derivative
  Int -> -- (d) dimension of the input (number of elements in the input tuple)
  ADM (Lambda SOACS) -- output: some kind of codegen for the lambda
mkScanFusedMapLam :: VjpOps
-> SubExp
-> Lambda SOACS
-> [VName]
-> [VName]
-> [VName]
-> Special
-> Int
-> ADM (Lambda SOACS)
mkScanFusedMapLam VjpOps
ops SubExp
w Lambda SOACS
scn_lam [VName]
xs [VName]
ys [VName]
ys_adj Special
s Int
d = do
  let sc :: Maybe SpecialCase
sc = Special -> Maybe SpecialCase
specialCase Special
s
      k :: Int
k = Special -> Int
specialSubSize Special
s
  [TypeBase Shape NoUniqueness]
ys_ts <- (VName -> ADM (TypeBase Shape NoUniqueness))
-> [VName] -> ADM [TypeBase Shape NoUniqueness]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse VName -> ADM (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType [VName]
ys
  [[SubExp]]
idmat <- Int -> TypeBase Shape NoUniqueness -> ADM [[SubExp]]
identityM ([VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
ys) (TypeBase Shape NoUniqueness -> ADM [[SubExp]])
-> TypeBase Shape NoUniqueness -> ADM [[SubExp]]
forall a b. (a -> b) -> a -> b
$ TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. TypeBase Shape u -> TypeBase Shape u
rowType (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall a b. (a -> b) -> a -> b
$ [TypeBase Shape NoUniqueness] -> TypeBase Shape NoUniqueness
forall a. HasCallStack => [a] -> a
head [TypeBase Shape NoUniqueness]
ys_ts
  [Lambda SOACS]
lams <- ([SubExp] -> ADM (Lambda SOACS))
-> [[SubExp]] -> ADM [Lambda SOACS]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (VjpOps
-> Lambda SOACS -> FirstOrSecond -> [SubExp] -> ADM (Lambda SOACS)
mkScanAdjointLam VjpOps
ops Lambda SOACS
scn_lam FirstOrSecond
WrtFirst) [[SubExp]]
idmat
  Param (TypeBase Shape NoUniqueness)
par_i <- [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"i" (TypeBase Shape NoUniqueness
 -> ADM (Param (TypeBase Shape NoUniqueness)))
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  let i :: VName
i = Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
par_i
  [LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [Param (TypeBase Shape NoUniqueness)
LParam (Rep ADM)
par_i] (ADM Result -> ADM (Lambda (Rep ADM)))
-> ADM Result -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
    ([VName] -> Result) -> ADM [VName] -> ADM Result
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes (ADM [VName] -> ADM Result)
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"x"
      (Exp SOACS -> ADM Result) -> ADM (Exp SOACS) -> ADM Result
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
        (TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0)
        ( ADM Result -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (ADM Result -> ADM (Body (Rep ADM)))
-> ADM Result -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
            SubExp
j <- [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"j" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
w TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1))
            [SubExp]
y_s <- [VName] -> (VName -> ADM SubExp) -> ADM [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
ys_adj ((VName -> ADM SubExp) -> ADM [SubExp])
-> (VName -> ADM SubExp) -> ADM [SubExp]
forall a b. (a -> b) -> a -> b
$ \VName
y_ ->
              [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp (VName -> [Char]
baseString VName
y_ [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_j") (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> [ADM (Exp (Rep ADM))] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
eIndex VName
y_ [SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
j]
            let zso :: [[SubExp]]
zso = Special -> [SubExp] -> [[SubExp]]
forall a. Special -> [a] -> [[a]]
orderArgs Special
s [SubExp]
y_s
            let ido :: [[[SubExp]]]
ido = Special -> [[SubExp]] -> [[[SubExp]]]
forall a. Special -> [a] -> [[a]]
orderArgs Special
s ([[SubExp]] -> [[[SubExp]]]) -> [[SubExp]] -> [[[SubExp]]]
forall a b. (a -> b) -> a -> b
$ Int -> Maybe SpecialCase -> [[SubExp]] -> [[SubExp]]
forall a. Int -> Maybe SpecialCase -> [[a]] -> [[a]]
caseJac Int
k Maybe SpecialCase
sc [[SubExp]]
idmat
            Result -> ADM Result
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> ADM Result) -> Result -> ADM Result
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes ([SubExp] -> Result) -> [SubExp] -> Result
forall a b. (a -> b) -> a -> b
$ [[SubExp]] -> [SubExp]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[SubExp]] -> [SubExp]) -> [[SubExp]] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ ([SubExp] -> [SubExp] -> [SubExp])
-> [[SubExp]] -> [[SubExp]] -> [[SubExp]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
(++) [[SubExp]]
zso ([[SubExp]] -> [[SubExp]]) -> [[SubExp]] -> [[SubExp]]
forall a b. (a -> b) -> a -> b
$ ([[SubExp]] -> [SubExp]) -> [[[SubExp]]] -> [[SubExp]]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [[SubExp]] -> [SubExp]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[[SubExp]]]
ido
        )
        ( ADM Result -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (ADM Result -> ADM (Body (Rep ADM)))
-> ADM Result -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
            SubExp
j <- [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"j" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
w TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1))
            SubExp
j1 <- [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"j1" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
w TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i)
            [SubExp]
y_s <- [VName] -> (VName -> ADM SubExp) -> ADM [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
ys_adj ((VName -> ADM SubExp) -> ADM [SubExp])
-> (VName -> ADM SubExp) -> ADM [SubExp]
forall a b. (a -> b) -> a -> b
$ \VName
y_ ->
              [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp (VName -> [Char]
baseString VName
y_ [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_j") (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> [ADM (Exp (Rep ADM))] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
eIndex VName
y_ [SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
j]

            let args :: [ADM (Exp (Rep ADM))]
args =
                  (VName -> ADM (Exp (Rep ADM))) -> [VName] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> [ADM (Exp (Rep ADM))] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
`eIndex` [SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
j]) [VName]
ys [ADM (Exp (Rep ADM))]
-> [ADM (Exp (Rep ADM))] -> [ADM (Exp (Rep ADM))]
forall a. [a] -> [a] -> [a]
++ (VName -> ADM (Exp (Rep ADM))) -> [VName] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> [ADM (Exp (Rep ADM))] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
`eIndex` [SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
j1]) [VName]
xs
            [Result]
lam_rs <- (Lambda SOACS -> ADM Result) -> [Lambda SOACS] -> ADM [Result]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM Result
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
`eLambda` [ADM (Exp (Rep ADM))]
args) [Lambda SOACS]
lams

            let yso :: [Result]
yso = Special -> Result -> [Result]
forall a. Special -> [a] -> [[a]]
orderArgs Special
s (Result -> [Result]) -> Result -> [Result]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes [SubExp]
y_s
            let jaco :: [[Result]]
jaco = Special -> [Result] -> [[Result]]
forall a. Special -> [a] -> [[a]]
orderArgs Special
s ([Result] -> [[Result]]) -> [Result] -> [[Result]]
forall a b. (a -> b) -> a -> b
$ Int -> Maybe SpecialCase -> [Result] -> [Result]
forall a. Int -> Maybe SpecialCase -> [[a]] -> [[a]]
caseJac Int
k Maybe SpecialCase
sc ([Result] -> [Result]) -> [Result] -> [Result]
forall a b. (a -> b) -> a -> b
$ [Result] -> [Result]
forall a. [[a]] -> [[a]]
transpose [Result]
lam_rs

            Result -> ADM Result
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> ADM Result) -> Result -> ADM Result
forall a b. (a -> b) -> a -> b
$ [Result] -> Result
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([Result] -> Result) -> [Result] -> Result
forall a b. (a -> b) -> a -> b
$ (Result -> Result -> Result) -> [Result] -> [Result] -> [Result]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Result -> Result -> Result
forall a. [a] -> [a] -> [a]
(++) [Result]
yso ([Result] -> [Result]) -> [Result] -> [Result]
forall a b. (a -> b) -> a -> b
$ ([Result] -> Result) -> [[Result]] -> [Result]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Result] -> Result
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Result]]
jaco
        )
  where
    caseJac :: Int -> Maybe SpecialCase -> [[a]] -> [[a]]
    caseJac :: forall a. Int -> Maybe SpecialCase -> [[a]] -> [[a]]
caseJac Int
_ Maybe SpecialCase
Nothing [[a]]
jac = [[a]]
jac
    caseJac Int
k (Just SpecialCase
ZeroQuadrant) [[a]]
jac =
      [[[a]]] -> [[a]]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[[a]]] -> [[a]]) -> [[[a]]] -> [[a]]
forall a b. (a -> b) -> a -> b
$
        (Int -> [[a]] -> [[a]]) -> [Int] -> [[[a]]] -> [[[a]]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Int
i -> ([a] -> [a]) -> [[a]] -> [[a]]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
take Int
k ([a] -> [a]) -> ([a] -> [a]) -> [a] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
drop (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
k))) [Int
0 .. Int
d Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
k] ([[[a]]] -> [[[a]]]) -> [[[a]]] -> [[[a]]]
forall a b. (a -> b) -> a -> b
$
          Int -> [[a]] -> [[[a]]]
forall a. Int -> [a] -> [[a]]
chunk Int
k [[a]]
jac
    caseJac Int
k (Just SpecialCase
MatrixMul) [[a]]
jac =
      Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
take Int
k ([a] -> [a]) -> [[a]] -> [[a]]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> [[a]] -> [[a]]
forall a. Int -> [a] -> [a]
take Int
k [[a]]
jac

-- a1 a2 b -> a2 + b * a1
linFunT0 :: [PrimExp VName] -> [PrimExp VName] -> [[PrimExp VName]] -> Special -> PrimType -> [PrimExp VName]
linFunT0 :: [PrimExp VName]
-> [PrimExp VName]
-> [[PrimExp VName]]
-> Special
-> PrimType
-> [PrimExp VName]
linFunT0 [PrimExp VName]
a1 [PrimExp VName]
a2 [[PrimExp VName]]
b Special
s PrimType
pt =
  let t :: [PrimExp VName]
t = case Special -> Maybe SpecialCase
specialCase Special
s of
        Just SpecialCase
MatrixMul ->
          ([PrimExp VName] -> [PrimExp VName])
-> [[PrimExp VName]] -> [PrimExp VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\[PrimExp VName]
v -> [[PrimExp VName]] -> [PrimExp VName] -> PrimType -> [PrimExp VName]
matrixVecMul [[PrimExp VName]]
b [PrimExp VName]
v PrimType
pt) ([[PrimExp VName]] -> [PrimExp VName])
-> [[PrimExp VName]] -> [PrimExp VName]
forall a b. (a -> b) -> a -> b
$ Int -> [PrimExp VName] -> [[PrimExp VName]]
forall a. Int -> [a] -> [[a]]
chunk (Special -> Int
specialSubSize Special
s) [PrimExp VName]
a1
        Maybe SpecialCase
_ -> [[PrimExp VName]] -> [PrimExp VName] -> PrimType -> [PrimExp VName]
matrixVecMul [[PrimExp VName]]
b [PrimExp VName]
a1 PrimType
pt
   in [PrimExp VName]
a2 [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName]
`vectorAdd` [PrimExp VName]
t

-- \(a1, b1) (a2, b2) -> (a2 + b2 * a1, b2 * b1)
mkScanLinFunO :: Type -> Special -> ADM (Scan SOACS) -- a is an instance of y_bar, b is a Jacobian (a 'c' in the 2023 paper)
mkScanLinFunO :: TypeBase Shape NoUniqueness -> Special -> ADM (Scan SOACS)
mkScanLinFunO TypeBase Shape NoUniqueness
t Special
s = do
  let pt :: PrimType
pt = TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
t
  [SubExp]
neu_elm <- (Int, Int) -> ADM [SubExp]
mkNeutral ((Int, Int) -> ADM [SubExp]) -> (Int, Int) -> ADM [SubExp]
forall a b. (a -> b) -> a -> b
$ Special -> (Int, Int)
specialNeutral Special
s
  let (Int
as, Int
bs) = Special -> (Int, Int)
specialParams Special
s -- input size, Jacobian element count
  ([VName]
a1s, [VName]
b1s, [VName]
a2s, [VName]
b2s) <- (Int, Int) -> ADM ([VName], [VName], [VName], [VName])
forall {m :: * -> *}.
MonadFreshNames m =>
(Int, Int) -> m ([VName], [VName], [VName], [VName])
mkParams (Int
as, Int
bs) -- create sufficient free variables to bind every element of the vectors / matrices
  let pet :: VName -> PrimExp VName
pet = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
pt (SubExp -> PrimExp VName)
-> (VName -> SubExp) -> VName -> PrimExp VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var -- manifest variable names as expressions
  let (Int
_, Int
n) = Special -> (Int, Int)
specialNeutral Special
s -- output size (one side of the Jacobian)
  Lambda SOACS
lam <- [LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda ((VName -> Param (TypeBase Shape NoUniqueness))
-> [VName] -> [Param (TypeBase Shape NoUniqueness)]
forall a b. (a -> b) -> [a] -> [b]
map (\VName
v -> Attrs
-> VName
-> TypeBase Shape NoUniqueness
-> Param (TypeBase Shape NoUniqueness)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty VName
v (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. TypeBase Shape u -> TypeBase Shape u
rowType TypeBase Shape NoUniqueness
t)) ([VName]
a1s [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
b1s [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
a2s [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
b2s)) (ADM Result -> ADM (Lambda SOACS))
-> (ADM [SubExp] -> ADM Result)
-> ADM [SubExp]
-> ADM (Lambda SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([SubExp] -> Result) -> ADM [SubExp] -> ADM Result
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [SubExp] -> Result
subExpsRes (ADM [SubExp] -> ADM (Lambda SOACS))
-> ADM [SubExp] -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ do
    let [[PrimExp VName]
a1s', [PrimExp VName]
b1s', [PrimExp VName]
a2s', [PrimExp VName]
b2s'] = (([VName] -> [PrimExp VName]) -> [[VName]] -> [[PrimExp VName]]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([VName] -> [PrimExp VName]) -> [[VName]] -> [[PrimExp VName]])
-> ((VName -> PrimExp VName) -> [VName] -> [PrimExp VName])
-> (VName -> PrimExp VName)
-> [[VName]]
-> [[PrimExp VName]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> PrimExp VName) -> [VName] -> [PrimExp VName]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap) VName -> PrimExp VName
pet [[VName]
a1s, [VName]
b1s, [VName]
a2s, [VName]
b2s]
    let ([[PrimExp VName]]
b1sm, [[PrimExp VName]]
b2sm) = (Int -> [PrimExp VName] -> [[PrimExp VName]]
forall a. Int -> [a] -> [[a]]
chunk Int
n [PrimExp VName]
b1s', Int -> [PrimExp VName] -> [[PrimExp VName]]
forall a. Int -> [a] -> [[a]]
chunk Int
n [PrimExp VName]
b2s')

    let t0 :: [PrimExp VName]
t0 = [PrimExp VName]
-> [PrimExp VName]
-> [[PrimExp VName]]
-> Special
-> PrimType
-> [PrimExp VName]
linFunT0 [PrimExp VName]
a1s' [PrimExp VName]
a2s' [[PrimExp VName]]
b2sm Special
s PrimType
pt
    let t1 :: [PrimExp VName]
t1 = [[PrimExp VName]] -> [PrimExp VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[PrimExp VName]] -> [PrimExp VName])
-> [[PrimExp VName]] -> [PrimExp VName]
forall a b. (a -> b) -> a -> b
$ [[PrimExp VName]]
-> [[PrimExp VName]] -> PrimType -> [[PrimExp VName]]
matrixMul [[PrimExp VName]]
b2sm [[PrimExp VName]]
b1sm PrimType
pt
    (PrimExp VName -> ADM SubExp) -> [PrimExp VName] -> ADM [SubExp]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ([Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"r" (Exp SOACS -> ADM SubExp)
-> (PrimExp VName -> ADM (Exp SOACS))
-> PrimExp VName
-> ADM SubExp
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< PrimExp VName -> ADM (Exp (Rep ADM))
PrimExp VName -> ADM (Exp SOACS)
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
PrimExp VName -> m (Exp (Rep m))
toExp) ([PrimExp VName] -> ADM [SubExp])
-> [PrimExp VName] -> ADM [SubExp]
forall a b. (a -> b) -> a -> b
$ [PrimExp VName]
t0 [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName]
forall a. [a] -> [a] -> [a]
++ [PrimExp VName]
t1

  Scan SOACS -> ADM (Scan SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Scan SOACS -> ADM (Scan SOACS)) -> Scan SOACS -> ADM (Scan SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [SubExp] -> Scan SOACS
forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda SOACS
lam [SubExp]
neu_elm
  where
    mkNeutral :: (Int, Int) -> ADM [SubExp]
mkNeutral (Int
a, Int
b) = do
      [SubExp]
zeros <- Int -> ADM SubExp -> ADM [SubExp]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
a (ADM SubExp -> ADM [SubExp]) -> ADM SubExp -> ADM [SubExp]
forall a b. (a -> b) -> a -> b
$ [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"zeros" (Exp (Rep ADM) -> ADM SubExp) -> Exp (Rep ADM) -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ TypeBase Shape NoUniqueness -> Exp (Rep ADM)
forall rep. TypeBase Shape NoUniqueness -> Exp rep
zeroExp (TypeBase Shape NoUniqueness -> Exp (Rep ADM))
-> TypeBase Shape NoUniqueness -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. TypeBase Shape u -> TypeBase Shape u
rowType TypeBase Shape NoUniqueness
t
      [[SubExp]]
idmat <- Int -> TypeBase Shape NoUniqueness -> ADM [[SubExp]]
identityM Int
b (TypeBase Shape NoUniqueness -> ADM [[SubExp]])
-> TypeBase Shape NoUniqueness -> ADM [[SubExp]]
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> TypeBase Shape NoUniqueness)
-> PrimType -> TypeBase Shape NoUniqueness
forall a b. (a -> b) -> a -> b
$ TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
t
      [SubExp] -> ADM [SubExp]
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([SubExp] -> ADM [SubExp]) -> [SubExp] -> ADM [SubExp]
forall a b. (a -> b) -> a -> b
$ [SubExp]
zeros [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [[SubExp]] -> [SubExp]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[SubExp]]
idmat

    mkParams :: (Int, Int) -> m ([VName], [VName], [VName], [VName])
mkParams (Int
a, Int
b) = do
      [VName]
a1s <- Int -> m VName -> m [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
a (m VName -> m [VName]) -> m VName -> m [VName]
forall a b. (a -> b) -> a -> b
$ [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"a1"
      [VName]
b1s <- Int -> m VName -> m [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
b (m VName -> m [VName]) -> m VName -> m [VName]
forall a b. (a -> b) -> a -> b
$ [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"b1"
      [VName]
a2s <- Int -> m VName -> m [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
a (m VName -> m [VName]) -> m VName -> m [VName]
forall a b. (a -> b) -> a -> b
$ [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"a2"
      [VName]
b2s <- Int -> m VName -> m [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
b (m VName -> m [VName]) -> m VName -> m [VName]
forall a b. (a -> b) -> a -> b
$ [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"b2"
      ([VName], [VName], [VName], [VName])
-> m ([VName], [VName], [VName], [VName])
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([VName]
a1s, [VName]
b1s, [VName]
a2s, [VName]
b2s)

-- perform the final map
-- let xs_contribs =
--    map3 (\ i a r -> if i==0 then r else (df2dy (ys[i-1]) a) \bar{*} r)
--         (iota n) xs (reverse ds)
mkScanFinalMap :: VjpOps -> SubExp -> Lambda SOACS -> [VName] -> [VName] -> [VName] -> ADM [VName]
mkScanFinalMap :: VjpOps
-> SubExp
-> Lambda SOACS
-> [VName]
-> [VName]
-> [VName]
-> ADM [VName]
mkScanFinalMap VjpOps
ops SubExp
w Lambda SOACS
scan_lam [VName]
xs [VName]
ys [VName]
ds = do
  let eltps :: [TypeBase Shape NoUniqueness]
eltps = Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
scan_lam

  Param (TypeBase Shape NoUniqueness)
par_i <- [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"i" (TypeBase Shape NoUniqueness
 -> ADM (Param (TypeBase Shape NoUniqueness)))
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  let i :: VName
i = Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
par_i
  [Param (TypeBase Shape NoUniqueness)]
par_x <- (VName
 -> TypeBase Shape NoUniqueness
 -> ADM (Param (TypeBase Shape NoUniqueness)))
-> [VName]
-> [TypeBase Shape NoUniqueness]
-> ADM [Param (TypeBase Shape NoUniqueness)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_par_x")) [VName]
xs [TypeBase Shape NoUniqueness]
eltps

  Lambda SOACS
map_lam <-
    [LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda (Param (TypeBase Shape NoUniqueness)
par_i Param (TypeBase Shape NoUniqueness)
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. a -> [a] -> [a]
: [Param (TypeBase Shape NoUniqueness)]
par_x) (ADM Result -> ADM (Lambda (Rep ADM)))
-> ADM Result -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
      SubExp
j <- [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"j" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
w TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1))

      [VName]
dj <-
        [VName] -> (VName -> ADM VName) -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
ds ((VName -> ADM VName) -> ADM [VName])
-> (VName -> ADM VName) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ \VName
dd ->
          [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
dd [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_dj") (Exp SOACS -> ADM VName) -> ADM (Exp SOACS) -> ADM VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> [ADM (Exp (Rep ADM))] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
eIndex VName
dd [SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
j]

      ([VName] -> Result) -> ADM [VName] -> ADM Result
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes (ADM [VName] -> ADM Result)
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"scan_contribs"
        (Exp SOACS -> ADM Result) -> ADM (Exp SOACS) -> ADM Result
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          (TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0)
          ([SubExp] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM ([SubExp] -> ADM (Body (Rep ADM)))
-> [SubExp] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> SubExp
Var [VName]
dj)
          ( ADM Result -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (ADM Result -> ADM (Body (Rep ADM)))
-> ADM Result -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
              Lambda SOACS
lam <- VjpOps
-> Lambda SOACS -> FirstOrSecond -> [SubExp] -> ADM (Lambda SOACS)
mkScanAdjointLam VjpOps
ops Lambda SOACS
scan_lam FirstOrSecond
WrtSecond ([SubExp] -> ADM (Lambda SOACS)) -> [SubExp] -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> SubExp
Var [VName]
dj

              SubExp
im1 <- [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"im1" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)
              [SubExp]
ys_im1 <- [VName] -> (VName -> ADM SubExp) -> ADM [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
ys ((VName -> ADM SubExp) -> ADM [SubExp])
-> (VName -> ADM SubExp) -> ADM [SubExp]
forall a b. (a -> b) -> a -> b
$ \VName
y ->
                [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp (VName -> [Char]
baseString VName
y [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_im1") (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> [ADM (Exp (Rep ADM))] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
eIndex VName
y [SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
im1]

              let args :: [ADM (Exp (Rep ADM))]
args = (SubExp -> ADM (Exp (Rep ADM)))
-> [SubExp] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp ([SubExp] -> [ADM (Exp (Rep ADM))])
-> [SubExp] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ [SubExp]
ys_im1 [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ (Param (TypeBase Shape NoUniqueness) -> SubExp)
-> [Param (TypeBase Shape NoUniqueness)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (Param (TypeBase Shape NoUniqueness) -> VName)
-> Param (TypeBase Shape NoUniqueness)
-> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName) [Param (TypeBase Shape NoUniqueness)]
par_x
              Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM Result
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda (Rep ADM)
Lambda SOACS
lam [ADM (Exp (Rep ADM))]
args
          )

  VName
iota <- [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"iota" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
  [Char] -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"scan_contribs" (Exp (Rep ADM) -> ADM [VName]) -> Exp (Rep ADM) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM))
-> OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w (VName
iota VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
xs) (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
map_lam

-- | Scan special cases.
data SpecialCase = ZeroQuadrant | MatrixMul deriving (Int -> SpecialCase -> [Char] -> [Char]
[SpecialCase] -> [Char] -> [Char]
SpecialCase -> [Char]
(Int -> SpecialCase -> [Char] -> [Char])
-> (SpecialCase -> [Char])
-> ([SpecialCase] -> [Char] -> [Char])
-> Show SpecialCase
forall a.
(Int -> a -> [Char] -> [Char])
-> (a -> [Char]) -> ([a] -> [Char] -> [Char]) -> Show a
$cshowsPrec :: Int -> SpecialCase -> [Char] -> [Char]
showsPrec :: Int -> SpecialCase -> [Char] -> [Char]
$cshow :: SpecialCase -> [Char]
show :: SpecialCase -> [Char]
$cshowList :: [SpecialCase] -> [Char] -> [Char]
showList :: [SpecialCase] -> [Char] -> [Char]
Show)

-- | Metadata for how to perform the scan for the return sweep.
data Special = Special
  { -- | Size of one of the two dimensions of the Jacobian (e.g. 3 if
    --  it's 3x3, must be square because scan must be a->a->a). It's
    --  the size of the special neutral element, not the element itself
    Special -> (Int, Int)
specialNeutral :: (Int, Int),
    -- | Size of input (nr params); Flat size of Jacobian (dim1 *
    -- dim2)). Number of params for the special lambda.
    Special -> (Int, Int)
specialParams :: (Int, Int),
    -- | The number of scans to do, 1 in most cases, k in the
    -- ZeroQuadrant (block diagonal?) case.
    Special -> Int
specialScans :: Int,
    -- | Probably: the size of submatrices for the ZeroQuadrant (block
    -- diagonal?) case, or 1 otherwise.
    Special -> Int
specialSubSize :: Int,
    -- | Which case.
    Special -> Maybe SpecialCase
specialCase :: Maybe SpecialCase
  }
  deriving (Int -> Special -> [Char] -> [Char]
[Special] -> [Char] -> [Char]
Special -> [Char]
(Int -> Special -> [Char] -> [Char])
-> (Special -> [Char])
-> ([Special] -> [Char] -> [Char])
-> Show Special
forall a.
(Int -> a -> [Char] -> [Char])
-> (a -> [Char]) -> ([a] -> [Char] -> [Char]) -> Show a
$cshowsPrec :: Int -> Special -> [Char] -> [Char]
showsPrec :: Int -> Special -> [Char] -> [Char]
$cshow :: Special -> [Char]
show :: Special -> [Char]
$cshowList :: [Special] -> [Char] -> [Char]
showList :: [Special] -> [Char] -> [Char]
Show)

-- | The different ways to handle scans. The best one is chosen
-- heuristically by looking at the operator.
data ScanAlgo
  = -- | Construct and compose the Jacobians; the approach presented
    -- in *Reverse-Mode AD of Multi-Reduce and Scan in Futhark*.
    GenericIFL23 Special
  | -- | The approach from *Parallelism-preserving automatic
    -- differentiation for second-order array languages*.
    GenericPPAD
  deriving (Int -> ScanAlgo -> [Char] -> [Char]
[ScanAlgo] -> [Char] -> [Char]
ScanAlgo -> [Char]
(Int -> ScanAlgo -> [Char] -> [Char])
-> (ScanAlgo -> [Char])
-> ([ScanAlgo] -> [Char] -> [Char])
-> Show ScanAlgo
forall a.
(Int -> a -> [Char] -> [Char])
-> (a -> [Char]) -> ([a] -> [Char] -> [Char]) -> Show a
$cshowsPrec :: Int -> ScanAlgo -> [Char] -> [Char]
showsPrec :: Int -> ScanAlgo -> [Char] -> [Char]
$cshow :: ScanAlgo -> [Char]
show :: ScanAlgo -> [Char]
$cshowList :: [ScanAlgo] -> [Char] -> [Char]
showList :: [ScanAlgo] -> [Char] -> [Char]
Show)

subMats :: Int -> [[Exp SOACS]] -> Exp SOACS -> Maybe Int
subMats :: Int -> [[Exp SOACS]] -> Exp SOACS -> Maybe Int
subMats Int
d [[Exp SOACS]]
mat Exp SOACS
zero =
  let sub_d :: [Int]
sub_d = (Int -> Bool) -> [Int] -> [Int]
forall a. (a -> Bool) -> [a] -> [a]
filter (\Int
x -> Int
d Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
x Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) [Int
1 .. (Int
d Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2)]
      poss :: [Bool]
poss = (Int -> Bool) -> [Int] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map (\Int
m -> (([Exp SOACS], Int) -> Bool) -> [([Exp SOACS], Int)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Int -> ([Exp SOACS], Int) -> Bool
ok Int
m) ([([Exp SOACS], Int)] -> Bool) -> [([Exp SOACS], Int)] -> Bool
forall a b. (a -> b) -> a -> b
$ [[Exp SOACS]] -> [Int] -> [([Exp SOACS], Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [[Exp SOACS]]
mat [Int
0 .. Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]) [Int]
sub_d
      tmp :: [(Bool, Int)]
tmp = ((Bool, Int) -> Bool) -> [(Bool, Int)] -> [(Bool, Int)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool, Int) -> Bool
forall a b. (a, b) -> a
fst ([Bool] -> [Int] -> [(Bool, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
poss [Int]
sub_d)
   in if [(Bool, Int)] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Bool, Int)]
tmp then Maybe Int
forall a. Maybe a
Nothing else Int -> Maybe Int
forall a. a -> Maybe a
Just (Int -> Maybe Int) -> Int -> Maybe Int
forall a b. (a -> b) -> a -> b
$ (Bool, Int) -> Int
forall a b. (a, b) -> b
snd ((Bool, Int) -> Int) -> (Bool, Int) -> Int
forall a b. (a -> b) -> a -> b
$ [(Bool, Int)] -> (Bool, Int)
forall a. HasCallStack => [a] -> a
head [(Bool, Int)]
tmp
  where
    ok :: Int -> ([Exp SOACS], Int) -> Bool
ok Int
m ([Exp SOACS]
row, Int
i) =
      ((Exp SOACS, Int) -> Bool) -> [(Exp SOACS, Int)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\(Exp SOACS
v, Int
j) -> Exp SOACS
v Exp SOACS -> Exp SOACS -> Bool
forall a. Eq a => a -> a -> Bool
== Exp SOACS
zero Bool -> Bool -> Bool
|| Int
i Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
j Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
m) ([(Exp SOACS, Int)] -> Bool) -> [(Exp SOACS, Int)] -> Bool
forall a b. (a -> b) -> a -> b
$
        [Exp SOACS] -> [Int] -> [(Exp SOACS, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Exp SOACS]
row [Int
0 .. Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]

cases :: Int -> Type -> [[Exp SOACS]] -> ScanAlgo
cases :: Int -> TypeBase Shape NoUniqueness -> [[Exp SOACS]] -> ScanAlgo
cases Int
d TypeBase Shape NoUniqueness
t [[Exp SOACS]]
mat = case Int -> [[Exp SOACS]] -> Exp SOACS -> Maybe Int
subMats Int
d [[Exp SOACS]]
mat (Exp SOACS -> Maybe Int) -> Exp SOACS -> Maybe Int
forall a b. (a -> b) -> a -> b
$ TypeBase Shape NoUniqueness -> Exp SOACS
forall rep. TypeBase Shape NoUniqueness -> Exp rep
zeroExp TypeBase Shape NoUniqueness
t of
  Just Int
k ->
    let nonZeros :: [[[Exp SOACS]]]
nonZeros = (Int -> [[Exp SOACS]] -> [[Exp SOACS]])
-> [Int] -> [[[Exp SOACS]]] -> [[[Exp SOACS]]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Int
i -> ([Exp SOACS] -> [Exp SOACS]) -> [[Exp SOACS]] -> [[Exp SOACS]]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> [Exp SOACS] -> [Exp SOACS]
forall a. Int -> [a] -> [a]
take Int
k ([Exp SOACS] -> [Exp SOACS])
-> ([Exp SOACS] -> [Exp SOACS]) -> [Exp SOACS] -> [Exp SOACS]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [Exp SOACS] -> [Exp SOACS]
forall a. Int -> [a] -> [a]
drop (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
k))) [Int
0 .. Int
d Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
k] ([[[Exp SOACS]]] -> [[[Exp SOACS]]])
-> [[[Exp SOACS]]] -> [[[Exp SOACS]]]
forall a b. (a -> b) -> a -> b
$ Int -> [[Exp SOACS]] -> [[[Exp SOACS]]]
forall a. Int -> [a] -> [[a]]
chunk Int
k [[Exp SOACS]]
mat
     in if ([[Exp SOACS]] -> Bool) -> [[[Exp SOACS]]] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ([[Exp SOACS]] -> [[Exp SOACS]] -> Bool
forall a. Eq a => a -> a -> Bool
== [[[Exp SOACS]]] -> [[Exp SOACS]]
forall a. HasCallStack => [a] -> a
head [[[Exp SOACS]]]
nonZeros) ([[[Exp SOACS]]] -> Bool) -> [[[Exp SOACS]]] -> Bool
forall a b. (a -> b) -> a -> b
$ [[[Exp SOACS]]] -> [[[Exp SOACS]]]
forall a. HasCallStack => [a] -> [a]
tail [[[Exp SOACS]]]
nonZeros
          then Special -> ScanAlgo
GenericIFL23 (Special -> ScanAlgo) -> Special -> ScanAlgo
forall a b. (a -> b) -> a -> b
$ (Int, Int)
-> (Int, Int) -> Int -> Int -> Maybe SpecialCase -> Special
Special (Int
d, Int
k) (Int
d, Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
k) Int
1 Int
k (Maybe SpecialCase -> Special) -> Maybe SpecialCase -> Special
forall a b. (a -> b) -> a -> b
$ SpecialCase -> Maybe SpecialCase
forall a. a -> Maybe a
Just SpecialCase
MatrixMul
          else Special -> ScanAlgo
GenericIFL23 (Special -> ScanAlgo) -> Special -> ScanAlgo
forall a b. (a -> b) -> a -> b
$ (Int, Int)
-> (Int, Int) -> Int -> Int -> Maybe SpecialCase -> Special
Special (Int
k, Int
k) (Int
k, Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
k) (Int
d Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
k) Int
k (Maybe SpecialCase -> Special) -> Maybe SpecialCase -> Special
forall a b. (a -> b) -> a -> b
$ SpecialCase -> Maybe SpecialCase
forall a. a -> Maybe a
Just SpecialCase
ZeroQuadrant
  Maybe Int
Nothing ->
    case Int
d of
      Int
1 -> Special -> ScanAlgo
GenericIFL23 (Special -> ScanAlgo) -> Special -> ScanAlgo
forall a b. (a -> b) -> a -> b
$ (Int, Int)
-> (Int, Int) -> Int -> Int -> Maybe SpecialCase -> Special
Special (Int
d, Int
d) (Int
d, Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
d) Int
1 Int
d Maybe SpecialCase
forall a. Maybe a
Nothing
      Int
_ -> ScanAlgo
GenericPPAD

-- | construct and optimise a temporary lambda, that calculates the
-- Jacobian of the scan op. Figure out if the Jacobian has some
-- special shape, discarding the temporary lambda.
identifyCase :: VjpOps -> Lambda SOACS -> ADM ScanAlgo
identifyCase :: VjpOps -> Lambda SOACS -> ADM ScanAlgo
identifyCase VjpOps
ops Lambda SOACS
lam = do
  let t :: [TypeBase Shape NoUniqueness]
t = Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
lam
  let d :: Int
d = [TypeBase Shape NoUniqueness] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TypeBase Shape NoUniqueness]
t
  [[SubExp]]
idmat <- Int -> TypeBase Shape NoUniqueness -> ADM [[SubExp]]
identityM Int
d (TypeBase Shape NoUniqueness -> ADM [[SubExp]])
-> TypeBase Shape NoUniqueness -> ADM [[SubExp]]
forall a b. (a -> b) -> a -> b
$ [TypeBase Shape NoUniqueness] -> TypeBase Shape NoUniqueness
forall a. HasCallStack => [a] -> a
head [TypeBase Shape NoUniqueness]
t
  [Lambda SOACS]
lams <- ([SubExp] -> ADM (Lambda SOACS))
-> [[SubExp]] -> ADM [Lambda SOACS]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (VjpOps
-> Lambda SOACS -> FirstOrSecond -> [SubExp] -> ADM (Lambda SOACS)
mkScanAdjointLam VjpOps
ops Lambda SOACS
lam FirstOrSecond
WrtFirst) [[SubExp]]
idmat
  [Param (TypeBase Shape NoUniqueness)]
par1 <- (TypeBase Shape NoUniqueness
 -> ADM (Param (TypeBase Shape NoUniqueness)))
-> [TypeBase Shape NoUniqueness]
-> ADM [Param (TypeBase Shape NoUniqueness)]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ([Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"tmp1") [TypeBase Shape NoUniqueness]
t
  [Param (TypeBase Shape NoUniqueness)]
par2 <- (TypeBase Shape NoUniqueness
 -> ADM (Param (TypeBase Shape NoUniqueness)))
-> [TypeBase Shape NoUniqueness]
-> ADM [Param (TypeBase Shape NoUniqueness)]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ([Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"tmp2") [TypeBase Shape NoUniqueness]
t
  Lambda SOACS
jac_lam <- [LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda ([Param (TypeBase Shape NoUniqueness)]
par1 [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ [Param (TypeBase Shape NoUniqueness)]
par2) (ADM Result -> ADM (Lambda (Rep ADM)))
-> ADM Result -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
    let args :: [ADM (Exp (Rep ADM))]
args = (Param (TypeBase Shape NoUniqueness) -> ADM (Exp (Rep ADM)))
-> [Param (TypeBase Shape NoUniqueness)] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Param (TypeBase Shape NoUniqueness) -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam ([Param (TypeBase Shape NoUniqueness)] -> [ADM (Exp (Rep ADM))])
-> [Param (TypeBase Shape NoUniqueness)] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ [Param (TypeBase Shape NoUniqueness)]
par1 [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ [Param (TypeBase Shape NoUniqueness)]
par2
    [Result]
lam_rs <- (Lambda SOACS -> ADM Result) -> [Lambda SOACS] -> ADM [Result]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM Result
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
`eLambda` [ADM (Exp (Rep ADM))]
args) [Lambda SOACS]
lams

    Result -> ADM Result
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> ADM Result) -> Result -> ADM Result
forall a b. (a -> b) -> a -> b
$ [Result] -> Result
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([Result] -> [Result]
forall a. [[a]] -> [[a]]
transpose [Result]
lam_rs)

  Lambda SOACS
simp <- Lambda SOACS -> ADM (Lambda SOACS)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Lambda SOACS -> m (Lambda SOACS)
simplifyLambda Lambda SOACS
jac_lam
  let jac :: [[Exp rep]]
jac = Int -> [Exp rep] -> [[Exp rep]]
forall a. Int -> [a] -> [[a]]
chunk Int
d ([Exp rep] -> [[Exp rep]]) -> [Exp rep] -> [[Exp rep]]
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> Exp rep) -> Result -> [Exp rep]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep)
-> (SubExpRes -> BasicOp) -> SubExpRes -> Exp rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp (SubExp -> BasicOp)
-> (SubExpRes -> SubExp) -> SubExpRes -> BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) (Result -> [Exp rep]) -> Result -> [Exp rep]
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult (Body SOACS -> Result) -> Body SOACS -> Result
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
simp
  ScanAlgo -> ADM ScanAlgo
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ScanAlgo -> ADM ScanAlgo) -> ScanAlgo -> ADM ScanAlgo
forall a b. (a -> b) -> a -> b
$ Int -> TypeBase Shape NoUniqueness -> [[Exp SOACS]] -> ScanAlgo
cases Int
d ([TypeBase Shape NoUniqueness] -> TypeBase Shape NoUniqueness
forall a. HasCallStack => [a] -> a
head [TypeBase Shape NoUniqueness]
t) [[Exp SOACS]]
forall {rep}. [[Exp rep]]
jac

scanRight :: [VName] -> SubExp -> Scan SOACS -> ADM [VName]
scanRight :: [VName] -> SubExp -> Scan SOACS -> ADM [VName]
scanRight [VName]
as SubExp
w Scan SOACS
scan = do
  [TypeBase Shape NoUniqueness]
as_types <- (VName -> ADM (TypeBase Shape NoUniqueness))
-> [VName] -> ADM [TypeBase Shape NoUniqueness]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType [VName]
as
  let arg_type_row :: [TypeBase Shape NoUniqueness]
arg_type_row = (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. TypeBase Shape u -> TypeBase Shape u
rowType [TypeBase Shape NoUniqueness]
as_types

  [Param (TypeBase Shape NoUniqueness)]
par_a1 <- (VName
 -> TypeBase Shape NoUniqueness
 -> ADM (Param (TypeBase Shape NoUniqueness)))
-> [VName]
-> [TypeBase Shape NoUniqueness]
-> ADM [Param (TypeBase Shape NoUniqueness)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_par_a1")) [VName]
as [TypeBase Shape NoUniqueness]
arg_type_row
  [Param (TypeBase Shape NoUniqueness)]
par_a2 <- (VName
 -> TypeBase Shape NoUniqueness
 -> ADM (Param (TypeBase Shape NoUniqueness)))
-> [VName]
-> [TypeBase Shape NoUniqueness]
-> ADM [Param (TypeBase Shape NoUniqueness)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_par_a2")) [VName]
as [TypeBase Shape NoUniqueness]
arg_type_row
  -- Just the original operator but with par_a1 and par_a2 swapped.
  Lambda SOACS
rev_op <- [LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda ([Param (TypeBase Shape NoUniqueness)]
par_a1 [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. Semigroup a => a -> a -> a
<> [Param (TypeBase Shape NoUniqueness)]
par_a2) (ADM Result -> ADM (Lambda (Rep ADM)))
-> ADM Result -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
    Lambda SOACS
op <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda (Lambda SOACS -> ADM (Lambda SOACS))
-> Lambda SOACS -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ Scan SOACS -> Lambda SOACS
forall rep. Scan rep -> Lambda rep
scanLambda Scan SOACS
scan
    Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM Result
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda (Rep ADM)
Lambda SOACS
op ((Param (TypeBase Shape NoUniqueness) -> ADM (Exp SOACS))
-> [Param (TypeBase Shape NoUniqueness)] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> ADM (Exp (Rep ADM))
VName -> ADM (Exp SOACS)
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *). MonadBuilder m => VName -> m (Exp (Rep m))
toExp (VName -> ADM (Exp SOACS))
-> (Param (TypeBase Shape NoUniqueness) -> VName)
-> Param (TypeBase Shape NoUniqueness)
-> ADM (Exp SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName) ([Param (TypeBase Shape NoUniqueness)]
par_a2 [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. Semigroup a => a -> a -> a
<> [Param (TypeBase Shape NoUniqueness)]
par_a1))
  -- same neutral element
  let e :: [SubExp]
e = Scan SOACS -> [SubExp]
forall rep. Scan rep -> [SubExp]
scanNeutral Scan SOACS
scan
  let rev_scan :: Scan SOACS
rev_scan = Lambda SOACS -> [SubExp] -> Scan SOACS
forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda SOACS
rev_op [SubExp]
e

  VName
iota <-
    [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"iota" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
  -- flip the input array (this code is inspired from the code in
  -- diffScanAdd, but made to work with [VName] instead VName)
  Lambda SOACS
map_scan <- [VName] -> ADM (Lambda SOACS)
revArrLam [VName]
as
  -- perform the scan
  [VName]
scan_res <-
    [Char] -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"adj_ctrb_scan" (Exp SOACS -> ADM [VName])
-> (ScremaForm SOACS -> Exp SOACS)
-> ScremaForm SOACS
-> ADM [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> Exp SOACS)
-> (ScremaForm SOACS -> SOAC SOACS)
-> ScremaForm SOACS
-> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName
iota] (ScremaForm SOACS -> ADM [VName])
-> ScremaForm SOACS -> ADM [VName]
forall a b. (a -> b) -> a -> b
$
      [Scan SOACS] -> Lambda SOACS -> ScremaForm SOACS
forall rep. [Scan rep] -> Lambda rep -> ScremaForm rep
scanomapSOAC [Scan SOACS
rev_scan] Lambda SOACS
map_scan
  -- flip the output array again
  Lambda SOACS
rev_lam <- [VName] -> ADM (Lambda SOACS)
revArrLam [VName]
scan_res
  [Char] -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"reverse_scan_result" (Exp (Rep ADM) -> ADM [VName]) -> Exp (Rep ADM) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM))
-> OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName
iota] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
rev_lam
  where
    revArrLam :: [VName] -> ADM (Lambda SOACS)
    revArrLam :: [VName] -> ADM (Lambda SOACS)
revArrLam [VName]
arrs = do
      Param (TypeBase Shape NoUniqueness)
par_i <- [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"i" (TypeBase Shape NoUniqueness
 -> ADM (Param (TypeBase Shape NoUniqueness)))
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
      [LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [Param (TypeBase Shape NoUniqueness)
LParam (Rep ADM)
par_i] (ADM Result -> ADM (Lambda SOACS))
-> ((VName -> ADM SubExpRes) -> ADM Result)
-> (VName -> ADM SubExpRes)
-> ADM (Lambda SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [VName] -> (VName -> ADM SubExpRes) -> ADM Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
arrs ((VName -> ADM SubExpRes) -> ADM (Lambda SOACS))
-> (VName -> ADM SubExpRes) -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ \VName
arr ->
        (VName -> SubExpRes) -> ADM VName -> ADM SubExpRes
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> SubExpRes
varRes (ADM VName -> ADM SubExpRes)
-> (Exp SOACS -> ADM VName) -> Exp SOACS -> ADM SubExpRes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"ys_bar_rev"
          (Exp SOACS -> ADM SubExpRes) -> ADM (Exp SOACS) -> ADM SubExpRes
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> [ADM (Exp (Rep ADM))] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
eIndex VName
arr [TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
w TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
par_i) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)]

mkPPADOpLifted :: VjpOps -> [VName] -> Scan SOACS -> ADM (Lambda SOACS)
mkPPADOpLifted :: VjpOps -> [VName] -> Scan SOACS -> ADM (Lambda SOACS)
mkPPADOpLifted VjpOps
ops [VName]
as Scan SOACS
scan = do
  [TypeBase Shape NoUniqueness]
as_types <- (VName -> ADM (TypeBase Shape NoUniqueness))
-> [VName] -> ADM [TypeBase Shape NoUniqueness]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType [VName]
as
  let arg_type_row :: [TypeBase Shape NoUniqueness]
arg_type_row = (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. TypeBase Shape u -> TypeBase Shape u
rowType [TypeBase Shape NoUniqueness]
as_types
  [Param (TypeBase Shape NoUniqueness)]
par_x1 <- (VName
 -> TypeBase Shape NoUniqueness
 -> ADM (Param (TypeBase Shape NoUniqueness)))
-> [VName]
-> [TypeBase Shape NoUniqueness]
-> ADM [Param (TypeBase Shape NoUniqueness)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_par_x1")) [VName]
as [TypeBase Shape NoUniqueness]
arg_type_row
  [Param (TypeBase Shape NoUniqueness)]
par_x2_unused <- (VName
 -> TypeBase Shape NoUniqueness
 -> ADM (Param (TypeBase Shape NoUniqueness)))
-> [VName]
-> [TypeBase Shape NoUniqueness]
-> ADM [Param (TypeBase Shape NoUniqueness)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_par_x2_unused")) [VName]
as [TypeBase Shape NoUniqueness]
arg_type_row
  [Param (TypeBase Shape NoUniqueness)]
par_a1 <- (VName
 -> TypeBase Shape NoUniqueness
 -> ADM (Param (TypeBase Shape NoUniqueness)))
-> [VName]
-> [TypeBase Shape NoUniqueness]
-> ADM [Param (TypeBase Shape NoUniqueness)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_par_a1")) [VName]
as [TypeBase Shape NoUniqueness]
arg_type_row
  [Param (TypeBase Shape NoUniqueness)]
par_a2 <- (VName
 -> TypeBase Shape NoUniqueness
 -> ADM (Param (TypeBase Shape NoUniqueness)))
-> [VName]
-> [TypeBase Shape NoUniqueness]
-> ADM [Param (TypeBase Shape NoUniqueness)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_par_a2")) [VName]
as [TypeBase Shape NoUniqueness]
arg_type_row
  [Param (TypeBase Shape NoUniqueness)]
par_y1_h <- (VName
 -> TypeBase Shape NoUniqueness
 -> ADM (Param (TypeBase Shape NoUniqueness)))
-> [VName]
-> [TypeBase Shape NoUniqueness]
-> ADM [Param (TypeBase Shape NoUniqueness)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_par_y1_h")) [VName]
as [TypeBase Shape NoUniqueness]
arg_type_row
  [Param (TypeBase Shape NoUniqueness)]
par_y2_h <- (VName
 -> TypeBase Shape NoUniqueness
 -> ADM (Param (TypeBase Shape NoUniqueness)))
-> [VName]
-> [TypeBase Shape NoUniqueness]
-> ADM [Param (TypeBase Shape NoUniqueness)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_par_y2_h")) [VName]
as [TypeBase Shape NoUniqueness]
arg_type_row

  [Lambda SOACS]
add_lams <- (TypeBase Shape NoUniqueness -> ADM (Lambda SOACS))
-> [TypeBase Shape NoUniqueness] -> ADM [Lambda SOACS]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM TypeBase Shape NoUniqueness -> ADM (Lambda SOACS)
addLambda [TypeBase Shape NoUniqueness]
arg_type_row

  [LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda
    ([Param (TypeBase Shape NoUniqueness)]
par_x1 [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ [Param (TypeBase Shape NoUniqueness)]
par_a1 [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ [Param (TypeBase Shape NoUniqueness)]
par_y1_h [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ [Param (TypeBase Shape NoUniqueness)]
par_x2_unused [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ [Param (TypeBase Shape NoUniqueness)]
par_a2 [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ [Param (TypeBase Shape NoUniqueness)]
par_y2_h)
    ([Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [Lambda SOACS]
-> ADM Result
forall {dec} {dec} {dec}.
[Param dec]
-> [Param dec]
-> [Param dec]
-> [Param dec]
-> [Param dec]
-> [Lambda SOACS]
-> ADM Result
op_lift [Param (TypeBase Shape NoUniqueness)]
par_x1 [Param (TypeBase Shape NoUniqueness)]
par_a1 [Param (TypeBase Shape NoUniqueness)]
par_y1_h [Param (TypeBase Shape NoUniqueness)]
par_a2 [Param (TypeBase Shape NoUniqueness)]
par_y2_h [Lambda SOACS]
add_lams)
  where
    op_lift :: [Param dec]
-> [Param dec]
-> [Param dec]
-> [Param dec]
-> [Param dec]
-> [Lambda SOACS]
-> ADM Result
op_lift [Param dec]
px1 [Param dec]
pa1 [Param dec]
py1 [Param dec]
pa2 [Param dec]
py2 [Lambda SOACS]
adds = do
      Lambda SOACS
op_bar_1 <- VjpOps
-> Lambda SOACS -> FirstOrSecond -> [SubExp] -> ADM (Lambda SOACS)
mkScanAdjointLam VjpOps
ops (Scan SOACS -> Lambda SOACS
forall rep. Scan rep -> Lambda rep
scanLambda Scan SOACS
scan) FirstOrSecond
WrtFirst (VName -> SubExp
Var (VName -> SubExp) -> (Param dec -> VName) -> Param dec -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> VName
forall dec. Param dec -> VName
paramName (Param dec -> SubExp) -> [Param dec] -> [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Param dec]
py2)
      let op_bar_args :: [ADM (Exp (Rep ADM))]
op_bar_args = SubExp -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
toExp (SubExp -> ADM (Exp (Rep ADM)))
-> (Param dec -> SubExp) -> Param dec -> ADM (Exp (Rep ADM))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> SubExp) -> (Param dec -> VName) -> Param dec -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> VName
forall dec. Param dec -> VName
paramName (Param dec -> ADM (Exp (Rep ADM)))
-> [Param dec] -> [ADM (Exp (Rep ADM))]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Param dec]
px1 [Param dec] -> [Param dec] -> [Param dec]
forall a. [a] -> [a] -> [a]
++ [Param dec]
pa1
      [SubExp]
z_term <- (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (Result -> [SubExp]) -> ADM Result -> ADM [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM Result
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda (Rep ADM)
Lambda SOACS
op_bar_1 [ADM (Exp (Rep ADM))]
op_bar_args
      let z :: ADM Result
z =
            ((SubExp, SubExp, Lambda (Rep ADM)) -> ADM SubExpRes)
-> [(SubExp, SubExp, Lambda (Rep ADM))] -> ADM Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM
              (\(SubExp
z_t, SubExp
y_1, Lambda (Rep ADM)
add) -> Result -> SubExpRes
forall a. HasCallStack => [a] -> a
head (Result -> SubExpRes) -> ADM Result -> ADM SubExpRes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM Result
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda (Rep ADM)
add [SubExp -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
toExp SubExp
z_t, SubExp -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
toExp SubExp
y_1])
              ([SubExp]
-> [SubExp]
-> [Lambda (Rep ADM)]
-> [(SubExp, SubExp, Lambda (Rep ADM))]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [SubExp]
z_term (VName -> SubExp
Var (VName -> SubExp) -> (Param dec -> VName) -> Param dec -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> VName
forall dec. Param dec -> VName
paramName (Param dec -> SubExp) -> [Param dec] -> [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Param dec]
py1) [Lambda (Rep ADM)]
[Lambda SOACS]
adds)

      let x1 :: ADM Result
x1 = [SubExp] -> Result
subExpsRes ([SubExp] -> Result) -> ADM [SubExp] -> ADM Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Param dec -> ADM SubExp) -> [Param dec] -> ADM [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ([Char] -> SubExp -> ADM SubExp
forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
[Char] -> a -> m SubExp
toSubExp [Char]
"x1" (SubExp -> ADM SubExp)
-> (Param dec -> SubExp) -> Param dec -> ADM SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> SubExp) -> (Param dec -> VName) -> Param dec -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> VName
forall dec. Param dec -> VName
paramName) [Param dec]
px1
      Lambda SOACS
op <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda (Lambda SOACS -> ADM (Lambda SOACS))
-> Lambda SOACS -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ Scan SOACS -> Lambda SOACS
forall rep. Scan rep -> Lambda rep
scanLambda Scan SOACS
scan
      let a3 :: ADM Result
a3 = Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM Result
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda (Rep ADM)
Lambda SOACS
op (VName -> ADM (Exp (Rep ADM))
VName -> ADM (Exp SOACS)
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *). MonadBuilder m => VName -> m (Exp (Rep m))
toExp (VName -> ADM (Exp SOACS))
-> (Param dec -> VName) -> Param dec -> ADM (Exp SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> VName
forall dec. Param dec -> VName
paramName (Param dec -> ADM (Exp SOACS)) -> [Param dec] -> [ADM (Exp SOACS)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Param dec]
pa1 [Param dec] -> [Param dec] -> [Param dec]
forall a. [a] -> [a] -> [a]
++ [Param dec]
pa2)

      [Result] -> Result
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([Result] -> Result) -> ADM [Result] -> ADM Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [ADM Result] -> ADM [Result]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence [ADM Result
x1, ADM Result
a3, ADM Result
z]

asLiftPPAD :: [VName] -> SubExp -> [SubExp] -> ADM [VName]
asLiftPPAD :: [VName] -> SubExp -> [SubExp] -> ADM [VName]
asLiftPPAD [VName]
as SubExp
w [SubExp]
e = do
  Param (TypeBase Shape NoUniqueness)
par_i <- [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"i" (TypeBase Shape NoUniqueness
 -> ADM (Param (TypeBase Shape NoUniqueness)))
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  Lambda SOACS
lmb <- [LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [Param (TypeBase Shape NoUniqueness)
LParam (Rep ADM)
par_i] (ADM Result -> ADM (Lambda (Rep ADM)))
-> ADM Result -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
    [(VName, SubExp)]
-> ((VName, SubExp) -> ADM SubExpRes) -> ADM Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
as [SubExp]
e) (((VName, SubExp) -> ADM SubExpRes) -> ADM Result)
-> ((VName, SubExp) -> ADM SubExpRes) -> ADM Result
forall a b. (a -> b) -> a -> b
$ \(VName
arr, SubExp
arr_e) -> do
      VName
a_lift <-
        [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"a_lift"
          (Exp SOACS -> ADM VName) -> ADM (Exp SOACS) -> ADM VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
            ( do
                SubExp
nm1 <- [Char] -> TPrimExp Int64 VName -> ADM SubExp
forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
[Char] -> a -> m SubExp
toSubExp [Char]
"n_minus_one" (TPrimExp Int64 VName -> ADM SubExp)
-> TPrimExp Int64 VName -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 SubExp
w TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1
                Exp SOACS -> ADM (Exp SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp SOACS -> ADM (Exp SOACS)) -> Exp SOACS -> ADM (Exp SOACS)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (IntType -> CmpOp
CmpSlt IntType
Int64) (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
par_i) SubExp
nm1
            )
            ( ADM Result -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (ADM Result -> ADM (Body (Rep ADM)))
-> ADM Result -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ (\SubExp
x -> [SubExp -> SubExpRes
subExpRes SubExp
x]) (SubExp -> Result) -> ADM SubExp -> ADM Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"val" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> [ADM (Exp (Rep ADM))] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
eIndex VName
arr [TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (TPrimExp Int64 VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
par_i) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1])
            )
            (ADM Result -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (ADM Result -> ADM (Body (Rep ADM)))
-> ADM Result -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ Result -> ADM Result
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp -> SubExpRes
subExpRes SubExp
arr_e])
      SubExpRes -> ADM SubExpRes
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExpRes -> ADM SubExpRes) -> SubExpRes -> ADM SubExpRes
forall a b. (a -> b) -> a -> b
$ VName -> SubExpRes
varRes VName
a_lift

  VName
iota <- [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"iota" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
  [Char] -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"as_lift" (Exp (Rep ADM) -> ADM [VName]) -> Exp (Rep ADM) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM))
-> OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName
iota] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lmb

ysRightPPAD :: [VName] -> SubExp -> [SubExp] -> ADM [VName]
ysRightPPAD :: [VName] -> SubExp -> [SubExp] -> ADM [VName]
ysRightPPAD [VName]
ys SubExp
w [SubExp]
e = do
  Param (TypeBase Shape NoUniqueness)
par_i <- [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"i" (TypeBase Shape NoUniqueness
 -> ADM (Param (TypeBase Shape NoUniqueness)))
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  Lambda SOACS
lmb <- [LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [Param (TypeBase Shape NoUniqueness)
LParam (Rep ADM)
par_i] (ADM Result -> ADM (Lambda (Rep ADM)))
-> ADM Result -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
    [(VName, SubExp)]
-> ((VName, SubExp) -> ADM SubExpRes) -> ADM Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ys [SubExp]
e) (((VName, SubExp) -> ADM SubExpRes) -> ADM Result)
-> ((VName, SubExp) -> ADM SubExpRes) -> ADM Result
forall a b. (a -> b) -> a -> b
$ \(VName
arr, SubExp
arr_e) -> do
      VName
a_lift <-
        [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"y_right"
          (Exp SOACS -> ADM VName) -> ADM (Exp SOACS) -> ADM VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
            ( Exp (Rep ADM) -> ADM (Exp (Rep ADM))
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Rep ADM) -> ADM (Exp (Rep ADM)))
-> Exp (Rep ADM) -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (PrimType -> CmpOp
CmpEq PrimType
int64) (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
par_i) (Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64))
            )
            (ADM Result -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (ADM Result -> ADM (Body (Rep ADM)))
-> ADM Result -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ Result -> ADM Result
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp -> SubExpRes
subExpRes SubExp
arr_e])
            ( ADM Result -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (ADM Result -> ADM (Body (Rep ADM)))
-> ADM Result -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ (\SubExp
x -> [SubExp -> SubExpRes
subExpRes SubExp
x]) (SubExp -> Result) -> ADM SubExp -> ADM Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"val" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> [ADM (Exp (Rep ADM))] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
eIndex VName
arr [TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (TPrimExp Int64 VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
par_i) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1])
            )
      SubExpRes -> ADM SubExpRes
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExpRes -> ADM SubExpRes) -> SubExpRes -> ADM SubExpRes
forall a b. (a -> b) -> a -> b
$ VName -> SubExpRes
varRes VName
a_lift

  VName
iota <- [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"iota" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
  [Char] -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"ys_right" (Exp (Rep ADM) -> ADM [VName]) -> Exp (Rep ADM) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM))
-> OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName
iota] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lmb

finalMapPPAD :: VjpOps -> [VName] -> Scan SOACS -> ADM (Lambda SOACS)
finalMapPPAD :: VjpOps -> [VName] -> Scan SOACS -> ADM (Lambda SOACS)
finalMapPPAD VjpOps
ops [VName]
as Scan SOACS
scan = do
  [TypeBase Shape NoUniqueness]
as_types <- (VName -> ADM (TypeBase Shape NoUniqueness))
-> [VName] -> ADM [TypeBase Shape NoUniqueness]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType [VName]
as
  let arg_type_row :: [TypeBase Shape NoUniqueness]
arg_type_row = (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. TypeBase Shape u -> TypeBase Shape u
rowType [TypeBase Shape NoUniqueness]
as_types
  [Param (TypeBase Shape NoUniqueness)]
par_y_right <- (VName
 -> TypeBase Shape NoUniqueness
 -> ADM (Param (TypeBase Shape NoUniqueness)))
-> [VName]
-> [TypeBase Shape NoUniqueness]
-> ADM [Param (TypeBase Shape NoUniqueness)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_par_y_right")) [VName]
as [TypeBase Shape NoUniqueness]
arg_type_row
  [Param (TypeBase Shape NoUniqueness)]
par_a <- (VName
 -> TypeBase Shape NoUniqueness
 -> ADM (Param (TypeBase Shape NoUniqueness)))
-> [VName]
-> [TypeBase Shape NoUniqueness]
-> ADM [Param (TypeBase Shape NoUniqueness)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_par_a")) [VName]
as [TypeBase Shape NoUniqueness]
arg_type_row
  [Param (TypeBase Shape NoUniqueness)]
par_r_adj <- (VName
 -> TypeBase Shape NoUniqueness
 -> ADM (Param (TypeBase Shape NoUniqueness)))
-> [VName]
-> [TypeBase Shape NoUniqueness]
-> ADM [Param (TypeBase Shape NoUniqueness)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_par_r_adj")) [VName]
as [TypeBase Shape NoUniqueness]
arg_type_row

  [LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda ([Param (TypeBase Shape NoUniqueness)]
par_y_right [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ [Param (TypeBase Shape NoUniqueness)]
par_a [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ [Param (TypeBase Shape NoUniqueness)]
par_r_adj) (ADM Result -> ADM (Lambda (Rep ADM)))
-> ADM Result -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
    Lambda SOACS
op_bar_2 <- VjpOps
-> Lambda SOACS -> FirstOrSecond -> [SubExp] -> ADM (Lambda SOACS)
mkScanAdjointLam VjpOps
ops (Scan SOACS -> Lambda SOACS
forall rep. Scan rep -> Lambda rep
scanLambda Scan SOACS
scan) FirstOrSecond
WrtSecond (VName -> SubExp
Var (VName -> SubExp)
-> (Param (TypeBase Shape NoUniqueness) -> VName)
-> Param (TypeBase Shape NoUniqueness)
-> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName (Param (TypeBase Shape NoUniqueness) -> SubExp)
-> [Param (TypeBase Shape NoUniqueness)] -> [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Param (TypeBase Shape NoUniqueness)]
par_r_adj)
    Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM Result
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda (Rep ADM)
Lambda SOACS
op_bar_2 ([ADM (Exp (Rep ADM))] -> ADM Result)
-> [ADM (Exp (Rep ADM))] -> ADM Result
forall a b. (a -> b) -> a -> b
$ SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
toExp (SubExp -> ADM (Exp SOACS))
-> (Param (TypeBase Shape NoUniqueness) -> SubExp)
-> Param (TypeBase Shape NoUniqueness)
-> ADM (Exp SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> SubExp)
-> (Param (TypeBase Shape NoUniqueness) -> VName)
-> Param (TypeBase Shape NoUniqueness)
-> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName (Param (TypeBase Shape NoUniqueness) -> ADM (Exp SOACS))
-> [Param (TypeBase Shape NoUniqueness)] -> [ADM (Exp SOACS)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Param (TypeBase Shape NoUniqueness)]
par_y_right [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ [Param (TypeBase Shape NoUniqueness)]
par_a

diffScan :: VjpOps -> [VName] -> SubExp -> [VName] -> Scan SOACS -> ADM ()
diffScan :: VjpOps -> [VName] -> SubExp -> [VName] -> Scan SOACS -> ADM ()
diffScan VjpOps
ops [VName]
ys SubExp
w [VName]
as Scan SOACS
scan = do
  -- ys ~ results of scan, w ~ size of input array, as ~ (unzipped)
  -- arrays, scan ~ scan: operator with ne
  ScanAlgo
scan_case <- VjpOps -> Lambda SOACS -> ADM ScanAlgo
identifyCase VjpOps
ops (Lambda SOACS -> ADM ScanAlgo) -> Lambda SOACS -> ADM ScanAlgo
forall a b. (a -> b) -> a -> b
$ Scan SOACS -> Lambda SOACS
forall rep. Scan rep -> Lambda rep
scanLambda Scan SOACS
scan
  let d :: Int
d = [VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
as
  [VName]
ys_adj <- (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM VName
lookupAdjVal [VName]
ys -- ys_bar
  [TypeBase Shape NoUniqueness]
as_ts <- (VName -> ADM (TypeBase Shape NoUniqueness))
-> [VName] -> ADM [TypeBase Shape NoUniqueness]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType [VName]
as

  [VName]
as_contribs <- case ScanAlgo
scan_case of
    ScanAlgo
GenericPPAD -> do
      let e :: [SubExp]
e = Scan SOACS -> [SubExp]
forall rep. Scan rep -> [SubExp]
scanNeutral Scan SOACS
scan
      [VName]
as_lift <- [VName] -> SubExp -> [SubExp] -> ADM [VName]
asLiftPPAD [VName]
as SubExp
w [SubExp]
e

      let m :: [VName]
m = [VName]
ys [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
as_lift [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
ys_adj

      Lambda SOACS
op_lft <- VjpOps -> [VName] -> Scan SOACS -> ADM (Lambda SOACS)
mkPPADOpLifted VjpOps
ops [VName]
as Scan SOACS
scan
      [SubExp]
a_zero <- (TypeBase Shape NoUniqueness -> ADM SubExp)
-> [TypeBase Shape NoUniqueness] -> ADM [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((VName -> SubExp) -> ADM VName -> ADM SubExp
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> SubExp
Var (ADM VName -> ADM SubExp)
-> (TypeBase Shape NoUniqueness -> ADM VName)
-> TypeBase Shape NoUniqueness
-> ADM SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"rscan_zero" (Exp SOACS -> ADM VName)
-> (TypeBase Shape NoUniqueness -> Exp SOACS)
-> TypeBase Shape NoUniqueness
-> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeBase Shape NoUniqueness -> Exp SOACS
forall rep. TypeBase Shape NoUniqueness -> Exp rep
zeroExp (TypeBase Shape NoUniqueness -> Exp SOACS)
-> (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> TypeBase Shape NoUniqueness
-> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. TypeBase Shape u -> TypeBase Shape u
rowType) [TypeBase Shape NoUniqueness]
as_ts
      let lft_scan :: Scan SOACS
lft_scan = Lambda SOACS -> [SubExp] -> Scan SOACS
forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda SOACS
op_lft ([SubExp] -> Scan SOACS) -> [SubExp] -> Scan SOACS
forall a b. (a -> b) -> a -> b
$ [SubExp]
e [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
e [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
a_zero
      [VName]
rs_adj <- ([[VName]] -> Int -> [VName]
forall a. HasCallStack => [a] -> Int -> a
!! Int
2) ([[VName]] -> [VName])
-> ([VName] -> [[VName]]) -> [VName] -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [VName] -> [[VName]]
forall a. Int -> [a] -> [[a]]
chunk Int
d ([VName] -> [VName]) -> ADM [VName] -> ADM [VName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [VName] -> SubExp -> Scan SOACS -> ADM [VName]
scanRight [VName]
m SubExp
w Scan SOACS
lft_scan

      [VName]
ys_right <- [VName] -> SubExp -> [SubExp] -> ADM [VName]
ysRightPPAD [VName]
ys SubExp
w [SubExp]
e

      Lambda SOACS
final_lmb <- VjpOps -> [VName] -> Scan SOACS -> ADM (Lambda SOACS)
finalMapPPAD VjpOps
ops [VName]
as Scan SOACS
scan
      [Char] -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"as_bar" (Exp (Rep ADM) -> ADM [VName]) -> Exp (Rep ADM) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM))
-> OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w ([VName]
ys_right [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
as [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
rs_adj) (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
final_lmb
    GenericIFL23 Special
sc -> do
      -- IFL23
      Lambda SOACS
map1_lam <- VjpOps
-> SubExp
-> Lambda SOACS
-> [VName]
-> [VName]
-> [VName]
-> Special
-> Int
-> ADM (Lambda SOACS)
mkScanFusedMapLam VjpOps
ops SubExp
w (Scan SOACS -> Lambda SOACS
forall rep. Scan rep -> Lambda rep
scanLambda Scan SOACS
scan) [VName]
as [VName]
ys [VName]
ys_adj Special
sc Int
d
      Scan SOACS
scans_lin_fun_o <- TypeBase Shape NoUniqueness -> Special -> ADM (Scan SOACS)
mkScanLinFunO ([TypeBase Shape NoUniqueness] -> TypeBase Shape NoUniqueness
forall a. HasCallStack => [a] -> a
head [TypeBase Shape NoUniqueness]
as_ts) Special
sc
      [Scan SOACS]
scan_lams <- Int -> Scan SOACS -> ADM [Scan SOACS]
mkScans (Special -> Int
specialScans Special
sc) Scan SOACS
scans_lin_fun_o
      VName
iota <-
        [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"iota" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
      [VName]
r_scan <-
        [Char] -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"adj_ctrb_scan" (Exp SOACS -> ADM [VName])
-> (ScremaForm SOACS -> Exp SOACS)
-> ScremaForm SOACS
-> ADM [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> Exp SOACS)
-> (ScremaForm SOACS -> SOAC SOACS)
-> ScremaForm SOACS
-> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName
iota] (ScremaForm SOACS -> ADM [VName])
-> ScremaForm SOACS -> ADM [VName]
forall a b. (a -> b) -> a -> b
$
          [Scan SOACS] -> Lambda SOACS -> ScremaForm SOACS
forall rep. [Scan rep] -> Lambda rep -> ScremaForm rep
scanomapSOAC [Scan SOACS]
scan_lams Lambda SOACS
map1_lam
      VjpOps
-> SubExp
-> Lambda SOACS
-> [VName]
-> [VName]
-> [VName]
-> ADM [VName]
mkScanFinalMap VjpOps
ops SubExp
w (Scan SOACS -> Lambda SOACS
forall rep. Scan rep -> Lambda rep
scanLambda Scan SOACS
scan) [VName]
as [VName]
ys (Special -> [VName] -> Int -> [VName]
forall {b}. Special -> [b] -> Int -> [b]
splitScanRes Special
sc [VName]
r_scan Int
d)
  -- Goal: calculate as_contribs in new way
  -- zipWithM_ updateAdj as as_contribs -- as_bar += new adjoint
  (VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
updateAdj [VName]
as [VName]
as_contribs
  where
    mkScans :: Int -> Scan SOACS -> ADM [Scan SOACS]
    mkScans :: Int -> Scan SOACS -> ADM [Scan SOACS]
mkScans Int
d Scan SOACS
s =
      Int -> ADM (Scan SOACS) -> ADM [Scan SOACS]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
d (ADM (Scan SOACS) -> ADM [Scan SOACS])
-> ADM (Scan SOACS) -> ADM [Scan SOACS]
forall a b. (a -> b) -> a -> b
$ do
        Lambda SOACS
lam' <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda (Lambda SOACS -> ADM (Lambda SOACS))
-> Lambda SOACS -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ Scan SOACS -> Lambda SOACS
forall rep. Scan rep -> Lambda rep
scanLambda Scan SOACS
s
        Scan SOACS -> ADM (Scan SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Scan SOACS -> ADM (Scan SOACS)) -> Scan SOACS -> ADM (Scan SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [SubExp] -> Scan SOACS
forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda SOACS
lam' ([SubExp] -> Scan SOACS) -> [SubExp] -> Scan SOACS
forall a b. (a -> b) -> a -> b
$ Scan SOACS -> [SubExp]
forall rep. Scan rep -> [SubExp]
scanNeutral Scan SOACS
s
    splitScanRes :: Special -> [b] -> Int -> [b]
splitScanRes Special
sc [b]
res Int
d =
      ([b] -> [b]) -> [[b]] -> [b]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Int -> [b] -> [b]
forall a. Int -> [a] -> [a]
take (Int -> Int -> Int
forall a. Integral a => a -> a -> a
div Int
d (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ Special -> Int
specialScans Special
sc)) (Special -> [b] -> [[b]]
forall a. Special -> [a] -> [[a]]
orderArgs Special
sc [b]
res)

diffScanVec ::
  VjpOps ->
  [VName] ->
  StmAux () ->
  SubExp ->
  Lambda SOACS ->
  [SubExp] ->
  [VName] ->
  ADM () ->
  ADM ()
diffScanVec :: VjpOps
-> [VName]
-> StmAux ()
-> SubExp
-> Lambda SOACS
-> [SubExp]
-> [VName]
-> ADM ()
-> ADM ()
diffScanVec VjpOps
ops [VName]
ys StmAux ()
aux SubExp
w Lambda SOACS
lam [SubExp]
ne [VName]
as ADM ()
m = do
  Seq (Stm SOACS)
stmts <- ADM [()] -> ADM (Stms (Rep ADM))
forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ (ADM [()] -> ADM (Stms (Rep ADM)))
-> ADM [()] -> ADM (Stms (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
    Int
rank <- TypeBase Shape NoUniqueness -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (TypeBase Shape NoUniqueness -> Int)
-> ADM (TypeBase Shape NoUniqueness) -> ADM Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ADM (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType ([VName] -> VName
forall a. HasCallStack => [a] -> a
head [VName]
as)
    let rear :: [Int]
rear = [Int
1, Int
0] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
drop Int
2 [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]

    [VName]
transp_as <-
      [VName] -> (VName -> ADM VName) -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
as ((VName -> ADM VName) -> ADM [VName])
-> (VName -> ADM VName) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ \VName
a ->
        [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
a [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_transp") (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ VName -> [Int] -> BasicOp
Rearrange VName
a [Int]
rear

    [TypeBase Shape NoUniqueness]
ts <- (VName -> ADM (TypeBase Shape NoUniqueness))
-> [VName] -> ADM [TypeBase Shape NoUniqueness]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse VName -> ADM (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType [VName]
transp_as
    let n :: SubExp
n = Int -> [TypeBase Shape NoUniqueness] -> SubExp
forall u. Int -> [TypeBase Shape u] -> SubExp
arraysSize Int
0 [TypeBase Shape NoUniqueness]
ts

    [Param (TypeBase Shape NoUniqueness)]
as_par <- (TypeBase Shape NoUniqueness
 -> ADM (Param (TypeBase Shape NoUniqueness)))
-> [TypeBase Shape NoUniqueness]
-> ADM [Param (TypeBase Shape NoUniqueness)]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ([Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"as_par" (TypeBase Shape NoUniqueness
 -> ADM (Param (TypeBase Shape NoUniqueness)))
-> (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. TypeBase Shape u -> TypeBase Shape u
rowType) [TypeBase Shape NoUniqueness]
ts
    [Param (TypeBase Shape NoUniqueness)]
ne_par <- (TypeBase Shape NoUniqueness
 -> ADM (Param (TypeBase Shape NoUniqueness)))
-> [TypeBase Shape NoUniqueness]
-> ADM [Param (TypeBase Shape NoUniqueness)]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ([Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"ne_par") ([TypeBase Shape NoUniqueness]
 -> ADM [Param (TypeBase Shape NoUniqueness)])
-> [TypeBase Shape NoUniqueness]
-> ADM [Param (TypeBase Shape NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
lam

    ScremaForm SOACS
scan_form <- [Scan SOACS] -> ADM (ScremaForm SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Scan rep] -> m (ScremaForm rep)
scanSOAC [Lambda SOACS -> [SubExp] -> Scan SOACS
forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda SOACS
lam ((Param (TypeBase Shape NoUniqueness) -> SubExp)
-> [Param (TypeBase Shape NoUniqueness)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (Param (TypeBase Shape NoUniqueness) -> VName)
-> Param (TypeBase Shape NoUniqueness)
-> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName) [Param (TypeBase Shape NoUniqueness)]
ne_par)]

    Lambda SOACS
map_lam <-
      [LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda ([Param (TypeBase Shape NoUniqueness)]
as_par [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ [Param (TypeBase Shape NoUniqueness)]
ne_par) (ADM Result -> ADM (Lambda SOACS))
-> (SOAC SOACS -> ADM Result) -> SOAC SOACS -> ADM (Lambda SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([VName] -> Result) -> ADM [VName] -> ADM Result
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes (ADM [VName] -> ADM Result)
-> (SOAC SOACS -> ADM [VName]) -> SOAC SOACS -> ADM Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"map_res" (Exp SOACS -> ADM [VName])
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM (Lambda SOACS))
-> SOAC SOACS -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$
        SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w ((Param (TypeBase Shape NoUniqueness) -> VName)
-> [Param (TypeBase Shape NoUniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName [Param (TypeBase Shape NoUniqueness)]
as_par) ScremaForm SOACS
scan_form

    [VName]
transp_ys <-
      [Char] -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"trans_ys" (Exp SOACS -> ADM [VName])
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM [VName]) -> SOAC SOACS -> ADM [VName]
forall a b. (a -> b) -> a -> b
$
        SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n ([VName]
transp_as [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [SubExp] -> [VName]
subExpVars [SubExp]
ne) (Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
map_lam)

    [(VName, VName)] -> ((VName, VName) -> ADM ()) -> ADM [()]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ys [VName]
transp_ys) (((VName, VName) -> ADM ()) -> ADM [()])
-> ((VName, VName) -> ADM ()) -> ADM [()]
forall a b. (a -> b) -> a -> b
$ \(VName
y, VName
x) ->
      StmAux () -> ADM () -> ADM ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
y] (Exp (Rep ADM) -> ADM ()) -> Exp (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ VName -> [Int] -> BasicOp
Rearrange VName
x [Int]
rear

  (Stm SOACS -> ADM () -> ADM ())
-> ADM () -> Seq (Stm SOACS) -> ADM ()
forall a b. (a -> b -> b) -> b -> Seq a -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (VjpOps -> Stm SOACS -> ADM () -> ADM ()
vjpStm VjpOps
ops) ADM ()
m Seq (Stm SOACS)
stmts

diffScanAdd :: VjpOps -> VName -> SubExp -> Lambda SOACS -> SubExp -> VName -> ADM ()
diffScanAdd :: VjpOps
-> VName -> SubExp -> Lambda SOACS -> SubExp -> VName -> ADM ()
diffScanAdd VjpOps
_ops VName
ys SubExp
n Lambda SOACS
lam' SubExp
ne VName
as = do
  Lambda SOACS
lam <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam'
  VName
ys_bar <- VName -> ADM VName
lookupAdjVal VName
ys

  Lambda SOACS
map_scan <- VName -> ADM (Lambda SOACS)
rev_arr_lam VName
ys_bar

  VName
iota <-
    [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"iota" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
n (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64

  VName
scan_res <-
    [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"res_rev" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM))
-> OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
iota] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ [Scan SOACS] -> Lambda SOACS -> ScremaForm SOACS
forall rep. [Scan rep] -> Lambda rep -> ScremaForm rep
scanomapSOAC [Lambda SOACS -> [SubExp] -> Scan SOACS
forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda SOACS
lam [SubExp
ne]] Lambda SOACS
map_scan

  Lambda SOACS
rev_lam <- VName -> ADM (Lambda SOACS)
rev_arr_lam VName
scan_res
  VName
contrb <- [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"contrb" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM))
-> OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
iota] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
rev_lam

  VName -> VName -> ADM ()
updateAdj VName
as VName
contrb
  where
    rev_arr_lam :: VName -> ADM (Lambda SOACS)
    rev_arr_lam :: VName -> ADM (Lambda SOACS)
rev_arr_lam VName
arr = do
      Param (TypeBase Shape NoUniqueness)
par_i <- [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"i" (TypeBase Shape NoUniqueness
 -> ADM (Param (TypeBase Shape NoUniqueness)))
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
      [LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [Param (TypeBase Shape NoUniqueness)
LParam (Rep ADM)
par_i] (ADM Result -> ADM (Lambda (Rep ADM)))
-> ADM Result -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
        VName
a <-
          [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"ys_bar_rev"
            (Exp SOACS -> ADM VName) -> ADM (Exp SOACS) -> ADM VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> [ADM (Exp (Rep ADM))] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
eIndex VName
arr [TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
n TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
par_i) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)]
        Result -> ADM Result
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName -> SubExpRes
varRes VName
a]