{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}

module Futhark.AD.Rev.Hist
  ( diffMinMaxHist,
    diffMulHist,
    diffAddHist,
    diffVecHist,
    diffHist,
  )
where

import Control.Monad
import Futhark.AD.Rev.Monad
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.SOACS
import Futhark.Tools
import Futhark.Transform.Rename

getBinOpPlus :: PrimType -> BinOp
getBinOpPlus :: PrimType -> BinOp
getBinOpPlus (IntType IntType
x) = IntType -> Overflow -> BinOp
Add IntType
x Overflow
OverflowUndef
getBinOpPlus (FloatType FloatType
f) = FloatType -> BinOp
FAdd FloatType
f
getBinOpPlus PrimType
_ = String -> BinOp
forall a. HasCallStack => String -> a
error String
"In getBinOpMul, Hist.hs: input not supported"

getBinOpDiv :: PrimType -> BinOp
getBinOpDiv :: PrimType -> BinOp
getBinOpDiv (IntType IntType
t) = IntType -> Safety -> BinOp
SDiv IntType
t Safety
Unsafe
getBinOpDiv (FloatType FloatType
t) = FloatType -> BinOp
FDiv FloatType
t
getBinOpDiv PrimType
_ = String -> BinOp
forall a. HasCallStack => String -> a
error String
"In getBinOpDiv, Hist.hs: input not supported"

withinBounds :: [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds :: [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds [] = PrimExp VName -> TPrimExp Bool VName
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp VName -> TPrimExp Bool VName)
-> PrimExp VName -> TPrimExp Bool VName
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimExp VName
forall v. PrimValue -> PrimExp v
ValueExp (Bool -> PrimValue
BoolValue Bool
True)
withinBounds [(SubExp
q, VName
i)] = (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
q) TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. (SubExp -> TPrimExp Int64 VName
pe64 (IntType -> Integer -> SubExp
intConst IntType
Int64 (-Integer
1)) TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i)
withinBounds ((SubExp, VName)
qi : [(SubExp, VName)]
qis) = [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds [(SubExp, VName)
qi] TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds [(SubExp, VName)]
qis

elseIf ::
  (MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
  PrimType ->
  [(m (Exp (Rep m)), m (Exp (Rep m)))] ->
  [m (Body (Rep m))] ->
  m (Exp (Rep m))
elseIf :: forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
PrimType
-> [(m (Exp (Rep m)), m (Exp (Rep m)))]
-> [m (Body (Rep m))]
-> m (Exp (Rep m))
elseIf PrimType
t [(m (Exp (Rep m))
c1, m (Exp (Rep m))
c2)] [m (Body (Rep m))
bt, m (Body (Rep m))
bf] =
  m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
    (CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
t) m (Exp (Rep m))
c1 m (Exp (Rep m))
c2)
    m (Body (Rep m))
bt
    m (Body (Rep m))
bf
elseIf PrimType
t ((m (Exp (Rep m))
c1, m (Exp (Rep m))
c2) : [(m (Exp (Rep m)), m (Exp (Rep m)))]
cs) (m (Body (Rep m))
bt : [m (Body (Rep m))]
bs) =
  m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
    (CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
t) m (Exp (Rep m))
c1 m (Exp (Rep m))
c2)
    m (Body (Rep m))
bt
    (m (Body (Rep m)) -> m (Exp (Rep m)))
-> m (Body (Rep m)) -> m (Exp (Rep m))
forall a b. (a -> b) -> a -> b
$ [m (Exp (Rep m))] -> m (Body (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody
    ([m (Exp (Rep m))] -> m (Body (Rep m)))
-> [m (Exp (Rep m))] -> m (Body (Rep m))
forall a b. (a -> b) -> a -> b
$ m (Exp (Rep m)) -> [m (Exp (Rep m))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    (m (Exp (Rep m)) -> [m (Exp (Rep m))])
-> m (Exp (Rep m)) -> [m (Exp (Rep m))]
forall a b. (a -> b) -> a -> b
$ PrimType
-> [(m (Exp (Rep m)), m (Exp (Rep m)))]
-> [m (Body (Rep m))]
-> m (Exp (Rep m))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
PrimType
-> [(m (Exp (Rep m)), m (Exp (Rep m)))]
-> [m (Body (Rep m))]
-> m (Exp (Rep m))
elseIf PrimType
t [(m (Exp (Rep m)), m (Exp (Rep m)))]
cs [m (Body (Rep m))]
bs
elseIf PrimType
_ [(m (Exp (Rep m)), m (Exp (Rep m)))]
_ [m (Body (Rep m))]
_ = String -> m (Exp (Rep m))
forall a. HasCallStack => String -> a
error String
"In elseIf, Hist.hs: input not supported"

bindSubExpRes :: (MonadBuilder m) => String -> [SubExpRes] -> m [VName]
bindSubExpRes :: forall (m :: * -> *).
MonadBuilder m =>
String -> [SubExpRes] -> m [VName]
bindSubExpRes String
s =
  (SubExpRes -> m VName) -> [SubExpRes] -> m [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse
    ( \(SubExpRes Certs
cs SubExp
se) -> do
        VName
bn <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
s
        Certs -> m () -> m ()
forall a. Certs -> m a -> m a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
bn] (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
        VName -> m VName
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
bn
    )

nestedmap :: [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap :: [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [] [PrimType]
_ Lambda SOACS
lam = Lambda SOACS -> ADM (Lambda SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda SOACS
lam
nestedmap s :: [SubExp]
s@(SubExp
h : [SubExp]
r) [PrimType]
pt Lambda SOACS
lam = do
  [Param Type]
params <- (PrimType -> ADM (Param Type)) -> [PrimType] -> ADM [Param Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (\PrimType
tp -> String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"x" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Shape -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
tp ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
s) NoUniqueness
NoUniqueness) [PrimType]
pt
  Lambda SOACS
body <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
r [PrimType]
pt Lambda SOACS
lam
  [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type]
[LParam (Rep ADM)]
params (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
    ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (SOAC SOACS -> ADM [VName]) -> SOAC SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"res" (Exp SOACS -> ADM [VName])
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM [SubExpRes]) -> SOAC SOACS -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$
      SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
h ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
params) (Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
body)

-- \ds hs -> map2 lam ds hs
mkF' :: Lambda SOACS -> [Type] -> SubExp -> ADM ([VName], [VName], Lambda SOACS)
mkF' :: Lambda SOACS
-> [Type] -> SubExp -> ADM ([VName], [VName], Lambda SOACS)
mkF' Lambda SOACS
lam [Type]
tps SubExp
n = do
  Lambda SOACS
lam' <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam

  [Param Type]
ds_params <- (Type -> ADM (Param Type)) -> [Type] -> ADM [Param Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"ds_param") [Type]
tps
  [Param Type]
hs_params <- (Type -> ADM (Param Type)) -> [Type] -> ADM [Param Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"hs_param") [Type]
tps
  let ds_pars :: [VName]
ds_pars = (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
ds_params
  let hs_pars :: [VName]
hs_pars = (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
hs_params
  Lambda SOACS
lam_map <-
    [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda ([Param Type]
ds_params [Param Type] -> [Param Type] -> [Param Type]
forall a. Semigroup a => a -> a -> a
<> [Param Type]
hs_params) (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
      ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (SOAC SOACS -> ADM [VName]) -> SOAC SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"map_f'" (Exp SOACS -> ADM [VName])
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM [SubExpRes]) -> SOAC SOACS -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$
        SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n ([VName]
ds_pars [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
hs_pars) (Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam')

  ([VName], [VName], Lambda SOACS)
-> ADM ([VName], [VName], Lambda SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([VName]
ds_pars, [VName]
hs_pars, Lambda SOACS
lam_map)

-- \ls as rs -> map3 (\li ai ri -> li `lam` ai `lam` ri) ls as rs
mkF :: Lambda SOACS -> [Type] -> SubExp -> ADM ([VName], Lambda SOACS)
mkF :: Lambda SOACS -> [Type] -> SubExp -> ADM ([VName], Lambda SOACS)
mkF Lambda SOACS
lam [Type]
tps SubExp
n = do
  Lambda SOACS
lam_l <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam
  Lambda SOACS
lam_r <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam
  let q :: Int
q = [Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Type] -> Int) -> [Type] -> Int
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
lam
      ([Param Type]
lps, [Param Type]
aps) = Int -> [Param Type] -> ([Param Type], [Param Type])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
q ([Param Type] -> ([Param Type], [Param Type]))
-> [Param Type] -> ([Param Type], [Param Type])
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam_l
      ([Param Type]
ips, [Param Type]
rps) = Int -> [Param Type] -> ([Param Type], [Param Type])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
q ([Param Type] -> ([Param Type], [Param Type]))
-> [Param Type] -> ([Param Type], [Param Type])
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam_r
  Lambda SOACS
lam' <- [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda ([Param Type]
lps [Param Type] -> [Param Type] -> [Param Type]
forall a. Semigroup a => a -> a -> a
<> [Param Type]
aps [Param Type] -> [Param Type] -> [Param Type]
forall a. Semigroup a => a -> a -> a
<> [Param Type]
rps) (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
    [SubExpRes]
lam_l_res <- Body (Rep ADM) -> ADM [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Body (Rep m) -> m [SubExpRes]
bodyBind (Body (Rep ADM) -> ADM [SubExpRes])
-> Body (Rep ADM) -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam_l
    [(Param Type, SubExpRes)]
-> ((Param Type, SubExpRes) -> ADM ()) -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param Type] -> [SubExpRes] -> [(Param Type, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
ips [SubExpRes]
lam_l_res) (((Param Type, SubExpRes) -> ADM ()) -> ADM ())
-> ((Param Type, SubExpRes) -> ADM ()) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \(Param Type
ip, SubExpRes Certs
cs SubExp
se) ->
      Certs -> ADM () -> ADM ()
forall a. Certs -> ADM a -> ADM a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
ip] (Exp (Rep ADM) -> ADM ()) -> Exp (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
    Body (Rep ADM) -> ADM [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Body (Rep m) -> m [SubExpRes]
bodyBind (Body (Rep ADM) -> ADM [SubExpRes])
-> Body (Rep ADM) -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam_r

  [Param Type]
ls_params <- (Type -> ADM (Param Type)) -> [Type] -> ADM [Param Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"ls_param") [Type]
tps
  [Param Type]
as_params <- (Type -> ADM (Param Type)) -> [Type] -> ADM [Param Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"as_param") [Type]
tps
  [Param Type]
rs_params <- (Type -> ADM (Param Type)) -> [Type] -> ADM [Param Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"rs_param") [Type]
tps
  let map_params :: [Param Type]
map_params = [Param Type]
ls_params [Param Type] -> [Param Type] -> [Param Type]
forall a. Semigroup a => a -> a -> a
<> [Param Type]
as_params [Param Type] -> [Param Type] -> [Param Type]
forall a. Semigroup a => a -> a -> a
<> [Param Type]
rs_params
  Lambda SOACS
lam_map <-
    [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type]
[LParam (Rep ADM)]
map_params (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
      ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"map_f" (Exp SOACS -> ADM [SubExpRes]) -> Exp SOACS -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$
        Op SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$
          SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
map_params) (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$
            Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam'

  ([VName], Lambda SOACS) -> ADM ([VName], Lambda SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
as_params, Lambda SOACS
lam_map)

mapout :: VName -> SubExp -> SubExp -> ADM VName
mapout :: VName -> SubExp -> SubExp -> ADM VName
mapout VName
is SubExp
n SubExp
w = do
  Param Type
par_is <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"is" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  Lambda SOACS
is'_lam <-
    [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
par_is] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
      ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"is'"
        (Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          (TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds ([(SubExp, VName)] -> TPrimExp Bool VName)
-> [(SubExp, VName)] -> TPrimExp Bool VName
forall a b. (a -> b) -> a -> b
$ (SubExp, VName) -> [(SubExp, VName)]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
w, Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
par_is))
          ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
par_is)
          ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
w)

  String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"is'" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n (VName -> [VName]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
is) (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
is'_lam

multiScatter :: SubExp -> [VName] -> VName -> [VName] -> ADM [VName]
multiScatter :: SubExp -> [VName] -> VName -> [VName] -> ADM [VName]
multiScatter SubExp
n [VName]
dst VName
is [VName]
vs = do
  [Type]
tps <- (VName -> ADM Type) -> [VName] -> ADM [Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
vs
  Param Type
par_i <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  [Param Type]
scatter_params <- (Type -> ADM (Param Type)) -> [Type] -> ADM [Param Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"scatter_param" (Type -> ADM (Param Type))
-> (Type -> Type) -> Type -> ADM (Param Type)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType) [Type]
tps
  Lambda SOACS
scatter_lam <-
    [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda (Param Type
par_i Param Type -> [Param Type] -> [Param Type]
forall a. a -> [a] -> [a]
: [Param Type]
scatter_params) (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
      ([SubExp] -> [SubExpRes]) -> ADM [SubExp] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [SubExp] -> [SubExpRes]
subExpsRes (ADM [SubExp] -> ADM [SubExpRes])
-> ([Exp SOACS] -> ADM [SubExp]) -> [Exp SOACS] -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Exp SOACS -> ADM SubExp) -> [Exp SOACS] -> ADM [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"scatter_map_res") ([Exp SOACS] -> ADM [SubExpRes])
-> ADM [Exp SOACS] -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< do
        [Exp SOACS]
p1 <- Int -> ADM (Exp SOACS) -> ADM [Exp SOACS]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM ([Param Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Param Type]
scatter_params) (ADM (Exp SOACS) -> ADM [Exp SOACS])
-> ADM (Exp SOACS) -> ADM [Exp SOACS]
forall a b. (a -> b) -> a -> b
$ Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
par_i
        [Exp SOACS]
p2 <- (Param Type -> ADM (Exp SOACS)) -> [Param Type] -> ADM [Exp SOACS]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse Param Type -> ADM (Exp (Rep ADM))
Param Type -> ADM (Exp SOACS)
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam [Param Type]
scatter_params
        [Exp SOACS] -> ADM [Exp SOACS]
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Exp SOACS] -> ADM [Exp SOACS]) -> [Exp SOACS] -> ADM [Exp SOACS]
forall a b. (a -> b) -> a -> b
$ [Exp SOACS]
p1 [Exp SOACS] -> [Exp SOACS] -> [Exp SOACS]
forall a. Semigroup a => a -> a -> a
<> [Exp SOACS]
p2

  let spec :: [(Shape, Int, VName)]
spec = (Type -> VName -> (Shape, Int, VName))
-> [Type] -> [VName] -> [(Shape, Int, VName)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Type
t -> (,,) ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$ SubExp -> [SubExp]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> [SubExp]) -> SubExp -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
t) Int
1) [Type]
tps [VName]
dst
  String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"scatter_res" (Exp SOACS -> ADM [VName])
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM [VName]) -> SOAC SOACS -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ SubExp
-> [VName] -> [(Shape, Int, VName)] -> Lambda SOACS -> SOAC SOACS
forall rep.
SubExp
-> [VName] -> [(Shape, Int, VName)] -> Lambda rep -> SOAC rep
Scatter SubExp
n (VName
is VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
vs) [(Shape, Int, VName)]
spec Lambda SOACS
scatter_lam

multiIndex :: [VName] -> [DimIndex SubExp] -> ADM [VName]
multiIndex :: [VName] -> [DimIndex SubExp] -> ADM [VName]
multiIndex [VName]
vs [DimIndex SubExp]
s = do
  (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse
    ( \VName
x -> do
        Type
t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
x
        String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"sorted" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
x (Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
t [DimIndex SubExp]
s)
    )
    [VName]
vs

--
-- special case of histogram with min/max as operator.
-- Original, assuming `is: [n]i64` and `dst: [w]btp`
--     let x = reduce_by_index dst minmax ne is vs
-- Forward sweep:
--     need to copy dst: reverse sweep might use it 7
--       (see ex. in reducebyindexminmax6.fut where the first map requires the original dst to be differentiated).
--     let dst_cpy = copy dst
--     let (x, x_inds) = zip vs (iota n)
--                       |> reduce_by_index (dst_cpy,-1s) argminmax (ne,-1) is
--
-- Reverse sweep:
--     dst_bar += map2 (\i b -> if i == -1
--                              then b
--                              else 0
--                     ) x_inds x_bar

--     vs_ctrbs = map2 (\i b -> if i == -1
--                              then 0
--                              else vs_bar[i] + b
--                     ) x_inds x_bar
--     vs_bar <- scatter vs_bar x_inds vs_ctrbs
diffMinMaxHist ::
  VjpOps -> VName -> StmAux () -> SubExp -> BinOp -> SubExp -> VName -> VName -> SubExp -> SubExp -> VName -> ADM () -> ADM ()
diffMinMaxHist :: VjpOps
-> VName
-> StmAux ()
-> SubExp
-> BinOp
-> SubExp
-> VName
-> VName
-> SubExp
-> SubExp
-> VName
-> ADM ()
-> ADM ()
diffMinMaxHist VjpOps
_ops VName
x StmAux ()
aux SubExp
n BinOp
minmax SubExp
ne VName
is VName
vs SubExp
w SubExp
rf VName
dst ADM ()
m = do
  let t :: PrimType
t = BinOp -> PrimType
binOpType BinOp
minmax
  Type
vs_type <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
vs
  let vs_elm_type :: PrimType
vs_elm_type = Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
vs_type
  let vs_dims :: [SubExp]
vs_dims = Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
vs_type
  let inner_dims :: [SubExp]
inner_dims = [SubExp] -> [SubExp]
forall a. HasCallStack => [a] -> [a]
tail [SubExp]
vs_dims
  let nr_dims :: Int
nr_dims = [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs_dims
  Type
dst_type <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
dst
  let dst_dims :: [SubExp]
dst_dims = Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
dst_type

  VName
dst_cpy <-
    String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
dst String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_copy") (Exp SOACS -> ADM VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM VName) -> BasicOp -> ADM VName
forall a b. (a -> b) -> a -> b
$
      Shape -> SubExp -> BasicOp
Replicate Shape
forall a. Monoid a => a
mempty (VName -> SubExp
Var VName
dst)

  Param Type
acc_v_p <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"acc_v" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
  Param Type
acc_i_p <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"acc_i" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  Param Type
v_p <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"v" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
  Param Type
i_p <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  Lambda SOACS
hist_lam_inner <-
    [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
acc_v_p, Param Type
LParam (Rep ADM)
acc_i_p, Param Type
LParam (Rep ADM)
v_p, Param Type
LParam (Rep ADM)
i_p] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
      ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"idx_res"
        (Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          (CmpOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
t) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_v_p) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
v_p))
          ( [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody
              [ Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_v_p,
                BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (IntType -> BinOp
SMin IntType
Int64) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_i_p) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
i_p)
              ]
          )
          ( [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody
              [ ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
                  ( CmpOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp
                      (PrimType -> CmpOp
CmpEq PrimType
t)
                      (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_v_p)
                      (BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp BinOp
minmax (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_v_p) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
v_p))
                  )
                  ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_v_p, Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_i_p])
                  ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
v_p, Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
i_p])
              ]
          )
  Lambda SOACS
hist_lam <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
inner_dims [PrimType
vs_elm_type, PrimType
int64, PrimType
vs_elm_type, PrimType
int64] Lambda SOACS
hist_lam_inner

  VName
dst_minus_ones <-
    String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"minus_ones" (Exp SOACS -> ADM VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM VName) -> BasicOp -> ADM VName
forall a b. (a -> b) -> a -> b
$
      Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
dst_dims) (IntType -> Integer -> SubExp
intConst IntType
Int64 (-Integer
1))
  SubExp
ne_minus_ones <-
    String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"minus_ones" (Exp SOACS -> ADM SubExp)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM SubExp) -> BasicOp -> ADM SubExp
forall a b. (a -> b) -> a -> b
$
      Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
inner_dims) (IntType -> Integer -> SubExp
intConst IntType
Int64 (-Integer
1))
  VName
iota_n <-
    String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"red_iota" (Exp SOACS -> ADM VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM VName) -> BasicOp -> ADM VName
forall a b. (a -> b) -> a -> b
$
      SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
n (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64

  VName
inp_iota <- do
    if Int
nr_dims Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1
      then VName -> ADM VName
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
iota_n
      else do
        Param Type
i <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
        Lambda SOACS
lam <-
          [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
i] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
            ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"res" (Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< do
              Exp SOACS -> ADM (Exp SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp SOACS -> ADM (Exp SOACS)) -> Exp SOACS -> ADM (Exp SOACS)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
inner_dims) (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
i

        String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"res" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
iota_n] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam

  let hist_op :: HistOp SOACS
hist_op = Shape
-> SubExp -> [VName] -> [SubExp] -> Lambda SOACS -> HistOp SOACS
forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
rf [VName
dst_cpy, VName
dst_minus_ones] [SubExp
ne, if Int
nr_dims Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 then IntType -> Integer -> SubExp
intConst IntType
Int64 (-Integer
1) else SubExp
ne_minus_ones] Lambda SOACS
hist_lam
  Lambda SOACS
f' <- [Type] -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64, Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
vs_type, Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ PrimType -> Shape -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
int64 ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
vs_dims) NoUniqueness
NoUniqueness]
  VName
x_inds <- String -> ADM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (VName -> String
baseString VName
x String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_inds")
  StmAux () -> ADM () -> ADM ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$
    [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
x, VName
x_inds] (Exp (Rep ADM) -> ADM ()) -> Exp (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$
      Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$
        SubExp -> [VName] -> [HistOp SOACS] -> Lambda SOACS -> SOAC SOACS
forall rep.
SubExp -> [VName] -> [HistOp rep] -> Lambda rep -> SOAC rep
Hist SubExp
n [VName
is, VName
vs, VName
inp_iota] [HistOp SOACS
hist_op] Lambda SOACS
f'

  ADM ()
m

  VName
x_bar <- VName -> ADM VName
lookupAdjVal VName
x

  Param Type
x_ind_dst <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (VName -> String
baseString VName
x String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_ind_param") (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  Param Type
x_bar_dst <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (VName -> String
baseString VName
x String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_bar_param") (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
  Lambda SOACS
dst_lam_inner <-
    [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
x_ind_dst, Param Type
LParam (Rep ADM)
x_bar_dst] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
      ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"dst_bar"
        (Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          (TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
x_ind_dst) TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. -TPrimExp Int64 VName
1)
          ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
x_bar_dst)
          ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp (Rep ADM))) -> SubExp -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t)
  Lambda SOACS
dst_lam <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
inner_dims [PrimType
int64, PrimType
vs_elm_type] Lambda SOACS
dst_lam_inner

  VName
dst_bar <-
    String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
dst String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_bar") (Exp SOACS -> ADM VName)
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM VName) -> SOAC SOACS -> ADM VName
forall a b. (a -> b) -> a -> b
$
      SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName
x_inds, VName
x_bar] (Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
dst_lam)

  VName -> VName -> ADM ()
