{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -Wno-orphans #-}
-- | Tensor class instances for dual numbers. All definitions
-- are generic over whether the dual numbers are built from concrete arrays
-- of floats or from AST terms or anything else (e.g., nested 'ADVal').
module HordeAd.Core.OpsADVal
  ( crevOnADInputs, crevOnParams, cfwdOnParams
  ) where

import Prelude hiding (foldl')

import Data.Maybe (fromMaybe)
import Data.Proxy (Proxy (Proxy))
import Data.Type.Equality (gcastWith, testEquality, (:~:) (Refl))
import Data.Vector.Generic qualified as V
import GHC.TypeLits (sameNat)

import Data.Array.Nested (Replicate, type (++))
import Data.Array.Nested qualified as Nested
import Data.Array.Nested.Convert (withShsFromShR, withShsFromShX)
import Data.Array.Nested.Lemmas
import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Permutation qualified as Permutation
import Data.Array.Nested.Ranked.Shape
import Data.Array.Nested.Shaped.Shape
import Data.Array.Nested.Types (unsafeCoerceRefl)

import HordeAd.Core.CarriersADVal
import HordeAd.Core.CarriersConcrete
import HordeAd.Core.ConvertTensor
import HordeAd.Core.Delta
import HordeAd.Core.DeltaEval
import HordeAd.Core.Ops
import HordeAd.Core.TensorKind
import HordeAd.Core.Types
import HordeAd.Core.Unwind

-- * Non-symbolic (or at least non-sharing) reverse and forward derivative computation

-- The user-written function f can do anything, so the inputs
-- argument has to be duplicable.
crevOnADInputs
  :: forall x z target. (ADReadyNoLet target, ShareTensor target)
  => Maybe (target (ADTensorKind z))
  -> (ADVal target x -> ADVal target z)
  -> FullShapeTK x -> ADVal target x
  -> (target (ADTensorKind x), target z)
-- Break the inline chain to prevent false positives in inspection testing.
-- {-# INLINE crevOnADInputs #-}
crevOnADInputs :: forall (x :: TK) (z :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target) =>
Maybe (target (ADTensorKind z))
-> (ADVal target x -> ADVal target z)
-> FullShapeTK x
-> ADVal target x
-> (target (ADTensorKind x), target z)
crevOnADInputs Maybe (target (ADTensorKind z))
mdt ADVal target x -> ADVal target z
f FullShapeTK x
xftk ADVal target x
inputs =
  let -- Evaluate completely after terms constructed, to free memory
      -- before evaluation allocates new memory and new FFI is started.
      !(D target z
v Delta target z
delta) = ADVal target x -> ADVal target z
f ADVal target x
inputs in
  let zftk :: FullShapeTK z
zftk = Delta target z -> FullShapeTK z
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target z
delta
      dt :: target (ADTensorKind z)
dt = target (ADTensorKind z)
-> Maybe (target (ADTensorKind z)) -> target (ADTensorKind z)
forall a. a -> Maybe a -> a
fromMaybe ((forall r. GoodScalar r => r)
-> FullShapeTK (ADTensorKind z) -> target (ADTensorKind z)
forall (y :: TK).
(forall r. GoodScalar r => r) -> FullShapeTK y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
(forall r. GoodScalar r => r) -> FullShapeTK y -> target y
treplTarget r
forall r. GoodScalar r => r
1 (FullShapeTK (ADTensorKind z) -> target (ADTensorKind z))
-> FullShapeTK (ADTensorKind z) -> target (ADTensorKind z)
forall a b. (a -> b) -> a -> b
$ FullShapeTK z -> FullShapeTK (ADTensorKind z)
forall (y :: TK). FullShapeTK y -> FullShapeTK (ADTensorKind y)
adFTK FullShapeTK z
zftk) Maybe (target (ADTensorKind z))
mdt
      !gradient :: target (ADTensorKind x)
gradient = FullShapeTK x
-> FullShapeTK z
-> target (ADTensorKind z)
-> Delta target z
-> target (ADTensorKind x)
forall (x :: TK) (z :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target) =>
FullShapeTK x
-> FullShapeTK z
-> target (ADTensorKind z)
-> Delta target z
-> target (ADTensorKind x)
gradientFromDelta FullShapeTK x
xftk FullShapeTK z
zftk target (ADTensorKind z)
dt Delta target z
delta
  in (target (ADTensorKind x)
gradient, target z
v)

crevOnParams
  :: forall x z target. (ADReadyNoLet target, ShareTensor target)
  => Maybe (target (ADTensorKind z))
  -> (ADVal target x -> ADVal target z)
  -> FullShapeTK x -> target x
  -> (target (ADTensorKind x), target z)
{-# INLINE crevOnParams #-}
crevOnParams :: forall (x :: TK) (z :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target) =>
Maybe (target (ADTensorKind z))
-> (ADVal target x -> ADVal target z)
-> FullShapeTK x
-> target x
-> (target (ADTensorKind x), target z)
crevOnParams Maybe (target (ADTensorKind z))
edt ADVal target x -> ADVal target z
f FullShapeTK x
xftk target x
parameters =
  let deltaInputs :: Delta target x
deltaInputs = FullShapeTK x -> Delta target x
forall (x :: TK) (target :: Target).
FullShapeTK x -> Delta target x
generateDeltaInputs FullShapeTK x
xftk
      inputs :: ADVal target x
inputs = target x -> Delta target x -> ADVal target x
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dDnotShared target x
parameters Delta target x
deltaInputs
  in Maybe (target (ADTensorKind z))
-> (ADVal target x -> ADVal target z)
-> FullShapeTK x
-> ADVal target x
-> (target (ADTensorKind x), target z)
forall (x :: TK) (z :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target) =>
Maybe (target (ADTensorKind z))
-> (ADVal target x -> ADVal target z)
-> FullShapeTK x
-> ADVal target x
-> (target (ADTensorKind x), target z)
crevOnADInputs Maybe (target (ADTensorKind z))
edt ADVal target x -> ADVal target z
f FullShapeTK x
xftk ADVal target x
inputs

cfwdOnADInputs
  :: forall x z target. (ADReadyNoLet target, ShareTensor target)
  => FullShapeTK x -> ADVal target x
  -> (ADVal target x -> ADVal target z)
  -> target (ADTensorKind x)
  -> (target (ADTensorKind z), target z)
-- Break the inline chain to prevent false positives in inspection testing.
-- {-# INLINE cfwdOnADInputs #-}
cfwdOnADInputs :: forall (x :: TK) (z :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target) =>
FullShapeTK x
-> ADVal target x
-> (ADVal target x -> ADVal target z)
-> target (ADTensorKind x)
-> (target (ADTensorKind z), target z)
cfwdOnADInputs FullShapeTK x
xftk ADVal target x
inputs ADVal target x -> ADVal target z
f target (ADTensorKind x)
ds =
  let !(D target z
v Delta target z
delta) = ADVal target x -> ADVal target z
f ADVal target x
inputs in
  let !derivative :: target (ADTensorKind z)
derivative = forall (x :: TK) (z :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target) =>
Delta target z
-> FullShapeTK (ADTensorKind x)
-> target (ADTensorKind x)
-> target (ADTensorKind z)
derivativeFromDelta @x Delta target z
delta (FullShapeTK x -> FullShapeTK (ADTensorKind x)
forall (y :: TK). FullShapeTK y -> FullShapeTK (ADTensorKind y)
adFTK FullShapeTK x
xftk) target (ADTensorKind x)
ds
  in (target (ADTensorKind z)
derivative, target z
v)

cfwdOnParams
  :: forall x z target. (ADReadyNoLet target, ShareTensor target)
  => FullShapeTK x -> target x
  -> (ADVal target x -> ADVal target z)
  -> target (ADTensorKind x)
  -> (target (ADTensorKind z), target z)
{-# INLINE cfwdOnParams #-}
cfwdOnParams :: forall (x :: TK) (z :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target) =>
FullShapeTK x
-> target x
-> (ADVal target x -> ADVal target z)
-> target (ADTensorKind x)
-> (target (ADTensorKind z), target z)
cfwdOnParams FullShapeTK x
xftk target x
parameters ADVal target x -> ADVal target z
f target (ADTensorKind x)
ds =
  let deltaInputs :: Delta target x
deltaInputs = FullShapeTK x -> Delta target x
forall (x :: TK) (target :: Target).
FullShapeTK x -> Delta target x
generateDeltaInputs FullShapeTK x
xftk
      inputs :: ADVal target x
inputs = target x -> Delta target x -> ADVal target x
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dDnotShared target x
parameters Delta target x
deltaInputs
  in FullShapeTK x
-> ADVal target x
-> (ADVal target x -> ADVal target z)
-> target (ADTensorKind x)
-> (target (ADTensorKind z), target z)
forall (x :: TK) (z :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target) =>
FullShapeTK x
-> ADVal target x
-> (ADVal target x -> ADVal target z)
-> target (ADTensorKind x)
-> (target (ADTensorKind z), target z)
cfwdOnADInputs FullShapeTK x
xftk ADVal target x
inputs ADVal target x -> ADVal target z
f target (ADTensorKind x)
ds


-- * Instances

fromPrimalFTK :: FullShapeTK z -> f z -> ADVal f z
fromPrimalFTK :: forall (z :: TK) (f :: Target). FullShapeTK z -> f z -> ADVal f z
fromPrimalFTK FullShapeTK z
ftk f z
a = f z -> Delta f z -> ADVal f z
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dDnotShared f z
a (FullShapeTK z -> Delta f z
forall (x :: TK) (target :: Target).
FullShapeTK x -> Delta target x
DeltaZero FullShapeTK z
ftk)

instance ( ADReadyNoLet target, ShareTensor target
         , ShareTensor (PrimalOf target) )
         => LetTensor (ADVal target) where
  ttlet :: forall (x :: TK) (z :: TK).
ADVal target x
-> (ADVal target x -> ADVal target z) -> ADVal target z
ttlet (D target x
u Delta target x
u') ADVal target x -> ADVal target z
f =
    let !var2 :: target x
var2 = target x -> target x
forall (y :: TK). target y -> target y
forall (target :: Target) (y :: TK).
ShareTensor target =>
target y -> target y
tshare target x
u
    in ADVal target x -> ADVal target z
f (target x -> Delta target x -> ADVal target x
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dDnotShared target x
var2 Delta target x
u')  -- u' was already shared
  ttletPrimal :: forall (x :: TK) (z :: TK).
PrimalOf (ADVal target) x
-> (PrimalOf (ADVal target) x -> ADVal target z) -> ADVal target z
ttletPrimal PrimalOf (ADVal target) x
u PrimalOf (ADVal target) x -> ADVal target z
f =
    let !var2 :: target x
var2 = target x -> target x
forall (y :: TK). target y -> target y
forall (target :: Target) (y :: TK).
ShareTensor target =>
target y -> target y
tshare target x
PrimalOf (ADVal target) x
u
    in PrimalOf (ADVal target) x -> ADVal target z
f target x
PrimalOf (ADVal target) x
var2
  toShare :: forall (y :: TK). ADVal target y -> ShareOf (ADVal target) y
toShare = ADVal target y -> ShareOf (ADVal target) y
ADVal target y -> ADVal target y
forall a. a -> a
id
  tunshare :: forall (y :: TK). ShareOf (ADVal target) y -> ADVal target y
tunshare = ShareOf (ADVal target) y -> ADVal target y
ADVal target y -> ADVal target y
forall a. a -> a
id
  tD :: forall (y :: TK).
BaseTensor (ADVal target) =>
SingletonTK y
-> PrimalOf (ADVal target) y
-> DualOf (ADVal target) y
-> ADVal target y
tD SingletonTK y
_stk = target y -> Delta target y -> ADVal target y
PrimalOf (ADVal target) y
-> DualOf (ADVal target) y -> ADVal target y
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD

instance (ADReadyNoLet target, ShareTensor target)
         => ShareTensor (ADVal target) where
  tshare :: forall (y :: TK). ADVal target y -> ADVal target y
tshare (D target y
u Delta target y
u') = target y -> Delta target y -> ADVal target y
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dDnotShared (target y -> target y
forall (y :: TK). target y -> target y
forall (target :: Target) (y :: TK).
ShareTensor target =>
target y -> target y
tshare target y
u) Delta target y
u'  -- u' was already shared
  tunpair :: forall (x :: TK) (z :: TK).
ADVal target (TKProduct x z) -> (ADVal target x, ADVal target z)
tunpair (D target (TKProduct x z)
u Delta target (TKProduct x z)
u') = let (target x
u1, target z
u2) = target (TKProduct x z) -> (target x, target z)
forall (x :: TK) (z :: TK).
target (TKProduct x z) -> (target x, target z)
forall (target :: Target) (x :: TK) (z :: TK).
ShareTensor target =>
target (TKProduct x z) -> (target x, target z)
tunpair target (TKProduct x z)
u
                         (Delta target x
d1, Delta target z
d2) = Delta target (TKProduct x z) -> (Delta target x, Delta target z)
forall (target :: Target) (x :: TK) (y :: TK).
Delta target (TKProduct x y) -> (Delta target x, Delta target y)
unDeltaPair Delta target (TKProduct x z)
u'
                     in (target x -> Delta target x -> ADVal target x
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dDnotShared target x
u1 Delta target x
d1, target z -> Delta target z -> ADVal target z
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dDnotShared target z
u2 Delta target z
d2)

-- Note that this instance doesn't do vectorization. To enable it,
-- use the Ast instance and only then interpret in ADVal.
-- In any case, only the Ast instantiation of this instance
-- is used in the best pipeline, in particular, to satisfy the constraints
-- needed for the interpretation of Ast in ADVal.
-- The ADVal Concrete instantiation is used in other pipelines and tests.
instance ( ADReadyNoLet target, ShareTensor target
         , ShareTensor (PrimalOf target) )
         => BaseTensor (ADVal target) where
  -- Ranked ops
  rshape :: forall (n :: Nat) (x :: TK).
KnownSTK x =>
ADVal target (TKR2 n x) -> IShR n
rshape (D target (TKR2 n x)
u Delta target (TKR2 n x)
_) = target (TKR2 n x) -> IShR n
forall (n :: Nat) (x :: TK).
KnownSTK x =>
target (TKR2 n x) -> IShR n
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKR2 n x) -> IShR n
rshape target (TKR2 n x)
u
  trsum :: forall (n :: Nat) (x :: TK).
(KnownNat n, KnownSTK x) =>
ADVal target (TKR2 (1 + n) x) -> ADVal target (TKR2 n x)
trsum (D target (TKR2 (1 + n) x)
u Delta target (TKR2 (1 + n) x)
u') = Int
-> (forall (n :: Nat).
    KnownNat n =>
    SNat n -> ADVal target (TKR2 n x))
-> ADVal target (TKR2 n x)
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat (target (TKR2 (1 + n) x) -> Int
forall (n :: Nat) (x :: TK).
KnownSTK x =>
target (TKR2 (1 + n) x) -> Int
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKR2 (1 + n) x) -> Int
rwidth target (TKR2 (1 + n) x)
u) ((forall (n :: Nat).
  KnownNat n =>
  SNat n -> ADVal target (TKR2 n x))
 -> ADVal target (TKR2 n x))
-> (forall (n :: Nat).
    KnownNat n =>
    SNat n -> ADVal target (TKR2 n x))
-> ADVal target (TKR2 n x)
forall a b. (a -> b) -> a -> b
$ \SNat n
snat ->
    target (TKR2 n x)
-> Delta target (TKR2 n x) -> ADVal target (TKR2 n x)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (target (TKR2 (1 + n) x) -> target (TKR2 n x)
forall (n :: Nat) (x :: TK).
(KnownNat n, KnownSTK x) =>
target (TKR2 (1 + n) x) -> target (TKR2 n x)
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownNat n, KnownSTK x) =>
target (TKR2 (1 + n) x) -> target (TKR2 n x)
trsum target (TKR2 (1 + n) x)
u) (SNat n
-> SingletonTK (TKR2 n x)
-> Delta target (BuildTensorKind n (TKR2 n x))
-> Delta target (TKR2 n x)
forall (b :: TK) (k :: Nat) (a :: Target).
SNat k
-> SingletonTK b -> Delta a (BuildTensorKind k b) -> Delta a b
DeltaSum SNat n
snat SingletonTK (TKR2 n x)
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Delta target (BuildTensorKind n (TKR2 n x))
Delta target (TKR2 (1 + n) x)
u')
  trsum0 :: forall (n :: Nat) (x :: TK).
(KnownNat n, KnownSTK x) =>
ADVal target (TKR2 n x) -> ADVal target (TKR2 0 x)
trsum0 (D target (TKR2 n x)
u Delta target (TKR2 n x)
u') = target (TKR2 0 x)
-> Delta target (TKR2 0 x) -> ADVal target (TKR2 0 x)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (target (TKR2 n x) -> target (TKR2 0 x)
forall (n :: Nat) (x :: TK).
(KnownNat n, KnownSTK x) =>
target (TKR2 n x) -> target (TKR2 0 x)
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownNat n, KnownSTK x) =>
target (TKR2 n x) -> target (TKR2 0 x)
trsum0 target (TKR2 n x)
u) (Delta target (TKR2 n x) -> Delta target (TKR2 0 x)
forall (a :: Target) (n :: Nat) (r :: TK).
Delta a (TKR2 n r) -> Delta a (TKR2 0 r)
DeltaSum0R Delta target (TKR2 n x)
u')
  trdot0 :: forall (n :: Nat) r.
(KnownNat n, GoodScalar r) =>
ADVal target (TKR n r)
-> ADVal target (TKR n r) -> ADVal target (TKR 0 r)
trdot0 (D target (TKR n r)
ue Delta target (TKR n r)
u') (D target (TKR n r)
ve Delta target (TKR n r)
v') =
    -- The bangs below are neccessary for GHC 9.2.7 test results to match 9.4.
    let !u :: target (TKR n r)
u = target (TKR n r) -> target (TKR n r)
forall (y :: TK). target y -> target y
forall (target :: Target) (y :: TK).
ShareTensor target =>
target y -> target y
tshare target (TKR n r)
ue in
    let !v :: target (TKR n r)
v = target (TKR n r) -> target (TKR n r)
forall (y :: TK). target y -> target y
forall (target :: Target) (y :: TK).
ShareTensor target =>
target y -> target y
tshare target (TKR n r)
ve
    in target (TKR 0 r)
-> Delta target (TKR 0 r) -> ADVal target (TKR 0 r)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (target (TKR n r) -> target (TKR n r) -> target (TKR 0 r)
forall (n :: Nat) r.
(KnownNat n, GoodScalar r) =>
target (TKR n r) -> target (TKR n r) -> target (TKR 0 r)
forall (target :: Target) (n :: Nat) r.
(BaseTensor target, KnownNat n, GoodScalar r) =>
target (TKR n r) -> target (TKR n r) -> target (TKR 0 r)
trdot0 target (TKR n r)
u target (TKR n r)
v) (Delta target (TKR 0 r)
-> Delta target (TKR 0 r) -> Delta target (TKR 0 r)
forall (f :: Target) (z :: TK).
Num (f z) =>
Delta f z -> Delta f z -> Delta f z
dAdd (target (TKR n r)
-> Delta target (TKR n r) -> Delta target (TKR 0 r)
forall r (a :: Target) (n :: Nat).
(GoodScalar r, Show (a (TKR n r))) =>
a (TKR n r) -> Delta a (TKR n r) -> Delta a (TKR2 0 (TKScalar r))
DeltaDot0R target (TKR n r)
v Delta target (TKR n r)
u') (target (TKR n r)
-> Delta target (TKR n r) -> Delta target (TKR 0 r)
forall r (a :: Target) (n :: Nat).
(GoodScalar r, Show (a (TKR n r))) =>
a (TKR n r) -> Delta a (TKR n r) -> Delta a (TKR2 0 (TKScalar r))
DeltaDot0R target (TKR n r)
u Delta target (TKR n r)
v'))
  -- These two are manually vectorized to avoid delta blowup when run
  -- via primitive pipelines.
  trmatvecmul :: forall r.
GoodScalar r =>
ADVal target (TKR 2 r)
-> ADVal target (TKR 1 r) -> ADVal target (TKR 1 r)
trmatvecmul ADVal target (TKR 2 r)
m ADVal target (TKR 1 r)
v = ADVal target (TKR2 (1 + 1) (TKScalar r)) -> ADVal target (TKR 1 r)
forall (n :: Nat) (x :: TK).
(KnownNat n, KnownSTK x) =>
ADVal target (TKR2 (1 + n) x) -> ADVal target (TKR2 n x)
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownNat n, KnownSTK x) =>
target (TKR2 (1 + n) x) -> target (TKR2 n x)
trsum (ADVal target (TKR2 (2 + 0) (TKScalar r))
-> ADVal target (TKR2 (2 + 0) (TKScalar r))
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownSTK x, BaseTensor target) =>
target (TKR2 (2 + n) x) -> target (TKR2 (2 + n) x)
rtr (Int
-> ADVal target (TKR 1 r)
-> ADVal target (TKR2 (1 + 1) (TKScalar r))
forall (n :: Nat) (x :: TK).
(KnownNat n, KnownSTK x) =>
Int -> ADVal target (TKR2 n x) -> ADVal target (TKR2 (1 + n) x)
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownNat n, KnownSTK x) =>
Int -> target (TKR2 n x) -> target (TKR2 (1 + n) x)
trreplicate (ADVal target (TKR2 (1 + 1) (TKScalar r)) -> Int
forall (n :: Nat) (x :: TK).
KnownSTK x =>
ADVal target (TKR2 (1 + n) x) -> Int
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKR2 (1 + n) x) -> Int
rwidth ADVal target (TKR 2 r)
ADVal target (TKR2 (1 + 1) (TKScalar r))
m) ADVal target (TKR 1 r)
v ADVal target (TKR 2 r)
-> ADVal target (TKR 2 r) -> ADVal target (TKR 2 r)
forall a. Num a => a -> a -> a
* ADVal target (TKR 2 r)
m))
  trmatmul2 :: forall r.
GoodScalar r =>
ADVal target (TKR 2 r)
-> ADVal target (TKR 2 r) -> ADVal target (TKR 2 r)
trmatmul2 ADVal target (TKR 2 r)
m1 ADVal target (TKR 2 r)
m2 = case ADVal target (TKR 2 r) -> IShR 2
forall (n :: Nat) (x :: TK).
KnownSTK x =>
ADVal target (TKR2 n x) -> IShR n
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKR2 n x) -> IShR n
rshape ADVal target (TKR 2 r)
m2 of
    Int
_ :$: Int
width2 :$: ShR n Int
ZSR ->
      ADVal target (TKR2 (1 + 2) (TKScalar r)) -> ADVal target (TKR 2 r)
forall (n :: Nat) (x :: TK).
(KnownNat n, KnownSTK x) =>
ADVal target (TKR2 (1 + n) x) -> ADVal target (TKR2 n x)
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownNat n, KnownSTK x) =>
target (TKR2 (1 + n) x) -> target (TKR2 n x)
trsum (PermR
-> ADVal target (TKR2 3 (TKScalar r))
-> ADVal target (TKR2 3 (TKScalar r))
forall (n :: Nat) (x :: TK).
KnownSTK x =>
PermR -> ADVal target (TKR2 n x) -> ADVal target (TKR2 n x)
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
PermR -> target (TKR2 n x) -> target (TKR2 n x)
trtranspose [Int
2,Int
1,Int
0] (Int
-> ADVal target (TKR 2 r)
-> ADVal target (TKR2 (1 + 2) (TKScalar r))
forall (n :: Nat) (x :: TK).
(KnownNat n, KnownSTK x) =>
Int -> ADVal target (TKR2 n x) -> ADVal target (TKR2 (1 + n) x)
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownNat n, KnownSTK x) =>
Int -> target (TKR2 n x) -> target (TKR2 (1 + n) x)
trreplicate Int
width2 ADVal target (TKR 2 r)
m1)
             ADVal target (TKR2 3 (TKScalar r))
-> ADVal target (TKR2 3 (TKScalar r))
-> ADVal target (TKR2 3 (TKScalar r))
forall a. Num a => a -> a -> a
* PermR
-> ADVal target (TKR2 3 (TKScalar r))
-> ADVal target (TKR2 3 (TKScalar r))
forall (n :: Nat) (x :: TK).
KnownSTK x =>
PermR -> ADVal target (TKR2 n x) -> ADVal target (TKR2 n x)
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
PermR -> target (TKR2 n x) -> target (TKR2 n x)
trtranspose [Int
1,Int
0] (Int
-> ADVal target (TKR 2 r)
-> ADVal target (TKR2 (1 + 2) (TKScalar r))
forall (n :: Nat) (x :: TK).
(KnownNat n, KnownSTK x) =>
Int -> ADVal target (TKR2 n x) -> ADVal target (TKR2 (1 + n) x)
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownNat n, KnownSTK x) =>
Int -> target (TKR2 n x) -> target (TKR2 (1 + n) x)
trreplicate (ADVal target (TKR2 (1 + 1) (TKScalar r)) -> Int
forall (n :: Nat) (x :: TK).
KnownSTK x =>
ADVal target (TKR2 (1 + n) x) -> Int
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKR2 (1 + n) x) -> Int
rwidth ADVal target (TKR 2 r)
ADVal target (TKR2 (1 + 1) (TKScalar r))
m1) ADVal target (TKR 2 r)
m2))
  trreplicate :: forall (n :: Nat) (x :: TK).
(KnownNat n, KnownSTK x) =>
Int -> ADVal target (TKR2 n x) -> ADVal target (TKR2 (1 + n) x)
trreplicate Int
k (D target (TKR2 n x)
u Delta target (TKR2 n x)
u') = Int
-> (forall (n :: Nat).
    KnownNat n =>
    SNat n -> ADVal target (TKR2 (1 + n) x))
-> ADVal target (TKR2 (1 + n) x)
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
k ((forall (n :: Nat).
  KnownNat n =>
  SNat n -> ADVal target (TKR2 (1 + n) x))
 -> ADVal target (TKR2 (1 + n) x))
-> (forall (n :: Nat).
    KnownNat n =>
    SNat n -> ADVal target (TKR2 (1 + n) x))
-> ADVal target (TKR2 (1 + n) x)
forall a b. (a -> b) -> a -> b
$ \SNat n
snat ->
    target (TKR2 (1 + n) x)
-> Delta target (TKR2 (1 + n) x) -> ADVal target (TKR2 (1 + n) x)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (Int -> target (TKR2 n x) -> target (TKR2 (1 + n) x)
forall (n :: Nat) (x :: TK).
(KnownNat n, KnownSTK x) =>
Int -> target (TKR2 n x) -> target (TKR2 (1 + n) x)
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownNat n, KnownSTK x) =>
Int -> target (TKR2 n x) -> target (TKR2 (1 + n) x)
trreplicate Int
k target (TKR2 n x)
u) (SNat n
-> SingletonTK (TKR2 n x)
-> Delta target (TKR2 n x)
-> Delta target (BuildTensorKind n (TKR2 n x))
forall (y :: TK) (k :: Nat) (a :: Target).
SNat k
-> SingletonTK y -> Delta a y -> Delta a (BuildTensorKind k y)
DeltaReplicate SNat n
snat SingletonTK (TKR2 n x)
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Delta target (TKR2 n x)
u')
  trindex :: forall (m :: Nat) (n :: Nat) (x :: TK).
(KnownNat m, KnownNat n, KnownSTK x) =>
ADVal target (TKR2 (m + n) x)
-> IxROf (ADVal target) m -> ADVal target (TKR2 n x)
trindex (D target (TKR2 (m + n) x)
u Delta target (TKR2 (m + n) x)
u') IxROf (ADVal target) m
i =
    let !ix :: IxR m (PrimalOf target (TKScalar Int64))
ix = PrimalOf target (TKScalar Int64)
-> PrimalOf target (TKScalar Int64)
forall (y :: TK). PrimalOf target y -> PrimalOf target y
forall (target :: Target) (y :: TK).
ShareTensor target =>
target y -> target y
tshare (PrimalOf target (TKScalar Int64)
 -> PrimalOf target (TKScalar Int64))
-> (target (TKScalar Int64) -> PrimalOf target (TKScalar Int64))
-> target (TKScalar Int64)
-> PrimalOf target (TKScalar Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. target (TKScalar Int64) -> PrimalOf target (TKScalar Int64)
forall (y :: TK). target y -> PrimalOf target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
target y -> PrimalOf target y
tprimalPart (target (TKScalar Int64) -> PrimalOf target (TKScalar Int64))
-> IxR m (target (TKScalar Int64))
-> IxR m (PrimalOf target (TKScalar Int64))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> IxR m (target (TKScalar Int64))
IxROf (ADVal target) m
i
    in target (TKR2 n x)
-> Delta target (TKR2 n x) -> ADVal target (TKR2 n x)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (target (TKR2 (m + n) x)
-> IxR m (PrimalOf target (TKScalar Int64)) -> target (TKR2 n x)
forall (m :: Nat) (n :: Nat) (x :: TK).
(KnownNat m, KnownNat n, KnownSTK x) =>
target (TKR2 (m + n) x) -> IxROf target m -> target (TKR2 n x)
forall (target :: Target) (m :: Nat) (n :: Nat) (x :: TK).
(BaseTensor target, KnownNat m, KnownNat n, KnownSTK x) =>
target (TKR2 (m + n) x) -> IxROf target m -> target (TKR2 n x)
trindex target (TKR2 (m + n) x)
u IxR m (PrimalOf target (TKScalar Int64))
ix) (SNat n
-> Delta target (TKR2 (m + n) x)
-> IxR m (PrimalOf target (TKScalar Int64))
-> Delta target (TKR2 n x)
forall (m :: Nat) (n :: Nat) (r :: TK) (a :: Target).
SNat n
-> Delta a (TKR2 (m + n) r) -> IxROf a m -> Delta a (TKR2 n r)
DeltaIndexR SNat n
forall (n :: Nat). KnownNat n => SNat n
SNat Delta target (TKR2 (m + n) x)
u' IxR m (PrimalOf target (TKScalar Int64))
ix)
  trscatter :: forall (m :: Nat) (n :: Nat) (p :: Nat) (x :: TK).
(KnownNat m, KnownNat n, KnownNat p, KnownSTK x) =>
IShR (p + n)
-> ADVal target (TKR2 (m + n) x)
-> (IxROf (ADVal target) m -> IxROf (ADVal target) p)
-> ADVal target (TKR2 (p + n) x)
trscatter IShR (p + n)
sh (D target (TKR2 (m + n) x)
u Delta target (TKR2 (m + n) x)
u') IxROf (ADVal target) m -> IxROf (ADVal target) p
f =
    let g :: IxR m (PrimalOf target (TKScalar Int64))
-> IxR p (PrimalOf target (TKScalar Int64))
g IxR m (PrimalOf target (TKScalar Int64))
x = target (TKScalar Int64) -> PrimalOf target (TKScalar Int64)
forall (y :: TK). target y -> PrimalOf target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
target y -> PrimalOf target y
tprimalPart (target (TKScalar Int64) -> PrimalOf target (TKScalar Int64))
-> IxR p (target (TKScalar Int64))
-> IxR p (PrimalOf target (TKScalar Int64))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> IxROf (ADVal target) m -> IxROf (ADVal target) p
f (SingletonTK (TKScalar Int64)
-> PrimalOf target (TKScalar Int64) -> target (TKScalar Int64)
forall (y :: TK). SingletonTK y -> PrimalOf target y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> PrimalOf target y -> target y
tfromPrimal SingletonTK (TKScalar Int64)
forall r. GoodScalar r => SingletonTK (TKScalar r)
STKScalar (PrimalOf target (TKScalar Int64) -> target (TKScalar Int64))
-> IxR m (PrimalOf target (TKScalar Int64))
-> IxR m (target (TKScalar Int64))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> IxR m (PrimalOf target (TKScalar Int64))
x)
    in target (TKR2 (p + n) x)
-> Delta target (TKR2 (p + n) x) -> ADVal target (TKR2 (p + n) x)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (IShR (p + n)
-> target (TKR2 (m + n) x)
-> (IxR m (PrimalOf target (TKScalar Int64))
    -> IxR p (PrimalOf target (TKScalar Int64)))
-> target (TKR2 (p + n) x)
forall (m :: Nat) (n :: Nat) (p :: Nat) (x :: TK).
(KnownNat m, KnownNat n, KnownNat p, KnownSTK x) =>
IShR (p + n)
-> target (TKR2 (m + n) x)
-> (IxROf target m -> IxROf target p)
-> target (TKR2 (p + n) x)
forall (target :: Target) (m :: Nat) (n :: Nat) (p :: Nat)
       (x :: TK).
(BaseTensor target, KnownNat m, KnownNat n, KnownNat p,
 KnownSTK x) =>
IShR (p + n)
-> target (TKR2 (m + n) x)
-> (IxROf target m -> IxROf target p)
-> target (TKR2 (p + n) x)
trscatter IShR (p + n)
sh target (TKR2 (m + n) x)
u IxR m (PrimalOf target (TKScalar Int64))
-> IxR p (PrimalOf target (TKScalar Int64))
g) (SNat m
-> SNat n
-> SNat p
-> IShR (p + n)
-> Delta target (TKR2 (m + n) x)
-> (IxR m (PrimalOf target (TKScalar Int64))
    -> IxR p (PrimalOf target (TKScalar Int64)))
-> Delta target (TKR2 (p + n) x)
forall (m :: Nat) (n :: Nat) (p :: Nat) (r :: TK) (a :: Target).
SNat m
-> SNat n
-> SNat p
-> IShR (p + n)
-> Delta a (TKR2 (m + n) r)
-> (IxROf a m -> IxROf a p)
-> Delta a (TKR2 (p + n) r)
DeltaScatterR SNat m
forall (n :: Nat). KnownNat n => SNat n
SNat SNat n
forall (n :: Nat). KnownNat n => SNat n
SNat SNat p
forall (n :: Nat). KnownNat n => SNat n
SNat IShR (p + n)
sh Delta target (TKR2 (m + n) x)
u' IxR m (PrimalOf target (TKScalar Int64))
-> IxR p (PrimalOf target (TKScalar Int64))
g)
  trgather :: forall (m :: Nat) (n :: Nat) (p :: Nat) (x :: TK).
(KnownNat m, KnownNat n, KnownNat p, KnownSTK x) =>
IShR (m + n)
-> ADVal target (TKR2 (p + n) x)
-> (IxROf (ADVal target) m -> IxROf (ADVal target) p)
-> ADVal target (TKR2 (m + n) x)
trgather IShR (m + n)
sh (D target (TKR2 (p + n) x)
u Delta target (TKR2 (p + n) x)
u') IxROf (ADVal target) m -> IxROf (ADVal target) p
f =
    let g :: IxR m (PrimalOf target (TKScalar Int64))
-> IxR p (PrimalOf target (TKScalar Int64))
g IxR m (PrimalOf target (TKScalar Int64))
x = target (TKScalar Int64) -> PrimalOf target (TKScalar Int64)
forall (y :: TK). target y -> PrimalOf target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
target y -> PrimalOf target y
tprimalPart (target (TKScalar Int64) -> PrimalOf target (TKScalar Int64))
-> IxR p (target (TKScalar Int64))
-> IxR p (PrimalOf target (TKScalar Int64))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> IxROf (ADVal target) m -> IxROf (ADVal target) p
f (SingletonTK (TKScalar Int64)
-> PrimalOf target (TKScalar Int64) -> target (TKScalar Int64)
forall (y :: TK). SingletonTK y -> PrimalOf target y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> PrimalOf target y -> target y
tfromPrimal SingletonTK (TKScalar Int64)
forall r. GoodScalar r => SingletonTK (TKScalar r)
STKScalar (PrimalOf target (TKScalar Int64) -> target (TKScalar Int64))
-> IxR m (PrimalOf target (TKScalar Int64))
-> IxR m (target (TKScalar Int64))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> IxR m (PrimalOf target (TKScalar Int64))
x)
    in target (TKR2 (m + n) x)
-> Delta target (TKR2 (m + n) x) -> ADVal target (TKR2 (m + n) x)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (IShR (m + n)
-> target (TKR2 (p + n) x)
-> (IxR m (PrimalOf target (TKScalar Int64))
    -> IxR p (PrimalOf target (TKScalar Int64)))
-> target (TKR2 (m + n) x)
forall (m :: Nat) (n :: Nat) (p :: Nat) (x :: TK).
(KnownNat m, KnownNat n, KnownNat p, KnownSTK x) =>
IShR (m + n)
-> target (TKR2 (p + n) x)
-> (IxROf target m -> IxROf target p)
-> target (TKR2 (m + n) x)
forall (target :: Target) (m :: Nat) (n :: Nat) (p :: Nat)
       (x :: TK).
(BaseTensor target, KnownNat m, KnownNat n, KnownNat p,
 KnownSTK x) =>
IShR (m + n)
-> target (TKR2 (p + n) x)
-> (IxROf target m -> IxROf target p)
-> target (TKR2 (m + n) x)
trgather IShR (m + n)
sh target (TKR2 (p + n) x)
u IxR m (PrimalOf target (TKScalar Int64))
-> IxR p (PrimalOf target (TKScalar Int64))
g) (SNat m
-> SNat n
-> SNat p
-> IShR (m + n)
-> Delta target (TKR2 (p + n) x)
-> (IxR m (PrimalOf target (TKScalar Int64))
    -> IxR p (PrimalOf target (TKScalar Int64)))
-> Delta target (TKR2 (m + n) x)
forall (m :: Nat) (n :: Nat) (p :: Nat) (r :: TK) (a :: Target).
SNat m
-> SNat n
-> SNat p
-> IShR (m + n)
-> Delta a (TKR2 (p + n) r)
-> (IxROf a m -> IxROf a p)
-> Delta a (TKR2 (m + n) r)
DeltaGatherR SNat m
forall (n :: Nat). KnownNat n => SNat n
SNat SNat n
forall (n :: Nat). KnownNat n => SNat n
SNat SNat p
forall (n :: Nat). KnownNat n => SNat n
SNat IShR (m + n)
sh Delta target (TKR2 (p + n) x)
u' IxR m (PrimalOf target (TKScalar Int64))
-> IxR p (PrimalOf target (TKScalar Int64))
g)
      -- Note how f is not interpreted as a function on dual numbers
      -- but just on integers and so no cotangents for results of application
      -- of f have to be computed and stored in contangent maps later on.
      -- Note also how g is duplicated and this leads to loss of sharing
      -- of indexes in AST instances.
  trconcrete :: forall r (n :: Nat).
GoodScalar r =>
Ranked n r -> ADVal target (TKR n r)
trconcrete Ranked n r
a =
    let v :: target (TKR n r)
v = Ranked n r -> target (TKR n r)
forall r (n :: Nat). GoodScalar r => Ranked n r -> target (TKR n r)
forall (target :: Target) r (n :: Nat).
(BaseTensor target, GoodScalar r) =>
Ranked n r -> target (TKR n r)
trconcrete Ranked n r
a
    in FullShapeTK (TKR n r) -> target (TKR n r) -> ADVal target (TKR n r)
forall (z :: TK) (f :: Target). FullShapeTK z -> f z -> ADVal f z
fromPrimalFTK (IShR n -> FullShapeTK (TKScalar r) -> FullShapeTK (TKR n r)
forall (n :: Nat) (x :: TK).
IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
FTKR (Ranked n r -> IShR n
forall a (n :: Nat). Elt a => Ranked n a -> IShR n
Nested.rshape Ranked n r
a) FullShapeTK (TKScalar r)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar) target (TKR n r)
v
  trfloor :: forall r r2 (n :: Nat).
(GoodScalar r, RealFrac r, GoodScalar r2, Integral r2) =>
ADVal target (TKR n r) -> ADVal target (TKR n r2)
trfloor (D target (TKR n r)
u Delta target (TKR n r)
_) =
    let v :: target (TKR n r2)
v = target (TKR n r) -> target (TKR n r2)
forall r r2 (n :: Nat).
(GoodScalar r, RealFrac r, GoodScalar r2, Integral r2) =>
target (TKR n r) -> target (TKR n r2)
forall (target :: Target) r r2 (n :: Nat).
(BaseTensor target, GoodScalar r, RealFrac r, GoodScalar r2,
 Integral r2) =>
target (TKR n r) -> target (TKR n r2)
trfloor target (TKR n r)
u
    in FullShapeTK (TKR n r2)
-> target (TKR n r2) -> ADVal target (TKR n r2)
forall (z :: TK) (f :: Target). FullShapeTK z -> f z -> ADVal f z
fromPrimalFTK (IShR n -> FullShapeTK (TKScalar r2) -> FullShapeTK (TKR n r2)
forall (n :: Nat) (x :: TK).
IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
FTKR (target (TKR n r2) -> IShR n
forall (n :: Nat) (x :: TK).
KnownSTK x =>
target (TKR2 n x) -> IShR n
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKR2 n x) -> IShR n
rshape target (TKR n r2)
v) FullShapeTK (TKScalar r2)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar) target (TKR n r2)
v
  trfromIntegral :: forall r1 r2 (n :: Nat).
(GoodScalar r1, Integral r1, GoodScalar r2) =>
ADVal target (TKR n r1) -> ADVal target (TKR n r2)
trfromIntegral (D target (TKR n r1)
u Delta target (TKR n r1)
_) =
    let v :: target (TKR n r2)
v = target (TKR n r1) -> target (TKR n r2)
forall r1 r2 (n :: Nat).
(GoodScalar r1, Integral r1, GoodScalar r2) =>
target (TKR n r1) -> target (TKR n r2)
forall (target :: Target) r1 r2 (n :: Nat).
(BaseTensor target, GoodScalar r1, Integral r1, GoodScalar r2) =>
target (TKR n r1) -> target (TKR n r2)
trfromIntegral target (TKR n r1)
u
    in FullShapeTK (TKR n r2)
-> target (TKR n r2) -> ADVal target (TKR n r2)
forall (z :: TK) (f :: Target). FullShapeTK z -> f z -> ADVal f z
fromPrimalFTK (IShR n -> FullShapeTK (TKScalar r2) -> FullShapeTK (TKR n r2)
forall (n :: Nat) (x :: TK).
IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
FTKR (target (TKR n r2) -> IShR n
forall (n :: Nat) (x :: TK).
KnownSTK x =>
target (TKR2 n x) -> IShR n
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKR2 n x) -> IShR n
rshape target (TKR n r2)
v) FullShapeTK (TKScalar r2)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar) target (TKR n r2)
v
  trcast :: forall r1 r2 (n :: Nat).
(RealFrac r1, GoodScalar r1, RealFrac r2, GoodScalar r2) =>
ADVal target (TKR n r1) -> ADVal target (TKR n r2)
trcast (D target (TKR n r1)
u Delta target (TKR n r1)
u') = target (TKR n r2)
-> Delta target (TKR n r2) -> ADVal target (TKR n r2)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (target (TKR n r1) -> target (TKR n r2)
forall r1 r2 (n :: Nat).
(RealFrac r1, GoodScalar r1, RealFrac r2, GoodScalar r2) =>
target (TKR n r1) -> target (TKR n r2)
forall (target :: Target) r1 r2 (n :: Nat).
(BaseTensor target, RealFrac r1, GoodScalar r1, RealFrac r2,
 GoodScalar r2) =>
target (TKR n r1) -> target (TKR n r2)
trcast target (TKR n r1)
u) (Delta target (TKR n r1) -> Delta target (TKR n r2)
forall r1 r2 (a :: Target) (n :: Nat).
(GoodScalar r1, RealFrac r1, GoodScalar r2, RealFrac r2) =>
Delta a (TKR n r1) -> Delta a (TKR2 n (TKScalar r2))
DeltaCastR Delta target (TKR n r1)
u')
  trminIndex :: forall (n :: Nat) r r2.
(GoodScalar r, GoodScalar r2) =>
ADVal target (TKR (1 + n) r) -> ADVal target (TKR n r2)
trminIndex (D target (TKR (1 + n) r)
u Delta target (TKR (1 + n) r)
_) =
    let v :: target (TKR n r2)
v = target (TKR (1 + n) r) -> target (TKR n r2)
forall (n :: Nat) r r2.
(GoodScalar r, GoodScalar r2) =>
target (TKR (1 + n) r) -> target (TKR n r2)
forall (target :: Target) (n :: Nat) r r2.
(BaseTensor target, GoodScalar r, GoodScalar r2) =>
target (TKR (1 + n) r) -> target (TKR n r2)
trminIndex target (TKR (1 + n) r)
u
    in FullShapeTK (TKR n r2)
-> target (TKR n r2) -> ADVal target (TKR n r2)
forall (z :: TK) (f :: Target). FullShapeTK z -> f z -> ADVal f z
fromPrimalFTK (IShR n -> FullShapeTK (TKScalar r2) -> FullShapeTK (TKR n r2)
forall (n :: Nat) (x :: TK).
IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
FTKR (target (TKR n r2) -> IShR n
forall (n :: Nat) (x :: TK).
KnownSTK x =>
target (TKR2 n x) -> IShR n
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKR2 n x) -> IShR n
rshape target (TKR n r2)
v) FullShapeTK (TKScalar r2)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar) target (TKR n r2)
v
  trmaxIndex :: forall (n :: Nat) r r2.
(GoodScalar r, GoodScalar r2) =>
ADVal target (TKR (1 + n) r) -> ADVal target (TKR n r2)
trmaxIndex (D target (TKR (1 + n) r)
u Delta target (TKR (1 + n) r)
_) =
    let v :: target (TKR n r2)
v = target (TKR (1 + n) r) -> target (TKR n r2)
forall (n :: Nat) r r2.
(GoodScalar r, GoodScalar r2) =>
target (TKR (1 + n) r) -> target (TKR n r2)
forall (target :: Target) (n :: Nat) r r2.
(BaseTensor target, GoodScalar r, GoodScalar r2) =>
target (TKR (1 + n) r) -> target (TKR n r2)
trmaxIndex target (TKR (1 + n) r)
u
    in FullShapeTK (TKR n r2)
-> target (TKR n r2) -> ADVal target (TKR n r2)
forall (z :: TK) (f :: Target). FullShapeTK z -> f z -> ADVal f z
fromPrimalFTK (IShR n -> FullShapeTK (TKScalar r2) -> FullShapeTK (TKR n r2)
forall (n :: Nat) (x :: TK).
IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
FTKR (target (TKR n r2) -> IShR n
forall (n :: Nat) (x :: TK).
KnownSTK x =>
target (TKR2 n x) -> IShR n
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKR2 n x) -> IShR n
rshape target (TKR n r2)
v) FullShapeTK (TKScalar r2)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar) target (TKR n r2)
v
  triota :: forall r. GoodScalar r => Int -> ADVal target (TKR 1 r)
triota Int
n = FullShapeTK (TKR 1 r) -> target (TKR 1 r) -> ADVal target (TKR 1 r)
forall (z :: TK) (f :: Target). FullShapeTK z -> f z -> ADVal f z
fromPrimalFTK (IShR 1 -> FullShapeTK (TKScalar r) -> FullShapeTK (TKR 1 r)
forall (n :: Nat) (x :: TK).
IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
FTKR (Int
n Int -> ShR 0 Int -> IShR 1
forall {n1 :: Nat} {i} (n :: Nat).
((n + 1 :: Nat) ~ (n1 :: Nat)) =>
i -> ShR n i -> ShR n1 i
:$: ShR 0 Int
forall (n :: Nat) i. ((n :: Nat) ~ (0 :: Nat)) => ShR n i
ZSR) FullShapeTK (TKScalar r)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar) (target (TKR 1 r) -> ADVal target (TKR 1 r))
-> target (TKR 1 r) -> ADVal target (TKR 1 r)
forall a b. (a -> b) -> a -> b
$ Int -> target (TKR 1 r)
forall r. GoodScalar r => Int -> target (TKR 1 r)
forall (target :: Target) r.
(BaseTensor target, GoodScalar r) =>
Int -> target (TKR 1 r)
triota Int
n
  trappend :: forall (n :: Nat) (x :: TK).
KnownSTK x =>
ADVal target (TKR2 (1 + n) x)
-> ADVal target (TKR2 (1 + n) x) -> ADVal target (TKR2 (1 + n) x)
trappend (D target (TKR2 (1 + n) x)
u Delta target (TKR2 (1 + n) x)
u') (D target (TKR2 (1 + n) x)
v Delta target (TKR2 (1 + n) x)
v') = target (TKR2 (1 + n) x)
-> Delta target (TKR2 (1 + n) x) -> ADVal target (TKR2 (1 + n) x)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (target (TKR2 (1 + n) x)
-> target (TKR2 (1 + n) x) -> target (TKR2 (1 + n) x)
forall (n :: Nat) (x :: TK).
KnownSTK x =>
target (TKR2 (1 + n) x)
-> target (TKR2 (1 + n) x) -> target (TKR2 (1 + n) x)
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKR2 (1 + n) x)
-> target (TKR2 (1 + n) x) -> target (TKR2 (1 + n) x)
trappend target (TKR2 (1 + n) x)
u target (TKR2 (1 + n) x)
v) (Delta target (TKR2 (1 + n) x)
-> Delta target (TKR2 (1 + n) x) -> Delta target (TKR2 (1 + n) x)
forall (a :: Target) (n :: Nat) (r :: TK).
Delta a (TKR2 (1 + n) r)
-> Delta a (TKR2 (1 + n) r) -> Delta a (TKR2 (1 + n) r)
DeltaAppendR Delta target (TKR2 (1 + n) x)
u' Delta target (TKR2 (1 + n) x)
v')
  trslice :: forall (n :: Nat) (x :: TK).
KnownSTK x =>
Int
-> Int
-> ADVal target (TKR2 (1 + n) x)
-> ADVal target (TKR2 (1 + n) x)
trslice Int
i Int
n (D target (TKR2 (1 + n) x)
u Delta target (TKR2 (1 + n) x)
u') = target (TKR2 (1 + n) x)
-> Delta target (TKR2 (1 + n) x) -> ADVal target (TKR2 (1 + n) x)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (Int -> Int -> target (TKR2 (1 + n) x) -> target (TKR2 (1 + n) x)
forall (n :: Nat) (x :: TK).
KnownSTK x =>
Int -> Int -> target (TKR2 (1 + n) x) -> target (TKR2 (1 + n) x)
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
Int -> Int -> target (TKR2 (1 + n) x) -> target (TKR2 (1 + n) x)
trslice Int
i Int
n target (TKR2 (1 + n) x)
u) (Int
-> Int
-> Delta target (TKR2 (1 + n) x)
-> Delta target (TKR2 (1 + n) x)
forall (a :: Target) (n :: Nat) (r :: TK).
Int -> Int -> Delta a (TKR2 (1 + n) r) -> Delta a (TKR2 (1 + n) r)
DeltaSliceR Int
i Int
n Delta target (TKR2 (1 + n) x)
u')
  trreverse :: forall (n :: Nat) (x :: TK).
KnownSTK x =>
ADVal target (TKR2 (1 + n) x) -> ADVal target (TKR2 (1 + n) x)
trreverse (D target (TKR2 (1 + n) x)
u Delta target (TKR2 (1 + n) x)
u') = target (TKR2 (1 + n) x)
-> Delta target (TKR2 (1 + n) x) -> ADVal target (TKR2 (1 + n) x)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (target (TKR2 (1 + n) x) -> target (TKR2 (1 + n) x)
forall (n :: Nat) (x :: TK).
KnownSTK x =>
target (TKR2 (1 + n) x) -> target (TKR2 (1 + n) x)
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKR2 (1 + n) x) -> target (TKR2 (1 + n) x)
trreverse target (TKR2 (1 + n) x)
u) (Delta target (TKR2 (1 + n) x) -> Delta target (TKR2 (1 + n) x)
forall (a :: Target) (n :: Nat) (r :: TK).
Delta a (TKR2 (1 + n) r) -> Delta a (TKR2 (1 + n) r)
DeltaReverseR Delta target (TKR2 (1 + n) x)
u')
  trtranspose :: forall (n :: Nat) (x :: TK).
KnownSTK x =>
PermR -> ADVal target (TKR2 n x) -> ADVal target (TKR2 n x)
trtranspose PermR
perm (D target (TKR2 n x)
u Delta target (TKR2 n x)
u') = target (TKR2 n x)
-> Delta target (TKR2 n x) -> ADVal target (TKR2 n x)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (PermR -> target (TKR2 n x) -> target (TKR2 n x)
forall (n :: Nat) (x :: TK).
KnownSTK x =>
PermR -> target (TKR2 n x) -> target (TKR2 n x)
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
PermR -> target (TKR2 n x) -> target (TKR2 n x)
trtranspose PermR
perm target (TKR2 n x)
u) (PermR -> Delta target (TKR2 n x) -> Delta target (TKR2 n x)
forall (a :: Target) (n :: Nat) (r :: TK).
PermR -> Delta a (TKR2 n r) -> Delta a (TKR2 n r)
DeltaTransposeR PermR
perm Delta target (TKR2 n x)
u')
  trreshape :: forall (n :: Nat) (m :: Nat) (x :: TK).
KnownSTK x =>
IShR m -> ADVal target (TKR2 n x) -> ADVal target (TKR2 m x)
trreshape IShR m
sh (D target (TKR2 n x)
u Delta target (TKR2 n x)
u') = target (TKR2 m x)
-> Delta target (TKR2 m x) -> ADVal target (TKR2 m x)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (IShR m -> target (TKR2 n x) -> target (TKR2 m x)
forall (n :: Nat) (m :: Nat) (x :: TK).
KnownSTK x =>
IShR m -> target (TKR2 n x) -> target (TKR2 m x)
forall (target :: Target) (n :: Nat) (m :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
IShR m -> target (TKR2 n x) -> target (TKR2 m x)
trreshape IShR m
sh target (TKR2 n x)
u) (IShR m -> Delta target (TKR2 n x) -> Delta target (TKR2 m x)
forall (m :: Nat) (a :: Target) (n :: Nat) (r :: TK).
IShR m -> Delta a (TKR2 n r) -> Delta a (TKR2 m r)
DeltaReshapeR IShR m
sh Delta target (TKR2 n x)
u')
  trbuild1 :: forall (n :: Nat) (x :: TK).
(KnownNat n, KnownSTK x) =>
Int
-> (IntOf (ADVal target) -> ADVal target (TKR2 n x))
-> ADVal target (TKR2 (1 + n) x)
trbuild1 @n @x Int
k IntOf (ADVal target) -> ADVal target (TKR2 n x)
f =
    let l :: [Integer]
l = [Integer
0 .. Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1]
    in if [Integer] -> Bool
forall a. [a] -> Bool
forall (t :: Type -> Type) a. Foldable t => t a -> Bool
null [Integer]
l
       then case Proxy @Nat n -> Proxy @Nat 0 -> Maybe ((:~:) @Nat n 0)
forall (a :: Nat) (b :: Nat) (proxy1 :: Nat -> Type)
       (proxy2 :: Nat -> Type).
(KnownNat a, KnownNat b) =>
proxy1 a -> proxy2 b -> Maybe ((:~:) @Nat a b)
sameNat (forall (t :: Nat). Proxy @Nat t
forall {k} (t :: k). Proxy @k t
Proxy @n) (forall (t :: Nat). Proxy @Nat t
forall {k} (t :: k). Proxy @k t
Proxy @0) of
         Just (:~:) @Nat n 0
Refl | Dict @Type KnownElt (RepConcrete x)
Dict <- SingletonTK x -> Dict @Type KnownElt (RepConcrete x)
forall (y :: TK).
SingletonTK y -> Dict @Type KnownElt (RepConcrete y)
eltDictRep (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @x) ->
           let arr :: Ranked 1 (RepConcrete x)
arr = Ranked 1 (RepConcrete x)
forall a. KnownElt a => Ranked 1 a
Nested.remptyArray
           in FullShapeTK (TKR2 1 x)
-> Concrete (TKR2 1 x) -> ADVal target (TKR2 1 x)
forall (y :: TK). FullShapeTK y -> Concrete y -> ADVal target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
FullShapeTK y -> Concrete y -> target y
tconcrete (SingletonTK (TKR2 1 x)
-> RepConcrete (TKR2 1 x) -> FullShapeTK (TKR2 1 x)
forall (y :: TK). SingletonTK y -> RepConcrete y -> FullShapeTK y
tftkG SingletonTK (TKR2 1 x)
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Ranked 1 (RepConcrete x)
RepConcrete (TKR2 1 x)
arr) (RepConcrete (TKR2 1 x) -> Concrete (TKR2 1 x)
forall (y :: TK). RepConcrete y -> Concrete y
Concrete Ranked 1 (RepConcrete x)
RepConcrete (TKR2 1 x)
arr)
         Maybe ((:~:) @Nat n 0)
Nothing -> [Char] -> ADVal target (TKR2 (1 + n) x)
forall a. HasCallStack => [Char] -> a
error [Char]
"rbuild1: shape ambiguity"
       else Vector (ADVal target (TKR2 n x)) -> ADVal target (TKR2 (1 + n) x)
forall (n :: Nat) (x :: TK).
(KnownNat n, KnownSTK x) =>
Vector (ADVal target (TKR2 n x)) -> ADVal target (TKR2 (1 + n) x)
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownNat n, KnownSTK x) =>
Vector (target (TKR2 n x)) -> target (TKR2 (1 + n) x)
trfromVector (Vector (ADVal target (TKR2 n x)) -> ADVal target (TKR2 (1 + n) x))
-> Vector (ADVal target (TKR2 n x))
-> ADVal target (TKR2 (1 + n) x)
forall a b. (a -> b) -> a -> b
$ [ADVal target (TKR2 n x)] -> Vector (ADVal target (TKR2 n x))
forall (v :: Type -> Type) a. Vector v a => [a] -> v a
V.fromList ([ADVal target (TKR2 n x)] -> Vector (ADVal target (TKR2 n x)))
-> [ADVal target (TKR2 n x)] -> Vector (ADVal target (TKR2 n x))
forall a b. (a -> b) -> a -> b
$ (Integer -> ADVal target (TKR2 n x))
-> [Integer] -> [ADVal target (TKR2 n x)]
forall a b. (a -> b) -> [a] -> [b]
map (target (TKScalar Int64) -> ADVal target (TKR2 n x)
IntOf (ADVal target) -> ADVal target (TKR2 n x)
f (target (TKScalar Int64) -> ADVal target (TKR2 n x))
-> (Integer -> target (TKScalar Int64))
-> Integer
-> ADVal target (TKR2 n x)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> target (TKScalar Int64)
forall a. Num a => Integer -> a
fromInteger) [Integer]
l
              -- hope this fuses

  -- Shaped ops
  sshape :: forall (sh :: [Nat]) (x :: TK).
KnownSTK x =>
ADVal target (TKS2 sh x) -> ShS sh
sshape (D target (TKS2 sh x)
u Delta target (TKS2 sh x)
_) = target (TKS2 sh x) -> ShS sh
forall (sh :: [Nat]) (x :: TK).
KnownSTK x =>
target (TKS2 sh x) -> ShS sh
forall (target :: Target) (sh :: [Nat]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKS2 sh x) -> ShS sh
sshape target (TKS2 sh x)
u
  tssum :: forall (n :: Nat) (sh :: [Nat]) (x :: TK).
(KnownNat n, KnownShS sh, KnownSTK x) =>
ADVal target (TKS2 ((':) @Nat n sh) x) -> ADVal target (TKS2 sh x)
tssum (D target (TKS2 ((':) @Nat n sh) x)
u Delta target (TKS2 ((':) @Nat n sh) x)
u') = target (TKS2 sh x)
-> Delta target (TKS2 sh x) -> ADVal target (TKS2 sh x)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (target (TKS2 ((':) @Nat n sh) x) -> target (TKS2 sh x)
forall (n :: Nat) (sh :: [Nat]) (x :: TK).
(KnownNat n, KnownShS sh, KnownSTK x) =>
target (TKS2 ((':) @Nat n sh) x) -> target (TKS2 sh x)
forall (target :: Target) (n :: Nat) (sh :: [Nat]) (x :: TK).
(BaseTensor target, KnownNat n, KnownShS sh, KnownSTK x) =>
target (TKS2 ((':) @Nat n sh) x) -> target (TKS2 sh x)
tssum target (TKS2 ((':) @Nat n sh) x)
u) (SNat n
-> SingletonTK (TKS2 sh x)
-> Delta target (BuildTensorKind n (TKS2 sh x))
-> Delta target (TKS2 sh x)
forall (b :: TK) (k :: Nat) (a :: Target).
SNat k
-> SingletonTK b -> Delta a (BuildTensorKind k b) -> Delta a b
DeltaSum SNat n
forall (n :: Nat). KnownNat n => SNat n
SNat SingletonTK (TKS2 sh x)
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Delta target (BuildTensorKind n (TKS2 sh x))
Delta target (TKS2 ((':) @Nat n sh) x)
u')
  tssum0 :: forall (sh :: [Nat]) (x :: TK).
(KnownShS sh, KnownSTK x) =>
ADVal target (TKS2 sh x) -> ADVal target (TKS2 ('[] @Nat) x)
tssum0 (D target (TKS2 sh x)
u Delta target (TKS2 sh x)
u') = target (TKS2 ('[] @Nat) x)
-> Delta target (TKS2 ('[] @Nat) x)
-> ADVal target (TKS2 ('[] @Nat) x)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (target (TKS2 sh x) -> target (TKS2 ('[] @Nat) x)
forall (sh :: [Nat]) (x :: TK).
(KnownShS sh, KnownSTK x) =>
target (TKS2 sh x) -> target (TKS2 ('[] @Nat) x)
forall (target :: Target) (sh :: [Nat]) (x :: TK).
(BaseTensor target, KnownShS sh, KnownSTK x) =>
target (TKS2 sh x) -> target (TKS2 ('[] @Nat) x)
tssum0 target (TKS2 sh x)
u) (Delta target (TKS2 sh x) -> Delta target (TKS2 ('[] @Nat) x)
forall (a :: Target) (sh :: [Nat]) (r :: TK).
Delta a (TKS2 sh r) -> Delta a (TKS2 ('[] @Nat) r)
DeltaSum0S Delta target (TKS2 sh x)
u')
  tsdot0 :: forall (sh :: [Nat]) r.
(KnownShS sh, GoodScalar r) =>
ADVal target (TKS sh r)
-> ADVal target (TKS sh r) -> ADVal target (TKS ('[] @Nat) r)
tsdot0 (D target (TKS sh r)
ue Delta target (TKS sh r)
u') (D target (TKS sh r)
ve Delta target (TKS sh r)
v') =
    -- The bangs below are neccessary for GHC 9.2.7 test results to match 9.4.
    let !u :: target (TKS sh r)
u = target (TKS sh r) -> target (TKS sh r)
forall (y :: TK). target y -> target y
forall (target :: Target) (y :: TK).
ShareTensor target =>
target y -> target y
tshare target (TKS sh r)
ue in
    let !v :: target (TKS sh r)
v = target (TKS sh r) -> target (TKS sh r)
forall (y :: TK). target y -> target y
forall (target :: Target) (y :: TK).
ShareTensor target =>
target y -> target y
tshare target (TKS sh r)
ve
    in target (TKS ('[] @Nat) r)
-> Delta target (TKS ('[] @Nat) r)
-> ADVal target (TKS ('[] @Nat) r)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (target (TKS sh r) -> target (TKS sh r) -> target (TKS ('[] @Nat) r)
forall (sh :: [Nat]) r.
(KnownShS sh, GoodScalar r) =>
target (TKS sh r) -> target (TKS sh r) -> target (TKS ('[] @Nat) r)
forall (target :: Target) (sh :: [Nat]) r.
(BaseTensor target, KnownShS sh, GoodScalar r) =>
target (TKS sh r) -> target (TKS sh r) -> target (TKS ('[] @Nat) r)
tsdot0 target (TKS sh r)
u target (TKS sh r)
v) (Delta target (TKS ('[] @Nat) r)
-> Delta target (TKS ('[] @Nat) r)
-> Delta target (TKS ('[] @Nat) r)
forall (f :: Target) (z :: TK).
Num (f z) =>
Delta f z -> Delta f z -> Delta f z
dAdd (target (TKS sh r)
-> Delta target (TKS sh r) -> Delta target (TKS ('[] @Nat) r)
forall r (a :: Target) (sh :: [Nat]).
(GoodScalar r, Show (a (TKS sh r))) =>
a (TKS sh r)
-> Delta a (TKS sh r) -> Delta a (TKS2 ('[] @Nat) (TKScalar r))
DeltaDot0S target (TKS sh r)
v Delta target (TKS sh r)
u') (target (TKS sh r)
-> Delta target (TKS sh r) -> Delta target (TKS ('[] @Nat) r)
forall r (a :: Target) (sh :: [Nat]).
(GoodScalar r, Show (a (TKS sh r))) =>
a (TKS sh r)
-> Delta a (TKS sh r) -> Delta a (TKS2 ('[] @Nat) (TKScalar r))
DeltaDot0S target (TKS sh r)
u Delta target (TKS sh r)
v'))
  -- These two are manually vectorized to avoid delta blowup when run
  -- via primitive pipelines.
  tsmatvecmul :: forall (m :: Nat) (n :: Nat) r.
(KnownNat m, KnownNat n, GoodScalar r) =>
ADVal target (TKS ((':) @Nat m ((':) @Nat n ('[] @Nat))) r)
-> ADVal target (TKS ((':) @Nat n ('[] @Nat)) r)
-> ADVal target (TKS ((':) @Nat m ('[] @Nat)) r)
tsmatvecmul ADVal
  target (TKS2 ((':) @Nat m ((':) @Nat n ('[] @Nat))) (TKScalar r))
m ADVal target (TKS ((':) @Nat n ('[] @Nat)) r)
v = ADVal
  target (TKS2 ((':) @Nat n ((':) @Nat m ('[] @Nat))) (TKScalar r))
-> ADVal target (TKS2 ((':) @Nat m ('[] @Nat)) (TKScalar r))
forall (n :: Nat) (sh :: [Nat]) (x :: TK).
(KnownNat n, KnownShS sh, KnownSTK x) =>
ADVal target (TKS2 ((':) @Nat n sh) x) -> ADVal target (TKS2 sh x)
forall (target :: Target) (n :: Nat) (sh :: [Nat]) (x :: TK).
(BaseTensor target, KnownNat n, KnownShS sh, KnownSTK x) =>
target (TKS2 ((':) @Nat n sh) x) -> target (TKS2 sh x)
tssum (ADVal
  target (TKS2 ((':) @Nat m ((':) @Nat n ('[] @Nat))) (TKScalar r))
-> ADVal
     target (TKS2 ((':) @Nat n ((':) @Nat m ('[] @Nat))) (TKScalar r))
forall (n :: Nat) (m :: Nat) (sh :: [Nat]) (x :: TK)
       (target :: Target).
(KnownSTK x, BaseTensor target) =>
target (TKS2 ((':) @Nat n ((':) @Nat m sh)) x)
-> target (TKS2 ((':) @Nat m ((':) @Nat n sh)) x)
str (SNat m
-> ShS ((':) @Nat n ('[] @Nat))
-> ADVal target (TKS ((':) @Nat n ('[] @Nat)) r)
-> ADVal
     target (TKS2 ((':) @Nat m ((':) @Nat n ('[] @Nat))) (TKScalar r))
forall (sh :: [Nat]) (k :: Nat) (x :: TK).
KnownSTK x =>
SNat k
-> ShS sh
-> ADVal target (TKS2 sh x)
-> ADVal target (TKS2 ((':) @Nat k sh) x)
forall (target :: Target) (sh :: [Nat]) (k :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
SNat k
-> ShS sh -> target (TKS2 sh x) -> target (TKS2 ((':) @Nat k sh) x)
tsreplicate SNat m
forall (n :: Nat). KnownNat n => SNat n
SNat ShS ((':) @Nat n ('[] @Nat))
forall (sh :: [Nat]). KnownShS sh => ShS sh
knownShS ADVal target (TKS ((':) @Nat n ('[] @Nat)) r)
v ADVal
  target (TKS2 ((':) @Nat m ((':) @Nat n ('[] @Nat))) (TKScalar r))
-> ADVal
     target (TKS2 ((':) @Nat m ((':) @Nat n ('[] @Nat))) (TKScalar r))
-> ADVal
     target (TKS2 ((':) @Nat m ((':) @Nat n ('[] @Nat))) (TKScalar r))
forall a. Num a => a -> a -> a
* ADVal
  target (TKS2 ((':) @Nat m ((':) @Nat n ('[] @Nat))) (TKScalar r))
m))
  tsmatmul2 :: forall (m :: Nat) (n :: Nat) (p :: Nat) r.
(KnownNat m, KnownNat n, KnownNat p, GoodScalar r) =>
ADVal target (TKS ((':) @Nat m ((':) @Nat n ('[] @Nat))) r)
-> ADVal target (TKS ((':) @Nat n ((':) @Nat p ('[] @Nat))) r)
-> ADVal target (TKS ((':) @Nat m ((':) @Nat p ('[] @Nat))) r)
tsmatmul2 ADVal target (TKS ((':) @Nat m ((':) @Nat n ('[] @Nat))) r)
m1 ADVal target (TKS ((':) @Nat n ((':) @Nat p ('[] @Nat))) r)
m2 =
    ADVal
  target
  (TKS2
     ((':) @Nat n ((':) @Nat m ((':) @Nat p ('[] @Nat)))) (TKScalar r))
-> ADVal
     target (TKS2 ((':) @Nat m ((':) @Nat p ('[] @Nat))) (TKScalar r))
forall (n :: Nat) (sh :: [Nat]) (x :: TK).
(KnownNat n, KnownShS sh, KnownSTK x) =>
ADVal target (TKS2 ((':) @Nat n sh) x) -> ADVal target (TKS2 sh x)
forall (target :: Target) (n :: Nat) (sh :: [Nat]) (x :: TK).
(BaseTensor target, KnownNat n, KnownShS sh, KnownSTK x) =>
target (TKS2 ((':) @Nat n sh) x) -> target (TKS2 sh x)
tssum (Perm ((':) @Nat 2 ((':) @Nat 1 ((':) @Nat 0 ('[] @Nat))))
-> ADVal
     target
     (TKS2
        ((':) @Nat p ((':) @Nat m ((':) @Nat n ('[] @Nat)))) (TKScalar r))
-> ADVal
     target
     (TKS2
        (PermutePrefix
           @Nat
           ((':) @Nat 2 ((':) @Nat 1 ((':) @Nat 0 ('[] @Nat))))
           ((':) @Nat p ((':) @Nat m ((':) @Nat n ('[] @Nat)))))
        (TKScalar r))
forall (perm :: [Nat]) (sh :: [Nat]) (x :: TK).
(IsPermutation perm, (<=) @Nat (Rank @Nat perm) (Rank @Nat sh),
 KnownSTK x) =>
Perm perm
-> ADVal target (TKS2 sh x)
-> ADVal target (TKS2 (PermutePrefix @Nat perm sh) x)
forall (target :: Target) (perm :: [Nat]) (sh :: [Nat]) (x :: TK).
(BaseTensor target, IsPermutation perm,
 (<=) @Nat (Rank @Nat perm) (Rank @Nat sh), KnownSTK x) =>
Perm perm
-> target (TKS2 sh x)
-> target (TKS2 (PermutePrefix @Nat perm sh) x)
tstranspose (forall (l :: [Nat]). KnownPerm l => Perm l
Permutation.makePerm @'[2, 1, 0])
                       (SNat p
-> ShS ((':) @Nat m ((':) @Nat n ('[] @Nat)))
-> ADVal target (TKS ((':) @Nat m ((':) @Nat n ('[] @Nat))) r)
-> ADVal
     target
     (TKS2
        ((':) @Nat p ((':) @Nat m ((':) @Nat n ('[] @Nat)))) (TKScalar r))
forall (sh :: [Nat]) (k :: Nat) (x :: TK).
KnownSTK x =>
SNat k
-> ShS sh
-> ADVal target (TKS2 sh x)
-> ADVal target (TKS2 ((':) @Nat k sh) x)
forall (target :: Target) (sh :: [Nat]) (k :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
SNat k
-> ShS sh -> target (TKS2 sh x) -> target (TKS2 ((':) @Nat k sh) x)
tsreplicate SNat p
forall (n :: Nat). KnownNat n => SNat n
SNat ShS ((':) @Nat m ((':) @Nat n ('[] @Nat)))
forall (sh :: [Nat]). KnownShS sh => ShS sh
knownShS ADVal target (TKS ((':) @Nat m ((':) @Nat n ('[] @Nat))) r)
m1)
           ADVal
  target
  (TKS2
     ((':) @Nat n ((':) @Nat m ((':) @Nat p ('[] @Nat)))) (TKScalar r))
-> ADVal
     target
     (TKS2
        ((':) @Nat n ((':) @Nat m ((':) @Nat p ('[] @Nat)))) (TKScalar r))
-> ADVal
     target
     (TKS2
        ((':) @Nat n ((':) @Nat m ((':) @Nat p ('[] @Nat)))) (TKScalar r))
forall a. Num a => a -> a -> a
* Perm ((':) @Nat 1 ((':) @Nat 0 ('[] @Nat)))
-> ADVal
     target
     (TKS2
        ((':) @Nat m ((':) @Nat n ((':) @Nat p ('[] @Nat)))) (TKScalar r))
-> ADVal
     target
     (TKS2
        (PermutePrefix
           @Nat
           ((':) @Nat 1 ((':) @Nat 0 ('[] @Nat)))
           ((':) @Nat m ((':) @Nat n ((':) @Nat p ('[] @Nat)))))
        (TKScalar r))
forall (perm :: [Nat]) (sh :: [Nat]) (x :: TK).
(IsPermutation perm, (<=) @Nat (Rank @Nat perm) (Rank @Nat sh),
 KnownSTK x) =>
Perm perm
-> ADVal target (TKS2 sh x)
-> ADVal target (TKS2 (PermutePrefix @Nat perm sh) x)
forall (target :: Target) (perm :: [Nat]) (sh :: [Nat]) (x :: TK).
(BaseTensor target, IsPermutation perm,
 (<=) @Nat (Rank @Nat perm) (Rank @Nat sh), KnownSTK x) =>
Perm perm
-> target (TKS2 sh x)
-> target (TKS2 (PermutePrefix @Nat perm sh) x)
tstranspose (forall (l :: [Nat]). KnownPerm l => Perm l
Permutation.makePerm @'[1, 0])
                         (SNat m
-> ShS ((':) @Nat n ((':) @Nat p ('[] @Nat)))
-> ADVal target (TKS ((':) @Nat n ((':) @Nat p ('[] @Nat))) r)
-> ADVal
     target
     (TKS2
        ((':) @Nat m ((':) @Nat n ((':) @Nat p ('[] @Nat)))) (TKScalar r))
forall (sh :: [Nat]) (k :: Nat) (x :: TK).
KnownSTK x =>
SNat k
-> ShS sh
-> ADVal target (TKS2 sh x)
-> ADVal target (TKS2 ((':) @Nat k sh) x)
forall (target :: Target) (sh :: [Nat]) (k :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
SNat k
-> ShS sh -> target (TKS2 sh x) -> target (TKS2 ((':) @Nat k sh) x)
tsreplicate SNat m
forall (n :: Nat). KnownNat n => SNat n
SNat ShS ((':) @Nat n ((':) @Nat p ('[] @Nat)))
forall (sh :: [Nat]). KnownShS sh => ShS sh
knownShS ADVal target (TKS ((':) @Nat n ((':) @Nat p ('[] @Nat))) r)
m2))
  tsindex :: forall (shm :: [Nat]) (shn :: [Nat]) (x :: TK).
(KnownShS shm, KnownShS shn, KnownSTK x) =>
ADVal target (TKS2 ((++) @Nat shm shn) x)
-> IxSOf (ADVal target) shm -> ADVal target (TKS2 shn x)
tsindex (D target (TKS2 ((++) @Nat shm shn) x)
u Delta target (TKS2 ((++) @Nat shm shn) x)
u') IxSOf (ADVal target) shm
i =
    let !ix :: IxS shm (PrimalOf target (TKScalar Int64))
ix = PrimalOf target (TKScalar Int64)
-> PrimalOf target (TKScalar Int64)
forall (y :: TK). PrimalOf target y -> PrimalOf target y
forall (target :: Target) (y :: TK).
ShareTensor target =>
target y -> target y
tshare (PrimalOf target (TKScalar Int64)
 -> PrimalOf target (TKScalar Int64))
-> (target (TKScalar Int64) -> PrimalOf target (TKScalar Int64))
-> target (TKScalar Int64)
-> PrimalOf target (TKScalar Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. target (TKScalar Int64) -> PrimalOf target (TKScalar Int64)
forall (y :: TK). target y -> PrimalOf target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
target y -> PrimalOf target y
tprimalPart (target (TKScalar Int64) -> PrimalOf target (TKScalar Int64))
-> IxS shm (target (TKScalar Int64))
-> IxS shm (PrimalOf target (TKScalar Int64))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> IxS shm (target (TKScalar Int64))
IxSOf (ADVal target) shm
i
    in target (TKS2 shn x)
-> Delta target (TKS2 shn x) -> ADVal target (TKS2 shn x)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (target (TKS2 ((++) @Nat shm shn) x)
-> IxS shm (PrimalOf target (TKScalar Int64))
-> target (TKS2 shn x)
forall (shm :: [Nat]) (shn :: [Nat]) (x :: TK).
(KnownShS shm, KnownShS shn, KnownSTK x) =>
target (TKS2 ((++) @Nat shm shn) x)
-> IxSOf target shm -> target (TKS2 shn x)
forall (target :: Target) (shm :: [Nat]) (shn :: [Nat]) (x :: TK).
(BaseTensor target, KnownShS shm, KnownShS shn, KnownSTK x) =>
target (TKS2 ((++) @Nat shm shn) x)
-> IxSOf target shm -> target (TKS2 shn x)
tsindex target (TKS2 ((++) @Nat shm shn) x)
u IxS shm (PrimalOf target (TKScalar Int64))
ix) (ShS shn
-> Delta target (TKS2 ((++) @Nat shm shn) x)
-> IxS shm (PrimalOf target (TKScalar Int64))
-> Delta target (TKS2 shn x)
forall (shm :: [Nat]) (shn :: [Nat]) (r :: TK) (a :: Target).
ShS shn
-> Delta a (TKS2 ((++) @Nat shm shn) r)
-> IxSOf a shm
-> Delta a (TKS2 shn r)
DeltaIndexS ShS shn
forall (sh :: [Nat]). KnownShS sh => ShS sh
knownShS Delta target (TKS2 ((++) @Nat shm shn) x)
u' IxS shm (PrimalOf target (TKScalar Int64))
ix)
  tsscatter :: forall (shm :: [Nat]) (shn :: [Nat]) (shp :: [Nat]) (x :: TK).
(KnownShS shm, KnownShS shn, KnownShS shp, KnownSTK x) =>
ADVal target (TKS2 ((++) @Nat shm shn) x)
-> (IxSOf (ADVal target) shm -> IxSOf (ADVal target) shp)
-> ADVal target (TKS2 ((++) @Nat shp shn) x)
tsscatter @shm @shn @shp (D target (TKS2 ((++) @Nat shm shn) x)
u Delta target (TKS2 ((++) @Nat shm shn) x)
u') IxSOf (ADVal target) shm -> IxSOf (ADVal target) shp
f =
    let g :: IxS shm (PrimalOf target (TKScalar Int64))
-> IxS shp (PrimalOf target (TKScalar Int64))
g IxS shm (PrimalOf target (TKScalar Int64))
x = target (TKScalar Int64) -> PrimalOf target (TKScalar Int64)
forall (y :: TK). target y -> PrimalOf target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
target y -> PrimalOf target y
tprimalPart (target (TKScalar Int64) -> PrimalOf target (TKScalar Int64))
-> IxS shp (target (TKScalar Int64))
-> IxS shp (PrimalOf target (TKScalar Int64))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> IxSOf (ADVal target) shm -> IxSOf (ADVal target) shp
f (SingletonTK (TKScalar Int64)
-> PrimalOf target (TKScalar Int64) -> target (TKScalar Int64)
forall (y :: TK). SingletonTK y -> PrimalOf target y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> PrimalOf target y -> target y
tfromPrimal SingletonTK (TKScalar Int64)
forall r. GoodScalar r => SingletonTK (TKScalar r)
STKScalar (PrimalOf target (TKScalar Int64) -> target (TKScalar Int64))
-> IxS shm (PrimalOf target (TKScalar Int64))
-> IxS shm (target (TKScalar Int64))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> IxS shm (PrimalOf target (TKScalar Int64))
x)
    in target (TKS2 ((++) @Nat shp shn) x)
-> Delta target (TKS2 ((++) @Nat shp shn) x)
-> ADVal target (TKS2 ((++) @Nat shp shn) x)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (forall (target :: Target) (shm :: [Nat]) (shn :: [Nat])
       (shp :: [Nat]) (x :: TK).
(BaseTensor target, KnownShS shm, KnownShS shn, KnownShS shp,
 KnownSTK x) =>
target (TKS2 ((++) @Nat shm shn) x)
-> (IxSOf target shm -> IxSOf target shp)
-> target (TKS2 ((++) @Nat shp shn) x)
tsscatter @_ @shm @shn @shp target (TKS2 ((++) @Nat shm shn) x)
u IxS shm (PrimalOf target (TKScalar Int64))
-> IxS shp (PrimalOf target (TKScalar Int64))
g)
          (forall (shm :: [Nat]) (shn :: [Nat]) (shp :: [Nat]) (r :: TK)
       (a :: Target).
ShS shm
-> ShS shn
-> ShS shp
-> Delta a (TKS2 ((++) @Nat shm shn) r)
-> (IxSOf a shm -> IxSOf a shp)
-> Delta a (TKS2 ((++) @Nat shp shn) r)
DeltaScatterS @shm @shn @shp ShS shm
forall (sh :: [Nat]). KnownShS sh => ShS sh
knownShS ShS shn
forall (sh :: [Nat]). KnownShS sh => ShS sh
knownShS ShS shp
forall (sh :: [Nat]). KnownShS sh => ShS sh
knownShS Delta target (TKS2 ((++) @Nat shm shn) x)
u' IxS shm (PrimalOf target (TKScalar Int64))
-> IxS shp (PrimalOf target (TKScalar Int64))
g)
  tsgather :: forall (shm :: [Nat]) (shn :: [Nat]) (shp :: [Nat]) (x :: TK).
(KnownShS shm, KnownShS shn, KnownShS shp, KnownSTK x) =>
ADVal target (TKS2 ((++) @Nat shp shn) x)
-> (IxSOf (ADVal target) shm -> IxSOf (ADVal target) shp)
-> ADVal target (TKS2 ((++) @Nat shm shn) x)
tsgather @shm @shn @shp (D target (TKS2 ((++) @Nat shp shn) x)
u Delta target (TKS2 ((++) @Nat shp shn) x)
u') IxSOf (ADVal target) shm -> IxSOf (ADVal target) shp
f =
    let g :: IxS shm (PrimalOf target (TKScalar Int64))
-> IxS shp (PrimalOf target (TKScalar Int64))
g IxS shm (PrimalOf target (TKScalar Int64))
x = target (TKScalar Int64) -> PrimalOf target (TKScalar Int64)
forall (y :: TK). target y -> PrimalOf target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
target y -> PrimalOf target y
tprimalPart (target (TKScalar Int64) -> PrimalOf target (TKScalar Int64))
-> IxS shp (target (TKScalar Int64))
-> IxS shp (PrimalOf target (TKScalar Int64))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> IxSOf (ADVal target) shm -> IxSOf (ADVal target) shp
f (SingletonTK (TKScalar Int64)
-> PrimalOf target (TKScalar Int64) -> target (TKScalar Int64)
forall (y :: TK). SingletonTK y -> PrimalOf target y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> PrimalOf target y -> target y
tfromPrimal SingletonTK (TKScalar Int64)
forall r. GoodScalar r => SingletonTK (TKScalar r)
STKScalar (PrimalOf target (TKScalar Int64) -> target (TKScalar Int64))
-> IxS shm (PrimalOf target (TKScalar Int64))
-> IxS shm (target (TKScalar Int64))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> IxS shm (PrimalOf target (TKScalar Int64))
x)
    in target (TKS2 ((++) @Nat shm shn) x)
-> Delta target (TKS2 ((++) @Nat shm shn) x)
-> ADVal target (TKS2 ((++) @Nat shm shn) x)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (forall (target :: Target) (shm :: [Nat]) (shn :: [Nat])
       (shp :: [Nat]) (x :: TK).
(BaseTensor target, KnownShS shm, KnownShS shn, KnownShS shp,
 KnownSTK x) =>
target (TKS2 ((++) @Nat shp shn) x)
-> (IxSOf target shm -> IxSOf target shp)
-> target (TKS2 ((++) @Nat shm shn) x)
tsgather @_ @shm @shn @shp target (TKS2 ((++) @Nat shp shn) x)
u IxS shm (PrimalOf target (TKScalar Int64))
-> IxS shp (PrimalOf target (TKScalar Int64))
g)
          (forall (shm :: [Nat]) (shn :: [Nat]) (shp :: [Nat]) (r :: TK)
       (a :: Target).
ShS shm
-> ShS shn
-> ShS shp
-> Delta a (TKS2 ((++) @Nat shp shn) r)
-> (IxSOf a shm -> IxSOf a shp)
-> Delta a (TKS2 ((++) @Nat shm shn) r)
DeltaGatherS @shm @shn @shp ShS shm
forall (sh :: [Nat]). KnownShS sh => ShS sh
knownShS ShS shn
forall (sh :: [Nat]). KnownShS sh => ShS sh
knownShS ShS shp
forall (sh :: [Nat]). KnownShS sh => ShS sh
knownShS Delta target (TKS2 ((++) @Nat shp shn) x)
u' IxS shm (PrimalOf target (TKScalar Int64))
-> IxS shp (PrimalOf target (TKScalar Int64))
g)
  tsconcrete :: forall r (sh :: [Nat]).
GoodScalar r =>
Shaped sh r -> ADVal target (TKS sh r)
tsconcrete Shaped sh r
a =
    let v :: target (TKS sh r)
v = Shaped sh r -> target (TKS sh r)
forall r (sh :: [Nat]).
GoodScalar r =>
Shaped sh r -> target (TKS sh r)
forall (target :: Target) r (sh :: [Nat]).
(BaseTensor target, GoodScalar r) =>
Shaped sh r -> target (TKS sh r)
tsconcrete Shaped sh r
a
    in FullShapeTK (TKS sh r)
-> target (TKS sh r) -> ADVal target (TKS sh r)
forall (z :: TK) (f :: Target). FullShapeTK z -> f z -> ADVal f z
fromPrimalFTK (ShS sh -> FullShapeTK (TKScalar r) -> FullShapeTK (TKS sh r)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS (Shaped sh r -> ShS sh
forall (sh :: [Nat]) a. Elt a => Shaped sh a -> ShS sh
Nested.sshape Shaped sh r
a) FullShapeTK (TKScalar r)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar) target (TKS sh r)
v
  tsfloor :: forall r r2 (sh :: [Nat]).
(GoodScalar r, RealFrac r, GoodScalar r2, Integral r2) =>
ADVal target (TKS sh r) -> ADVal target (TKS sh r2)
tsfloor (D target (TKS sh r)
u Delta target (TKS sh r)
_) =
    let v :: target (TKS sh r2)
v = target (TKS sh r) -> target (TKS sh r2)
forall r r2 (sh :: [Nat]).
(GoodScalar r, RealFrac r, GoodScalar r2, Integral r2) =>
target (TKS sh r) -> target (TKS sh r2)
forall (target :: Target) r r2 (sh :: [Nat]).
(BaseTensor target, GoodScalar r, RealFrac r, GoodScalar r2,
 Integral r2) =>
target (TKS sh r) -> target (TKS sh r2)
tsfloor target (TKS sh r)
u
    in FullShapeTK (TKS sh r2)
-> target (TKS sh r2) -> ADVal target (TKS sh r2)
forall (z :: TK) (f :: Target). FullShapeTK z -> f z -> ADVal f z
fromPrimalFTK (ShS sh -> FullShapeTK (TKScalar r2) -> FullShapeTK (TKS sh r2)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS (target (TKS sh r2) -> ShS sh
forall (sh :: [Nat]) (x :: TK).
KnownSTK x =>
target (TKS2 sh x) -> ShS sh
forall (target :: Target) (sh :: [Nat]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKS2 sh x) -> ShS sh
sshape target (TKS sh r2)
v) FullShapeTK (TKScalar r2)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar) target (TKS sh r2)
v
  tsfromIntegral :: forall r1 r2 (sh :: [Nat]).
(GoodScalar r1, Integral r1, GoodScalar r2) =>
ADVal target (TKS sh r1) -> ADVal target (TKS sh r2)
tsfromIntegral (D target (TKS sh r1)
u Delta target (TKS sh r1)
_) =
    let v :: target (TKS sh r2)
v = target (TKS sh r1) -> target (TKS sh r2)
forall r1 r2 (sh :: [Nat]).
(GoodScalar r1, Integral r1, GoodScalar r2) =>
target (TKS sh r1) -> target (TKS sh r2)
forall (target :: Target) r1 r2 (sh :: [Nat]).
(BaseTensor target, GoodScalar r1, Integral r1, GoodScalar r2) =>
target (TKS sh r1) -> target (TKS sh r2)
tsfromIntegral target (TKS sh r1)
u
    in FullShapeTK (TKS sh r2)
-> target (TKS sh r2) -> ADVal target (TKS sh r2)
forall (z :: TK) (f :: Target). FullShapeTK z -> f z -> ADVal f z
fromPrimalFTK (ShS sh -> FullShapeTK (TKScalar r2) -> FullShapeTK (TKS sh r2)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS (target (TKS sh r2) -> ShS sh
forall (sh :: [Nat]) (x :: TK).
KnownSTK x =>
target (TKS2 sh x) -> ShS sh
forall (target :: Target) (sh :: [Nat]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKS2 sh x) -> ShS sh
sshape target (TKS sh r2)
v) FullShapeTK (TKScalar r2)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar) target (TKS sh r2)
v
  tscast :: forall r1 r2 (sh :: [Nat]).
(RealFrac r1, GoodScalar r1, RealFrac r2, GoodScalar r2) =>
ADVal target (TKS sh r1) -> ADVal target (TKS sh r2)
tscast (D target (TKS sh r1)
u Delta target (TKS sh r1)
u') = target (TKS sh r2)
-> Delta target (TKS sh r2) -> ADVal target (TKS sh r2)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (target (TKS sh r1) -> target (TKS sh r2)
forall r1 r2 (sh :: [Nat]).
(RealFrac r1, GoodScalar r1, RealFrac r2, GoodScalar r2) =>
target (TKS sh r1) -> target (TKS sh r2)
forall (target :: Target) r1 r2 (sh :: [Nat]).
(BaseTensor target, RealFrac r1, GoodScalar r1, RealFrac r2,
 GoodScalar r2) =>
target (TKS sh r1) -> target (TKS sh r2)
tscast target (TKS sh r1)
u) (Delta target (TKS sh r1) -> Delta target (TKS sh r2)
forall r1 r2 (a :: Target) (sh :: [Nat]).
(GoodScalar r1, RealFrac r1, GoodScalar r2, RealFrac r2) =>
Delta a (TKS sh r1) -> Delta a (TKS2 sh (TKScalar r2))
DeltaCastS Delta target (TKS sh r1)
u')
  tsminIndex :: forall (n :: Nat) (sh :: [Nat]) r r2.
(GoodScalar r, GoodScalar r2) =>
ADVal target (TKS ((':) @Nat n sh) r)
-> ADVal target (TKS (Init @Nat ((':) @Nat n sh)) r2)
tsminIndex (D target (TKS ((':) @Nat n sh) r)
u Delta target (TKS ((':) @Nat n sh) r)
_) =
    let v :: target (TKS2 (Init @Nat ((':) @Nat n sh)) (TKScalar r2))
v = target (TKS ((':) @Nat n sh) r)
-> target (TKS2 (Init @Nat ((':) @Nat n sh)) (TKScalar r2))
forall (n :: Nat) (sh :: [Nat]) r r2.
(GoodScalar r, GoodScalar r2) =>
target (TKS ((':) @Nat n sh) r)
-> target (TKS (Init @Nat ((':) @Nat n sh)) r2)
forall (target :: Target) (n :: Nat) (sh :: [Nat]) r r2.
(BaseTensor target, GoodScalar r, GoodScalar r2) =>
target (TKS ((':) @Nat n sh) r)
-> target (TKS (Init @Nat ((':) @Nat n sh)) r2)
tsminIndex target (TKS ((':) @Nat n sh) r)
u
    in FullShapeTK (TKS2 (Init @Nat ((':) @Nat n sh)) (TKScalar r2))
-> target (TKS2 (Init @Nat ((':) @Nat n sh)) (TKScalar r2))
-> ADVal target (TKS2 (Init @Nat ((':) @Nat n sh)) (TKScalar r2))
forall (z :: TK) (f :: Target). FullShapeTK z -> f z -> ADVal f z
fromPrimalFTK (ShS (Init @Nat ((':) @Nat n sh))
-> FullShapeTK (TKScalar r2)
-> FullShapeTK (TKS2 (Init @Nat ((':) @Nat n sh)) (TKScalar r2))
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS (target (TKS2 (Init @Nat ((':) @Nat n sh)) (TKScalar r2))
-> ShS (Init @Nat ((':) @Nat n sh))
forall (sh :: [Nat]) (x :: TK).
KnownSTK x =>
target (TKS2 sh x) -> ShS sh
forall (target :: Target) (sh :: [Nat]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKS2 sh x) -> ShS sh
sshape target (TKS2 (Init @Nat ((':) @Nat n sh)) (TKScalar r2))
v) FullShapeTK (TKScalar r2)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar) target (TKS2 (Init @Nat ((':) @Nat n sh)) (TKScalar r2))
v
  tsmaxIndex :: forall (n :: Nat) (sh :: [Nat]) r r2.
(GoodScalar r, GoodScalar r2) =>
ADVal target (TKS ((':) @Nat n sh) r)
-> ADVal target (TKS (Init @Nat ((':) @Nat n sh)) r2)
tsmaxIndex (D target (TKS ((':) @Nat n sh) r)
u Delta target (TKS ((':) @Nat n sh) r)
_) =
    let v :: target (TKS2 (Init @Nat ((':) @Nat n sh)) (TKScalar r2))
v = target (TKS ((':) @Nat n sh) r)
-> target (TKS2 (Init @Nat ((':) @Nat n sh)) (TKScalar r2))
forall (n :: Nat) (sh :: [Nat]) r r2.
(GoodScalar r, GoodScalar r2) =>
target (TKS ((':) @Nat n sh) r)
-> target (TKS (Init @Nat ((':) @Nat n sh)) r2)
forall (target :: Target) (n :: Nat) (sh :: [Nat]) r r2.
(BaseTensor target, GoodScalar r, GoodScalar r2) =>
target (TKS ((':) @Nat n sh) r)
-> target (TKS (Init @Nat ((':) @Nat n sh)) r2)
tsmaxIndex target (TKS ((':) @Nat n sh) r)
u
    in FullShapeTK (TKS2 (Init @Nat ((':) @Nat n sh)) (TKScalar r2))
-> target (TKS2 (Init @Nat ((':) @Nat n sh)) (TKScalar r2))
-> ADVal target (TKS2 (Init @Nat ((':) @Nat n sh)) (TKScalar r2))
forall (z :: TK) (f :: Target). FullShapeTK z -> f z -> ADVal f z
fromPrimalFTK (ShS (Init @Nat ((':) @Nat n sh))
-> FullShapeTK (TKScalar r2)
-> FullShapeTK (TKS2 (Init @Nat ((':) @Nat n sh)) (TKScalar r2))
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS (target (TKS2 (Init @Nat ((':) @Nat n sh)) (TKScalar r2))
-> ShS (Init @Nat ((':) @Nat n sh))
forall (sh :: [Nat]) (x :: TK).
KnownSTK x =>
target (TKS2 sh x) -> ShS sh
forall (target :: Target) (sh :: [Nat]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKS2 sh x) -> ShS sh
sshape target (TKS2 (Init @Nat ((':) @Nat n sh)) (TKScalar r2))
v) FullShapeTK (TKScalar r2)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar) target (TKS2 (Init @Nat ((':) @Nat n sh)) (TKScalar r2))
v
  tsiota :: forall (n :: Nat) r.
(KnownNat n, GoodScalar r) =>
ADVal target (TKS ((':) @Nat n ('[] @Nat)) r)
tsiota = FullShapeTK (TKS ((':) @Nat n ('[] @Nat)) r)
-> target (TKS ((':) @Nat n ('[] @Nat)) r)
-> ADVal target (TKS ((':) @Nat n ('[] @Nat)) r)
forall (z :: TK) (f :: Target). FullShapeTK z -> f z -> ADVal f z
fromPrimalFTK (ShS ((':) @Nat n ('[] @Nat))
-> FullShapeTK (TKScalar r)
-> FullShapeTK (TKS ((':) @Nat n ('[] @Nat)) r)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS (SNat n
forall (n :: Nat). KnownNat n => SNat n
SNat SNat n -> ShS ('[] @Nat) -> ShS ((':) @Nat n ('[] @Nat))
forall {sh1 :: [Nat]} (n :: Nat) (sh :: [Nat]).
(KnownNat n, ((':) @Nat n sh :: [Nat]) ~ (sh1 :: [Nat])) =>
SNat n -> ShS sh -> ShS sh1
:$$ ShS ('[] @Nat)
forall (sh :: [Nat]).
((sh :: [Nat]) ~ ('[] @Nat :: [Nat])) =>
ShS sh
ZSS) FullShapeTK (TKScalar r)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar) target (TKS ((':) @Nat n ('[] @Nat)) r)
forall (n :: Nat) r.
(KnownNat n, GoodScalar r) =>
target (TKS ((':) @Nat n ('[] @Nat)) r)
forall (target :: Target) (n :: Nat) r.
(BaseTensor target, KnownNat n, GoodScalar r) =>
target (TKS ((':) @Nat n ('[] @Nat)) r)
tsiota
  tsappend :: forall (m :: Nat) (n :: Nat) (sh :: [Nat]) (x :: TK).
KnownSTK x =>
ADVal target (TKS2 ((':) @Nat m sh) x)
-> ADVal target (TKS2 ((':) @Nat n sh) x)
-> ADVal target (TKS2 ((':) @Nat (m + n) sh) x)
tsappend (D target (TKS2 ((':) @Nat m sh) x)
u Delta target (TKS2 ((':) @Nat m sh) x)
u') (D target (TKS2 ((':) @Nat n sh) x)
v Delta target (TKS2 ((':) @Nat n sh) x)
v') = target (TKS2 ((':) @Nat (m + n) sh) x)
-> Delta target (TKS2 ((':) @Nat (m + n) sh) x)
-> ADVal target (TKS2 ((':) @Nat (m + n) sh) x)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (target (TKS2 ((':) @Nat m sh) x)
-> target (TKS2 ((':) @Nat n sh) x)
-> target (TKS2 ((':) @Nat (m + n) sh) x)
forall (m :: Nat) (n :: Nat) (sh :: [Nat]) (x :: TK).
KnownSTK x =>
target (TKS2 ((':) @Nat m sh) x)
-> target (TKS2 ((':) @Nat n sh) x)
-> target (TKS2 ((':) @Nat (m + n) sh) x)
forall (target :: Target) (m :: Nat) (n :: Nat) (sh :: [Nat])
       (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKS2 ((':) @Nat m sh) x)
-> target (TKS2 ((':) @Nat n sh) x)
-> target (TKS2 ((':) @Nat (m + n) sh) x)
tsappend target (TKS2 ((':) @Nat m sh) x)
u target (TKS2 ((':) @Nat n sh) x)
v) (Delta target (TKS2 ((':) @Nat m sh) x)
-> Delta target (TKS2 ((':) @Nat n sh) x)
-> Delta target (TKS2 ((':) @Nat (m + n) sh) x)
forall (a :: Target) (r :: TK) (m :: Nat) (n :: Nat) (sh :: [Nat]).
Delta a (TKS2 ((':) @Nat m sh) r)
-> Delta a (TKS2 ((':) @Nat n sh) r)
-> Delta a (TKS2 ((':) @Nat (m + n) sh) r)
DeltaAppendS Delta target (TKS2 ((':) @Nat m sh) x)
u' Delta target (TKS2 ((':) @Nat n sh) x)
v')
  tsslice :: forall (i :: Nat) (n :: Nat) (k :: Nat) (sh :: [Nat]) (x :: TK).
KnownSTK x =>
SNat i
-> SNat n
-> SNat k
-> ADVal target (TKS2 ((':) @Nat ((i + n) + k) sh) x)
-> ADVal target (TKS2 ((':) @Nat n sh) x)
tsslice SNat i
i SNat n
n SNat k
k (D target (TKS2 ((':) @Nat ((i + n) + k) sh) x)
u Delta target (TKS2 ((':) @Nat ((i + n) + k) sh) x)
u') = target (TKS2 ((':) @Nat n sh) x)
-> Delta target (TKS2 ((':) @Nat n sh) x)
-> ADVal target (TKS2 ((':) @Nat n sh) x)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (SNat i
-> SNat n
-> SNat k
-> target (TKS2 ((':) @Nat ((i + n) + k) sh) x)
-> target (TKS2 ((':) @Nat n sh) x)
forall (i :: Nat) (n :: Nat) (k :: Nat) (sh :: [Nat]) (x :: TK).
KnownSTK x =>
SNat i
-> SNat n
-> SNat k
-> target (TKS2 ((':) @Nat ((i + n) + k) sh) x)
-> target (TKS2 ((':) @Nat n sh) x)
forall (target :: Target) (i :: Nat) (n :: Nat) (k :: Nat)
       (sh :: [Nat]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
SNat i
-> SNat n
-> SNat k
-> target (TKS2 ((':) @Nat ((i + n) + k) sh) x)
-> target (TKS2 ((':) @Nat n sh) x)
tsslice SNat i
i SNat n
n SNat k
k target (TKS2 ((':) @Nat ((i + n) + k) sh) x)
u) (SNat i
-> SNat n
-> SNat k
-> Delta target (TKS2 ((':) @Nat ((i + n) + k) sh) x)
-> Delta target (TKS2 ((':) @Nat n sh) x)
forall (i :: Nat) (n :: Nat) (k :: Nat) (a :: Target) (sh :: [Nat])
       (r :: TK).
SNat i
-> SNat n
-> SNat k
-> Delta a (TKS2 ((':) @Nat ((i + n) + k) sh) r)
-> Delta a (TKS2 ((':) @Nat n sh) r)
DeltaSliceS SNat i
i SNat n
n SNat k
k Delta target (TKS2 ((':) @Nat ((i + n) + k) sh) x)
u')
  tsreverse :: forall (n :: Nat) (sh :: [Nat]) (x :: TK).
KnownSTK x =>
ADVal target (TKS2 ((':) @Nat n sh) x)
-> ADVal target (TKS2 ((':) @Nat n sh) x)
tsreverse (D target (TKS2 ((':) @Nat n sh) x)
u Delta target (TKS2 ((':) @Nat n sh) x)
u') = target (TKS2 ((':) @Nat n sh) x)
-> Delta target (TKS2 ((':) @Nat n sh) x)
-> ADVal target (TKS2 ((':) @Nat n sh) x)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (target (TKS2 ((':) @Nat n sh) x)
-> target (TKS2 ((':) @Nat n sh) x)
forall (n :: Nat) (sh :: [Nat]) (x :: TK).
KnownSTK x =>
target (TKS2 ((':) @Nat n sh) x)
-> target (TKS2 ((':) @Nat n sh) x)
forall (target :: Target) (n :: Nat) (sh :: [Nat]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKS2 ((':) @Nat n sh) x)
-> target (TKS2 ((':) @Nat n sh) x)
tsreverse target (TKS2 ((':) @Nat n sh) x)
u) (Delta target (TKS2 ((':) @Nat n sh) x)
-> Delta target (TKS2 ((':) @Nat n sh) x)
forall (a :: Target) (n :: Nat) (sh :: [Nat]) (r :: TK).
Delta a (TKS2 ((':) @Nat n sh) r)
-> Delta a (TKS2 ((':) @Nat n sh) r)
DeltaReverseS Delta target (TKS2 ((':) @Nat n sh) x)
u')
  tsbuild1 :: forall (k :: Nat) (sh :: [Nat]) (x :: TK).
(KnownNat k, KnownShS sh, KnownSTK x) =>
(IntOf (ADVal target) -> ADVal target (TKS2 sh x))
-> ADVal target (TKS2 ((':) @Nat k sh) x)
tsbuild1 @k @sh @r IntOf (ADVal target) -> ADVal target (TKS2 sh x)
f | Dict @Type KnownElt (RepConcrete x)
Dict <- SingletonTK x -> Dict @Type KnownElt (RepConcrete x)
forall (y :: TK).
SingletonTK y -> Dict @Type KnownElt (RepConcrete y)
eltDictRep (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @r) =
    let l :: [Integer]
l = [Integer
0 .. forall (n :: Nat) r. (KnownNat n, Num r) => r
valueOf @k Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1]
    in if [Integer] -> Bool
forall a. [a] -> Bool
forall (t :: Type -> Type) a. Foldable t => t a -> Bool
null [Integer]
l
       then let arr :: Shaped ((':) @Nat 0 sh) (RepConcrete x)
arr = forall a (sh :: [Nat]).
KnownElt a =>
ShS sh -> Shaped ((':) @Nat 0 sh) a
Nested.semptyArray @(RepConcrete r) (forall (sh :: [Nat]). KnownShS sh => ShS sh
knownShS @sh)
            in (:~:) @Nat k 0
-> (((k :: Nat) ~ (0 :: Nat)) =>
    ADVal target (TKS2 ((':) @Nat k sh) x))
-> ADVal target (TKS2 ((':) @Nat k sh) x)
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:) @Nat k 0
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl :: k :~: 0) ((((k :: Nat) ~ (0 :: Nat)) =>
  ADVal target (TKS2 ((':) @Nat k sh) x))
 -> ADVal target (TKS2 ((':) @Nat k sh) x))
-> (((k :: Nat) ~ (0 :: Nat)) =>
    ADVal target (TKS2 ((':) @Nat k sh) x))
-> ADVal target (TKS2 ((':) @Nat k sh) x)
forall a b. (a -> b) -> a -> b
$
               FullShapeTK (TKS2 ((':) @Nat k sh) x)
-> Concrete (TKS2 ((':) @Nat k sh) x)
-> ADVal target (TKS2 ((':) @Nat k sh) x)
forall (y :: TK). FullShapeTK y -> Concrete y -> ADVal target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
FullShapeTK y -> Concrete y -> target y
tconcrete (SingletonTK (TKS2 ((':) @Nat k sh) x)
-> RepConcrete (TKS2 ((':) @Nat k sh) x)
-> FullShapeTK (TKS2 ((':) @Nat k sh) x)
forall (y :: TK). SingletonTK y -> RepConcrete y -> FullShapeTK y
tftkG SingletonTK (TKS2 ((':) @Nat k sh) x)
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Shaped ((':) @Nat 0 sh) (RepConcrete x)
RepConcrete (TKS2 ((':) @Nat k sh) x)
arr) (RepConcrete (TKS2 ((':) @Nat k sh) x)
-> Concrete (TKS2 ((':) @Nat k sh) x)
forall (y :: TK). RepConcrete y -> Concrete y
Concrete Shaped ((':) @Nat 0 sh) (RepConcrete x)
RepConcrete (TKS2 ((':) @Nat k sh) x)
arr)
       else Vector (ADVal target (TKS2 sh x))
-> ADVal target (TKS2 ((':) @Nat k sh) x)
forall (n :: Nat) (sh :: [Nat]) (x :: TK).
(KnownNat n, KnownShS sh, KnownSTK x) =>
Vector (ADVal target (TKS2 sh x))
-> ADVal target (TKS2 ((':) @Nat n sh) x)
forall (target :: Target) (n :: Nat) (sh :: [Nat]) (x :: TK).
(BaseTensor target, KnownNat n, KnownShS sh, KnownSTK x) =>
Vector (target (TKS2 sh x)) -> target (TKS2 ((':) @Nat n sh) x)
tsfromVector (Vector (ADVal target (TKS2 sh x))
 -> ADVal target (TKS2 ((':) @Nat k sh) x))
-> Vector (ADVal target (TKS2 sh x))
-> ADVal target (TKS2 ((':) @Nat k sh) x)
forall a b. (a -> b) -> a -> b
$ [ADVal target (TKS2 sh x)] -> Vector (ADVal target (TKS2 sh x))
forall (v :: Type -> Type) a. Vector v a => [a] -> v a
V.fromList ([ADVal target (TKS2 sh x)] -> Vector (ADVal target (TKS2 sh x)))
-> [ADVal target (TKS2 sh x)] -> Vector (ADVal target (TKS2 sh x))
forall a b. (a -> b) -> a -> b
$ (Integer -> ADVal target (TKS2 sh x))
-> [Integer] -> [ADVal target (TKS2 sh x)]
forall a b. (a -> b) -> [a] -> [b]
map (target (TKScalar Int64) -> ADVal target (TKS2 sh x)
IntOf (ADVal target) -> ADVal target (TKS2 sh x)
f (target (TKScalar Int64) -> ADVal target (TKS2 sh x))
-> (Integer -> target (TKScalar Int64))
-> Integer
-> ADVal target (TKS2 sh x)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> target (TKScalar Int64)
forall a. Num a => Integer -> a
fromInteger) [Integer]
l
              -- hope this fuses

  -- Mixed ops
  xshape :: forall (sh :: [Maybe Nat]) (x :: TK).
KnownSTK x =>
ADVal target (TKX2 sh x) -> IShX sh
xshape (D target (TKX2 sh x)
u Delta target (TKX2 sh x)
_) = target (TKX2 sh x) -> IShX sh
forall (sh :: [Maybe Nat]) (x :: TK).
KnownSTK x =>
target (TKX2 sh x) -> IShX sh
forall (target :: Target) (sh :: [Maybe Nat]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKX2 sh x) -> IShX sh
xshape target (TKX2 sh x)
u
  txsum :: forall (n :: Nat) (sh :: [Maybe Nat]) (x :: TK).
(KnownNat n, KnownShX sh, KnownSTK x) =>
ADVal target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
-> ADVal target (TKX2 sh x)
txsum (D target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
u Delta target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
u') = target (TKX2 sh x)
-> Delta target (TKX2 sh x) -> ADVal target (TKX2 sh x)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
-> target (TKX2 sh x)
forall (n :: Nat) (sh :: [Maybe Nat]) (x :: TK).
(KnownNat n, KnownShX sh, KnownSTK x) =>
target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
-> target (TKX2 sh x)
forall (target :: Target) (n :: Nat) (sh :: [Maybe Nat]) (x :: TK).
(BaseTensor target, KnownNat n, KnownShX sh, KnownSTK x) =>
target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
-> target (TKX2 sh x)
txsum target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
u) (SNat n
-> SingletonTK (TKX2 sh x)
-> Delta target (BuildTensorKind n (TKX2 sh x))
-> Delta target (TKX2 sh x)
forall (b :: TK) (k :: Nat) (a :: Target).
SNat k
-> SingletonTK b -> Delta a (BuildTensorKind k b) -> Delta a b
DeltaSum SNat n
forall (n :: Nat). KnownNat n => SNat n
SNat SingletonTK (TKX2 sh x)
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Delta target (BuildTensorKind n (TKX2 sh x))
Delta target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
u')
  txsum0 :: forall (sh :: [Maybe Nat]) (x :: TK).
(KnownShX sh, KnownSTK x, ConvertTensor (ADVal target)) =>
ADVal target (TKX2 sh x)
-> ADVal target (TKX2 ('[] @(Maybe Nat)) x)
txsum0 (D target (TKX2 sh x)
u Delta target (TKX2 sh x)
u') = target (TKX2 ('[] @(Maybe Nat)) x)
-> Delta target (TKX2 ('[] @(Maybe Nat)) x)
-> ADVal target (TKX2 ('[] @(Maybe Nat)) x)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (target (TKX2 sh x) -> target (TKX2 ('[] @(Maybe Nat)) x)
forall (sh :: [Maybe Nat]) (x :: TK).
(KnownShX sh, KnownSTK x, ConvertTensor target) =>
target (TKX2 sh x) -> target (TKX2 ('[] @(Maybe Nat)) x)
forall (target :: Target) (sh :: [Maybe Nat]) (x :: TK).
(BaseTensor target, KnownShX sh, KnownSTK x,
 ConvertTensor target) =>
target (TKX2 sh x) -> target (TKX2 ('[] @(Maybe Nat)) x)
txsum0 target (TKX2 sh x)
u) (Delta target (TKX2 sh x)
-> Delta target (TKX2 ('[] @(Maybe Nat)) x)
forall (a :: Target) (sh :: [Maybe Nat]) (r :: TK).
Delta a (TKX2 sh r) -> Delta a (TKX2 ('[] @(Maybe Nat)) r)
DeltaSum0X Delta target (TKX2 sh x)
u')
  txdot0 :: forall (sh :: [Maybe Nat]) r.
(KnownShX sh, GoodScalar r, ConvertTensor (ADVal target)) =>
ADVal target (TKX sh r)
-> ADVal target (TKX sh r)
-> ADVal target (TKX ('[] @(Maybe Nat)) r)
txdot0 (D target (TKX sh r)
ue Delta target (TKX sh r)
u') (D target (TKX sh r)
ve Delta target (TKX sh r)
v') =
    -- The bangs below are neccessary for GHC 9.2.7 test results to match 9.4.
    let !u :: target (TKX sh r)
u = target (TKX sh r) -> target (TKX sh r)
forall (y :: TK). target y -> target y
forall (target :: Target) (y :: TK).
ShareTensor target =>
target y -> target y
tshare target (TKX sh r)
ue in
    let !v :: target (TKX sh r)
v = target (TKX sh r) -> target (TKX sh r)
forall (y :: TK). target y -> target y
forall (target :: Target) (y :: TK).
ShareTensor target =>
target y -> target y
tshare target (TKX sh r)
ve
    in target (TKX ('[] @(Maybe Nat)) r)
-> Delta target (TKX ('[] @(Maybe Nat)) r)
-> ADVal target (TKX ('[] @(Maybe Nat)) r)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (target (TKX sh r)
-> target (TKX sh r) -> target (TKX ('[] @(Maybe Nat)) r)
forall (sh :: [Maybe Nat]) r.
(KnownShX sh, GoodScalar r, ConvertTensor target) =>
target (TKX sh r)
-> target (TKX sh r) -> target (TKX ('[] @(Maybe Nat)) r)
forall (target :: Target) (sh :: [Maybe Nat]) r.
(BaseTensor target, KnownShX sh, GoodScalar r,
 ConvertTensor target) =>
target (TKX sh r)
-> target (TKX sh r) -> target (TKX ('[] @(Maybe Nat)) r)
txdot0 target (TKX sh r)
u target (TKX sh r)
v) (Delta target (TKX ('[] @(Maybe Nat)) r)
-> Delta target (TKX ('[] @(Maybe Nat)) r)
-> Delta target (TKX ('[] @(Maybe Nat)) r)
forall (f :: Target) (z :: TK).
Num (f z) =>
Delta f z -> Delta f z -> Delta f z
dAdd (target (TKX sh r)
-> Delta target (TKX sh r)
-> Delta target (TKX ('[] @(Maybe Nat)) r)
forall r (a :: Target) (sh :: [Maybe Nat]).
(GoodScalar r, Show (a (TKX sh r))) =>
a (TKX sh r)
-> Delta a (TKX sh r)
-> Delta a (TKX2 ('[] @(Maybe Nat)) (TKScalar r))
DeltaDot0X target (TKX sh r)
v Delta target (TKX sh r)
u') (target (TKX sh r)
-> Delta target (TKX sh r)
-> Delta target (TKX ('[] @(Maybe Nat)) r)
forall r (a :: Target) (sh :: [Maybe Nat]).
(GoodScalar r, Show (a (TKX sh r))) =>
a (TKX sh r)
-> Delta a (TKX sh r)
-> Delta a (TKX2 ('[] @(Maybe Nat)) (TKScalar r))
DeltaDot0X target (TKX sh r)
u Delta target (TKX sh r)
v'))
  -- These two are manually vectorized to avoid delta blowup when run
  -- via primitive pipelines.
  txmatvecmul :: forall (mm :: Maybe Nat) (mn :: Maybe Nat) r.
(GoodScalar r, ConvertTensor (ADVal target)) =>
SMayNat @Nat Int SNat mm
-> SMayNat @Nat Int SNat mn
-> ADVal
     target
     (TKX
        ((':) @(Maybe Nat) mm ((':) @(Maybe Nat) mn ('[] @(Maybe Nat)))) r)
-> ADVal target (TKX ((':) @(Maybe Nat) mn ('[] @(Maybe Nat))) r)
-> ADVal target (TKX ((':) @(Maybe Nat) mm ('[] @(Maybe Nat))) r)
txmatvecmul SMayNat @Nat Int SNat mm
mm SMayNat @Nat Int SNat mn
mn ADVal
  target
  (TKX
     ((':) @(Maybe Nat) mm ((':) @(Maybe Nat) mn ('[] @(Maybe Nat)))) r)
m ADVal target (TKX ((':) @(Maybe Nat) mn ('[] @(Maybe Nat))) r)
v =
    StaticShX ((':) @(Maybe Nat) mn ('[] @(Maybe Nat)))
-> (KnownShX ((':) @(Maybe Nat) mn ('[] @(Maybe Nat))) =>
    ADVal target (TKX ((':) @(Maybe Nat) mm ('[] @(Maybe Nat))) r))
-> ADVal target (TKX ((':) @(Maybe Nat) mm ('[] @(Maybe Nat))) r)
forall (sh :: [Maybe Nat]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX (ShX ((':) @(Maybe Nat) mn ('[] @(Maybe Nat))) Int
-> StaticShX ((':) @(Maybe Nat) mn ('[] @(Maybe Nat)))
forall (sh :: [Maybe Nat]) i. ShX sh i -> StaticShX sh
ssxFromShX (ShX ((':) @(Maybe Nat) mn ('[] @(Maybe Nat))) Int
 -> StaticShX ((':) @(Maybe Nat) mn ('[] @(Maybe Nat))))
-> ShX ((':) @(Maybe Nat) mn ('[] @(Maybe Nat))) Int
-> StaticShX ((':) @(Maybe Nat) mn ('[] @(Maybe Nat)))
forall a b. (a -> b) -> a -> b
$ SMayNat @Nat Int SNat mn
mn SMayNat @Nat Int SNat mn
-> ShX ('[] @(Maybe Nat)) Int
-> ShX ((':) @(Maybe Nat) mn ('[] @(Maybe Nat))) Int
forall {sh1 :: [Maybe Nat]} {i} (n :: Maybe Nat)
       (sh :: [Maybe Nat]).
(((':) @(Maybe Nat) n sh :: [Maybe Nat]) ~ (sh1 :: [Maybe Nat])) =>
SMayNat @Nat i SNat n -> ShX sh i -> ShX sh1 i
:$% ShX ('[] @(Maybe Nat)) Int
forall (sh :: [Maybe Nat]) i.
((sh :: [Maybe Nat]) ~ ('[] @(Maybe Nat) :: [Maybe Nat])) =>
ShX sh i
ZSX) ((KnownShX ((':) @(Maybe Nat) mn ('[] @(Maybe Nat))) =>
  ADVal target (TKX ((':) @(Maybe Nat) mm ('[] @(Maybe Nat))) r))
 -> ADVal target (TKX ((':) @(Maybe Nat) mm ('[] @(Maybe Nat))) r))
-> (KnownShX ((':) @(Maybe Nat) mn ('[] @(Maybe Nat))) =>
    ADVal target (TKX ((':) @(Maybe Nat) mm ('[] @(Maybe Nat))) r))
-> ADVal target (TKX ((':) @(Maybe Nat) mm ('[] @(Maybe Nat))) r)
forall a b. (a -> b) -> a -> b
$
    StaticShX
  ((':) @(Maybe Nat) mm ((':) @(Maybe Nat) mn ('[] @(Maybe Nat))))
-> (KnownShX
      ((':) @(Maybe Nat) mm ((':) @(Maybe Nat) mn ('[] @(Maybe Nat)))) =>
    ADVal target (TKX ((':) @(Maybe Nat) mm ('[] @(Maybe Nat))) r))
-> ADVal target (TKX ((':) @(Maybe Nat) mm ('[] @(Maybe Nat))) r)
forall (sh :: [Maybe Nat]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX (ShX
  ((':) @(Maybe Nat) mm ((':) @(Maybe Nat) mn ('[] @(Maybe Nat))))
  Int
-> StaticShX
     ((':) @(Maybe Nat) mm ((':) @(Maybe Nat) mn ('[] @(Maybe Nat))))
forall (sh :: [Maybe Nat]) i. ShX sh i -> StaticShX sh
ssxFromShX (ShX
   ((':) @(Maybe Nat) mm ((':) @(Maybe Nat) mn ('[] @(Maybe Nat))))
   Int
 -> StaticShX
      ((':) @(Maybe Nat) mm ((':) @(Maybe Nat) mn ('[] @(Maybe Nat)))))
-> ShX
     ((':) @(Maybe Nat) mm ((':) @(Maybe Nat) mn ('[] @(Maybe Nat))))
     Int
-> StaticShX
     ((':) @(Maybe Nat) mm ((':) @(Maybe Nat) mn ('[] @(Maybe Nat))))
forall a b. (a -> b) -> a -> b
$ SMayNat @Nat Int SNat mm
mm SMayNat @Nat Int SNat mm
-> ShX ((':) @(Maybe Nat) mn ('[] @(Maybe Nat))) Int
-> ShX
     ((':) @(Maybe Nat) mm ((':) @(Maybe Nat) mn ('[] @(Maybe Nat))))
     Int
forall {sh1 :: [Maybe Nat]} {i} (n :: Maybe Nat)
       (sh :: [Maybe Nat]).
(((':) @(Maybe Nat) n sh :: [Maybe Nat]) ~ (sh1 :: [Maybe Nat])) =>
SMayNat @Nat i SNat n -> ShX sh i -> ShX sh1 i
:$% SMayNat @Nat Int SNat mn
mn SMayNat @Nat Int SNat mn
-> ShX ('[] @(Maybe Nat)) Int
-> ShX ((':) @(Maybe Nat) mn ('[] @(Maybe Nat))) Int
forall {sh1 :: [Maybe Nat]} {i} (n :: Maybe Nat)
       (sh :: [Maybe Nat]).
(((':) @(Maybe Nat) n sh :: [Maybe Nat]) ~ (sh1 :: [Maybe Nat])) =>
SMayNat @Nat i SNat n -> ShX sh i -> ShX sh1 i
:$% ShX ('[] @(Maybe Nat)) Int
forall (sh :: [Maybe Nat]) i.
((sh :: [Maybe Nat]) ~ ('[] @(Maybe Nat) :: [Maybe Nat])) =>
ShX sh i
ZSX) ((KnownShX
    ((':) @(Maybe Nat) mm ((':) @(Maybe Nat) mn ('[] @(Maybe Nat)))) =>
  ADVal target (TKX ((':) @(Maybe Nat) mm ('[] @(Maybe Nat))) r))
 -> ADVal target (TKX ((':) @(Maybe Nat) mm ('[] @(Maybe Nat))) r))
-> (KnownShX
      ((':) @(Maybe Nat) mm ((':) @(Maybe Nat) mn ('[] @(Maybe Nat)))) =>
    ADVal target (TKX ((':) @(Maybe Nat) mm ('[] @(Maybe Nat))) r))
-> ADVal target (TKX ((':) @(Maybe Nat) mm ('[] @(Maybe Nat))) r)
forall a b. (a -> b) -> a -> b
$
    Int
-> (forall (n :: Nat).
    KnownNat n =>
    SNat n
    -> ADVal target (TKX ((':) @(Maybe Nat) mm ('[] @(Maybe Nat))) r))
-> ADVal target (TKX ((':) @(Maybe Nat) mm ('[] @(Maybe Nat))) r)
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat (SMayNat @Nat Int SNat mm -> Int
forall (n :: Maybe Nat). SMayNat @Nat Int SNat n -> Int
fromSMayNat' SMayNat @Nat Int SNat mm
mm) ((forall (n :: Nat).
  KnownNat n =>
  SNat n
  -> ADVal target (TKX ((':) @(Maybe Nat) mm ('[] @(Maybe Nat))) r))
 -> ADVal target (TKX ((':) @(Maybe Nat) mm ('[] @(Maybe Nat))) r))
-> (forall (n :: Nat).
    KnownNat n =>
    SNat n
    -> ADVal target (TKX ((':) @(Maybe Nat) mm ('[] @(Maybe Nat))) r))
-> ADVal target (TKX ((':) @(Maybe Nat) mm ('[] @(Maybe Nat))) r)
forall a b. (a -> b) -> a -> b
$ \(SNat @m) ->
    Int
-> (forall (n :: Nat).
    KnownNat n =>
    SNat n
    -> ADVal target (TKX ((':) @(Maybe Nat) mm ('[] @(Maybe Nat))) r))
-> ADVal target (TKX ((':) @(Maybe Nat) mm ('[] @(Maybe Nat))) r)
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat (SMayNat @Nat Int SNat mn -> Int
forall (n :: Maybe Nat). SMayNat @Nat Int SNat n -> Int
fromSMayNat' SMayNat @Nat Int SNat mn
mn) ((forall (n :: Nat).
  KnownNat n =>
  SNat n
  -> ADVal target (TKX ((':) @(Maybe Nat) mm ('[] @(Maybe Nat))) r))
 -> ADVal target (TKX ((':) @(Maybe Nat) mm ('[] @(Maybe Nat))) r))
-> (forall (n :: Nat).
    KnownNat n =>
    SNat n
    -> ADVal target (TKX ((':) @(Maybe Nat) mm ('[] @(Maybe Nat))) r))
-> ADVal target (TKX ((':) @(Maybe Nat) mm ('[] @(Maybe Nat))) r)
forall a b. (a -> b) -> a -> b
$ \(SNat @n) ->
      StaticShX ((':) @(Maybe Nat) mm ('[] @(Maybe Nat)))
-> ADVal
     target
     (TKX2
        ((':) @(Maybe Nat) ('Just @Nat n) ('[] @(Maybe Nat))) (TKScalar r))
-> ADVal target (TKX ((':) @(Maybe Nat) mm ('[] @(Maybe Nat))) r)
forall (x :: TK) (sh :: [Maybe Nat]) (sh2 :: [Maybe Nat]).
(KnownSTK x, KnownShX sh,
 (Rank @(Maybe Nat) sh :: Nat) ~ (Rank @(Maybe Nat) sh2 :: Nat),
 ConvertTensor (ADVal target)) =>
StaticShX sh2
-> ADVal target (TKX2 sh x) -> ADVal target (TKX2 sh2 x)
forall (target :: Target) (x :: TK) (sh :: [Maybe Nat])
       (sh2 :: [Maybe Nat]).
(BaseTensor target, KnownSTK x, KnownShX sh,
 (Rank @(Maybe Nat) sh :: Nat) ~ (Rank @(Maybe Nat) sh2 :: Nat),
 ConvertTensor target) =>
StaticShX sh2 -> target (TKX2 sh x) -> target (TKX2 sh2 x)
xmcast (ShX ((':) @(Maybe Nat) mm ('[] @(Maybe Nat))) Int
-> StaticShX ((':) @(Maybe Nat) mm ('[] @(Maybe Nat)))
forall (sh :: [Maybe Nat]) i. ShX sh i -> StaticShX sh
ssxFromShX (SMayNat @Nat Int SNat mm
mm SMayNat @Nat Int SNat mm
-> ShX ('[] @(Maybe Nat)) Int
-> ShX ((':) @(Maybe Nat) mm ('[] @(Maybe Nat))) Int
forall {sh1 :: [Maybe Nat]} {i} (n :: Maybe Nat)
       (sh :: [Maybe Nat]).
(((':) @(Maybe Nat) n sh :: [Maybe Nat]) ~ (sh1 :: [Maybe Nat])) =>
SMayNat @Nat i SNat n -> ShX sh i -> ShX sh1 i
:$% ShX ('[] @(Maybe Nat)) Int
forall (sh :: [Maybe Nat]) i.
((sh :: [Maybe Nat]) ~ ('[] @(Maybe Nat) :: [Maybe Nat])) =>
ShX sh i
ZSX))
      (ADVal
   target
   (TKX2
      ((':) @(Maybe Nat) ('Just @Nat n) ('[] @(Maybe Nat))) (TKScalar r))
 -> ADVal target (TKX ((':) @(Maybe Nat) mm ('[] @(Maybe Nat))) r))
-> ADVal
     target
     (TKX2
        ((':) @(Maybe Nat) ('Just @Nat n) ('[] @(Maybe Nat))) (TKScalar r))
-> ADVal target (TKX ((':) @(Maybe Nat) mm ('[] @(Maybe Nat))) r)
forall a b. (a -> b) -> a -> b
$ ADVal
  target
  (TKX2
     ((':)
        @(Maybe Nat)
        ('Just @Nat n)
        ((':) @(Maybe Nat) ('Just @Nat n) ('[] @(Maybe Nat))))
     (TKScalar r))
-> ADVal
     target
     (TKX2
        ((':) @(Maybe Nat) ('Just @Nat n) ('[] @(Maybe Nat))) (TKScalar r))
forall (n :: Nat) (sh :: [Maybe Nat]) (x :: TK).
(KnownNat n, KnownShX sh, KnownSTK x) =>
ADVal target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
-> ADVal target (TKX2 sh x)
forall (target :: Target) (n :: Nat) (sh :: [Maybe Nat]) (x :: TK).
(BaseTensor target, KnownNat n, KnownShX sh, KnownSTK x) =>
target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
-> target (TKX2 sh x)
txsum (ADVal
  target
  (TKX2
     ((':)
        @(Maybe Nat)
        ('Just @Nat n)
        ((':) @(Maybe Nat) ('Just @Nat n) ('[] @(Maybe Nat))))
     (TKScalar r))
-> ADVal
     target
     (TKX2
        ((':)
           @(Maybe Nat)
           ('Just @Nat n)
           ((':) @(Maybe Nat) ('Just @Nat n) ('[] @(Maybe Nat))))
        (TKScalar r))
forall (n :: Nat) (m :: Nat) (sh :: [Maybe Nat]) (x :: TK)
       (target :: Target).
(KnownSTK x, BaseTensor target) =>
target
  (TKX2
     ((':)
        @(Maybe Nat) ('Just @Nat n) ((':) @(Maybe Nat) ('Just @Nat m) sh))
     x)
-> target
     (TKX2
        ((':)
           @(Maybe Nat) ('Just @Nat m) ((':) @(Maybe Nat) ('Just @Nat n) sh))
        x)
xtr (SNat n
-> StaticShX ((':) @(Maybe Nat) ('Just @Nat n) ('[] @(Maybe Nat)))
-> ADVal
     target
     (TKX2
        ((':) @(Maybe Nat) ('Just @Nat n) ('[] @(Maybe Nat))) (TKScalar r))
-> ADVal
     target
     (TKX2
        ((':)
           @(Maybe Nat)
           ('Just @Nat n)
           ((':) @(Maybe Nat) ('Just @Nat n) ('[] @(Maybe Nat))))
        (TKScalar r))
forall (sh :: [Maybe Nat]) (k :: Nat) (x :: TK).
KnownSTK x =>
SNat k
-> StaticShX sh
-> ADVal target (TKX2 sh x)
-> ADVal target (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) x)
forall (target :: Target) (sh :: [Maybe Nat]) (k :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
SNat k
-> StaticShX sh
-> target (TKX2 sh x)
-> target (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) x)
txreplicate (forall (n :: Nat). KnownNat n => SNat n
SNat @m) StaticShX ((':) @(Maybe Nat) ('Just @Nat n) ('[] @(Maybe Nat)))
forall (sh :: [Maybe Nat]). KnownShX sh => StaticShX sh
knownShX
                      (StaticShX ((':) @(Maybe Nat) ('Just @Nat n) ('[] @(Maybe Nat)))
-> ADVal target (TKX ((':) @(Maybe Nat) mn ('[] @(Maybe Nat))) r)
-> ADVal
     target
     (TKX2
        ((':) @(Maybe Nat) ('Just @Nat n) ('[] @(Maybe Nat))) (TKScalar r))
forall (x :: TK) (sh :: [Maybe Nat]) (sh2 :: [Maybe Nat]).
(KnownSTK x, KnownShX sh,
 (Rank @(Maybe Nat) sh :: Nat) ~ (Rank @(Maybe Nat) sh2 :: Nat),
 ConvertTensor (ADVal target)) =>
StaticShX sh2
-> ADVal target (TKX2 sh x) -> ADVal target (TKX2 sh2 x)
forall (target :: Target) (x :: TK) (sh :: [Maybe Nat])
       (sh2 :: [Maybe Nat]).
(BaseTensor target, KnownSTK x, KnownShX sh,
 (Rank @(Maybe Nat) sh :: Nat) ~ (Rank @(Maybe Nat) sh2 :: Nat),
 ConvertTensor target) =>
StaticShX sh2 -> target (TKX2 sh x) -> target (TKX2 sh2 x)
xmcast (ShX
  ((':) @(Maybe Nat) ('Just @Nat n) ('[] @(Maybe Nat)))
  (ZonkAny @Type 0)
-> StaticShX ((':) @(Maybe Nat) ('Just @Nat n) ('[] @(Maybe Nat)))
forall (sh :: [Maybe Nat]) i. ShX sh i -> StaticShX sh
ssxFromShX (SNat n -> SMayNat @Nat (ZonkAny @Type 0) SNat ('Just @Nat n)
forall {k} (f :: k -> Type) (n1 :: k) i.
f n1 -> SMayNat @k i f ('Just @k n1)
Nested.SKnown (forall (n :: Nat). KnownNat n => SNat n
SNat @n)
                                             SMayNat @Nat (ZonkAny @Type 0) SNat ('Just @Nat n)
-> ShX ('[] @(Maybe Nat)) (ZonkAny @Type 0)
-> ShX
     ((':) @(Maybe Nat) ('Just @Nat n) ('[] @(Maybe Nat)))
     (ZonkAny @Type 0)
forall {sh1 :: [Maybe Nat]} {i} (n :: Maybe Nat)
       (sh :: [Maybe Nat]).
(((':) @(Maybe Nat) n sh :: [Maybe Nat]) ~ (sh1 :: [Maybe Nat])) =>
SMayNat @Nat i SNat n -> ShX sh i -> ShX sh1 i
:$% ShX ('[] @(Maybe Nat)) (ZonkAny @Type 0)
forall (sh :: [Maybe Nat]) i.
((sh :: [Maybe Nat]) ~ ('[] @(Maybe Nat) :: [Maybe Nat])) =>
ShX sh i
ZSX)) ADVal target (TKX ((':) @(Maybe Nat) mn ('[] @(Maybe Nat))) r)
v)
                    ADVal
  target
  (TKX2
     ((':)
        @(Maybe Nat)
        ('Just @Nat n)
        ((':) @(Maybe Nat) ('Just @Nat n) ('[] @(Maybe Nat))))
     (TKScalar r))
-> ADVal
     target
     (TKX2
        ((':)
           @(Maybe Nat)
           ('Just @Nat n)
           ((':) @(Maybe Nat) ('Just @Nat n) ('[] @(Maybe Nat))))
        (TKScalar r))
-> ADVal
     target
     (TKX2
        ((':)
           @(Maybe Nat)
           ('Just @Nat n)
           ((':) @(Maybe Nat) ('Just @Nat n) ('[] @(Maybe Nat))))
        (TKScalar r))
forall a. Num a => a -> a -> a
* StaticShX
  ((':)
     @(Maybe Nat)
     ('Just @Nat n)
     ((':) @(Maybe Nat) ('Just @Nat n) ('[] @(Maybe Nat))))
-> ADVal
     target
     (TKX
        ((':) @(Maybe Nat) mm ((':) @(Maybe Nat) mn ('[] @(Maybe Nat)))) r)
-> ADVal
     target
     (TKX2
        ((':)
           @(Maybe Nat)
           ('Just @Nat n)
           ((':) @(Maybe Nat) ('Just @Nat n) ('[] @(Maybe Nat))))
        (TKScalar r))
forall (x :: TK) (sh :: [Maybe Nat]) (sh2 :: [Maybe Nat]).
(KnownSTK x, KnownShX sh,
 (Rank @(Maybe Nat) sh :: Nat) ~ (Rank @(Maybe Nat) sh2 :: Nat),
 ConvertTensor (ADVal target)) =>
StaticShX sh2
-> ADVal target (TKX2 sh x) -> ADVal target (TKX2 sh2 x)
forall (target :: Target) (x :: TK) (sh :: [Maybe Nat])
       (sh2 :: [Maybe Nat]).
(BaseTensor target, KnownSTK x, KnownShX sh,
 (Rank @(Maybe Nat) sh :: Nat) ~ (Rank @(Maybe Nat) sh2 :: Nat),
 ConvertTensor target) =>
StaticShX sh2 -> target (TKX2 sh x) -> target (TKX2 sh2 x)
xmcast (ShX
  ((':)
     @(Maybe Nat)
     ('Just @Nat n)
     ((':) @(Maybe Nat) ('Just @Nat n) ('[] @(Maybe Nat))))
  (ZonkAny @Type 1)
-> StaticShX
     ((':)
        @(Maybe Nat)
        ('Just @Nat n)
        ((':) @(Maybe Nat) ('Just @Nat n) ('[] @(Maybe Nat))))
forall (sh :: [Maybe Nat]) i. ShX sh i -> StaticShX sh
ssxFromShX (SNat n -> SMayNat @Nat (ZonkAny @Type 1) SNat ('Just @Nat n)
forall {k} (f :: k -> Type) (n1 :: k) i.
f n1 -> SMayNat @k i f ('Just @k n1)
Nested.SKnown (forall (n :: Nat). KnownNat n => SNat n
SNat @m)
                                            SMayNat @Nat (ZonkAny @Type 1) SNat ('Just @Nat n)
-> ShX
     ((':) @(Maybe Nat) ('Just @Nat n) ('[] @(Maybe Nat)))
     (ZonkAny @Type 1)
-> ShX
     ((':)
        @(Maybe Nat)
        ('Just @Nat n)
        ((':) @(Maybe Nat) ('Just @Nat n) ('[] @(Maybe Nat))))
     (ZonkAny @Type 1)
forall {sh1 :: [Maybe Nat]} {i} (n :: Maybe Nat)
       (sh :: [Maybe Nat]).
(((':) @(Maybe Nat) n sh :: [Maybe Nat]) ~ (sh1 :: [Maybe Nat])) =>
SMayNat @Nat i SNat n -> ShX sh i -> ShX sh1 i
:$% SNat n -> SMayNat @Nat (ZonkAny @Type 1) SNat ('Just @Nat n)
forall {k} (f :: k -> Type) (n1 :: k) i.
f n1 -> SMayNat @k i f ('Just @k n1)
Nested.SKnown (forall (n :: Nat). KnownNat n => SNat n
SNat @n)
                                            SMayNat @Nat (ZonkAny @Type 1) SNat ('Just @Nat n)
-> ShX ('[] @(Maybe Nat)) (ZonkAny @Type 1)
-> ShX
     ((':) @(Maybe Nat) ('Just @Nat n) ('[] @(Maybe Nat)))
     (ZonkAny @Type 1)
forall {sh1 :: [Maybe Nat]} {i} (n :: Maybe Nat)
       (sh :: [Maybe Nat]).
(((':) @(Maybe Nat) n sh :: [Maybe Nat]) ~ (sh1 :: [Maybe Nat])) =>
SMayNat @Nat i SNat n -> ShX sh i -> ShX sh1 i
:$% ShX ('[] @(Maybe Nat)) (ZonkAny @Type 1)
forall (sh :: [Maybe Nat]) i.
((sh :: [Maybe Nat]) ~ ('[] @(Maybe Nat) :: [Maybe Nat])) =>
ShX sh i
ZSX)) ADVal
  target
  (TKX
     ((':) @(Maybe Nat) mm ((':) @(Maybe Nat) mn ('[] @(Maybe Nat)))) r)
m))
  txmatmul2 :: forall (m :: Nat) (n :: Nat) (p :: Nat) r.
(KnownNat m, KnownNat n, KnownNat p, GoodScalar r,
 ConvertTensor (ADVal target)) =>
ADVal
  target
  (TKX
     ((':)
        @(Maybe Nat)
        ('Just @Nat m)
        ((':) @(Maybe Nat) ('Just @Nat n) ('[] @(Maybe Nat))))
     r)
-> ADVal
     target
     (TKX
        ((':)
           @(Maybe Nat)
           ('Just @Nat n)
           ((':) @(Maybe Nat) ('Just @Nat p) ('[] @(Maybe Nat))))
        r)
-> ADVal
     target
     (TKX
        ((':)
           @(Maybe Nat)
           ('Just @Nat m)
           ((':) @(Maybe Nat) ('Just @Nat p) ('[] @(Maybe Nat))))
        r)
txmatmul2 ADVal
  target
  (TKX
     ((':)
        @(Maybe Nat)
        ('Just @Nat m)
        ((':) @(Maybe Nat) ('Just @Nat n) ('[] @(Maybe Nat))))
     r)
m1 ADVal
  target
  (TKX
     ((':)
        @(Maybe Nat)
        ('Just @Nat n)
        ((':) @(Maybe Nat) ('Just @Nat p) ('[] @(Maybe Nat))))
     r)
m2 =
    ADVal
  target
  (TKX2
     ((':)
        @(Maybe Nat)
        ('Just @Nat n)
        ((':)
           @(Maybe Nat)
           ('Just @Nat m)
           ((':) @(Maybe Nat) ('Just @Nat p) ('[] @(Maybe Nat)))))
     (TKScalar r))
-> ADVal
     target
     (TKX2
        ((':)
           @(Maybe Nat)
           ('Just @Nat m)
           ((':) @(Maybe Nat) ('Just @Nat p) ('[] @(Maybe Nat))))
        (TKScalar r))
forall (n :: Nat) (sh :: [Maybe Nat]) (x :: TK).
(KnownNat n, KnownShX sh, KnownSTK x) =>
ADVal target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
-> ADVal target (TKX2 sh x)
forall (target :: Target) (n :: Nat) (sh :: [Maybe Nat]) (x :: TK).
(BaseTensor target, KnownNat n, KnownShX sh, KnownSTK x) =>
target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
-> target (TKX2 sh x)
txsum (Perm ((':) @Nat 2 ((':) @Nat 1 ((':) @Nat 0 ('[] @Nat))))
-> ADVal
     target
     (TKX2
        ((':)
           @(Maybe Nat)
           ('Just @Nat p)
           ((':)
              @(Maybe Nat)
              ('Just @Nat m)
              ((':) @(Maybe Nat) ('Just @Nat n) ('[] @(Maybe Nat)))))
        (TKScalar r))
-> ADVal
     target
     (TKX2
        (PermutePrefix
           @(Maybe Nat)
           ((':) @Nat 2 ((':) @Nat 1 ((':) @Nat 0 ('[] @Nat))))
           ((':)
              @(Maybe Nat)
              ('Just @Nat p)
              ((':)
                 @(Maybe Nat)
                 ('Just @Nat m)
                 ((':) @(Maybe Nat) ('Just @Nat n) ('[] @(Maybe Nat))))))
        (TKScalar r))
forall (perm :: [Nat]) (sh :: [Maybe Nat]) (x :: TK).
(IsPermutation perm,
 (<=) @Nat (Rank @Nat perm) (Rank @(Maybe Nat) sh), KnownSTK x) =>
Perm perm
-> ADVal target (TKX2 sh x)
-> ADVal target (TKX2 (PermutePrefix @(Maybe Nat) perm sh) x)
forall (target :: Target) (perm :: [Nat]) (sh :: [Maybe Nat])
       (x :: TK).
(BaseTensor target, IsPermutation perm,
 (<=) @Nat (Rank @Nat perm) (Rank @(Maybe Nat) sh), KnownSTK x) =>
Perm perm
-> target (TKX2 sh x)
-> target (TKX2 (PermutePrefix @(Maybe Nat) perm sh) x)
txtranspose (forall (l :: [Nat]). KnownPerm l => Perm l
Permutation.makePerm @'[2, 1, 0])
                       (SNat p
-> StaticShX
     ((':)
        @(Maybe Nat)
        ('Just @Nat m)
        ((':) @(Maybe Nat) ('Just @Nat n) ('[] @(Maybe Nat))))
-> ADVal
     target
     (TKX
        ((':)
           @(Maybe Nat)
           ('Just @Nat m)
           ((':) @(Maybe Nat) ('Just @Nat n) ('[] @(Maybe Nat))))
        r)
-> ADVal
     target
     (TKX2
        ((':)
           @(Maybe Nat)
           ('Just @Nat p)
           ((':)
              @(Maybe Nat)
              ('Just @Nat m)
              ((':) @(Maybe Nat) ('Just @Nat n) ('[] @(Maybe Nat)))))
        (TKScalar r))
forall (sh :: [Maybe Nat]) (k :: Nat) (x :: TK).
KnownSTK x =>
SNat k
-> StaticShX sh
-> ADVal target (TKX2 sh x)
-> ADVal target (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) x)
forall (target :: Target) (sh :: [Maybe Nat]) (k :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
SNat k
-> StaticShX sh
-> target (TKX2 sh x)
-> target (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) x)
txreplicate SNat p
forall (n :: Nat). KnownNat n => SNat n
SNat StaticShX
  ((':)
     @(Maybe Nat)
     ('Just @Nat m)
     ((':) @(Maybe Nat) ('Just @Nat n) ('[] @(Maybe Nat))))
forall (sh :: [Maybe Nat]). KnownShX sh => StaticShX sh
knownShX ADVal
  target
  (TKX
     ((':)
        @(Maybe Nat)
        ('Just @Nat m)
        ((':) @(Maybe Nat) ('Just @Nat n) ('[] @(Maybe Nat))))
     r)
m1)
           ADVal
  target
  (TKX2
     ((':)
        @(Maybe Nat)
        ('Just @Nat n)
        ((':)
           @(Maybe Nat)
           ('Just @Nat m)
           ((':) @(Maybe Nat) ('Just @Nat p) ('[] @(Maybe Nat)))))
     (TKScalar r))
-> ADVal
     target
     (TKX2
        ((':)
           @(Maybe Nat)
           ('Just @Nat n)
           ((':)
              @(Maybe Nat)
              ('Just @Nat m)
              ((':) @(Maybe Nat) ('Just @Nat p) ('[] @(Maybe Nat)))))
        (TKScalar r))
-> ADVal
     target
     (TKX2
        ((':)
           @(Maybe Nat)
           ('Just @Nat n)
           ((':)
              @(Maybe Nat)
              ('Just @Nat m)
              ((':) @(Maybe Nat) ('Just @Nat p) ('[] @(Maybe Nat)))))
        (TKScalar r))
forall a. Num a => a -> a -> a
* Perm ((':) @Nat 1 ((':) @Nat 0 ('[] @Nat)))
-> ADVal
     target
     (TKX2
        ((':)
           @(Maybe Nat)
           ('Just @Nat m)
           ((':)
              @(Maybe Nat)
              ('Just @Nat n)
              ((':) @(Maybe Nat) ('Just @Nat p) ('[] @(Maybe Nat)))))
        (TKScalar r))
-> ADVal
     target
     (TKX2
        (PermutePrefix
           @(Maybe Nat)
           ((':) @Nat 1 ((':) @Nat 0 ('[] @Nat)))
           ((':)
              @(Maybe Nat)
              ('Just @Nat m)
              ((':)
                 @(Maybe Nat)
                 ('Just @Nat n)
                 ((':) @(Maybe Nat) ('Just @Nat p) ('[] @(Maybe Nat))))))
        (TKScalar r))
forall (perm :: [Nat]) (sh :: [Maybe Nat]) (x :: TK).
(IsPermutation perm,
 (<=) @Nat (Rank @Nat perm) (Rank @(Maybe Nat) sh), KnownSTK x) =>
Perm perm
-> ADVal target (TKX2 sh x)
-> ADVal target (TKX2 (PermutePrefix @(Maybe Nat) perm sh) x)
forall (target :: Target) (perm :: [Nat]) (sh :: [Maybe Nat])
       (x :: TK).
(BaseTensor target, IsPermutation perm,
 (<=) @Nat (Rank @Nat perm) (Rank @(Maybe Nat) sh), KnownSTK x) =>
Perm perm
-> target (TKX2 sh x)
-> target (TKX2 (PermutePrefix @(Maybe Nat) perm sh) x)
txtranspose (forall (l :: [Nat]). KnownPerm l => Perm l
Permutation.makePerm @'[1, 0])
                         (SNat m
-> StaticShX
     ((':)
        @(Maybe Nat)
        ('Just @Nat n)
        ((':) @(Maybe Nat) ('Just @Nat p) ('[] @(Maybe Nat))))
-> ADVal
     target
     (TKX
        ((':)
           @(Maybe Nat)
           ('Just @Nat n)
           ((':) @(Maybe Nat) ('Just @Nat p) ('[] @(Maybe Nat))))
        r)
-> ADVal
     target
     (TKX2
        ((':)
           @(Maybe Nat)
           ('Just @Nat m)
           ((':)
              @(Maybe Nat)
              ('Just @Nat n)
              ((':) @(Maybe Nat) ('Just @Nat p) ('[] @(Maybe Nat)))))
        (TKScalar r))
forall (sh :: [Maybe Nat]) (k :: Nat) (x :: TK).
KnownSTK x =>
SNat k
-> StaticShX sh
-> ADVal target (TKX2 sh x)
-> ADVal target (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) x)
forall (target :: Target) (sh :: [Maybe Nat]) (k :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
SNat k
-> StaticShX sh
-> target (TKX2 sh x)
-> target (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) x)
txreplicate SNat m
forall (n :: Nat). KnownNat n => SNat n
SNat StaticShX
  ((':)
     @(Maybe Nat)
     ('Just @Nat n)
     ((':) @(Maybe Nat) ('Just @Nat p) ('[] @(Maybe Nat))))
forall (sh :: [Maybe Nat]). KnownShX sh => StaticShX sh
knownShX ADVal
  target
  (TKX
     ((':)
        @(Maybe Nat)
        ('Just @Nat n)
        ((':) @(Maybe Nat) ('Just @Nat p) ('[] @(Maybe Nat))))
     r)
m2))
  txreplicate :: forall (sh :: [Maybe Nat]) (k :: Nat) (x :: TK).
KnownSTK x =>
SNat k
-> StaticShX sh
-> ADVal target (TKX2 sh x)
-> ADVal target (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) x)
txreplicate SNat k
snat StaticShX sh
sh (D target (TKX2 sh x)
u Delta target (TKX2 sh x)
u') =
    target (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) x)
-> Delta target (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) x)
-> ADVal target (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) x)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (SNat k
-> StaticShX sh
-> target (TKX2 sh x)
-> target (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) x)
forall (sh :: [Maybe Nat]) (k :: Nat) (x :: TK).
KnownSTK x =>
SNat k
-> StaticShX sh
-> target (TKX2 sh x)
-> target (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) x)
forall (target :: Target) (sh :: [Maybe Nat]) (k :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
SNat k
-> StaticShX sh
-> target (TKX2 sh x)
-> target (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) x)
txreplicate SNat k
snat StaticShX sh
sh target (TKX2 sh x)
u) (SNat k
-> SingletonTK (TKX2 sh x)
-> Delta target (TKX2 sh x)
-> Delta target (BuildTensorKind k (TKX2 sh x))
forall (y :: TK) (k :: Nat) (a :: Target).
SNat k
-> SingletonTK y -> Delta a y -> Delta a (BuildTensorKind k y)
DeltaReplicate SNat k
snat (StaticShX sh -> SingletonTK x -> SingletonTK (TKX2 sh x)
forall (sh :: [Maybe Nat]) (x :: TK).
StaticShX sh -> SingletonTK x -> SingletonTK (TKX2 sh x)
STKX StaticShX sh
sh SingletonTK x
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK) Delta target (TKX2 sh x)
u')
  txindex :: forall (sh1 :: [Maybe Nat]) (sh2 :: [Maybe Nat]) (x :: TK).
(KnownShX sh1, KnownShX sh2, KnownSTK x) =>
ADVal target (TKX2 ((++) @(Maybe Nat) sh1 sh2) x)
-> IxXOf (ADVal target) sh1 -> ADVal target (TKX2 sh2 x)
txindex (D target (TKX2 ((++) @(Maybe Nat) sh1 sh2) x)
u Delta target (TKX2 ((++) @(Maybe Nat) sh1 sh2) x)
u') IxXOf (ADVal target) sh1
i =
    let !ix :: IxX sh1 (PrimalOf target (TKScalar Int64))
ix = PrimalOf target (TKScalar Int64)
-> PrimalOf target (TKScalar Int64)
forall (y :: TK). PrimalOf target y -> PrimalOf target y
forall (target :: Target) (y :: TK).
ShareTensor target =>
target y -> target y
tshare (PrimalOf target (TKScalar Int64)
 -> PrimalOf target (TKScalar Int64))
-> (target (TKScalar Int64) -> PrimalOf target (TKScalar Int64))
-> target (TKScalar Int64)
-> PrimalOf target (TKScalar Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. target (TKScalar Int64) -> PrimalOf target (TKScalar Int64)
forall (y :: TK). target y -> PrimalOf target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
target y -> PrimalOf target y
tprimalPart (target (TKScalar Int64) -> PrimalOf target (TKScalar Int64))
-> IxX sh1 (target (TKScalar Int64))
-> IxX sh1 (PrimalOf target (TKScalar Int64))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> IxX sh1 (target (TKScalar Int64))
IxXOf (ADVal target) sh1
i
    in target (TKX2 sh2 x)
-> Delta target (TKX2 sh2 x) -> ADVal target (TKX2 sh2 x)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (target (TKX2 ((++) @(Maybe Nat) sh1 sh2) x)
-> IxX sh1 (PrimalOf target (TKScalar Int64))
-> target (TKX2 sh2 x)
forall (sh1 :: [Maybe Nat]) (sh2 :: [Maybe Nat]) (x :: TK).
(KnownShX sh1, KnownShX sh2, KnownSTK x) =>
target (TKX2 ((++) @(Maybe Nat) sh1 sh2) x)
-> IxXOf target sh1 -> target (TKX2 sh2 x)
forall (target :: Target) (sh1 :: [Maybe Nat]) (sh2 :: [Maybe Nat])
       (x :: TK).
(BaseTensor target, KnownShX sh1, KnownShX sh2, KnownSTK x) =>
target (TKX2 ((++) @(Maybe Nat) sh1 sh2) x)
-> IxXOf target sh1 -> target (TKX2 sh2 x)
txindex target (TKX2 ((++) @(Maybe Nat) sh1 sh2) x)
u IxX sh1 (PrimalOf target (TKScalar Int64))
ix) (StaticShX sh2
-> Delta target (TKX2 ((++) @(Maybe Nat) sh1 sh2) x)
-> IxX sh1 (PrimalOf target (TKScalar Int64))
-> Delta target (TKX2 sh2 x)
forall (shm :: [Maybe Nat]) (shn :: [Maybe Nat]) (r :: TK)
       (a :: Target).
StaticShX shn
-> Delta a (TKX2 ((++) @(Maybe Nat) shm shn) r)
-> IxXOf a shm
-> Delta a (TKX2 shn r)
DeltaIndexX StaticShX sh2
forall (sh :: [Maybe Nat]). KnownShX sh => StaticShX sh
knownShX Delta target (TKX2 ((++) @(Maybe Nat) sh1 sh2) x)
u' IxX sh1 (PrimalOf target (TKScalar Int64))
ix)
  txscatter :: forall (shm :: [Maybe Nat]) (shn :: [Maybe Nat])
       (shp :: [Maybe Nat]) (x :: TK).
(KnownShX shm, KnownShX shn, KnownShX shp, KnownSTK x) =>
IShX ((++) @(Maybe Nat) shp shn)
-> ADVal target (TKX2 ((++) @(Maybe Nat) shm shn) x)
-> (IxXOf (ADVal target) shm -> IxXOf (ADVal target) shp)
-> ADVal target (TKX2 ((++) @(Maybe Nat) shp shn) x)
txscatter @shm @shn @shp IShX ((++) @(Maybe Nat) shp shn)
sh (D target (TKX2 ((++) @(Maybe Nat) shm shn) x)
u Delta target (TKX2 ((++) @(Maybe Nat) shm shn) x)
u') IxXOf (ADVal target) shm -> IxXOf (ADVal target) shp
f =
    let g :: IxX shm (PrimalOf target (TKScalar Int64))
-> IxX shp (PrimalOf target (TKScalar Int64))
g IxX shm (PrimalOf target (TKScalar Int64))
x = target (TKScalar Int64) -> PrimalOf target (TKScalar Int64)
forall (y :: TK). target y -> PrimalOf target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
target y -> PrimalOf target y
tprimalPart (target (TKScalar Int64) -> PrimalOf target (TKScalar Int64))
-> IxX shp (target (TKScalar Int64))
-> IxX shp (PrimalOf target (TKScalar Int64))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> IxXOf (ADVal target) shm -> IxXOf (ADVal target) shp
f (SingletonTK (TKScalar Int64)
-> PrimalOf target (TKScalar Int64) -> target (TKScalar Int64)
forall (y :: TK). SingletonTK y -> PrimalOf target y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> PrimalOf target y -> target y
tfromPrimal SingletonTK (TKScalar Int64)
forall r. GoodScalar r => SingletonTK (TKScalar r)
STKScalar (PrimalOf target (TKScalar Int64) -> target (TKScalar Int64))
-> IxX shm (PrimalOf target (TKScalar Int64))
-> IxX shm (target (TKScalar Int64))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> IxX shm (PrimalOf target (TKScalar Int64))
x)
    in target (TKX2 ((++) @(Maybe Nat) shp shn) x)
-> Delta target (TKX2 ((++) @(Maybe Nat) shp shn) x)
-> ADVal target (TKX2 ((++) @(Maybe Nat) shp shn) x)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (forall (target :: Target) (shm :: [Maybe Nat]) (shn :: [Maybe Nat])
       (shp :: [Maybe Nat]) (x :: TK).
(BaseTensor target, KnownShX shm, KnownShX shn, KnownShX shp,
 KnownSTK x) =>
IShX ((++) @(Maybe Nat) shp shn)
-> target (TKX2 ((++) @(Maybe Nat) shm shn) x)
-> (IxXOf target shm -> IxXOf target shp)
-> target (TKX2 ((++) @(Maybe Nat) shp shn) x)
txscatter @_ @shm @shn @shp IShX ((++) @(Maybe Nat) shp shn)
sh target (TKX2 ((++) @(Maybe Nat) shm shn) x)
u IxX shm (PrimalOf target (TKScalar Int64))
-> IxX shp (PrimalOf target (TKScalar Int64))
g)
          (forall (shm :: [Maybe Nat]) (shn :: [Maybe Nat])
       (shp :: [Maybe Nat]) (a :: Target) (r :: TK).
StaticShX shm
-> StaticShX shn
-> StaticShX shp
-> IShX ((++) @(Maybe Nat) shp shn)
-> Delta a (TKX2 ((++) @(Maybe Nat) shm shn) r)
-> (IxXOf a shm -> IxXOf a shp)
-> Delta a (TKX2 ((++) @(Maybe Nat) shp shn) r)
DeltaScatterX @shm @shn @shp StaticShX shm
forall (sh :: [Maybe Nat]). KnownShX sh => StaticShX sh
knownShX StaticShX shn
forall (sh :: [Maybe Nat]). KnownShX sh => StaticShX sh
knownShX StaticShX shp
forall (sh :: [Maybe Nat]). KnownShX sh => StaticShX sh
knownShX IShX ((++) @(Maybe Nat) shp shn)
sh Delta target (TKX2 ((++) @(Maybe Nat) shm shn) x)
u' IxX shm (PrimalOf target (TKScalar Int64))
-> IxX shp (PrimalOf target (TKScalar Int64))
g)
  txgather :: forall (shm :: [Maybe Nat]) (shn :: [Maybe Nat])
       (shp :: [Maybe Nat]) (x :: TK).
(KnownShX shm, KnownShX shn, KnownShX shp, KnownSTK x) =>
IShX ((++) @(Maybe Nat) shm shn)
-> ADVal target (TKX2 ((++) @(Maybe Nat) shp shn) x)
-> (IxXOf (ADVal target) shm -> IxXOf (ADVal target) shp)
-> ADVal target (TKX2 ((++) @(Maybe Nat) shm shn) x)
txgather @shm @shn @shp IShX ((++) @(Maybe Nat) shm shn)
sh (D target (TKX2 ((++) @(Maybe Nat) shp shn) x)
u Delta target (TKX2 ((++) @(Maybe Nat) shp shn) x)
u') IxXOf (ADVal target) shm -> IxXOf (ADVal target) shp
f =
    let g :: IxX shm (PrimalOf target (TKScalar Int64))
-> IxX shp (PrimalOf target (TKScalar Int64))
g IxX shm (PrimalOf target (TKScalar Int64))
x = target (TKScalar Int64) -> PrimalOf target (TKScalar Int64)
forall (y :: TK). target y -> PrimalOf target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
target y -> PrimalOf target y
tprimalPart (target (TKScalar Int64) -> PrimalOf target (TKScalar Int64))
-> IxX shp (target (TKScalar Int64))
-> IxX shp (PrimalOf target (TKScalar Int64))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> IxXOf (ADVal target) shm -> IxXOf (ADVal target) shp
f (SingletonTK (TKScalar Int64)
-> PrimalOf target (TKScalar Int64) -> target (TKScalar Int64)
forall (y :: TK). SingletonTK y -> PrimalOf target y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> PrimalOf target y -> target y
tfromPrimal SingletonTK (TKScalar Int64)
forall r. GoodScalar r => SingletonTK (TKScalar r)
STKScalar (PrimalOf target (TKScalar Int64) -> target (TKScalar Int64))
-> IxX shm (PrimalOf target (TKScalar Int64))
-> IxX shm (target (TKScalar Int64))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> IxX shm (PrimalOf target (TKScalar Int64))
x)
    in target (TKX2 ((++) @(Maybe Nat) shm shn) x)
-> Delta target (TKX2 ((++) @(Maybe Nat) shm shn) x)
-> ADVal target (TKX2 ((++) @(Maybe Nat) shm shn) x)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (forall (target :: Target) (shm :: [Maybe Nat]) (shn :: [Maybe Nat])
       (shp :: [Maybe Nat]) (x :: TK).
(BaseTensor target, KnownShX shm, KnownShX shn, KnownShX shp,
 KnownSTK x) =>
IShX ((++) @(Maybe Nat) shm shn)
-> target (TKX2 ((++) @(Maybe Nat) shp shn) x)
-> (IxXOf target shm -> IxXOf target shp)
-> target (TKX2 ((++) @(Maybe Nat) shm shn) x)
txgather @_ @shm @shn @shp IShX ((++) @(Maybe Nat) shm shn)
sh target (TKX2 ((++) @(Maybe Nat) shp shn) x)
u IxX shm (PrimalOf target (TKScalar Int64))
-> IxX shp (PrimalOf target (TKScalar Int64))
g)
          (forall (shm :: [Maybe Nat]) (shn :: [Maybe Nat])
       (shp :: [Maybe Nat]) (a :: Target) (r :: TK).
StaticShX shm
-> StaticShX shn
-> StaticShX shp
-> IShX ((++) @(Maybe Nat) shm shn)
-> Delta a (TKX2 ((++) @(Maybe Nat) shp shn) r)
-> (IxXOf a shm -> IxXOf a shp)
-> Delta a (TKX2 ((++) @(Maybe Nat) shm shn) r)
DeltaGatherX @shm @shn @shp StaticShX shm
forall (sh :: [Maybe Nat]). KnownShX sh => StaticShX sh
knownShX StaticShX shn
forall (sh :: [Maybe Nat]). KnownShX sh => StaticShX sh
knownShX StaticShX shp
forall (sh :: [Maybe Nat]). KnownShX sh => StaticShX sh
knownShX IShX ((++) @(Maybe Nat) shm shn)
sh Delta target (TKX2 ((++) @(Maybe Nat) shp shn) x)
u' IxX shm (PrimalOf target (TKScalar Int64))
-> IxX shp (PrimalOf target (TKScalar Int64))
g)
  txconcrete :: forall r (sh :: [Maybe Nat]).
GoodScalar r =>
Mixed sh r -> ADVal target (TKX sh r)
txconcrete Mixed sh r
a =
    let v :: target (TKX sh r)
v = Mixed sh r -> target (TKX sh r)
forall r (sh :: [Maybe Nat]).
GoodScalar r =>
Mixed sh r -> target (TKX sh r)
forall (target :: Target) r (sh :: [Maybe Nat]).
(BaseTensor target, GoodScalar r) =>
Mixed sh r -> target (TKX sh r)
txconcrete Mixed sh r
a
    in FullShapeTK (TKX sh r)
-> target (TKX sh r) -> ADVal target (TKX sh r)
forall (z :: TK) (f :: Target). FullShapeTK z -> f z -> ADVal f z
fromPrimalFTK (IShX sh -> FullShapeTK (TKScalar r) -> FullShapeTK (TKX sh r)
forall (sh :: [Maybe Nat]) (x :: TK).
IShX sh -> FullShapeTK x -> FullShapeTK (TKX2 sh x)
FTKX (Mixed sh r -> IShX sh
forall (sh :: [Maybe Nat]). Mixed sh r -> IShX sh
forall a (sh :: [Maybe Nat]). Elt a => Mixed sh a -> IShX sh
Nested.mshape Mixed sh r
a) FullShapeTK (TKScalar r)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar) target (TKX sh r)
v
  txfloor :: forall r r2 (sh :: [Maybe Nat]).
(GoodScalar r, RealFrac r, GoodScalar r2, Integral r2) =>
ADVal target (TKX sh r) -> ADVal target (TKX sh r2)
txfloor (D target (TKX sh r)
u Delta target (TKX sh r)
_) =
    let v :: target (TKX sh r2)
v = target (TKX sh r) -> target (TKX sh r2)
forall r r2 (sh :: [Maybe Nat]).
(GoodScalar r, RealFrac r, GoodScalar r2, Integral r2) =>
target (TKX sh r) -> target (TKX sh r2)
forall (target :: Target) r r2 (sh :: [Maybe Nat]).
(BaseTensor target, GoodScalar r, RealFrac r, GoodScalar r2,
 Integral r2) =>
target (TKX sh r) -> target (TKX sh r2)
txfloor target (TKX sh r)
u
    in FullShapeTK (TKX sh r2)
-> target (TKX sh r2) -> ADVal target (TKX sh r2)
forall (z :: TK) (f :: Target). FullShapeTK z -> f z -> ADVal f z
fromPrimalFTK (IShX sh -> FullShapeTK (TKScalar r2) -> FullShapeTK (TKX sh r2)
forall (sh :: [Maybe Nat]) (x :: TK).
IShX sh -> FullShapeTK x -> FullShapeTK (TKX2 sh x)
FTKX (target (TKX sh r2) -> IShX sh
forall (sh :: [Maybe Nat]) (x :: TK).
KnownSTK x =>
target (TKX2 sh x) -> IShX sh
forall (target :: Target) (sh :: [Maybe Nat]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKX2 sh x) -> IShX sh
xshape target (TKX sh r2)
v) FullShapeTK (TKScalar r2)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar) target (TKX sh r2)
v
  txfromIntegral :: forall r1 r2 (sh :: [Maybe Nat]).
(GoodScalar r1, Integral r1, GoodScalar r2) =>
ADVal target (TKX sh r1) -> ADVal target (TKX sh r2)
txfromIntegral (D target (TKX sh r1)
u Delta target (TKX sh r1)
_) =
    let v :: target (TKX sh r2)
v = target (TKX sh r1) -> target (TKX sh r2)
forall r1 r2 (sh :: [Maybe Nat]).
(GoodScalar r1, Integral r1, GoodScalar r2) =>
target (TKX sh r1) -> target (TKX sh r2)
forall (target :: Target) r1 r2 (sh :: [Maybe Nat]).
(BaseTensor target, GoodScalar r1, Integral r1, GoodScalar r2) =>
target (TKX sh r1) -> target (TKX sh r2)
txfromIntegral target (TKX sh r1)
u
    in FullShapeTK (TKX sh r2)
-> target (TKX sh r2) -> ADVal target (TKX sh r2)
forall (z :: TK) (f :: Target). FullShapeTK z -> f z -> ADVal f z
fromPrimalFTK (IShX sh -> FullShapeTK (TKScalar r2) -> FullShapeTK (TKX sh r2)
forall (sh :: [Maybe Nat]) (x :: TK).
IShX sh -> FullShapeTK x -> FullShapeTK (TKX2 sh x)
FTKX (target (TKX sh r2) -> IShX sh
forall (sh :: [Maybe Nat]) (x :: TK).
KnownSTK x =>
target (TKX2 sh x) -> IShX sh
forall (target :: Target) (sh :: [Maybe Nat]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKX2 sh x) -> IShX sh
xshape target (TKX sh r2)
v) FullShapeTK (TKScalar r2)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar) target (TKX sh r2)
v
  txcast :: forall r1 r2 (sh :: [Maybe Nat]).
(RealFrac r1, GoodScalar r1, RealFrac r2, GoodScalar r2) =>
ADVal target (TKX sh r1) -> ADVal target (TKX sh r2)
txcast (D target (TKX sh r1)
u Delta target (TKX sh r1)
u') = target (TKX sh r2)
-> Delta target (TKX sh r2) -> ADVal target (TKX sh r2)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (target (TKX sh r1) -> target (TKX sh r2)
forall r1 r2 (sh :: [Maybe Nat]).
(RealFrac r1, GoodScalar r1, RealFrac r2, GoodScalar r2) =>
target (TKX sh r1) -> target (TKX sh r2)
forall (target :: Target) r1 r2 (sh :: [Maybe Nat]).
(BaseTensor target, RealFrac r1, GoodScalar r1, RealFrac r2,
 GoodScalar r2) =>
target (TKX sh r1) -> target (TKX sh r2)
txcast target (TKX sh r1)
u) (Delta target (TKX sh r1) -> Delta target (TKX sh r2)
forall r1 r2 (a :: Target) (sh :: [Maybe Nat]).
(GoodScalar r1, RealFrac r1, GoodScalar r2, RealFrac r2) =>
Delta a (TKX sh r1) -> Delta a (TKX2 sh (TKScalar r2))
DeltaCastX Delta target (TKX sh r1)
u')
  txminIndex :: forall (mn :: Maybe Nat) (sh :: [Maybe Nat]) r r2.
(GoodScalar r, GoodScalar r2) =>
ADVal target (TKX ((':) @(Maybe Nat) mn sh) r)
-> ADVal
     target (TKX (Init @(Maybe Nat) ((':) @(Maybe Nat) mn sh)) r2)
txminIndex (D target (TKX ((':) @(Maybe Nat) mn sh) r)
u Delta target (TKX ((':) @(Maybe Nat) mn sh) r)
_) =
    let v :: target
  (TKX2 (Init @(Maybe Nat) ((':) @(Maybe Nat) mn sh)) (TKScalar r2))
v = target (TKX ((':) @(Maybe Nat) mn sh) r)
-> target
     (TKX2 (Init @(Maybe Nat) ((':) @(Maybe Nat) mn sh)) (TKScalar r2))
forall (mn :: Maybe Nat) (sh :: [Maybe Nat]) r r2.
(GoodScalar r, GoodScalar r2) =>
target (TKX ((':) @(Maybe Nat) mn sh) r)
-> target (TKX (Init @(Maybe Nat) ((':) @(Maybe Nat) mn sh)) r2)
forall (target :: Target) (mn :: Maybe Nat) (sh :: [Maybe Nat]) r
       r2.
(BaseTensor target, GoodScalar r, GoodScalar r2) =>
target (TKX ((':) @(Maybe Nat) mn sh) r)
-> target (TKX (Init @(Maybe Nat) ((':) @(Maybe Nat) mn sh)) r2)
txminIndex target (TKX ((':) @(Maybe Nat) mn sh) r)
u
    in FullShapeTK
  (TKX2 (Init @(Maybe Nat) ((':) @(Maybe Nat) mn sh)) (TKScalar r2))
-> target
     (TKX2 (Init @(Maybe Nat) ((':) @(Maybe Nat) mn sh)) (TKScalar r2))
-> ADVal
     target
     (TKX2 (Init @(Maybe Nat) ((':) @(Maybe Nat) mn sh)) (TKScalar r2))
forall (z :: TK) (f :: Target). FullShapeTK z -> f z -> ADVal f z
fromPrimalFTK (IShX (Init @(Maybe Nat) ((':) @(Maybe Nat) mn sh))
-> FullShapeTK (TKScalar r2)
-> FullShapeTK
     (TKX2 (Init @(Maybe Nat) ((':) @(Maybe Nat) mn sh)) (TKScalar r2))
forall (sh :: [Maybe Nat]) (x :: TK).
IShX sh -> FullShapeTK x -> FullShapeTK (TKX2 sh x)
FTKX (target
  (TKX2 (Init @(Maybe Nat) ((':) @(Maybe Nat) mn sh)) (TKScalar r2))
-> IShX (Init @(Maybe Nat) ((':) @(Maybe Nat) mn sh))
forall (sh :: [Maybe Nat]) (x :: TK).
KnownSTK x =>
target (TKX2 sh x) -> IShX sh
forall (target :: Target) (sh :: [Maybe Nat]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKX2 sh x) -> IShX sh
xshape target
  (TKX2 (Init @(Maybe Nat) ((':) @(Maybe Nat) mn sh)) (TKScalar r2))
v) FullShapeTK (TKScalar r2)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar) target
  (TKX2 (Init @(Maybe Nat) ((':) @(Maybe Nat) mn sh)) (TKScalar r2))
v
  txmaxIndex :: forall (mn :: Maybe Nat) (sh :: [Maybe Nat]) r r2.
(GoodScalar r, GoodScalar r2) =>
ADVal target (TKX ((':) @(Maybe Nat) mn sh) r)
-> ADVal
     target (TKX (Init @(Maybe Nat) ((':) @(Maybe Nat) mn sh)) r2)
txmaxIndex (D target (TKX ((':) @(Maybe Nat) mn sh) r)
u Delta target (TKX ((':) @(Maybe Nat) mn sh) r)
_) =
    let v :: target
  (TKX2 (Init @(Maybe Nat) ((':) @(Maybe Nat) mn sh)) (TKScalar r2))
v = target (TKX ((':) @(Maybe Nat) mn sh) r)
-> target
     (TKX2 (Init @(Maybe Nat) ((':) @(Maybe Nat) mn sh)) (TKScalar r2))
forall (mn :: Maybe Nat) (sh :: [Maybe Nat]) r r2.
(GoodScalar r, GoodScalar r2) =>
target (TKX ((':) @(Maybe Nat) mn sh) r)
-> target (TKX (Init @(Maybe Nat) ((':) @(Maybe Nat) mn sh)) r2)
forall (target :: Target) (mn :: Maybe Nat) (sh :: [Maybe Nat]) r
       r2.
(BaseTensor target, GoodScalar r, GoodScalar r2) =>
target (TKX ((':) @(Maybe Nat) mn sh) r)
-> target (TKX (Init @(Maybe Nat) ((':) @(Maybe Nat) mn sh)) r2)
txmaxIndex target (TKX ((':) @(Maybe Nat) mn sh) r)
u
    in FullShapeTK
  (TKX2 (Init @(Maybe Nat) ((':) @(Maybe Nat) mn sh)) (TKScalar r2))
-> target
     (TKX2 (Init @(Maybe Nat) ((':) @(Maybe Nat) mn sh)) (TKScalar r2))
-> ADVal
     target
     (TKX2 (Init @(Maybe Nat) ((':) @(Maybe Nat) mn sh)) (TKScalar r2))
forall (z :: TK) (f :: Target). FullShapeTK z -> f z -> ADVal f z
fromPrimalFTK (IShX (Init @(Maybe Nat) ((':) @(Maybe Nat) mn sh))
-> FullShapeTK (TKScalar r2)
-> FullShapeTK
     (TKX2 (Init @(Maybe Nat) ((':) @(Maybe Nat) mn sh)) (TKScalar r2))
forall (sh :: [Maybe Nat]) (x :: TK).
IShX sh -> FullShapeTK x -> FullShapeTK (TKX2 sh x)
FTKX (target
  (TKX2 (Init @(Maybe Nat) ((':) @(Maybe Nat) mn sh)) (TKScalar r2))
-> IShX (Init @(Maybe Nat) ((':) @(Maybe Nat) mn sh))
forall (sh :: [Maybe Nat]) (x :: TK).
KnownSTK x =>
target (TKX2 sh x) -> IShX sh
forall (target :: Target) (sh :: [Maybe Nat]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKX2 sh x) -> IShX sh
xshape target
  (TKX2 (Init @(Maybe Nat) ((':) @(Maybe Nat) mn sh)) (TKScalar r2))
v) FullShapeTK (TKScalar r2)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar) target
  (TKX2 (Init @(Maybe Nat) ((':) @(Maybe Nat) mn sh)) (TKScalar r2))
v
  txiota :: forall (n :: Nat) r.
(KnownNat n, GoodScalar r) =>
ADVal
  target
  (TKX ((':) @(Maybe Nat) ('Just @Nat n) ('[] @(Maybe Nat))) r)
txiota = FullShapeTK
  (TKX ((':) @(Maybe Nat) ('Just @Nat n) ('[] @(Maybe Nat))) r)
-> target
     (TKX ((':) @(Maybe Nat) ('Just @Nat n) ('[] @(Maybe Nat))) r)
-> ADVal
     target
     (TKX ((':) @(Maybe Nat) ('Just @Nat n) ('[] @(Maybe Nat))) r)
forall (z :: TK) (f :: Target). FullShapeTK z -> f z -> ADVal f z
fromPrimalFTK (IShX ((':) @(Maybe Nat) ('Just @Nat n) ('[] @(Maybe Nat)))
-> FullShapeTK (TKScalar r)
-> FullShapeTK
     (TKX ((':) @(Maybe Nat) ('Just @Nat n) ('[] @(Maybe Nat))) r)
forall (sh :: [Maybe Nat]) (x :: TK).
IShX sh -> FullShapeTK x -> FullShapeTK (TKX2 sh x)
FTKX (SNat n -> SMayNat @Nat Int SNat ('Just @Nat n)
forall {k} (f :: k -> Type) (n1 :: k) i.
f n1 -> SMayNat @k i f ('Just @k n1)
Nested.SKnown SNat n
forall (n :: Nat). KnownNat n => SNat n
SNat SMayNat @Nat Int SNat ('Just @Nat n)
-> ShX ('[] @(Maybe Nat)) Int
-> IShX ((':) @(Maybe Nat) ('Just @Nat n) ('[] @(Maybe Nat)))
forall {sh1 :: [Maybe Nat]} {i} (n :: Maybe Nat)
       (sh :: [Maybe Nat]).
(((':) @(Maybe Nat) n sh :: [Maybe Nat]) ~ (sh1 :: [Maybe Nat])) =>
SMayNat @Nat i SNat n -> ShX sh i -> ShX sh1 i
:$% ShX ('[] @(Maybe Nat)) Int
forall (sh :: [Maybe Nat]) i.
((sh :: [Maybe Nat]) ~ ('[] @(Maybe Nat) :: [Maybe Nat])) =>
ShX sh i
ZSX) FullShapeTK (TKScalar r)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar) target
  (TKX ((':) @(Maybe Nat) ('Just @Nat n) ('[] @(Maybe Nat))) r)
forall (n :: Nat) r.
(KnownNat n, GoodScalar r) =>
target
  (TKX ((':) @(Maybe Nat) ('Just @Nat n) ('[] @(Maybe Nat))) r)
forall (target :: Target) (n :: Nat) r.
(BaseTensor target, KnownNat n, GoodScalar r) =>
target
  (TKX ((':) @(Maybe Nat) ('Just @Nat n) ('[] @(Maybe Nat))) r)
txiota
  txappend :: forall (m :: Nat) (n :: Nat) (sh :: [Maybe Nat]) (x :: TK).
KnownSTK x =>
ADVal target (TKX2 ((':) @(Maybe Nat) ('Just @Nat m) sh) x)
-> ADVal target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
-> ADVal
     target (TKX2 ((':) @(Maybe Nat) ('Just @Nat (m + n)) sh) x)
txappend (D target (TKX2 ((':) @(Maybe Nat) ('Just @Nat m) sh) x)
u Delta target (TKX2 ((':) @(Maybe Nat) ('Just @Nat m) sh) x)
u') (D target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
v Delta target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
v') = target (TKX2 ((':) @(Maybe Nat) ('Just @Nat (m + n)) sh) x)
-> Delta
     target (TKX2 ((':) @(Maybe Nat) ('Just @Nat (m + n)) sh) x)
-> ADVal
     target (TKX2 ((':) @(Maybe Nat) ('Just @Nat (m + n)) sh) x)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (target (TKX2 ((':) @(Maybe Nat) ('Just @Nat m) sh) x)
-> target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
-> target (TKX2 ((':) @(Maybe Nat) ('Just @Nat (m + n)) sh) x)
forall (m :: Nat) (n :: Nat) (sh :: [Maybe Nat]) (x :: TK).
KnownSTK x =>
target (TKX2 ((':) @(Maybe Nat) ('Just @Nat m) sh) x)
-> target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
-> target (TKX2 ((':) @(Maybe Nat) ('Just @Nat (m + n)) sh) x)
forall (target :: Target) (m :: Nat) (n :: Nat) (sh :: [Maybe Nat])
       (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKX2 ((':) @(Maybe Nat) ('Just @Nat m) sh) x)
-> target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
-> target (TKX2 ((':) @(Maybe Nat) ('Just @Nat (m + n)) sh) x)
txappend target (TKX2 ((':) @(Maybe Nat) ('Just @Nat m) sh) x)
u target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
v) (Delta target (TKX2 ((':) @(Maybe Nat) ('Just @Nat m) sh) x)
-> Delta target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
-> Delta
     target (TKX2 ((':) @(Maybe Nat) ('Just @Nat (m + n)) sh) x)
forall (a :: Target) (m :: Nat) (sh :: [Maybe Nat]) (r :: TK)
       (n :: Nat).
Delta a (TKX2 ((':) @(Maybe Nat) ('Just @Nat m) sh) r)
-> Delta a (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) r)
-> Delta a (TKX2 ((':) @(Maybe Nat) ('Just @Nat (m + n)) sh) r)
DeltaAppendX Delta target (TKX2 ((':) @(Maybe Nat) ('Just @Nat m) sh) x)
u' Delta target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
v')
  txslice :: forall (i :: Nat) (n :: Nat) (k :: Nat) (sh :: [Maybe Nat])
       (x :: TK).
KnownSTK x =>
SNat i
-> SNat n
-> SNat k
-> ADVal
     target (TKX2 ((':) @(Maybe Nat) ('Just @Nat ((i + n) + k)) sh) x)
-> ADVal target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
txslice SNat i
i SNat n
n SNat k
k (D target (TKX2 ((':) @(Maybe Nat) ('Just @Nat ((i + n) + k)) sh) x)
u Delta
  target (TKX2 ((':) @(Maybe Nat) ('Just @Nat ((i + n) + k)) sh) x)
u') = target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
-> Delta target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
-> ADVal target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (SNat i
-> SNat n
-> SNat k
-> target
     (TKX2 ((':) @(Maybe Nat) ('Just @Nat ((i + n) + k)) sh) x)
-> target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
forall (i :: Nat) (n :: Nat) (k :: Nat) (sh :: [Maybe Nat])
       (x :: TK).
KnownSTK x =>
SNat i
-> SNat n
-> SNat k
-> target
     (TKX2 ((':) @(Maybe Nat) ('Just @Nat ((i + n) + k)) sh) x)
-> target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
forall (target :: Target) (i :: Nat) (n :: Nat) (k :: Nat)
       (sh :: [Maybe Nat]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
SNat i
-> SNat n
-> SNat k
-> target
     (TKX2 ((':) @(Maybe Nat) ('Just @Nat ((i + n) + k)) sh) x)
-> target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
txslice SNat i
i SNat n
n SNat k
k target (TKX2 ((':) @(Maybe Nat) ('Just @Nat ((i + n) + k)) sh) x)
u) (SNat i
-> SNat n
-> SNat k
-> Delta
     target (TKX2 ((':) @(Maybe Nat) ('Just @Nat ((i + n) + k)) sh) x)
-> Delta target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
forall (i :: Nat) (n :: Nat) (k :: Nat) (a :: Target)
       (sh :: [Maybe Nat]) (r :: TK).
SNat i
-> SNat n
-> SNat k
-> Delta
     a (TKX2 ((':) @(Maybe Nat) ('Just @Nat ((i + n) + k)) sh) r)
-> Delta a (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) r)
DeltaSliceX SNat i
i SNat n
n SNat k
k Delta
  target (TKX2 ((':) @(Maybe Nat) ('Just @Nat ((i + n) + k)) sh) x)
u')
  txreverse :: forall (mn :: Maybe Nat) (sh :: [Maybe Nat]) (x :: TK).
KnownSTK x =>
ADVal target (TKX2 ((':) @(Maybe Nat) mn sh) x)
-> ADVal target (TKX2 ((':) @(Maybe Nat) mn sh) x)
txreverse (D target (TKX2 ((':) @(Maybe Nat) mn sh) x)
u Delta target (TKX2 ((':) @(Maybe Nat) mn sh) x)
u') = target (TKX2 ((':) @(Maybe Nat) mn sh) x)
-> Delta target (TKX2 ((':) @(Maybe Nat) mn sh) x)
-> ADVal target (TKX2 ((':) @(Maybe Nat) mn sh) x)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (target (TKX2 ((':) @(Maybe Nat) mn sh) x)
-> target (TKX2 ((':) @(Maybe Nat) mn sh) x)
forall (mn :: Maybe Nat) (sh :: [Maybe Nat]) (x :: TK).
KnownSTK x =>
target (TKX2 ((':) @(Maybe Nat) mn sh) x)
-> target (TKX2 ((':) @(Maybe Nat) mn sh) x)
forall (target :: Target) (mn :: Maybe Nat) (sh :: [Maybe Nat])
       (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKX2 ((':) @(Maybe Nat) mn sh) x)
-> target (TKX2 ((':) @(Maybe Nat) mn sh) x)
txreverse target (TKX2 ((':) @(Maybe Nat) mn sh) x)
u) (Delta target (TKX2 ((':) @(Maybe Nat) mn sh) x)
-> Delta target (TKX2 ((':) @(Maybe Nat) mn sh) x)
forall (a :: Target) (mn :: Maybe Nat) (sh :: [Maybe Nat])
       (r :: TK).
Delta a (TKX2 ((':) @(Maybe Nat) mn sh) r)
-> Delta a (TKX2 ((':) @(Maybe Nat) mn sh) r)
DeltaReverseX Delta target (TKX2 ((':) @(Maybe Nat) mn sh) x)
u')
  txtranspose :: forall (perm :: [Nat]) (sh :: [Maybe Nat]) (x :: TK).
(IsPermutation perm,
 (<=) @Nat (Rank @Nat perm) (Rank @(Maybe Nat) sh), KnownSTK x) =>
Perm perm
-> ADVal target (TKX2 sh x)
-> ADVal target (TKX2 (PermutePrefix @(Maybe Nat) perm sh) x)
txtranspose Perm perm
perm (D target (TKX2 sh x)
u Delta target (TKX2 sh x)
u') =
    target
  (TKX2
     ((++)
        @(Maybe Nat)
        (Permute @(Maybe Nat) perm (TakeLen @(Maybe Nat) @Nat perm sh))
        (DropLen @(Maybe Nat) @Nat perm sh))
     x)
-> Delta
     target
     (TKX2
        ((++)
           @(Maybe Nat)
           (Permute @(Maybe Nat) perm (TakeLen @(Maybe Nat) @Nat perm sh))
           (DropLen @(Maybe Nat) @Nat perm sh))
        x)
-> ADVal
     target
     (TKX2
        ((++)
           @(Maybe Nat)
           (Permute @(Maybe Nat) perm (TakeLen @(Maybe Nat) @Nat perm sh))
           (DropLen @(Maybe Nat) @Nat perm sh))
        x)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (Perm perm
-> target (TKX2 sh x)
-> target
     (TKX2
        ((++)
           @(Maybe Nat)
           (Permute @(Maybe Nat) perm (TakeLen @(Maybe Nat) @Nat perm sh))
           (DropLen @(Maybe Nat) @Nat perm sh))
        x)
forall (perm :: [Nat]) (sh :: [Maybe Nat]) (x :: TK).
(IsPermutation perm,
 (<=) @Nat (Rank @Nat perm) (Rank @(Maybe Nat) sh), KnownSTK x) =>
Perm perm
-> target (TKX2 sh x)
-> target (TKX2 (PermutePrefix @(Maybe Nat) perm sh) x)
forall (target :: Target) (perm :: [Nat]) (sh :: [Maybe Nat])
       (x :: TK).
(BaseTensor target, IsPermutation perm,
 (<=) @Nat (Rank @Nat perm) (Rank @(Maybe Nat) sh), KnownSTK x) =>
Perm perm
-> target (TKX2 sh x)
-> target (TKX2 (PermutePrefix @(Maybe Nat) perm sh) x)
txtranspose Perm perm
perm target (TKX2 sh x)
u) (forall (perm :: [Nat]) (sh :: [Maybe Nat]) (r :: TK) (a :: Target).
(IsPermutation perm,
 (<=) @Nat (Rank @Nat perm) (Rank @(Maybe Nat) sh)) =>
Perm perm
-> Delta a (TKX2 sh r)
-> Delta a (TKX2 (PermutePrefix @(Maybe Nat) perm sh) r)
DeltaTransposeX @_ @_ @_ @target Perm perm
perm Delta target (TKX2 sh x)
u')
  txreshape :: forall (sh :: [Maybe Nat]) (sh2 :: [Maybe Nat]) (x :: TK).
KnownSTK x =>
IShX sh2 -> ADVal target (TKX2 sh x) -> ADVal target (TKX2 sh2 x)
txreshape IShX sh2
sh (D target (TKX2 sh x)
u Delta target (TKX2 sh x)
u') = target (TKX2 sh2 x)
-> Delta target (TKX2 sh2 x) -> ADVal target (TKX2 sh2 x)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (IShX sh2 -> target (TKX2 sh x) -> target (TKX2 sh2 x)
forall (sh :: [Maybe Nat]) (sh2 :: [Maybe Nat]) (x :: TK).
KnownSTK x =>
IShX sh2 -> target (TKX2 sh x) -> target (TKX2 sh2 x)
forall (target :: Target) (sh :: [Maybe Nat]) (sh2 :: [Maybe Nat])
       (x :: TK).
(BaseTensor target, KnownSTK x) =>
IShX sh2 -> target (TKX2 sh x) -> target (TKX2 sh2 x)
txreshape IShX sh2
sh target (TKX2 sh x)
u) (IShX sh2 -> Delta target (TKX2 sh x) -> Delta target (TKX2 sh2 x)
forall (sh2 :: [Maybe Nat]) (a :: Target) (sh :: [Maybe Nat])
       (r :: TK).
IShX sh2 -> Delta a (TKX2 sh r) -> Delta a (TKX2 sh2 r)
DeltaReshapeX IShX sh2
sh Delta target (TKX2 sh x)
u')
  txbuild1 :: forall (k :: Nat) (sh :: [Maybe Nat]) (x :: TK).
(KnownNat k, KnownShX sh, KnownSTK x) =>
(IntOf (ADVal target) -> ADVal target (TKX2 sh x))
-> ADVal target (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) x)
txbuild1 @k @sh @r IntOf (ADVal target) -> ADVal target (TKX2 sh x)
f =
    let l :: [Integer]
l = [Integer
0 .. forall (n :: Nat) r. (KnownNat n, Num r) => r
valueOf @k Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1]
    in if [Integer] -> Bool
forall a. [a] -> Bool
forall (t :: Type -> Type) a. Foldable t => t a -> Bool
null [Integer]
l
       then case StaticShX sh
-> StaticShX ('[] @(Maybe Nat))
-> Maybe ((:~:) @[Maybe Nat] sh ('[] @(Maybe Nat)))
forall (a :: [Maybe Nat]) (b :: [Maybe Nat]).
StaticShX a -> StaticShX b -> Maybe ((:~:) @[Maybe Nat] a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (forall (sh :: [Maybe Nat]). KnownShX sh => StaticShX sh
knownShX @sh) StaticShX ('[] @(Maybe Nat))
forall (sh :: [Maybe Nat]).
((sh :: [Maybe Nat]) ~ ('[] @(Maybe Nat) :: [Maybe Nat])) =>
StaticShX sh
ZKX of
         Just (:~:) @[Maybe Nat] sh ('[] @(Maybe Nat))
Refl | Dict @Type KnownElt (RepConcrete x)
Dict <- SingletonTK x -> Dict @Type KnownElt (RepConcrete x)
forall (y :: TK).
SingletonTK y -> Dict @Type KnownElt (RepConcrete y)
eltDictRep (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @r) ->
           let arr :: Mixed
  ((':) @(Maybe Nat) ('Just @Nat 0) ('[] @(Maybe Nat)))
  (RepConcrete x)
arr = forall a (sh :: [Maybe Nat]).
KnownElt a =>
IShX sh -> Mixed ((':) @(Maybe Nat) ('Just @Nat 0) sh) a
Nested.memptyArray @(RepConcrete r) ShX ('[] @(Maybe Nat)) Int
forall (sh :: [Maybe Nat]) i.
((sh :: [Maybe Nat]) ~ ('[] @(Maybe Nat) :: [Maybe Nat])) =>
ShX sh i
ZSX
           in (:~:) @Nat k 0
-> (((k :: Nat) ~ (0 :: Nat)) =>
    ADVal target (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) x))
-> ADVal target (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) x)
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:) @Nat k 0
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl :: k :~: 0) ((((k :: Nat) ~ (0 :: Nat)) =>
  ADVal target (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) x))
 -> ADVal target (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) x))
-> (((k :: Nat) ~ (0 :: Nat)) =>
    ADVal target (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) x))
-> ADVal target (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) x)
forall a b. (a -> b) -> a -> b
$
              FullShapeTK (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) x)
-> Concrete (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) x)
-> ADVal target (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) x)
forall (y :: TK). FullShapeTK y -> Concrete y -> ADVal target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
FullShapeTK y -> Concrete y -> target y
tconcrete (SingletonTK (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) x)
-> RepConcrete (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) x)
-> FullShapeTK (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) x)
forall (y :: TK). SingletonTK y -> RepConcrete y -> FullShapeTK y
tftkG SingletonTK (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) x)
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Mixed
  ((':) @(Maybe Nat) ('Just @Nat 0) ('[] @(Maybe Nat)))
  (RepConcrete x)
RepConcrete (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) x)
arr) (RepConcrete (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) x)
-> Concrete (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) x)
forall (y :: TK). RepConcrete y -> Concrete y
Concrete Mixed
  ((':) @(Maybe Nat) ('Just @Nat 0) ('[] @(Maybe Nat)))
  (RepConcrete x)
RepConcrete (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) x)
arr)
         Maybe ((:~:) @[Maybe Nat] sh ('[] @(Maybe Nat)))
Nothing -> [Char]
-> ADVal target (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) x)
forall a. HasCallStack => [Char] -> a
error [Char]
"xbuild1: shape ambiguity"
       else Vector (ADVal target (TKX2 sh x))
-> ADVal target (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) x)
forall (n :: Nat) (sh :: [Maybe Nat]) (x :: TK).
(KnownNat n, KnownShX sh, KnownSTK x) =>
Vector (ADVal target (TKX2 sh x))
-> ADVal target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
forall (target :: Target) (n :: Nat) (sh :: [Maybe Nat]) (x :: TK).
(BaseTensor target, KnownNat n, KnownShX sh, KnownSTK x) =>
Vector (target (TKX2 sh x))
-> target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
txfromVector (Vector (ADVal target (TKX2 sh x))
 -> ADVal target (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) x))
-> Vector (ADVal target (TKX2 sh x))
-> ADVal target (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) x)
forall a b. (a -> b) -> a -> b
$ [ADVal target (TKX2 sh x)] -> Vector (ADVal target (TKX2 sh x))
forall (v :: Type -> Type) a. Vector v a => [a] -> v a
V.fromList ([ADVal target (TKX2 sh x)] -> Vector (ADVal target (TKX2 sh x)))
-> [ADVal target (TKX2 sh x)] -> Vector (ADVal target (TKX2 sh x))
forall a b. (a -> b) -> a -> b
$ (Integer -> ADVal target (TKX2 sh x))
-> [Integer] -> [ADVal target (TKX2 sh x)]
forall a b. (a -> b) -> [a] -> [b]
map (target (TKScalar Int64) -> ADVal target (TKX2 sh x)
IntOf (ADVal target) -> ADVal target (TKX2 sh x)
f (target (TKScalar Int64) -> ADVal target (TKX2 sh x))
-> (Integer -> target (TKScalar Int64))
-> Integer
-> ADVal target (TKX2 sh x)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> target (TKScalar Int64)
forall a. Num a => Integer -> a
fromInteger) [Integer]
l
              -- hope this fuses

  -- Scalar ops
  tkconcrete :: forall r. GoodScalar r => r -> ADVal target (TKScalar r)
tkconcrete r
a =
    let v :: target (TKScalar r)
v = r -> target (TKScalar r)
forall r. GoodScalar r => r -> target (TKScalar r)
forall (target :: Target) r.
(BaseTensor target, GoodScalar r) =>
r -> target (TKScalar r)
tkconcrete r
a
    in FullShapeTK (TKScalar r)
-> target (TKScalar r) -> ADVal target (TKScalar r)
forall (z :: TK) (f :: Target). FullShapeTK z -> f z -> ADVal f z
fromPrimalFTK FullShapeTK (TKScalar r)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar target (TKScalar r)
v
  tkfloor :: forall r r2.
(GoodScalar r, RealFrac r, GoodScalar r2, Integral r2) =>
ADVal target (TKScalar r) -> ADVal target (TKScalar r2)
tkfloor (D target (TKScalar r)
u Delta target (TKScalar r)
_) =
    let v :: target (TKScalar r2)
v = target (TKScalar r) -> target (TKScalar r2)
forall r r2.
(GoodScalar r, RealFrac r, GoodScalar r2, Integral r2) =>
target (TKScalar r) -> target (TKScalar r2)
forall (target :: Target) r r2.
(BaseTensor target, GoodScalar r, RealFrac r, GoodScalar r2,
 Integral r2) =>
target (TKScalar r) -> target (TKScalar r2)
tkfloor target (TKScalar r)
u
    in FullShapeTK (TKScalar r2)
-> target (TKScalar r2) -> ADVal target (TKScalar r2)
forall (z :: TK) (f :: Target). FullShapeTK z -> f z -> ADVal f z
fromPrimalFTK FullShapeTK (TKScalar r2)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar target (TKScalar r2)
v
  tkfromIntegral :: forall r1 r2.
(GoodScalar r1, Integral r1, GoodScalar r2) =>
ADVal target (TKScalar r1) -> ADVal target (TKScalar r2)
tkfromIntegral (D target (TKScalar r1)
u Delta target (TKScalar r1)
_) =
    let v :: target (TKScalar r2)
v = target (TKScalar r1) -> target (TKScalar r2)
forall r1 r2.
(GoodScalar r1, Integral r1, GoodScalar r2) =>
target (TKScalar r1) -> target (TKScalar r2)
forall (target :: Target) r1 r2.
(BaseTensor target, GoodScalar r1, Integral r1, GoodScalar r2) =>
target (TKScalar r1) -> target (TKScalar r2)
tkfromIntegral target (TKScalar r1)
u
    in FullShapeTK (TKScalar r2)
-> target (TKScalar r2) -> ADVal target (TKScalar r2)
forall (z :: TK) (f :: Target). FullShapeTK z -> f z -> ADVal f z
fromPrimalFTK FullShapeTK (TKScalar r2)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar target (TKScalar r2)
v
  tkcast :: forall r1 r2.
(RealFrac r1, GoodScalar r1, RealFrac r2, GoodScalar r2) =>
ADVal target (TKScalar r1) -> ADVal target (TKScalar r2)
tkcast (D target (TKScalar r1)
u Delta target (TKScalar r1)
u') = target (TKScalar r2)
-> Delta target (TKScalar r2) -> ADVal target (TKScalar r2)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (target (TKScalar r1) -> target (TKScalar r2)
forall r1 r2.
(RealFrac r1, GoodScalar r1, RealFrac r2, GoodScalar r2) =>
target (TKScalar r1) -> target (TKScalar r2)
forall (target :: Target) r1 r2.
(BaseTensor target, RealFrac r1, GoodScalar r1, RealFrac r2,
 GoodScalar r2) =>
target (TKScalar r1) -> target (TKScalar r2)
tkcast target (TKScalar r1)
u) (Delta target (TKScalar r1) -> Delta target (TKScalar r2)
forall r1 r2 (a :: Target).
(GoodScalar r1, RealFrac r1, GoodScalar r2, RealFrac r2) =>
Delta a (TKScalar r1) -> Delta a (TKScalar r2)
DeltaCastK Delta target (TKScalar r1)
u')

  -- General operations that don't require LetTensor nor ShareTensor
  tftk :: forall (y :: TK). SingletonTK y -> ADVal target y -> FullShapeTK y
tftk SingletonTK y
_stk (D target y
_ Delta target y
u') = Delta target y -> FullShapeTK y
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target y
u'
  tconcrete :: forall (y :: TK). FullShapeTK y -> Concrete y -> ADVal target y
tconcrete FullShapeTK y
ftk Concrete y
t | Dict @TK KnownSTK y
Dict <- SingletonTK y -> Dict @TK KnownSTK y
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK (FullShapeTK y -> SingletonTK y
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK y
ftk) =
    FullShapeTK y -> target y -> ADVal target y
forall (z :: TK) (f :: Target). FullShapeTK z -> f z -> ADVal f z
fromPrimalFTK FullShapeTK y
ftk (target y -> ADVal target y) -> target y -> ADVal target y
forall a b. (a -> b) -> a -> b
$ FullShapeTK y -> Concrete y -> target y
forall (y :: TK). FullShapeTK y -> Concrete y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
FullShapeTK y -> Concrete y -> target y
tconcrete FullShapeTK y
ftk Concrete y
t
  tpair :: forall (x :: TK) (z :: TK).
ADVal target x -> ADVal target z -> ADVal target (TKProduct x z)
tpair (D target x
u Delta target x
u') (D target z
v Delta target z
v') = target (TKProduct x z)
-> Delta target (TKProduct x z) -> ADVal target (TKProduct x z)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dDnotShared (target x -> target z -> target (TKProduct x z)
forall (x :: TK) (z :: TK).
target x -> target z -> target (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair target x
u target z
v) (Delta target x -> Delta target z -> Delta target (TKProduct x z)
forall (y :: TK) (z :: TK) (a :: Target).
Delta a y -> Delta a z -> Delta a (TKProduct y z)
DeltaPair Delta target x
u' Delta target z
v')
  tproject1 :: forall (x :: TK) (z :: TK).
ADVal target (TKProduct x z) -> ADVal target x
tproject1 (D target (TKProduct x z)
u Delta target (TKProduct x z)
u') = target x -> Delta target x -> ADVal target x
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dDnotShared (target (TKProduct x z) -> target x
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 target (TKProduct x z)
u) ((Delta target x, Delta target z) -> Delta target x
forall a b. (a, b) -> a
fst ((Delta target x, Delta target z) -> Delta target x)
-> (Delta target x, Delta target z) -> Delta target x
forall a b. (a -> b) -> a -> b
$ Delta target (TKProduct x z) -> (Delta target x, Delta target z)
forall (target :: Target) (x :: TK) (y :: TK).
Delta target (TKProduct x y) -> (Delta target x, Delta target y)
unDeltaPairUnshared Delta target (TKProduct x z)
u')
  tproject2 :: forall (x :: TK) (z :: TK).
ADVal target (TKProduct x z) -> ADVal target z
tproject2 (D target (TKProduct x z)
u Delta target (TKProduct x z)
u') = target z -> Delta target z -> ADVal target z
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dDnotShared (target (TKProduct x z) -> target z
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target z
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target z
tproject2 target (TKProduct x z)
u) ((Delta target x, Delta target z) -> Delta target z
forall a b. (a, b) -> b
snd ((Delta target x, Delta target z) -> Delta target z)
-> (Delta target x, Delta target z) -> Delta target z
forall a b. (a -> b) -> a -> b
$ Delta target (TKProduct x z) -> (Delta target x, Delta target z)
forall (target :: Target) (x :: TK) (y :: TK).
Delta target (TKProduct x y) -> (Delta target x, Delta target y)
unDeltaPairUnshared Delta target (TKProduct x z)
u')
  tsreplicate :: forall (sh :: [Nat]) (k :: Nat) (x :: TK).
KnownSTK x =>
SNat k
-> ShS sh
-> ADVal target (TKS2 sh x)
-> ADVal target (TKS2 ((':) @Nat k sh) x)
tsreplicate SNat k
snat ShS sh
sh (D target (TKS2 sh x)
u Delta target (TKS2 sh x)
u') =
    target (TKS2 ((':) @Nat k sh) x)
-> Delta target (TKS2 ((':) @Nat k sh) x)
-> ADVal target (TKS2 ((':) @Nat k sh) x)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (SNat k
-> ShS sh -> target (TKS2 sh x) -> target (TKS2 ((':) @Nat k sh) x)
forall (sh :: [Nat]) (k :: Nat) (x :: TK).
KnownSTK x =>
SNat k
-> ShS sh -> target (TKS2 sh x) -> target (TKS2 ((':) @Nat k sh) x)
forall (target :: Target) (sh :: [Nat]) (k :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
SNat k
-> ShS sh -> target (TKS2 sh x) -> target (TKS2 ((':) @Nat k sh) x)
tsreplicate SNat k
snat ShS sh
sh target (TKS2 sh x)
u) (SNat k
-> SingletonTK (TKS2 sh x)
-> Delta target (TKS2 sh x)
-> Delta target (BuildTensorKind k (TKS2 sh x))
forall (y :: TK) (k :: Nat) (a :: Target).
SNat k
-> SingletonTK y -> Delta a y -> Delta a (BuildTensorKind k y)
DeltaReplicate SNat k
snat (ShS sh -> SingletonTK x -> SingletonTK (TKS2 sh x)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> SingletonTK x -> SingletonTK (TKS2 sh x)
STKS ShS sh
sh SingletonTK x
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK) Delta target (TKS2 sh x)
u')
  tstranspose :: forall (perm :: [Nat]) (sh :: [Nat]) (x :: TK).
(IsPermutation perm, (<=) @Nat (Rank @Nat perm) (Rank @Nat sh),
 KnownSTK x) =>
Perm perm
-> ADVal target (TKS2 sh x)
-> ADVal target (TKS2 (PermutePrefix @Nat perm sh) x)
tstranspose Perm perm
perm (D target (TKS2 sh x)
u Delta target (TKS2 sh x)
u') =
    target
  (TKS2
     ((++)
        @Nat
        (Permute @Nat perm (TakeLen @Nat @Nat perm sh))
        (DropLen @Nat @Nat perm sh))
     x)
-> Delta
     target
     (TKS2
        ((++)
           @Nat
           (Permute @Nat perm (TakeLen @Nat @Nat perm sh))
           (DropLen @Nat @Nat perm sh))
        x)
-> ADVal
     target
     (TKS2
        ((++)
           @Nat
           (Permute @Nat perm (TakeLen @Nat @Nat perm sh))
           (DropLen @Nat @Nat perm sh))
        x)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (Perm perm
-> target (TKS2 sh x)
-> target
     (TKS2
        ((++)
           @Nat
           (Permute @Nat perm (TakeLen @Nat @Nat perm sh))
           (DropLen @Nat @Nat perm sh))
        x)
forall (perm :: [Nat]) (sh :: [Nat]) (x :: TK).
(IsPermutation perm, (<=) @Nat (Rank @Nat perm) (Rank @Nat sh),
 KnownSTK x) =>
Perm perm
-> target (TKS2 sh x)
-> target (TKS2 (PermutePrefix @Nat perm sh) x)
forall (target :: Target) (perm :: [Nat]) (sh :: [Nat]) (x :: TK).
(BaseTensor target, IsPermutation perm,
 (<=) @Nat (Rank @Nat perm) (Rank @Nat sh), KnownSTK x) =>
Perm perm
-> target (TKS2 sh x)
-> target (TKS2 (PermutePrefix @Nat perm sh) x)
tstranspose Perm perm
perm target (TKS2 sh x)
u) (forall (perm :: [Nat]) (sh :: [Nat]) (r :: TK) (a :: Target).
(IsPermutation perm, (<=) @Nat (Rank @Nat perm) (Rank @Nat sh)) =>
Perm perm
-> Delta a (TKS2 sh r)
-> Delta a (TKS2 (PermutePrefix @Nat perm sh) r)
DeltaTransposeS @_ @_ @_ @target Perm perm
perm Delta target (TKS2 sh x)
u')
  tsreshape :: forall (sh :: [Nat]) (sh2 :: [Nat]) (x :: TK).
((Product sh :: Nat) ~ (Product sh2 :: Nat), KnownSTK x) =>
ShS sh2 -> ADVal target (TKS2 sh x) -> ADVal target (TKS2 sh2 x)
tsreshape ShS sh2
sh (D target (TKS2 sh x)
u Delta target (TKS2 sh x)
u') = target (TKS2 sh2 x)
-> Delta target (TKS2 sh2 x) -> ADVal target (TKS2 sh2 x)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (ShS sh2 -> target (TKS2 sh x) -> target (TKS2 sh2 x)
forall (sh :: [Nat]) (sh2 :: [Nat]) (x :: TK).
((Product sh :: Nat) ~ (Product sh2 :: Nat), KnownSTK x) =>
ShS sh2 -> target (TKS2 sh x) -> target (TKS2 sh2 x)
forall (target :: Target) (sh :: [Nat]) (sh2 :: [Nat]) (x :: TK).
(BaseTensor target, (Product sh :: Nat) ~ (Product sh2 :: Nat),
 KnownSTK x) =>
ShS sh2 -> target (TKS2 sh x) -> target (TKS2 sh2 x)
tsreshape ShS sh2
sh target (TKS2 sh x)
u) (ShS sh2 -> Delta target (TKS2 sh x) -> Delta target (TKS2 sh2 x)
forall (sh :: [Nat]) (sh2 :: [Nat]) (a :: Target) (r :: TK).
((Product sh :: Nat) ~ (Product sh2 :: Nat)) =>
ShS sh2 -> Delta a (TKS2 sh r) -> Delta a (TKS2 sh2 r)
DeltaReshapeS ShS sh2
sh Delta target (TKS2 sh x)
u')
  tmapAccumRDer :: forall (accy :: TK) (by :: TK) (ey :: TK) (k :: Nat).
Proxy @Target (ADVal target)
-> SNat k
-> FullShapeTK accy
-> FullShapeTK by
-> FullShapeTK ey
-> HFunOf (ADVal target) (TKProduct accy ey) (TKProduct accy by)
-> HFunOf
     (ADVal target)
     (TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
     (ADTensorKind (TKProduct accy by))
-> HFunOf
     (ADVal target)
     (TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
     (ADTensorKind (TKProduct accy ey))
-> ADVal target accy
-> ADVal target (BuildTensorKind k ey)
-> ADVal target (TKProduct accy (BuildTensorKind k by))
tmapAccumRDer @accy @by @ey Proxy @Target (ADVal target)
_ !SNat k
k FullShapeTK accy
accftk FullShapeTK by
bftk FullShapeTK ey
eftk HFunOf (ADVal target) (TKProduct accy ey) (TKProduct accy by)
f HFunOf
  (ADVal target)
  (TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
  (ADTensorKind (TKProduct accy by))
df HFunOf
  (ADVal target)
  (TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
  (ADTensorKind (TKProduct accy ey))
rf ADVal target accy
acc0D ADVal target (BuildTensorKind k ey)
esD
   | Dict @TK KnownSTK (BuildTensorKind k accy)
Dict <- SNat k
-> SingletonTK accy -> Dict @TK KnownSTK (BuildTensorKind k accy)
forall (k :: Nat) (y :: TK).
SNat k -> SingletonTK y -> Dict @TK KnownSTK (BuildTensorKind k y)
lemKnownSTKOfBuild SNat k
k (FullShapeTK accy -> SingletonTK accy
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK accy
accftk)
   , Dict @TK KnownSTK (BuildTensorKind k ey)
Dict <- SNat k
-> SingletonTK ey -> Dict @TK KnownSTK (BuildTensorKind k ey)
forall (k :: Nat) (y :: TK).
SNat k -> SingletonTK y -> Dict @TK KnownSTK (BuildTensorKind k y)
lemKnownSTKOfBuild SNat k
k (FullShapeTK ey -> SingletonTK ey
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK ey
eftk) =
    let !(D target accy
acc0 Delta target accy
acc0') = ADVal target accy
acc0D in
    let !(D target (BuildTensorKind k ey)
esNotShared Delta target (BuildTensorKind k ey)
es') = ADVal target (BuildTensorKind k ey)
esD in
    let !es :: target (BuildTensorKind k ey)
es = target (BuildTensorKind k ey) -> target (BuildTensorKind k ey)
forall (y :: TK). target y -> target y
forall (target :: Target) (y :: TK).
ShareTensor target =>
target y -> target y
tshare target (BuildTensorKind k ey)
esNotShared
        codomainShs :: FullShapeTK (TKProduct accy by)
codomainShs = FullShapeTK accy
-> FullShapeTK by -> FullShapeTK (TKProduct accy by)
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct FullShapeTK accy
accftk FullShapeTK by
bftk
        g :: forall f. ADReady f
          => f (TKProduct accy ey)
          -> f (TKProduct accy (TKProduct accy by))
        g :: forall (f :: Target).
ADReady f =>
f (TKProduct accy ey) -> f (TKProduct accy (TKProduct accy by))
g !f (TKProduct accy ey)
acc_e =
          f (TKProduct accy ey)
-> (f (TKProduct accy ey)
    -> f (TKProduct accy (TKProduct accy by)))
-> f (TKProduct accy (TKProduct accy by))
forall (x :: TK) (z :: TK). f x -> (f x -> f z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
LetTensor target =>
target x -> (target x -> target z) -> target z
ttlet f (TKProduct accy ey)
acc_e ((f (TKProduct accy ey) -> f (TKProduct accy (TKProduct accy by)))
 -> f (TKProduct accy (TKProduct accy by)))
-> (f (TKProduct accy ey)
    -> f (TKProduct accy (TKProduct accy by)))
-> f (TKProduct accy (TKProduct accy by))
forall a b. (a -> b) -> a -> b
$ \ !f (TKProduct accy ey)
acc_e1 ->
          f (TKProduct accy by)
-> (f (TKProduct accy by)
    -> f (TKProduct accy (TKProduct accy by)))
-> f (TKProduct accy (TKProduct accy by))
forall (x :: TK) (z :: TK). f x -> (f x -> f z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
LetTensor target =>
target x -> (target x -> target z) -> target z
ttlet (HFun (TKProduct accy ey) (TKProduct accy by)
-> forall (f :: Target).
   ADReady f =>
   f (TKProduct accy ey) -> f (TKProduct accy by)
forall (x :: TK) (z :: TK).
HFun x z -> forall (f :: Target). ADReady f => f x -> f z
unHFun HFunOf (ADVal target) (TKProduct accy ey) (TKProduct accy by)
HFun (TKProduct accy ey) (TKProduct accy by)
f f (TKProduct accy ey)
acc_e) ((f (TKProduct accy by) -> f (TKProduct accy (TKProduct accy by)))
 -> f (TKProduct accy (TKProduct accy by)))
-> (f (TKProduct accy by)
    -> f (TKProduct accy (TKProduct accy by)))
-> f (TKProduct accy (TKProduct accy by))
forall a b. (a -> b) -> a -> b
$ \ !f (TKProduct accy by)
accRes_bRes ->
            f accy
-> f (TKProduct accy by) -> f (TKProduct accy (TKProduct accy by))
forall (x :: TK) (z :: TK). f x -> f z -> f (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair (f (TKProduct accy by) -> f accy
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 f (TKProduct accy by)
accRes_bRes)
                  (f accy -> f by -> f (TKProduct accy by)
forall (x :: TK) (z :: TK). f x -> f z -> f (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair (f (TKProduct accy ey) -> f accy
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 f (TKProduct accy ey)
acc_e1) (f (TKProduct accy by) -> f by
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target z
tproject2 f (TKProduct accy by)
accRes_bRes))
        dg :: forall f. ADReady f
           => f (TKProduct (ADTensorKind (TKProduct accy ey))
                           (TKProduct accy ey))
           -> f (ADTensorKind (TKProduct accy (TKProduct accy by)))
        dg :: forall (f :: Target).
ADReady f =>
f (TKProduct
     (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
-> f (ADTensorKind (TKProduct accy (TKProduct accy by)))
dg !f (TKProduct
     (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
dacc_de_acc_e =
          f (TKProduct
     (TKProduct (ADTensorKind accy) (ADTensorKind ey))
     (TKProduct accy ey))
-> (f (TKProduct
         (TKProduct (ADTensorKind accy) (ADTensorKind ey))
         (TKProduct accy ey))
    -> f (ADTensorKind (TKProduct accy (TKProduct accy by))))
-> f (ADTensorKind (TKProduct accy (TKProduct accy by)))
forall (x :: TK) (z :: TK). f x -> (f x -> f z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
LetTensor target =>
target x -> (target x -> target z) -> target z
ttlet f (TKProduct
     (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
f (TKProduct
     (TKProduct (ADTensorKind accy) (ADTensorKind ey))
     (TKProduct accy ey))
dacc_de_acc_e ((f (TKProduct
       (TKProduct (ADTensorKind accy) (ADTensorKind ey))
       (TKProduct accy ey))
  -> f (ADTensorKind (TKProduct accy (TKProduct accy by))))
 -> f (ADTensorKind (TKProduct accy (TKProduct accy by))))
-> (f (TKProduct
         (TKProduct (ADTensorKind accy) (ADTensorKind ey))
         (TKProduct accy ey))
    -> f (ADTensorKind (TKProduct accy (TKProduct accy by))))
-> f (ADTensorKind (TKProduct accy (TKProduct accy by)))
forall a b. (a -> b) -> a -> b
$ \ !f (TKProduct
     (TKProduct (ADTensorKind accy) (ADTensorKind ey))
     (TKProduct accy ey))
dacc_de_acc_e1 ->
            let (!f (TKProduct (ADTensorKind accy) (ADTensorKind ey))
dacc_de, !f (TKProduct accy ey)
_acc_e) =
                  (f (TKProduct
     (TKProduct (ADTensorKind accy) (ADTensorKind ey))
     (TKProduct accy ey))
-> f (TKProduct (ADTensorKind accy) (ADTensorKind ey))
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 f (TKProduct
     (TKProduct (ADTensorKind accy) (ADTensorKind ey))
     (TKProduct accy ey))
dacc_de_acc_e1, f (TKProduct
     (TKProduct (ADTensorKind accy) (ADTensorKind ey))
     (TKProduct accy ey))
-> f (TKProduct accy ey)
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target z
tproject2 f (TKProduct
     (TKProduct (ADTensorKind accy) (ADTensorKind ey))
     (TKProduct accy ey))
dacc_de_acc_e1)
                !dacc1 :: f (ADTensorKind accy)
dacc1 = f (TKProduct (ADTensorKind accy) (ADTensorKind ey))
-> f (ADTensorKind accy)
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 f (TKProduct (ADTensorKind accy) (ADTensorKind ey))
dacc_de
            in f (TKProduct (ADTensorKind accy) (ADTensorKind by))
-> (f (TKProduct (ADTensorKind accy) (ADTensorKind by))
    -> f (ADTensorKind (TKProduct accy (TKProduct accy by))))
-> f (ADTensorKind (TKProduct accy (TKProduct accy by)))
forall (x :: TK) (z :: TK). f x -> (f x -> f z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
LetTensor target =>
target x -> (target x -> target z) -> target z
ttlet (HFun
  (TKProduct
     (TKProduct (ADTensorKind accy) (ADTensorKind ey))
     (TKProduct accy ey))
  (TKProduct (ADTensorKind accy) (ADTensorKind by))
-> forall (f :: Target).
   ADReady f =>
   f (TKProduct
        (TKProduct (ADTensorKind accy) (ADTensorKind ey))
        (TKProduct accy ey))
   -> f (TKProduct (ADTensorKind accy) (ADTensorKind by))
forall (x :: TK) (z :: TK).
HFun x z -> forall (f :: Target). ADReady f => f x -> f z
unHFun HFunOf
  (ADVal target)
  (TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
  (ADTensorKind (TKProduct accy by))
HFun
  (TKProduct
     (TKProduct (ADTensorKind accy) (ADTensorKind ey))
     (TKProduct accy ey))
  (TKProduct (ADTensorKind accy) (ADTensorKind by))
df f (TKProduct
     (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
f (TKProduct
     (TKProduct (ADTensorKind accy) (ADTensorKind ey))
     (TKProduct accy ey))
dacc_de_acc_e) ((f (TKProduct (ADTensorKind accy) (ADTensorKind by))
  -> f (ADTensorKind (TKProduct accy (TKProduct accy by))))
 -> f (ADTensorKind (TKProduct accy (TKProduct accy by))))
-> (f (TKProduct (ADTensorKind accy) (ADTensorKind by))
    -> f (ADTensorKind (TKProduct accy (TKProduct accy by))))
-> f (ADTensorKind (TKProduct accy (TKProduct accy by)))
forall a b. (a -> b) -> a -> b
$ \ !f (TKProduct (ADTensorKind accy) (ADTensorKind by))
accRes_bRes ->
                 f (ADTensorKind accy)
-> f (TKProduct (ADTensorKind accy) (ADTensorKind by))
-> f (TKProduct
        (ADTensorKind accy)
        (TKProduct (ADTensorKind accy) (ADTensorKind by)))
forall (x :: TK) (z :: TK). f x -> f z -> f (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair (f (TKProduct (ADTensorKind accy) (ADTensorKind by))
-> f (ADTensorKind accy)
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 f (TKProduct (ADTensorKind accy) (ADTensorKind by))
accRes_bRes)
                       (f (ADTensorKind accy)
-> f (ADTensorKind by)
-> f (TKProduct (ADTensorKind accy) (ADTensorKind by))
forall (x :: TK) (z :: TK). f x -> f z -> f (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair f (ADTensorKind accy)
dacc1 (f (TKProduct (ADTensorKind accy) (ADTensorKind by))
-> f (ADTensorKind by)
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target z
tproject2 f (TKProduct (ADTensorKind accy) (ADTensorKind by))
accRes_bRes))
        rg :: forall f. ADReady f
           => f (TKProduct (ADTensorKind (TKProduct accy
                                         (TKProduct accy by)))
                           (TKProduct accy ey))
           -> f (ADTensorKind (TKProduct accy ey))
        rg :: forall (f :: Target).
ADReady f =>
f (TKProduct
     (ADTensorKind (TKProduct accy (TKProduct accy by)))
     (TKProduct accy ey))
-> f (ADTensorKind (TKProduct accy ey))
rg !f (TKProduct
     (ADTensorKind (TKProduct accy (TKProduct accy by)))
     (TKProduct accy ey))
args =
          f (TKProduct
     (TKProduct
        (ADTensorKind accy)
        (TKProduct (ADTensorKind accy) (ADTensorKind by)))
     (TKProduct accy ey))
-> (f (TKProduct
         (TKProduct
            (ADTensorKind accy)
            (TKProduct (ADTensorKind accy) (ADTensorKind by)))
         (TKProduct accy ey))
    -> f (ADTensorKind (TKProduct accy ey)))
-> f (ADTensorKind (TKProduct accy ey))
forall (x :: TK) (z :: TK). f x -> (f x -> f z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
LetTensor target =>
target x -> (target x -> target z) -> target z
ttlet f (TKProduct
     (ADTensorKind (TKProduct accy (TKProduct accy by)))
     (TKProduct accy ey))
f (TKProduct
     (TKProduct
        (ADTensorKind accy)
        (TKProduct (ADTensorKind accy) (ADTensorKind by)))
     (TKProduct accy ey))
args ((f (TKProduct
       (TKProduct
          (ADTensorKind accy)
          (TKProduct (ADTensorKind accy) (ADTensorKind by)))
       (TKProduct accy ey))
  -> f (ADTensorKind (TKProduct accy ey)))
 -> f (ADTensorKind (TKProduct accy ey)))
-> (f (TKProduct
         (TKProduct
            (ADTensorKind accy)
            (TKProduct (ADTensorKind accy) (ADTensorKind by)))
         (TKProduct accy ey))
    -> f (ADTensorKind (TKProduct accy ey)))
-> f (ADTensorKind (TKProduct accy ey))
forall a b. (a -> b) -> a -> b
$ \ f (TKProduct
     (TKProduct
        (ADTensorKind accy)
        (TKProduct (ADTensorKind accy) (ADTensorKind by)))
     (TKProduct accy ey))
args1 ->
            let (!f (TKProduct
     (ADTensorKind accy)
     (TKProduct (ADTensorKind accy) (ADTensorKind by)))
dx_db, !f (TKProduct accy ey)
acc_e) = (f (TKProduct
     (TKProduct
        (ADTensorKind accy)
        (TKProduct (ADTensorKind accy) (ADTensorKind by)))
     (TKProduct accy ey))
-> f (TKProduct
        (ADTensorKind accy)
        (TKProduct (ADTensorKind accy) (ADTensorKind by)))
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 f (TKProduct
     (TKProduct
        (ADTensorKind accy)
        (TKProduct (ADTensorKind accy) (ADTensorKind by)))
     (TKProduct accy ey))
args1, f (TKProduct
     (TKProduct
        (ADTensorKind accy)
        (TKProduct (ADTensorKind accy) (ADTensorKind by)))
     (TKProduct accy ey))
-> f (TKProduct accy ey)
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target z
tproject2 f (TKProduct
     (TKProduct
        (ADTensorKind accy)
        (TKProduct (ADTensorKind accy) (ADTensorKind by)))
     (TKProduct accy ey))
args1)
            in f (TKProduct
     (ADTensorKind accy)
     (TKProduct (ADTensorKind accy) (ADTensorKind by)))
-> (f (TKProduct
         (ADTensorKind accy)
         (TKProduct (ADTensorKind accy) (ADTensorKind by)))
    -> f (ADTensorKind (TKProduct accy ey)))
-> f (ADTensorKind (TKProduct accy ey))
forall (x :: TK) (z :: TK). f x -> (f x -> f z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
LetTensor target =>
target x -> (target x -> target z) -> target z
ttlet f (TKProduct
     (ADTensorKind accy)
     (TKProduct (ADTensorKind accy) (ADTensorKind by)))
dx_db ((f (TKProduct
       (ADTensorKind accy)
       (TKProduct (ADTensorKind accy) (ADTensorKind by)))
  -> f (ADTensorKind (TKProduct accy ey)))
 -> f (ADTensorKind (TKProduct accy ey)))
-> (f (TKProduct
         (ADTensorKind accy)
         (TKProduct (ADTensorKind accy) (ADTensorKind by)))
    -> f (ADTensorKind (TKProduct accy ey)))
-> f (ADTensorKind (TKProduct accy ey))
forall a b. (a -> b) -> a -> b
$ \ !f (TKProduct
     (ADTensorKind accy)
     (TKProduct (ADTensorKind accy) (ADTensorKind by)))
dx_db1 ->
              let (!f (ADTensorKind accy)
dx, !f (TKProduct (ADTensorKind accy) (ADTensorKind by))
db) = (f (TKProduct
     (ADTensorKind accy)
     (TKProduct (ADTensorKind accy) (ADTensorKind by)))
-> f (ADTensorKind accy)
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 f (TKProduct
     (ADTensorKind accy)
     (TKProduct (ADTensorKind accy) (ADTensorKind by)))
dx_db1, f (TKProduct
     (ADTensorKind accy)
     (TKProduct (ADTensorKind accy) (ADTensorKind by)))
-> f (TKProduct (ADTensorKind accy) (ADTensorKind by))
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target z
tproject2 f (TKProduct
     (ADTensorKind accy)
     (TKProduct (ADTensorKind accy) (ADTensorKind by)))
dx_db1)
              in f (TKProduct (ADTensorKind accy) (ADTensorKind by))
-> (f (TKProduct (ADTensorKind accy) (ADTensorKind by))
    -> f (ADTensorKind (TKProduct accy ey)))
-> f (ADTensorKind (TKProduct accy ey))
forall (x :: TK) (z :: TK). f x -> (f x -> f z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
LetTensor target =>
target x -> (target x -> target z) -> target z
ttlet f (TKProduct (ADTensorKind accy) (ADTensorKind by))
db ((f (TKProduct (ADTensorKind accy) (ADTensorKind by))
  -> f (ADTensorKind (TKProduct accy ey)))
 -> f (ADTensorKind (TKProduct accy ey)))
-> (f (TKProduct (ADTensorKind accy) (ADTensorKind by))
    -> f (ADTensorKind (TKProduct accy ey)))
-> f (ADTensorKind (TKProduct accy ey))
forall a b. (a -> b) -> a -> b
$ \ !f (TKProduct (ADTensorKind accy) (ADTensorKind by))
db1 ->
                let dx_dbRes :: f (TKProduct (ADTensorKind accy) (ADTensorKind by))
dx_dbRes = f (ADTensorKind accy)
-> f (ADTensorKind by)
-> f (TKProduct (ADTensorKind accy) (ADTensorKind by))
forall (x :: TK) (z :: TK). f x -> f z -> f (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair f (ADTensorKind accy)
dx (f (TKProduct (ADTensorKind accy) (ADTensorKind by))
-> f (ADTensorKind by)
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target z
tproject2 f (TKProduct (ADTensorKind accy) (ADTensorKind by))
db1)
                in f (TKProduct (ADTensorKind accy) (ADTensorKind ey))
-> (f (TKProduct (ADTensorKind accy) (ADTensorKind ey))
    -> f (ADTensorKind (TKProduct accy ey)))
-> f (ADTensorKind (TKProduct accy ey))
forall (x :: TK) (z :: TK). f x -> (f x -> f z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
LetTensor target =>
target x -> (target x -> target z) -> target z
ttlet (HFun
  (TKProduct
     (TKProduct (ADTensorKind accy) (ADTensorKind by))
     (TKProduct accy ey))
  (TKProduct (ADTensorKind accy) (ADTensorKind ey))
-> forall (f :: Target).
   ADReady f =>
   f (TKProduct
        (TKProduct (ADTensorKind accy) (ADTensorKind by))
        (TKProduct accy ey))
   -> f (TKProduct (ADTensorKind accy) (ADTensorKind ey))
forall (x :: TK) (z :: TK).
HFun x z -> forall (f :: Target). ADReady f => f x -> f z
unHFun HFunOf
  (ADVal target)
  (TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
  (ADTensorKind (TKProduct accy ey))
HFun
  (TKProduct
     (TKProduct (ADTensorKind accy) (ADTensorKind by))
     (TKProduct accy ey))
  (TKProduct (ADTensorKind accy) (ADTensorKind ey))
rf (f (TKProduct (ADTensorKind accy) (ADTensorKind by))
-> f (TKProduct accy ey)
-> f (TKProduct
        (TKProduct (ADTensorKind accy) (ADTensorKind by))
        (TKProduct accy ey))
forall (x :: TK) (z :: TK). f x -> f z -> f (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair f (TKProduct (ADTensorKind accy) (ADTensorKind by))
dx_dbRes f (TKProduct accy ey)
acc_e))
                   ((f (TKProduct (ADTensorKind accy) (ADTensorKind ey))
  -> f (ADTensorKind (TKProduct accy ey)))
 -> f (ADTensorKind (TKProduct accy ey)))
-> (f (TKProduct (ADTensorKind accy) (ADTensorKind ey))
    -> f (ADTensorKind (TKProduct accy ey)))
-> f (ADTensorKind (TKProduct accy ey))
forall a b. (a -> b) -> a -> b
$ \ !f (TKProduct (ADTensorKind accy) (ADTensorKind ey))
daccRes_deRes ->
                  let added :: f (ADTensorKind accy)
added = SingletonTK (ADTensorKind accy)
-> f (ADTensorKind accy)
-> f (ADTensorKind accy)
-> f (ADTensorKind accy)
forall (y :: TK). SingletonTK y -> f y -> f y -> f y
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> target y -> target y
taddTarget (SingletonTK accy -> SingletonTK (ADTensorKind accy)
forall (y :: TK). SingletonTK y -> SingletonTK (ADTensorKind y)
adSTK (SingletonTK accy -> SingletonTK (ADTensorKind accy))
-> SingletonTK accy -> SingletonTK (ADTensorKind accy)
forall a b. (a -> b) -> a -> b
$ FullShapeTK accy -> SingletonTK accy
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK accy
accftk)
                                         (f (TKProduct (ADTensorKind accy) (ADTensorKind ey))
-> f (ADTensorKind accy)
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 f (TKProduct (ADTensorKind accy) (ADTensorKind ey))
daccRes_deRes)
                                         (f (TKProduct (ADTensorKind accy) (ADTensorKind by))
-> f (ADTensorKind accy)
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 f (TKProduct (ADTensorKind accy) (ADTensorKind by))
db1)
                  in f (ADTensorKind accy)
-> f (ADTensorKind ey)
-> f (TKProduct (ADTensorKind accy) (ADTensorKind ey))
forall (x :: TK) (z :: TK). f x -> f z -> f (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair f (ADTensorKind accy)
added (f (TKProduct (ADTensorKind accy) (ADTensorKind ey))
-> f (ADTensorKind ey)
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target z
tproject2 f (TKProduct (ADTensorKind accy) (ADTensorKind ey))
daccRes_deRes)
        p :: target (TKProduct accy (BuildTensorKind k (TKProduct accy by)))
p = Proxy @Target target
-> SNat k
-> FullShapeTK accy
-> FullShapeTK (TKProduct accy by)
-> FullShapeTK ey
-> HFunOf
     target (TKProduct accy ey) (TKProduct accy (TKProduct accy by))
-> HFunOf
     target
     (TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
     (ADTensorKind (TKProduct accy (TKProduct accy by)))
-> HFunOf
     target
     (TKProduct
        (ADTensorKind (TKProduct accy (TKProduct accy by)))
        (TKProduct accy ey))
     (ADTensorKind (TKProduct accy ey))
-> target accy
-> target (BuildTensorKind k ey)
-> target (TKProduct accy (BuildTensorKind k (TKProduct accy by)))
forall (accy :: TK) (by :: TK) (ey :: TK) (k :: Nat).
Proxy @Target target
-> SNat k
-> FullShapeTK accy
-> FullShapeTK by
-> FullShapeTK ey
-> HFunOf target (TKProduct accy ey) (TKProduct accy by)
-> HFunOf
     target
     (TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
     (ADTensorKind (TKProduct accy by))
-> HFunOf
     target
     (TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
     (ADTensorKind (TKProduct accy ey))
-> target accy
-> target (BuildTensorKind k ey)
-> target (TKProduct accy (BuildTensorKind k by))
forall (target :: Target) (accy :: TK) (by :: TK) (ey :: TK)
       (k :: Nat).
BaseTensor target =>
Proxy @Target target
-> SNat k
-> FullShapeTK accy
-> FullShapeTK by
-> FullShapeTK ey
-> HFunOf target (TKProduct accy ey) (TKProduct accy by)
-> HFunOf
     target
     (TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
     (ADTensorKind (TKProduct accy by))
-> HFunOf
     target
     (TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
     (ADTensorKind (TKProduct accy ey))
-> target accy
-> target (BuildTensorKind k ey)
-> target (TKProduct accy (BuildTensorKind k by))
tmapAccumRDer (forall {k} (t :: k). Proxy @k t
forall (t :: Target). Proxy @Target t
Proxy @target)
                          SNat k
k FullShapeTK accy
accftk FullShapeTK (TKProduct accy by)
codomainShs FullShapeTK ey
eftk
                          (forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
FullShapeTK x -> HFun x z -> HFunOf target x z
tlambda @target (FullShapeTK accy
-> FullShapeTK ey -> FullShapeTK (TKProduct accy ey)
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct FullShapeTK accy
accftk FullShapeTK ey
eftk)
                           (HFun (TKProduct accy ey) (TKProduct accy (TKProduct accy by))
 -> HFunOf
      target (TKProduct accy ey) (TKProduct accy (TKProduct accy by)))
-> HFun (TKProduct accy ey) (TKProduct accy (TKProduct accy by))
-> HFunOf
     target (TKProduct accy ey) (TKProduct accy (TKProduct accy by))
forall a b. (a -> b) -> a -> b
$ (forall (f :: Target).
 ADReady f =>
 f (TKProduct accy ey) -> f (TKProduct accy (TKProduct accy by)))
-> HFun (TKProduct accy ey) (TKProduct accy (TKProduct accy by))
forall (x :: TK) (z :: TK).
(forall (f :: Target). ADReady f => f x -> f z) -> HFun x z
HFun f (TKProduct accy ey) -> f (TKProduct accy (TKProduct accy by))
forall (f :: Target).
ADReady f =>
f (TKProduct accy ey) -> f (TKProduct accy (TKProduct accy by))
g)
                          (forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
FullShapeTK x -> HFun x z -> HFunOf target x z
tlambda @target
                             (FullShapeTK (TKProduct (ADTensorKind accy) (ADTensorKind ey))
-> FullShapeTK (TKProduct accy ey)
-> FullShapeTK
     (TKProduct
        (TKProduct (ADTensorKind accy) (ADTensorKind ey))
        (TKProduct accy ey))
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct (FullShapeTK (TKProduct accy ey)
-> FullShapeTK (ADTensorKind (TKProduct accy ey))
forall (y :: TK). FullShapeTK y -> FullShapeTK (ADTensorKind y)
adFTK (FullShapeTK accy
-> FullShapeTK ey -> FullShapeTK (TKProduct accy ey)
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct FullShapeTK accy
accftk FullShapeTK ey
eftk))
                                         (FullShapeTK accy
-> FullShapeTK ey -> FullShapeTK (TKProduct accy ey)
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct FullShapeTK accy
accftk FullShapeTK ey
eftk))
                           (HFun
   (TKProduct
      (TKProduct (ADTensorKind accy) (ADTensorKind ey))
      (TKProduct accy ey))
   (TKProduct
      (ADTensorKind accy)
      (TKProduct (ADTensorKind accy) (ADTensorKind by)))
 -> HFunOf
      target
      (TKProduct
         (TKProduct (ADTensorKind accy) (ADTensorKind ey))
         (TKProduct accy ey))
      (TKProduct
         (ADTensorKind accy)
         (TKProduct (ADTensorKind accy) (ADTensorKind by))))
-> HFun
     (TKProduct
        (TKProduct (ADTensorKind accy) (ADTensorKind ey))
        (TKProduct accy ey))
     (TKProduct
        (ADTensorKind accy)
        (TKProduct (ADTensorKind accy) (ADTensorKind by)))
-> HFunOf
     target
     (TKProduct
        (TKProduct (ADTensorKind accy) (ADTensorKind ey))
        (TKProduct accy ey))
     (TKProduct
        (ADTensorKind accy)
        (TKProduct (ADTensorKind accy) (ADTensorKind by)))
forall a b. (a -> b) -> a -> b
$ (forall (f :: Target).
 ADReady f =>
 f (TKProduct
      (TKProduct (ADTensorKind accy) (ADTensorKind ey))
      (TKProduct accy ey))
 -> f (TKProduct
         (ADTensorKind accy)
         (TKProduct (ADTensorKind accy) (ADTensorKind by))))
-> HFun
     (TKProduct
        (TKProduct (ADTensorKind accy) (ADTensorKind ey))
        (TKProduct accy ey))
     (TKProduct
        (ADTensorKind accy)
        (TKProduct (ADTensorKind accy) (ADTensorKind by)))
forall (x :: TK) (z :: TK).
(forall (f :: Target). ADReady f => f x -> f z) -> HFun x z
HFun f (TKProduct
     (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
-> f (ADTensorKind (TKProduct accy (TKProduct accy by)))
f (TKProduct
     (TKProduct (ADTensorKind accy) (ADTensorKind ey))
     (TKProduct accy ey))
-> f (TKProduct
        (ADTensorKind accy)
        (TKProduct (ADTensorKind accy) (ADTensorKind by)))
forall (f :: Target).
ADReady f =>
f (TKProduct
     (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
-> f (ADTensorKind (TKProduct accy (TKProduct accy by)))
forall (f :: Target).
ADReady f =>
f (TKProduct
     (TKProduct (ADTensorKind accy) (ADTensorKind ey))
     (TKProduct accy ey))
-> f (TKProduct
        (ADTensorKind accy)
        (TKProduct (ADTensorKind accy) (ADTensorKind by)))
dg)
                          (forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
FullShapeTK x -> HFun x z -> HFunOf target x z
tlambda @target
                             (FullShapeTK
  (TKProduct
     (ADTensorKind accy)
     (TKProduct (ADTensorKind accy) (ADTensorKind by)))
-> FullShapeTK (TKProduct accy ey)
-> FullShapeTK
     (TKProduct
        (TKProduct
           (ADTensorKind accy)
           (TKProduct (ADTensorKind accy) (ADTensorKind by)))
        (TKProduct accy ey))
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct (FullShapeTK (TKProduct accy (TKProduct accy by))
-> FullShapeTK (ADTensorKind (TKProduct accy (TKProduct accy by)))
forall (y :: TK). FullShapeTK y -> FullShapeTK (ADTensorKind y)
adFTK (FullShapeTK accy
-> FullShapeTK (TKProduct accy by)
-> FullShapeTK (TKProduct accy (TKProduct accy by))
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct FullShapeTK accy
accftk FullShapeTK (TKProduct accy by)
codomainShs))
                                         (FullShapeTK accy
-> FullShapeTK ey -> FullShapeTK (TKProduct accy ey)
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct FullShapeTK accy
accftk FullShapeTK ey
eftk))
                           (HFun
   (TKProduct
      (TKProduct
         (ADTensorKind accy)
         (TKProduct (ADTensorKind accy) (ADTensorKind by)))
      (TKProduct accy ey))
   (TKProduct (ADTensorKind accy) (ADTensorKind ey))
 -> HFunOf
      target
      (TKProduct
         (TKProduct
            (ADTensorKind accy)
            (TKProduct (ADTensorKind accy) (ADTensorKind by)))
         (TKProduct accy ey))
      (TKProduct (ADTensorKind accy) (ADTensorKind ey)))
-> HFun
     (TKProduct
        (TKProduct
           (ADTensorKind accy)
           (TKProduct (ADTensorKind accy) (ADTensorKind by)))
        (TKProduct accy ey))
     (TKProduct (ADTensorKind accy) (ADTensorKind ey))
-> HFunOf
     target
     (TKProduct
        (TKProduct
           (ADTensorKind accy)
           (TKProduct (ADTensorKind accy) (ADTensorKind by)))
        (TKProduct accy ey))
     (TKProduct (ADTensorKind accy) (ADTensorKind ey))
forall a b. (a -> b) -> a -> b
$ (forall (f :: Target).
 ADReady f =>
 f (TKProduct
      (TKProduct
         (ADTensorKind accy)
         (TKProduct (ADTensorKind accy) (ADTensorKind by)))
      (TKProduct accy ey))
 -> f (TKProduct (ADTensorKind accy) (ADTensorKind ey)))
-> HFun
     (TKProduct
        (TKProduct
           (ADTensorKind accy)
           (TKProduct (ADTensorKind accy) (ADTensorKind by)))
        (TKProduct accy ey))
     (TKProduct (ADTensorKind accy) (ADTensorKind ey))
forall (x :: TK) (z :: TK).
(forall (f :: Target). ADReady f => f x -> f z) -> HFun x z
HFun f (TKProduct
     (ADTensorKind (TKProduct accy (TKProduct accy by)))
     (TKProduct accy ey))
-> f (ADTensorKind (TKProduct accy ey))
f (TKProduct
     (TKProduct
        (ADTensorKind accy)
        (TKProduct (ADTensorKind accy) (ADTensorKind by)))
     (TKProduct accy ey))
-> f (TKProduct (ADTensorKind accy) (ADTensorKind ey))
forall (f :: Target).
ADReady f =>
f (TKProduct
     (ADTensorKind (TKProduct accy (TKProduct accy by)))
     (TKProduct accy ey))
-> f (ADTensorKind (TKProduct accy ey))
forall (f :: Target).
ADReady f =>
f (TKProduct
     (TKProduct
        (ADTensorKind accy)
        (TKProduct (ADTensorKind accy) (ADTensorKind by)))
     (TKProduct accy ey))
-> f (TKProduct (ADTensorKind accy) (ADTensorKind ey))
rg)
                          target accy
acc0 target (BuildTensorKind k ey)
es
        (target accy
accFin, target (TKProduct (BuildTensorKind k accy) (BuildTensorKind k by))
qbs) = target
  (TKProduct
     accy (TKProduct (BuildTensorKind k accy) (BuildTensorKind k by)))
-> (target accy,
    target (TKProduct (BuildTensorKind k accy) (BuildTensorKind k by)))
forall (x :: TK) (z :: TK).
target (TKProduct x z) -> (target x, target z)
forall (target :: Target) (x :: TK) (z :: TK).
ShareTensor target =>
target (TKProduct x z) -> (target x, target z)
tunpair target (TKProduct accy (BuildTensorKind k (TKProduct accy by)))
target
  (TKProduct
     accy (TKProduct (BuildTensorKind k accy) (BuildTensorKind k by)))
p
        (target (BuildTensorKind k accy)
q, target (BuildTensorKind k by)
bs) = target (TKProduct (BuildTensorKind k accy) (BuildTensorKind k by))
-> (target (BuildTensorKind k accy), target (BuildTensorKind k by))
forall (x :: TK) (z :: TK).
target (TKProduct x z) -> (target x, target z)
forall (target :: Target) (x :: TK) (z :: TK).
ShareTensor target =>
target (TKProduct x z) -> (target x, target z)
tunpair target (TKProduct (BuildTensorKind k accy) (BuildTensorKind k by))
qbs
        dual :: Delta target (TKProduct accy (BuildTensorKind k by))
dual = SNat k
-> FullShapeTK by
-> FullShapeTK ey
-> target (BuildTensorKind k accy)
-> target (BuildTensorKind k ey)
-> HFun
     (TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
     (ADTensorKind (TKProduct accy by))
-> HFun
     (TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
     (ADTensorKind (TKProduct accy ey))
-> Delta target accy
-> Delta target (BuildTensorKind k ey)
-> Delta target (TKProduct accy (BuildTensorKind k by))
forall (a :: Target) (k :: Nat) (accy :: TK) (by :: TK) (ey :: TK).
(Show (a (BuildTensorKind k accy)),
 Show (a (BuildTensorKind k ey))) =>
SNat k
-> FullShapeTK by
-> FullShapeTK ey
-> a (BuildTensorKind k accy)
-> a (BuildTensorKind k ey)
-> HFun
     (TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
     (ADTensorKind (TKProduct accy by))
-> HFun
     (TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
     (ADTensorKind (TKProduct accy ey))
-> Delta a accy
-> Delta a (BuildTensorKind k ey)
-> Delta a (TKProduct accy (BuildTensorKind k by))
DeltaMapAccumR SNat k
k FullShapeTK by
bftk FullShapeTK ey
eftk target (BuildTensorKind k accy)
q target (BuildTensorKind k ey)
es HFunOf
  (ADVal target)
  (TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
  (ADTensorKind (TKProduct accy by))
HFun
  (TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
  (ADTensorKind (TKProduct accy by))
df HFunOf
  (ADVal target)
  (TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
  (ADTensorKind (TKProduct accy ey))
HFun
  (TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
  (ADTensorKind (TKProduct accy ey))
rf Delta target accy
acc0' Delta target (BuildTensorKind k ey)
es'
    in target (TKProduct accy (BuildTensorKind k by))
-> Delta target (TKProduct accy (BuildTensorKind k by))
-> ADVal target (TKProduct accy (BuildTensorKind k by))
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (target accy
-> target (BuildTensorKind k by)
-> target (TKProduct accy (BuildTensorKind k by))
forall (x :: TK) (z :: TK).
target x -> target z -> target (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair target accy
accFin target (BuildTensorKind k by)
bs) Delta target (TKProduct accy (BuildTensorKind k by))
dual
  tmapAccumLDer :: forall (accy :: TK) (by :: TK) (ey :: TK) (k :: Nat).
Proxy @Target (ADVal target)
-> SNat k
-> FullShapeTK accy
-> FullShapeTK by
-> FullShapeTK ey
-> HFunOf (ADVal target) (TKProduct accy ey) (TKProduct accy by)
-> HFunOf
     (ADVal target)
     (TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
     (ADTensorKind (TKProduct accy by))
-> HFunOf
     (ADVal target)
     (TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
     (ADTensorKind (TKProduct accy ey))
-> ADVal target accy
-> ADVal target (BuildTensorKind k ey)
-> ADVal target (TKProduct accy (BuildTensorKind k by))
tmapAccumLDer @accy @by @ey Proxy @Target (ADVal target)
_ !SNat k
k FullShapeTK accy
accftk FullShapeTK by
bftk FullShapeTK ey
eftk HFunOf (ADVal target) (TKProduct accy ey) (TKProduct accy by)
f HFunOf
  (ADVal target)
  (TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
  (ADTensorKind (TKProduct accy by))
df HFunOf
  (ADVal target)
  (TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
  (ADTensorKind (TKProduct accy ey))
rf ADVal target accy
acc0D ADVal target (BuildTensorKind k ey)
esD
   | Dict @TK KnownSTK (BuildTensorKind k accy)
Dict <- SNat k
-> SingletonTK accy -> Dict @TK KnownSTK (BuildTensorKind k accy)
forall (k :: Nat) (y :: TK).
SNat k -> SingletonTK y -> Dict @TK KnownSTK (BuildTensorKind k y)
lemKnownSTKOfBuild SNat k
k (FullShapeTK accy -> SingletonTK accy
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK accy
accftk)
   , Dict @TK KnownSTK (BuildTensorKind k ey)
Dict <- SNat k
-> SingletonTK ey -> Dict @TK KnownSTK (BuildTensorKind k ey)
forall (k :: Nat) (y :: TK).
SNat k -> SingletonTK y -> Dict @TK KnownSTK (BuildTensorKind k y)
lemKnownSTKOfBuild SNat k
k (FullShapeTK ey -> SingletonTK ey
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK ey
eftk) =
    let !(D target accy
acc0 Delta target accy
acc0') = ADVal target accy
acc0D in
    let !(D target (BuildTensorKind k ey)
esNotShared Delta target (BuildTensorKind k ey)
es') = ADVal target (BuildTensorKind k ey)
esD in
    let !es :: target (BuildTensorKind k ey)
es = target (BuildTensorKind k ey) -> target (BuildTensorKind k ey)
forall (y :: TK). target y -> target y
forall (target :: Target) (y :: TK).
ShareTensor target =>
target y -> target y
tshare target (BuildTensorKind k ey)
esNotShared
        codomainShs :: FullShapeTK (TKProduct accy by)
codomainShs = FullShapeTK accy
-> FullShapeTK by -> FullShapeTK (TKProduct accy by)
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct FullShapeTK accy
accftk FullShapeTK by
bftk
        g :: forall f. ADReady f
          => f (TKProduct accy ey)
          -> f (TKProduct accy (TKProduct accy by))
        g :: forall (f :: Target).
ADReady f =>
f (TKProduct accy ey) -> f (TKProduct accy (TKProduct accy by))
g !f (TKProduct accy ey)
acc_e =
          f (TKProduct accy ey)
-> (f (TKProduct accy ey)
    -> f (TKProduct accy (TKProduct accy by)))
-> f (TKProduct accy (TKProduct accy by))
forall (x :: TK) (z :: TK). f x -> (f x -> f z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
LetTensor target =>
target x -> (target x -> target z) -> target z
ttlet f (TKProduct accy ey)
acc_e ((f (TKProduct accy ey) -> f (TKProduct accy (TKProduct accy by)))
 -> f (TKProduct accy (TKProduct accy by)))
-> (f (TKProduct accy ey)
    -> f (TKProduct accy (TKProduct accy by)))
-> f (TKProduct accy (TKProduct accy by))
forall a b. (a -> b) -> a -> b
$ \ !f (TKProduct accy ey)
acc_e1 ->
          f (TKProduct accy by)
-> (f (TKProduct accy by)
    -> f (TKProduct accy (TKProduct accy by)))
-> f (TKProduct accy (TKProduct accy by))
forall (x :: TK) (z :: TK). f x -> (f x -> f z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
LetTensor target =>
target x -> (target x -> target z) -> target z
ttlet (HFun (TKProduct accy ey) (TKProduct accy by)
-> forall (f :: Target).
   ADReady f =>
   f (TKProduct accy ey) -> f (TKProduct accy by)
forall (x :: TK) (z :: TK).
HFun x z -> forall (f :: Target). ADReady f => f x -> f z
unHFun HFunOf (ADVal target) (TKProduct accy ey) (TKProduct accy by)
HFun (TKProduct accy ey) (TKProduct accy by)
f f (TKProduct accy ey)
acc_e) ((f (TKProduct accy by) -> f (TKProduct accy (TKProduct accy by)))
 -> f (TKProduct accy (TKProduct accy by)))
-> (f (TKProduct accy by)
    -> f (TKProduct accy (TKProduct accy by)))
-> f (TKProduct accy (TKProduct accy by))
forall a b. (a -> b) -> a -> b
$ \ !f (TKProduct accy by)
accRes_bRes ->
            f accy
-> f (TKProduct accy by) -> f (TKProduct accy (TKProduct accy by))
forall (x :: TK) (z :: TK). f x -> f z -> f (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair (f (TKProduct accy by) -> f accy
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 f (TKProduct accy by)
accRes_bRes)
                  (f accy -> f by -> f (TKProduct accy by)
forall (x :: TK) (z :: TK). f x -> f z -> f (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair (f (TKProduct accy ey) -> f accy
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 f (TKProduct accy ey)
acc_e1) (f (TKProduct accy by) -> f by
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target z
tproject2 f (TKProduct accy by)
accRes_bRes))
        dg :: forall f. ADReady f
           => f (TKProduct (ADTensorKind (TKProduct accy ey))
                           (TKProduct accy ey))
           -> f (ADTensorKind (TKProduct accy (TKProduct accy by)))
        dg :: forall (f :: Target).
ADReady f =>
f (TKProduct
     (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
-> f (ADTensorKind (TKProduct accy (TKProduct accy by)))
dg !f (TKProduct
     (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
dacc_de_acc_e =
          f (TKProduct
     (TKProduct (ADTensorKind accy) (ADTensorKind ey))
     (TKProduct accy ey))
-> (f (TKProduct
         (TKProduct (ADTensorKind accy) (ADTensorKind ey))
         (TKProduct accy ey))
    -> f (ADTensorKind (TKProduct accy (TKProduct accy by))))
-> f (ADTensorKind (TKProduct accy (TKProduct accy by)))
forall (x :: TK) (z :: TK). f x -> (f x -> f z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
LetTensor target =>
target x -> (target x -> target z) -> target z
ttlet f (TKProduct
     (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
f (TKProduct
     (TKProduct (ADTensorKind accy) (ADTensorKind ey))
     (TKProduct accy ey))
dacc_de_acc_e ((f (TKProduct
       (TKProduct (ADTensorKind accy) (ADTensorKind ey))
       (TKProduct accy ey))
  -> f (ADTensorKind (TKProduct accy (TKProduct accy by))))
 -> f (ADTensorKind (TKProduct accy (TKProduct accy by))))
-> (f (TKProduct
         (TKProduct (ADTensorKind accy) (ADTensorKind ey))
         (TKProduct accy ey))
    -> f (ADTensorKind (TKProduct accy (TKProduct accy by))))
-> f (ADTensorKind (TKProduct accy (TKProduct accy by)))
forall a b. (a -> b) -> a -> b
$ \ !f (TKProduct
     (TKProduct (ADTensorKind accy) (ADTensorKind ey))
     (TKProduct accy ey))
dacc_de_acc_e1 ->
            let (!f (TKProduct (ADTensorKind accy) (ADTensorKind ey))
dacc_de, !f (TKProduct accy ey)
_acc_e) =
                  (f (TKProduct
     (TKProduct (ADTensorKind accy) (ADTensorKind ey))
     (TKProduct accy ey))
-> f (TKProduct (ADTensorKind accy) (ADTensorKind ey))
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 f (TKProduct
     (TKProduct (ADTensorKind accy) (ADTensorKind ey))
     (TKProduct accy ey))
dacc_de_acc_e1, f (TKProduct
     (TKProduct (ADTensorKind accy) (ADTensorKind ey))
     (TKProduct accy ey))
-> f (TKProduct accy ey)
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target z
tproject2 f (TKProduct
     (TKProduct (ADTensorKind accy) (ADTensorKind ey))
     (TKProduct accy ey))
dacc_de_acc_e1)
                !dacc1 :: f (ADTensorKind accy)
dacc1 = f (TKProduct (ADTensorKind accy) (ADTensorKind ey))
-> f (ADTensorKind accy)
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 f (TKProduct (ADTensorKind accy) (ADTensorKind ey))
dacc_de
            in f (TKProduct (ADTensorKind accy) (ADTensorKind by))
-> (f (TKProduct (ADTensorKind accy) (ADTensorKind by))
    -> f (ADTensorKind (TKProduct accy (TKProduct accy by))))
-> f (ADTensorKind (TKProduct accy (TKProduct accy by)))
forall (x :: TK) (z :: TK). f x -> (f x -> f z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
LetTensor target =>
target x -> (target x -> target z) -> target z
ttlet (HFun
  (TKProduct
     (TKProduct (ADTensorKind accy) (ADTensorKind ey))
     (TKProduct accy ey))
  (TKProduct (ADTensorKind accy) (ADTensorKind by))
-> forall (f :: Target).
   ADReady f =>
   f (TKProduct
        (TKProduct (ADTensorKind accy) (ADTensorKind ey))
        (TKProduct accy ey))
   -> f (TKProduct (ADTensorKind accy) (ADTensorKind by))
forall (x :: TK) (z :: TK).
HFun x z -> forall (f :: Target). ADReady f => f x -> f z
unHFun HFunOf
  (ADVal target)
  (TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
  (ADTensorKind (TKProduct accy by))
HFun
  (TKProduct
     (TKProduct (ADTensorKind accy) (ADTensorKind ey))
     (TKProduct accy ey))
  (TKProduct (ADTensorKind accy) (ADTensorKind by))
df f (TKProduct
     (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
f (TKProduct
     (TKProduct (ADTensorKind accy) (ADTensorKind ey))
     (TKProduct accy ey))
dacc_de_acc_e) ((f (TKProduct (ADTensorKind accy) (ADTensorKind by))
  -> f (ADTensorKind (TKProduct accy (TKProduct accy by))))
 -> f (ADTensorKind (TKProduct accy (TKProduct accy by))))
-> (f (TKProduct (ADTensorKind accy) (ADTensorKind by))
    -> f (ADTensorKind (TKProduct accy (TKProduct accy by))))
-> f (ADTensorKind (TKProduct accy (TKProduct accy by)))
forall a b. (a -> b) -> a -> b
$ \ !f (TKProduct (ADTensorKind accy) (ADTensorKind by))
accRes_bRes ->
                 f (ADTensorKind accy)
-> f (TKProduct (ADTensorKind accy) (ADTensorKind by))
-> f (TKProduct
        (ADTensorKind accy)
        (TKProduct (ADTensorKind accy) (ADTensorKind by)))
forall (x :: TK) (z :: TK). f x -> f z -> f (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair (f (TKProduct (ADTensorKind accy) (ADTensorKind by))
-> f (ADTensorKind accy)
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 f (TKProduct (ADTensorKind accy) (ADTensorKind by))
accRes_bRes)
                       (f (ADTensorKind accy)
-> f (ADTensorKind by)
-> f (TKProduct (ADTensorKind accy) (ADTensorKind by))
forall (x :: TK) (z :: TK). f x -> f z -> f (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair f (ADTensorKind accy)
dacc1 (f (TKProduct (ADTensorKind accy) (ADTensorKind by))
-> f (ADTensorKind by)
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target z
tproject2 f (TKProduct (ADTensorKind accy) (ADTensorKind by))
accRes_bRes))
        rg :: forall f. ADReady f
           => f (TKProduct (ADTensorKind (TKProduct accy
                                         (TKProduct accy by)))
                           (TKProduct accy ey))
           -> f (ADTensorKind (TKProduct accy ey))
        rg :: forall (f :: Target).
ADReady f =>
f (TKProduct
     (ADTensorKind (TKProduct accy (TKProduct accy by)))
     (TKProduct accy ey))
-> f (ADTensorKind (TKProduct accy ey))
rg !f (TKProduct
     (ADTensorKind (TKProduct accy (TKProduct accy by)))
     (TKProduct accy ey))
args =
          f (TKProduct
     (TKProduct
        (ADTensorKind accy)
        (TKProduct (ADTensorKind accy) (ADTensorKind by)))
     (TKProduct accy ey))
-> (f (TKProduct
         (TKProduct
            (ADTensorKind accy)
            (TKProduct (ADTensorKind accy) (ADTensorKind by)))
         (TKProduct accy ey))
    -> f (ADTensorKind (TKProduct accy ey)))
-> f (ADTensorKind (TKProduct accy ey))
forall (x :: TK) (z :: TK). f x -> (f x -> f z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
LetTensor target =>
target x -> (target x -> target z) -> target z
ttlet f (TKProduct
     (ADTensorKind (TKProduct accy (TKProduct accy by)))
     (TKProduct accy ey))
f (TKProduct
     (TKProduct
        (ADTensorKind accy)
        (TKProduct (ADTensorKind accy) (ADTensorKind by)))
     (TKProduct accy ey))
args ((f (TKProduct
       (TKProduct
          (ADTensorKind accy)
          (TKProduct (ADTensorKind accy) (ADTensorKind by)))
       (TKProduct accy ey))
  -> f (ADTensorKind (TKProduct accy ey)))
 -> f (ADTensorKind (TKProduct accy ey)))
-> (f (TKProduct
         (TKProduct
            (ADTensorKind accy)
            (TKProduct (ADTensorKind accy) (ADTensorKind by)))
         (TKProduct accy ey))
    -> f (ADTensorKind (TKProduct accy ey)))
-> f (ADTensorKind (TKProduct accy ey))
forall a b. (a -> b) -> a -> b
$ \ f (TKProduct
     (TKProduct
        (ADTensorKind accy)
        (TKProduct (ADTensorKind accy) (ADTensorKind by)))
     (TKProduct accy ey))
args1 ->
            let (!f (TKProduct
     (ADTensorKind accy)
     (TKProduct (ADTensorKind accy) (ADTensorKind by)))
dx_db, !f (TKProduct accy ey)
acc_e) = (f (TKProduct
     (TKProduct
        (ADTensorKind accy)
        (TKProduct (ADTensorKind accy) (ADTensorKind by)))
     (TKProduct accy ey))
-> f (TKProduct
        (ADTensorKind accy)
        (TKProduct (ADTensorKind accy) (ADTensorKind by)))
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 f (TKProduct
     (TKProduct
        (ADTensorKind accy)
        (TKProduct (ADTensorKind accy) (ADTensorKind by)))
     (TKProduct accy ey))
args1, f (TKProduct
     (TKProduct
        (ADTensorKind accy)
        (TKProduct (ADTensorKind accy) (ADTensorKind by)))
     (TKProduct accy ey))
-> f (TKProduct accy ey)
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target z
tproject2 f (TKProduct
     (TKProduct
        (ADTensorKind accy)
        (TKProduct (ADTensorKind accy) (ADTensorKind by)))
     (TKProduct accy ey))
args1)
            in f (TKProduct
     (ADTensorKind accy)
     (TKProduct (ADTensorKind accy) (ADTensorKind by)))
-> (f (TKProduct
         (ADTensorKind accy)
         (TKProduct (ADTensorKind accy) (ADTensorKind by)))
    -> f (ADTensorKind (TKProduct accy ey)))
-> f (ADTensorKind (TKProduct accy ey))
forall (x :: TK) (z :: TK). f x -> (f x -> f z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
LetTensor target =>
target x -> (target x -> target z) -> target z
ttlet f (TKProduct
     (ADTensorKind accy)
     (TKProduct (ADTensorKind accy) (ADTensorKind by)))
dx_db ((f (TKProduct
       (ADTensorKind accy)
       (TKProduct (ADTensorKind accy) (ADTensorKind by)))
  -> f (ADTensorKind (TKProduct accy ey)))
 -> f (ADTensorKind (TKProduct accy ey)))
-> (f (TKProduct
         (ADTensorKind accy)
         (TKProduct (ADTensorKind accy) (ADTensorKind by)))
    -> f (ADTensorKind (TKProduct accy ey)))
-> f (ADTensorKind (TKProduct accy ey))
forall a b. (a -> b) -> a -> b
$ \ !f (TKProduct
     (ADTensorKind accy)
     (TKProduct (ADTensorKind accy) (ADTensorKind by)))
dx_db1 ->
              let (!f (ADTensorKind accy)
dx, !f (TKProduct (ADTensorKind accy) (ADTensorKind by))
db) = (f (TKProduct
     (ADTensorKind accy)
     (TKProduct (ADTensorKind accy) (ADTensorKind by)))
-> f (ADTensorKind accy)
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 f (TKProduct
     (ADTensorKind accy)
     (TKProduct (ADTensorKind accy) (ADTensorKind by)))
dx_db1, f (TKProduct
     (ADTensorKind accy)
     (TKProduct (ADTensorKind accy) (ADTensorKind by)))
-> f (TKProduct (ADTensorKind accy) (ADTensorKind by))
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target z
tproject2 f (TKProduct
     (ADTensorKind accy)
     (TKProduct (ADTensorKind accy) (ADTensorKind by)))
dx_db1)
              in f (TKProduct (ADTensorKind accy) (ADTensorKind by))
-> (f (TKProduct (ADTensorKind accy) (ADTensorKind by))
    -> f (ADTensorKind (TKProduct accy ey)))
-> f (ADTensorKind (TKProduct accy ey))
forall (x :: TK) (z :: TK). f x -> (f x -> f z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
LetTensor target =>
target x -> (target x -> target z) -> target z
ttlet f (TKProduct (ADTensorKind accy) (ADTensorKind by))
db ((f (TKProduct (ADTensorKind accy) (ADTensorKind by))
  -> f (ADTensorKind (TKProduct accy ey)))
 -> f (ADTensorKind (TKProduct accy ey)))
-> (f (TKProduct (ADTensorKind accy) (ADTensorKind by))
    -> f (ADTensorKind (TKProduct accy ey)))
-> f (ADTensorKind (TKProduct accy ey))
forall a b. (a -> b) -> a -> b
$ \ !f (TKProduct (ADTensorKind accy) (ADTensorKind by))
db1 ->
                let dx_dbRes :: f (TKProduct (ADTensorKind accy) (ADTensorKind by))
dx_dbRes = f (ADTensorKind accy)
-> f (ADTensorKind by)
-> f (TKProduct (ADTensorKind accy) (ADTensorKind by))
forall (x :: TK) (z :: TK). f x -> f z -> f (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair f (ADTensorKind accy)
dx (f (TKProduct (ADTensorKind accy) (ADTensorKind by))
-> f (ADTensorKind by)
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target z
tproject2 f (TKProduct (ADTensorKind accy) (ADTensorKind by))
db1)
                in f (TKProduct (ADTensorKind accy) (ADTensorKind ey))
-> (f (TKProduct (ADTensorKind accy) (ADTensorKind ey))
    -> f (ADTensorKind (TKProduct accy ey)))
-> f (ADTensorKind (TKProduct accy ey))
forall (x :: TK) (z :: TK). f x -> (f x -> f z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
LetTensor target =>
target x -> (target x -> target z) -> target z
ttlet (HFun
  (TKProduct
     (TKProduct (ADTensorKind accy) (ADTensorKind by))
     (TKProduct accy ey))
  (TKProduct (ADTensorKind accy) (ADTensorKind ey))
-> forall (f :: Target).
   ADReady f =>
   f (TKProduct
        (TKProduct (ADTensorKind accy) (ADTensorKind by))
        (TKProduct accy ey))
   -> f (TKProduct (ADTensorKind accy) (ADTensorKind ey))
forall (x :: TK) (z :: TK).
HFun x z -> forall (f :: Target). ADReady f => f x -> f z
unHFun HFunOf
  (ADVal target)
  (TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
  (ADTensorKind (TKProduct accy ey))
HFun
  (TKProduct
     (TKProduct (ADTensorKind accy) (ADTensorKind by))
     (TKProduct accy ey))
  (TKProduct (ADTensorKind accy) (ADTensorKind ey))
rf (f (TKProduct (ADTensorKind accy) (ADTensorKind by))
-> f (TKProduct accy ey)
-> f (TKProduct
        (TKProduct (ADTensorKind accy) (ADTensorKind by))
        (TKProduct accy ey))
forall (x :: TK) (z :: TK). f x -> f z -> f (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair f (TKProduct (ADTensorKind accy) (ADTensorKind by))
dx_dbRes f (TKProduct accy ey)
acc_e))
                   ((f (TKProduct (ADTensorKind accy) (ADTensorKind ey))
  -> f (ADTensorKind (TKProduct accy ey)))
 -> f (ADTensorKind (TKProduct accy ey)))
-> (f (TKProduct (ADTensorKind accy) (ADTensorKind ey))
    -> f (ADTensorKind (TKProduct accy ey)))
-> f (ADTensorKind (TKProduct accy ey))
forall a b. (a -> b) -> a -> b
$ \ !f (TKProduct (ADTensorKind accy) (ADTensorKind ey))
daccRes_deRes ->
                  let added :: f (ADTensorKind accy)
added = SingletonTK (ADTensorKind accy)
-> f (ADTensorKind accy)
-> f (ADTensorKind accy)
-> f (ADTensorKind accy)
forall (y :: TK). SingletonTK y -> f y -> f y -> f y
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> target y -> target y
taddTarget (SingletonTK accy -> SingletonTK (ADTensorKind accy)
forall (y :: TK). SingletonTK y -> SingletonTK (ADTensorKind y)
adSTK (SingletonTK accy -> SingletonTK (ADTensorKind accy))
-> SingletonTK accy -> SingletonTK (ADTensorKind accy)
forall a b. (a -> b) -> a -> b
$ FullShapeTK accy -> SingletonTK accy
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK accy
accftk)
                                         (f (TKProduct (ADTensorKind accy) (ADTensorKind ey))
-> f (ADTensorKind accy)
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 f (TKProduct (ADTensorKind accy) (ADTensorKind ey))
daccRes_deRes)
                                         (f (TKProduct (ADTensorKind accy) (ADTensorKind by))
-> f (ADTensorKind accy)
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 f (TKProduct (ADTensorKind accy) (ADTensorKind by))
db1)
                  in f (ADTensorKind accy)
-> f (ADTensorKind ey)
-> f (TKProduct (ADTensorKind accy) (ADTensorKind ey))
forall (x :: TK) (z :: TK). f x -> f z -> f (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair f (ADTensorKind accy)
added (f (TKProduct (ADTensorKind accy) (ADTensorKind ey))
-> f (ADTensorKind ey)
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target z
tproject2 f (TKProduct (ADTensorKind accy) (ADTensorKind ey))
daccRes_deRes)
        p :: target (TKProduct accy (BuildTensorKind k (TKProduct accy by)))
p = Proxy @Target target
-> SNat k
-> FullShapeTK accy
-> FullShapeTK (TKProduct accy by)
-> FullShapeTK ey
-> HFunOf
     target (TKProduct accy ey) (TKProduct accy (TKProduct accy by))
-> HFunOf
     target
     (TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
     (ADTensorKind (TKProduct accy (TKProduct accy by)))
-> HFunOf
     target
     (TKProduct
        (ADTensorKind (TKProduct accy (TKProduct accy by)))
        (TKProduct accy ey))
     (ADTensorKind (TKProduct accy ey))
-> target accy
-> target (BuildTensorKind k ey)
-> target (TKProduct accy (BuildTensorKind k (TKProduct accy by)))
forall (accy :: TK) (by :: TK) (ey :: TK) (k :: Nat).
Proxy @Target target
-> SNat k
-> FullShapeTK accy
-> FullShapeTK by
-> FullShapeTK ey
-> HFunOf target (TKProduct accy ey) (TKProduct accy by)
-> HFunOf
     target
     (TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
     (ADTensorKind (TKProduct accy by))
-> HFunOf
     target
     (TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
     (ADTensorKind (TKProduct accy ey))
-> target accy
-> target (BuildTensorKind k ey)
-> target (TKProduct accy (BuildTensorKind k by))
forall (target :: Target) (accy :: TK) (by :: TK) (ey :: TK)
       (k :: Nat).
BaseTensor target =>
Proxy @Target target
-> SNat k
-> FullShapeTK accy
-> FullShapeTK by
-> FullShapeTK ey
-> HFunOf target (TKProduct accy ey) (TKProduct accy by)
-> HFunOf
     target
     (TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
     (ADTensorKind (TKProduct accy by))
-> HFunOf
     target
     (TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
     (ADTensorKind (TKProduct accy ey))
-> target accy
-> target (BuildTensorKind k ey)
-> target (TKProduct accy (BuildTensorKind k by))
tmapAccumLDer (forall {k} (t :: k). Proxy @k t
forall (t :: Target). Proxy @Target t
Proxy @target)
                          SNat k
k FullShapeTK accy
accftk FullShapeTK (TKProduct accy by)
codomainShs FullShapeTK ey
eftk
                          (forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
FullShapeTK x -> HFun x z -> HFunOf target x z
tlambda @target (FullShapeTK accy
-> FullShapeTK ey -> FullShapeTK (TKProduct accy ey)
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct FullShapeTK accy
accftk FullShapeTK ey
eftk)
                           (HFun (TKProduct accy ey) (TKProduct accy (TKProduct accy by))
 -> HFunOf
      target (TKProduct accy ey) (TKProduct accy (TKProduct accy by)))
-> HFun (TKProduct accy ey) (TKProduct accy (TKProduct accy by))
-> HFunOf
     target (TKProduct accy ey) (TKProduct accy (TKProduct accy by))
forall a b. (a -> b) -> a -> b
$ (forall (f :: Target).
 ADReady f =>
 f (TKProduct accy ey) -> f (TKProduct accy (TKProduct accy by)))
-> HFun (TKProduct accy ey) (TKProduct accy (TKProduct accy by))
forall (x :: TK) (z :: TK).
(forall (f :: Target). ADReady f => f x -> f z) -> HFun x z
HFun f (TKProduct accy ey) -> f (TKProduct accy (TKProduct accy by))
forall (f :: Target).
ADReady f =>
f (TKProduct accy ey) -> f (TKProduct accy (TKProduct accy by))
g)
                          (forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
FullShapeTK x -> HFun x z -> HFunOf target x z
tlambda @target
                             (FullShapeTK (TKProduct (ADTensorKind accy) (ADTensorKind ey))
-> FullShapeTK (TKProduct accy ey)
-> FullShapeTK
     (TKProduct
        (TKProduct (ADTensorKind accy) (ADTensorKind ey))
        (TKProduct accy ey))
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct (FullShapeTK (TKProduct accy ey)
-> FullShapeTK (ADTensorKind (TKProduct accy ey))
forall (y :: TK). FullShapeTK y -> FullShapeTK (ADTensorKind y)
adFTK (FullShapeTK accy
-> FullShapeTK ey -> FullShapeTK (TKProduct accy ey)
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct FullShapeTK accy
accftk FullShapeTK ey
eftk))
                                         (FullShapeTK accy
-> FullShapeTK ey -> FullShapeTK (TKProduct accy ey)
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct FullShapeTK accy
accftk FullShapeTK ey
eftk))
                           (HFun
   (TKProduct
      (TKProduct (ADTensorKind accy) (ADTensorKind ey))
      (TKProduct accy ey))
   (TKProduct
      (ADTensorKind accy)
      (TKProduct (ADTensorKind accy) (ADTensorKind by)))
 -> HFunOf
      target
      (TKProduct
         (TKProduct (ADTensorKind accy) (ADTensorKind ey))
         (TKProduct accy ey))
      (TKProduct
         (ADTensorKind accy)
         (TKProduct (ADTensorKind accy) (ADTensorKind by))))
-> HFun
     (TKProduct
        (TKProduct (ADTensorKind accy) (ADTensorKind ey))
        (TKProduct accy ey))
     (TKProduct
        (ADTensorKind accy)
        (TKProduct (ADTensorKind accy) (ADTensorKind by)))
-> HFunOf
     target
     (TKProduct
        (TKProduct (ADTensorKind accy) (ADTensorKind ey))
        (TKProduct accy ey))
     (TKProduct
        (ADTensorKind accy)
        (TKProduct (ADTensorKind accy) (ADTensorKind by)))
forall a b. (a -> b) -> a -> b
$ (forall (f :: Target).
 ADReady f =>
 f (TKProduct
      (TKProduct (ADTensorKind accy) (ADTensorKind ey))
      (TKProduct accy ey))
 -> f (TKProduct
         (ADTensorKind accy)
         (TKProduct (ADTensorKind accy) (ADTensorKind by))))
-> HFun
     (TKProduct
        (TKProduct (ADTensorKind accy) (ADTensorKind ey))
        (TKProduct accy ey))
     (TKProduct
        (ADTensorKind accy)
        (TKProduct (ADTensorKind accy) (ADTensorKind by)))
forall (x :: TK) (z :: TK).
(forall (f :: Target). ADReady f => f x -> f z) -> HFun x z
HFun f (TKProduct
     (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
-> f (ADTensorKind (TKProduct accy (TKProduct accy by)))
f (TKProduct
     (TKProduct (ADTensorKind accy) (ADTensorKind ey))
     (TKProduct accy ey))
-> f (TKProduct
        (ADTensorKind accy)
        (TKProduct (ADTensorKind accy) (ADTensorKind by)))
forall (f :: Target).
ADReady f =>
f (TKProduct
     (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
-> f (ADTensorKind (TKProduct accy (TKProduct accy by)))
forall (f :: Target).
ADReady f =>
f (TKProduct
     (TKProduct (ADTensorKind accy) (ADTensorKind ey))
     (TKProduct accy ey))
-> f (TKProduct
        (ADTensorKind accy)
        (TKProduct (ADTensorKind accy) (ADTensorKind by)))
dg)
                          (forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
FullShapeTK x -> HFun x z -> HFunOf target x z
tlambda @target
                             (FullShapeTK
  (TKProduct
     (ADTensorKind accy)
     (TKProduct (ADTensorKind accy) (ADTensorKind by)))
-> FullShapeTK (TKProduct accy ey)
-> FullShapeTK
     (TKProduct
        (TKProduct
           (ADTensorKind accy)
           (TKProduct (ADTensorKind accy) (ADTensorKind by)))
        (TKProduct accy ey))
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct (FullShapeTK (TKProduct accy (TKProduct accy by))
-> FullShapeTK (ADTensorKind (TKProduct accy (TKProduct accy by)))
forall (y :: TK). FullShapeTK y -> FullShapeTK (ADTensorKind y)
adFTK (FullShapeTK accy
-> FullShapeTK (TKProduct accy by)
-> FullShapeTK (TKProduct accy (TKProduct accy by))
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct FullShapeTK accy
accftk FullShapeTK (TKProduct accy by)
codomainShs))
                                         (FullShapeTK accy
-> FullShapeTK ey -> FullShapeTK (TKProduct accy ey)
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct FullShapeTK accy
accftk FullShapeTK ey
eftk))
                           (HFun
   (TKProduct
      (TKProduct
         (ADTensorKind accy)
         (TKProduct (ADTensorKind accy) (ADTensorKind by)))
      (TKProduct accy ey))
   (TKProduct (ADTensorKind accy) (ADTensorKind ey))
 -> HFunOf
      target
      (TKProduct
         (TKProduct
            (ADTensorKind accy)
            (TKProduct (ADTensorKind accy) (ADTensorKind by)))
         (TKProduct accy ey))
      (TKProduct (ADTensorKind accy) (ADTensorKind ey)))
-> HFun
     (TKProduct
        (TKProduct
           (ADTensorKind accy)
           (TKProduct (ADTensorKind accy) (ADTensorKind by)))
        (TKProduct accy ey))
     (TKProduct (ADTensorKind accy) (ADTensorKind ey))
-> HFunOf
     target
     (TKProduct
        (TKProduct
           (ADTensorKind accy)
           (TKProduct (ADTensorKind accy) (ADTensorKind by)))
        (TKProduct accy ey))
     (TKProduct (ADTensorKind accy) (ADTensorKind ey))
forall a b. (a -> b) -> a -> b
$ (forall (f :: Target).
 ADReady f =>
 f (TKProduct
      (TKProduct
         (ADTensorKind accy)
         (TKProduct (ADTensorKind accy) (ADTensorKind by)))
      (TKProduct accy ey))
 -> f (TKProduct (ADTensorKind accy) (ADTensorKind ey)))
-> HFun
     (TKProduct
        (TKProduct
           (ADTensorKind accy)
           (TKProduct (ADTensorKind accy) (ADTensorKind by)))
        (TKProduct accy ey))
     (TKProduct (ADTensorKind accy) (ADTensorKind ey))
forall (x :: TK) (z :: TK).
(forall (f :: Target). ADReady f => f x -> f z) -> HFun x z
HFun f (TKProduct
     (ADTensorKind (TKProduct accy (TKProduct accy by)))
     (TKProduct accy ey))
-> f (ADTensorKind (TKProduct accy ey))
f (TKProduct
     (TKProduct
        (ADTensorKind accy)
        (TKProduct (ADTensorKind accy) (ADTensorKind by)))
     (TKProduct accy ey))
-> f (TKProduct (ADTensorKind accy) (ADTensorKind ey))
forall (f :: Target).
ADReady f =>
f (TKProduct
     (ADTensorKind (TKProduct accy (TKProduct accy by)))
     (TKProduct accy ey))
-> f (ADTensorKind (TKProduct accy ey))
forall (f :: Target).
ADReady f =>
f (TKProduct
     (TKProduct
        (ADTensorKind accy)
        (TKProduct (ADTensorKind accy) (ADTensorKind by)))
     (TKProduct accy ey))
-> f (TKProduct (ADTensorKind accy) (ADTensorKind ey))
rg)
                          target accy
acc0 target (BuildTensorKind k ey)
es
        (target accy
accFin, target (TKProduct (BuildTensorKind k accy) (BuildTensorKind k by))
qbs) = target
  (TKProduct
     accy (TKProduct (BuildTensorKind k accy) (BuildTensorKind k by)))
-> (target accy,
    target (TKProduct (BuildTensorKind k accy) (BuildTensorKind k by)))
forall (x :: TK) (z :: TK).
target (TKProduct x z) -> (target x, target z)
forall (target :: Target) (x :: TK) (z :: TK).
ShareTensor target =>
target (TKProduct x z) -> (target x, target z)
tunpair target (TKProduct accy (BuildTensorKind k (TKProduct accy by)))
target
  (TKProduct
     accy (TKProduct (BuildTensorKind k accy) (BuildTensorKind k by)))
p
        (target (BuildTensorKind k accy)
q, target (BuildTensorKind k by)
bs) = target (TKProduct (BuildTensorKind k accy) (BuildTensorKind k by))
-> (target (BuildTensorKind k accy), target (BuildTensorKind k by))
forall (x :: TK) (z :: TK).
target (TKProduct x z) -> (target x, target z)
forall (target :: Target) (x :: TK) (z :: TK).
ShareTensor target =>
target (TKProduct x z) -> (target x, target z)
tunpair target (TKProduct (BuildTensorKind k accy) (BuildTensorKind k by))
qbs
        dual :: Delta target (TKProduct accy (BuildTensorKind k by))
dual = SNat k
-> FullShapeTK by
-> FullShapeTK ey
-> target (BuildTensorKind k accy)
-> target (BuildTensorKind k ey)
-> HFun
     (TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
     (ADTensorKind (TKProduct accy by))
-> HFun
     (TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
     (ADTensorKind (TKProduct accy ey))
-> Delta target accy
-> Delta target (BuildTensorKind k ey)
-> Delta target (TKProduct accy (BuildTensorKind k by))
forall (a :: Target) (k :: Nat) (accy :: TK) (by :: TK) (ey :: TK).
(Show (a (BuildTensorKind k accy)),
 Show (a (BuildTensorKind k ey))) =>
SNat k
-> FullShapeTK by
-> FullShapeTK ey
-> a (BuildTensorKind k accy)
-> a (BuildTensorKind k ey)
-> HFun
     (TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
     (ADTensorKind (TKProduct accy by))
-> HFun
     (TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
     (ADTensorKind (TKProduct accy ey))
-> Delta a accy
-> Delta a (BuildTensorKind k ey)
-> Delta a (TKProduct accy (BuildTensorKind k by))
DeltaMapAccumL SNat k
k FullShapeTK by
bftk FullShapeTK ey
eftk target (BuildTensorKind k accy)
q target (BuildTensorKind k ey)
es HFunOf
  (ADVal target)
  (TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
  (ADTensorKind (TKProduct accy by))
HFun
  (TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
  (ADTensorKind (TKProduct accy by))
df HFunOf
  (ADVal target)
  (TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
  (ADTensorKind (TKProduct accy ey))
HFun
  (TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
  (ADTensorKind (TKProduct accy ey))
rf Delta target accy
acc0' Delta target (BuildTensorKind k ey)
es'
    in target (TKProduct accy (BuildTensorKind k by))
-> Delta target (TKProduct accy (BuildTensorKind k by))
-> ADVal target (TKProduct accy (BuildTensorKind k by))
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (target accy
-> target (BuildTensorKind k by)
-> target (TKProduct accy (BuildTensorKind k by))
forall (x :: TK) (z :: TK).
target x -> target z -> target (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair target accy
accFin target (BuildTensorKind k by)
bs) Delta target (TKProduct accy (BuildTensorKind k by))
dual
  tApply :: forall (x :: TK) (z :: TK).
HFunOf (ADVal target) x z -> ADVal target x -> ADVal target z
tApply (HFun forall (f :: Target). ADReady f => f x -> f z
f) = ADVal target x -> ADVal target z
forall (f :: Target). ADReady f => f x -> f z
f
  tlambda :: forall (x :: TK) (z :: TK).
FullShapeTK x -> HFun x z -> HFunOf (ADVal target) x z
tlambda FullShapeTK x
_ = HFun x z -> HFunOf (ADVal target) x z
HFun x z -> HFun x z
forall a. a -> a
id
  -- Bangs are for the proper order of sharing stamps.
  tcond :: forall (y :: TK).
Boolean (BoolOf (ADVal target)) =>
SingletonTK y
-> BoolOf (ADVal target)
-> ADVal target y
-> ADVal target y
-> ADVal target y
tcond !SingletonTK y
stk !BoolOf (ADVal target)
b !ADVal target y
u !ADVal target y
v =
    let uv :: ADVal target (BuildTensorKind 2 y)
uv = SNat 2
-> SingletonTK y
-> Vector (ADVal target y)
-> ADVal target (BuildTensorKind 2 y)
forall (y :: TK) (k :: Nat).
SNat k
-> SingletonTK y
-> Vector (ADVal target y)
-> ADVal target (BuildTensorKind k y)
forall (target :: Target) (y :: TK) (k :: Nat).
BaseTensor target =>
SNat k
-> SingletonTK y
-> Vector (target y)
-> target (BuildTensorKind k y)
tfromVector (forall (n :: Nat). KnownNat n => SNat n
SNat @2) SingletonTK y
stk ([ADVal target y] -> Vector (ADVal target y)
forall (v :: Type -> Type) a. Vector v a => [a] -> v a
V.fromList [ADVal target y
u, ADVal target y
v])
    in SNat 2
-> SingletonTK y
-> ADVal target (BuildTensorKind 2 y)
-> IntOf (ADVal target)
-> ADVal target y
forall (z :: TK) (k :: Nat).
ConvertTensor (ADVal target) =>
SNat k
-> SingletonTK z
-> ADVal target (BuildTensorKind k z)
-> IntOf (ADVal target)
-> ADVal target z
forall (target :: Target) (z :: TK) (k :: Nat).
(BaseTensor target, ConvertTensor target) =>
SNat k
-> SingletonTK z
-> target (BuildTensorKind k z)
-> IntOf target
-> target z
tindexBuild (forall (n :: Nat). KnownNat n => SNat n
SNat @2) SingletonTK y
stk ADVal target (BuildTensorKind 2 y)
uv (SingletonTK (TKScalar Int64)
-> BoolOf target
-> target (TKScalar Int64)
-> target (TKScalar Int64)
-> target (TKScalar Int64)
forall (y :: TK).
Boolean (BoolOf target) =>
SingletonTK y -> BoolOf target -> target y -> target y -> target y
forall (target :: Target) (y :: TK).
(BaseTensor target, Boolean (BoolOf target)) =>
SingletonTK y -> BoolOf target -> target y -> target y -> target y
tcond SingletonTK (TKScalar Int64)
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK BoolOf target
BoolOf (ADVal target)
b target (TKScalar Int64)
0 target (TKScalar Int64)
1)
  tprimalPart :: forall (y :: TK). ADVal target y -> PrimalOf (ADVal target) y
tprimalPart (D target y
u Delta target y
_) = target y
PrimalOf (ADVal target) y
u
  tdualPart :: forall (y :: TK).
SingletonTK y -> ADVal target y -> DualOf (ADVal target) y
tdualPart SingletonTK y
_stk (D target y
_ Delta target y
u') = DualOf (ADVal target) y
Delta target y
u'
  tfromPrimal :: forall (y :: TK).
SingletonTK y -> PrimalOf (ADVal target) y -> ADVal target y
tfromPrimal SingletonTK y
stk PrimalOf (ADVal target) y
t = FullShapeTK y -> target y -> ADVal target y
forall (z :: TK) (f :: Target). FullShapeTK z -> f z -> ADVal f z
fromPrimalFTK (SingletonTK y -> target y -> FullShapeTK y
forall (y :: TK). SingletonTK y -> target y -> FullShapeTK y
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk SingletonTK y
stk target y
PrimalOf (ADVal target) y
t) target y
PrimalOf (ADVal target) y
t
  tfromDual :: forall (y :: TK). DualOf (ADVal target) y -> ADVal target y
tfromDual DualOf (ADVal target) y
t = target y -> Delta target y -> ADVal target y
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dDnotShared (FullShapeTK y -> target y
forall (y :: TK). FullShapeTK y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
FullShapeTK y -> target y
tdefTarget (Delta target y -> FullShapeTK y
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta DualOf (ADVal target) y
Delta target y
t)) DualOf (ADVal target) y
Delta target y
t
  tScale :: forall (y :: TK).
(Num (ADVal target y), Num (PrimalOf (ADVal target) y)) =>
SingletonTK y
-> PrimalOf (ADVal target) y
-> DualOf (ADVal target) y
-> DualOf (ADVal target) y
tScale SingletonTK y
_stk = target y -> Delta target y -> Delta target y
PrimalOf (ADVal target) y
-> DualOf (ADVal target) y -> DualOf (ADVal target) y
forall (f :: Target) (z :: TK).
Num (f z) =>
f z -> Delta f z -> Delta f z
dScale
  tgrad :: forall (x :: TK) r.
FullShapeTK x
-> HFun x (TKScalar r) -> HFunOf (ADVal target) x (ADTensorKind x)
tgrad @x FullShapeTK x
xftk HFun x (TKScalar r)
h =
    let rf :: forall f. ADReady f
           => f x
           -> f (ADTensorKind x)
        -- This computes the derivative of g again for each new a.
        rf :: forall (f :: Target). ADReady f => f x -> f (ADTensorKind x)
rf !f x
a = f x -> (f x -> f (ADTensorKind x)) -> f (ADTensorKind x)
forall (x :: TK) (z :: TK). f x -> (f x -> f z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
LetTensor target =>
target x -> (target x -> target z) -> target z
ttlet f x
a ((f x -> f (ADTensorKind x)) -> f (ADTensorKind x))
-> (f x -> f (ADTensorKind x)) -> f (ADTensorKind x)
forall a b. (a -> b) -> a -> b
$ \ !f x
aShared ->
          ShareOf f (ADTensorKind x) -> f (ADTensorKind x)
forall (y :: TK). ShareOf f y -> f y
forall (target :: Target) (y :: TK).
LetTensor target =>
ShareOf target y -> target y
tunshare (ShareOf f (ADTensorKind x) -> f (ADTensorKind x))
-> ShareOf f (ADTensorKind x) -> f (ADTensorKind x)
forall a b. (a -> b) -> a -> b
$ (ShareOf f (ADTensorKind x), ShareOf f (TKScalar r))
-> ShareOf f (ADTensorKind x)
forall a b. (a, b) -> a
fst ((ShareOf f (ADTensorKind x), ShareOf f (TKScalar r))
 -> ShareOf f (ADTensorKind x))
-> (ShareOf f (ADTensorKind x), ShareOf f (TKScalar r))
-> ShareOf f (ADTensorKind x)
forall a b. (a -> b) -> a -> b
$ Maybe (ShareOf f (ADTensorKind (TKScalar r)))
-> (ADVal (ShareOf f) x -> ADVal (ShareOf f) (TKScalar r))
-> FullShapeTK x
-> ShareOf f x
-> (ShareOf f (ADTensorKind x), ShareOf f (TKScalar r))
forall (x :: TK) (z :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target) =>
Maybe (target (ADTensorKind z))
-> (ADVal target x -> ADVal target z)
-> FullShapeTK x
-> target x
-> (target (ADTensorKind x), target z)
crevOnParams
                             Maybe (ShareOf f (ADTensorKind (TKScalar r)))
Maybe (ShareOf f (TKScalar (ADTensorScalar r)))
forall a. Maybe a
Nothing
                             (HFun x (TKScalar r)
-> forall (f :: Target). ADReady f => f x -> f (TKScalar r)
forall (x :: TK) (z :: TK).
HFun x z -> forall (f :: Target). ADReady f => f x -> f z
unHFun HFun x (TKScalar r)
h @(ADVal (ShareOf f)))
                             FullShapeTK x
xftk
                             (f x -> ShareOf f x
forall (y :: TK). f y -> ShareOf f y
forall (target :: Target) (y :: TK).
LetTensor target =>
target y -> ShareOf target y
toShare f x
aShared)
    in (forall (f :: Target). ADReady f => f x -> f (ADTensorKind x))
-> HFun x (ADTensorKind x)
forall (x :: TK) (z :: TK).
(forall (f :: Target). ADReady f => f x -> f z) -> HFun x z
HFun f x -> f (ADTensorKind x)
forall (f :: Target). ADReady f => f x -> f (ADTensorKind x)
rf
  tvjp :: forall (x :: TK) (z :: TK).
FullShapeTK x
-> HFun x z
-> HFunOf
     (ADVal target) (TKProduct (ADTensorKind z) x) (ADTensorKind x)
tvjp @x @z FullShapeTK x
xftk HFun x z
h =
    let rf :: forall f. ADReady f
           => f (TKProduct (ADTensorKind z) x)
           -> f (ADTensorKind x)
        -- This computes the derivative of g again for each new db and a.
        rf :: forall (f :: Target).
ADReady f =>
f (TKProduct (ADTensorKind z) x) -> f (ADTensorKind x)
rf !f (TKProduct (ADTensorKind z) x)
db_a = f (TKProduct (ADTensorKind z) x)
-> (f (TKProduct (ADTensorKind z) x) -> f (ADTensorKind x))
-> f (ADTensorKind x)
forall (x :: TK) (z :: TK). f x -> (f x -> f z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
LetTensor target =>
target x -> (target x -> target z) -> target z
ttlet f (TKProduct (ADTensorKind z) x)
db_a ((f (TKProduct (ADTensorKind z) x) -> f (ADTensorKind x))
 -> f (ADTensorKind x))
-> (f (TKProduct (ADTensorKind z) x) -> f (ADTensorKind x))
-> f (ADTensorKind x)
forall a b. (a -> b) -> a -> b
$ \ !f (TKProduct (ADTensorKind z) x)
db_aShared ->
          ShareOf f (ADTensorKind x) -> f (ADTensorKind x)
forall (y :: TK). ShareOf f y -> f y
forall (target :: Target) (y :: TK).
LetTensor target =>
ShareOf target y -> target y
tunshare (ShareOf f (ADTensorKind x) -> f (ADTensorKind x))
-> ShareOf f (ADTensorKind x) -> f (ADTensorKind x)
forall a b. (a -> b) -> a -> b
$ (ShareOf f (ADTensorKind x), ShareOf f z)
-> ShareOf f (ADTensorKind x)
forall a b. (a, b) -> a
fst ((ShareOf f (ADTensorKind x), ShareOf f z)
 -> ShareOf f (ADTensorKind x))
-> (ShareOf f (ADTensorKind x), ShareOf f z)
-> ShareOf f (ADTensorKind x)
forall a b. (a -> b) -> a -> b
$ Maybe (ShareOf f (ADTensorKind z))
-> (ADVal (ShareOf f) x -> ADVal (ShareOf f) z)
-> FullShapeTK x
-> ShareOf f x
-> (ShareOf f (ADTensorKind x), ShareOf f z)
forall (x :: TK) (z :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target) =>
Maybe (target (ADTensorKind z))
-> (ADVal target x -> ADVal target z)
-> FullShapeTK x
-> target x
-> (target (ADTensorKind x), target z)
crevOnParams
                             (ShareOf f (ADTensorKind z) -> Maybe (ShareOf f (ADTensorKind z))
forall a. a -> Maybe a
Just (ShareOf f (ADTensorKind z) -> Maybe (ShareOf f (ADTensorKind z)))
-> ShareOf f (ADTensorKind z) -> Maybe (ShareOf f (ADTensorKind z))
forall a b. (a -> b) -> a -> b
$ f (ADTensorKind z) -> ShareOf f (ADTensorKind z)
forall (y :: TK). f y -> ShareOf f y
forall (target :: Target) (y :: TK).
LetTensor target =>
target y -> ShareOf target y
toShare (f (ADTensorKind z) -> ShareOf f (ADTensorKind z))
-> f (ADTensorKind z) -> ShareOf f (ADTensorKind z)
forall a b. (a -> b) -> a -> b
$ f (TKProduct (ADTensorKind z) x) -> f (ADTensorKind z)
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 f (TKProduct (ADTensorKind z) x)
db_aShared)
                             (HFun x z -> forall (f :: Target). ADReady f => f x -> f z
forall (x :: TK) (z :: TK).
HFun x z -> forall (f :: Target). ADReady f => f x -> f z
unHFun HFun x z
h @(ADVal (ShareOf f)))
                             FullShapeTK x
xftk
                             (f x -> ShareOf f x
forall (y :: TK). f y -> ShareOf f y
forall (target :: Target) (y :: TK).
LetTensor target =>
target y -> ShareOf target y
toShare (f x -> ShareOf f x) -> f x -> ShareOf f x
forall a b. (a -> b) -> a -> b
$ f (TKProduct (ADTensorKind z) x) -> f x
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target z
tproject2 f (TKProduct (ADTensorKind z) x)
db_aShared)
    in (forall (f :: Target).
 ADReady f =>
 f (TKProduct (ADTensorKind z) x) -> f (ADTensorKind x))
-> HFun (TKProduct (ADTensorKind z) x) (ADTensorKind x)
forall (x :: TK) (z :: TK).
(forall (f :: Target). ADReady f => f x -> f z) -> HFun x z
HFun f (TKProduct (ADTensorKind z) x) -> f (ADTensorKind x)
forall (f :: Target).
ADReady f =>
f (TKProduct (ADTensorKind z) x) -> f (ADTensorKind x)
rf
  tjvp :: forall (x :: TK) (z :: TK).
FullShapeTK x
-> HFun x z
-> HFunOf
     (ADVal target) (TKProduct (ADTensorKind x) x) (ADTensorKind z)
tjvp @x @z FullShapeTK x
xftk HFun x z
h =
    let df :: forall f. ADReady f
           => f (TKProduct (ADTensorKind x) x)
           -> f (ADTensorKind z)
        -- This computes the derivative of g again for each new da and a.
        df :: forall (f :: Target).
ADReady f =>
f (TKProduct (ADTensorKind x) x) -> f (ADTensorKind z)
df !f (TKProduct (ADTensorKind x) x)
da_a = f (TKProduct (ADTensorKind x) x)
-> (f (TKProduct (ADTensorKind x) x) -> f (ADTensorKind z))
-> f (ADTensorKind z)
forall (x :: TK) (z :: TK). f x -> (f x -> f z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
LetTensor target =>
target x -> (target x -> target z) -> target z
ttlet f (TKProduct (ADTensorKind x) x)
da_a ((f (TKProduct (ADTensorKind x) x) -> f (ADTensorKind z))
 -> f (ADTensorKind z))
-> (f (TKProduct (ADTensorKind x) x) -> f (ADTensorKind z))
-> f (ADTensorKind z)
forall a b. (a -> b) -> a -> b
$ \ !f (TKProduct (ADTensorKind x) x)
da_aShared ->
          ShareOf f (ADTensorKind z) -> f (ADTensorKind z)
forall (y :: TK). ShareOf f y -> f y
forall (target :: Target) (y :: TK).
LetTensor target =>
ShareOf target y -> target y
tunshare (ShareOf f (ADTensorKind z) -> f (ADTensorKind z))
-> ShareOf f (ADTensorKind z) -> f (ADTensorKind z)
forall a b. (a -> b) -> a -> b
$ (ShareOf f (ADTensorKind z), ShareOf f z)
-> ShareOf f (ADTensorKind z)
forall a b. (a, b) -> a
fst ((ShareOf f (ADTensorKind z), ShareOf f z)
 -> ShareOf f (ADTensorKind z))
-> (ShareOf f (ADTensorKind z), ShareOf f z)
-> ShareOf f (ADTensorKind z)
forall a b. (a -> b) -> a -> b
$ FullShapeTK x
-> ShareOf f x
-> (ADVal (ShareOf f) x -> ADVal (ShareOf f) z)
-> ShareOf f (ADTensorKind x)
-> (ShareOf f (ADTensorKind z), ShareOf f z)
forall (x :: TK) (z :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target) =>
FullShapeTK x
-> target x
-> (ADVal target x -> ADVal target z)
-> target (ADTensorKind x)
-> (target (ADTensorKind z), target z)
cfwdOnParams
                             FullShapeTK x
xftk
                             (f x -> ShareOf f x
forall (y :: TK). f y -> ShareOf f y
forall (target :: Target) (y :: TK).
LetTensor target =>
target y -> ShareOf target y
toShare (f x -> ShareOf f x) -> f x -> ShareOf f x
forall a b. (a -> b) -> a -> b
$ f (TKProduct (ADTensorKind x) x) -> f x
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target z
tproject2 f (TKProduct (ADTensorKind x) x)
da_aShared)
                             (HFun x z -> forall (f :: Target). ADReady f => f x -> f z
forall (x :: TK) (z :: TK).
HFun x z -> forall (f :: Target). ADReady f => f x -> f z
unHFun HFun x z
h @(ADVal (ShareOf f)))
                             (f (ADTensorKind x) -> ShareOf f (ADTensorKind x)
forall (y :: TK). f y -> ShareOf f y
forall (target :: Target) (y :: TK).
LetTensor target =>
target y -> ShareOf target y
toShare (f (ADTensorKind x) -> ShareOf f (ADTensorKind x))
-> f (ADTensorKind x) -> ShareOf f (ADTensorKind x)
forall a b. (a -> b) -> a -> b
$ f (TKProduct (ADTensorKind x) x) -> f (ADTensorKind x)
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 f (TKProduct (ADTensorKind x) x)
da_aShared)
    in (forall (f :: Target).
 ADReady f =>
 f (TKProduct (ADTensorKind x) x) -> f (ADTensorKind z))
-> HFun (TKProduct (ADTensorKind x) x) (ADTensorKind z)
forall (x :: TK) (z :: TK).
(forall (f :: Target). ADReady f => f x -> f z) -> HFun x z
HFun f (TKProduct (ADTensorKind x) x) -> f (ADTensorKind z)
forall (f :: Target).
ADReady f =>
f (TKProduct (ADTensorKind x) x) -> f (ADTensorKind z)
df

  tfromVector :: forall (y :: TK) (k :: Nat).
SNat k
-> SingletonTK y
-> Vector (ADVal target y)
-> ADVal target (BuildTensorKind k y)
tfromVector SNat k
snat SingletonTK y
stk Vector (ADVal target y)
lu =
    target (BuildTensorKind k y)
-> Delta target (BuildTensorKind k y)
-> ADVal target (BuildTensorKind k y)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (SNat k
-> SingletonTK y
-> Vector (target y)
-> target (BuildTensorKind k y)
forall (y :: TK) (k :: Nat).
SNat k
-> SingletonTK y
-> Vector (target y)
-> target (BuildTensorKind k y)
forall (target :: Target) (y :: TK) (k :: Nat).
BaseTensor target =>
SNat k
-> SingletonTK y
-> Vector (target y)
-> target (BuildTensorKind k y)
tfromVector SNat k
snat SingletonTK y
stk (Vector (target y) -> target (BuildTensorKind k y))
-> Vector (target y) -> target (BuildTensorKind k y)
forall a b. (a -> b) -> a -> b
$ (ADVal target y -> target y)
-> Vector (ADVal target y) -> Vector (target y)
forall (v :: Type -> Type) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
V.map (\(D target y
u Delta target y
_) -> target y
u) Vector (ADVal target y)
lu)
       (SNat k
-> SingletonTK y
-> Vector (Delta target y)
-> Delta target (BuildTensorKind k y)
forall (y :: TK) (k :: Nat) (a :: Target).
SNat k
-> SingletonTK y
-> Vector (Delta a y)
-> Delta a (BuildTensorKind k y)
DeltaFromVector SNat k
snat SingletonTK y
stk (Vector (Delta target y) -> Delta target (BuildTensorKind k y))
-> Vector (Delta target y) -> Delta target (BuildTensorKind k y)
forall a b. (a -> b) -> a -> b
$ (ADVal target y -> Delta target y)
-> Vector (ADVal target y) -> Vector (Delta target y)
forall (v :: Type -> Type) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
V.map (\(D target y
_ Delta target y
u') -> Delta target y
u') Vector (ADVal target y)
lu)

  treplTarget :: forall (y :: TK).
(forall r. GoodScalar r => r) -> FullShapeTK y -> ADVal target y
treplTarget forall r. GoodScalar r => r
r FullShapeTK y
ftk = target y -> Delta target y -> ADVal target y
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dDnotShared ((forall r. GoodScalar r => r) -> FullShapeTK y -> target y
forall (y :: TK).
(forall r. GoodScalar r => r) -> FullShapeTK y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
(forall r. GoodScalar r => r) -> FullShapeTK y -> target y
treplTarget r
forall r. GoodScalar r => r
r FullShapeTK y
ftk) (FullShapeTK y -> Delta target y
forall (x :: TK) (target :: Target).
FullShapeTK x -> Delta target x
DeltaZero FullShapeTK y
ftk)
  tdefTarget :: forall (y :: TK). FullShapeTK y -> ADVal target y
tdefTarget FullShapeTK y
ftk = target y -> Delta target y -> ADVal target y
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dDnotShared (FullShapeTK y -> target y
forall (y :: TK). FullShapeTK y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
FullShapeTK y -> target y
tdefTarget FullShapeTK y
ftk) (FullShapeTK y -> Delta target y
forall (x :: TK) (target :: Target).
FullShapeTK x -> Delta target x
DeltaZero FullShapeTK y
ftk)
  taddTarget :: forall (y :: TK).
SingletonTK y -> ADVal target y -> ADVal target y -> ADVal target y
taddTarget = SingletonTK y -> ADVal target y -> ADVal target y -> ADVal target y
forall (target :: Target) (y :: TK).
(BaseTensor target, ConvertTensor target) =>
SingletonTK y -> target y -> target y -> target y
addTarget
  tmultTarget :: forall (y :: TK).
SingletonTK y -> ADVal target y -> ADVal target y -> ADVal target y
tmultTarget = SingletonTK y -> ADVal target y -> ADVal target y -> ADVal target y
forall (target :: Target) (y :: TK).
(BaseTensor target, ConvertTensor target) =>
SingletonTK y -> target y -> target y -> target y
multTarget
  tsum0Target :: forall (y :: TK).
FullShapeTK y -> ADVal target y -> ADVal target (TKScalar Double)
tsum0Target = FullShapeTK y -> ADVal target y -> ADVal target (TKScalar Double)
forall (target :: Target) (y :: TK).
(BaseTensor target, ConvertTensor target) =>
FullShapeTK y -> target y -> target (TKScalar Double)
sum0Target
  tdot0Target :: forall (y :: TK).
FullShapeTK y
-> ADVal target y
-> ADVal target y
-> ADVal target (TKScalar Double)
tdot0Target = FullShapeTK y
-> ADVal target y
-> ADVal target y
-> ADVal target (TKScalar Double)
forall (target :: Target) (y :: TK).
(BaseTensor target, ConvertTensor target) =>
FullShapeTK y -> target y -> target y -> target (TKScalar Double)
dot0Target

instance ( ADReadyNoLet target, ShareTensor target
         , ShareTensor (PrimalOf target) )
         => ConvertTensor (ADVal target) where
  tconvert :: forall (a :: TK) (b :: TK).
TKConversion a b
-> SingletonTK a -> ADVal target a -> ADVal target b
tconvert TKConversion a b
c SingletonTK a
astk (D target a
u Delta target a
u') =
    target b -> Delta target b -> ADVal target b
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dDnotShared (TKConversion a b -> SingletonTK a -> target a -> target b
forall (a :: TK) (b :: TK).
TKConversion a b -> SingletonTK a -> target a -> target b
forall (target :: Target) (a :: TK) (b :: TK).
ConvertTensor target =>
TKConversion a b -> SingletonTK a -> target a -> target b
tconvert TKConversion a b
c SingletonTK a
astk target a
u)
                (TKConversion a b -> Delta target a -> Delta target b
forall (a1 :: TK) (b :: TK) (a :: Target).
TKConversion a1 b -> Delta a a1 -> Delta a b
DeltaConvert TKConversion a b
c Delta target a
u')

  rfromX :: forall (sh :: [Maybe Nat]) (x :: TK).
KnownSTK x =>
ADVal target (TKX2 sh x)
-> ADVal target (TKR2 (Rank @(Maybe Nat) sh) x)
rfromX a :: ADVal target (TKX2 sh x)
a@(D target (TKX2 sh x)
_ Delta target (TKX2 sh x)
u') = case Delta target (TKX2 sh x) -> FullShapeTK (TKX2 sh x)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX2 sh x)
u' of
    FTKX IShX sh
sh' FullShapeTK x
_ ->
      IShX sh
-> (forall (sh :: [Nat]).
    ((Rank @Nat sh :: Nat) ~ (Rank @(Maybe Nat) sh :: Nat)) =>
    ShS sh -> ADVal target (TKR2 (Rank @(Maybe Nat) sh) x))
-> ADVal target (TKR2 (Rank @(Maybe Nat) sh) x)
forall (sh' :: [Maybe Nat]) r.
IShX sh'
-> (forall (sh :: [Nat]).
    ((Rank @Nat sh :: Nat) ~ (Rank @(Maybe Nat) sh' :: Nat)) =>
    ShS sh -> r)
-> r
withShsFromShX IShX sh
sh' ((forall (sh :: [Nat]).
  ((Rank @Nat sh :: Nat) ~ (Rank @(Maybe Nat) sh :: Nat)) =>
  ShS sh -> ADVal target (TKR2 (Rank @(Maybe Nat) sh) x))
 -> ADVal target (TKR2 (Rank @(Maybe Nat) sh) x))
-> (forall (sh :: [Nat]).
    ((Rank @Nat sh :: Nat) ~ (Rank @(Maybe Nat) sh :: Nat)) =>
    ShS sh -> ADVal target (TKR2 (Rank @(Maybe Nat) sh) x))
-> ADVal target (TKR2 (Rank @(Maybe Nat) sh) x)
forall a b. (a -> b) -> a -> b
$ \(ShS sh
sh :: ShS sh) ->
        ShS sh
-> (KnownShS sh => ADVal target (TKR2 (Rank @(Maybe Nat) sh) x))
-> ADVal target (TKR2 (Rank @(Maybe Nat) sh) x)
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS sh
sh ((KnownShS sh => ADVal target (TKR2 (Rank @(Maybe Nat) sh) x))
 -> ADVal target (TKR2 (Rank @(Maybe Nat) sh) x))
-> (KnownShS sh => ADVal target (TKR2 (Rank @(Maybe Nat) sh) x))
-> ADVal target (TKR2 (Rank @(Maybe Nat) sh) x)
forall a b. (a -> b) -> a -> b
$
        ADVal target (TKS2 sh x) -> ADVal target (TKR2 (Rank @Nat sh) x)
forall (sh :: [Nat]) (x :: TK).
(KnownShS sh, KnownSTK x) =>
ADVal target (TKS2 sh x) -> ADVal target (TKR2 (Rank @Nat sh) x)
forall (target :: Target) (sh :: [Nat]) (x :: TK).
(ConvertTensor target, KnownShS sh, KnownSTK x) =>
target (TKS2 sh x) -> target (TKR2 (Rank @Nat sh) x)
rfromS (ADVal target (TKS2 sh x) -> ADVal target (TKR2 (Rank @Nat sh) x))
-> ADVal target (TKS2 sh x) -> ADVal target (TKR2 (Rank @Nat sh) x)
forall a b. (a -> b) -> a -> b
$ forall (target :: Target) (sh :: [Nat]) (sh' :: [Maybe Nat])
       (x :: TK).
(ConvertTensor target, KnownShS sh,
 (Rank @Nat sh :: Nat) ~ (Rank @(Maybe Nat) sh' :: Nat),
 KnownSTK x) =>
target (TKX2 sh' x) -> target (TKS2 sh x)
sfromX @_ @sh ADVal target (TKX2 sh x)
a
  xfromR :: forall (sh' :: [Maybe Nat]) (x :: TK).
(KnownShX sh', KnownSTK x) =>
ADVal target (TKR2 (Rank @(Maybe Nat) sh') x)
-> ADVal target (TKX2 sh' x)
xfromR a :: ADVal target (TKR2 (Rank @(Maybe Nat) sh') x)
a@(D target (TKR2 (Rank @(Maybe Nat) sh') x)
_ Delta target (TKR2 (Rank @(Maybe Nat) sh') x)
u') = case Delta target (TKR2 (Rank @(Maybe Nat) sh') x)
-> FullShapeTK (TKR2 (Rank @(Maybe Nat) sh') x)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR2 (Rank @(Maybe Nat) sh') x)
u' of
    FTKR IShR n
shr FullShapeTK x
_ ->
      IShR n
-> (forall (sh :: [Nat]).
    ((Rank @Nat sh :: Nat) ~ (n :: Nat)) =>
    ShS sh -> ADVal target (TKX2 sh' x))
-> ADVal target (TKX2 sh' x)
forall (n :: Nat) r.
IShR n
-> (forall (sh :: [Nat]).
    ((Rank @Nat sh :: Nat) ~ (n :: Nat)) =>
    ShS sh -> r)
-> r
withShsFromShR IShR n
shr ((forall (sh :: [Nat]).
  ((Rank @Nat sh :: Nat) ~ (n :: Nat)) =>
  ShS sh -> ADVal target (TKX2 sh' x))
 -> ADVal target (TKX2 sh' x))
-> (forall (sh :: [Nat]).
    ((Rank @Nat sh :: Nat) ~ (n :: Nat)) =>
    ShS sh -> ADVal target (TKX2 sh' x))
-> ADVal target (TKX2 sh' x)
forall a b. (a -> b) -> a -> b
$ \(ShS sh
sh :: ShS sh) ->
        ShS sh
-> (KnownShS sh => ADVal target (TKX2 sh' x))
-> ADVal target (TKX2 sh' x)
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS sh
sh ((KnownShS sh => ADVal target (TKX2 sh' x))
 -> ADVal target (TKX2 sh' x))
-> (KnownShS sh => ADVal target (TKX2 sh' x))
-> ADVal target (TKX2 sh' x)
forall a b. (a -> b) -> a -> b
$
        forall (target :: Target) (sh :: [Nat]) (sh' :: [Maybe Nat])
       (x :: TK).
(ConvertTensor target, KnownShS sh, KnownShX sh',
 (Rank @Nat sh :: Nat) ~ (Rank @(Maybe Nat) sh' :: Nat),
 KnownSTK x) =>
target (TKS2 sh x) -> target (TKX2 sh' x)
xfromS @_ @sh (ADVal target (TKS2 sh x) -> ADVal target (TKX2 sh' x))
-> ADVal target (TKS2 sh x) -> ADVal target (TKX2 sh' x)
forall a b. (a -> b) -> a -> b
$ ADVal target (TKR2 (Rank @Nat sh) x) -> ADVal target (TKS2 sh x)
forall (sh :: [Nat]) (x :: TK).
(KnownShS sh, KnownSTK x) =>
ADVal target (TKR2 (Rank @Nat sh) x) -> ADVal target (TKS2 sh x)
forall (target :: Target) (sh :: [Nat]) (x :: TK).
(ConvertTensor target, KnownShS sh, KnownSTK x) =>
target (TKR2 (Rank @Nat sh) x) -> target (TKS2 sh x)
sfromR ADVal target (TKR2 (Rank @Nat sh) x)
ADVal target (TKR2 (Rank @(Maybe Nat) sh') x)
a

  sfromR :: forall (sh :: [Nat]) (x :: TK).
(KnownShS sh, KnownSTK x) =>
ADVal target (TKR2 (Rank @Nat sh) x) -> ADVal target (TKS2 sh x)
sfromR (D target (TKR2 (Rank @Nat sh) x)
u Delta target (TKR2 (Rank @Nat sh) x)
u') = target (TKS2 sh x)
-> Delta target (TKS2 sh x) -> ADVal target (TKS2 sh x)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dDnotShared (target (TKR2 (Rank @Nat sh) x) -> target (TKS2 sh x)
forall (sh :: [Nat]) (x :: TK).
(KnownShS sh, KnownSTK x) =>
target (TKR2 (Rank @Nat sh) x) -> target (TKS2 sh x)
forall (target :: Target) (sh :: [Nat]) (x :: TK).
(ConvertTensor target, KnownShS sh, KnownSTK x) =>
target (TKR2 (Rank @Nat sh) x) -> target (TKS2 sh x)
sfromR target (TKR2 (Rank @Nat sh) x)
u) (ShS sh
-> Delta target (TKR2 (Rank @Nat sh) x) -> Delta target (TKS2 sh x)
forall (sh :: [Nat]) (x :: TK) (target :: Target).
ShS sh
-> Delta target (TKR2 (Rank @Nat sh) x) -> Delta target (TKS2 sh x)
dSFromR ShS sh
forall (sh :: [Nat]). KnownShS sh => ShS sh
knownShS Delta target (TKR2 (Rank @Nat sh) x)
u')
  sfromX :: forall (sh :: [Nat]) (sh' :: [Maybe Nat]) (x :: TK).
(KnownShS sh,
 (Rank @Nat sh :: Nat) ~ (Rank @(Maybe Nat) sh' :: Nat),
 KnownSTK x) =>
ADVal target (TKX2 sh' x) -> ADVal target (TKS2 sh x)
sfromX (D target (TKX2 sh' x)
u Delta target (TKX2 sh' x)
u') = target (TKS2 sh x)
-> Delta target (TKS2 sh x) -> ADVal target (TKS2 sh x)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dDnotShared (target (TKX2 sh' x) -> target (TKS2 sh x)
forall (sh :: [Nat]) (sh' :: [Maybe Nat]) (x :: TK).
(KnownShS sh,
 (Rank @Nat sh :: Nat) ~ (Rank @(Maybe Nat) sh' :: Nat),
 KnownSTK x) =>
target (TKX2 sh' x) -> target (TKS2 sh x)
forall (target :: Target) (sh :: [Nat]) (sh' :: [Maybe Nat])
       (x :: TK).
(ConvertTensor target, KnownShS sh,
 (Rank @Nat sh :: Nat) ~ (Rank @(Maybe Nat) sh' :: Nat),
 KnownSTK x) =>
target (TKX2 sh' x) -> target (TKS2 sh x)
sfromX target (TKX2 sh' x)
u) (ShS sh -> Delta target (TKX2 sh' x) -> Delta target (TKS2 sh x)
forall (sh :: [Nat]) (sh' :: [Maybe Nat]) (x :: TK)
       (target :: Target).
((Rank @Nat sh :: Nat) ~ (Rank @(Maybe Nat) sh' :: Nat)) =>
ShS sh -> Delta target (TKX2 sh' x) -> Delta target (TKS2 sh x)
dSFromX ShS sh
forall (sh :: [Nat]). KnownShS sh => ShS sh
knownShS Delta target (TKX2 sh' x)
u')
  xfromS :: forall (sh :: [Nat]) (sh' :: [Maybe Nat]) (x :: TK).
(KnownShS sh, KnownShX sh',
 (Rank @Nat sh :: Nat) ~ (Rank @(Maybe Nat) sh' :: Nat),
 KnownSTK x) =>
ADVal target (TKS2 sh x) -> ADVal target (TKX2 sh' x)
xfromS (D target (TKS2 sh x)
u Delta target (TKS2 sh x)
u') = target (TKX2 sh' x)
-> Delta target (TKX2 sh' x) -> ADVal target (TKX2 sh' x)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dDnotShared (target (TKS2 sh x) -> target (TKX2 sh' x)
forall (sh :: [Nat]) (sh' :: [Maybe Nat]) (x :: TK).
(KnownShS sh, KnownShX sh',
 (Rank @Nat sh :: Nat) ~ (Rank @(Maybe Nat) sh' :: Nat),
 KnownSTK x) =>
target (TKS2 sh x) -> target (TKX2 sh' x)
forall (target :: Target) (sh :: [Nat]) (sh' :: [Maybe Nat])
       (x :: TK).
(ConvertTensor target, KnownShS sh, KnownShX sh',
 (Rank @Nat sh :: Nat) ~ (Rank @(Maybe Nat) sh' :: Nat),
 KnownSTK x) =>
target (TKS2 sh x) -> target (TKX2 sh' x)
xfromS target (TKS2 sh x)
u) (StaticShX sh'
-> Delta target (TKS2 sh x) -> Delta target (TKX2 sh' x)
forall (sh :: [Nat]) (sh' :: [Maybe Nat]) (x :: TK)
       (target :: Target).
((Rank @Nat sh :: Nat) ~ (Rank @(Maybe Nat) sh' :: Nat)) =>
StaticShX sh'
-> Delta target (TKS2 sh x) -> Delta target (TKX2 sh' x)
dXFromS StaticShX sh'
forall (sh :: [Maybe Nat]). KnownShX sh => StaticShX sh
knownShX Delta target (TKS2 sh x)
u')

  rzip :: forall (y :: TK) (z :: TK) (n :: Nat).
(KnownSTK y, KnownSTK z) =>
ADVal target (TKProduct (TKR2 n y) (TKR2 n z))
-> ADVal target (TKR2 n (TKProduct y z))
rzip @_ @_ @n (D target (TKProduct (TKR2 n y) (TKR2 n z))
u Delta target (TKProduct (TKR2 n y) (TKR2 n z))
u')
   | (:~:)
  @Nat
  (Rank @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)))
  n
Refl <- Proxy @Nat n
-> (:~:)
     @Nat
     (Rank @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)))
     n
forall (proxy :: Nat -> Type) (n :: Nat).
proxy n
-> (:~:)
     @Nat
     (Rank @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)))
     n
lemRankReplicate (forall (t :: Nat). Proxy @Nat t
forall {k} (t :: k). Proxy @k t
Proxy @n) = case Delta target (TKProduct (TKR2 n y) (TKR2 n z))
-> FullShapeTK (TKProduct (TKR2 n y) (TKR2 n z))
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKProduct (TKR2 n y) (TKR2 n z))
u' of
    ftk :: FullShapeTK (TKProduct (TKR2 n y) (TKR2 n z))
ftk@(FTKProduct (FTKR IShR n
_sh FullShapeTK x
y) (FTKR IShR n
_ FullShapeTK x
z)) ->
      let c :: TKConversion
  (TKProduct (TKR2 n x) (TKR2 n x)) (TKR2 n (TKProduct y z))
c = TKConversion
  (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) (TKProduct x x))
  (TKR2 n (TKProduct y z))
-> TKConversion
     (TKProduct (TKR2 n x) (TKR2 n x))
     (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) (TKProduct x x))
-> TKConversion
     (TKProduct (TKR2 n x) (TKR2 n x)) (TKR2 n (TKProduct y z))
forall (b1 :: TK) (b :: TK) (a :: TK).
TKConversion b1 b -> TKConversion a b1 -> TKConversion a b
ConvCmp
                (SingletonTK (TKProduct x x)
-> TKConversion
     (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) (TKProduct x x))
     (TKR2
        (Rank @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)))
        (TKProduct x x))
forall (a1 :: TK) (sh :: [Maybe Nat]).
SingletonTK a1
-> TKConversion (TKX2 sh a1) (TKR2 (Rank @(Maybe Nat) sh) a1)
ConvXR (FullShapeTK (TKProduct x x) -> SingletonTK (TKProduct x x)
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK (FullShapeTK x -> FullShapeTK x -> FullShapeTK (TKProduct x x)
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct FullShapeTK x
y FullShapeTK x
z)))
                (TKConversion
  (TKProduct
     (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) x)
     (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) x))
  (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) (TKProduct x x))
-> TKConversion
     (TKProduct (TKR2 n x) (TKR2 n x))
     (TKProduct
        (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) x)
        (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) x))
-> TKConversion
     (TKProduct (TKR2 n x) (TKR2 n x))
     (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) (TKProduct x x))
forall (b1 :: TK) (b :: TK) (a :: TK).
TKConversion b1 b -> TKConversion a b1 -> TKConversion a b
ConvCmp
                   (SingletonTK x
-> SingletonTK x
-> TKConversion
     (TKProduct
        (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) x)
        (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) x))
     (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) (TKProduct x x))
forall (a1 :: TK) (b1 :: TK) (sh :: [Maybe Nat]).
SingletonTK a1
-> SingletonTK b1
-> TKConversion
     (TKProduct (TKX2 sh a1) (TKX2 sh b1)) (TKX2 sh (TKProduct a1 b1))
ConvZip (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
y) (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
z))
                   (TKConversion
  (TKR2 n x) (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) x)
-> TKConversion
     (TKR2 n x) (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) x)
-> TKConversion
     (TKProduct (TKR2 n x) (TKR2 n x))
     (TKProduct
        (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) x)
        (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) x))
forall (a1 :: TK) (a' :: TK) (b1 :: TK) (b' :: TK).
TKConversion a1 a'
-> TKConversion b1 b'
-> TKConversion (TKProduct a1 b1) (TKProduct a' b')
ConvT2 TKConversion
  (TKR2 n x) (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) x)
forall (n :: Nat) (a1 :: TK).
TKConversion
  (TKR2 n a1) (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) a1)
ConvRX TKConversion
  (TKR2 n x) (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) x)
forall (n :: Nat) (a1 :: TK).
TKConversion
  (TKR2 n a1) (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) a1)
ConvRX))
      in target (TKR2 n (TKProduct y z))
-> Delta target (TKR2 n (TKProduct y z))
-> ADVal target (TKR2 n (TKProduct y z))
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (TKConversion
  (TKProduct (TKR2 n x) (TKR2 n x)) (TKR2 n (TKProduct y z))
-> SingletonTK (TKProduct (TKR2 n x) (TKR2 n x))
-> target (TKProduct (TKR2 n x) (TKR2 n x))
-> target (TKR2 n (TKProduct y z))
forall (a :: TK) (b :: TK).
TKConversion a b -> SingletonTK a -> target a -> target b
forall (target :: Target) (a :: TK) (b :: TK).
ConvertTensor target =>
TKConversion a b -> SingletonTK a -> target a -> target b
tconvert TKConversion
  (TKProduct (TKR2 n x) (TKR2 n x)) (TKR2 n (TKProduct y z))
c (FullShapeTK (TKProduct (TKR2 n x) (TKR2 n x))
-> SingletonTK (TKProduct (TKR2 n x) (TKR2 n x))
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK (TKProduct (TKR2 n y) (TKR2 n z))
FullShapeTK (TKProduct (TKR2 n x) (TKR2 n x))
ftk) target (TKProduct (TKR2 n y) (TKR2 n z))
target (TKProduct (TKR2 n x) (TKR2 n x))
u)
            (TKConversion
  (TKProduct (TKR2 n x) (TKR2 n x)) (TKR2 n (TKProduct y z))
-> Delta target (TKProduct (TKR2 n x) (TKR2 n x))
-> Delta target (TKR2 n (TKProduct y z))
forall (a1 :: TK) (b :: TK) (a :: Target).
TKConversion a1 b -> Delta a a1 -> Delta a b
DeltaConvert TKConversion
  (TKProduct (TKR2 n x) (TKR2 n x)) (TKR2 n (TKProduct y z))
c Delta target (TKProduct (TKR2 n y) (TKR2 n z))
Delta target (TKProduct (TKR2 n x) (TKR2 n x))
u')
  runzip :: forall (y :: TK) (z :: TK) (n :: Nat).
ADVal target (TKR2 n (TKProduct y z))
-> ADVal target (TKProduct (TKR2 n y) (TKR2 n z))
runzip @_ @_ @n (D target (TKR2 n (TKProduct y z))
u Delta target (TKR2 n (TKProduct y z))
u')
   | (:~:)
  @Nat
  (Rank @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)))
  n
Refl <- Proxy @Nat n
-> (:~:)
     @Nat
     (Rank @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)))
     n
forall (proxy :: Nat -> Type) (n :: Nat).
proxy n
-> (:~:)
     @Nat
     (Rank @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)))
     n
lemRankReplicate (forall (t :: Nat). Proxy @Nat t
forall {k} (t :: k). Proxy @k t
Proxy @n) = case Delta target (TKR2 n (TKProduct y z))
-> FullShapeTK (TKR2 n (TKProduct y z))
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR2 n (TKProduct y z))
u' of
    ftk :: FullShapeTK (TKR2 n (TKProduct y z))
ftk@(FTKR IShR n
_sh (FTKProduct FullShapeTK y1
y FullShapeTK z
z)) ->
      let c :: TKConversion
  (TKR2 n (TKProduct y1 z)) (TKProduct (TKR2 n y) (TKR2 n z))
c = TKConversion
  (TKProduct
     (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) y1)
     (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) z))
  (TKProduct (TKR2 n y) (TKR2 n z))
-> TKConversion
     (TKR2 n (TKProduct y1 z))
     (TKProduct
        (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) y1)
        (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) z))
-> TKConversion
     (TKR2 n (TKProduct y1 z)) (TKProduct (TKR2 n y) (TKR2 n z))
forall (b1 :: TK) (b :: TK) (a :: TK).
TKConversion b1 b -> TKConversion a b1 -> TKConversion a b
ConvCmp
                (TKConversion
  (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) y1) (TKR2 n y)
-> TKConversion
     (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) z) (TKR2 n z)
-> TKConversion
     (TKProduct
        (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) y1)
        (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) z))
     (TKProduct (TKR2 n y) (TKR2 n z))
forall (a1 :: TK) (a' :: TK) (b1 :: TK) (b' :: TK).
TKConversion a1 a'
-> TKConversion b1 b'
-> TKConversion (TKProduct a1 b1) (TKProduct a' b')
ConvT2 (SingletonTK y1
-> TKConversion
     (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) y1)
     (TKR2
        (Rank @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat))) y1)
forall (a1 :: TK) (sh :: [Maybe Nat]).
SingletonTK a1
-> TKConversion (TKX2 sh a1) (TKR2 (Rank @(Maybe Nat) sh) a1)
ConvXR (FullShapeTK y1 -> SingletonTK y1
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK y1
y)) (SingletonTK z
-> TKConversion
     (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) z)
     (TKR2
        (Rank @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat))) z)
forall (a1 :: TK) (sh :: [Maybe Nat]).
SingletonTK a1
-> TKConversion (TKX2 sh a1) (TKR2 (Rank @(Maybe Nat) sh) a1)
ConvXR (FullShapeTK z -> SingletonTK z
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK z
z)))
                (TKConversion
  (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) (TKProduct y1 z))
  (TKProduct
     (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) y1)
     (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) z))
-> TKConversion
     (TKR2 n (TKProduct y1 z))
     (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) (TKProduct y1 z))
-> TKConversion
     (TKR2 n (TKProduct y1 z))
     (TKProduct
        (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) y1)
        (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) z))
forall (b1 :: TK) (b :: TK) (a :: TK).
TKConversion b1 b -> TKConversion a b1 -> TKConversion a b
ConvCmp
                   (SingletonTK y1
-> SingletonTK z
-> TKConversion
     (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) (TKProduct y1 z))
     (TKProduct
        (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) y1)
        (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) z))
forall (a1 :: TK) (b1 :: TK) (sh :: [Maybe Nat]).
SingletonTK a1
-> SingletonTK b1
-> TKConversion
     (TKX2 sh (TKProduct a1 b1)) (TKProduct (TKX2 sh a1) (TKX2 sh b1))
ConvUnzip (FullShapeTK y1 -> SingletonTK y1
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK y1
y) (FullShapeTK z -> SingletonTK z
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK z
z))
                   TKConversion
  (TKR2 n (TKProduct y1 z))
  (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) (TKProduct y1 z))
forall (n :: Nat) (a1 :: TK).
TKConversion
  (TKR2 n a1) (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) a1)
ConvRX)
      in target (TKProduct (TKR2 n y) (TKR2 n z))
-> Delta target (TKProduct (TKR2 n y) (TKR2 n z))
-> ADVal target (TKProduct (TKR2 n y) (TKR2 n z))
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (TKConversion
  (TKR2 n (TKProduct y1 z)) (TKProduct (TKR2 n y) (TKR2 n z))
-> SingletonTK (TKR2 n (TKProduct y1 z))
-> target (TKR2 n (TKProduct y1 z))
-> target (TKProduct (TKR2 n y) (TKR2 n z))
forall (a :: TK) (b :: TK).
TKConversion a b -> SingletonTK a -> target a -> target b
forall (target :: Target) (a :: TK) (b :: TK).
ConvertTensor target =>
TKConversion a b -> SingletonTK a -> target a -> target b
tconvert TKConversion
  (TKR2 n (TKProduct y1 z)) (TKProduct (TKR2 n y) (TKR2 n z))
c (FullShapeTK (TKR2 n (TKProduct y1 z))
-> SingletonTK (TKR2 n (TKProduct y1 z))
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK (TKR2 n (TKProduct y z))
FullShapeTK (TKR2 n (TKProduct y1 z))
ftk) target (TKR2 n (TKProduct y z))
target (TKR2 n (TKProduct y1 z))
u)
            (TKConversion
  (TKR2 n (TKProduct y1 z)) (TKProduct (TKR2 n y) (TKR2 n z))
-> Delta target (TKR2 n (TKProduct y1 z))
-> Delta target (TKProduct (TKR2 n y) (TKR2 n z))
forall (a1 :: TK) (b :: TK) (a :: Target).
TKConversion a1 b -> Delta a a1 -> Delta a b
DeltaConvert TKConversion
  (TKR2 n (TKProduct y1 z)) (TKProduct (TKR2 n y) (TKR2 n z))
c Delta target (TKR2 n (TKProduct y z))
Delta target (TKR2 n (TKProduct y1 z))
u')
  szip :: forall (y :: TK) (z :: TK) (sh :: [Nat]).
(KnownSTK y, KnownSTK z) =>
ADVal target (TKProduct (TKS2 sh y) (TKS2 sh z))
-> ADVal target (TKS2 sh (TKProduct y z))
szip (D target (TKProduct (TKS2 sh y) (TKS2 sh z))
u Delta target (TKProduct (TKS2 sh y) (TKS2 sh z))
u') = case Delta target (TKProduct (TKS2 sh y) (TKS2 sh z))
-> FullShapeTK (TKProduct (TKS2 sh y) (TKS2 sh z))
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKProduct (TKS2 sh y) (TKS2 sh z))
u' of
    ftk :: FullShapeTK (TKProduct (TKS2 sh y) (TKS2 sh z))
ftk@(FTKProduct (FTKS ShS sh
_sh FullShapeTK x
y) (FTKS ShS sh
_ FullShapeTK x
z)) ->
      let c :: TKConversion
  (TKProduct (TKS2 sh x) (TKS2 sh x)) (TKS2 sh (TKProduct y z))
c = TKConversion
  (TKX2 (MapJust @Nat sh) (TKProduct x x)) (TKS2 sh (TKProduct y z))
-> TKConversion
     (TKProduct (TKS2 sh x) (TKS2 sh x))
     (TKX2 (MapJust @Nat sh) (TKProduct x x))
-> TKConversion
     (TKProduct (TKS2 sh x) (TKS2 sh x)) (TKS2 sh (TKProduct y z))
forall (b1 :: TK) (b :: TK) (a :: TK).
TKConversion b1 b -> TKConversion a b1 -> TKConversion a b
ConvCmp
                TKConversion
  (TKX2 (MapJust @Nat sh) (TKProduct y z)) (TKS2 sh (TKProduct y z))
TKConversion
  (TKX2 (MapJust @Nat sh) (TKProduct x x)) (TKS2 sh (TKProduct y z))
forall (sh :: [Nat]) (a1 :: TK).
TKConversion (TKX2 (MapJust @Nat sh) a1) (TKS2 sh a1)
ConvXS
                (TKConversion
  (TKProduct (TKX2 (MapJust @Nat sh) x) (TKX2 (MapJust @Nat sh) x))
  (TKX2 (MapJust @Nat sh) (TKProduct x x))
-> TKConversion
     (TKProduct (TKS2 sh x) (TKS2 sh x))
     (TKProduct (TKX2 (MapJust @Nat sh) x) (TKX2 (MapJust @Nat sh) x))
-> TKConversion
     (TKProduct (TKS2 sh x) (TKS2 sh x))
     (TKX2 (MapJust @Nat sh) (TKProduct x x))
forall (b1 :: TK) (b :: TK) (a :: TK).
TKConversion b1 b -> TKConversion a b1 -> TKConversion a b
ConvCmp
                   (SingletonTK x
-> SingletonTK x
-> TKConversion
     (TKProduct (TKX2 (MapJust @Nat sh) x) (TKX2 (MapJust @Nat sh) x))
     (TKX2 (MapJust @Nat sh) (TKProduct x x))
forall (a1 :: TK) (b1 :: TK) (sh :: [Maybe Nat]).
SingletonTK a1
-> SingletonTK b1
-> TKConversion
     (TKProduct (TKX2 sh a1) (TKX2 sh b1)) (TKX2 sh (TKProduct a1 b1))
ConvZip (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
y) (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
z))
                   (TKConversion (TKS2 sh x) (TKX2 (MapJust @Nat sh) x)
-> TKConversion (TKS2 sh x) (TKX2 (MapJust @Nat sh) x)
-> TKConversion
     (TKProduct (TKS2 sh x) (TKS2 sh x))
     (TKProduct (TKX2 (MapJust @Nat sh) x) (TKX2 (MapJust @Nat sh) x))
forall (a1 :: TK) (a' :: TK) (b1 :: TK) (b' :: TK).
TKConversion a1 a'
-> TKConversion b1 b'
-> TKConversion (TKProduct a1 b1) (TKProduct a' b')
ConvT2 TKConversion (TKS2 sh x) (TKX2 (MapJust @Nat sh) x)
forall (sh :: [Nat]) (a1 :: TK).
TKConversion (TKS2 sh a1) (TKX2 (MapJust @Nat sh) a1)
ConvSX TKConversion (TKS2 sh x) (TKX2 (MapJust @Nat sh) x)
forall (sh :: [Nat]) (a1 :: TK).
TKConversion (TKS2 sh a1) (TKX2 (MapJust @Nat sh) a1)
ConvSX))
      in target (TKS2 sh (TKProduct y z))
-> Delta target (TKS2 sh (TKProduct y z))
-> ADVal target (TKS2 sh (TKProduct y z))
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (TKConversion
  (TKProduct (TKS2 sh x) (TKS2 sh x)) (TKS2 sh (TKProduct y z))
-> SingletonTK (TKProduct (TKS2 sh x) (TKS2 sh x))
-> target (TKProduct (TKS2 sh x) (TKS2 sh x))
-> target (TKS2 sh (TKProduct y z))
forall (a :: TK) (b :: TK).
TKConversion a b -> SingletonTK a -> target a -> target b
forall (target :: Target) (a :: TK) (b :: TK).
ConvertTensor target =>
TKConversion a b -> SingletonTK a -> target a -> target b
tconvert TKConversion
  (TKProduct (TKS2 sh x) (TKS2 sh x)) (TKS2 sh (TKProduct y z))
c (FullShapeTK (TKProduct (TKS2 sh x) (TKS2 sh x))
-> SingletonTK (TKProduct (TKS2 sh x) (TKS2 sh x))
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK (TKProduct (TKS2 sh y) (TKS2 sh z))
FullShapeTK (TKProduct (TKS2 sh x) (TKS2 sh x))
ftk) target (TKProduct (TKS2 sh y) (TKS2 sh z))
target (TKProduct (TKS2 sh x) (TKS2 sh x))
u)
            (TKConversion
  (TKProduct (TKS2 sh x) (TKS2 sh x)) (TKS2 sh (TKProduct y z))
-> Delta target (TKProduct (TKS2 sh x) (TKS2 sh x))
-> Delta target (TKS2 sh (TKProduct y z))
forall (a1 :: TK) (b :: TK) (a :: Target).
TKConversion a1 b -> Delta a a1 -> Delta a b
DeltaConvert TKConversion
  (TKProduct (TKS2 sh x) (TKS2 sh x)) (TKS2 sh (TKProduct y z))
c Delta target (TKProduct (TKS2 sh y) (TKS2 sh z))
Delta target (TKProduct (TKS2 sh x) (TKS2 sh x))
u')
  sunzip :: forall (y :: TK) (z :: TK) (sh :: [Nat]).
ADVal target (TKS2 sh (TKProduct y z))
-> ADVal target (TKProduct (TKS2 sh y) (TKS2 sh z))
sunzip (D target (TKS2 sh (TKProduct y z))
u Delta target (TKS2 sh (TKProduct y z))
u') = case Delta target (TKS2 sh (TKProduct y z))
-> FullShapeTK (TKS2 sh (TKProduct y z))
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS2 sh (TKProduct y z))
u' of
    ftk :: FullShapeTK (TKS2 sh (TKProduct y z))
ftk@(FTKS ShS sh
_sh (FTKProduct FullShapeTK y1
y FullShapeTK z
z)) ->
      let c :: TKConversion
  (TKS2 sh (TKProduct y1 z)) (TKProduct (TKS2 sh y) (TKS2 sh z))
c = TKConversion
  (TKProduct (TKX2 (MapJust @Nat sh) y1) (TKX2 (MapJust @Nat sh) z))
  (TKProduct (TKS2 sh y) (TKS2 sh z))
-> TKConversion
     (TKS2 sh (TKProduct y1 z))
     (TKProduct (TKX2 (MapJust @Nat sh) y1) (TKX2 (MapJust @Nat sh) z))
-> TKConversion
     (TKS2 sh (TKProduct y1 z)) (TKProduct (TKS2 sh y) (TKS2 sh z))
forall (b1 :: TK) (b :: TK) (a :: TK).
TKConversion b1 b -> TKConversion a b1 -> TKConversion a b
ConvCmp
                (TKConversion (TKX2 (MapJust @Nat sh) y1) (TKS2 sh y)
-> TKConversion (TKX2 (MapJust @Nat sh) z) (TKS2 sh z)
-> TKConversion
     (TKProduct (TKX2 (MapJust @Nat sh) y1) (TKX2 (MapJust @Nat sh) z))
     (TKProduct (TKS2 sh y) (TKS2 sh z))
forall (a1 :: TK) (a' :: TK) (b1 :: TK) (b' :: TK).
TKConversion a1 a'
-> TKConversion b1 b'
-> TKConversion (TKProduct a1 b1) (TKProduct a' b')
ConvT2 TKConversion (TKX2 (MapJust @Nat sh) y) (TKS2 sh y)
TKConversion (TKX2 (MapJust @Nat sh) y1) (TKS2 sh y)
forall (sh :: [Nat]) (a1 :: TK).
TKConversion (TKX2 (MapJust @Nat sh) a1) (TKS2 sh a1)
ConvXS TKConversion (TKX2 (MapJust @Nat sh) z) (TKS2 sh z)
TKConversion (TKX2 (MapJust @Nat sh) z) (TKS2 sh z)
forall (sh :: [Nat]) (a1 :: TK).
TKConversion (TKX2 (MapJust @Nat sh) a1) (TKS2 sh a1)
ConvXS)
                (TKConversion
  (TKX2 (MapJust @Nat sh) (TKProduct y1 z))
  (TKProduct (TKX2 (MapJust @Nat sh) y1) (TKX2 (MapJust @Nat sh) z))
-> TKConversion
     (TKS2 sh (TKProduct y1 z))
     (TKX2 (MapJust @Nat sh) (TKProduct y1 z))
-> TKConversion
     (TKS2 sh (TKProduct y1 z))
     (TKProduct (TKX2 (MapJust @Nat sh) y1) (TKX2 (MapJust @Nat sh) z))
forall (b1 :: TK) (b :: TK) (a :: TK).
TKConversion b1 b -> TKConversion a b1 -> TKConversion a b
ConvCmp
                   (SingletonTK y1
-> SingletonTK z
-> TKConversion
     (TKX2 (MapJust @Nat sh) (TKProduct y1 z))
     (TKProduct (TKX2 (MapJust @Nat sh) y1) (TKX2 (MapJust @Nat sh) z))
forall (a1 :: TK) (b1 :: TK) (sh :: [Maybe Nat]).
SingletonTK a1
-> SingletonTK b1
-> TKConversion
     (TKX2 sh (TKProduct a1 b1)) (TKProduct (TKX2 sh a1) (TKX2 sh b1))
ConvUnzip (FullShapeTK y1 -> SingletonTK y1
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK y1
y) (FullShapeTK z -> SingletonTK z
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK z
z))
                   TKConversion
  (TKS2 sh (TKProduct y1 z))
  (TKX2 (MapJust @Nat sh) (TKProduct y1 z))
forall (sh :: [Nat]) (a1 :: TK).
TKConversion (TKS2 sh a1) (TKX2 (MapJust @Nat sh) a1)
ConvSX)
      in target (TKProduct (TKS2 sh y) (TKS2 sh z))
-> Delta target (TKProduct (TKS2 sh y) (TKS2 sh z))
-> ADVal target (TKProduct (TKS2 sh y) (TKS2 sh z))
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (TKConversion
  (TKS2 sh (TKProduct y1 z)) (TKProduct (TKS2 sh y) (TKS2 sh z))
-> SingletonTK (TKS2 sh (TKProduct y1 z))
-> target (TKS2 sh (TKProduct y1 z))
-> target (TKProduct (TKS2 sh y) (TKS2 sh z))
forall (a :: TK) (b :: TK).
TKConversion a b -> SingletonTK a -> target a -> target b
forall (target :: Target) (a :: TK) (b :: TK).
ConvertTensor target =>
TKConversion a b -> SingletonTK a -> target a -> target b
tconvert TKConversion
  (TKS2 sh (TKProduct y1 z)) (TKProduct (TKS2 sh y) (TKS2 sh z))
c (FullShapeTK (TKS2 sh (TKProduct y1 z))
-> SingletonTK (TKS2 sh (TKProduct y1 z))
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK (TKS2 sh (TKProduct y z))
FullShapeTK (TKS2 sh (TKProduct y1 z))
ftk) target (TKS2 sh (TKProduct y z))
target (TKS2 sh (TKProduct y1 z))
u)
            (TKConversion
  (TKS2 sh (TKProduct y1 z)) (TKProduct (TKS2 sh y) (TKS2 sh z))
-> Delta target (TKS2 sh (TKProduct y1 z))
-> Delta target (TKProduct (TKS2 sh y) (TKS2 sh z))
forall (a1 :: TK) (b :: TK) (a :: Target).
TKConversion a1 b -> Delta a a1 -> Delta a b
DeltaConvert TKConversion
  (TKS2 sh (TKProduct y1 z)) (TKProduct (TKS2 sh y) (TKS2 sh z))
c Delta target (TKS2 sh (TKProduct y z))
Delta target (TKS2 sh (TKProduct y1 z))
u')
  xzip :: forall (y :: TK) (z :: TK) (sh :: [Maybe Nat]).
(KnownSTK y, KnownSTK z) =>
ADVal target (TKProduct (TKX2 sh y) (TKX2 sh z))
-> ADVal target (TKX2 sh (TKProduct y z))
xzip (D target (TKProduct (TKX2 sh y) (TKX2 sh z))
u Delta target (TKProduct (TKX2 sh y) (TKX2 sh z))
u') = case Delta target (TKProduct (TKX2 sh y) (TKX2 sh z))
-> FullShapeTK (TKProduct (TKX2 sh y) (TKX2 sh z))
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKProduct (TKX2 sh y) (TKX2 sh z))
u' of
    ftk :: FullShapeTK (TKProduct (TKX2 sh y) (TKX2 sh z))
ftk@(FTKProduct (FTKX IShX sh
_sh FullShapeTK x
y) (FTKX IShX sh
_ FullShapeTK x
z)) ->
      let c :: TKConversion
  (TKProduct (TKX2 sh x) (TKX2 sh x)) (TKX2 sh (TKProduct x x))
c = SingletonTK x
-> SingletonTK x
-> TKConversion
     (TKProduct (TKX2 sh x) (TKX2 sh x)) (TKX2 sh (TKProduct x x))
forall (a1 :: TK) (b1 :: TK) (sh :: [Maybe Nat]).
SingletonTK a1
-> SingletonTK b1
-> TKConversion
     (TKProduct (TKX2 sh a1) (TKX2 sh b1)) (TKX2 sh (TKProduct a1 b1))
ConvZip (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
y) (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
z)
      in target (TKX2 sh (TKProduct y z))
-> Delta target (TKX2 sh (TKProduct y z))
-> ADVal target (TKX2 sh (TKProduct y z))
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (TKConversion
  (TKProduct (TKX2 sh x) (TKX2 sh x)) (TKX2 sh (TKProduct y z))
-> SingletonTK (TKProduct (TKX2 sh x) (TKX2 sh x))
-> target (TKProduct (TKX2 sh x) (TKX2 sh x))
-> target (TKX2 sh (TKProduct y z))
forall (a :: TK) (b :: TK).
TKConversion a b -> SingletonTK a -> target a -> target b
forall (target :: Target) (a :: TK) (b :: TK).
ConvertTensor target =>
TKConversion a b -> SingletonTK a -> target a -> target b
tconvert TKConversion
  (TKProduct (TKX2 sh x) (TKX2 sh x)) (TKX2 sh (TKProduct y z))
TKConversion
  (TKProduct (TKX2 sh x) (TKX2 sh x)) (TKX2 sh (TKProduct x x))
c (FullShapeTK (TKProduct (TKX2 sh x) (TKX2 sh x))
-> SingletonTK (TKProduct (TKX2 sh x) (TKX2 sh x))
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK (TKProduct (TKX2 sh y) (TKX2 sh z))
FullShapeTK (TKProduct (TKX2 sh x) (TKX2 sh x))
ftk) target (TKProduct (TKX2 sh y) (TKX2 sh z))
target (TKProduct (TKX2 sh x) (TKX2 sh x))
u)
            (TKConversion
  (TKProduct (TKX2 sh x) (TKX2 sh x)) (TKX2 sh (TKProduct y z))
-> Delta target (TKProduct (TKX2 sh x) (TKX2 sh x))
-> Delta target (TKX2 sh (TKProduct y z))
forall (a1 :: TK) (b :: TK) (a :: Target).
TKConversion a1 b -> Delta a a1 -> Delta a b
DeltaConvert TKConversion
  (TKProduct (TKX2 sh x) (TKX2 sh x)) (TKX2 sh (TKProduct y z))
TKConversion
  (TKProduct (TKX2 sh x) (TKX2 sh x)) (TKX2 sh (TKProduct x x))
c Delta target (TKProduct (TKX2 sh y) (TKX2 sh z))
Delta target (TKProduct (TKX2 sh x) (TKX2 sh x))
u')
  xunzip :: forall (y :: TK) (z :: TK) (sh :: [Maybe Nat]).
ADVal target (TKX2 sh (TKProduct y z))
-> ADVal target (TKProduct (TKX2 sh y) (TKX2 sh z))
xunzip (D target (TKX2 sh (TKProduct y z))
u Delta target (TKX2 sh (TKProduct y z))
u') = case Delta target (TKX2 sh (TKProduct y z))
-> FullShapeTK (TKX2 sh (TKProduct y z))
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX2 sh (TKProduct y z))
u' of
    ftk :: FullShapeTK (TKX2 sh (TKProduct y z))
ftk@(FTKX IShX sh
_sh (FTKProduct FullShapeTK y1
y FullShapeTK z
z)) ->
      let c :: TKConversion
  (TKX2 sh (TKProduct y1 z)) (TKProduct (TKX2 sh y1) (TKX2 sh z))
c = SingletonTK y1
-> SingletonTK z
-> TKConversion
     (TKX2 sh (TKProduct y1 z)) (TKProduct (TKX2 sh y1) (TKX2 sh z))
forall (a1 :: TK) (b1 :: TK) (sh :: [Maybe Nat]).
SingletonTK a1
-> SingletonTK b1
-> TKConversion
     (TKX2 sh (TKProduct a1 b1)) (TKProduct (TKX2 sh a1) (TKX2 sh b1))
ConvUnzip (FullShapeTK y1 -> SingletonTK y1
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK y1
y) (FullShapeTK z -> SingletonTK z
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK z
z)
      in target (TKProduct (TKX2 sh y) (TKX2 sh z))
-> Delta target (TKProduct (TKX2 sh y) (TKX2 sh z))
-> ADVal target (TKProduct (TKX2 sh y) (TKX2 sh z))
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (TKConversion
  (TKX2 sh (TKProduct y1 z)) (TKProduct (TKX2 sh y) (TKX2 sh z))
-> SingletonTK (TKX2 sh (TKProduct y1 z))
-> target (TKX2 sh (TKProduct y1 z))
-> target (TKProduct (TKX2 sh y) (TKX2 sh z))
forall (a :: TK) (b :: TK).
TKConversion a b -> SingletonTK a -> target a -> target b
forall (target :: Target) (a :: TK) (b :: TK).
ConvertTensor target =>
TKConversion a b -> SingletonTK a -> target a -> target b
tconvert TKConversion
  (TKX2 sh (TKProduct y1 z)) (TKProduct (TKX2 sh y) (TKX2 sh z))
TKConversion
  (TKX2 sh (TKProduct y1 z)) (TKProduct (TKX2 sh y1) (TKX2 sh z))
c (FullShapeTK (TKX2 sh (TKProduct y1 z))
-> SingletonTK (TKX2 sh (TKProduct y1 z))
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK (TKX2 sh (TKProduct y z))
FullShapeTK (TKX2 sh (TKProduct y1 z))
ftk) target (TKX2 sh (TKProduct y z))
target (TKX2 sh (TKProduct y1 z))
u)
            (TKConversion
  (TKX2 sh (TKProduct y1 z)) (TKProduct (TKX2 sh y) (TKX2 sh z))
-> Delta target (TKX2 sh (TKProduct y1 z))
-> Delta target (TKProduct (TKX2 sh y) (TKX2 sh z))
forall (a1 :: TK) (b :: TK) (a :: Target).
TKConversion a1 b -> Delta a a1 -> Delta a b
DeltaConvert TKConversion
  (TKX2 sh (TKProduct y1 z)) (TKProduct (TKX2 sh y) (TKX2 sh z))
TKConversion
  (TKX2 sh (TKProduct y1 z)) (TKProduct (TKX2 sh y1) (TKX2 sh z))
c Delta target (TKX2 sh (TKProduct y z))
Delta target (TKX2 sh (TKProduct y1 z))
u')

  xnestR :: forall (sh1 :: [Maybe Nat]) (m :: Nat) (x :: TK).
(KnownNat m, KnownSTK x) =>
StaticShX sh1
-> ADVal
     target
     (TKX2
        ((++) @(Maybe Nat) sh1 (Replicate @(Maybe Nat) m ('Nothing @Nat)))
        x)
-> ADVal target (TKX2 sh1 (TKR2 m x))
xnestR @sh1 @m @x StaticShX sh1
sh1 (D target
  (TKX2
     ((++) @(Maybe Nat) sh1 (Replicate @(Maybe Nat) m ('Nothing @Nat)))
     x)
u Delta
  target
  (TKX2
     ((++) @(Maybe Nat) sh1 (Replicate @(Maybe Nat) m ('Nothing @Nat)))
     x)
u')
    | (:~:)
  @Nat
  (Rank @(Maybe Nat) (Replicate @(Maybe Nat) m ('Nothing @Nat)))
  m
Refl <- Proxy @Nat m
-> (:~:)
     @Nat
     (Rank @(Maybe Nat) (Replicate @(Maybe Nat) m ('Nothing @Nat)))
     m
forall (proxy :: Nat -> Type) (n :: Nat).
proxy n
-> (:~:)
     @Nat
     (Rank @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)))
     n
lemRankReplicate (forall (t :: Nat). Proxy @Nat t
forall {k} (t :: k). Proxy @k t
Proxy @m) =
      let c :: TKConversion (TKX2 (sh1 ++ Replicate m Nothing) x)
                            (TKX2 sh1 (TKR2 m x))
          c :: TKConversion
  (TKX2
     ((++) @(Maybe Nat) sh1 (Replicate @(Maybe Nat) m ('Nothing @Nat)))
     x)
  (TKX2 sh1 (TKR2 m x))
c = TKConversion
  (TKX2 sh1 (TKX2 (Replicate @(Maybe Nat) m ('Nothing @Nat)) x))
  (TKX2 sh1 (TKR2 m x))
-> TKConversion
     (TKX2
        ((++) @(Maybe Nat) sh1 (Replicate @(Maybe Nat) m ('Nothing @Nat)))
        x)
     (TKX2 sh1 (TKX2 (Replicate @(Maybe Nat) m ('Nothing @Nat)) x))
-> TKConversion
     (TKX2
        ((++) @(Maybe Nat) sh1 (Replicate @(Maybe Nat) m ('Nothing @Nat)))
        x)
     (TKX2 sh1 (TKR2 m x))
forall (b1 :: TK) (b :: TK) (a :: TK).
TKConversion b1 b -> TKConversion a b1 -> TKConversion a b
ConvCmp
                (TKConversion
  (TKX2 (Replicate @(Maybe Nat) m ('Nothing @Nat)) x) (TKR2 m x)
-> TKConversion
     (TKX2 sh1 (TKX2 (Replicate @(Maybe Nat) m ('Nothing @Nat)) x))
     (TKX2 sh1 (TKR2 m x))
forall (a1 :: TK) (b1 :: TK) (sh :: [Maybe Nat]).
TKConversion a1 b1 -> TKConversion (TKX2 sh a1) (TKX2 sh b1)
ConvXX (SingletonTK x
-> TKConversion
     (TKX2 (Replicate @(Maybe Nat) m ('Nothing @Nat)) x)
     (TKR2
        (Rank @(Maybe Nat) (Replicate @(Maybe Nat) m ('Nothing @Nat))) x)
forall (a1 :: TK) (sh :: [Maybe Nat]).
SingletonTK a1
-> TKConversion (TKX2 sh a1) (TKR2 (Rank @(Maybe Nat) sh) a1)
ConvXR (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @x)))
                (forall (sh :: [Maybe Nat]) (a1 :: TK) (sh' :: [Maybe Nat]).
SingletonTK (TKX2 sh a1)
-> TKConversion
     (TKX2 ((++) @(Maybe Nat) sh sh') a1) (TKX2 sh (TKX2 sh' a1))
ConvNest @_ @_ @(Replicate m Nothing)
                          (StaticShX sh1 -> SingletonTK x -> SingletonTK (TKX2 sh1 x)
forall (sh :: [Maybe Nat]) (x :: TK).
StaticShX sh -> SingletonTK x -> SingletonTK (TKX2 sh x)
STKX StaticShX sh1
sh1 (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @x)))
      in target (TKX2 sh1 (TKR2 m x))
-> Delta target (TKX2 sh1 (TKR2 m x))
-> ADVal target (TKX2 sh1 (TKR2 m x))
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (TKConversion
  (TKX2
     ((++) @(Maybe Nat) sh1 (Replicate @(Maybe Nat) m ('Nothing @Nat)))
     x)
  (TKX2 sh1 (TKR2 m x))
-> SingletonTK
     (TKX2
        ((++) @(Maybe Nat) sh1 (Replicate @(Maybe Nat) m ('Nothing @Nat)))
        x)
-> target
     (TKX2
        ((++) @(Maybe Nat) sh1 (Replicate @(Maybe Nat) m ('Nothing @Nat)))
        x)
-> target (TKX2 sh1 (TKR2 m x))
forall (a :: TK) (b :: TK).
TKConversion a b -> SingletonTK a -> target a -> target b
forall (target :: Target) (a :: TK) (b :: TK).
ConvertTensor target =>
TKConversion a b -> SingletonTK a -> target a -> target b
tconvert TKConversion
  (TKX2
     ((++) @(Maybe Nat) sh1 (Replicate @(Maybe Nat) m ('Nothing @Nat)))
     x)
  (TKX2 sh1 (TKR2 m x))
c (FullShapeTK
  (TKX2
     ((++) @(Maybe Nat) sh1 (Replicate @(Maybe Nat) m ('Nothing @Nat)))
     x)
-> SingletonTK
     (TKX2
        ((++) @(Maybe Nat) sh1 (Replicate @(Maybe Nat) m ('Nothing @Nat)))
        x)
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK (FullShapeTK
   (TKX2
      ((++) @(Maybe Nat) sh1 (Replicate @(Maybe Nat) m ('Nothing @Nat)))
      x)
 -> SingletonTK
      (TKX2
         ((++) @(Maybe Nat) sh1 (Replicate @(Maybe Nat) m ('Nothing @Nat)))
         x))
-> FullShapeTK
     (TKX2
        ((++) @(Maybe Nat) sh1 (Replicate @(Maybe Nat) m ('Nothing @Nat)))
        x)
-> SingletonTK
     (TKX2
        ((++) @(Maybe Nat) sh1 (Replicate @(Maybe Nat) m ('Nothing @Nat)))
        x)
forall a b. (a -> b) -> a -> b
$ Delta
  target
  (TKX2
     ((++) @(Maybe Nat) sh1 (Replicate @(Maybe Nat) m ('Nothing @Nat)))
     x)
-> FullShapeTK
     (TKX2
        ((++) @(Maybe Nat) sh1 (Replicate @(Maybe Nat) m ('Nothing @Nat)))
        x)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta
  target
  (TKX2
     ((++) @(Maybe Nat) sh1 (Replicate @(Maybe Nat) m ('Nothing @Nat)))
     x)
u') target
  (TKX2
     ((++) @(Maybe Nat) sh1 (Replicate @(Maybe Nat) m ('Nothing @Nat)))
     x)
u)
            (TKConversion
  (TKX2
     ((++) @(Maybe Nat) sh1 (Replicate @(Maybe Nat) m ('Nothing @Nat)))
     x)
  (TKX2 sh1 (TKR2 m x))
-> Delta
     target
     (TKX2
        ((++) @(Maybe Nat) sh1 (Replicate @(Maybe Nat) m ('Nothing @Nat)))
        x)
-> Delta target (TKX2 sh1 (TKR2 m x))
forall (a1 :: TK) (b :: TK) (a :: Target).
TKConversion a1 b -> Delta a a1 -> Delta a b
DeltaConvert TKConversion
  (TKX2
     ((++) @(Maybe Nat) sh1 (Replicate @(Maybe Nat) m ('Nothing @Nat)))
     x)
  (TKX2 sh1 (TKR2 m x))
c Delta
  target
  (TKX2
     ((++) @(Maybe Nat) sh1 (Replicate @(Maybe Nat) m ('Nothing @Nat)))
     x)
u')
  xnestS :: forall (sh1 :: [Maybe Nat]) (sh2 :: [Nat]) (x :: TK).
(KnownShS sh2, KnownSTK x) =>
StaticShX sh1
-> ADVal target (TKX2 ((++) @(Maybe Nat) sh1 (MapJust @Nat sh2)) x)
-> ADVal target (TKX2 sh1 (TKS2 sh2 x))
xnestS @_ @_ @x StaticShX sh1
sh1 (D target (TKX2 ((++) @(Maybe Nat) sh1 (MapJust @Nat sh2)) x)
u Delta target (TKX2 ((++) @(Maybe Nat) sh1 (MapJust @Nat sh2)) x)
u') =
    let c :: TKConversion
  (TKX2 ((++) @(Maybe Nat) sh1 (MapJust @Nat sh2)) x)
  (TKX2 sh1 (TKS2 sh2 x))
c = TKConversion
  (TKX2 sh1 (TKX2 (MapJust @Nat sh2) x)) (TKX2 sh1 (TKS2 sh2 x))
-> TKConversion
     (TKX2 ((++) @(Maybe Nat) sh1 (MapJust @Nat sh2)) x)
     (TKX2 sh1 (TKX2 (MapJust @Nat sh2) x))
-> TKConversion
     (TKX2 ((++) @(Maybe Nat) sh1 (MapJust @Nat sh2)) x)
     (TKX2 sh1 (TKS2 sh2 x))
forall (b1 :: TK) (b :: TK) (a :: TK).
TKConversion b1 b -> TKConversion a b1 -> TKConversion a b
ConvCmp (TKConversion (TKX2 (MapJust @Nat sh2) x) (TKS2 sh2 x)
-> TKConversion
     (TKX2 sh1 (TKX2 (MapJust @Nat sh2) x)) (TKX2 sh1 (TKS2 sh2 x))
forall (a1 :: TK) (b1 :: TK) (sh :: [Maybe Nat]).
TKConversion a1 b1 -> TKConversion (TKX2 sh a1) (TKX2 sh b1)
ConvXX TKConversion (TKX2 (MapJust @Nat sh2) x) (TKS2 sh2 x)
forall (sh :: [Nat]) (a1 :: TK).
TKConversion (TKX2 (MapJust @Nat sh) a1) (TKS2 sh a1)
ConvXS)
                    (SingletonTK (TKX2 sh1 x)
-> TKConversion
     (TKX2 ((++) @(Maybe Nat) sh1 (MapJust @Nat sh2)) x)
     (TKX2 sh1 (TKX2 (MapJust @Nat sh2) x))
forall (sh :: [Maybe Nat]) (a1 :: TK) (sh' :: [Maybe Nat]).
SingletonTK (TKX2 sh a1)
-> TKConversion
     (TKX2 ((++) @(Maybe Nat) sh sh') a1) (TKX2 sh (TKX2 sh' a1))
ConvNest (StaticShX sh1 -> SingletonTK x -> SingletonTK (TKX2 sh1 x)
forall (sh :: [Maybe Nat]) (x :: TK).
StaticShX sh -> SingletonTK x -> SingletonTK (TKX2 sh x)
STKX StaticShX sh1
sh1 (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @x)))
    in target (TKX2 sh1 (TKS2 sh2 x))
-> Delta target (TKX2 sh1 (TKS2 sh2 x))
-> ADVal target (TKX2 sh1 (TKS2 sh2 x))
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (TKConversion
  (TKX2 ((++) @(Maybe Nat) sh1 (MapJust @Nat sh2)) x)
  (TKX2 sh1 (TKS2 sh2 x))
-> SingletonTK (TKX2 ((++) @(Maybe Nat) sh1 (MapJust @Nat sh2)) x)
-> target (TKX2 ((++) @(Maybe Nat) sh1 (MapJust @Nat sh2)) x)
-> target (TKX2 sh1 (TKS2 sh2 x))
forall (a :: TK) (b :: TK).
TKConversion a b -> SingletonTK a -> target a -> target b
forall (target :: Target) (a :: TK) (b :: TK).
ConvertTensor target =>
TKConversion a b -> SingletonTK a -> target a -> target b
tconvert TKConversion
  (TKX2 ((++) @(Maybe Nat) sh1 (MapJust @Nat sh2)) x)
  (TKX2 sh1 (TKS2 sh2 x))
c (FullShapeTK (TKX2 ((++) @(Maybe Nat) sh1 (MapJust @Nat sh2)) x)
-> SingletonTK (TKX2 ((++) @(Maybe Nat) sh1 (MapJust @Nat sh2)) x)
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK (FullShapeTK (TKX2 ((++) @(Maybe Nat) sh1 (MapJust @Nat sh2)) x)
 -> SingletonTK (TKX2 ((++) @(Maybe Nat) sh1 (MapJust @Nat sh2)) x))
-> FullShapeTK (TKX2 ((++) @(Maybe Nat) sh1 (MapJust @Nat sh2)) x)
-> SingletonTK (TKX2 ((++) @(Maybe Nat) sh1 (MapJust @Nat sh2)) x)
forall a b. (a -> b) -> a -> b
$ Delta target (TKX2 ((++) @(Maybe Nat) sh1 (MapJust @Nat sh2)) x)
-> FullShapeTK (TKX2 ((++) @(Maybe Nat) sh1 (MapJust @Nat sh2)) x)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX2 ((++) @(Maybe Nat) sh1 (MapJust @Nat sh2)) x)
u') target (TKX2 ((++) @(Maybe Nat) sh1 (MapJust @Nat sh2)) x)
u)
          (TKConversion
  (TKX2 ((++) @(Maybe Nat) sh1 (MapJust @Nat sh2)) x)
  (TKX2 sh1 (TKS2 sh2 x))
-> Delta target (TKX2 ((++) @(Maybe Nat) sh1 (MapJust @Nat sh2)) x)
-> Delta target (TKX2 sh1 (TKS2 sh2 x))
forall (a1 :: TK) (b :: TK) (a :: Target).
TKConversion a1 b -> Delta a a1 -> Delta a b
DeltaConvert TKConversion
  (TKX2 ((++) @(Maybe Nat) sh1 (MapJust @Nat sh2)) x)
  (TKX2 sh1 (TKS2 sh2 x))
c Delta target (TKX2 ((++) @(Maybe Nat) sh1 (MapJust @Nat sh2)) x)
u')
  xnest :: forall (sh1 :: [Maybe Nat]) (sh2 :: [Maybe Nat]) (x :: TK).
(KnownShX sh2, KnownSTK x) =>
StaticShX sh1
-> ADVal target (TKX2 ((++) @(Maybe Nat) sh1 sh2) x)
-> ADVal target (TKX2 sh1 (TKX2 sh2 x))
xnest @_ @_ @x StaticShX sh1
sh1 (D target (TKX2 ((++) @(Maybe Nat) sh1 sh2) x)
u Delta target (TKX2 ((++) @(Maybe Nat) sh1 sh2) x)
u') =
    let c :: TKConversion
  (TKX2 ((++) @(Maybe Nat) sh1 sh2) x) (TKX2 sh1 (TKX2 sh2 x))
c = SingletonTK (TKX2 sh1 x)
-> TKConversion
     (TKX2 ((++) @(Maybe Nat) sh1 sh2) x) (TKX2 sh1 (TKX2 sh2 x))
forall (sh :: [Maybe Nat]) (a1 :: TK) (sh' :: [Maybe Nat]).
SingletonTK (TKX2 sh a1)
-> TKConversion
     (TKX2 ((++) @(Maybe Nat) sh sh') a1) (TKX2 sh (TKX2 sh' a1))
ConvNest (StaticShX sh1 -> SingletonTK x -> SingletonTK (TKX2 sh1 x)
forall (sh :: [Maybe Nat]) (x :: TK).
StaticShX sh -> SingletonTK x -> SingletonTK (TKX2 sh x)
STKX StaticShX sh1
sh1 (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @x))
    in target (TKX2 sh1 (TKX2 sh2 x))
-> Delta target (TKX2 sh1 (TKX2 sh2 x))
-> ADVal target (TKX2 sh1 (TKX2 sh2 x))
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (TKConversion
  (TKX2 ((++) @(Maybe Nat) sh1 sh2) x) (TKX2 sh1 (TKX2 sh2 x))
-> SingletonTK (TKX2 ((++) @(Maybe Nat) sh1 sh2) x)
-> target (TKX2 ((++) @(Maybe Nat) sh1 sh2) x)
-> target (TKX2 sh1 (TKX2 sh2 x))
forall (a :: TK) (b :: TK).
TKConversion a b -> SingletonTK a -> target a -> target b
forall (target :: Target) (a :: TK) (b :: TK).
ConvertTensor target =>
TKConversion a b -> SingletonTK a -> target a -> target b
tconvert TKConversion
  (TKX2 ((++) @(Maybe Nat) sh1 sh2) x) (TKX2 sh1 (TKX2 sh2 x))
c (FullShapeTK (TKX2 ((++) @(Maybe Nat) sh1 sh2) x)
-> SingletonTK (TKX2 ((++) @(Maybe Nat) sh1 sh2) x)
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK (FullShapeTK (TKX2 ((++) @(Maybe Nat) sh1 sh2) x)
 -> SingletonTK (TKX2 ((++) @(Maybe Nat) sh1 sh2) x))
-> FullShapeTK (TKX2 ((++) @(Maybe Nat) sh1 sh2) x)
-> SingletonTK (TKX2 ((++) @(Maybe Nat) sh1 sh2) x)
forall a b. (a -> b) -> a -> b
$ Delta target (TKX2 ((++) @(Maybe Nat) sh1 sh2) x)
-> FullShapeTK (TKX2 ((++) @(Maybe Nat) sh1 sh2) x)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX2 ((++) @(Maybe Nat) sh1 sh2) x)
u') target (TKX2 ((++) @(Maybe Nat) sh1 sh2) x)
u)
          (TKConversion
  (TKX2 ((++) @(Maybe Nat) sh1 sh2) x) (TKX2 sh1 (TKX2 sh2 x))
-> Delta target (TKX2 ((++) @(Maybe Nat) sh1 sh2) x)
-> Delta target (TKX2 sh1 (TKX2 sh2 x))
forall (a1 :: TK) (b :: TK) (a :: Target).
TKConversion a1 b -> Delta a a1 -> Delta a b
DeltaConvert TKConversion
  (TKX2 ((++) @(Maybe Nat) sh1 sh2) x) (TKX2 sh1 (TKX2 sh2 x))
c Delta target (TKX2 ((++) @(Maybe Nat) sh1 sh2) x)
u')
  xunNestR :: forall (sh1 :: [Maybe Nat]) (m :: Nat) (x :: TK).
(KnownShX sh1, KnownNat m, KnownSTK x) =>
ADVal target (TKX2 sh1 (TKR2 m x))
-> ADVal
     target
     (TKX2
        ((++) @(Maybe Nat) sh1 (Replicate @(Maybe Nat) m ('Nothing @Nat)))
        x)
xunNestR (D target (TKX2 sh1 (TKR2 m x))
u Delta target (TKX2 sh1 (TKR2 m x))
u') =
    let c :: TKConversion
  (TKX2 sh (TKR2 n a1))
  (TKX2
     ((++) @(Maybe Nat) sh (Replicate @(Maybe Nat) n ('Nothing @Nat)))
     a1)
c = TKConversion
  (TKX2 sh (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) a1))
  (TKX2
     ((++) @(Maybe Nat) sh (Replicate @(Maybe Nat) n ('Nothing @Nat)))
     a1)
-> TKConversion
     (TKX2 sh (TKR2 n a1))
     (TKX2 sh (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) a1))
-> TKConversion
     (TKX2 sh (TKR2 n a1))
     (TKX2
        ((++) @(Maybe Nat) sh (Replicate @(Maybe Nat) n ('Nothing @Nat)))
        a1)
forall (b1 :: TK) (b :: TK) (a :: TK).
TKConversion b1 b -> TKConversion a b1 -> TKConversion a b
ConvCmp TKConversion
  (TKX2 sh (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) a1))
  (TKX2
     ((++) @(Maybe Nat) sh (Replicate @(Maybe Nat) n ('Nothing @Nat)))
     a1)
forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]) (a1 :: TK).
TKConversion
  (TKX2 sh (TKX2 sh' a1)) (TKX2 ((++) @(Maybe Nat) sh sh') a1)
ConvUnnest
                    (TKConversion
  (TKR2 n a1) (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) a1)
-> TKConversion
     (TKX2 sh (TKR2 n a1))
     (TKX2 sh (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) a1))
forall (a1 :: TK) (b1 :: TK) (sh :: [Maybe Nat]).
TKConversion a1 b1 -> TKConversion (TKX2 sh a1) (TKX2 sh b1)
ConvXX TKConversion
  (TKR2 n a1) (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) a1)
forall (n :: Nat) (a1 :: TK).
TKConversion
  (TKR2 n a1) (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) a1)
ConvRX)
    in target
  (TKX2
     ((++) @(Maybe Nat) sh1 (Replicate @(Maybe Nat) m ('Nothing @Nat)))
     x)
-> Delta
     target
     (TKX2
        ((++) @(Maybe Nat) sh1 (Replicate @(Maybe Nat) m ('Nothing @Nat)))
        x)
-> ADVal
     target
     (TKX2
        ((++) @(Maybe Nat) sh1 (Replicate @(Maybe Nat) m ('Nothing @Nat)))
        x)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (TKConversion
  (TKX2 sh1 (TKR2 m x))
  (TKX2
     ((++) @(Maybe Nat) sh1 (Replicate @(Maybe Nat) m ('Nothing @Nat)))
     x)
-> SingletonTK (TKX2 sh1 (TKR2 m x))
-> target (TKX2 sh1 (TKR2 m x))
-> target
     (TKX2
        ((++) @(Maybe Nat) sh1 (Replicate @(Maybe Nat) m ('Nothing @Nat)))
        x)
forall (a :: TK) (b :: TK).
TKConversion a b -> SingletonTK a -> target a -> target b
forall (target :: Target) (a :: TK) (b :: TK).
ConvertTensor target =>
TKConversion a b -> SingletonTK a -> target a -> target b
tconvert TKConversion
  (TKX2 sh1 (TKR2 m x))
  (TKX2
     ((++) @(Maybe Nat) sh1 (Replicate @(Maybe Nat) m ('Nothing @Nat)))
     x)
forall {sh :: [Maybe Nat]} {n :: Nat} {a1 :: TK}.
TKConversion
  (TKX2 sh (TKR2 n a1))
  (TKX2
     ((++) @(Maybe Nat) sh (Replicate @(Maybe Nat) n ('Nothing @Nat)))
     a1)
c (FullShapeTK (TKX2 sh1 (TKR2 m x))
-> SingletonTK (TKX2 sh1 (TKR2 m x))
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK (FullShapeTK (TKX2 sh1 (TKR2 m x))
 -> SingletonTK (TKX2 sh1 (TKR2 m x)))
-> FullShapeTK (TKX2 sh1 (TKR2 m x))
-> SingletonTK (TKX2 sh1 (TKR2 m x))
forall a b. (a -> b) -> a -> b
$ Delta target (TKX2 sh1 (TKR2 m x))
-> FullShapeTK (TKX2 sh1 (TKR2 m x))
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX2 sh1 (TKR2 m x))
u') target (TKX2 sh1 (TKR2 m x))
u)
          (TKConversion
  (TKX2 sh1 (TKR2 m x))
  (TKX2
     ((++) @(Maybe Nat) sh1 (Replicate @(Maybe Nat) m ('Nothing @Nat)))
     x)
-> Delta target (TKX2 sh1 (TKR2 m x))
-> Delta
     target
     (TKX2
        ((++) @(Maybe Nat) sh1 (Replicate @(Maybe Nat) m ('Nothing @Nat)))
        x)
forall (a1 :: TK) (b :: TK) (a :: Target).
TKConversion a1 b -> Delta a a1 -> Delta a b
DeltaConvert TKConversion
  (TKX2 sh1 (TKR2 m x))
  (TKX2
     ((++) @(Maybe Nat) sh1 (Replicate @(Maybe Nat) m ('Nothing @Nat)))
     x)
forall {sh :: [Maybe Nat]} {n :: Nat} {a1 :: TK}.
TKConversion
  (TKX2 sh (TKR2 n a1))
  (TKX2
     ((++) @(Maybe Nat) sh (Replicate @(Maybe Nat) n ('Nothing @Nat)))
     a1)
c Delta target (TKX2 sh1 (TKR2 m x))
u')
  xunNestS :: forall (sh1 :: [Maybe Nat]) (sh2 :: [Nat]) (x :: TK).
(KnownShX sh1, KnownShS sh2, KnownSTK x) =>
ADVal target (TKX2 sh1 (TKS2 sh2 x))
-> ADVal target (TKX2 ((++) @(Maybe Nat) sh1 (MapJust @Nat sh2)) x)
xunNestS (D target (TKX2 sh1 (TKS2 sh2 x))
u Delta target (TKX2 sh1 (TKS2 sh2 x))
u') =
    let c :: TKConversion
  (TKX2 sh (TKS2 sh a1))
  (TKX2 ((++) @(Maybe Nat) sh (MapJust @Nat sh)) a1)
c = TKConversion
  (TKX2 sh (TKX2 (MapJust @Nat sh) a1))
  (TKX2 ((++) @(Maybe Nat) sh (MapJust @Nat sh)) a1)
-> TKConversion
     (TKX2 sh (TKS2 sh a1)) (TKX2 sh (TKX2 (MapJust @Nat sh) a1))
-> TKConversion
     (TKX2 sh (TKS2 sh a1))
     (TKX2 ((++) @(Maybe Nat) sh (MapJust @Nat sh)) a1)
forall (b1 :: TK) (b :: TK) (a :: TK).
TKConversion b1 b -> TKConversion a b1 -> TKConversion a b
ConvCmp TKConversion
  (TKX2 sh (TKX2 (MapJust @Nat sh) a1))
  (TKX2 ((++) @(Maybe Nat) sh (MapJust @Nat sh)) a1)
forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]) (a1 :: TK).
TKConversion
  (TKX2 sh (TKX2 sh' a1)) (TKX2 ((++) @(Maybe Nat) sh sh') a1)
ConvUnnest
                   (TKConversion (TKS2 sh a1) (TKX2 (MapJust @Nat sh) a1)
-> TKConversion
     (TKX2 sh (TKS2 sh a1)) (TKX2 sh (TKX2 (MapJust @Nat sh) a1))
forall (a1 :: TK) (b1 :: TK) (sh :: [Maybe Nat]).
TKConversion a1 b1 -> TKConversion (TKX2 sh a1) (TKX2 sh b1)
ConvXX TKConversion (TKS2 sh a1) (TKX2 (MapJust @Nat sh) a1)
forall (sh :: [Nat]) (a1 :: TK).
TKConversion (TKS2 sh a1) (TKX2 (MapJust @Nat sh) a1)
ConvSX)
    in target (TKX2 ((++) @(Maybe Nat) sh1 (MapJust @Nat sh2)) x)
-> Delta target (TKX2 ((++) @(Maybe Nat) sh1 (MapJust @Nat sh2)) x)
-> ADVal target (TKX2 ((++) @(Maybe Nat) sh1 (MapJust @Nat sh2)) x)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (TKConversion
  (TKX2 sh1 (TKS2 sh2 x))
  (TKX2 ((++) @(Maybe Nat) sh1 (MapJust @Nat sh2)) x)
-> SingletonTK (TKX2 sh1 (TKS2 sh2 x))
-> target (TKX2 sh1 (TKS2 sh2 x))
-> target (TKX2 ((++) @(Maybe Nat) sh1 (MapJust @Nat sh2)) x)
forall (a :: TK) (b :: TK).
TKConversion a b -> SingletonTK a -> target a -> target b
forall (target :: Target) (a :: TK) (b :: TK).
ConvertTensor target =>
TKConversion a b -> SingletonTK a -> target a -> target b
tconvert TKConversion
  (TKX2 sh1 (TKS2 sh2 x))
  (TKX2 ((++) @(Maybe Nat) sh1 (MapJust @Nat sh2)) x)
forall {sh :: [Maybe Nat]} {sh :: [Nat]} {a1 :: TK}.
TKConversion
  (TKX2 sh (TKS2 sh a1))
  (TKX2 ((++) @(Maybe Nat) sh (MapJust @Nat sh)) a1)
c (FullShapeTK (TKX2 sh1 (TKS2 sh2 x))
-> SingletonTK (TKX2 sh1 (TKS2 sh2 x))
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK (FullShapeTK (TKX2 sh1 (TKS2 sh2 x))
 -> SingletonTK (TKX2 sh1 (TKS2 sh2 x)))
-> FullShapeTK (TKX2 sh1 (TKS2 sh2 x))
-> SingletonTK (TKX2 sh1 (TKS2 sh2 x))
forall a b. (a -> b) -> a -> b
$ Delta target (TKX2 sh1 (TKS2 sh2 x))
-> FullShapeTK (TKX2 sh1 (TKS2 sh2 x))
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX2 sh1 (TKS2 sh2 x))
u') target (TKX2 sh1 (TKS2 sh2 x))
u)
          (TKConversion
  (TKX2 sh1 (TKS2 sh2 x))
  (TKX2 ((++) @(Maybe Nat) sh1 (MapJust @Nat sh2)) x)
-> Delta target (TKX2 sh1 (TKS2 sh2 x))
-> Delta target (TKX2 ((++) @(Maybe Nat) sh1 (MapJust @Nat sh2)) x)
forall (a1 :: TK) (b :: TK) (a :: Target).
TKConversion a1 b -> Delta a a1 -> Delta a b
DeltaConvert TKConversion
  (TKX2 sh1 (TKS2 sh2 x))
  (TKX2 ((++) @(Maybe Nat) sh1 (MapJust @Nat sh2)) x)
forall {sh :: [Maybe Nat]} {sh :: [Nat]} {a1 :: TK}.
TKConversion
  (TKX2 sh (TKS2 sh a1))
  (TKX2 ((++) @(Maybe Nat) sh (MapJust @Nat sh)) a1)
c Delta target (TKX2 sh1 (TKS2 sh2 x))
u')
  xunNest :: forall (sh1 :: [Maybe Nat]) (sh2 :: [Maybe Nat]) (x :: TK).
(KnownShX sh1, KnownShX sh2, KnownSTK x) =>
ADVal target (TKX2 sh1 (TKX2 sh2 x))
-> ADVal target (TKX2 ((++) @(Maybe Nat) sh1 sh2) x)
xunNest (D target (TKX2 sh1 (TKX2 sh2 x))
u Delta target (TKX2 sh1 (TKX2 sh2 x))
u') =
    let c :: TKConversion
  (TKX2 sh (TKX2 sh' a1)) (TKX2 ((++) @(Maybe Nat) sh sh') a1)
c = TKConversion
  (TKX2 sh (TKX2 sh' a1)) (TKX2 ((++) @(Maybe Nat) sh sh') a1)
forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]) (a1 :: TK).
TKConversion
  (TKX2 sh (TKX2 sh' a1)) (TKX2 ((++) @(Maybe Nat) sh sh') a1)
ConvUnnest
    in target (TKX2 ((++) @(Maybe Nat) sh1 sh2) x)
-> Delta target (TKX2 ((++) @(Maybe Nat) sh1 sh2) x)
-> ADVal target (TKX2 ((++) @(Maybe Nat) sh1 sh2) x)
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dD (TKConversion
  (TKX2 sh1 (TKX2 sh2 x)) (TKX2 ((++) @(Maybe Nat) sh1 sh2) x)
-> SingletonTK (TKX2 sh1 (TKX2 sh2 x))
-> target (TKX2 sh1 (TKX2 sh2 x))
-> target (TKX2 ((++) @(Maybe Nat) sh1 sh2) x)
forall (a :: TK) (b :: TK).
TKConversion a b -> SingletonTK a -> target a -> target b
forall (target :: Target) (a :: TK) (b :: TK).
ConvertTensor target =>
TKConversion a b -> SingletonTK a -> target a -> target b
tconvert TKConversion
  (TKX2 sh1 (TKX2 sh2 x)) (TKX2 ((++) @(Maybe Nat) sh1 sh2) x)
forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]) (a1 :: TK).
TKConversion
  (TKX2 sh (TKX2 sh' a1)) (TKX2 ((++) @(Maybe Nat) sh sh') a1)
c (FullShapeTK (TKX2 sh1 (TKX2 sh2 x))
-> SingletonTK (TKX2 sh1 (TKX2 sh2 x))
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK (FullShapeTK (TKX2 sh1 (TKX2 sh2 x))
 -> SingletonTK (TKX2 sh1 (TKX2 sh2 x)))
-> FullShapeTK (TKX2 sh1 (TKX2 sh2 x))
-> SingletonTK (TKX2 sh1 (TKX2 sh2 x))
forall a b. (a -> b) -> a -> b
$ Delta target (TKX2 sh1 (TKX2 sh2 x))
-> FullShapeTK (TKX2 sh1 (TKX2 sh2 x))
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX2 sh1 (TKX2 sh2 x))
u') target (TKX2 sh1 (TKX2 sh2 x))
u)
          (TKConversion
  (TKX2 sh1 (TKX2 sh2 x)) (TKX2 ((++) @(Maybe Nat) sh1 sh2) x)
-> Delta target (TKX2 sh1 (TKX2 sh2 x))
-> Delta target (TKX2 ((++) @(Maybe Nat) sh1 sh2) x)
forall (a1 :: TK) (b :: TK) (a :: Target).
TKConversion a1 b -> Delta a a1 -> Delta a b
DeltaConvert TKConversion
  (TKX2 sh1 (TKX2 sh2 x)) (TKX2 ((++) @(Maybe Nat) sh1 sh2) x)
forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]) (a1 :: TK).
TKConversion
  (TKX2 sh (TKX2 sh' a1)) (TKX2 ((++) @(Maybe Nat) sh sh') a1)
c Delta target (TKX2 sh1 (TKX2 sh2 x))
u')

  tpairConv :: forall (x :: TK) (z :: TK).
ADVal target x -> ADVal target z -> ADVal target (TKProduct x z)
tpairConv = ADVal target x -> ADVal target z -> ADVal target (TKProduct x z)
forall (x :: TK) (z :: TK).
ADVal target x -> ADVal target z -> ADVal target (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair
  tunpairConv :: forall (x :: TK) (z :: TK).
ADVal target (TKProduct x z) -> (ADVal target x, ADVal target z)
tunpairConv = ADVal target (TKProduct x z) -> (ADVal target x, ADVal target z)
forall (x :: TK) (z :: TK).
ADVal target (TKProduct x z) -> (ADVal target x, ADVal target z)
forall (target :: Target) (x :: TK) (z :: TK).
ShareTensor target =>
target (TKProduct x z) -> (target x, target z)
tunpair