{-# LANGUAGE TypeFamilies #-}
module Futhark.AD.Rev.Scatter (vjpScatter) where
import Control.Monad
import Futhark.AD.Rev.Monad
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.SOACS
import Futhark.Tools
import Futhark.Util (chunk)
withinBounds :: [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds :: [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds [] = PrimExp VName -> TPrimExp Bool VName
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp VName -> TPrimExp Bool VName)
-> PrimExp VName -> TPrimExp Bool VName
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimExp VName
forall v. PrimValue -> PrimExp v
ValueExp (Bool -> PrimValue
BoolValue Bool
True)
withinBounds [(SubExp
q, VName
i)] = (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
q) TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. (SubExp -> TPrimExp Int64 VName
pe64 (IntType -> Integer -> SubExp
intConst IntType
Int64 (-Integer
1)) TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i)
withinBounds ((SubExp, VName)
qi : [(SubExp, VName)]
qis) = [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds [(SubExp, VName)
qi] TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds [(SubExp, VName)]
qis
genIdxLamBody :: VName -> [(SubExp, Param Type)] -> Type -> ADM (Body SOACS)
genIdxLamBody :: VName -> [(SubExp, Param Type)] -> Type -> ADM (Body SOACS)
genIdxLamBody VName
as [(SubExp, Param Type)]
wpis = VName
-> [(SubExp, Param Type)]
-> [Param Type]
-> Type
-> ADM (Body SOACS)
genRecLamBody VName
as [(SubExp, Param Type)]
wpis []
where
genRecLamBody :: VName -> [(SubExp, Param Type)] -> [Param Type] -> Type -> ADM (Body SOACS)
genRecLamBody :: VName
-> [(SubExp, Param Type)]
-> [Param Type]
-> Type
-> ADM (Body SOACS)
genRecLamBody VName
arr [(SubExp, Param Type)]
w_pis [Param Type]
nest_pis (Array PrimType
t (Shape []) NoUniqueness
_) =
VName
-> [(SubExp, Param Type)]
-> [Param Type]
-> Type
-> ADM (Body SOACS)
genRecLamBody VName
arr [(SubExp, Param Type)]
w_pis [Param Type]
nest_pis (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t)
genRecLamBody VName
arr [(SubExp, Param Type)]
w_pis [Param Type]
nest_pis (Array PrimType
t (Shape (SubExp
s : [SubExp]
ss)) NoUniqueness
_) = do
Param Type
new_ip <- [Char] -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"i" (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
let t' :: Type
t' = PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp]
ss
Lambda SOACS
inner_lam <-
[LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
new_ip] (ADM Result -> ADM (Lambda (Rep ADM)))
-> ADM Result -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
Body (Rep ADM) -> ADM Result
Body SOACS -> ADM Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (Body SOACS -> ADM Result) -> ADM (Body SOACS) -> ADM Result
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName
-> [(SubExp, Param Type)]
-> [Param Type]
-> Type
-> ADM (Body SOACS)
genRecLamBody VName
arr [(SubExp, Param Type)]
w_pis ([Param Type]
nest_pis [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type
new_ip]) Type
t'
let ([SubExp]
_, [Param Type]
orig_pis) = [(SubExp, Param Type)] -> ([SubExp], [Param Type])
forall a b. [(a, b)] -> ([a], [b])
unzip [(SubExp, Param Type)]
w_pis
ADM Result -> ADM (Body (Rep ADM))
ADM Result -> ADM (Body SOACS)
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (ADM Result -> ADM (Body SOACS))
-> (ADM Result -> ADM Result) -> ADM Result -> ADM (Body SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope SOACS -> ADM Result -> ADM Result
forall a. Scope SOACS -> ADM a -> ADM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param Type] -> Scope SOACS
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams ([Param Type]
orig_pis [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type]
nest_pis)) (ADM Result -> ADM (Body SOACS)) -> ADM Result -> ADM (Body SOACS)
forall a b. (a -> b) -> a -> b
$ do
VName
iota_v <- [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"iota" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
s (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
SubExp
r <- [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp (VName -> [Char]
baseString VName
arr [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_elem") (Exp (Rep ADM) -> ADM SubExp) -> Exp (Rep ADM) -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
s [VName
iota_v] (Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
inner_lam)
Result -> ADM Result
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp -> SubExpRes
subExpRes SubExp
r]
genRecLamBody VName
arr [(SubExp, Param Type)]
w_pis [Param Type]
nest_pis (Prim PrimType
ptp) = do
let ([SubExp]
ws, [Param Type]
orig_pis) = [(SubExp, Param Type)] -> ([SubExp], [Param Type])
forall a b. [(a, b)] -> ([a], [b])
unzip [(SubExp, Param Type)]
w_pis
let inds :: [VName]
inds = (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName ([Param Type]
orig_pis [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type]
nest_pis)
Scope SOACS -> ADM (Body SOACS) -> ADM (Body SOACS)
forall a. Scope SOACS -> ADM a -> ADM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param Type] -> Scope SOACS
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams ([Param Type]
orig_pis [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type]
nest_pis)) (ADM (Body SOACS) -> ADM (Body SOACS))
-> ADM (Body SOACS) -> ADM (Body SOACS)
forall a b. (a -> b) -> a -> b
$
[ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody
[ ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds ([(SubExp, VName)] -> TPrimExp Bool VName)
-> [(SubExp, VName)] -> TPrimExp Bool VName
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [VName] -> [(SubExp, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
ws ([VName] -> [(SubExp, VName)]) -> [VName] -> [(SubExp, VName)]
forall a b. (a -> b) -> a -> b
$ (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
orig_pis)
( do
SubExp
r <- [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"r" (Exp (Rep ADM) -> ADM SubExp) -> Exp (Rep ADM) -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ (VName -> DimIndex SubExp) -> [VName] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> (VName -> SubExp) -> VName -> DimIndex SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
inds
[SubExp] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [SubExp
r]
)
([SubExp] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
ptp])
]
genRecLamBody VName
_ [(SubExp, Param Type)]
_ [Param Type]
_ Type
_ = [Char] -> ADM (Body SOACS)
forall a. HasCallStack => [Char] -> a
error [Char]
"In Rev.hs, helper function genRecLamBody, unreachable case reached!"
vjpScatter1 ::
PatElem Type ->
StmAux () ->
(SubExp, [VName], (ShapeBase SubExp, Int, VName)) ->
ADM () ->
ADM ()
vjpScatter1 :: PatElem Type
-> StmAux ()
-> (SubExp, [VName], (ShapeBase SubExp, Int, VName))
-> ADM ()
-> ADM ()
vjpScatter1 PatElem Type
pys StmAux ()
aux (SubExp
w, [VName]
ass, (ShapeBase SubExp
shp, Int
num_vals, VName
xs)) ADM ()
m = do
let rank :: Int
rank = [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int) -> [SubExp] -> Int
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shp
([VName]
all_inds, [VName]
val_as) = Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt (Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
num_vals) [VName]
ass
inds_as :: [[VName]]
inds_as = Int -> [VName] -> [[VName]]
forall a. Int -> [a] -> [[a]]
chunk Int
rank [VName]
all_inds
Type
xs_t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
xs
let val_t :: Type
val_t = Int -> Type -> Type
forall u.
Int
-> TypeBase (ShapeBase SubExp) u -> TypeBase (ShapeBase SubExp) u
stripArray (ShapeBase SubExp -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
shp) Type
xs_t
[VName]
xs_saves <- [[VName]] -> VName -> Type -> ADM [VName]
mkGather [[VName]]
inds_as VName
xs Type
xs_t
Lambda SOACS
id_lam <-
[Type] -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda ([Type] -> ADM (Lambda SOACS)) -> [Type] -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$
Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate (ShapeBase SubExp -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
shp) (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate (ShapeBase SubExp -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
shp) Type
val_t
Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep ADM) -> ADM ()) -> Stm (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep ADM))
-> StmAux (ExpDec (Rep ADM)) -> Exp (Rep ADM) -> Stm (Rep ADM)
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type
pys]) StmAux ()
StmAux (ExpDec (Rep ADM))
aux (Exp (Rep ADM) -> Stm (Rep ADM)) -> Exp (Rep ADM) -> Stm (Rep ADM)
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp
-> [VName] -> ScatterSpec VName -> Lambda SOACS -> SOAC SOACS
forall rep.
SubExp -> [VName] -> ScatterSpec VName -> Lambda rep -> SOAC rep
Scatter SubExp
w [VName]
ass [(ShapeBase SubExp
shp, Int
num_vals, VName
xs)] Lambda SOACS
id_lam
ADM ()
m
let ys :: VName
ys = PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName PatElem Type
pys
VName
ys_copy <-
[Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
ys [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_copy") (Exp SOACS -> ADM VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM VName) -> BasicOp -> ADM VName
forall a b. (a -> b) -> a -> b
$
ShapeBase SubExp -> SubExp -> BasicOp
Replicate ShapeBase SubExp
forall a. Monoid a => a
mempty (VName -> SubExp
Var VName
ys)
ADM () -> ADM ()
forall a. ADM a -> ADM a
returnSweepCode (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ do
VName
ys_adj <- VName -> ADM VName
lookupAdjVal VName
ys
[VName]
vs_ctrbs <- [[VName]] -> VName -> Type -> ADM [VName]
mkGather [[VName]]
inds_as VName
ys_adj Type
xs_t
(VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
updateAdj [VName]
val_as [VName]
vs_ctrbs
[VName]
zeros <-
Int -> ADM VName -> ADM [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM ([VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
val_as) (ADM VName -> ADM [VName])
-> (Exp SOACS -> ADM VName) -> Exp SOACS -> ADM [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"zeros" (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [VName]
forall a b. (a -> b) -> a -> b
$
Type -> Exp SOACS
forall rep. Type -> Exp rep
zeroExp (Type -> Exp SOACS) -> Type -> Exp SOACS
forall a b. (a -> b) -> a -> b
$
Type
xs_t Type -> SubExp -> Type
forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` SubExp
w
let f_tps :: [Type]
f_tps = Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate (Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
num_vals) (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate Int
num_vals Type
val_t
Lambda SOACS
f <- [Type] -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda [Type]
f_tps
VName
xs_adj <-
[Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
xs [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_adj") (Exp SOACS -> ADM VName)
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM VName) -> SOAC SOACS -> ADM VName
forall a b. (a -> b) -> a -> b
$
SubExp
-> [VName] -> ScatterSpec VName -> Lambda SOACS -> SOAC SOACS
forall rep.
SubExp -> [VName] -> ScatterSpec VName -> Lambda rep -> SOAC rep
Scatter SubExp
w ([VName]
all_inds [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
zeros) [(ShapeBase SubExp
shp, Int
num_vals, VName
ys_adj)] Lambda SOACS
f
VName -> VName -> ADM ()
insAdj VName
xs VName
xs_adj
Lambda SOACS
f' <- [Type] -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda [Type]
f_tps
VName
xs_rc <-
StmAux () -> ADM VName -> ADM VName
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux (ADM VName -> ADM VName)
-> (SOAC SOACS -> ADM VName) -> SOAC SOACS -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
xs [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_rc") (Exp SOACS -> ADM VName)
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM VName) -> SOAC SOACS -> ADM VName
forall a b. (a -> b) -> a -> b
$
SubExp
-> [VName] -> ScatterSpec VName -> Lambda SOACS -> SOAC SOACS
forall rep.
SubExp -> [VName] -> ScatterSpec VName -> Lambda rep -> SOAC rep
Scatter SubExp
w ([VName]
all_inds [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
xs_saves) [(ShapeBase SubExp
shp, Int
num_vals, VName
ys)] Lambda SOACS
f'
VName -> VName -> ADM ()
addSubstitution VName
xs VName
xs_rc
VName -> VName -> ADM ()
addSubstitution VName
ys VName
ys_copy
where
mkGather :: [[VName]] -> VName -> Type -> ADM [VName]
mkGather :: [[VName]] -> VName -> Type -> ADM [VName]
mkGather [[VName]]
inds_as VName
arr Type
arr_t = do
[[Param Type]]
ips <- [[VName]] -> ([VName] -> ADM [Param Type]) -> ADM [[Param Type]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [[VName]]
inds_as (([VName] -> ADM [Param Type]) -> ADM [[Param Type]])
-> ([VName] -> ADM [Param Type]) -> ADM [[Param Type]]
forall a b. (a -> b) -> a -> b
$ \[VName]
idxs ->
(VName -> ADM (Param Type)) -> [VName] -> ADM [Param Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\VName
idx -> [Char] -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
idx [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_elem") (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)) [VName]
idxs
Lambda SOACS
gather_lam <- [LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda ([[Param Type]] -> [Param Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Param Type]]
ips) (ADM Result -> ADM (Lambda SOACS))
-> (([Param Type] -> ADM Result) -> ADM Result)
-> ([Param Type] -> ADM Result)
-> ADM (Lambda SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Result] -> Result) -> ADM [Result] -> ADM Result
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Result] -> Result
forall a. Monoid a => [a] -> a
mconcat (ADM [Result] -> ADM Result)
-> (([Param Type] -> ADM Result) -> ADM [Result])
-> ([Param Type] -> ADM Result)
-> ADM Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[Param Type]] -> ([Param Type] -> ADM Result) -> ADM [Result]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [[Param Type]]
ips (([Param Type] -> ADM Result) -> ADM (Lambda SOACS))
-> ([Param Type] -> ADM Result) -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ \[Param Type]
idxs -> do
let q :: Int
q = [Param Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Param Type]
idxs
([SubExp]
ws, Type
eltp) = (Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
take Int
q ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims Type
arr_t, Int -> Type -> Type
forall u.
Int
-> TypeBase (ShapeBase SubExp) u -> TypeBase (ShapeBase SubExp) u
stripArray Int
q Type
arr_t)
Body (Rep ADM) -> ADM Result
Body SOACS -> ADM Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (Body SOACS -> ADM Result) -> ADM (Body SOACS) -> ADM Result
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> [(SubExp, Param Type)] -> Type -> ADM (Body SOACS)
genIdxLamBody VName
arr ([SubExp] -> [Param Type] -> [(SubExp, Param Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
ws [Param Type]
idxs) Type
eltp
let soac :: SOAC SOACS
soac = SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w ([[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
inds_as) (Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
gather_lam)
[Char] -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp (VName -> [Char]
baseString VName
arr [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_gather") (Exp (Rep ADM) -> ADM [VName]) -> Exp (Rep ADM) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ Op SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op Op SOACS
SOAC SOACS
soac
vjpScatter ::
VjpOps ->
Pat Type ->
StmAux () ->
(SubExp, [VName], Lambda SOACS, [(Shape, Int, VName)]) ->
ADM () ->
ADM ()
vjpScatter :: VjpOps
-> Pat Type
-> StmAux ()
-> (SubExp, [VName], Lambda SOACS, ScatterSpec VName)
-> ADM ()
-> ADM ()
vjpScatter VjpOps
ops (Pat [PatElem Type]
pes) StmAux ()
aux (SubExp
w, [VName]
ass, Lambda SOACS
lam, ScatterSpec VName
written_info) ADM ()
m
| Lambda SOACS -> Bool
forall rep. Lambda rep -> Bool
isIdentityLambda Lambda SOACS
lam,
[(ShapeBase SubExp
shp, Int
num_vals, VName
xs)] <- ScatterSpec VName
written_info,
[PatElem Type
pys] <- [PatElem Type]
pes =
PatElem Type
-> StmAux ()
-> (SubExp, [VName], (ShapeBase SubExp, Int, VName))
-> ADM ()
-> ADM ()
vjpScatter1 PatElem Type
pys StmAux ()
aux (SubExp
w, [VName]
ass, (ShapeBase SubExp
shp, Int
num_vals, VName
xs)) ADM ()
m
| Lambda SOACS -> Bool
forall rep. Lambda rep -> Bool
isIdentityLambda Lambda SOACS
lam = do
let sind :: Int
sind = ScatterSpec VName -> Int
forall {a} {c}. [(ShapeBase a, Int, c)] -> Int
splitInd ScatterSpec VName
written_info
([VName]
inds, [VName]
vals) = Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
sind [VName]
ass
[Stm SOACS]
lst_stms <- ([VName], [VName])
-> [(PatElem Type, (ShapeBase SubExp, Int, VName))]
-> ADM [Stm SOACS]
chunkScatterInps ([VName]
inds, [VName]
vals) ([PatElem Type]
-> ScatterSpec VName
-> [(PatElem Type, (ShapeBase SubExp, Int, VName))]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem Type]
pes ScatterSpec VName
written_info)
Stms SOACS -> ADM ()
diffScatters ([Stm SOACS] -> Stms SOACS
forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm SOACS]
lst_stms)
| Bool
otherwise =
[Char] -> ADM ()
forall a. HasCallStack => [Char] -> a
error [Char]
"vjpScatter: cannot handle"
where
splitInd :: [(ShapeBase a, Int, c)] -> Int
splitInd [] = Int
0
splitInd ((ShapeBase a
shp, Int
num_res, c
_) : [(ShapeBase a, Int, c)]
rest) =
Int
num_res Int -> Int -> Int
forall a. Num a => a -> a -> a
* [a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (ShapeBase a -> [a]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase a
shp) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [(ShapeBase a, Int, c)] -> Int
splitInd [(ShapeBase a, Int, c)]
rest
chunkScatterInps :: ([VName], [VName])
-> [(PatElem Type, (ShapeBase SubExp, Int, VName))]
-> ADM [Stm SOACS]
chunkScatterInps ([VName]
acc_inds, [VName]
acc_vals) [] =
case ([VName]
acc_inds, [VName]
acc_vals) of
([], []) -> [Stm SOACS] -> ADM [Stm SOACS]
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
([VName], [VName])
_ -> [Char] -> ADM [Stm SOACS]
forall a. HasCallStack => [Char] -> a
error [Char]
"chunkScatterInps: cannot handle"
chunkScatterInps
([VName]
acc_inds, [VName]
acc_vals)
((PatElem Type
pe, info :: (ShapeBase SubExp, Int, VName)
info@(ShapeBase SubExp
shp, Int
num_vals, VName
_)) : [(PatElem Type, (ShapeBase SubExp, Int, VName))]
rest) = do
let num_inds :: Int
num_inds = Int
num_vals Int -> Int -> Int
forall a. Num a => a -> a -> a
* [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shp)
([VName]
curr_inds, [VName]
other_inds) = Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_inds [VName]
acc_inds
([VName]
curr_vals, [VName]
other_vals) = Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_vals [VName]
acc_vals
[Type]
vtps <- (VName -> ADM Type) -> [VName] -> ADM [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
curr_vals
Lambda SOACS
f <- [Type] -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda (Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate Int
num_inds (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
vtps)
let stm :: Stm SOACS
stm =
Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type
pe]) StmAux ()
StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS)
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> Stm SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> Stm SOACS) -> SOAC SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$
SubExp
-> [VName] -> ScatterSpec VName -> Lambda SOACS -> SOAC SOACS
forall rep.
SubExp -> [VName] -> ScatterSpec VName -> Lambda rep -> SOAC rep
Scatter SubExp
w ([VName]
curr_inds [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
curr_vals) [(ShapeBase SubExp, Int, VName)
info] Lambda SOACS
f
[Stm SOACS]
stms_rest <- ([VName], [VName])
-> [(PatElem Type, (ShapeBase SubExp, Int, VName))]
-> ADM [Stm SOACS]
chunkScatterInps ([VName]
other_inds, [VName]
other_vals) [(PatElem Type, (ShapeBase SubExp, Int, VName))]
rest
[Stm SOACS] -> ADM [Stm SOACS]
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Stm SOACS] -> ADM [Stm SOACS]) -> [Stm SOACS] -> ADM [Stm SOACS]
forall a b. (a -> b) -> a -> b
$ Stm SOACS
stm Stm SOACS -> [Stm SOACS] -> [Stm SOACS]
forall a. a -> [a] -> [a]
: [Stm SOACS]
stms_rest
diffScatters :: Stms SOACS -> ADM ()
diffScatters Stms SOACS
all_stms
| Just (Stm SOACS
stm, Stms SOACS
stms) <- Stms SOACS -> Maybe (Stm SOACS, Stms SOACS)
forall rep. Stms rep -> Maybe (Stm rep, Stms rep)
stmsHead Stms SOACS
all_stms =
VjpOps -> Stm SOACS -> ADM () -> ADM ()
vjpStm VjpOps
ops Stm SOACS
stm (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> ADM ()
diffScatters Stms SOACS
stms
| Bool
otherwise = ADM ()
m