updateAdj VName
dst VName
dst_bar

  VName
vs_bar <- VName -> ADM VName
lookupAdjVal VName
vs

  [VName]
inds' <- (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"inds" (Exp SOACS -> ADM VName)
-> (VName -> Exp SOACS) -> VName -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> (VName -> BasicOp) -> VName -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) (SubExp -> BasicOp) -> (VName -> SubExp) -> VName -> BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) ([VName] -> ADM [VName]) -> ADM [VName] -> ADM [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [SubExp] -> [SubExp] -> ADM [VName]
mk_indices [SubExp]
inner_dims []
  let inds :: [VName]
inds = VName
x_inds VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
inds'

  [Param Type]
par_x_ind_vs <- Int -> ADM (Param Type) -> ADM [Param Type]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
nr_dims (ADM (Param Type) -> ADM [Param Type])
-> ADM (Param Type) -> ADM [Param Type]
forall a b. (a -> b) -> a -> b
$ String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (VName -> String
baseString VName
x String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_ind_param") (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  Param Type
par_x_bar_vs <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (VName -> String
baseString VName
x String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_bar_param") (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
  Lambda SOACS
vs_lam_inner <-
    [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda (Param Type
par_x_bar_vs Param Type -> [Param Type] -> [Param Type]
forall a. a -> [a] -> [a]
: [Param Type]
par_x_ind_vs) (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
      ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"res"
        (Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          (TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (Param Type -> VName
forall dec. Param dec -> VName
paramName (Param Type -> VName) -> Param Type -> VName
forall a b. (a -> b) -> a -> b
$ [Param Type] -> Param Type
forall a. HasCallStack => [a] -> a
head [Param Type]
par_x_ind_vs) TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. -TPrimExp Int64 VName
1)
          ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp (Rep ADM))) -> SubExp -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t)
          ( [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$
              ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ do
                SubExp
vs_bar_i <-
                  String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp (VName -> String
baseString VName
vs_bar String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_el") (Exp SOACS -> ADM SubExp)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM SubExp) -> BasicOp -> ADM SubExp
forall a b. (a -> b) -> a -> b
$
                    VName -> Slice SubExp -> BasicOp
Index VName
vs_bar (Slice SubExp -> BasicOp)
-> ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp]
-> BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> BasicOp) -> [DimIndex SubExp] -> BasicOp
forall a b. (a -> b) -> a -> b
$
                      (Param Type -> DimIndex SubExp)
-> [Param Type] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> (Param Type -> SubExp) -> Param Type -> DimIndex SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> SubExp) -> (Param Type -> VName) -> Param Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> VName
forall dec. Param dec -> VName
paramName) [Param Type]
par_x_ind_vs
                BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (PrimType -> BinOp
getBinOpPlus PrimType
t) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
par_x_bar_vs) (SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
vs_bar_i)
          )
  Lambda SOACS
vs_lam <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
inner_dims (PrimType
vs_elm_type PrimType -> [PrimType] -> [PrimType]
forall a. a -> [a] -> [a]
: Int -> PrimType -> [PrimType]
forall a. Int -> a -> [a]
replicate Int
nr_dims PrimType
int64) Lambda SOACS
vs_lam_inner

  VName
vs_bar_p <-
    String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
vs String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_partial") (Exp SOACS -> ADM VName)
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM VName) -> SOAC SOACS -> ADM VName
forall a b. (a -> b) -> a -> b
$
      SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w (VName
x_bar VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
inds) (Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
vs_lam)

  SubExp
q <-
    String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"q"
      (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp -> SubExp -> [SubExp] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) [SubExp]
dst_dims

  [VName]
scatter_inps <- do
    -- traverse (letExp "flat" . BasicOp . Reshape [DimNew q]) $ inds ++ [vs_bar_p]
    -- ToDo: Cosmin asks: is the below the correct translation of the line above?
    (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"flat" (Exp SOACS -> ADM VName)
-> (VName -> Exp SOACS) -> VName -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> (VName -> BasicOp) -> VName -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ReshapeKind -> Shape -> VName -> BasicOp
Reshape ReshapeKind
ReshapeArbitrary ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
q])) ([VName] -> ADM [VName]) -> [VName] -> ADM [VName]
forall a b. (a -> b) -> a -> b
$
      [VName]
inds [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName
vs_bar_p]

  Lambda SOACS
f'' <- [Type] -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda ([Type] -> ADM (Lambda SOACS)) -> [Type] -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate Int
nr_dims (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t]
  VName
vs_bar' <-
    String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
vs String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_bar") (Exp SOACS -> ADM VName)
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM VName) -> SOAC SOACS -> ADM VName
forall a b. (a -> b) -> a -> b
$
      SubExp
-> [VName] -> [(Shape, Int, VName)] -> Lambda SOACS -> SOAC SOACS
forall rep.
SubExp
-> [VName] -> [(Shape, Int, VName)] -> Lambda rep -> SOAC rep
Scatter SubExp
q [VName]
scatter_inps [([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
vs_dims, Int
1, VName
vs_bar)] Lambda SOACS
f''
  VName -> VName -> ADM ()
insAdj VName
vs VName
vs_bar'
  where
    mk_indices :: [SubExp] -> [SubExp] -> ADM [VName]
    mk_indices :: [SubExp] -> [SubExp] -> ADM [VName]
mk_indices [] [SubExp]
_ = [VName] -> ADM [VName]
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
    mk_indices [SubExp
d] [SubExp]
iotas = do
      [VName]
reps <- (SubExp -> ADM VName) -> [SubExp] -> ADM [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"rep" (Exp SOACS -> ADM VName)
-> (SubExp -> Exp SOACS) -> SubExp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS)
-> (SubExp -> BasicOp) -> SubExp -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
d])) [SubExp]
iotas
      VName
iota_d <-
        String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"red_iota" (Exp SOACS -> ADM VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM VName) -> BasicOp -> ADM VName
forall a b. (a -> b) -> a -> b
$
          SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
d (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
      [VName] -> ADM [VName]
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([VName] -> ADM [VName]) -> [VName] -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ [VName]
reps [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName
iota_d]
    mk_indices (SubExp
d : [SubExp]
dims) [SubExp]
iotas = do
      VName
iota_d <-
        String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"red_iota" (Exp SOACS -> ADM VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM VName) -> BasicOp -> ADM VName
forall a b. (a -> b) -> a -> b
$
          SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
d (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64

      Param Type
i_param <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
      Lambda SOACS
lam <-
        [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
i_param] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
          ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$
            [SubExp] -> [SubExp] -> ADM [VName]
mk_indices [SubExp]
dims ([SubExp] -> ADM [VName]) -> [SubExp] -> ADM [VName]
forall a b. (a -> b) -> a -> b
$
              [SubExp]
iotas [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
i_param]

      String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"res" (Exp (Rep ADM) -> ADM [VName]) -> Exp (Rep ADM) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
d [VName
iota_d] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam

--
-- special case of histogram with multiplication as operator.
-- Original, assuming `is: [n]i64` and `dst: [w]btp`
--     let x = reduce_by_index dst (*) ne is vs
-- Forward sweep:
--     dst does not need to be copied: dst is not overwritten
--     let (ps, zs) = map (\v -> if v == 0 then (1,1) else (v,0)) vs
--     let non_zero_prod = reduce_by_index nes (*) ne is ps
--     let zero_count = reduce_by_index 0s (+) 0 is zs
--     let h_part = map2 (\p c -> if c == 0 then p else 0
--                       ) non_zero_prod zero_count
--     let x = map2 (*) dst h_part
--
-- Reverse sweep:
--     dst_bar += map2 (*) h_part x_bar

--     let part_bar = map2 (*) dst x_bar
--     vs_bar += map2 (\i v -> let zr_cts = zero_count[i]
--                             let pr_bar = part_bar[i]
--                             let nz_prd = non_zero_prod[i]
--                             in if zr_cts == 0
--                             then pr_bar * (nz_prd / v)
--                             else if zr_cts == 1 and v == 0
--                             then nz_prd * pr_bar
--                             else 0
--                    ) is vs
diffMulHist ::
  VjpOps -> VName -> StmAux () -> SubExp -> BinOp -> SubExp -> VName -> VName -> SubExp -> SubExp -> VName -> ADM () -> ADM ()
diffMulHist :: VjpOps
-> VName
-> StmAux ()
-> SubExp
-> BinOp
-> SubExp
-> VName
-> VName
-> SubExp
-> SubExp
-> VName
-> ADM ()
-> ADM ()
diffMulHist VjpOps
_ops VName
x StmAux ()
aux SubExp
n BinOp
mul SubExp
ne VName
is VName
vs SubExp
w SubExp
rf VName
dst ADM ()
m = do
  let t :: PrimType
t = BinOp -> PrimType
binOpType BinOp
mul
  Type
vs_type <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
vs
  let vs_dims :: [SubExp]
vs_dims = Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
vs_type
  let vs_elm_type :: PrimType
vs_elm_type = Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
vs_type
  Type
dst_type <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
dst
  let dst_dims :: [SubExp]
dst_dims = Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
dst_type
  let inner_dims :: [SubExp]
inner_dims = [SubExp] -> [SubExp]
forall a. HasCallStack => [a] -> [a]
tail [SubExp]
vs_dims

  Param Type
v_param <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"v" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
  Lambda SOACS
lam_ps_zs_inner <-
    [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
v_param] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
      ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"map_res"
        (Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          (CmpOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
t) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
v_param) (SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp (Rep ADM))) -> SubExp -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t))
          ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ (SubExp -> ADM (Exp SOACS)) -> [SubExp] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp [PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
onePrimValue PrimType
t, IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1])
          ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
v_param, SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp (Rep ADM))) -> SubExp -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0])
  Lambda SOACS
lam_ps_zs <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
vs_dims [PrimType
vs_elm_type] Lambda SOACS
lam_ps_zs_inner
  [SubExpRes]
ps_zs_res <- Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda (Rep ADM)
Lambda SOACS
lam_ps_zs [SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp (Rep ADM))) -> SubExp -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
vs]
  [VName]
ps_zs <- String -> [SubExpRes] -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> [SubExpRes] -> m [VName]
bindSubExpRes String
"ps_zs" [SubExpRes]
ps_zs_res
  let [VName
ps, VName
zs] = [VName]
ps_zs

  Lambda SOACS
lam_mul_inner <- BinOp -> PrimType -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
BinOp -> PrimType -> m (Lambda (Rep m))
binOpLambda BinOp
mul PrimType
t
  Lambda SOACS
lam_mul <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
inner_dims [PrimType
vs_elm_type, PrimType
vs_elm_type] Lambda SOACS
lam_mul_inner
  VName
nz_prods0 <- String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"nz_prd" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
ne
  let hist_nzp :: HistOp SOACS
hist_nzp = Shape
-> SubExp -> [VName] -> [SubExp] -> Lambda SOACS -> HistOp SOACS
forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
rf [VName
nz_prods0] [SubExp
ne] Lambda SOACS
lam_mul

  Lambda SOACS
lam_add_inner <- BinOp -> PrimType -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
BinOp -> PrimType -> m (Lambda (Rep m))
binOpLambda (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef) PrimType
int64
  Lambda SOACS
lam_add <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
inner_dims [PrimType
int64, PrimType
int64] Lambda SOACS
lam_add_inner
  VName
zr_counts0 <- String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"zr_cts" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
dst_dims) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0)
  SubExp
zrn_ne <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"zr_ne" (Exp (Rep ADM) -> ADM SubExp) -> Exp (Rep ADM) -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
inner_dims) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0)
  let hist_zrn :: HistOp SOACS
hist_zrn = Shape
-> SubExp -> [VName] -> [SubExp] -> Lambda SOACS -> HistOp SOACS
forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
rf [VName
zr_counts0] [if [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs_dims Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 then IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0 else SubExp
zrn_ne] Lambda SOACS
lam_add

  Lambda SOACS
f' <- [Type] -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64, PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64, Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
vs_type, Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ PrimType -> Shape -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
int64 ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
vs_dims) NoUniqueness
NoUniqueness]
  VName
nz_prods <- String -> ADM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"non_zero_prod"
  VName
zr_counts <- String -> ADM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"zero_count"
  StmAux () -> ADM () -> ADM ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$
    [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
nz_prods, VName
zr_counts] (Exp (Rep ADM) -> ADM ()) -> Exp (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$
      Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$
        SubExp -> [VName] -> [HistOp SOACS] -> Lambda SOACS -> SOAC SOACS
forall rep.
SubExp -> [VName] -> [HistOp rep] -> Lambda rep -> SOAC rep
Hist SubExp
n [VName
is, VName
is, VName
ps, VName
zs] [HistOp SOACS
hist_nzp, HistOp SOACS
hist_zrn] Lambda SOACS
f'

  Param Type
p_param <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"prod" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
  Param Type
c_param <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"count" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  Lambda SOACS
lam_h_part_inner <-
    [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
p_param, Param Type
LParam (Rep ADM)
c_param] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
      ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"h_part"
        (Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          (TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
0 TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
c_param))
          ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
p_param)
          ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp (Rep ADM))) -> SubExp -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t)
  Lambda SOACS
lam_h_part <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
dst_dims [PrimType
vs_elm_type, PrimType
int64] Lambda SOACS
lam_h_part_inner
  [SubExpRes]
h_part_res <- Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda (Rep ADM)
Lambda SOACS
lam_h_part ([ADM (Exp (Rep ADM))] -> ADM [SubExpRes])
-> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$ (VName -> ADM (Exp SOACS)) -> [VName] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp SOACS))
-> (VName -> SubExp) -> VName -> ADM (Exp SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName
nz_prods, VName
zr_counts]
  [VName]
h_part' <- String -> [SubExpRes] -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> [SubExpRes] -> m [VName]
bindSubExpRes String
"h_part" [SubExpRes]
h_part_res
  let [VName
h_part] = [VName]
h_part'

  Lambda SOACS
lam_mul_inner' <- BinOp -> PrimType -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
BinOp -> PrimType -> m (Lambda (Rep m))
binOpLambda BinOp
mul PrimType
t
  Lambda SOACS
lam_mul' <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
dst_dims [PrimType
vs_elm_type, PrimType
vs_elm_type] Lambda SOACS
lam_mul_inner'
  [SubExpRes]
x_res <- Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda (Rep ADM)
Lambda SOACS
lam_mul' ([ADM (Exp (Rep ADM))] -> ADM [SubExpRes])
-> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$ (VName -> ADM (Exp SOACS)) -> [VName] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp SOACS))
-> (VName -> SubExp) -> VName -> ADM (Exp SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName
dst, VName
h_part]
  [VName]
x' <- String -> [SubExpRes] -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> [SubExpRes] -> m [VName]
bindSubExpRes String
"x" [SubExpRes]
x_res
  StmAux () -> ADM () -> ADM ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
x] (Exp (Rep ADM) -> ADM ()) -> Exp (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ [VName] -> VName
forall a. HasCallStack => [a] -> a
head [VName]
x'

  ADM ()
m

  VName
x_bar <- VName -> ADM VName
lookupAdjVal VName
x

  Lambda SOACS
lam_mul'' <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam_mul'
  [SubExpRes]
dst_bar_res <- Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda (Rep ADM)
Lambda SOACS
lam_mul'' ([ADM (Exp (Rep ADM))] -> ADM [SubExpRes])
-> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$ (VName -> ADM (Exp SOACS)) -> [VName] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp SOACS))
-> (VName -> SubExp) -> VName -> ADM (Exp SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName
h_part, VName
x_bar]
  [VName]
dst_bar <- String -> [SubExpRes] -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> [SubExpRes] -> m [VName]
bindSubExpRes (VName -> String
baseString VName
dst String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_bar") [SubExpRes]
dst_bar_res
  VName -> VName -> ADM ()
updateAdj VName
dst (VName -> ADM ()) -> VName -> ADM ()
forall a b. (a -> b) -> a -> b
$ [VName] -> VName
forall a. HasCallStack => [a] -> a
head [VName]
dst_bar

  Lambda SOACS
lam_mul''' <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam_mul'
  [SubExpRes]
part_bar_res <- Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda (Rep ADM)
Lambda SOACS
lam_mul''' ([ADM (Exp (Rep ADM))] -> ADM [SubExpRes])
-> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$ (VName -> ADM (Exp SOACS)) -> [VName] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp SOACS))
-> (VName -> SubExp) -> VName -> ADM (Exp SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName
dst, VName
x_bar]
  [VName]
part_bar' <- String -> [SubExpRes] -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> [SubExpRes] -> m [VName]
bindSubExpRes String
"part_bar" [SubExpRes]
part_bar_res
  let [VName
part_bar] = [VName]
part_bar'

  [Param Type]
inner_params <- (String -> Type -> ADM (Param Type))
-> [String] -> [Type] -> ADM [Param Type]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam [String
"zr_cts", String
"pr_bar", String
"nz_prd", String
"a"] ([Type] -> ADM [Param Type]) -> [Type] -> ADM [Param Type]
forall a b. (a -> b) -> a -> b
$ (PrimType -> Type) -> [PrimType] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim [PrimType
int64, PrimType
t, PrimType
t, PrimType
t]
  let [Param Type
zr_cts, Param Type
pr_bar, Param Type
nz_prd, Param Type
a_param] = [Param Type]
inner_params
  Lambda SOACS
lam_vsbar_inner <-
    [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type]
[LParam (Rep ADM)]
inner_params (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
      ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"vs_bar" (Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< do
        ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          (CmpOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
int64) (SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp (Rep ADM))) -> SubExp -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
zr_cts))
          ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp BinOp
mul (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
pr_bar) (ADM (Exp (Rep ADM)) -> ADM (Exp (Rep ADM)))
-> ADM (Exp (Rep ADM)) -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (PrimType -> BinOp
getBinOpDiv PrimType
t) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
nz_prd) (ADM (Exp (Rep ADM)) -> ADM (Exp (Rep ADM)))
-> ADM (Exp (Rep ADM)) -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
a_param)
          ( [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$
              ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$
                ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
                  ( BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
                      BinOp
LogAnd
                      (CmpOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
int64) (SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp (Rep ADM))) -> SubExp -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
zr_cts))
                      (CmpOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
t) (SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp (Rep ADM))) -> SubExp -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t) (ADM (Exp (Rep ADM)) -> ADM (Exp (Rep ADM)))
-> ADM (Exp (Rep ADM)) -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
a_param)
                  )
                  ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp BinOp
mul (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
nz_prd) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
pr_bar))
                  ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp (Rep ADM))) -> SubExp -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t)
          )

  Lambda SOACS
lam_vsbar_middle <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
inner_dims [PrimType
int64, PrimType
t, PrimType
t, PrimType
t] Lambda SOACS
lam_vsbar_inner

  Param Type
i_param <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  Param Type
a_param' <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"a" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
vs_type
  Lambda SOACS
lam_vsbar <-
    [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
i_param, Param Type
LParam (Rep ADM)
a_param'] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
      ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"vs_bar"
        (Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          (TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds ([(SubExp, VName)] -> TPrimExp Bool VName)
-> [(SubExp, VName)] -> TPrimExp Bool VName
forall a b. (a -> b) -> a -> b
$ (SubExp, VName) -> [(SubExp, VName)]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
w, Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
i_param))
          ( ADM [SubExpRes] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m [SubExpRes] -> m (Body (Rep m))
buildBody_ (ADM [SubExpRes] -> ADM (Body (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
              let i :: Slice SubExp
i = Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
vs_type [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
i_param]
              [VName]
names <- (String -> ADM VName) -> [String] -> ADM [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse String -> ADM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName [String
"zr_cts", String
"pr_bar", String
"nz_prd"]
              (VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (\VName
name -> [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
name] (Exp SOACS -> ADM ()) -> (VName -> Exp SOACS) -> VName -> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> (VName -> BasicOp) -> VName -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Slice SubExp -> BasicOp)
-> Slice SubExp -> VName -> BasicOp
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> Slice SubExp -> BasicOp
Index Slice SubExp
i) [VName]
names [VName
zr_counts, VName
part_bar, VName
nz_prods]
              Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda (Rep ADM)
Lambda SOACS
lam_vsbar_middle ([ADM (Exp (Rep ADM))] -> ADM [SubExpRes])
-> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$ (VName -> ADM (Exp SOACS)) -> [VName] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp SOACS))
-> (VName -> SubExp) -> VName -> ADM (Exp SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
names [ADM (Exp SOACS)] -> [ADM (Exp SOACS)] -> [ADM (Exp SOACS)]
forall a. Semigroup a => a -> a -> a
<> [Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
a_param']
          )
          ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ Exp (Rep ADM) -> ADM (Exp (Rep ADM))
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Rep ADM) -> ADM (Exp (Rep ADM)))
-> Exp (Rep ADM) -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ Type -> Exp (Rep ADM)
forall rep. Type -> Exp rep
zeroExp (Type -> Exp (Rep ADM)) -> Type -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
dst_type)

  VName
vs_bar <-
    String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
vs String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_bar") (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
is, VName
vs] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam_vsbar

  VName -> VName -> ADM ()
updateAdj VName
vs VName
vs_bar

--
-- special case of histogram with add as operator.
-- Original, assuming `is: [n]i64` and `dst: [w]btp`
--     let x = reduce_by_index dst (+) ne is vs
-- Forward sweep:
--     need to copy dst: reverse sweep might use it 7
--       (see ex. in reducebyindexminmax6.fut where the first map requires the original dst to be differentiated).
--     let dst_cpy = copy dst
--     let x = reduce_by_index dst_cpy (+) ne is vs
--
-- Reverse sweep:
--     dst_bar += x_bar
--
--     vs_bar += map (\i -> x_bar[i]) is
diffAddHist ::
  VjpOps -> VName -> StmAux () -> SubExp -> Lambda SOACS -> SubExp -> VName -> VName -> SubExp -> SubExp -> VName -> ADM () -> ADM ()
diffAddHist :: VjpOps
-> VName
-> StmAux ()
-> SubExp
-> Lambda SOACS
-> SubExp
-> VName
-> VName
-> SubExp
-> SubExp
-> VName
-> ADM ()
-> ADM ()
diffAddHist VjpOps
_ops VName
x StmAux ()
aux SubExp
n Lambda SOACS
add SubExp
ne VName
is VName
vs SubExp
w SubExp
rf VName
dst ADM ()
m = do
  let t :: Type
t = Param Type -> Type
forall dec. Param dec -> dec
paramDec (Param Type -> Type) -> Param Type -> Type
forall a b. (a -> b) -> a -> b
$ [Param Type] -> Param Type
forall a. HasCallStack => [a] -> a
head ([Param Type] -> Param Type) -> [Param Type] -> Param Type
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
add

  VName
dst_cpy <-
    String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
dst String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_copy") (Exp SOACS -> ADM VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM VName) -> BasicOp -> ADM VName
forall a b. (a -> b) -> a -> b
$
      Shape -> SubExp -> BasicOp
Replicate Shape
forall a. Monoid a => a
mempty (VName -> SubExp
Var VName
dst)

  Lambda SOACS
f <- [Type] -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64, Type
t]
  StmAux () -> ADM () -> ADM ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux (ADM () -> ADM ())
-> (SOAC SOACS -> ADM ()) -> SOAC SOACS -> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
x] (Exp SOACS -> ADM ())
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM ()) -> SOAC SOACS -> ADM ()
forall a b. (a -> b) -> a -> b
$
    SubExp -> [VName] -> [HistOp SOACS] -> Lambda SOACS -> SOAC SOACS
forall rep.
SubExp -> [VName] -> [HistOp rep] -> Lambda rep -> SOAC rep
Hist SubExp
n [VName
is, VName
vs] [Shape
-> SubExp -> [VName] -> [SubExp] -> Lambda SOACS -> HistOp SOACS
forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
rf [VName
dst_cpy] [SubExp
ne] Lambda SOACS
add] Lambda SOACS
f

  ADM ()
m

  VName
x_bar <- VName -> ADM VName
lookupAdjVal VName
x

  VName -> VName -> ADM ()
updateAdj VName
dst VName
x_bar

  Type
x_type <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
x
  Param Type
i_param <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (VName -> String
baseString VName
vs String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_i") (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  let i :: VName
i = Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
i_param
  Lambda SOACS
lam_vsbar <-
    [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
i_param] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
      ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"vs_bar"
        (Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          (TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds ([(SubExp, VName)] -> TPrimExp Bool VName)
-> [(SubExp, VName)] -> TPrimExp Bool VName
forall a b. (a -> b) -> a -> b
$ (SubExp, VName) -> [(SubExp, VName)]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
w, VName
i))
          ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ Exp (Rep ADM) -> ADM (Exp (Rep ADM))
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Rep ADM) -> ADM (Exp (Rep ADM)))
-> Exp (Rep ADM) -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
x_bar (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
x_type [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
i])
          ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
ne)

  VName
vs_bar <- String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
vs String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_bar") (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
is] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam_vsbar
  VName -> VName -> ADM ()
updateAdj VName
vs VName
vs_bar

-- Special case for vectorised combining operator. Rewrite
--   reduce_by_index dst (map2 op) nes is vss
-- to
--   map3 (\dst_col vss_col ne ->
--           reduce_by_index dst_col op ne is vss_col
--        ) (transpose dst) (transpose vss) nes |> transpose
-- before differentiating.
diffVecHist ::
  VjpOps ->
  VName ->
  StmAux () ->
  SubExp ->
  Lambda SOACS ->
  VName ->
  VName ->
  VName ->
  SubExp ->
  SubExp ->
  VName ->
  ADM () ->
  ADM ()
diffVecHist :: VjpOps
-> VName
-> StmAux ()
-> SubExp
-> Lambda SOACS
-> VName
-> VName
-> VName
-> SubExp
-> SubExp
-> VName
-> ADM ()
-> ADM ()
diffVecHist VjpOps
ops VName
x StmAux ()
aux SubExp
n Lambda SOACS
op VName
nes VName
is VName
vss SubExp
w SubExp
rf VName
dst ADM ()
m = do
  Seq (Stm SOACS)
stms <- ADM () -> ADM (Stms (Rep ADM))
forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ (ADM () -> ADM (Stms (Rep ADM))) -> ADM () -> ADM (Stms (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
    Int
rank <- Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (Type -> Int) -> ADM Type -> ADM Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
vss
    let dims :: [Int]
dims = [Int
1, Int
0] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
drop Int
2 [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]

    VName
dstT <- String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"dstT" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
dims VName
dst
    VName
vssT <- String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"vssT" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
dims VName
vss
    Type
t_dstT <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
dstT
    Type
t_vssT <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
vssT
    Type
t_nes <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
nes

    Param Type
dst_col <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"dst_col" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
t_dstT
    Param Type
vss_col <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"vss_col" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
t_vssT
    Param Type
ne <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"ne" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
t_nes

    Lambda SOACS
f <- [Type] -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64 Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
: Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
op)
    Lambda SOACS
map_lam <-
      [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
dst_col, Param Type
LParam (Rep ADM)
vss_col, Param Type
LParam (Rep ADM)
ne] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
        -- TODO Have to copy dst_col, but isn't it already unique?
        VName
dst_col_cpy <-
          String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"dst_col_cpy" (Exp SOACS -> ADM VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM VName) -> BasicOp -> ADM VName
forall a b. (a -> b) -> a -> b
$
            Shape -> SubExp -> BasicOp
Replicate Shape
forall a. Monoid a => a
mempty (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
dst_col)
        (VName -> [SubExpRes]) -> ADM VName -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([VName] -> [SubExpRes]
varsRes ([VName] -> [SubExpRes])
-> (VName -> [VName]) -> VName -> [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> [VName]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure) (ADM VName -> ADM [SubExpRes])
-> (Exp SOACS -> ADM VName) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"col_res" (Exp SOACS -> ADM [SubExpRes]) -> Exp SOACS -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$
          Op SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$
            SubExp -> [VName] -> [HistOp SOACS] -> Lambda SOACS -> SOAC SOACS
forall rep.
SubExp -> [VName] -> [HistOp rep] -> Lambda rep -> SOAC rep
Hist
              SubExp
n
              [VName
is, Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
vss_col]
              [Shape
-> SubExp -> [VName] -> [SubExp] -> Lambda SOACS -> HistOp SOACS
forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
rf [VName
dst_col_cpy] [VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
ne] Lambda SOACS
op]
              Lambda SOACS
f
    VName
histT <-
      String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"histT" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$
        Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$
          SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma (Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
t_dstT) [VName
dstT, VName
vssT, VName
nes] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$
            Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
map_lam
    StmAux () -> ADM () -> ADM ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux (ADM () -> ADM ()) -> (BasicOp -> ADM ()) -> BasicOp -> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
x] (Exp SOACS -> ADM ())
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM ()) -> BasicOp -> ADM ()
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
dims VName
histT
  (Stm SOACS -> ADM () -> ADM ())
-> ADM () -> Seq (Stm SOACS) -> ADM ()
forall a b. (a -> b -> b) -> b -> Seq a -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (VjpOps -> Stm SOACS -> ADM () -> ADM ()
vjpStm VjpOps
ops) ADM ()
m Seq (Stm SOACS)
stms

--
-- a step in the radix sort implementation
-- it assumes the key we are sorting
-- after is [n]i64 and it is the first VName
--
-- local def radix_sort_step [n] 't (xs: [n]t) (get_bit: i32 -> t -> i32)
--                                  (digit_n: i32): [n]t =
--   let num x = get_bit (digit_n+1) x * 2 + get_bit digit_n x
--   let pairwise op (a1,b1,c1,d1) (a2,b2,c2,d2) =
--     (a1 `op` a2, b1 `op` b2, c1 `op` c2, d1 `op` d2)
--   let bins = xs |> map num
--   let flags = bins |> map (\x -> if x == 0 then (1,0,0,0)
--                                  else if x == 1 then (0,1,0,0)
--                                  else if x == 2 then (0,0,1,0)
--                                  else (0,0,0,1))
--   let offsets = scan (pairwise (+)) (0,0,0,0) flags
--   let (na,nb,nc,_nd) = last offsets
--   let f bin (a,b,c,d) = match bin
--                         case 0 -> a-1
--                         case 1 -> na+b-1
--                         case 2 -> na+nb+c-1
--                         case _ -> na+nb+nc+d-1
--   let is = map2 f bins offsets
--   in scatter scratch is xs
radixSortStep :: [VName] -> [Type] -> SubExp -> SubExp -> SubExp -> ADM [VName]
radixSortStep :: [VName] -> [Type] -> SubExp -> SubExp -> SubExp -> ADM [VName]
radixSortStep [VName]
xs [Type]
tps SubExp
bit SubExp
n SubExp
w = do
  -- let is = head xs
  VName
is <- VName -> SubExp -> SubExp -> ADM VName
mapout ([VName] -> VName
forall a. HasCallStack => [a] -> a
head [VName]
xs) SubExp
n SubExp
w

  Param Type
num_param <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"num" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  Lambda SOACS
num_lam <-
    [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
num_param] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
      ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"num_res"
        (Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
          (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef)
          ( BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
              (IntType -> BinOp
And IntType
Int64)
              (BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (IntType -> BinOp
AShr IntType
Int64) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
num_param) (SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
bit))
              (Integer -> ADM (Exp (Rep ADM))
forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst Integer
1)
          )
          ( BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
              (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef)
              (Integer -> ADM (Exp (Rep ADM))
forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst Integer
2)
              ( BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
                  (IntType -> BinOp
And IntType
Int64)
                  (BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (IntType -> BinOp
AShr IntType
Int64) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
num_param) (BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef) (SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
bit) (Integer -> ADM (Exp (Rep ADM))
forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst Integer
1)))
                  (Integer -> ADM (Exp (Rep ADM))
forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst Integer
1)
              )
          )

  VName
bins <- String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"bins" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
is] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
num_lam
  Param Type
flag_param <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"flag" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  Lambda SOACS
flag_lam <-
    [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
flag_param] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
      ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"flag_res"
        (Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PrimType
-> [(ADM (Exp (Rep ADM)), ADM (Exp (Rep ADM)))]
-> [ADM (Body (Rep ADM))]
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
PrimType
-> [(m (Exp (Rep m)), m (Exp (Rep m)))]
-> [m (Body (Rep m))]
-> m (Exp (Rep m))
elseIf
          PrimType
int64
          ((Integer -> (ADM (Exp SOACS), ADM (Exp SOACS)))
-> [Integer] -> [(ADM (Exp SOACS), ADM (Exp SOACS))]
forall a b. (a -> b) -> [a] -> [b]
map ((,) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
flag_param) (ADM (Exp SOACS) -> (ADM (Exp SOACS), ADM (Exp SOACS)))
-> (Integer -> ADM (Exp SOACS))
-> Integer
-> (ADM (Exp SOACS), ADM (Exp SOACS))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> ADM (Exp (Rep ADM))
Integer -> ADM (Exp SOACS)
forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst) [Integer
0 .. Integer
2])
          ((Integer -> ADM (Body SOACS)) -> [Integer] -> [ADM (Body SOACS)]
forall a b. (a -> b) -> [a] -> [b]
map ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
[ADM (Exp (Rep ADM))] -> ADM (Body SOACS)
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body SOACS))
-> (Integer -> [ADM (Exp (Rep ADM))])
-> Integer
-> ADM (Body SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Integer -> ADM (Exp (Rep ADM)))
-> [Integer] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Integer -> ADM (Exp (Rep ADM))
forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst ([Integer] -> [ADM (Exp (Rep ADM))])
-> (Integer -> [Integer]) -> Integer -> [ADM (Exp (Rep ADM))]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (\Integer
i -> (Integer -> Integer) -> [Integer] -> [Integer]
forall a b. (a -> b) -> [a] -> [b]
map (\Integer
j -> if Integer
i Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
j then Integer
1 else Integer
0) [Integer
0 .. Integer
3])) ([Integer
0 .. Integer
3] :: [Integer]))

  [VName]
flags <- String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"flags" (Exp (Rep ADM) -> ADM [VName]) -> Exp (Rep ADM) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
bins] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
flag_lam

  [Param Type]
scan_params <- (String -> ADM (Param Type)) -> [String] -> ADM [Param Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ((String -> Type -> ADM (Param Type))
-> Type -> String -> ADM (Param Type)
forall a b c. (a -> b -> c) -> b -> a -> c
flip String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (Type -> String -> ADM (Param Type))
-> Type -> String -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) [String
"a1", String
"b1", String
"c1", String
"d1", String
"a2", String
"b2", String
"c2", String
"d2"]
  Lambda SOACS
scan_lam <-
    [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type]
[LParam (Rep ADM)]
scan_params (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
      ([SubExp] -> [SubExpRes]) -> ADM [SubExp] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [SubExp] -> [SubExpRes]
subExpsRes (ADM [SubExp] -> ADM [SubExpRes])
-> ([Exp SOACS] -> ADM [SubExp]) -> [Exp SOACS] -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Exp SOACS -> ADM SubExp) -> [Exp SOACS] -> ADM [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"scan_res") ([Exp SOACS] -> ADM [SubExpRes])
-> ADM [Exp SOACS] -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< do
        ([ADM (Exp SOACS)] -> [ADM (Exp SOACS)] -> ADM [Exp SOACS])
-> ([ADM (Exp SOACS)], [ADM (Exp SOACS)]) -> ADM [Exp SOACS]
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((ADM (Exp SOACS) -> ADM (Exp SOACS) -> ADM (Exp SOACS))
-> [ADM (Exp SOACS)] -> [ADM (Exp SOACS)] -> ADM [Exp SOACS]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (BinOp
 -> ADM (Exp (Rep ADM))
 -> ADM (Exp (Rep ADM))
 -> ADM (Exp (Rep ADM)))
-> BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef)) (([ADM (Exp SOACS)], [ADM (Exp SOACS)]) -> ADM [Exp SOACS])
-> ([ADM (Exp SOACS)], [ADM (Exp SOACS)]) -> ADM [Exp SOACS]
forall a b. (a -> b) -> a -> b
$ Int -> [ADM (Exp SOACS)] -> ([ADM (Exp SOACS)], [ADM (Exp SOACS)])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
4 ([ADM (Exp SOACS)] -> ([ADM (Exp SOACS)], [ADM (Exp SOACS)]))
-> [ADM (Exp SOACS)] -> ([ADM (Exp SOACS)], [ADM (Exp SOACS)])
forall a b. (a -> b) -> a -> b
$ (Param Type -> ADM (Exp SOACS))
-> [Param Type] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> ADM (Exp (Rep ADM))
Param Type -> ADM (Exp SOACS)
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam [Param Type]
scan_params

  ScremaForm SOACS
scan <- [Scan SOACS] -> ADM (ScremaForm SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Scan rep] -> m (ScremaForm rep)
scanSOAC ([Scan SOACS] -> ADM (ScremaForm SOACS))
-> [Scan SOACS] -> ADM (ScremaForm SOACS)
forall a b. (a -> b) -> a -> b
$ Scan SOACS -> [Scan SOACS]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Scan SOACS -> [Scan SOACS]) -> Scan SOACS -> [Scan SOACS]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [SubExp] -> Scan SOACS
forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda SOACS
scan_lam ([SubExp] -> Scan SOACS) -> [SubExp] -> Scan SOACS
forall a b. (a -> b) -> a -> b
$ (Integer -> SubExp) -> [Integer] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (IntType -> Integer -> SubExp
intConst IntType
Int64) [Integer
0, Integer
0, Integer
0, Integer
0]
  [VName]
offsets <- String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"offsets" (Exp (Rep ADM) -> ADM [VName]) -> Exp (Rep ADM) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName]
flags ScremaForm SOACS
scan

  SubExp
ind <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"ind_last" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (IntType -> Overflow -> BinOp
Sub IntType
Int64 Overflow
OverflowUndef) (SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
n) (Integer -> ADM (Exp (Rep ADM))
forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst Integer
1)
  let i :: Slice SubExp
i = [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
ind]
  [VName]
nabcd <- (String -> ADM VName) -> [String] -> ADM [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse String -> ADM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName [String
"na", String
"nb", String
"nc", String
"nd"]
  (VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (\VName
abcd -> [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
abcd] (Exp SOACS -> ADM ()) -> (VName -> Exp SOACS) -> VName -> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> (VName -> BasicOp) -> VName -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Slice SubExp -> BasicOp)
-> Slice SubExp -> VName -> BasicOp
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> Slice SubExp -> BasicOp
Index Slice SubExp
i) [VName]
nabcd [VName]
offsets

  let vars :: [SubExp]
vars = (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
nabcd
  [Param Type]
map_params <- (String -> ADM (Param Type)) -> [String] -> ADM [Param Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ((String -> Type -> ADM (Param Type))
-> Type -> String -> ADM (Param Type)
forall a b c. (a -> b -> c) -> b -> a -> c
flip String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (Type -> String -> ADM (Param Type))
-> Type -> String -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) [String
"bin", String
"a", String
"b", String
"c", String
"d"]
  Lambda SOACS
map_lam <-
    [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type]
[LParam (Rep ADM)]
map_params (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
      ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"map_res"
        (Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PrimType
-> [(ADM (Exp (Rep ADM)), ADM (Exp (Rep ADM)))]
-> [ADM (Body (Rep ADM))]
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
PrimType
-> [(m (Exp (Rep m)), m (Exp (Rep m)))]
-> [m (Body (Rep m))]
-> m (Exp (Rep m))
elseIf
          PrimType
int64
          ((Integer -> (ADM (Exp SOACS), ADM (Exp SOACS)))
-> [Integer] -> [(ADM (Exp SOACS), ADM (Exp SOACS))]
forall a b. (a -> b) -> [a] -> [b]
map ((,) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam (Param Type -> ADM (Exp (Rep ADM)))
-> Param Type -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ [Param Type] -> Param Type
forall a. HasCallStack => [a] -> a
head [Param Type]
map_params) (ADM (Exp SOACS) -> (ADM (Exp SOACS), ADM (Exp SOACS)))
-> (Integer -> ADM (Exp SOACS))
-> Integer
-> (ADM (Exp SOACS), ADM (Exp SOACS))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> ADM (Exp (Rep ADM))
Integer -> ADM (Exp SOACS)
forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst) [Integer
0 .. Integer
2])
          ( (Int -> Param Type -> ADM (Body SOACS))
-> [Int] -> [Param Type] -> [ADM (Body SOACS)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
              ( \Int
j Param Type
p ->
                  [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$
                    ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ do
                      SubExp
t <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"t" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (IntType -> Overflow -> BinOp
Sub IntType
Int64 Overflow
OverflowUndef) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
p) (Integer -> ADM (Exp (Rep ADM))
forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst Integer
1)
                      BinOp -> SubExp -> [SubExp] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (SubExp
t SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
take Int
j [SubExp]
vars)
              )
              [Int
0 .. Int
3]
              ([Param Type] -> [Param Type]
forall a. HasCallStack => [a] -> [a]
tail [Param Type]
map_params)
          )

  VName
nis <- String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"nis" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n (VName
bins VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
offsets) (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
map_lam

  [VName]
scatter_dst <- (Type -> ADM VName) -> [Type] -> ADM [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (\Type
t -> String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"scatter_dst" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ PrimType -> [SubExp] -> BasicOp
Scratch (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t)) [Type]
tps
  SubExp -> [VName] -> VName -> [VName] -> ADM [VName]
multiScatter SubExp
n [VName]
scatter_dst VName
nis [VName]
xs
  where
    iConst :: Integer -> m (Exp (Rep m))
iConst Integer
c = SubExp -> m (Exp (Rep m))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> m (Exp (Rep m))) -> SubExp -> m (Exp (Rep m))
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
c

--
-- the radix sort implementation
-- def radix_sort [n] 't (xs: [n]i64) =
--   let iters = if n == 0 then 0 else 32
--   in loop xs for i < iters do radix_sort_step xs i64.get_bit (i*2)
radixSort :: [VName] -> SubExp -> SubExp -> ADM [VName]
radixSort :: [VName] -> SubExp -> SubExp -> ADM [VName]
radixSort [VName]
xs SubExp
n SubExp
w = do
  SubExp
logw <- SubExp -> ADM SubExp
log2 (SubExp -> ADM SubExp) -> ADM SubExp -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"w1" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
w TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1)
  -- ceil logw by (logw + 1) / 2
  SubExp
iters <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"iters" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PrimExp VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
PrimExp VName -> m (Exp (Rep m))
toExp (TPrimExp Int64 VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (SubExp -> TPrimExp Int64 VName
pe64 SubExp
logw TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1) PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
~/~ TPrimExp Int64 VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (SubExp -> TPrimExp Int64 VName
pe64 (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
2)))

  [Type]
types <- (VName -> ADM Type) -> [VName] -> ADM [Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
xs
  [Param (TypeBase Shape Uniqueness)]
params <- (VName -> Type -> ADM (Param (TypeBase Shape Uniqueness)))
-> [VName] -> [Type] -> ADM [Param (TypeBase Shape Uniqueness)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\VName
x -> String
-> TypeBase Shape Uniqueness
-> ADM (Param (TypeBase Shape Uniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (VName -> String
baseString VName
x) (TypeBase Shape Uniqueness
 -> ADM (Param (TypeBase Shape Uniqueness)))
-> (Type -> TypeBase Shape Uniqueness)
-> Type
-> ADM (Param (TypeBase Shape Uniqueness))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Uniqueness -> TypeBase Shape Uniqueness)
-> Uniqueness -> Type -> TypeBase Shape Uniqueness
forall a b c. (a -> b -> c) -> b -> a -> c
flip Type -> Uniqueness -> TypeBase Shape Uniqueness
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Uniqueness
Nonunique) [VName]
xs [Type]
types
  VName
i <- String -> ADM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"i"
  Body SOACS
loopbody <- ADM [SubExpRes] -> ADM (Body (Rep ADM))
ADM [SubExpRes] -> ADM (Body SOACS)
forall (m :: * -> *).
MonadBuilder m =>
m [SubExpRes] -> m (Body (Rep m))
buildBody_ (ADM [SubExpRes] -> ADM (Body SOACS))
-> (ADM [SubExpRes] -> ADM [SubExpRes])
-> ADM [SubExpRes]
-> ADM (Body SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope SOACS -> ADM [SubExpRes] -> ADM [SubExpRes]
forall a. Scope SOACS -> ADM a -> ADM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param (TypeBase Shape Uniqueness)] -> Scope SOACS
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param (TypeBase Shape Uniqueness)]
params) (ADM [SubExpRes] -> ADM (Body SOACS))
-> ADM [SubExpRes] -> ADM (Body SOACS)
forall a b. (a -> b) -> a -> b
$
    ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$ do
      SubExp
bit <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"bit" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
2)
      [VName] -> [Type] -> SubExp -> SubExp -> SubExp -> ADM [VName]
radixSortStep ((Param (TypeBase Shape Uniqueness) -> VName)
-> [Param (TypeBase Shape Uniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName [Param (TypeBase Shape Uniqueness)]
params) [Type]
types SubExp
bit SubExp
n SubExp
w

  String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"sorted" (Exp (Rep ADM) -> ADM [VName]) -> Exp (Rep ADM) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$
    [(FParam SOACS, SubExp)] -> LoopForm -> Body SOACS -> Exp SOACS
forall rep.
[(FParam rep, SubExp)] -> LoopForm -> Body rep -> Exp rep
Loop
      ([FParam SOACS] -> [SubExp] -> [(FParam SOACS, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase Shape Uniqueness)]
[FParam SOACS]
params ([SubExp] -> [(FParam SOACS, SubExp)])
-> [SubExp] -> [(FParam SOACS, SubExp)]
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
xs)
      (VName -> IntType -> SubExp -> LoopForm
ForLoop VName
i IntType
Int64 SubExp
iters)
      Body SOACS
loopbody
  where
    log2 :: SubExp -> ADM SubExp
    log2 :: SubExp -> ADM SubExp
log2 SubExp
m = do
      [Param (TypeBase Shape Uniqueness)]
params <- (String
 -> TypeBase Shape Uniqueness
 -> ADM (Param (TypeBase Shape Uniqueness)))
-> [String]
-> [TypeBase Shape Uniqueness]
-> ADM [Param (TypeBase Shape Uniqueness)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM String
-> TypeBase Shape Uniqueness
-> ADM (Param (TypeBase Shape Uniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam [String
"cond", String
"r", String
"i"] ([TypeBase Shape Uniqueness]
 -> ADM [Param (TypeBase Shape Uniqueness)])
-> [TypeBase Shape Uniqueness]
-> ADM [Param (TypeBase Shape Uniqueness)]
forall a b. (a -> b) -> a -> b
$ (PrimType -> TypeBase Shape Uniqueness)
-> [PrimType] -> [TypeBase Shape Uniqueness]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> TypeBase Shape Uniqueness
forall shape u. PrimType -> TypeBase shape u
Prim [PrimType
Bool, PrimType
int64, PrimType
int64]
      let [Param (TypeBase Shape Uniqueness)
cond, Param (TypeBase Shape Uniqueness)
r, Param (TypeBase Shape Uniqueness)
i] = [Param (TypeBase Shape Uniqueness)]
params

      Body SOACS
body <- ADM [SubExpRes] -> ADM (Body (Rep ADM))
ADM [SubExpRes] -> ADM (Body SOACS)
forall (m :: * -> *).
MonadBuilder m =>
m [SubExpRes] -> m (Body (Rep m))
buildBody_ (ADM [SubExpRes] -> ADM (Body SOACS))
-> (ADM [SubExpRes] -> ADM [SubExpRes])
-> ADM [SubExpRes]
-> ADM (Body SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope SOACS -> ADM [SubExpRes] -> ADM [SubExpRes]
forall a. Scope SOACS -> ADM a -> ADM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param (TypeBase Shape Uniqueness)] -> Scope SOACS
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param (TypeBase Shape Uniqueness)]
params) (ADM [SubExpRes] -> ADM (Body SOACS))
-> ADM [SubExpRes] -> ADM (Body SOACS)
forall a b. (a -> b) -> a -> b
$ do
        SubExp
r' <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"r'" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape Uniqueness)
r) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp t v
.>>. TPrimExp Int64 VName
1)
        SubExp
cond' <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"cond'" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> TPrimExp Bool VName
forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot (TPrimExp Bool VName -> TPrimExp Bool VName)
-> TPrimExp Bool VName -> TPrimExp Bool VName
forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 SubExp
r' TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0)
        SubExp
i' <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"i'" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape Uniqueness)
i) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1)
        [SubExpRes] -> ADM [SubExpRes]
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([SubExpRes] -> ADM [SubExpRes]) -> [SubExpRes] -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [SubExpRes]
subExpsRes [SubExp
cond', SubExp
r', SubExp
i']

      SubExp
cond_init <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"test" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> TPrimExp Bool VName
forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot (TPrimExp Bool VName -> TPrimExp Bool VName)
-> TPrimExp Bool VName -> TPrimExp Bool VName
forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 SubExp
m TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0)

      [SubExp]
l <-
        String -> Exp (Rep ADM) -> ADM [SubExp]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [SubExp]
letTupExp' String
"log2res" (Exp (Rep ADM) -> ADM [SubExp]) -> Exp (Rep ADM) -> ADM [SubExp]
forall a b. (a -> b) -> a -> b
$
          [(FParam SOACS, SubExp)] -> LoopForm -> Body SOACS -> Exp SOACS
forall rep.
[(FParam rep, SubExp)] -> LoopForm -> Body rep -> Exp rep
Loop
            ([Param (TypeBase Shape Uniqueness)]
-> [SubExp] -> [(Param (TypeBase Shape Uniqueness), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase Shape Uniqueness)]
params [SubExp
cond_init, SubExp
m, PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
int64])
            (VName -> LoopForm
WhileLoop (VName -> LoopForm) -> VName -> LoopForm
forall a b. (a -> b) -> a -> b
$ Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape Uniqueness)
cond)
            Body SOACS
body

      let [SubExp
_, SubExp
_, SubExp
res] = [SubExp]
l
      SubExp -> ADM SubExp
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
res

radixSort' :: [VName] -> SubExp -> SubExp -> ADM [VName]
radixSort' :: [VName] -> SubExp -> SubExp -> ADM [VName]
radixSort' [VName]
xs SubExp
n SubExp
w = do
  VName
iota_n <-
    String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"red_iota" (Exp SOACS -> ADM VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM VName) -> BasicOp -> ADM VName
forall a b. (a -> b) -> a -> b
$
      SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
n (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64

  [VName]
radres <- [VName] -> SubExp -> SubExp -> ADM [VName]
radixSort [[VName] -> VName
forall a. HasCallStack => [a] -> a
head [VName]
xs, VName
iota_n] SubExp
n SubExp
w
  let [VName
is', VName
iota'] = [VName]
radres

  Param Type
i_param <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  let slice :: [DimIndex SubExp]
slice = [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
i_param]
  Lambda SOACS
map_lam <- [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
i_param] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExpRes]
varsRes ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [VName] -> [DimIndex SubExp] -> ADM [VName]
multiIndex ([VName] -> [VName]
forall a. HasCallStack => [a] -> [a]
tail [VName]
xs) [DimIndex SubExp]
slice

  [VName]
sorted <- String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"sorted" (Exp (Rep ADM) -> ADM [VName]) -> Exp (Rep ADM) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
iota'] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
map_lam
  [VName] -> ADM [VName]
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([VName] -> ADM [VName]) -> [VName] -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ VName
iota' VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: VName
is' VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
sorted

--
-- generic case of histogram.
-- Original, assuming `is: [n]i64` and `dst: [w]btp`
--   let xs = reduce_by_index dst odot ne is as
-- Forward sweep:
-- let h_part = reduce_by_index (replicate w ne) odot ne is as
-- let xs = map2 odot dst h_part
-- Reverse sweep:
-- h_part_bar += f'' dst h_part
-- dst_bar += f' dst h_part

-- let flag = map (\i -> i == 0 || sis[i] != sis[i-1]) (iota n)
-- let flag_rev = map (\i -> i==0 || flag[n-i]) (iota n)
-- let ls = seg_scan_exc odot ne flag sas
-- let rs = reverse sas |>
--          seg_scan_exc odot ne flag_rev |> reverse
-- let f_bar = map (\i -> if i < w && -1 < w
--                        then h_part_bar[i]
--                        else 0s
--                 ) sis
-- let sas_bar = f f_dst ls sas rs
-- as_bar += scatter (Scratch alpha n) siota sas_bar
-- Where:
--  siota: 'iota n' sorted wrt 'is'
--  sis: 'is' sorted wrt 'is'
--  sas: 'as' sorted wrt 'is'
--  f'' = vjpLambda xs_bar h_part (map2 odot)
--  f' = vjpLambda xs_bar dst (map2 odot)
--  f  = vjpLambda f_bar sas (map4 (\di li ai ri -> di odot li odot ai odot ri))
--  0s is an alpha-dimensional array with 0 (possibly 0-dim)
diffHist :: VjpOps -> [VName] -> StmAux () -> SubExp -> Lambda SOACS -> [SubExp] -> [VName] -> [SubExp] -> SubExp -> [VName] -> ADM () -> ADM ()
diffHist :: VjpOps
-> [VName]
-> StmAux ()
-> SubExp
-> Lambda SOACS
-> [SubExp]
-> [VName]
-> [SubExp]
-> SubExp
-> [VName]
-> ADM ()
-> ADM ()
diffHist VjpOps
ops [VName]
xs StmAux ()
aux SubExp
n Lambda SOACS
lam0 [SubExp]
ne [VName]
as [SubExp]
w SubExp
rf [VName]
dst ADM ()
m = do
  [Type]
as_type <- (VName -> ADM Type) -> [VName] -> ADM [Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType ([VName] -> ADM [Type]) -> [VName] -> ADM [Type]
forall a b. (a -> b) -> a -> b
$ [VName] -> [VName]
forall a. HasCallStack => [a] -> [a]
tail [VName]
as
  [Type]
dst_type <- (VName -> ADM Type) -> [VName] -> ADM [Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
dst

  [VName]
nes <- (SubExp -> ADM VName) -> [SubExp] -> ADM [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"new_dst" (Exp SOACS -> ADM VName)
-> (SubExp -> Exp SOACS) -> SubExp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS)
-> (SubExp -> BasicOp) -> SubExp -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$ SubExp -> [SubExp]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> [SubExp]) -> SubExp -> [SubExp]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> SubExp
forall a. HasCallStack => [a] -> a
head [SubExp]
w)) [SubExp]
ne

  Lambda SOACS
h_map <- [Type] -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda ([Type] -> ADM (Lambda SOACS)) -> [Type] -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64 Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
: (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType [Type]
as_type
  [VName]
h_part <- (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (String -> ADM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> ADM VName) -> (VName -> String) -> VName -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String -> String -> String) -> String -> String -> String
forall a b c. (a -> b -> c) -> b -> a -> c
flip String -> String -> String
forall a. Semigroup a => a -> a -> a
(<>) String
"_h_part" (String -> String) -> (VName -> String) -> VName -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> String
baseString) [VName]
xs
  StmAux () -> ADM () -> ADM ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux (ADM () -> ADM ())
-> (SOAC SOACS -> ADM ()) -> SOAC SOACS -> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName]
h_part (Exp SOACS -> ADM ())
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM ()) -> SOAC SOACS -> ADM ()
forall a b. (a -> b) -> a -> b
$
    SubExp -> [VName] -> [HistOp SOACS] -> Lambda SOACS -> SOAC SOACS
forall rep.
SubExp -> [VName] -> [HistOp rep] -> Lambda rep -> SOAC rep
Hist SubExp
n [VName]
as [Shape
-> SubExp -> [VName] -> [SubExp] -> Lambda SOACS -> HistOp SOACS
forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
w) SubExp
rf [VName]
nes [SubExp]
ne Lambda SOACS
lam0] Lambda SOACS
h_map

  Lambda SOACS
lam0' <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam0
  StmAux () -> ADM () -> ADM ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux (ADM () -> ADM ())
-> (SOAC SOACS -> ADM ()) -> SOAC SOACS -> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName]
xs (Exp SOACS -> ADM ())
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM ()) -> SOAC SOACS -> ADM ()
forall a b. (a -> b) -> a -> b
$
    SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma ([SubExp] -> SubExp
forall a. HasCallStack => [a] -> a
head [SubExp]
w) ([VName]
dst [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
h_part) (Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam0')

  ADM ()
m

  [VName]
xs_bar <- (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse VName -> ADM VName
lookupAdjVal [VName]
xs

  ([VName]
dst_params, [VName]
hp_params, Lambda SOACS
f') <- Lambda SOACS
-> [Type] -> SubExp -> ADM ([VName], [VName], Lambda SOACS)
mkF' Lambda SOACS
lam0 [Type]
dst_type (SubExp -> ADM ([VName], [VName], Lambda SOACS))
-> SubExp -> ADM ([VName], [VName], Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ [SubExp] -> SubExp
forall a. HasCallStack => [a] -> a
head [SubExp]
w
  Lambda SOACS
f'_adj_dst <- VjpOps -> [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda VjpOps
ops ((VName -> Adj) -> [VName] -> [Adj]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Adj
adjFromVar [VName]
xs_bar) [VName]
dst_params Lambda SOACS
f'
  Lambda SOACS
f'_adj_hp <- VjpOps -> [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda VjpOps
ops ((VName -> Adj) -> [VName] -> [Adj]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Adj
adjFromVar [VName]
xs_bar) [VName]
hp_params Lambda SOACS
f'

  [SubExpRes]
dst_bar' <- Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda (Rep ADM)
Lambda SOACS
f'_adj_dst ([ADM (Exp (Rep ADM))] -> ADM [SubExpRes])
-> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$ (VName -> ADM (Exp (Rep ADM))) -> [VName] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp SOACS))
-> (VName -> SubExp) -> VName -> ADM (Exp SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) ([VName] -> [ADM (Exp (Rep ADM))])
-> [VName] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ [VName]
dst [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
h_part
  [VName]
dst_bar <- String -> [SubExpRes] -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> [SubExpRes] -> m [VName]
bindSubExpRes String
"dst_bar" [SubExpRes]
dst_bar'
  (VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
updateAdj [VName]
dst [VName]
dst_bar

  [SubExpRes]
h_part_bar' <- Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda (Rep ADM)
Lambda SOACS
f'_adj_hp ([ADM (Exp (Rep ADM))] -> ADM [SubExpRes])
-> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$ (VName -> ADM (Exp (Rep ADM))) -> [VName] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp SOACS))
-> (VName -> SubExp) -> VName -> ADM (Exp SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) ([VName] -> [ADM (Exp (Rep ADM))])
-> [VName] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ [VName]
dst [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
h_part
  [VName]
h_part_bar <- String -> [SubExpRes] -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> [SubExpRes] -> m [VName]
bindSubExpRes String
"h_part_bar" [SubExpRes]
h_part_bar'

  Lambda SOACS
lam <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam0
  Lambda SOACS
lam' <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam0

  -- is' <- mapout (head as) n (head w)
  -- sorted <- radixSort' (is' : tail as) n $ head w
  [VName]
sorted <- [VName] -> SubExp -> SubExp -> ADM [VName]
radixSort' [VName]
as SubExp
n (SubExp -> ADM [VName]) -> SubExp -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> SubExp
forall a. HasCallStack => [a] -> a
head [SubExp]
w
  let siota :: VName
siota = [VName] -> VName
forall a. HasCallStack => [a] -> a
head [VName]
sorted
  let sis :: VName
sis = [VName] -> VName
forall a. HasCallStack => [a] -> a
head ([VName] -> VName) -> [VName] -> VName
forall a b. (a -> b) -> a -> b
$ [VName] -> [VName]
forall a. HasCallStack => [a] -> [a]
tail [VName]
sorted
  let sas :: [VName]
sas = Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
drop Int
2 [VName]
sorted

  VName
iota_n <-
    String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"iota" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
n (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64

  Param Type
par_i <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  Lambda SOACS
flag_lam <- LParam SOACS -> VName -> ADM (Lambda SOACS)
mkFlagLam Param Type
LParam SOACS
par_i VName
sis
  VName
flag <- String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"flag" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
iota_n] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
flag_lam

  -- map (\i -> (if flag[i] then (true,ne) else (false,vs[i-1]), if i==0 || flag[n-i] then (true,ne) else (false,vs[n-i]))) (iota n)
  Param Type
par_i' <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  let i' :: VName
i' = Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
par_i'
  Lambda SOACS
g_lam <-
    [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
par_i'] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
      ([SubExp] -> [SubExpRes]) -> ADM [SubExp] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [SubExp] -> [SubExpRes]
subExpsRes (ADM [SubExp] -> ADM [SubExpRes])
-> ([Exp SOACS] -> ADM [SubExp]) -> [Exp SOACS] -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Exp SOACS -> ADM SubExp) -> [Exp SOACS] -> ADM [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"scan_inps") ([Exp SOACS] -> ADM [SubExpRes])
-> ADM [Exp SOACS] -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< do
        SubExp
im1 <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"i_1" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)
        SubExp
nmi <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"n_i" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
n TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i')
        let s1 :: [DimIndex SubExp]
s1 = [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
im1]
        let s2 :: [DimIndex SubExp]
s2 = [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
nmi]

        -- flag array for left scan
        SubExp
f1 <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"f1" (Exp (Rep ADM) -> ADM SubExp) -> Exp (Rep ADM) -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
flag (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
i']

        -- array for left scan
        [SubExp]
r1 <-
          String -> Exp (Rep ADM) -> ADM [SubExp]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [SubExp]
letTupExp' String
"r1"
            (Exp SOACS -> ADM [SubExp]) -> ADM (Exp SOACS) -> ADM [SubExp]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
              (SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
f1)
              ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ (SubExp -> ADM (Exp SOACS)) -> [SubExp] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp [SubExp]
ne)
              ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
[ADM (Exp SOACS)] -> ADM (Body SOACS)
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp SOACS)] -> ADM (Body SOACS))
-> ([VName] -> [ADM (Exp SOACS)]) -> [VName] -> ADM (Body SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> ADM (Exp SOACS)) -> [VName] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp SOACS))
-> (VName -> SubExp) -> VName -> ADM (Exp SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) ([VName] -> ADM (Body SOACS)) -> ADM [VName] -> ADM (Body SOACS)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [VName] -> [DimIndex SubExp] -> ADM [VName]
multiIndex [VName]
sas [DimIndex SubExp]
s1)

        -- array for right scan inc flag
        [SubExp]
r2 <-
          String -> Exp (Rep ADM) -> ADM [SubExp]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [SubExp]
letTupExp' String
"r2"
            (Exp SOACS -> ADM [SubExp]) -> ADM (Exp SOACS) -> ADM [SubExp]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
              (TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i' TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0)
              ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ (SubExp -> ADM (Exp (Rep ADM)))
-> [SubExp] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp ([SubExp] -> [ADM (Exp (Rep ADM))])
-> [SubExp] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimType -> PrimValue
onePrimValue PrimType
Bool) SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: [SubExp]
ne)
              ( [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$
                  ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ do
                    ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
                      (Exp (Rep ADM) -> ADM (Exp (Rep ADM))
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Rep ADM) -> ADM (Exp (Rep ADM)))
-> Exp (Rep ADM) -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
flag (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
s2)
                      ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ (SubExp -> ADM (Exp (Rep ADM)))
-> [SubExp] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp ([SubExp] -> [ADM (Exp (Rep ADM))])
-> [SubExp] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimType -> PrimValue
onePrimValue PrimType
Bool) SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: [SubExp]
ne)
                      ( [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
[ADM (Exp SOACS)] -> ADM (Body SOACS)
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp SOACS)] -> ADM (Body SOACS))
-> ([VName] -> [ADM (Exp SOACS)]) -> [VName] -> ADM (Body SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp -> ADM (Exp SOACS)) -> [SubExp] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp ([SubExp] -> [ADM (Exp SOACS)])
-> ([VName] -> [SubExp]) -> [VName] -> [ADM (Exp SOACS)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PrimValue -> SubExp
Constant (PrimType -> PrimValue
blankPrimValue PrimType
Bool) :) ([SubExp] -> [SubExp])
-> ([VName] -> [SubExp]) -> [VName] -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> SubExp
Var
                          ([VName] -> ADM (Body SOACS)) -> ADM [VName] -> ADM (Body SOACS)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [VName] -> [DimIndex SubExp] -> ADM [VName]
multiIndex [VName]
sas [DimIndex SubExp]
s2
                      )
              )

        (SubExp -> ADM (Exp SOACS)) -> [SubExp] -> ADM [Exp SOACS]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp ([SubExp] -> ADM [Exp SOACS]) -> [SubExp] -> ADM [Exp SOACS]
forall a b. (a -> b) -> a -> b
$ SubExp
f1 SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: [SubExp]
r1 [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
r2

  -- scan (\(f1,v1) (f2,v2) ->
  --   let f = f1 || f2
  --   let v = if f2 then v2 else g v1 v2
  --   in (f,v) ) (false,ne) (zip flags vals)
  [Lambda SOACS]
scan_lams <-
    (Lambda SOACS -> ADM (Lambda SOACS))
-> [Lambda SOACS] -> ADM [Lambda SOACS]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse
      ( \Lambda SOACS
l -> do
          Param Type
f1 <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"f1" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Bool
          Param Type
f2 <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"f2" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Bool
          [Param Type]
ps <- Lambda SOACS -> [Param Type]
Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams (Lambda SOACS -> [Param Type])
-> ADM (Lambda SOACS) -> ADM [Param Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam0
          let ([Param Type]
p1, [Param Type]
p2) = Int -> [Param Type] -> ([Param Type], [Param Type])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
ne) [Param Type]
ps

          [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda (Param Type
f1 Param Type -> [Param Type] -> [Param Type]
forall a. a -> [a] -> [a]
: [Param Type]
p1 [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ Param Type
f2 Param Type -> [Param Type] -> [Param Type]
forall a. a -> [a] -> [a]
: [Param Type]
p2) (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
            ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"scan_res" (Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< do
              let f :: ADM (Exp (Rep ADM))
f = BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp BinOp
LogOr (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
f1) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
f2)
              ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
                (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
f2)
                ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM))
ADM (Exp SOACS)
f ADM (Exp SOACS) -> [ADM (Exp SOACS)] -> [ADM (Exp SOACS)]
forall a. a -> [a] -> [a]
: (Param Type -> ADM (Exp SOACS))
-> [Param Type] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Param Type -> ADM (Exp (Rep ADM))
Param Type -> ADM (Exp SOACS)
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam [Param Type]
p2)
                ( [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
[ADM (Exp SOACS)] -> ADM (Body SOACS)
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp SOACS)] -> ADM (Body SOACS))
-> ([VName] -> [ADM (Exp SOACS)]) -> [VName] -> ADM (Body SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ADM (Exp (Rep ADM))
ADM (Exp SOACS)
f :) ([ADM (Exp SOACS)] -> [ADM (Exp SOACS)])
-> ([VName] -> [ADM (Exp SOACS)]) -> [VName] -> [ADM (Exp SOACS)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> ADM (Exp SOACS)) -> [VName] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp SOACS))
-> (VName -> SubExp) -> VName -> ADM (Exp SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var)
                    ([VName] -> ADM (Body SOACS)) -> ADM [VName] -> ADM (Body SOACS)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< String -> [SubExpRes] -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> [SubExpRes] -> m [VName]
bindSubExpRes String
"gres"
                    ([SubExpRes] -> ADM [VName]) -> ADM [SubExpRes] -> ADM [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda (Rep ADM)
Lambda SOACS
l ((Param Type -> ADM (Exp SOACS))
-> [Param Type] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Param Type -> ADM (Exp (Rep ADM))
Param Type -> ADM (Exp SOACS)
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam [Param Type]
ps)
                )
      )
      [Lambda SOACS
lam, Lambda SOACS
lam']

  let ne' :: [SubExp]
ne' = PrimValue -> SubExp
Constant (Bool -> PrimValue
BoolValue Bool
False) SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: [SubExp]
ne

  [VName]
scansres <-
    String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"adj_ctrb_scan" (Exp SOACS -> ADM [VName])
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM [VName]) -> SOAC SOACS -> ADM [VName]
forall a b. (a -> b) -> a -> b
$
      SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
iota_n] ([Scan SOACS] -> Lambda SOACS -> ScremaForm SOACS
forall rep. [Scan rep] -> Lambda rep -> ScremaForm rep
scanomapSOAC ((Lambda SOACS -> Scan SOACS) -> [Lambda SOACS] -> [Scan SOACS]
forall a b. (a -> b) -> [a] -> [b]
map (Lambda SOACS -> [SubExp] -> Scan SOACS
forall rep. Lambda rep -> [SubExp] -> Scan rep
`Scan` [SubExp]
ne') [Lambda SOACS]
scan_lams) Lambda SOACS
g_lam)

  let (VName
_ : [VName]
ls_arr, VName
_ : [VName]
rs_arr_rev) = Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
ne Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [VName]
scansres

  -- map (\i -> if i < w && -1 < w then (xs_bar[i], dst[i]) else (0,ne)) sis
  Param Type
par_i'' <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  let i'' :: VName
i'' = Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
par_i''
  Lambda SOACS
map_lam <-
    [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
par_i''] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
      ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"scan_res"
        (Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          (TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds ([(SubExp, VName)] -> TPrimExp Bool VName)
-> [(SubExp, VName)] -> TPrimExp Bool VName
forall a b. (a -> b) -> a -> b
$ (SubExp, VName) -> [(SubExp, VName)]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([SubExp] -> SubExp
forall a. HasCallStack => [a] -> a
head [SubExp]
w, VName
i''))
          ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
[ADM (Exp SOACS)] -> ADM (Body SOACS)
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp SOACS)] -> ADM (Body SOACS))
-> ([VName] -> [ADM (Exp SOACS)]) -> [VName] -> ADM (Body SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> ADM (Exp SOACS)) -> [VName] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp SOACS))
-> (VName -> SubExp) -> VName -> ADM (Exp SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) ([VName] -> ADM (Body SOACS)) -> ADM [VName] -> ADM (Body SOACS)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [VName] -> [DimIndex SubExp] -> ADM [VName]
multiIndex [VName]
h_part_bar [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
i''])
          ( [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
              (Type -> ADM (Exp SOACS)) -> [Type] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
map (\Type
t -> Exp SOACS -> ADM (Exp SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp SOACS -> ADM (Exp SOACS)) -> Exp SOACS -> ADM (Exp SOACS)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [SubExp]
forall a. HasCallStack => [a] -> [a]
tail ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t) (PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue (PrimType -> PrimValue) -> PrimType -> PrimValue
forall a b. (a -> b) -> a -> b
$ Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t)) [Type]
as_type
          )

  [VName]
f_bar <- String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"f_bar" (Exp (Rep ADM) -> ADM [VName]) -> Exp (Rep ADM) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
sis] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
map_lam

  ([VName]
as_params, Lambda SOACS
f) <- Lambda SOACS -> [Type] -> SubExp -> ADM ([VName], Lambda SOACS)
mkF Lambda SOACS
lam0 [Type]
as_type SubExp
n
  Lambda SOACS
f_adj <- VjpOps -> [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda VjpOps
ops ((VName -> Adj) -> [VName] -> [Adj]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Adj
adjFromVar [VName]
f_bar) [VName]
as_params Lambda SOACS
f

  -- map (\i -> rs_arr_rev[n-i-1]) (iota n)
  Param Type
par_i''' <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  let i''' :: VName
i''' = Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
par_i'''
  Lambda SOACS
rev_lam <- [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
par_i'''] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
    SubExp
nmim1 <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"n_i_1" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
n TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i''' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)
    [VName] -> [SubExpRes]
varsRes ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [VName] -> [DimIndex SubExp] -> ADM [VName]
multiIndex [VName]
rs_arr_rev [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
nmim1]

  [VName]
rs_arr <- String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"rs_arr" (Exp (Rep ADM) -> ADM [VName]) -> Exp (Rep ADM) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
iota_n] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
rev_lam

  [VName]
sas_bar <-
    String -> [SubExpRes] -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> [SubExpRes] -> m [VName]
bindSubExpRes String
"sas_bar"
      ([SubExpRes] -> ADM [VName]) -> ADM [SubExpRes] -> ADM [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda (Rep ADM)
Lambda SOACS
f_adj ((VName -> ADM (Exp (Rep ADM))) -> [VName] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp SOACS))
-> (VName -> SubExp) -> VName -> ADM (Exp SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) ([VName] -> [ADM (Exp (Rep ADM))])
-> [VName] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ [VName]
ls_arr [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
sas [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
rs_arr)

  [VName]
scatter_dst <- (Type -> ADM VName) -> [Type] -> ADM [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (\Type
t -> String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"scatter_dst" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ PrimType -> [SubExp] -> BasicOp
Scratch (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t)) [Type]
as_type
  [VName]
as_bar <- SubExp -> [VName] -> VName -> [VName] -> ADM [VName]
multiScatter SubExp
n [VName]
scatter_dst VName
siota [VName]
sas_bar

  (VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
updateAdj ([VName] -> [VName]
forall a. HasCallStack => [a] -> [a]
tail [VName]
as) [VName]
as_bar
  where
    -- map (\i -> if i == 0 then true else is[i] != is[i-1]) (iota n)
    mkFlagLam :: LParam SOACS -> VName -> ADM (Lambda SOACS)
    mkFlagLam :: LParam SOACS -> VName -> ADM (Lambda SOACS)
mkFlagLam LParam SOACS
par_i VName
sis =
      [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [LParam (Rep ADM)
LParam SOACS
par_i] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
        ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"flag" (Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< do
          let i :: VName
i = Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
LParam SOACS
par_i
          ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
            (TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0))
            ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp (Rep ADM))) -> SubExp -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
onePrimValue PrimType
Bool)
            ( [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$
                ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ do
                  VName
i_p <- String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"i_p" (Exp SOACS -> ADM VName) -> ADM (Exp SOACS) -> ADM VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)
                  [VName]
vs <- (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"vs" (Exp SOACS -> ADM VName)
-> (VName -> Exp SOACS) -> VName -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> (VName -> BasicOp) -> VName -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Slice SubExp -> BasicOp
Index VName
sis (Slice SubExp -> BasicOp)
-> (VName -> Slice SubExp) -> VName -> BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> (VName -> [DimIndex SubExp]) -> VName -> Slice SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DimIndex SubExp -> [DimIndex SubExp]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DimIndex SubExp -> [DimIndex SubExp])
-> (VName -> DimIndex SubExp) -> VName -> [DimIndex SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> (VName -> SubExp) -> VName -> DimIndex SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName
i, VName
i_p]
                  let [VName
vs_i, VName
vs_p] = [VName]
vs
                  TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ TPrimExp Bool VName -> TPrimExp Bool VName
forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot (TPrimExp Bool VName -> TPrimExp Bool VName)
-> TPrimExp Bool VName -> TPrimExp Bool VName
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
vs_i TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
vs_p
            )