{-# LANGUAGE AllowAmbiguousTypes, QuantifiedConstraints #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
module HordeAd.Core.DeltaEval
(
gradientFromDelta, derivativeFromDelta
, evalRev, evalRevFTK, evalRevSame, evalRevFromnMap, EvalState
) where
import Prelude
import Control.Arrow (second)
import Control.Exception.Assert.Sugar
import Data.Dependent.EnumMap.Strict (DEnumMap)
import Data.Dependent.EnumMap.Strict qualified as DMap
import Data.Dependent.Sum (DSum (..))
import Data.Proxy (Proxy (Proxy))
import Data.Traversable (mapAccumL)
import Data.Type.Equality (gcastWith, testEquality, (:~:) (Refl))
import Data.Vector.Generic qualified as V
import Text.Show (showListWith)
import Text.Show.Functions ()
import Type.Reflection (typeRep)
import Data.Array.Nested (type (++))
import Data.Array.Nested qualified as Nested
import Data.Array.Nested.Lemmas
import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Permutation (permInverse)
import Data.Array.Nested.Permutation qualified as Permutation
import Data.Array.Nested.Ranked.Shape
import Data.Array.Nested.Shaped.Shape
import Data.Array.Nested.Types (unsafeCoerceRefl)
import HordeAd.Core.ConvertTensor
import HordeAd.Core.Delta
import HordeAd.Core.Ops
import HordeAd.Core.TensorKind
import HordeAd.Core.Types
import HordeAd.Core.Unwind
gradientFromDelta
:: forall x z target. (ADReadyNoLet target, ShareTensor target)
=> FullShapeTK x
-> FullShapeTK z
-> target (ADTensorKind z)
-> Delta target z
-> target (ADTensorKind x)
gradientFromDelta :: forall (x :: TK) (z :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target) =>
FullShapeTK x
-> FullShapeTK z
-> target (ADTensorKind z)
-> Delta target z
-> target (ADTensorKind x)
gradientFromDelta !FullShapeTK x
xftk !FullShapeTK z
zftk !target (ADTensorKind z)
dt Delta target z
deltaTopLevel =
let s0 :: EvalState target
s0 = FullShapeTK x -> EvalState target
forall (x :: TK) (target :: Target).
FullShapeTK x -> EvalState target
initEvalState FullShapeTK x
xftk
s1 :: EvalState target
s1 = FullShapeTK z
-> EvalState target
-> target (ADTensorKind z)
-> Delta target z
-> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target) =>
FullShapeTK y
-> EvalState target
-> target (ADTensorKind y)
-> Delta target y
-> EvalState target
evalRev FullShapeTK z
zftk EvalState target
s0 target (ADTensorKind z)
dt Delta target z
deltaTopLevel
s2 :: EvalState target
s2 = EvalState target -> EvalState target
forall (target :: Target).
(ADReadyNoLet target, ShareTensor target) =>
EvalState target -> EvalState target
evalRevFromnMap EvalState target
s1
(target (ADTensorKind x)
res, [DSum @TK (InputId target) (TensorOrZero target)]
remainder) =
forall (ady :: TK) (target :: Target).
ADReadyNoLet target =>
[DSum @TK (InputId target) (TensorOrZero target)]
-> EvalState target
-> FullShapeTK ady
-> (target ady, [DSum @TK (InputId target) (TensorOrZero target)])
rebuildInputs @(ADTensorKind x) (DEnumMap @TK (InputId target) (TensorOrZero target)
-> [DSum @TK (InputId target) (TensorOrZero target)]
forall {k1} (k2 :: k1 -> Type) (v :: k1 -> Type).
Enum1 @k1 k2 =>
DEnumMap @k1 k2 v -> [DSum @k1 k2 v]
DMap.toAscList (DEnumMap @TK (InputId target) (TensorOrZero target)
-> [DSum @TK (InputId target) (TensorOrZero target)])
-> DEnumMap @TK (InputId target) (TensorOrZero target)
-> [DSum @TK (InputId target) (TensorOrZero target)]
forall a b. (a -> b) -> a -> b
$ EvalState target
-> DEnumMap @TK (InputId target) (TensorOrZero target)
forall (target :: Target). EvalState target -> IMap target
iMap EvalState target
s2) EvalState target
s2
(FullShapeTK (ADTensorKind x)
-> (target (ADTensorKind x),
[DSum @TK (InputId target) (TensorOrZero target)]))
-> FullShapeTK (ADTensorKind x)
-> (target (ADTensorKind x),
[DSum @TK (InputId target) (TensorOrZero target)])
forall a b. (a -> b) -> a -> b
$ FullShapeTK x -> FullShapeTK (ADTensorKind x)
forall (y :: TK). FullShapeTK y -> FullShapeTK (ADTensorKind y)
adFTK FullShapeTK x
xftk
in Bool -> target (ADTensorKind x) -> target (ADTensorKind x)
forall a. (?callStack::CallStack) => Bool -> a -> a
assert ([DSum @TK (InputId target) (TensorOrZero target)] -> Bool
forall a. [a] -> Bool
forall (t :: Type -> Type) a. Foldable t => t a -> Bool
null [DSum @TK (InputId target) (TensorOrZero target)]
remainder) target (ADTensorKind x)
res
derivativeFromDelta
:: forall x z target. (ADReadyNoLet target, ShareTensor target)
=> Delta target z -> FullShapeTK (ADTensorKind x)
-> target (ADTensorKind x)
-> target (ADTensorKind z)
derivativeFromDelta :: forall (x :: TK) (z :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target) =>
Delta target z
-> FullShapeTK (ADTensorKind x)
-> target (ADTensorKind x)
-> target (ADTensorKind z)
derivativeFromDelta Delta target z
deltaTopLevel FullShapeTK (ADTensorKind x)
ftk target (ADTensorKind x)
ds =
let iMap :: DEnumMap @TK (InputId target) (TensorOrZero target)
iMap = [DSum @TK (InputId target) (TensorOrZero target)]
-> DEnumMap @TK (InputId target) (TensorOrZero target)
forall {kind} (k :: kind -> Type) (v :: kind -> Type).
Enum1 @kind k =>
[DSum @kind k v] -> DEnumMap @kind k v
DMap.fromDistinctAscList ([DSum @TK (InputId target) (TensorOrZero target)]
-> DEnumMap @TK (InputId target) (TensorOrZero target))
-> [DSum @TK (InputId target) (TensorOrZero target)]
-> DEnumMap @TK (InputId target) (TensorOrZero target)
forall a b. (a -> b) -> a -> b
$ ([DSum @TK (InputId target) (TensorOrZero target)], Int)
-> [DSum @TK (InputId target) (TensorOrZero target)]
forall a b. (a, b) -> a
fst (([DSum @TK (InputId target) (TensorOrZero target)], Int)
-> [DSum @TK (InputId target) (TensorOrZero target)])
-> ([DSum @TK (InputId target) (TensorOrZero target)], Int)
-> [DSum @TK (InputId target) (TensorOrZero target)]
forall a b. (a -> b) -> a -> b
$ Int
-> FullShapeTK (ADTensorKind x)
-> target (ADTensorKind x)
-> ([DSum @TK (InputId target) (TensorOrZero target)], Int)
forall (target :: Target) (y :: TK).
ShareTensor target =>
Int
-> FullShapeTK y
-> target y
-> ([DSum @TK (InputId target) (TensorOrZero target)], Int)
generateDSums Int
0 FullShapeTK (ADTensorKind x)
ftk target (ADTensorKind x)
ds
s0 :: DEnumMap @kind k v
s0 = DEnumMap @kind k v
forall {kind} (k :: kind -> Type) (v :: kind -> Type).
DEnumMap @kind k v
DMap.empty
!(!ADMap target
_s2, !target (ADTensorKind z)
c) = DEnumMap @TK (InputId target) (TensorOrZero target)
-> ADMap target
-> Delta target z
-> (ADMap target, target (ADTensorKind z))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwd DEnumMap @TK (InputId target) (TensorOrZero target)
iMap ADMap target
forall {kind} (k :: kind -> Type) (v :: kind -> Type).
DEnumMap @kind k v
s0 Delta target z
deltaTopLevel
in target (ADTensorKind z)
c
type ADMap target = DEnumMap (NodeId target) (Cotangent target)
type IMap target = DEnumMap (InputId target) (TensorOrZero target)
showsPrec_IMap
:: (forall y. KnownSTK y => Show (TensorOrZero target y))
=> Int -> IMap target -> ShowS
showsPrec_IMap :: forall (target :: Target).
(forall (y :: TK). KnownSTK y => Show (TensorOrZero target y)) =>
Int -> IMap target -> ShowS
showsPrec_IMap Int
d IMap target
demap =
Bool -> ShowS -> ShowS
showParen (Int
d Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
10) (ShowS -> ShowS) -> ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$
String -> ShowS
showString String
"fromList "
ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (DSum @TK (InputId target) (TensorOrZero target) -> ShowS)
-> [DSum @TK (InputId target) (TensorOrZero target)] -> ShowS
forall a. (a -> ShowS) -> [a] -> ShowS
showListWith
(\(InputId target a
k :=> TensorOrZero target a
target) ->
SingletonTK a -> (KnownSTK a => ShowS) -> ShowS
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK a -> SingletonTK a
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK (FullShapeTK a -> SingletonTK a) -> FullShapeTK a -> SingletonTK a
forall a b. (a -> b) -> a -> b
$ InputId target a -> FullShapeTK a
forall (f :: Target) (y :: TK). InputId f y -> FullShapeTK y
inputIdToFTK InputId target a
k) ((KnownSTK a => ShowS) -> ShowS) -> (KnownSTK a => ShowS) -> ShowS
forall a b. (a -> b) -> a -> b
$
Int -> InputId target a -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
2 InputId target a
k ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ShowS
showString String
" :=> " ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> TensorOrZero target a -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
1 TensorOrZero target a
target)
(IMap target -> [DSum @TK (InputId target) (TensorOrZero target)]
forall {k1} (k2 :: k1 -> Type) (v :: k1 -> Type).
Enum1 @k1 k2 =>
DEnumMap @k1 k2 v -> [DSum @k1 k2 v]
DMap.toList IMap target
demap)
show_IMap
:: (forall y. KnownSTK y => Show (TensorOrZero target y))
=> IMap target -> String
show_IMap :: forall (target :: Target).
(forall (y :: TK). KnownSTK y => Show (TensorOrZero target y)) =>
IMap target -> String
show_IMap IMap target
iMap = Int -> IMap target -> ShowS
forall (target :: Target).
(forall (y :: TK). KnownSTK y => Show (TensorOrZero target y)) =>
Int -> IMap target -> ShowS
showsPrec_IMap Int
0 IMap target
iMap String
""
type role Cotangent nominal nominal
newtype Cotangent target y =
Cotangent {forall (target :: Target) (y :: TK).
Cotangent target y -> target (ADTensorKind y)
unCotangent :: target (ADTensorKind y)}
type role TensorOrZero nominal nominal
data TensorOrZero target y =
TOTensor (target y)
| TOZero (FullShapeTK y)
deriving Int -> TensorOrZero target y -> ShowS
[TensorOrZero target y] -> ShowS
TensorOrZero target y -> String
(Int -> TensorOrZero target y -> ShowS)
-> (TensorOrZero target y -> String)
-> ([TensorOrZero target y] -> ShowS)
-> Show (TensorOrZero target y)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (target :: Target) (y :: TK).
Show (target y) =>
Int -> TensorOrZero target y -> ShowS
forall (target :: Target) (y :: TK).
Show (target y) =>
[TensorOrZero target y] -> ShowS
forall (target :: Target) (y :: TK).
Show (target y) =>
TensorOrZero target y -> String
$cshowsPrec :: forall (target :: Target) (y :: TK).
Show (target y) =>
Int -> TensorOrZero target y -> ShowS
showsPrec :: Int -> TensorOrZero target y -> ShowS
$cshow :: forall (target :: Target) (y :: TK).
Show (target y) =>
TensorOrZero target y -> String
show :: TensorOrZero target y -> String
$cshowList :: forall (target :: Target) (y :: TK).
Show (target y) =>
[TensorOrZero target y] -> ShowS
showList :: [TensorOrZero target y] -> ShowS
Show
evalTensorOrZero :: forall target x. ADReadyNoLet target
=> TensorOrZero target x -> target x
evalTensorOrZero :: forall (target :: Target) (x :: TK).
ADReadyNoLet target =>
TensorOrZero target x -> target x
evalTensorOrZero = \case
TOTensor target x
t -> target x
t
TOZero FullShapeTK x
ftk -> FullShapeTK x -> target x
forall (y :: TK). FullShapeTK y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
FullShapeTK y -> target y
tdefTarget FullShapeTK x
ftk
addTensorOrZero :: forall target y. (ADReadyNoLet target, ShareTensor target)
=> SingletonTK y
-> TensorOrZero target y -> TensorOrZero target y
-> TensorOrZero target y
addTensorOrZero :: forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target) =>
SingletonTK y
-> TensorOrZero target y
-> TensorOrZero target y
-> TensorOrZero target y
addTensorOrZero SingletonTK y
stk TensorOrZero target y
a TensorOrZero target y
b = case (TensorOrZero target y
a, TensorOrZero target y
b) of
(TOTensor target y
ta, TOTensor target y
tb) -> target y -> TensorOrZero target y
forall (target :: Target) (y :: TK).
target y -> TensorOrZero target y
TOTensor (target y -> TensorOrZero target y)
-> target y -> TensorOrZero target y
forall a b. (a -> b) -> a -> b
$ SingletonTK y -> target y -> target y -> target y
forall (y :: TK). SingletonTK y -> target y -> target y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> target y -> target y
taddTarget SingletonTK y
stk target y
ta target y
tb
(TOZero{}, TensorOrZero target y
_) -> TensorOrZero target y
b
(TensorOrZero target y
_, TOZero{}) -> TensorOrZero target y
a
rebuildInputs :: forall ady target. ADReadyNoLet target
=> [DSum (InputId target) (TensorOrZero target)]
-> EvalState target
-> FullShapeTK ady
-> (target ady, [DSum (InputId target) (TensorOrZero target)])
rebuildInputs :: forall (ady :: TK) (target :: Target).
ADReadyNoLet target =>
[DSum @TK (InputId target) (TensorOrZero target)]
-> EvalState target
-> FullShapeTK ady
-> (target ady, [DSum @TK (InputId target) (TensorOrZero target)])
rebuildInputs [DSum @TK (InputId target) (TensorOrZero target)]
els EvalState target
s2 FullShapeTK ady
ftk = case FullShapeTK ady
ftk of
FTKProduct FullShapeTK y1
ftk1 FullShapeTK z
ftk2 ->
let (target y1
t1, [DSum @TK (InputId target) (TensorOrZero target)]
rest1) = [DSum @TK (InputId target) (TensorOrZero target)]
-> EvalState target
-> FullShapeTK y1
-> (target y1, [DSum @TK (InputId target) (TensorOrZero target)])
forall (ady :: TK) (target :: Target).
ADReadyNoLet target =>
[DSum @TK (InputId target) (TensorOrZero target)]
-> EvalState target
-> FullShapeTK ady
-> (target ady, [DSum @TK (InputId target) (TensorOrZero target)])
rebuildInputs [DSum @TK (InputId target) (TensorOrZero target)]
els EvalState target
s2 FullShapeTK y1
ftk1
(target z
t2, [DSum @TK (InputId target) (TensorOrZero target)]
rest2) = [DSum @TK (InputId target) (TensorOrZero target)]
-> EvalState target
-> FullShapeTK z
-> (target z, [DSum @TK (InputId target) (TensorOrZero target)])
forall (ady :: TK) (target :: Target).
ADReadyNoLet target =>
[DSum @TK (InputId target) (TensorOrZero target)]
-> EvalState target
-> FullShapeTK ady
-> (target ady, [DSum @TK (InputId target) (TensorOrZero target)])
rebuildInputs [DSum @TK (InputId target) (TensorOrZero target)]
rest1 EvalState target
s2 FullShapeTK z
ftk2
!t :: target (TKProduct y1 z)
t = target y1 -> target z -> target (TKProduct y1 z)
forall (x :: TK) (z :: TK).
target x -> target z -> target (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair target y1
t1 target z
t2
in (target ady
target (TKProduct y1 z)
t, [DSum @TK (InputId target) (TensorOrZero target)]
rest2)
FullShapeTK ady
_ | FullShapeTK ady -> Bool
forall (y :: TK). FullShapeTK y -> Bool
differentiableFTK FullShapeTK ady
ftk -> case [DSum @TK (InputId target) (TensorOrZero target)]
els of
(InputId target a
n :=> tz :: TensorOrZero target a
tz@(TOTensor target a
t)) : [DSum @TK (InputId target) (TensorOrZero target)]
rest ->
case FullShapeTK a -> FullShapeTK ady -> Maybe ((:~:) @TK a ady)
forall (y1 :: TK) (y2 :: TK).
FullShapeTK y1 -> FullShapeTK y2 -> Maybe ((:~:) @TK y1 y2)
matchingFTK (InputId target a -> FullShapeTK a
forall (f :: Target) (y :: TK). InputId f y -> FullShapeTK y
inputIdToFTK InputId target a
n) FullShapeTK ady
ftk of
Just (:~:) @TK a ady
Refl ->
(target ady
target a
t, [DSum @TK (InputId target) (TensorOrZero target)]
rest)
Maybe ((:~:) @TK a ady)
_ | Dict @TK KnownSTK a
Dict <- SingletonTK a -> Dict @TK KnownSTK a
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK (FullShapeTK a -> SingletonTK a
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK (FullShapeTK a -> SingletonTK a) -> FullShapeTK a -> SingletonTK a
forall a b. (a -> b) -> a -> b
$ InputId target a -> FullShapeTK a
forall (f :: Target) (y :: TK). InputId f y -> FullShapeTK y
inputIdToFTK InputId target a
n) ->
String
-> (target ady, [DSum @TK (InputId target) (TensorOrZero target)])
forall a. (?callStack::CallStack) => String -> a
error (String
-> (target ady, [DSum @TK (InputId target) (TensorOrZero target)]))
-> String
-> (target ady, [DSum @TK (InputId target) (TensorOrZero target)])
forall a b. (a -> b) -> a -> b
$ String
"rebuildInputs: wrong Tensor type: "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ (InputId target a, TensorOrZero target a, String) -> String
forall a. Show a => a -> String
show (InputId target a
n, TensorOrZero target a
tz, IMap target -> String
forall (target :: Target).
(forall (y :: TK). KnownSTK y => Show (TensorOrZero target y)) =>
IMap target -> String
show_IMap (EvalState target -> IMap target
forall (target :: Target). EvalState target -> IMap target
iMap EvalState target
s2))
(InputId target a
n :=> tz :: TensorOrZero target a
tz@(TOZero FullShapeTK a
ftk2)) : [DSum @TK (InputId target) (TensorOrZero target)]
rest ->
case FullShapeTK a -> FullShapeTK ady -> Maybe ((:~:) @TK a ady)
forall (y1 :: TK) (y2 :: TK).
FullShapeTK y1 -> FullShapeTK y2 -> Maybe ((:~:) @TK y1 y2)
matchingFTK FullShapeTK a
ftk2 FullShapeTK ady
ftk of
Just (:~:) @TK a ady
Refl ->
let !zero :: target ady
zero = FullShapeTK ady -> target ady
forall (y :: TK). FullShapeTK y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
FullShapeTK y -> target y
tdefTarget FullShapeTK ady
ftk
in (target ady
zero, [DSum @TK (InputId target) (TensorOrZero target)]
rest)
Maybe ((:~:) @TK a ady)
_ | Dict @TK KnownSTK a
Dict <- SingletonTK a -> Dict @TK KnownSTK a
forall (y :: TK). SingletonTK y -> Dict @TK KnownSTK y
lemKnownSTK (FullShapeTK a -> SingletonTK a
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK (FullShapeTK a -> SingletonTK a) -> FullShapeTK a -> SingletonTK a
forall a b. (a -> b) -> a -> b
$ InputId target a -> FullShapeTK a
forall (f :: Target) (y :: TK). InputId f y -> FullShapeTK y
inputIdToFTK InputId target a
n) ->
String
-> (target ady, [DSum @TK (InputId target) (TensorOrZero target)])
forall a. (?callStack::CallStack) => String -> a
error (String
-> (target ady, [DSum @TK (InputId target) (TensorOrZero target)]))
-> String
-> (target ady, [DSum @TK (InputId target) (TensorOrZero target)])
forall a b. (a -> b) -> a -> b
$ String
"rebuildInputs: wrong Zero type: "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ (InputId target a, TensorOrZero target a, String) -> String
forall a. Show a => a -> String
show (InputId target a
n, TensorOrZero target a
tz, IMap target -> String
forall (target :: Target).
(forall (y :: TK). KnownSTK y => Show (TensorOrZero target y)) =>
IMap target -> String
show_IMap (EvalState target -> IMap target
forall (target :: Target). EvalState target -> IMap target
iMap EvalState target
s2))
[DSum @TK (InputId target) (TensorOrZero target)]
_ -> String
-> (target ady, [DSum @TK (InputId target) (TensorOrZero target)])
forall a. (?callStack::CallStack) => String -> a
error (String
-> (target ady, [DSum @TK (InputId target) (TensorOrZero target)]))
-> String
-> (target ady, [DSum @TK (InputId target) (TensorOrZero target)])
forall a b. (a -> b) -> a -> b
$ String
"rebuildInputs: illegal TensorOrZero: "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ IMap target -> String
forall (target :: Target).
(forall (y :: TK). KnownSTK y => Show (TensorOrZero target y)) =>
IMap target -> String
show_IMap (EvalState target -> IMap target
forall (target :: Target). EvalState target -> IMap target
iMap EvalState target
s2)
FullShapeTK ady
_ -> (FullShapeTK ady -> target ady
forall (y :: TK). FullShapeTK y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
FullShapeTK y -> target y
tdefTarget FullShapeTK ady
ftk, [DSum @TK (InputId target) (TensorOrZero target)]
els)
generateDSumsDummy :: Int -> FullShapeTK y
-> ([DSum (InputId target) (TensorOrZero target)], Int)
generateDSumsDummy :: forall (y :: TK) (target :: Target).
Int
-> FullShapeTK y
-> ([DSum @TK (InputId target) (TensorOrZero target)], Int)
generateDSumsDummy Int
j FullShapeTK y
ftk = case FullShapeTK y
ftk of
FTKProduct FullShapeTK y1
ftk1 FullShapeTK z
ftk2 ->
let ([DSum @TK (InputId target) (TensorOrZero target)]
ds1, Int
j1) = Int
-> FullShapeTK y1
-> ([DSum @TK (InputId target) (TensorOrZero target)], Int)
forall (y :: TK) (target :: Target).
Int
-> FullShapeTK y
-> ([DSum @TK (InputId target) (TensorOrZero target)], Int)
generateDSumsDummy Int
j FullShapeTK y1
ftk1
([DSum @TK (InputId target) (TensorOrZero target)]
ds2, Int
j2) = Int
-> FullShapeTK z
-> ([DSum @TK (InputId target) (TensorOrZero target)], Int)
forall (y :: TK) (target :: Target).
Int
-> FullShapeTK y
-> ([DSum @TK (InputId target) (TensorOrZero target)], Int)
generateDSumsDummy Int
j1 FullShapeTK z
ftk2
in ([DSum @TK (InputId target) (TensorOrZero target)]
ds1 [DSum @TK (InputId target) (TensorOrZero target)]
-> [DSum @TK (InputId target) (TensorOrZero target)]
-> [DSum @TK (InputId target) (TensorOrZero target)]
forall a. [a] -> [a] -> [a]
++ [DSum @TK (InputId target) (TensorOrZero target)]
ds2, Int
j2)
FullShapeTK y
_ | FullShapeTK y -> Bool
forall (y :: TK). FullShapeTK y -> Bool
differentiableFTK FullShapeTK y
ftk -> ([FullShapeTK y -> Int -> InputId target y
forall (y :: TK) (f :: Target). FullShapeTK y -> Int -> InputId f y
mkInputId FullShapeTK y
ftk Int
j InputId target y
-> TensorOrZero target y
-> DSum @TK (InputId target) (TensorOrZero target)
forall {k} (tag :: k -> Type) (f :: k -> Type) (a :: k).
tag a -> f a -> DSum @k tag f
:=> FullShapeTK y -> TensorOrZero target y
forall (target :: Target) (y :: TK).
FullShapeTK y -> TensorOrZero target y
TOZero FullShapeTK y
ftk], Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
FullShapeTK y
_ -> ([], Int
j)
generateDSums :: ShareTensor target
=> Int -> FullShapeTK y -> target y
-> ([DSum (InputId target) (TensorOrZero target)], Int)
generateDSums :: forall (target :: Target) (y :: TK).
ShareTensor target =>
Int
-> FullShapeTK y
-> target y
-> ([DSum @TK (InputId target) (TensorOrZero target)], Int)
generateDSums Int
j FullShapeTK y
ftk target y
t = case FullShapeTK y
ftk of
FTKProduct FullShapeTK y1
ftk1 FullShapeTK z
ftk2 ->
let (target y1
t1, target z
t2) = target (TKProduct y1 z) -> (target y1, target z)
forall (x :: TK) (z :: TK).
target (TKProduct x z) -> (target x, target z)
forall (target :: Target) (x :: TK) (z :: TK).
ShareTensor target =>
target (TKProduct x z) -> (target x, target z)
tunpair target y
target (TKProduct y1 z)
t
([DSum @TK (InputId target) (TensorOrZero target)]
ds1, Int
j1) = Int
-> FullShapeTK y1
-> target y1
-> ([DSum @TK (InputId target) (TensorOrZero target)], Int)
forall (target :: Target) (y :: TK).
ShareTensor target =>
Int
-> FullShapeTK y
-> target y
-> ([DSum @TK (InputId target) (TensorOrZero target)], Int)
generateDSums Int
j FullShapeTK y1
ftk1 target y1
t1
([DSum @TK (InputId target) (TensorOrZero target)]
ds2, Int
j2) = Int
-> FullShapeTK z
-> target z
-> ([DSum @TK (InputId target) (TensorOrZero target)], Int)
forall (target :: Target) (y :: TK).
ShareTensor target =>
Int
-> FullShapeTK y
-> target y
-> ([DSum @TK (InputId target) (TensorOrZero target)], Int)
generateDSums Int
j1 FullShapeTK z
ftk2 target z
t2
in ([DSum @TK (InputId target) (TensorOrZero target)]
ds1 [DSum @TK (InputId target) (TensorOrZero target)]
-> [DSum @TK (InputId target) (TensorOrZero target)]
-> [DSum @TK (InputId target) (TensorOrZero target)]
forall a. [a] -> [a] -> [a]
++ [DSum @TK (InputId target) (TensorOrZero target)]
ds2, Int
j2)
FullShapeTK y
_ | FullShapeTK y -> Bool
forall (y :: TK). FullShapeTK y -> Bool
differentiableFTK FullShapeTK y
ftk -> ([FullShapeTK y -> Int -> InputId target y
forall (y :: TK) (f :: Target). FullShapeTK y -> Int -> InputId f y
mkInputId FullShapeTK y
ftk Int
j InputId target y
-> TensorOrZero target y
-> DSum @TK (InputId target) (TensorOrZero target)
forall {k} (tag :: k -> Type) (f :: k -> Type) (a :: k).
tag a -> f a -> DSum @k tag f
:=> target y -> TensorOrZero target y
forall (target :: Target) (y :: TK).
target y -> TensorOrZero target y
TOTensor target y
t], Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
FullShapeTK y
_ -> ([], Int
j)
type role EvalState nominal
data EvalState target = EvalState
{ forall (target :: Target). EvalState target -> IMap target
iMap :: IMap target
, forall (target :: Target). EvalState target -> ADMap target
dMap :: ADMap target
, forall (target :: Target).
EvalState target -> DEnumMap @TK (NodeId target) (Delta target)
nMap :: DEnumMap (NodeId target) (Delta target)
}
initEvalState :: FullShapeTK x -> EvalState target
initEvalState :: forall (x :: TK) (target :: Target).
FullShapeTK x -> EvalState target
initEvalState FullShapeTK x
ftk0 =
let iMap :: IMap target
iMap = [DSum @TK (InputId target) (TensorOrZero target)] -> IMap target
forall {kind} (k :: kind -> Type) (v :: kind -> Type).
Enum1 @kind k =>
[DSum @kind k v] -> DEnumMap @kind k v
DMap.fromDistinctAscList ([DSum @TK (InputId target) (TensorOrZero target)] -> IMap target)
-> [DSum @TK (InputId target) (TensorOrZero target)] -> IMap target
forall a b. (a -> b) -> a -> b
$ ([DSum @TK (InputId target) (TensorOrZero target)], Int)
-> [DSum @TK (InputId target) (TensorOrZero target)]
forall a b. (a, b) -> a
fst (([DSum @TK (InputId target) (TensorOrZero target)], Int)
-> [DSum @TK (InputId target) (TensorOrZero target)])
-> ([DSum @TK (InputId target) (TensorOrZero target)], Int)
-> [DSum @TK (InputId target) (TensorOrZero target)]
forall a b. (a -> b) -> a -> b
$ Int
-> FullShapeTK (ADTensorKind x)
-> ([DSum @TK (InputId target) (TensorOrZero target)], Int)
forall (y :: TK) (target :: Target).
Int
-> FullShapeTK y
-> ([DSum @TK (InputId target) (TensorOrZero target)], Int)
generateDSumsDummy Int
0 (FullShapeTK (ADTensorKind x)
-> ([DSum @TK (InputId target) (TensorOrZero target)], Int))
-> FullShapeTK (ADTensorKind x)
-> ([DSum @TK (InputId target) (TensorOrZero target)], Int)
forall a b. (a -> b) -> a -> b
$ FullShapeTK x -> FullShapeTK (ADTensorKind x)
forall (y :: TK). FullShapeTK y -> FullShapeTK (ADTensorKind y)
adFTK FullShapeTK x
ftk0
dMap :: DEnumMap @kind k v
dMap = DEnumMap @kind k v
forall {kind} (k :: kind -> Type) (v :: kind -> Type).
DEnumMap @kind k v
DMap.empty
nMap :: DEnumMap @kind k v
nMap = DEnumMap @kind k v
forall {kind} (k :: kind -> Type) (v :: kind -> Type).
DEnumMap @kind k v
DMap.empty
in EvalState {IMap target
DEnumMap @TK (NodeId target) (Delta target)
ADMap target
forall {kind} (k :: kind -> Type) (v :: kind -> Type).
DEnumMap @kind k v
iMap :: IMap target
dMap :: ADMap target
nMap :: DEnumMap @TK (NodeId target) (Delta target)
iMap :: IMap target
dMap :: forall {kind} (k :: kind -> Type) (v :: kind -> Type).
DEnumMap @kind k v
nMap :: forall {kind} (k :: kind -> Type) (v :: kind -> Type).
DEnumMap @kind k v
..}
evalRevScalarRuntimeSpecialized
:: forall r target.
(GoodScalar r, ADReadyNoLet target, ShareTensor target)
=> EvalState target -> target (ADTensorKind (TKScalar r))
-> Delta target (TKScalar r)
-> EvalState target
{-# INLINE evalRevScalarRuntimeSpecialized #-}
evalRevScalarRuntimeSpecialized :: forall r (target :: Target).
(GoodScalar r, ADReadyNoLet target, ShareTensor target) =>
EvalState target
-> target (ADTensorKind (TKScalar r))
-> Delta target (TKScalar r)
-> EvalState target
evalRevScalarRuntimeSpecialized !EvalState target
s !target (ADTensorKind (TKScalar r))
c =
case TypeRep @Type r
-> TypeRep @Type Double -> Maybe ((:~:) @Type r Double)
forall a b.
TypeRep @Type a -> TypeRep @Type b -> Maybe ((:~:) @Type a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @r) (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @Double) of
Just (:~:) @Type r Double
Refl -> forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame @(TKScalar Double) EvalState target
s target (ADTensorKind (TKScalar r))
target (ADTensorKind (TKScalar Double))
c
Maybe ((:~:) @Type r Double)
_ -> case TypeRep @Type r
-> TypeRep @Type Float -> Maybe ((:~:) @Type r Float)
forall a b.
TypeRep @Type a -> TypeRep @Type b -> Maybe ((:~:) @Type a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @r) (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @Float) of
Just (:~:) @Type r Float
Refl -> forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame @(TKScalar Float) EvalState target
s target (ADTensorKind (TKScalar r))
target (ADTensorKind (TKScalar Float))
c
Maybe ((:~:) @Type r Float)
_ -> EvalState target -> Delta target (TKScalar r) -> EvalState target
forall a b. a -> b -> a
const EvalState target
s
evalRevRRuntimeSpecialized
:: forall n r target.
(GoodScalar r, ADReadyNoLet target, ShareTensor target)
=> EvalState target -> target (ADTensorKind (TKR n r))
-> Delta target (TKR n r)
-> EvalState target
{-# INLINE evalRevRRuntimeSpecialized #-}
evalRevRRuntimeSpecialized :: forall (n :: Nat) r (target :: Target).
(GoodScalar r, ADReadyNoLet target, ShareTensor target) =>
EvalState target
-> target (ADTensorKind (TKR n r))
-> Delta target (TKR n r)
-> EvalState target
evalRevRRuntimeSpecialized !EvalState target
s !target (ADTensorKind (TKR n r))
c =
case TypeRep @Type r
-> TypeRep @Type Double -> Maybe ((:~:) @Type r Double)
forall a b.
TypeRep @Type a -> TypeRep @Type b -> Maybe ((:~:) @Type a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @r) (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @Double) of
Just (:~:) @Type r Double
Refl -> forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame @(TKR n Double) EvalState target
s target (ADTensorKind (TKR n r))
target (ADTensorKind (TKR n Double))
c
Maybe ((:~:) @Type r Double)
_ -> case TypeRep @Type r
-> TypeRep @Type Float -> Maybe ((:~:) @Type r Float)
forall a b.
TypeRep @Type a -> TypeRep @Type b -> Maybe ((:~:) @Type a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @r) (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @Float) of
Just (:~:) @Type r Float
Refl -> forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame @(TKR n Float) EvalState target
s target (ADTensorKind (TKR n r))
target (ADTensorKind (TKR n Float))
c
Maybe ((:~:) @Type r Float)
_ -> EvalState target -> Delta target (TKR n r) -> EvalState target
forall a b. a -> b -> a
const EvalState target
s
evalSRuntimeSpecialized
:: forall sh r target.
(GoodScalar r, ADReadyNoLet target, ShareTensor target)
=> EvalState target -> target (ADTensorKind (TKS sh r))
-> Delta target (TKS sh r)
-> EvalState target
{-# INLINE evalSRuntimeSpecialized #-}
evalSRuntimeSpecialized :: forall (sh :: [Nat]) r (target :: Target).
(GoodScalar r, ADReadyNoLet target, ShareTensor target) =>
EvalState target
-> target (ADTensorKind (TKS sh r))
-> Delta target (TKS sh r)
-> EvalState target
evalSRuntimeSpecialized !EvalState target
s !target (ADTensorKind (TKS sh r))
c =
case TypeRep @Type r
-> TypeRep @Type Double -> Maybe ((:~:) @Type r Double)
forall a b.
TypeRep @Type a -> TypeRep @Type b -> Maybe ((:~:) @Type a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @r) (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @Double) of
Just (:~:) @Type r Double
Refl -> forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame @(TKS sh Double) EvalState target
s target (ADTensorKind (TKS sh r))
target (ADTensorKind (TKS sh Double))
c
Maybe ((:~:) @Type r Double)
_ -> case TypeRep @Type r
-> TypeRep @Type Float -> Maybe ((:~:) @Type r Float)
forall a b.
TypeRep @Type a -> TypeRep @Type b -> Maybe ((:~:) @Type a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @r) (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @Float) of
Just (:~:) @Type r Float
Refl -> forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame @(TKS sh Float) EvalState target
s target (ADTensorKind (TKS sh r))
target (ADTensorKind (TKS sh Float))
c
Maybe ((:~:) @Type r Float)
_ -> EvalState target -> Delta target (TKS sh r) -> EvalState target
forall a b. a -> b -> a
const EvalState target
s
evalXRuntimeSpecialized
:: forall sh r target.
(GoodScalar r, ADReadyNoLet target, ShareTensor target)
=> EvalState target -> target (ADTensorKind (TKX sh r))
-> Delta target (TKX sh r)
-> EvalState target
{-# INLINE evalXRuntimeSpecialized #-}
evalXRuntimeSpecialized :: forall (sh :: [Maybe Nat]) r (target :: Target).
(GoodScalar r, ADReadyNoLet target, ShareTensor target) =>
EvalState target
-> target (ADTensorKind (TKX sh r))
-> Delta target (TKX sh r)
-> EvalState target
evalXRuntimeSpecialized !EvalState target
s !target (ADTensorKind (TKX sh r))
c =
case TypeRep @Type r
-> TypeRep @Type Double -> Maybe ((:~:) @Type r Double)
forall a b.
TypeRep @Type a -> TypeRep @Type b -> Maybe ((:~:) @Type a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @r) (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @Double) of
Just (:~:) @Type r Double
Refl -> forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame @(TKX sh Double) EvalState target
s target (ADTensorKind (TKX sh r))
target (ADTensorKind (TKX sh Double))
c
Maybe ((:~:) @Type r Double)
_ -> case TypeRep @Type r
-> TypeRep @Type Float -> Maybe ((:~:) @Type r Float)
forall a b.
TypeRep @Type a -> TypeRep @Type b -> Maybe ((:~:) @Type a b)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality @k f =>
f a -> f b -> Maybe ((:~:) @k a b)
testEquality (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @r) (forall a. Typeable @Type a => TypeRep @Type a
forall {k} (a :: k). Typeable @k a => TypeRep @k a
typeRep @Float) of
Just (:~:) @Type r Float
Refl -> forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame @(TKX sh Float) EvalState target
s target (ADTensorKind (TKX sh r))
target (ADTensorKind (TKX sh Float))
c
Maybe ((:~:) @Type r Float)
_ -> EvalState target -> Delta target (TKX sh r) -> EvalState target
forall a b. a -> b -> a
const EvalState target
s
evalRev
:: forall y target.
(ADReadyNoLet target, ShareTensor target)
=> FullShapeTK y
-> EvalState target -> target (ADTensorKind y) -> Delta target y
-> EvalState target
evalRev :: forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target) =>
FullShapeTK y
-> EvalState target
-> target (ADTensorKind y)
-> Delta target y
-> EvalState target
evalRev FullShapeTK y
ftk !EvalState target
s !target (ADTensorKind y)
c Delta target y
d = case FullShapeTK y
ftk of
FTKScalar @r -> forall r (target :: Target).
(GoodScalar r, ADReadyNoLet target, ShareTensor target) =>
EvalState target
-> target (ADTensorKind (TKScalar r))
-> Delta target (TKScalar r)
-> EvalState target
evalRevScalarRuntimeSpecialized @r EvalState target
s target (ADTensorKind y)
target (ADTensorKind (TKScalar r))
c Delta target y
Delta target (TKScalar r)
d
FTKR @n IShR n
_ (FTKScalar @r) -> forall (n :: Nat) r (target :: Target).
(GoodScalar r, ADReadyNoLet target, ShareTensor target) =>
EvalState target
-> target (ADTensorKind (TKR n r))
-> Delta target (TKR n r)
-> EvalState target
evalRevRRuntimeSpecialized @n @r EvalState target
s target (ADTensorKind y)
target (ADTensorKind (TKR n r))
c Delta target y
Delta target (TKR n r)
d
FTKS @sh ShS sh
_ (FTKScalar @r) -> forall (sh :: [Nat]) r (target :: Target).
(GoodScalar r, ADReadyNoLet target, ShareTensor target) =>
EvalState target
-> target (ADTensorKind (TKS sh r))
-> Delta target (TKS sh r)
-> EvalState target
evalSRuntimeSpecialized @sh @r EvalState target
s target (ADTensorKind y)
target (ADTensorKind (TKS sh r))
c Delta target y
Delta target (TKS sh r)
d
FTKX @sh IShX sh
_ (FTKScalar @r) -> forall (sh :: [Maybe Nat]) r (target :: Target).
(GoodScalar r, ADReadyNoLet target, ShareTensor target) =>
EvalState target
-> target (ADTensorKind (TKX sh r))
-> Delta target (TKX sh r)
-> EvalState target
evalXRuntimeSpecialized @sh @r EvalState target
s target (ADTensorKind y)
target (ADTensorKind (TKX sh r))
c Delta target y
Delta target (TKX sh r)
d
FullShapeTK y
_ -> EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevFTK EvalState target
s target (ADTensorKind y)
c Delta target y
d
evalRevFTK
:: forall y target.
(ADReadyNoLet target, ShareTensor target)
=> EvalState target -> target (ADTensorKind y) -> Delta target y
-> EvalState target
evalRevFTK :: forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevFTK !EvalState target
s !target (ADTensorKind y)
c Delta target y
d0 = case Delta target y
d0 of
DeltaShare NodeId target y
n Delta target y
d ->
Bool -> EvalState target -> EvalState target
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (case Delta target y
d of
DeltaZero{} -> Bool
False
DeltaPair{} -> Bool
False
DeltaInput{} -> Bool
False
DeltaShare{} -> Bool
False
Delta target y
_ -> Bool
True)
(EvalState target -> EvalState target)
-> EvalState target -> EvalState target
forall a b. (a -> b) -> a -> b
$ if NodeId target y
-> DEnumMap @TK (NodeId target) (Delta target) -> Bool
forall {kind} (k :: kind -> Type) (a :: kind) (v :: kind -> Type).
Enum1 @kind k =>
k a -> DEnumMap @kind k v -> Bool
DMap.member NodeId target y
n (DEnumMap @TK (NodeId target) (Delta target) -> Bool)
-> DEnumMap @TK (NodeId target) (Delta target) -> Bool
forall a b. (a -> b) -> a -> b
$ EvalState target -> DEnumMap @TK (NodeId target) (Delta target)
forall (target :: Target).
EvalState target -> DEnumMap @TK (NodeId target) (Delta target)
nMap EvalState target
s
then let addc :: Cotangent target y -> Cotangent target y
addc Cotangent target y
x =
target (ADTensorKind y) -> Cotangent target y
forall (target :: Target) (y :: TK).
target (ADTensorKind y) -> Cotangent target y
Cotangent (target (ADTensorKind y) -> Cotangent target y)
-> target (ADTensorKind y) -> Cotangent target y
forall a b. (a -> b) -> a -> b
$ SingletonTK (ADTensorKind y)
-> target (ADTensorKind y)
-> target (ADTensorKind y)
-> target (ADTensorKind y)
forall (y :: TK). SingletonTK y -> target y -> target y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> target y -> target y
taddTarget (SingletonTK y -> SingletonTK (ADTensorKind y)
forall (y :: TK). SingletonTK y -> SingletonTK (ADTensorKind y)
adSTK (SingletonTK y -> SingletonTK (ADTensorKind y))
-> SingletonTK y -> SingletonTK (ADTensorKind y)
forall a b. (a -> b) -> a -> b
$ FullShapeTK y -> SingletonTK y
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK (FullShapeTK y -> SingletonTK y) -> FullShapeTK y -> SingletonTK y
forall a b. (a -> b) -> a -> b
$ NodeId target y -> FullShapeTK y
forall (f :: Target) (y :: TK). NodeId f y -> FullShapeTK y
nodeIdToFTK NodeId target y
n)
target (ADTensorKind y)
c (Cotangent target y -> target (ADTensorKind y)
forall (target :: Target) (y :: TK).
Cotangent target y -> target (ADTensorKind y)
unCotangent Cotangent target y
x)
in EvalState target
s {dMap = DMap.adjust addc n $ dMap s}
else let cd :: Cotangent target y
cd = target (ADTensorKind y) -> Cotangent target y
forall (target :: Target) (y :: TK).
target (ADTensorKind y) -> Cotangent target y
Cotangent target (ADTensorKind y)
c
in EvalState target
s { nMap = DMap.insert n d $ nMap s
, dMap = DMap.insert n cd $ dMap s }
DeltaPair Delta target y
d1 Delta target z
d2 ->
let (target (ADTensorKind y)
c1, target (ADTensorKind z)
c2) = target (TKProduct (ADTensorKind y) (ADTensorKind z))
-> (target (ADTensorKind y), target (ADTensorKind z))
forall (x :: TK) (z :: TK).
target (TKProduct x z) -> (target x, target z)
forall (target :: Target) (x :: TK) (z :: TK).
ShareTensor target =>
target (TKProduct x z) -> (target x, target z)
tunpair target (ADTensorKind y)
target (TKProduct (ADTensorKind y) (ADTensorKind z))
c
in EvalState target
-> target (ADTensorKind z) -> Delta target z -> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevFTK (EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevFTK EvalState target
s target (ADTensorKind y)
c1 Delta target y
d1) target (ADTensorKind z)
c2 Delta target z
d2
DeltaProject1 Delta target (TKProduct y z)
d -> case Delta target (TKProduct y z) -> FullShapeTK (TKProduct y z)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKProduct y z)
d of
FTKProduct FullShapeTK y1
_ FullShapeTK z
ftk2 ->
let zero :: target (ADTensorKind z)
zero = FullShapeTK (ADTensorKind z) -> target (ADTensorKind z)
forall (y :: TK). FullShapeTK y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
FullShapeTK y -> target y
tdefTarget (FullShapeTK (ADTensorKind z) -> target (ADTensorKind z))
-> FullShapeTK (ADTensorKind z) -> target (ADTensorKind z)
forall a b. (a -> b) -> a -> b
$ FullShapeTK z -> FullShapeTK (ADTensorKind z)
forall (y :: TK). FullShapeTK y -> FullShapeTK (ADTensorKind y)
adFTK FullShapeTK z
ftk2
in EvalState target
-> target (ADTensorKind (TKProduct y z))
-> Delta target (TKProduct y z)
-> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevFTK EvalState target
s (target (ADTensorKind y)
-> target (ADTensorKind z)
-> target (TKProduct (ADTensorKind y) (ADTensorKind z))
forall (x :: TK) (z :: TK).
target x -> target z -> target (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair target (ADTensorKind y)
c target (ADTensorKind z)
zero) Delta target (TKProduct y z)
d
DeltaProject2 Delta target (TKProduct y y)
d -> case Delta target (TKProduct y y) -> FullShapeTK (TKProduct y y)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKProduct y y)
d of
FTKProduct FullShapeTK y1
ftk1 FullShapeTK z
_ ->
let zero :: target (ADTensorKind y)
zero = FullShapeTK (ADTensorKind y) -> target (ADTensorKind y)
forall (y :: TK). FullShapeTK y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
FullShapeTK y -> target y
tdefTarget (FullShapeTK (ADTensorKind y) -> target (ADTensorKind y))
-> FullShapeTK (ADTensorKind y) -> target (ADTensorKind y)
forall a b. (a -> b) -> a -> b
$ FullShapeTK y1 -> FullShapeTK (ADTensorKind y1)
forall (y :: TK). FullShapeTK y -> FullShapeTK (ADTensorKind y)
adFTK FullShapeTK y1
ftk1
in EvalState target
-> target (ADTensorKind (TKProduct y y))
-> Delta target (TKProduct y y)
-> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevFTK EvalState target
s (target (ADTensorKind y)
-> target (ADTensorKind y)
-> target (TKProduct (ADTensorKind y) (ADTensorKind y))
forall (x :: TK) (z :: TK).
target x -> target z -> target (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair target (ADTensorKind y)
zero target (ADTensorKind y)
c) Delta target (TKProduct y y)
d
DeltaFromVector SNat k
snat SingletonTK y
stk Vector (Delta target y)
ld | (:~:)
@TK
(BuildTensorKind k (ADTensorKind y))
(ADTensorKind (BuildTensorKind k y))
Refl <- SNat k
-> SingletonTK y
-> (:~:)
@TK
(BuildTensorKind k (ADTensorKind y))
(ADTensorKind (BuildTensorKind k y))
forall (k :: Nat) (y :: TK).
SNat k
-> SingletonTK y
-> (:~:)
@TK
(BuildTensorKind k (ADTensorKind y))
(ADTensorKind (BuildTensorKind k y))
lemBuildOfAD SNat k
snat SingletonTK y
stk ->
let cxs :: [target (ADTensorKind y)]
cxs = SNat k
-> SingletonTK (ADTensorKind y)
-> target (BuildTensorKind k (ADTensorKind y))
-> [target (ADTensorKind y)]
forall (y :: TK) (k :: Nat).
(BaseTensor target, ConvertTensor target) =>
SNat k
-> SingletonTK y -> target (BuildTensorKind k y) -> [target y]
forall (target :: Target) (y :: TK) (k :: Nat).
(ShareTensor target, BaseTensor target, ConvertTensor target) =>
SNat k
-> SingletonTK y -> target (BuildTensorKind k y) -> [target y]
tunravelToListShare SNat k
snat (SingletonTK y -> SingletonTK (ADTensorKind y)
forall (y :: TK). SingletonTK y -> SingletonTK (ADTensorKind y)
adSTK SingletonTK y
stk) target (ADTensorKind y)
target (BuildTensorKind k (ADTensorKind y))
c
in (EvalState target
-> (target (ADTensorKind y), Delta target y) -> EvalState target)
-> EvalState target
-> [(target (ADTensorKind y), Delta target y)]
-> EvalState target
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\ !EvalState target
s2 (target (ADTensorKind y)
cx, Delta target y
d2) -> EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevFTK EvalState target
s2 target (ADTensorKind y)
cx Delta target y
d2) EvalState target
s
([(target (ADTensorKind y), Delta target y)] -> EvalState target)
-> [(target (ADTensorKind y), Delta target y)] -> EvalState target
forall a b. (a -> b) -> a -> b
$ [target (ADTensorKind y)]
-> [Delta target y] -> [(target (ADTensorKind y), Delta target y)]
forall a b. [a] -> [b] -> [(a, b)]
zip [target (ADTensorKind y)]
cxs (Vector (Delta target y) -> [Delta target y]
forall (v :: Type -> Type) a. Vector v a => v a -> [a]
V.toList Vector (Delta target y)
ld)
DeltaSum SNat k
snat SingletonTK y
stk Delta target (BuildTensorKind k y)
d | (:~:)
@TK
(BuildTensorKind k (ADTensorKind y))
(ADTensorKind (BuildTensorKind k y))
Refl <- SNat k
-> SingletonTK y
-> (:~:)
@TK
(BuildTensorKind k (ADTensorKind y))
(ADTensorKind (BuildTensorKind k y))
forall (k :: Nat) (y :: TK).
SNat k
-> SingletonTK y
-> (:~:)
@TK
(BuildTensorKind k (ADTensorKind y))
(ADTensorKind (BuildTensorKind k y))
lemBuildOfAD SNat k
snat SingletonTK y
stk ->
EvalState target
-> target (ADTensorKind (BuildTensorKind k y))
-> Delta target (BuildTensorKind k y)
-> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevFTK EvalState target
s (SNat k
-> SingletonTK (ADTensorKind y)
-> target (ADTensorKind y)
-> target (BuildTensorKind k (ADTensorKind y))
forall (z :: TK) (k :: Nat).
ConvertTensor target =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
forall (target :: Target) (z :: TK) (k :: Nat).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
treplicate SNat k
snat (SingletonTK y -> SingletonTK (ADTensorKind y)
forall (y :: TK). SingletonTK y -> SingletonTK (ADTensorKind y)
adSTK SingletonTK y
stk) target (ADTensorKind y)
c) Delta target (BuildTensorKind k y)
d
DeltaReplicate SNat k
snat SingletonTK y
stk Delta target y
d | (:~:)
@TK
(BuildTensorKind k (ADTensorKind y))
(ADTensorKind (BuildTensorKind k y))
Refl <- SNat k
-> SingletonTK y
-> (:~:)
@TK
(BuildTensorKind k (ADTensorKind y))
(ADTensorKind (BuildTensorKind k y))
forall (k :: Nat) (y :: TK).
SNat k
-> SingletonTK y
-> (:~:)
@TK
(BuildTensorKind k (ADTensorKind y))
(ADTensorKind (BuildTensorKind k y))
lemBuildOfAD SNat k
snat SingletonTK y
stk ->
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevFTK EvalState target
s (SNat k
-> SingletonTK (ADTensorKind y)
-> target (BuildTensorKind k (ADTensorKind y))
-> target (ADTensorKind y)
forall (z :: TK) (k :: Nat).
ConvertTensor target =>
SNat k -> SingletonTK z -> target (BuildTensorKind k z) -> target z
forall (target :: Target) (z :: TK) (k :: Nat).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target (BuildTensorKind k z) -> target z
tsum SNat k
snat (SingletonTK y -> SingletonTK (ADTensorKind y)
forall (y :: TK). SingletonTK y -> SingletonTK (ADTensorKind y)
adSTK SingletonTK y
stk) target (ADTensorKind y)
target (BuildTensorKind k (ADTensorKind y))
c) Delta target y
d
DeltaMapAccumR SNat k
k FullShapeTK by
bftk FullShapeTK ey
eftk target (BuildTensorKind k accy)
q target (BuildTensorKind k ey)
es HFun
(TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy by))
_df HFun
(TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy ey))
rf Delta target accy
acc0' Delta target (BuildTensorKind k ey)
es'
| (:~:)
@TK
(BuildTensorKind k (ADTensorKind by))
(ADTensorKind (BuildTensorKind k by))
Refl <- SNat k
-> SingletonTK by
-> (:~:)
@TK
(BuildTensorKind k (ADTensorKind by))
(ADTensorKind (BuildTensorKind k by))
forall (k :: Nat) (y :: TK).
SNat k
-> SingletonTK y
-> (:~:)
@TK
(BuildTensorKind k (ADTensorKind y))
(ADTensorKind (BuildTensorKind k y))
lemBuildOfAD SNat k
k (FullShapeTK by -> SingletonTK by
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK by
bftk)
, (:~:)
@TK
(BuildTensorKind k (ADTensorKind ey))
(ADTensorKind (BuildTensorKind k ey))
Refl <- SNat k
-> SingletonTK ey
-> (:~:)
@TK
(BuildTensorKind k (ADTensorKind ey))
(ADTensorKind (BuildTensorKind k ey))
forall (k :: Nat) (y :: TK).
SNat k
-> SingletonTK y
-> (:~:)
@TK
(BuildTensorKind k (ADTensorKind y))
(ADTensorKind (BuildTensorKind k y))
lemBuildOfAD SNat k
k (FullShapeTK ey -> SingletonTK ey
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK ey
eftk) ->
let accftk :: FullShapeTK accy
accftk = Delta target accy -> FullShapeTK accy
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target accy
acc0'
accftkAD :: FullShapeTK (ADTensorKind accy)
accftkAD = FullShapeTK accy -> FullShapeTK (ADTensorKind accy)
forall (y :: TK). FullShapeTK y -> FullShapeTK (ADTensorKind y)
adFTK FullShapeTK accy
accftk
bftkAD :: FullShapeTK (ADTensorKind by)
bftkAD = FullShapeTK by -> FullShapeTK (ADTensorKind by)
forall (y :: TK). FullShapeTK y -> FullShapeTK (ADTensorKind y)
adFTK FullShapeTK by
bftk
eftkAD :: FullShapeTK (ADTensorKind ey)
eftkAD = FullShapeTK ey -> FullShapeTK (ADTensorKind ey)
forall (y :: TK). FullShapeTK y -> FullShapeTK (ADTensorKind y)
adFTK FullShapeTK ey
eftk
(target (ADTensorKind accy)
c0, target (BuildTensorKind k (ADTensorKind by))
crest) = target
(TKProduct
(ADTensorKind accy) (BuildTensorKind k (ADTensorKind by)))
-> (target (ADTensorKind accy),
target (BuildTensorKind k (ADTensorKind by)))
forall (x :: TK) (z :: TK).
target (TKProduct x z) -> (target x, target z)
forall (target :: Target) (x :: TK) (z :: TK).
ShareTensor target =>
target (TKProduct x z) -> (target x, target z)
tunpair target (ADTensorKind y)
target
(TKProduct
(ADTensorKind accy) (BuildTensorKind k (ADTensorKind by)))
c
dacc_des :: target
(TKProduct
(ADTensorKind accy) (BuildTensorKind k (ADTensorKind ey)))
dacc_des =
Proxy @Target target
-> SNat k
-> FullShapeTK (ADTensorKind accy)
-> FullShapeTK (ADTensorKind ey)
-> FullShapeTK (TKProduct (ADTensorKind by) (TKProduct accy ey))
-> (forall (f :: Target).
ADReady f =>
f (ADTensorKind accy)
-> f (TKProduct (ADTensorKind by) (TKProduct accy ey))
-> f (TKProduct (ADTensorKind accy) (ADTensorKind ey)))
-> target (ADTensorKind accy)
-> target
(BuildTensorKind
k (TKProduct (ADTensorKind by) (TKProduct accy ey)))
-> target
(TKProduct
(ADTensorKind accy) (BuildTensorKind k (ADTensorKind ey)))
forall (accy :: TK) (by :: TK) (ey :: TK) (k :: Nat)
(target :: Target).
BaseTensor target =>
Proxy @Target target
-> SNat k
-> FullShapeTK accy
-> FullShapeTK by
-> FullShapeTK ey
-> (forall (f :: Target).
ADReady f =>
f accy -> f ey -> f (TKProduct accy by))
-> target accy
-> target (BuildTensorKind k ey)
-> target (TKProduct accy (BuildTensorKind k by))
tmapAccumL (forall {k} (t :: k). Proxy @k t
forall (t :: Target). Proxy @Target t
Proxy @target)
SNat k
k FullShapeTK (ADTensorKind accy)
accftkAD FullShapeTK (ADTensorKind ey)
eftkAD (FullShapeTK (ADTensorKind by)
-> FullShapeTK (TKProduct accy ey)
-> FullShapeTK (TKProduct (ADTensorKind by) (TKProduct accy ey))
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct FullShapeTK (ADTensorKind by)
bftkAD
(FullShapeTK accy
-> FullShapeTK ey -> FullShapeTK (TKProduct accy ey)
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct FullShapeTK accy
accftk FullShapeTK ey
eftk))
(\f (ADTensorKind accy)
dx f (TKProduct (ADTensorKind by) (TKProduct accy ey))
db_acc_e ->
f (TKProduct (ADTensorKind by) (TKProduct accy ey))
-> (f (TKProduct (ADTensorKind by) (TKProduct accy ey))
-> f (TKProduct (ADTensorKind accy) (ADTensorKind ey)))
-> f (TKProduct (ADTensorKind accy) (ADTensorKind ey))
forall (x :: TK) (z :: TK). f x -> (f x -> f z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
LetTensor target =>
target x -> (target x -> target z) -> target z
ttlet f (TKProduct (ADTensorKind by) (TKProduct accy ey))
db_acc_e ((f (TKProduct (ADTensorKind by) (TKProduct accy ey))
-> f (TKProduct (ADTensorKind accy) (ADTensorKind ey)))
-> f (TKProduct (ADTensorKind accy) (ADTensorKind ey)))
-> (f (TKProduct (ADTensorKind by) (TKProduct accy ey))
-> f (TKProduct (ADTensorKind accy) (ADTensorKind ey)))
-> f (TKProduct (ADTensorKind accy) (ADTensorKind ey))
forall a b. (a -> b) -> a -> b
$ \ !f (TKProduct (ADTensorKind by) (TKProduct accy ey))
db_acc_e1 ->
HFun
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind by))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
-> forall (f :: Target).
ADReady f =>
f (TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind by))
(TKProduct accy ey))
-> f (TKProduct (ADTensorKind accy) (ADTensorKind ey))
forall (x :: TK) (z :: TK).
HFun x z -> forall (f :: Target). ADReady f => f x -> f z
unHFun HFun
(TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy ey))
HFun
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind by))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
rf (f (TKProduct (ADTensorKind accy) (ADTensorKind by))
-> f (TKProduct accy ey)
-> f (TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind by))
(TKProduct accy ey))
forall (x :: TK) (z :: TK). f x -> f z -> f (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair (f (ADTensorKind accy)
-> f (ADTensorKind by)
-> f (TKProduct (ADTensorKind accy) (ADTensorKind by))
forall (x :: TK) (z :: TK). f x -> f z -> f (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair f (ADTensorKind accy)
dx (f (TKProduct (ADTensorKind by) (TKProduct accy ey))
-> f (ADTensorKind by)
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 f (TKProduct (ADTensorKind by) (TKProduct accy ey))
db_acc_e1))
(f (TKProduct (ADTensorKind by) (TKProduct accy ey))
-> f (TKProduct accy ey)
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target z
tproject2 f (TKProduct (ADTensorKind by) (TKProduct accy ey))
db_acc_e1)))
target (ADTensorKind accy)
c0
(target (BuildTensorKind k (ADTensorKind by))
-> target
(TKProduct (BuildTensorKind k accy) (BuildTensorKind k ey))
-> target
(TKProduct
(BuildTensorKind k (ADTensorKind by))
(TKProduct (BuildTensorKind k accy) (BuildTensorKind k ey)))
forall (x :: TK) (z :: TK).
target x -> target z -> target (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair target (BuildTensorKind k (ADTensorKind by))
crest (target (BuildTensorKind k accy)
-> target (BuildTensorKind k ey)
-> target
(TKProduct (BuildTensorKind k accy) (BuildTensorKind k ey))
forall (x :: TK) (z :: TK).
target x -> target z -> target (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair target (BuildTensorKind k accy)
q target (BuildTensorKind k ey)
es))
(target (ADTensorKind accy)
dacc, target (BuildTensorKind k (ADTensorKind ey))
des) = target
(TKProduct
(ADTensorKind accy) (BuildTensorKind k (ADTensorKind ey)))
-> (target (ADTensorKind accy),
target (BuildTensorKind k (ADTensorKind ey)))
forall (x :: TK) (z :: TK).
target (TKProduct x z) -> (target x, target z)
forall (target :: Target) (x :: TK) (z :: TK).
ShareTensor target =>
target (TKProduct x z) -> (target x, target z)
tunpair target
(TKProduct
(ADTensorKind accy) (BuildTensorKind k (ADTensorKind ey)))
dacc_des
s2 :: EvalState target
s2 = EvalState target
-> target (ADTensorKind accy)
-> Delta target accy
-> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevFTK EvalState target
s target (ADTensorKind accy)
dacc Delta target accy
acc0'
in EvalState target
-> target (ADTensorKind (BuildTensorKind k ey))
-> Delta target (BuildTensorKind k ey)
-> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevFTK EvalState target
s2 target (ADTensorKind (BuildTensorKind k ey))
target (BuildTensorKind k (ADTensorKind ey))
des Delta target (BuildTensorKind k ey)
es'
DeltaMapAccumL SNat k
k FullShapeTK by
bftk FullShapeTK ey
eftk target (BuildTensorKind k accy)
q target (BuildTensorKind k ey)
es HFun
(TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy by))
_df HFun
(TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy ey))
rf Delta target accy
acc0' Delta target (BuildTensorKind k ey)
es'
| (:~:)
@TK
(BuildTensorKind k (ADTensorKind by))
(ADTensorKind (BuildTensorKind k by))
Refl <- SNat k
-> SingletonTK by
-> (:~:)
@TK
(BuildTensorKind k (ADTensorKind by))
(ADTensorKind (BuildTensorKind k by))
forall (k :: Nat) (y :: TK).
SNat k
-> SingletonTK y
-> (:~:)
@TK
(BuildTensorKind k (ADTensorKind y))
(ADTensorKind (BuildTensorKind k y))
lemBuildOfAD SNat k
k (FullShapeTK by -> SingletonTK by
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK by
bftk)
, (:~:)
@TK
(BuildTensorKind k (ADTensorKind ey))
(ADTensorKind (BuildTensorKind k ey))
Refl <- SNat k
-> SingletonTK ey
-> (:~:)
@TK
(BuildTensorKind k (ADTensorKind ey))
(ADTensorKind (BuildTensorKind k ey))
forall (k :: Nat) (y :: TK).
SNat k
-> SingletonTK y
-> (:~:)
@TK
(BuildTensorKind k (ADTensorKind y))
(ADTensorKind (BuildTensorKind k y))
lemBuildOfAD SNat k
k (FullShapeTK ey -> SingletonTK ey
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK ey
eftk) ->
let accftk :: FullShapeTK accy
accftk = Delta target accy -> FullShapeTK accy
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target accy
acc0'
accftkAD :: FullShapeTK (ADTensorKind accy)
accftkAD = FullShapeTK accy -> FullShapeTK (ADTensorKind accy)
forall (y :: TK). FullShapeTK y -> FullShapeTK (ADTensorKind y)
adFTK FullShapeTK accy
accftk
bftkAD :: FullShapeTK (ADTensorKind by)
bftkAD = FullShapeTK by -> FullShapeTK (ADTensorKind by)
forall (y :: TK). FullShapeTK y -> FullShapeTK (ADTensorKind y)
adFTK FullShapeTK by
bftk
eftkAD :: FullShapeTK (ADTensorKind ey)
eftkAD = FullShapeTK ey -> FullShapeTK (ADTensorKind ey)
forall (y :: TK). FullShapeTK y -> FullShapeTK (ADTensorKind y)
adFTK FullShapeTK ey
eftk
(target (ADTensorKind accy)
c0, target (BuildTensorKind k (ADTensorKind by))
crest) = target
(TKProduct
(ADTensorKind accy) (BuildTensorKind k (ADTensorKind by)))
-> (target (ADTensorKind accy),
target (BuildTensorKind k (ADTensorKind by)))
forall (x :: TK) (z :: TK).
target (TKProduct x z) -> (target x, target z)
forall (target :: Target) (x :: TK) (z :: TK).
ShareTensor target =>
target (TKProduct x z) -> (target x, target z)
tunpair target (ADTensorKind y)
target
(TKProduct
(ADTensorKind accy) (BuildTensorKind k (ADTensorKind by)))
c
dacc_des :: target
(TKProduct
(ADTensorKind accy) (BuildTensorKind k (ADTensorKind ey)))
dacc_des =
Proxy @Target target
-> SNat k
-> FullShapeTK (ADTensorKind accy)
-> FullShapeTK (ADTensorKind ey)
-> FullShapeTK (TKProduct (ADTensorKind by) (TKProduct accy ey))
-> (forall (f :: Target).
ADReady f =>
f (ADTensorKind accy)
-> f (TKProduct (ADTensorKind by) (TKProduct accy ey))
-> f (TKProduct (ADTensorKind accy) (ADTensorKind ey)))
-> target (ADTensorKind accy)
-> target
(BuildTensorKind
k (TKProduct (ADTensorKind by) (TKProduct accy ey)))
-> target
(TKProduct
(ADTensorKind accy) (BuildTensorKind k (ADTensorKind ey)))
forall (accy :: TK) (by :: TK) (ey :: TK) (k :: Nat)
(target :: Target).
BaseTensor target =>
Proxy @Target target
-> SNat k
-> FullShapeTK accy
-> FullShapeTK by
-> FullShapeTK ey
-> (forall (f :: Target).
ADReady f =>
f accy -> f ey -> f (TKProduct accy by))
-> target accy
-> target (BuildTensorKind k ey)
-> target (TKProduct accy (BuildTensorKind k by))
tmapAccumR (forall {k} (t :: k). Proxy @k t
forall (t :: Target). Proxy @Target t
Proxy @target)
SNat k
k FullShapeTK (ADTensorKind accy)
accftkAD FullShapeTK (ADTensorKind ey)
eftkAD (FullShapeTK (ADTensorKind by)
-> FullShapeTK (TKProduct accy ey)
-> FullShapeTK (TKProduct (ADTensorKind by) (TKProduct accy ey))
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct FullShapeTK (ADTensorKind by)
bftkAD
(FullShapeTK accy
-> FullShapeTK ey -> FullShapeTK (TKProduct accy ey)
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct FullShapeTK accy
accftk FullShapeTK ey
eftk))
(\f (ADTensorKind accy)
dx f (TKProduct (ADTensorKind by) (TKProduct accy ey))
db_acc_e ->
f (TKProduct (ADTensorKind by) (TKProduct accy ey))
-> (f (TKProduct (ADTensorKind by) (TKProduct accy ey))
-> f (TKProduct (ADTensorKind accy) (ADTensorKind ey)))
-> f (TKProduct (ADTensorKind accy) (ADTensorKind ey))
forall (x :: TK) (z :: TK). f x -> (f x -> f z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
LetTensor target =>
target x -> (target x -> target z) -> target z
ttlet f (TKProduct (ADTensorKind by) (TKProduct accy ey))
db_acc_e ((f (TKProduct (ADTensorKind by) (TKProduct accy ey))
-> f (TKProduct (ADTensorKind accy) (ADTensorKind ey)))
-> f (TKProduct (ADTensorKind accy) (ADTensorKind ey)))
-> (f (TKProduct (ADTensorKind by) (TKProduct accy ey))
-> f (TKProduct (ADTensorKind accy) (ADTensorKind ey)))
-> f (TKProduct (ADTensorKind accy) (ADTensorKind ey))
forall a b. (a -> b) -> a -> b
$ \ !f (TKProduct (ADTensorKind by) (TKProduct accy ey))
db_acc_e1 ->
HFun
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind by))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
-> forall (f :: Target).
ADReady f =>
f (TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind by))
(TKProduct accy ey))
-> f (TKProduct (ADTensorKind accy) (ADTensorKind ey))
forall (x :: TK) (z :: TK).
HFun x z -> forall (f :: Target). ADReady f => f x -> f z
unHFun HFun
(TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy ey))
HFun
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind by))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
rf (f (TKProduct (ADTensorKind accy) (ADTensorKind by))
-> f (TKProduct accy ey)
-> f (TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind by))
(TKProduct accy ey))
forall (x :: TK) (z :: TK). f x -> f z -> f (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair (f (ADTensorKind accy)
-> f (ADTensorKind by)
-> f (TKProduct (ADTensorKind accy) (ADTensorKind by))
forall (x :: TK) (z :: TK). f x -> f z -> f (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair f (ADTensorKind accy)
dx (f (TKProduct (ADTensorKind by) (TKProduct accy ey))
-> f (ADTensorKind by)
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 f (TKProduct (ADTensorKind by) (TKProduct accy ey))
db_acc_e1))
(f (TKProduct (ADTensorKind by) (TKProduct accy ey))
-> f (TKProduct accy ey)
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target z
tproject2 f (TKProduct (ADTensorKind by) (TKProduct accy ey))
db_acc_e1)))
target (ADTensorKind accy)
c0
(target (BuildTensorKind k (ADTensorKind by))
-> target
(TKProduct (BuildTensorKind k accy) (BuildTensorKind k ey))
-> target
(TKProduct
(BuildTensorKind k (ADTensorKind by))
(TKProduct (BuildTensorKind k accy) (BuildTensorKind k ey)))
forall (x :: TK) (z :: TK).
target x -> target z -> target (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair target (BuildTensorKind k (ADTensorKind by))
crest (target (BuildTensorKind k accy)
-> target (BuildTensorKind k ey)
-> target
(TKProduct (BuildTensorKind k accy) (BuildTensorKind k ey))
forall (x :: TK) (z :: TK).
target x -> target z -> target (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair target (BuildTensorKind k accy)
q target (BuildTensorKind k ey)
es))
(target (ADTensorKind accy)
dacc, target (BuildTensorKind k (ADTensorKind ey))
des) = target
(TKProduct
(ADTensorKind accy) (BuildTensorKind k (ADTensorKind ey)))
-> (target (ADTensorKind accy),
target (BuildTensorKind k (ADTensorKind ey)))
forall (x :: TK) (z :: TK).
target (TKProduct x z) -> (target x, target z)
forall (target :: Target) (x :: TK) (z :: TK).
ShareTensor target =>
target (TKProduct x z) -> (target x, target z)
tunpair target
(TKProduct
(ADTensorKind accy) (BuildTensorKind k (ADTensorKind ey)))
dacc_des
s2 :: EvalState target
s2 = EvalState target
-> target (ADTensorKind accy)
-> Delta target accy
-> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevFTK EvalState target
s target (ADTensorKind accy)
dacc Delta target accy
acc0'
in EvalState target
-> target (ADTensorKind (BuildTensorKind k ey))
-> Delta target (BuildTensorKind k ey)
-> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevFTK EvalState target
s2 target (ADTensorKind (BuildTensorKind k ey))
target (BuildTensorKind k (ADTensorKind ey))
des Delta target (BuildTensorKind k ey)
es'
Delta target y
_ -> let y :: FullShapeTK y
y = Delta target y -> FullShapeTK y
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target y
d0
in case FullShapeTK y
-> FullShapeTK (ADTensorKind y)
-> Maybe ((:~:) @TK y (ADTensorKind y))
forall (y1 :: TK) (y2 :: TK).
FullShapeTK y1 -> FullShapeTK y2 -> Maybe ((:~:) @TK y1 y2)
matchingFTK FullShapeTK y
y (FullShapeTK y -> FullShapeTK (ADTensorKind y)
forall (y :: TK). FullShapeTK y -> FullShapeTK (ADTensorKind y)
adFTK FullShapeTK y
y) of
Just (:~:) @TK y (ADTensorKind y)
Refl -> EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame EvalState target
s target (ADTensorKind y)
c Delta target y
d0
Maybe ((:~:) @TK y (ADTensorKind y))
_ -> EvalState target
s
evalRevSame
:: forall y target.
(ADReadyNoLet target, ShareTensor target, y ~ ADTensorKind y)
=> EvalState target -> target (ADTensorKind y) -> Delta target y
-> EvalState target
evalRevSame :: forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame !EvalState target
s !target (ADTensorKind y)
c = \case
DeltaInput InputId target y
i ->
let cs :: TensorOrZero target y
cs = target y -> TensorOrZero target y
forall (target :: Target) (y :: TK).
target y -> TensorOrZero target y
TOTensor target y
target (ADTensorKind y)
c
in EvalState target
s {iMap = DMap.adjust (addTensorOrZero (ftkToSTK $ inputIdToFTK i) cs) i
$ iMap s}
DeltaZero{} -> EvalState target
s
DeltaScale (NestedTarget target y
k) Delta target y
d -> EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame EvalState target
s (target y
k target y -> target y -> target y
forall a. Num a => a -> a -> a
* target y
target (ADTensorKind y)
c) Delta target y
d
DeltaAdd Delta target y
d Delta target y
e ->
let cShared :: target y
cShared = target y -> target y
forall (y :: TK). target y -> target y
forall (target :: Target) (y :: TK).
ShareTensor target =>
target y -> target y
tshare target y
target (ADTensorKind y)
c
in EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame (EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame EvalState target
s target y
target (ADTensorKind y)
cShared Delta target y
d) target y
target (ADTensorKind y)
cShared Delta target y
e
DeltaCastK @r1 Delta target (TKScalar r1)
d ->
EvalState target
-> target (ADTensorKind (TKScalar r1))
-> Delta target (TKScalar r1)
-> EvalState target
forall r (target :: Target).
(GoodScalar r, ADReadyNoLet target, ShareTensor target) =>
EvalState target
-> target (ADTensorKind (TKScalar r))
-> Delta target (TKScalar r)
-> EvalState target
evalRevScalarRuntimeSpecialized
EvalState target
s (FullShapeTK (TKScalar r1)
-> target (TKScalar r1) -> target (ADTensorKind (TKScalar r1))
forall (target :: Target) (y :: TK).
(BaseTensor target, ConvertTensor target, ShareTensor target) =>
FullShapeTK y -> target y -> target (ADTensorKind y)
toADTensorKindShared (forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar @r1) (target (TKScalar r1) -> target (ADTensorKind (TKScalar r1)))
-> target (TKScalar r1) -> target (ADTensorKind (TKScalar r1))
forall a b. (a -> b) -> a -> b
$ target (TKScalar r2) -> target (TKScalar r1)
forall r1 r2.
(RealFrac r1, GoodScalar r1, RealFrac r2, GoodScalar r2) =>
target (TKScalar r1) -> target (TKScalar r2)
forall (target :: Target) r1 r2.
(BaseTensor target, RealFrac r1, GoodScalar r1, RealFrac r2,
GoodScalar r2) =>
target (TKScalar r1) -> target (TKScalar r2)
tkcast target (ADTensorKind y)
target (TKScalar r2)
c) Delta target (TKScalar r1)
d
DeltaCastR Delta target (TKR n r1)
d -> case Delta target (TKR n r1) -> FullShapeTK (TKR n r1)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR n r1)
d of
FullShapeTK (TKR n r1)
y ->
EvalState target
-> target (ADTensorKind (TKR n r1))
-> Delta target (TKR n r1)
-> EvalState target
forall (n :: Nat) r (target :: Target).
(GoodScalar r, ADReadyNoLet target, ShareTensor target) =>
EvalState target
-> target (ADTensorKind (TKR n r))
-> Delta target (TKR n r)
-> EvalState target
evalRevRRuntimeSpecialized
EvalState target
s (FullShapeTK (TKR n r1)
-> target (TKR n r1) -> target (ADTensorKind (TKR n r1))
forall (target :: Target) (y :: TK).
(BaseTensor target, ConvertTensor target, ShareTensor target) =>
FullShapeTK y -> target y -> target (ADTensorKind y)
toADTensorKindShared FullShapeTK (TKR n r1)
y (target (TKR n r1) -> target (ADTensorKind (TKR n r1)))
-> target (TKR n r1) -> target (ADTensorKind (TKR n r1))
forall a b. (a -> b) -> a -> b
$ target (TKR2 n (TKScalar r2)) -> target (TKR n r1)
forall r1 r2 (n :: Nat).
(RealFrac r1, GoodScalar r1, RealFrac r2, GoodScalar r2) =>
target (TKR n r1) -> target (TKR n r2)
forall (target :: Target) r1 r2 (n :: Nat).
(BaseTensor target, RealFrac r1, GoodScalar r1, RealFrac r2,
GoodScalar r2) =>
target (TKR n r1) -> target (TKR n r2)
trcast target (ADTensorKind y)
target (TKR2 n (TKScalar r2))
c) Delta target (TKR n r1)
d
DeltaSum0R Delta target (TKR2 n r)
d -> case Delta target (TKR2 n r) -> FullShapeTK (TKR2 n r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR2 n r)
d of
FTKR IShR n
sh FullShapeTK x
x | SNat n
SNat <- IShR n -> SNat n
forall (n :: Nat) i. ShR n i -> SNat n
shrRank IShR n
sh ->
SingletonTK x
-> (KnownSTK x => EvalState target) -> EvalState target
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => EvalState target) -> EvalState target)
-> (KnownSTK x => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
EvalState target
-> target (ADTensorKind (TKR2 n r))
-> Delta target (TKR2 n r)
-> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame EvalState target
s (IShR n -> target (TKR2 0 r) -> target (TKR2 n r)
forall (n :: Nat) (x :: TK).
(KnownNat n, KnownSTK x) =>
IShR n -> target (TKR2 0 x) -> target (TKR2 n x)
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownNat n, KnownSTK x) =>
IShR n -> target (TKR2 0 x) -> target (TKR2 n x)
trreplicate0N IShR n
sh target (ADTensorKind y)
target (TKR2 0 r)
c) Delta target (TKR2 n r)
d
DeltaDot0R target (TKR n r)
v Delta target (TKR n r)
d -> case Delta target (TKR n r) -> FullShapeTK (TKR n r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR n r)
d of
FTKR IShR n
sh FullShapeTK x
x | SNat n
SNat <- IShR n -> SNat n
forall (n :: Nat) i. ShR n i -> SNat n
shrRank IShR n
sh ->
SingletonTK x
-> (KnownSTK x => EvalState target) -> EvalState target
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => EvalState target) -> EvalState target)
-> (KnownSTK x => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
EvalState target
-> target (ADTensorKind (TKR n r))
-> Delta target (TKR n r)
-> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame EvalState target
s (target (TKR n r)
v target (TKR n r) -> target (TKR n r) -> target (TKR n r)
forall a. Num a => a -> a -> a
* IShR n -> target (TKR2 0 (TKScalar r)) -> target (TKR n r)
forall (n :: Nat) (x :: TK).
(KnownNat n, KnownSTK x) =>
IShR n -> target (TKR2 0 x) -> target (TKR2 n x)
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownNat n, KnownSTK x) =>
IShR n -> target (TKR2 0 x) -> target (TKR2 n x)
trreplicate0N (target (TKR n r) -> IShR n
forall (n :: Nat) (x :: TK).
KnownSTK x =>
target (TKR2 n x) -> IShR n
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKR2 n x) -> IShR n
rshape target (TKR n r)
v) target (ADTensorKind y)
target (TKR2 0 (TKScalar r))
c) Delta target (TKR n r)
d
DeltaIndexR SNat n
SNat Delta target (TKR2 (m + n) r)
d IxROf target m
ix -> case Delta target (TKR2 (m + n) r) -> FullShapeTK (TKR2 (m + n) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR2 (m + n) r)
d of
FTKR IShR n
sh FullShapeTK x
x | SNat m
SNat <- IxROf target m -> SNat m
forall (n :: Nat) i. IxR n i -> SNat n
ixrRank IxROf target m
ix ->
SingletonTK x
-> (KnownSTK x => EvalState target) -> EvalState target
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => EvalState target) -> EvalState target)
-> (KnownSTK x => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
EvalState target
-> target (ADTensorKind (TKR2 (m + n) r))
-> Delta target (TKR2 (m + n) r)
-> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame EvalState target
s (IShR m
-> target (TKR2 n r) -> IxROf target m -> target (TKR2 (m + n) r)
forall (m :: Nat) (n :: Nat) (x :: TK).
(KnownNat m, KnownNat n, KnownSTK x,
(BoolOf (PrimalOf target) :: Type) ~ (BoolOf target :: Type),
EqH (PrimalOf target) (TKScalar Int64)) =>
IShR m
-> target (TKR2 n x) -> IxROf target m -> target (TKR2 (m + n) x)
forall (target :: Target) (m :: Nat) (n :: Nat) (x :: TK).
(BaseTensor target, KnownNat m, KnownNat n, KnownSTK x,
(BoolOf (PrimalOf target) :: Type) ~ (BoolOf target :: Type),
EqH (PrimalOf target) (TKScalar Int64)) =>
IShR m
-> target (TKR2 n x) -> IxROf target m -> target (TKR2 (m + n) x)
troneHot (ShR (m + n) Int -> IShR m
forall (m :: Nat) (n :: Nat) i.
(KnownNat n, KnownNat m) =>
ShR (m + n) i -> ShR m i
shrTake IShR n
ShR (m + n) Int
sh) target (ADTensorKind y)
target (TKR2 n r)
c IxROf target m
ix) Delta target (TKR2 (m + n) r)
d
DeltaScatterR SNat m
SNat SNat n
SNat SNat p
SNat IShR (p + n)
_sh Delta target (TKR2 (m + n) r)
d IxROf target m -> IxROf target p
f -> case Delta target (TKR2 (m + n) r) -> FullShapeTK (TKR2 (m + n) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR2 (m + n) r)
d of
FTKR IShR n
sh FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => EvalState target) -> EvalState target
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => EvalState target) -> EvalState target)
-> (KnownSTK x => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
EvalState target
-> target (ADTensorKind (TKR2 (m + n) r))
-> Delta target (TKR2 (m + n) r)
-> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame EvalState target
s (IShR (m + n)
-> target (TKR2 (p + n) r)
-> (IxROf target m -> IxROf target p)
-> target (TKR2 (m + n) r)
forall (m :: Nat) (n :: Nat) (p :: Nat) (x :: TK).
(KnownNat m, KnownNat n, KnownNat p, KnownSTK x) =>
IShR (m + n)
-> target (TKR2 (p + n) x)
-> (IxROf target m -> IxROf target p)
-> target (TKR2 (m + n) x)
forall (target :: Target) (m :: Nat) (n :: Nat) (p :: Nat)
(x :: TK).
(BaseTensor target, KnownNat m, KnownNat n, KnownNat p,
KnownSTK x) =>
IShR (m + n)
-> target (TKR2 (p + n) x)
-> (IxROf target m -> IxROf target p)
-> target (TKR2 (m + n) x)
trgather IShR n
IShR (m + n)
sh target (ADTensorKind y)
target (TKR2 (p + n) r)
c IxROf target m -> IxROf target p
f) Delta target (TKR2 (m + n) r)
d
DeltaGatherR SNat m
SNat SNat n
SNat SNat p
SNat IShR (m + n)
_sh Delta target (TKR2 (p + n) r)
d IxROf target m -> IxROf target p
f -> case Delta target (TKR2 (p + n) r) -> FullShapeTK (TKR2 (p + n) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR2 (p + n) r)
d of
FTKR IShR n
sh FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => EvalState target) -> EvalState target
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => EvalState target) -> EvalState target)
-> (KnownSTK x => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
EvalState target
-> target (ADTensorKind (TKR2 (p + n) r))
-> Delta target (TKR2 (p + n) r)
-> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame EvalState target
s (IShR (p + n)
-> target (TKR2 (m + n) r)
-> (IxROf target m -> IxROf target p)
-> target (TKR2 (p + n) r)
forall (m :: Nat) (n :: Nat) (p :: Nat) (x :: TK).
(KnownNat m, KnownNat n, KnownNat p, KnownSTK x) =>
IShR (p + n)
-> target (TKR2 (m + n) x)
-> (IxROf target m -> IxROf target p)
-> target (TKR2 (p + n) x)
forall (target :: Target) (m :: Nat) (n :: Nat) (p :: Nat)
(x :: TK).
(BaseTensor target, KnownNat m, KnownNat n, KnownNat p,
KnownSTK x) =>
IShR (p + n)
-> target (TKR2 (m + n) x)
-> (IxROf target m -> IxROf target p)
-> target (TKR2 (p + n) x)
trscatter IShR n
IShR (p + n)
sh target (ADTensorKind y)
target (TKR2 (m + n) r)
c IxROf target m -> IxROf target p
f) Delta target (TKR2 (p + n) r)
d
DeltaAppendR Delta target (TKR2 (1 + n) r)
d Delta target (TKR2 (1 + n) r)
e -> case (Delta target (TKR2 (1 + n) r) -> FullShapeTK (TKR2 (1 + n) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR2 (1 + n) r)
d, Delta target (TKR2 (1 + n) r) -> FullShapeTK (TKR2 (1 + n) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR2 (1 + n) r)
e) of
(FTKR (Int
m :$: ShR n Int
_) FullShapeTK x
x, FTKR (Int
n :$: ShR n Int
_) FullShapeTK x
_) ->
SingletonTK x
-> (KnownSTK x => EvalState target) -> EvalState target
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => EvalState target) -> EvalState target)
-> (KnownSTK x => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
let cShared :: target (TKR2 n r)
cShared = target (TKR2 n r) -> target (TKR2 n r)
forall (y :: TK). target y -> target y
forall (target :: Target) (y :: TK).
ShareTensor target =>
target y -> target y
tshare target (ADTensorKind y)
target (TKR2 n r)
c
s2 :: EvalState target
s2 = EvalState target
-> target (ADTensorKind (TKR2 (1 + n) r))
-> Delta target (TKR2 (1 + n) r)
-> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame EvalState target
s (Int -> Int -> target (TKR2 (1 + n) r) -> target (TKR2 (1 + n) r)
forall (n :: Nat) (x :: TK).
KnownSTK x =>
Int -> Int -> target (TKR2 (1 + n) x) -> target (TKR2 (1 + n) x)
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
Int -> Int -> target (TKR2 (1 + n) x) -> target (TKR2 (1 + n) x)
trslice Int
0 Int
m target (TKR2 n r)
target (TKR2 (1 + n) r)
cShared) Delta target (TKR2 (1 + n) r)
d
in EvalState target
-> target (ADTensorKind (TKR2 (1 + n) r))
-> Delta target (TKR2 (1 + n) r)
-> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame EvalState target
s2 (Int -> Int -> target (TKR2 (1 + n) r) -> target (TKR2 (1 + n) r)
forall (n :: Nat) (x :: TK).
KnownSTK x =>
Int -> Int -> target (TKR2 (1 + n) x) -> target (TKR2 (1 + n) x)
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
Int -> Int -> target (TKR2 (1 + n) x) -> target (TKR2 (1 + n) x)
trslice Int
m Int
n target (TKR2 n r)
target (TKR2 (1 + n) r)
cShared) Delta target (TKR2 (1 + n) r)
e
(FullShapeTK (TKR2 (1 + n) r), FullShapeTK (TKR2 (1 + n) r))
_ -> String -> EvalState target
forall a. (?callStack::CallStack) => String -> a
error String
"evalRevSame: impossible pattern needlessly required"
DeltaSliceR Int
i Int
n Delta target (TKR2 (1 + n) r)
d -> case Delta target (TKR2 (1 + n) r) -> FullShapeTK (TKR2 (1 + n) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR2 (1 + n) r)
d of
FTKR (Int
l :$: ShR n Int
rest) FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => EvalState target) -> EvalState target
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => EvalState target) -> EvalState target)
-> (KnownSTK x => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
EvalState target
-> target (ADTensorKind (TKR2 (1 + n) x))
-> Delta target (TKR2 (1 + n) x)
-> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame EvalState target
s (target (TKR2 (1 + n) x)
-> target (TKR2 (1 + n) x) -> target (TKR2 (1 + n) x)
forall (n :: Nat) (x :: TK).
KnownSTK x =>
target (TKR2 (1 + n) x)
-> target (TKR2 (1 + n) x) -> target (TKR2 (1 + n) x)
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKR2 (1 + n) x)
-> target (TKR2 (1 + n) x) -> target (TKR2 (1 + n) x)
trappend
(FullShapeTK (TKR2 (1 + n) x) -> target (TKR2 (1 + n) x)
forall (y :: TK). FullShapeTK y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
FullShapeTK y -> target y
tdefTarget (IShR (1 + n) -> FullShapeTK x -> FullShapeTK (TKR2 (1 + n) x)
forall (n :: Nat) (x :: TK).
IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
FTKR (Int
i Int -> ShR n Int -> IShR (1 + n)
forall {n1 :: Nat} {i} (n :: Nat).
((n + 1 :: Nat) ~ (n1 :: Nat)) =>
i -> ShR n i -> ShR n1 i
:$: ShR n Int
rest) FullShapeTK x
x))
(target (TKR2 (1 + n) x)
-> target (TKR2 (1 + n) x) -> target (TKR2 (1 + n) x)
forall (n :: Nat) (x :: TK).
KnownSTK x =>
target (TKR2 (1 + n) x)
-> target (TKR2 (1 + n) x) -> target (TKR2 (1 + n) x)
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKR2 (1 + n) x)
-> target (TKR2 (1 + n) x) -> target (TKR2 (1 + n) x)
trappend target (ADTensorKind y)
target (TKR2 (1 + n) x)
c
(FullShapeTK (TKR2 n x) -> target (TKR2 n x)
forall (y :: TK). FullShapeTK y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
FullShapeTK y -> target y
tdefTarget (IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
forall (n :: Nat) (x :: TK).
IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
FTKR (Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n Int -> ShR n Int -> IShR n
forall {n1 :: Nat} {i} (n :: Nat).
((n + 1 :: Nat) ~ (n1 :: Nat)) =>
i -> ShR n i -> ShR n1 i
:$: ShR n Int
rest) FullShapeTK x
x)))) Delta target (TKR2 (1 + n) r)
Delta target (TKR2 (1 + n) x)
d
FTKR ShR n Int
ZSR FullShapeTK x
_ -> String -> EvalState target
forall a. (?callStack::CallStack) => String -> a
error String
"evalRevSame: impossible pattern needlessly required"
DeltaReverseR Delta target (TKR2 (1 + n) r)
d -> case Delta target (TKR2 (1 + n) r) -> FullShapeTK (TKR2 (1 + n) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR2 (1 + n) r)
d of
FTKR IShR n
_ FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => EvalState target) -> EvalState target
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => EvalState target) -> EvalState target)
-> (KnownSTK x => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
EvalState target
-> target (ADTensorKind (TKR2 (1 + n) r))
-> Delta target (TKR2 (1 + n) r)
-> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame EvalState target
s (target (TKR2 (1 + n) r) -> target (TKR2 (1 + n) r)
forall (n :: Nat) (x :: TK).
KnownSTK x =>
target (TKR2 (1 + n) x) -> target (TKR2 (1 + n) x)
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKR2 (1 + n) x) -> target (TKR2 (1 + n) x)
trreverse target (ADTensorKind y)
target (TKR2 (1 + n) r)
c) Delta target (TKR2 (1 + n) r)
d
DeltaTransposeR PermR
perm Delta target (TKR2 n r)
d -> case Delta target (TKR2 n r) -> FullShapeTK (TKR2 n r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR2 n r)
d of
FTKR IShR n
_ FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => EvalState target) -> EvalState target
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => EvalState target) -> EvalState target)
-> (KnownSTK x => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
let permR :: PermR
permR = PermR -> PermR
permRInverse PermR
perm
in EvalState target
-> target (ADTensorKind (TKR2 n r))
-> Delta target (TKR2 n r)
-> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame EvalState target
s (PermR -> target (TKR2 n r) -> target (TKR2 n r)
forall (n :: Nat) (x :: TK).
KnownSTK x =>
PermR -> target (TKR2 n x) -> target (TKR2 n x)
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
PermR -> target (TKR2 n x) -> target (TKR2 n x)
trtranspose PermR
permR target (ADTensorKind y)
target (TKR2 n r)
c) Delta target (TKR2 n r)
d
DeltaReshapeR IShR m
_sh2 Delta target (TKR2 n r)
d -> case Delta target (TKR2 n r) -> FullShapeTK (TKR2 n r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR2 n r)
d of
FTKR IShR n
sh FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => EvalState target) -> EvalState target
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => EvalState target) -> EvalState target)
-> (KnownSTK x => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
EvalState target
-> target (ADTensorKind (TKR2 n r))
-> Delta target (TKR2 n r)
-> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame EvalState target
s (IShR n -> target (TKR2 m r) -> target (TKR2 n r)
forall (n :: Nat) (m :: Nat) (x :: TK).
KnownSTK x =>
IShR m -> target (TKR2 n x) -> target (TKR2 m x)
forall (target :: Target) (n :: Nat) (m :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
IShR m -> target (TKR2 n x) -> target (TKR2 m x)
trreshape IShR n
sh target (ADTensorKind y)
target (TKR2 m r)
c) Delta target (TKR2 n r)
d
DeltaCastS Delta target (TKS sh r1)
d -> case Delta target (TKS sh r1) -> FullShapeTK (TKS sh r1)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS sh r1)
d of
FullShapeTK (TKS sh r1)
y ->
EvalState target
-> target (ADTensorKind (TKS sh r1))
-> Delta target (TKS sh r1)
-> EvalState target
forall (sh :: [Nat]) r (target :: Target).
(GoodScalar r, ADReadyNoLet target, ShareTensor target) =>
EvalState target
-> target (ADTensorKind (TKS sh r))
-> Delta target (TKS sh r)
-> EvalState target
evalSRuntimeSpecialized
EvalState target
s (FullShapeTK (TKS sh r1)
-> target (TKS sh r1) -> target (ADTensorKind (TKS sh r1))
forall (target :: Target) (y :: TK).
(BaseTensor target, ConvertTensor target, ShareTensor target) =>
FullShapeTK y -> target y -> target (ADTensorKind y)
toADTensorKindShared FullShapeTK (TKS sh r1)
y (target (TKS sh r1) -> target (ADTensorKind (TKS sh r1)))
-> target (TKS sh r1) -> target (ADTensorKind (TKS sh r1))
forall a b. (a -> b) -> a -> b
$ target (TKS2 sh (TKScalar r2)) -> target (TKS sh r1)
forall r1 r2 (sh :: [Nat]).
(RealFrac r1, GoodScalar r1, RealFrac r2, GoodScalar r2) =>
target (TKS sh r1) -> target (TKS sh r2)
forall (target :: Target) r1 r2 (sh :: [Nat]).
(BaseTensor target, RealFrac r1, GoodScalar r1, RealFrac r2,
GoodScalar r2) =>
target (TKS sh r1) -> target (TKS sh r2)
tscast target (ADTensorKind y)
target (TKS2 sh (TKScalar r2))
c) Delta target (TKS sh r1)
d
DeltaSum0S Delta target (TKS2 sh r)
d -> case Delta target (TKS2 sh r) -> FullShapeTK (TKS2 sh r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS2 sh r)
d of
FTKS ShS sh
sh FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => EvalState target) -> EvalState target
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => EvalState target) -> EvalState target)
-> (KnownSTK x => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
EvalState target
-> target (ADTensorKind (TKS2 sh r))
-> Delta target (TKS2 sh r)
-> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame EvalState target
s (ShS sh -> target (TKS2 ('[] @Nat) r) -> target (TKS2 sh r)
forall (sh :: [Nat]) (x :: TK).
KnownSTK x =>
ShS sh -> target (TKS2 ('[] @Nat) x) -> target (TKS2 sh x)
forall (target :: Target) (sh :: [Nat]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
ShS sh -> target (TKS2 ('[] @Nat) x) -> target (TKS2 sh x)
tsreplicate0N ShS sh
sh target (ADTensorKind y)
target (TKS2 ('[] @Nat) r)
c) Delta target (TKS2 sh r)
d
DeltaDot0S target (TKS sh r)
v Delta target (TKS sh r)
d -> case Delta target (TKS sh r) -> FullShapeTK (TKS sh r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS sh r)
d of
FTKS ShS sh
sh FullShapeTK x
FTKScalar ->
EvalState target
-> target (ADTensorKind (TKS sh r))
-> Delta target (TKS sh r)
-> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame EvalState target
s (target (TKS sh r)
v target (TKS sh r) -> target (TKS sh r) -> target (TKS sh r)
forall a. Num a => a -> a -> a
* ShS sh
-> target (TKS2 ('[] @Nat) (TKScalar r)) -> target (TKS sh r)
forall (sh :: [Nat]) (x :: TK).
KnownSTK x =>
ShS sh -> target (TKS2 ('[] @Nat) x) -> target (TKS2 sh x)
forall (target :: Target) (sh :: [Nat]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
ShS sh -> target (TKS2 ('[] @Nat) x) -> target (TKS2 sh x)
tsreplicate0N ShS sh
ShS sh
sh target (ADTensorKind y)
target (TKS2 ('[] @Nat) (TKScalar r))
c) Delta target (TKS sh r)
d
DeltaIndexS ShS shn
shn Delta target (TKS2 ((++) @Nat shm shn) r)
d IxSOf target shm
ix -> case Delta target (TKS2 ((++) @Nat shm shn) r)
-> FullShapeTK (TKS2 ((++) @Nat shm shn) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS2 ((++) @Nat shm shn) r)
d of
FTKS ShS sh
_ FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => EvalState target) -> EvalState target
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => EvalState target) -> EvalState target)
-> (KnownSTK x => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
ShS shn -> (KnownShS shn => EvalState target) -> EvalState target
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS shn
shn ((KnownShS shn => EvalState target) -> EvalState target)
-> (KnownShS shn => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
ShS shm -> (KnownShS shm => EvalState target) -> EvalState target
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS (IxSOf target shm -> ShS shm
forall (sh :: [Nat]) i. IxS sh i -> ShS sh
shsFromIxS IxSOf target shm
ix) ((KnownShS shm => EvalState target) -> EvalState target)
-> (KnownShS shm => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
EvalState target
-> target (ADTensorKind (TKS2 ((++) @Nat shm shn) r))
-> Delta target (TKS2 ((++) @Nat shm shn) r)
-> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame EvalState target
s (target (TKS2 shn r)
-> IxSOf target shm -> target (TKS2 ((++) @Nat shm shn) r)
forall (sh1 :: [Nat]) (sh2 :: [Nat]) (x :: TK).
(KnownShS sh1, KnownShS sh2, KnownSTK x,
(BoolOf (PrimalOf target) :: Type) ~ (BoolOf target :: Type),
EqH (PrimalOf target) (TKScalar Int64)) =>
target (TKS2 sh2 x)
-> IxSOf target sh1 -> target (TKS2 ((++) @Nat sh1 sh2) x)
forall (target :: Target) (sh1 :: [Nat]) (sh2 :: [Nat]) (x :: TK).
(BaseTensor target, KnownShS sh1, KnownShS sh2, KnownSTK x,
(BoolOf (PrimalOf target) :: Type) ~ (BoolOf target :: Type),
EqH (PrimalOf target) (TKScalar Int64)) =>
target (TKS2 sh2 x)
-> IxSOf target sh1 -> target (TKS2 ((++) @Nat sh1 sh2) x)
tsoneHot target (ADTensorKind y)
target (TKS2 shn r)
c IxSOf target shm
ix) Delta target (TKS2 ((++) @Nat shm shn) r)
d
DeltaScatterS @shm @shn ShS shm
shm ShS shn
shn ShS shp
shp Delta target (TKS2 ((++) @Nat shm shn) r)
d IxSOf target shm -> IxSOf target shp
f -> case Delta target (TKS2 ((++) @Nat shm shn) r)
-> FullShapeTK (TKS2 ((++) @Nat shm shn) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS2 ((++) @Nat shm shn) r)
d of
FTKS ShS sh
_ FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => EvalState target) -> EvalState target
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => EvalState target) -> EvalState target)
-> (KnownSTK x => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
ShS shm -> (KnownShS shm => EvalState target) -> EvalState target
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS shm
shm ((KnownShS shm => EvalState target) -> EvalState target)
-> (KnownShS shm => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
ShS shn -> (KnownShS shn => EvalState target) -> EvalState target
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS shn
shn ((KnownShS shn => EvalState target) -> EvalState target)
-> (KnownShS shn => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
ShS shp -> (KnownShS shp => EvalState target) -> EvalState target
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS shp
shp ((KnownShS shp => EvalState target) -> EvalState target)
-> (KnownShS shp => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
EvalState target
-> target (ADTensorKind (TKS2 sh r))
-> Delta target (TKS2 sh r)
-> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame EvalState target
s (forall (target :: Target) (shm :: [Nat]) (shn :: [Nat])
(shp :: [Nat]) (x :: TK).
(BaseTensor target, KnownShS shm, KnownShS shn, KnownShS shp,
KnownSTK x) =>
target (TKS2 ((++) @Nat shp shn) x)
-> (IxSOf target shm -> IxSOf target shp)
-> target (TKS2 ((++) @Nat shm shn) x)
tsgather @_ @shm @shn target (ADTensorKind y)
target (TKS2 ((++) @Nat shp shn) r)
c IxSOf target shm -> IxSOf target shp
f) Delta target (TKS2 sh r)
Delta target (TKS2 ((++) @Nat shm shn) r)
d
DeltaGatherS @shm @shn ShS shm
shm ShS shn
shn ShS shp
shp Delta target (TKS2 ((++) @Nat shp shn) r)
d IxSOf target shm -> IxSOf target shp
f -> case Delta target (TKS2 ((++) @Nat shp shn) r)
-> FullShapeTK (TKS2 ((++) @Nat shp shn) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS2 ((++) @Nat shp shn) r)
d of
FTKS ShS sh
_ FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => EvalState target) -> EvalState target
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => EvalState target) -> EvalState target)
-> (KnownSTK x => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
ShS shm -> (KnownShS shm => EvalState target) -> EvalState target
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS shm
shm ((KnownShS shm => EvalState target) -> EvalState target)
-> (KnownShS shm => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
ShS shn -> (KnownShS shn => EvalState target) -> EvalState target
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS shn
shn ((KnownShS shn => EvalState target) -> EvalState target)
-> (KnownShS shn => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
ShS shp -> (KnownShS shp => EvalState target) -> EvalState target
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS shp
shp ((KnownShS shp => EvalState target) -> EvalState target)
-> (KnownShS shp => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
EvalState target
-> target (ADTensorKind (TKS2 sh r))
-> Delta target (TKS2 sh r)
-> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame EvalState target
s (forall (target :: Target) (shm :: [Nat]) (shn :: [Nat])
(shp :: [Nat]) (x :: TK).
(BaseTensor target, KnownShS shm, KnownShS shn, KnownShS shp,
KnownSTK x) =>
target (TKS2 ((++) @Nat shm shn) x)
-> (IxSOf target shm -> IxSOf target shp)
-> target (TKS2 ((++) @Nat shp shn) x)
tsscatter @_ @shm @shn target (ADTensorKind y)
target (TKS2 ((++) @Nat shm shn) r)
c IxSOf target shm -> IxSOf target shp
f) Delta target (TKS2 sh r)
Delta target (TKS2 ((++) @Nat shp shn) r)
d
DeltaAppendS Delta target (TKS2 ((':) @Nat m sh) r)
d Delta target (TKS2 ((':) @Nat n sh) r)
e -> case (Delta target (TKS2 ((':) @Nat m sh) r)
-> FullShapeTK (TKS2 ((':) @Nat m sh) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS2 ((':) @Nat m sh) r)
d, Delta target (TKS2 ((':) @Nat n sh) r)
-> FullShapeTK (TKS2 ((':) @Nat n sh) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS2 ((':) @Nat n sh) r)
e) of
(FTKS (SNat n
msnat :$$ ShS sh
_) FullShapeTK x
x, FTKS (SNat n
_ :$$ ShS sh
_) FullShapeTK x
_) ->
SingletonTK x
-> (KnownSTK x => EvalState target) -> EvalState target
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => EvalState target) -> EvalState target)
-> (KnownSTK x => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
let cShared :: target (TKS2 ((':) @Nat (m + n) sh) r)
cShared = target (TKS2 ((':) @Nat (m + n) sh) r)
-> target (TKS2 ((':) @Nat (m + n) sh) r)
forall (y :: TK). target y -> target y
forall (target :: Target) (y :: TK).
ShareTensor target =>
target y -> target y
tshare target (ADTensorKind y)
target (TKS2 ((':) @Nat (m + n) sh) r)
c
s2 :: EvalState target
s2 = EvalState target
-> target (ADTensorKind (TKS2 ((':) @Nat m sh) r))
-> Delta target (TKS2 ((':) @Nat m sh) r)
-> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame EvalState target
s (SNat 0
-> SNat m
-> SNat n
-> target (TKS2 ((':) @Nat ((0 + m) + n) sh) r)
-> target (TKS2 ((':) @Nat m sh) r)
forall (i :: Nat) (n :: Nat) (k :: Nat) (sh :: [Nat]) (x :: TK).
KnownSTK x =>
SNat i
-> SNat n
-> SNat k
-> target (TKS2 ((':) @Nat ((i + n) + k) sh) x)
-> target (TKS2 ((':) @Nat n sh) x)
forall (target :: Target) (i :: Nat) (n :: Nat) (k :: Nat)
(sh :: [Nat]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
SNat i
-> SNat n
-> SNat k
-> target (TKS2 ((':) @Nat ((i + n) + k) sh) x)
-> target (TKS2 ((':) @Nat n sh) x)
tsslice (forall (n :: Nat). KnownNat n => SNat n
SNat @0) SNat m
forall (n :: Nat). KnownNat n => SNat n
SNat SNat n
forall (n :: Nat). KnownNat n => SNat n
SNat target (TKS2 ((':) @Nat (m + n) sh) r)
target (TKS2 ((':) @Nat ((0 + m) + n) sh) r)
cShared) Delta target (TKS2 ((':) @Nat m sh) r)
d
in EvalState target
-> target (ADTensorKind (TKS2 ((':) @Nat n sh) r))
-> Delta target (TKS2 ((':) @Nat n sh) r)
-> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame EvalState target
s2 (SNat n
-> SNat n
-> SNat 0
-> target (TKS2 ((':) @Nat ((n + n) + 0) sh) r)
-> target (TKS2 ((':) @Nat n sh) r)
forall (i :: Nat) (n :: Nat) (k :: Nat) (sh :: [Nat]) (x :: TK).
KnownSTK x =>
SNat i
-> SNat n
-> SNat k
-> target (TKS2 ((':) @Nat ((i + n) + k) sh) x)
-> target (TKS2 ((':) @Nat n sh) x)
forall (target :: Target) (i :: Nat) (n :: Nat) (k :: Nat)
(sh :: [Nat]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
SNat i
-> SNat n
-> SNat k
-> target (TKS2 ((':) @Nat ((i + n) + k) sh) x)
-> target (TKS2 ((':) @Nat n sh) x)
tsslice SNat n
msnat SNat n
forall (n :: Nat). KnownNat n => SNat n
SNat SNat 0
forall (n :: Nat). KnownNat n => SNat n
SNat target (TKS2 ((':) @Nat (m + n) sh) r)
target (TKS2 ((':) @Nat ((n + n) + 0) sh) r)
cShared) Delta target (TKS2 ((':) @Nat n sh) r)
e
DeltaSliceS i :: SNat i
i@SNat i
SNat SNat n
_ k :: SNat k
k@SNat k
SNat Delta target (TKS2 ((':) @Nat ((i + n) + k) sh) r)
d -> case Delta target (TKS2 ((':) @Nat ((i + n) + k) sh) r)
-> FullShapeTK (TKS2 ((':) @Nat ((i + n) + k) sh) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS2 ((':) @Nat ((i + n) + k) sh) r)
d of
FTKS (SNat n
_ :$$ ShS sh
sh) FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => EvalState target) -> EvalState target
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => EvalState target) -> EvalState target)
-> (KnownSTK x => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
EvalState target
-> target (ADTensorKind (TKS2 ((':) @Nat (i + (n + k)) sh) r))
-> Delta target (TKS2 ((':) @Nat (i + (n + k)) sh) r)
-> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame EvalState target
s (target (TKS2 ((':) @Nat i sh) x)
-> target (TKS2 ((':) @Nat (n + k) sh) x)
-> target (TKS2 ((':) @Nat (i + (n + k)) sh) x)
forall (m :: Nat) (n :: Nat) (sh :: [Nat]) (x :: TK).
KnownSTK x =>
target (TKS2 ((':) @Nat m sh) x)
-> target (TKS2 ((':) @Nat n sh) x)
-> target (TKS2 ((':) @Nat (m + n) sh) x)
forall (target :: Target) (m :: Nat) (n :: Nat) (sh :: [Nat])
(x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKS2 ((':) @Nat m sh) x)
-> target (TKS2 ((':) @Nat n sh) x)
-> target (TKS2 ((':) @Nat (m + n) sh) x)
tsappend
(FullShapeTK (TKS2 ((':) @Nat i sh) x)
-> target (TKS2 ((':) @Nat i sh) x)
forall (y :: TK). FullShapeTK y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
FullShapeTK y -> target y
tdefTarget (ShS ((':) @Nat i sh)
-> FullShapeTK x -> FullShapeTK (TKS2 ((':) @Nat i sh) x)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS (SNat i
i SNat i -> ShS sh -> ShS ((':) @Nat i sh)
forall {sh1 :: [Nat]} (n :: Nat) (sh :: [Nat]).
(KnownNat n, ((':) @Nat n sh :: [Nat]) ~ (sh1 :: [Nat])) =>
SNat n -> ShS sh -> ShS sh1
:$$ ShS sh
sh) FullShapeTK x
x))
(target (TKS2 ((':) @Nat n sh) x)
-> target (TKS2 ((':) @Nat k sh) x)
-> target (TKS2 ((':) @Nat (n + k) sh) x)
forall (m :: Nat) (n :: Nat) (sh :: [Nat]) (x :: TK).
KnownSTK x =>
target (TKS2 ((':) @Nat m sh) x)
-> target (TKS2 ((':) @Nat n sh) x)
-> target (TKS2 ((':) @Nat (m + n) sh) x)
forall (target :: Target) (m :: Nat) (n :: Nat) (sh :: [Nat])
(x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKS2 ((':) @Nat m sh) x)
-> target (TKS2 ((':) @Nat n sh) x)
-> target (TKS2 ((':) @Nat (m + n) sh) x)
tsappend
target (ADTensorKind y)
target (TKS2 ((':) @Nat n sh) x)
c (FullShapeTK (TKS2 ((':) @Nat k sh) x)
-> target (TKS2 ((':) @Nat k sh) x)
forall (y :: TK). FullShapeTK y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
FullShapeTK y -> target y
tdefTarget (ShS ((':) @Nat k sh)
-> FullShapeTK x -> FullShapeTK (TKS2 ((':) @Nat k sh) x)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS (SNat k
k SNat k -> ShS sh -> ShS ((':) @Nat k sh)
forall {sh1 :: [Nat]} (n :: Nat) (sh :: [Nat]).
(KnownNat n, ((':) @Nat n sh :: [Nat]) ~ (sh1 :: [Nat])) =>
SNat n -> ShS sh -> ShS sh1
:$$ ShS sh
sh) FullShapeTK x
x)))) Delta target (TKS2 ((':) @Nat (i + (n + k)) sh) r)
Delta target (TKS2 ((':) @Nat ((i + n) + k) sh) r)
d
DeltaReverseS Delta target (TKS2 ((':) @Nat n sh) r)
d -> case Delta target (TKS2 ((':) @Nat n sh) r)
-> FullShapeTK (TKS2 ((':) @Nat n sh) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS2 ((':) @Nat n sh) r)
d of
FTKS ShS sh
_ FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => EvalState target) -> EvalState target
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => EvalState target) -> EvalState target)
-> (KnownSTK x => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
EvalState target
-> target (ADTensorKind (TKS2 ((':) @Nat n sh) r))
-> Delta target (TKS2 ((':) @Nat n sh) r)
-> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame EvalState target
s (target (TKS2 ((':) @Nat n sh) r)
-> target (TKS2 ((':) @Nat n sh) r)
forall (n :: Nat) (sh :: [Nat]) (x :: TK).
KnownSTK x =>
target (TKS2 ((':) @Nat n sh) x)
-> target (TKS2 ((':) @Nat n sh) x)
forall (target :: Target) (n :: Nat) (sh :: [Nat]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKS2 ((':) @Nat n sh) x)
-> target (TKS2 ((':) @Nat n sh) x)
tsreverse target (ADTensorKind y)
target (TKS2 ((':) @Nat n sh) r)
c) Delta target (TKS2 ((':) @Nat n sh) r)
d
DeltaTransposeS @perm @sh2 Perm perm
perm Delta target (TKS2 sh r)
d -> case Delta target (TKS2 sh r) -> FullShapeTK (TKS2 sh r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS2 sh r)
d of
FTKS ShS sh
_ FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => EvalState target) -> EvalState target
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => EvalState target) -> EvalState target)
-> (KnownSTK x => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
Perm perm
-> (forall (is' :: [Nat]).
(Assert
(AllElem' @Nat is' (Count 0 (Rank @Nat is'))) (TypeError ...),
Assert
(AllElem' @Nat (Count 0 (Rank @Nat is')) is') (TypeError ...)) =>
Perm is'
-> (forall (sh :: [Maybe Nat]).
((Rank @(Maybe Nat) sh :: Nat) ~ (Rank @Nat perm :: Nat)) =>
StaticShX sh
-> (:~:)
@[Maybe Nat]
(Permute @(Maybe Nat) is' (Permute @(Maybe Nat) perm sh))
sh)
-> EvalState target)
-> EvalState target
forall (is :: [Nat]) r.
Perm is
-> (forall (is' :: [Nat]).
IsPermutation is' =>
Perm is'
-> (forall (sh :: [Maybe Nat]).
((Rank @(Maybe Nat) sh :: Nat) ~ (Rank @Nat is :: Nat)) =>
StaticShX sh
-> (:~:)
@[Maybe Nat]
(Permute @(Maybe Nat) is' (Permute @(Maybe Nat) is sh))
sh)
-> r)
-> r
permInverse Perm perm
perm ((forall (is' :: [Nat]).
(Assert
(AllElem' @Nat is' (Count 0 (Rank @Nat is'))) (TypeError ...),
Assert
(AllElem' @Nat (Count 0 (Rank @Nat is')) is') (TypeError ...)) =>
Perm is'
-> (forall (sh :: [Maybe Nat]).
((Rank @(Maybe Nat) sh :: Nat) ~ (Rank @Nat perm :: Nat)) =>
StaticShX sh
-> (:~:)
@[Maybe Nat]
(Permute @(Maybe Nat) is' (Permute @(Maybe Nat) perm sh))
sh)
-> EvalState target)
-> EvalState target)
-> (forall (is' :: [Nat]).
(Assert
(AllElem' @Nat is' (Count 0 (Rank @Nat is'))) (TypeError ...),
Assert
(AllElem' @Nat (Count 0 (Rank @Nat is')) is') (TypeError ...)) =>
Perm is'
-> (forall (sh :: [Maybe Nat]).
((Rank @(Maybe Nat) sh :: Nat) ~ (Rank @Nat perm :: Nat)) =>
StaticShX sh
-> (:~:)
@[Maybe Nat]
(Permute @(Maybe Nat) is' (Permute @(Maybe Nat) perm sh))
sh)
-> EvalState target)
-> EvalState target
forall a b. (a -> b) -> a -> b
$ \(Perm is'
permRev :: Permutation.Perm permR) forall (sh :: [Maybe Nat]).
((Rank @(Maybe Nat) sh :: Nat) ~ (Rank @Nat perm :: Nat)) =>
StaticShX sh
-> (:~:)
@[Maybe Nat]
(Permute @(Maybe Nat) is' (Permute @(Maybe Nat) perm sh))
sh
_ ->
(:~:)
@[Nat]
((++)
@Nat
(Permute
@Nat is' (TakeLen @Nat @Nat is' (PermutePrefix @Nat perm sh)))
(DropLen @Nat @Nat is' (PermutePrefix @Nat perm sh)))
sh
-> ((((++)
@Nat
(Permute
@Nat is' (TakeLen @Nat @Nat is' (PermutePrefix @Nat perm sh)))
(DropLen @Nat @Nat is' (PermutePrefix @Nat perm sh)) :: [Nat])
~ (sh :: [Nat])) =>
EvalState target)
-> EvalState target
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:)
@[Nat]
((++)
@Nat
(Permute
@Nat is' (TakeLen @Nat @Nat is' (PermutePrefix @Nat perm sh)))
(DropLen @Nat @Nat is' (PermutePrefix @Nat perm sh)))
sh
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl
:: Permutation.PermutePrefix
permR (Permutation.PermutePrefix perm sh2) :~: sh2)
(((((++)
@Nat
(Permute
@Nat is' (TakeLen @Nat @Nat is' (PermutePrefix @Nat perm sh)))
(DropLen @Nat @Nat is' (PermutePrefix @Nat perm sh)) :: [Nat])
~ (sh :: [Nat])) =>
EvalState target)
-> EvalState target)
-> ((((++)
@Nat
(Permute
@Nat is' (TakeLen @Nat @Nat is' (PermutePrefix @Nat perm sh)))
(DropLen @Nat @Nat is' (PermutePrefix @Nat perm sh)) :: [Nat])
~ (sh :: [Nat])) =>
EvalState target)
-> EvalState target
forall a b. (a -> b) -> a -> b
$ (:~:) @Nat (Rank @Nat (PermutePrefix @Nat perm sh)) (Rank @Nat sh)
-> (((Rank @Nat (PermutePrefix @Nat perm sh) :: Nat)
~ (Rank @Nat sh :: Nat)) =>
EvalState target)
-> EvalState target
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:) @Nat (Rank @Nat (PermutePrefix @Nat perm sh)) (Rank @Nat sh)
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl
:: Rank (Permutation.PermutePrefix perm sh2) :~: Rank sh2)
((((Rank @Nat (PermutePrefix @Nat perm sh) :: Nat)
~ (Rank @Nat sh :: Nat)) =>
EvalState target)
-> EvalState target)
-> (((Rank @Nat (PermutePrefix @Nat perm sh) :: Nat)
~ (Rank @Nat sh :: Nat)) =>
EvalState target)
-> EvalState target
forall a b. (a -> b) -> a -> b
$ (:~:) @Nat (Rank @Nat is') (Rank @Nat perm)
-> (((Rank @Nat is' :: Nat) ~ (Rank @Nat perm :: Nat)) =>
EvalState target)
-> EvalState target
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:) @Nat (Rank @Nat is') (Rank @Nat perm)
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl
:: Rank permR :~: Rank perm)
((((Rank @Nat is' :: Nat) ~ (Rank @Nat perm :: Nat)) =>
EvalState target)
-> EvalState target)
-> (((Rank @Nat is' :: Nat) ~ (Rank @Nat perm :: Nat)) =>
EvalState target)
-> EvalState target
forall a b. (a -> b) -> a -> b
$ EvalState target
-> target (ADTensorKind (TKS2 sh r))
-> Delta target (TKS2 sh r)
-> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame EvalState target
s (Perm is'
-> target (TKS2 (PermutePrefix @Nat perm sh) r)
-> target
(TKS2
((++)
@Nat
(Permute
@Nat is' (TakeLen @Nat @Nat is' (PermutePrefix @Nat perm sh)))
(DropLen @Nat @Nat is' (PermutePrefix @Nat perm sh)))
r)
forall (perm :: [Nat]) (sh :: [Nat]) (x :: TK).
(IsPermutation perm, (<=) @Nat (Rank @Nat perm) (Rank @Nat sh),
KnownSTK x) =>
Perm perm
-> target (TKS2 sh x)
-> target (TKS2 (PermutePrefix @Nat perm sh) x)
forall (target :: Target) (perm :: [Nat]) (sh :: [Nat]) (x :: TK).
(BaseTensor target, IsPermutation perm,
(<=) @Nat (Rank @Nat perm) (Rank @Nat sh), KnownSTK x) =>
Perm perm
-> target (TKS2 sh x)
-> target (TKS2 (PermutePrefix @Nat perm sh) x)
tstranspose Perm is'
permRev target (ADTensorKind y)
target (TKS2 (PermutePrefix @Nat perm sh) r)
c) Delta target (TKS2 sh r)
d
DeltaReshapeS ShS sh2
_sh2 Delta target (TKS2 sh r)
d -> case Delta target (TKS2 sh r) -> FullShapeTK (TKS2 sh r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS2 sh r)
d of
FTKS ShS sh
sh FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => EvalState target) -> EvalState target
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => EvalState target) -> EvalState target)
-> (KnownSTK x => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
EvalState target
-> target (ADTensorKind (TKS2 sh r))
-> Delta target (TKS2 sh r)
-> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame EvalState target
s (ShS sh -> target (TKS2 sh2 r) -> target (TKS2 sh r)
forall (sh :: [Nat]) (sh2 :: [Nat]) (x :: TK).
((Product sh :: Nat) ~ (Product sh2 :: Nat), KnownSTK x) =>
ShS sh2 -> target (TKS2 sh x) -> target (TKS2 sh2 x)
forall (target :: Target) (sh :: [Nat]) (sh2 :: [Nat]) (x :: TK).
(BaseTensor target, (Product sh :: Nat) ~ (Product sh2 :: Nat),
KnownSTK x) =>
ShS sh2 -> target (TKS2 sh x) -> target (TKS2 sh2 x)
tsreshape ShS sh
sh target (ADTensorKind y)
target (TKS2 sh2 r)
c) Delta target (TKS2 sh r)
d
DeltaCastX Delta target (TKX sh r1)
d -> case Delta target (TKX sh r1) -> FullShapeTK (TKX sh r1)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX sh r1)
d of
FullShapeTK (TKX sh r1)
y ->
EvalState target
-> target (ADTensorKind (TKX sh r1))
-> Delta target (TKX sh r1)
-> EvalState target
forall (sh :: [Maybe Nat]) r (target :: Target).
(GoodScalar r, ADReadyNoLet target, ShareTensor target) =>
EvalState target
-> target (ADTensorKind (TKX sh r))
-> Delta target (TKX sh r)
-> EvalState target
evalXRuntimeSpecialized
EvalState target
s (FullShapeTK (TKX sh r1)
-> target (TKX sh r1) -> target (ADTensorKind (TKX sh r1))
forall (target :: Target) (y :: TK).
(BaseTensor target, ConvertTensor target, ShareTensor target) =>
FullShapeTK y -> target y -> target (ADTensorKind y)
toADTensorKindShared FullShapeTK (TKX sh r1)
y (target (TKX sh r1) -> target (ADTensorKind (TKX sh r1)))
-> target (TKX sh r1) -> target (ADTensorKind (TKX sh r1))
forall a b. (a -> b) -> a -> b
$ target (TKX2 sh (TKScalar r2)) -> target (TKX sh r1)
forall r1 r2 (sh :: [Maybe Nat]).
(RealFrac r1, GoodScalar r1, RealFrac r2, GoodScalar r2) =>
target (TKX sh r1) -> target (TKX sh r2)
forall (target :: Target) r1 r2 (sh :: [Maybe Nat]).
(BaseTensor target, RealFrac r1, GoodScalar r1, RealFrac r2,
GoodScalar r2) =>
target (TKX sh r1) -> target (TKX sh r2)
txcast target (ADTensorKind y)
target (TKX2 sh (TKScalar r2))
c) Delta target (TKX sh r1)
d
DeltaSum0X Delta target (TKX2 sh r)
d -> case Delta target (TKX2 sh r) -> FullShapeTK (TKX2 sh r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX2 sh r)
d of
FTKX IShX sh
sh FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => EvalState target) -> EvalState target
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => EvalState target) -> EvalState target)
-> (KnownSTK x => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
StaticShX sh
-> (KnownShX sh => EvalState target) -> EvalState target
forall (sh :: [Maybe Nat]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX (IShX sh -> StaticShX sh
forall (sh :: [Maybe Nat]) i. ShX sh i -> StaticShX sh
ssxFromShX IShX sh
sh) ((KnownShX sh => EvalState target) -> EvalState target)
-> (KnownShX sh => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
EvalState target
-> target (ADTensorKind (TKX2 sh r))
-> Delta target (TKX2 sh r)
-> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame EvalState target
s (IShX sh -> target (TKX2 ('[] @(Maybe Nat)) r) -> target (TKX2 sh r)
forall (sh :: [Maybe Nat]) (x :: TK).
(KnownShX sh, KnownSTK x) =>
IShX sh -> target (TKX2 ('[] @(Maybe Nat)) x) -> target (TKX2 sh x)
forall (target :: Target) (sh :: [Maybe Nat]) (x :: TK).
(BaseTensor target, KnownShX sh, KnownSTK x) =>
IShX sh -> target (TKX2 ('[] @(Maybe Nat)) x) -> target (TKX2 sh x)
txreplicate0N IShX sh
sh target (ADTensorKind y)
target (TKX2 ('[] @(Maybe Nat)) r)
c) Delta target (TKX2 sh r)
d
DeltaDot0X target (TKX sh r)
v Delta target (TKX sh r)
d -> case Delta target (TKX sh r) -> FullShapeTK (TKX sh r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX sh r)
d of
FTKX IShX sh
sh FullShapeTK x
FTKScalar ->
StaticShX sh
-> (KnownShX sh => EvalState target) -> EvalState target
forall (sh :: [Maybe Nat]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX (IShX sh -> StaticShX sh
forall (sh :: [Maybe Nat]) i. ShX sh i -> StaticShX sh
ssxFromShX IShX sh
sh) ((KnownShX sh => EvalState target) -> EvalState target)
-> (KnownShX sh => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
EvalState target
-> target (ADTensorKind (TKX sh r))
-> Delta target (TKX sh r)
-> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame EvalState target
s (target (TKX sh r)
v target (TKX sh r) -> target (TKX sh r) -> target (TKX sh r)
forall a. Num a => a -> a -> a
* IShX sh
-> target (TKX2 ('[] @(Maybe Nat)) (TKScalar r))
-> target (TKX sh r)
forall (sh :: [Maybe Nat]) (x :: TK).
(KnownShX sh, KnownSTK x) =>
IShX sh -> target (TKX2 ('[] @(Maybe Nat)) x) -> target (TKX2 sh x)
forall (target :: Target) (sh :: [Maybe Nat]) (x :: TK).
(BaseTensor target, KnownShX sh, KnownSTK x) =>
IShX sh -> target (TKX2 ('[] @(Maybe Nat)) x) -> target (TKX2 sh x)
txreplicate0N (target (TKX sh r) -> IShX sh
forall (sh :: [Maybe Nat]) (x :: TK).
KnownSTK x =>
target (TKX2 sh x) -> IShX sh
forall (target :: Target) (sh :: [Maybe Nat]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKX2 sh x) -> IShX sh
xshape target (TKX sh r)
v) target (ADTensorKind y)
target (TKX2 ('[] @(Maybe Nat)) (TKScalar r))
c) Delta target (TKX sh r)
d
DeltaIndexX @shm @shn StaticShX shn
shn Delta target (TKX2 ((++) @(Maybe Nat) shm shn) r)
d IxXOf target shm
ix -> case Delta target (TKX2 ((++) @(Maybe Nat) shm shn) r)
-> FullShapeTK (TKX2 ((++) @(Maybe Nat) shm shn) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX2 ((++) @(Maybe Nat) shm shn) r)
d of
FTKX IShX sh
sh FullShapeTK x
x | SNat @len <- IxXOf target shm -> SNat (Rank @(Maybe Nat) shm)
forall (sh :: [Maybe Nat]) i.
IxX sh i -> SNat (Rank @(Maybe Nat) sh)
ixxRank IxXOf target shm
ix ->
SingletonTK x
-> (KnownSTK x => EvalState target) -> EvalState target
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => EvalState target) -> EvalState target)
-> (KnownSTK x => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
StaticShX shn
-> (KnownShX shn => EvalState target) -> EvalState target
forall (sh :: [Maybe Nat]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX StaticShX shn
shn ((KnownShX shn => EvalState target) -> EvalState target)
-> (KnownShX shn => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
StaticShX sh
-> (KnownShX sh => EvalState target) -> EvalState target
forall (sh :: [Maybe Nat]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX (IShX sh -> StaticShX sh
forall (sh :: [Maybe Nat]) i. ShX sh i -> StaticShX sh
ssxFromShX IShX sh
sh) ((KnownShX sh => EvalState target) -> EvalState target)
-> (KnownShX sh => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
StaticShX shm
-> (KnownShX shm => EvalState target) -> EvalState target
forall (sh :: [Maybe Nat]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX (forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]) i.
StaticShX ((++) @(Maybe Nat) sh sh') -> IxX sh i -> StaticShX sh
ssxTakeIx @shm @shn (IShX sh -> StaticShX sh
forall (sh :: [Maybe Nat]) i. ShX sh i -> StaticShX sh
ssxFromShX IShX sh
sh) IxXOf target shm
ix) ((KnownShX shm => EvalState target) -> EvalState target)
-> (KnownShX shm => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
(:~:)
@[Maybe Nat] (Take @(Maybe Nat) (Rank @(Maybe Nat) shm) sh) shm
-> (((Take @(Maybe Nat) (Rank @(Maybe Nat) shm) sh :: [Maybe Nat])
~ (shm :: [Maybe Nat])) =>
EvalState target)
-> EvalState target
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:)
@[Maybe Nat] (Take @(Maybe Nat) (Rank @(Maybe Nat) shm) sh) shm
(:~:)
@[Maybe Nat]
(Take
@(Maybe Nat) (Rank @(Maybe Nat) shm) ((++) @(Maybe Nat) shm shn))
shm
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl :: Take (Rank shm) (shm ++ shn) :~: shm) ((((Take @(Maybe Nat) (Rank @(Maybe Nat) shm) sh :: [Maybe Nat])
~ (shm :: [Maybe Nat])) =>
EvalState target)
-> EvalState target)
-> (((Take @(Maybe Nat) (Rank @(Maybe Nat) shm) sh :: [Maybe Nat])
~ (shm :: [Maybe Nat])) =>
EvalState target)
-> EvalState target
forall a b. (a -> b) -> a -> b
$
EvalState target
-> target (ADTensorKind (TKX2 ((++) @(Maybe Nat) shm shn) r))
-> Delta target (TKX2 ((++) @(Maybe Nat) shm shn) r)
-> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame EvalState target
s (IShX shm
-> target (TKX2 shn r)
-> IxXOf target shm
-> target (TKX2 ((++) @(Maybe Nat) shm shn) r)
forall (sh1 :: [Maybe Nat]) (sh2 :: [Maybe Nat]) (x :: TK).
(KnownShX sh1, KnownShX sh2, KnownSTK x,
(BoolOf (PrimalOf target) :: Type) ~ (BoolOf target :: Type),
EqH (PrimalOf target) (TKScalar Int64), ConvertTensor target) =>
IShX sh1
-> target (TKX2 sh2 x)
-> IxXOf target sh1
-> target (TKX2 ((++) @(Maybe Nat) sh1 sh2) x)
forall (target :: Target) (sh1 :: [Maybe Nat]) (sh2 :: [Maybe Nat])
(x :: TK).
(BaseTensor target, KnownShX sh1, KnownShX sh2, KnownSTK x,
(BoolOf (PrimalOf target) :: Type) ~ (BoolOf target :: Type),
EqH (PrimalOf target) (TKScalar Int64), ConvertTensor target) =>
IShX sh1
-> target (TKX2 sh2 x)
-> IxXOf target sh1
-> target (TKX2 ((++) @(Maybe Nat) sh1 sh2) x)
txoneHot (forall (len :: Nat) (sh :: [Maybe Nat]).
(KnownNat len, KnownShX sh, KnownShX (Take @(Maybe Nat) len sh)) =>
IShX sh -> IShX (Take @(Maybe Nat) len sh)
shxTake @len IShX sh
sh) target (ADTensorKind y)
target (TKX2 shn r)
c IxXOf target shm
ix) Delta target (TKX2 ((++) @(Maybe Nat) shm shn) r)
d
DeltaScatterX @shm @shn StaticShX shm
shm StaticShX shn
shn StaticShX shp
shp IShX ((++) @(Maybe Nat) shp shn)
_sh Delta target (TKX2 ((++) @(Maybe Nat) shm shn) r)
d IxXOf target shm -> IxXOf target shp
f -> case Delta target (TKX2 ((++) @(Maybe Nat) shm shn) r)
-> FullShapeTK (TKX2 ((++) @(Maybe Nat) shm shn) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX2 ((++) @(Maybe Nat) shm shn) r)
d of
FTKX IShX sh
sh FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => EvalState target) -> EvalState target
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => EvalState target) -> EvalState target)
-> (KnownSTK x => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
StaticShX shm
-> (KnownShX shm => EvalState target) -> EvalState target
forall (sh :: [Maybe Nat]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX StaticShX shm
shm ((KnownShX shm => EvalState target) -> EvalState target)
-> (KnownShX shm => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
StaticShX shn
-> (KnownShX shn => EvalState target) -> EvalState target
forall (sh :: [Maybe Nat]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX StaticShX shn
shn ((KnownShX shn => EvalState target) -> EvalState target)
-> (KnownShX shn => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
StaticShX shp
-> (KnownShX shp => EvalState target) -> EvalState target
forall (sh :: [Maybe Nat]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX StaticShX shp
shp ((KnownShX shp => EvalState target) -> EvalState target)
-> (KnownShX shp => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
EvalState target
-> target (ADTensorKind (TKX2 sh r))
-> Delta target (TKX2 sh r)
-> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame EvalState target
s (forall (target :: Target) (shm :: [Maybe Nat]) (shn :: [Maybe Nat])
(shp :: [Maybe Nat]) (x :: TK).
(BaseTensor target, KnownShX shm, KnownShX shn, KnownShX shp,
KnownSTK x) =>
IShX ((++) @(Maybe Nat) shm shn)
-> target (TKX2 ((++) @(Maybe Nat) shp shn) x)
-> (IxXOf target shm -> IxXOf target shp)
-> target (TKX2 ((++) @(Maybe Nat) shm shn) x)
txgather @_ @shm @shn IShX sh
IShX ((++) @(Maybe Nat) shm shn)
sh target (ADTensorKind y)
target (TKX2 ((++) @(Maybe Nat) shp shn) r)
c IxXOf target shm -> IxXOf target shp
f) Delta target (TKX2 sh r)
Delta target (TKX2 ((++) @(Maybe Nat) shm shn) r)
d
DeltaGatherX @shm @shn StaticShX shm
shm StaticShX shn
shn StaticShX shp
shp IShX ((++) @(Maybe Nat) shm shn)
_sh Delta target (TKX2 ((++) @(Maybe Nat) shp shn) r)
d IxXOf target shm -> IxXOf target shp
f -> case Delta target (TKX2 ((++) @(Maybe Nat) shp shn) r)
-> FullShapeTK (TKX2 ((++) @(Maybe Nat) shp shn) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX2 ((++) @(Maybe Nat) shp shn) r)
d of
FTKX IShX sh
sh FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => EvalState target) -> EvalState target
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => EvalState target) -> EvalState target)
-> (KnownSTK x => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
StaticShX shm
-> (KnownShX shm => EvalState target) -> EvalState target
forall (sh :: [Maybe Nat]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX StaticShX shm
shm ((KnownShX shm => EvalState target) -> EvalState target)
-> (KnownShX shm => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
StaticShX shn
-> (KnownShX shn => EvalState target) -> EvalState target
forall (sh :: [Maybe Nat]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX StaticShX shn
shn ((KnownShX shn => EvalState target) -> EvalState target)
-> (KnownShX shn => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
StaticShX shp
-> (KnownShX shp => EvalState target) -> EvalState target
forall (sh :: [Maybe Nat]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX StaticShX shp
shp ((KnownShX shp => EvalState target) -> EvalState target)
-> (KnownShX shp => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
EvalState target
-> target (ADTensorKind (TKX2 sh r))
-> Delta target (TKX2 sh r)
-> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame EvalState target
s (forall (target :: Target) (shm :: [Maybe Nat]) (shn :: [Maybe Nat])
(shp :: [Maybe Nat]) (x :: TK).
(BaseTensor target, KnownShX shm, KnownShX shn, KnownShX shp,
KnownSTK x) =>
IShX ((++) @(Maybe Nat) shp shn)
-> target (TKX2 ((++) @(Maybe Nat) shm shn) x)
-> (IxXOf target shm -> IxXOf target shp)
-> target (TKX2 ((++) @(Maybe Nat) shp shn) x)
txscatter @_ @shm @shn IShX sh
IShX ((++) @(Maybe Nat) shp shn)
sh target (ADTensorKind y)
target (TKX2 ((++) @(Maybe Nat) shm shn) r)
c IxXOf target shm -> IxXOf target shp
f) Delta target (TKX2 sh r)
Delta target (TKX2 ((++) @(Maybe Nat) shp shn) r)
d
DeltaAppendX Delta target (TKX2 ((':) @(Maybe Nat) ('Just @Nat m) sh) r)
d Delta target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) r)
e -> case (Delta target (TKX2 ((':) @(Maybe Nat) ('Just @Nat m) sh) r)
-> FullShapeTK (TKX2 ((':) @(Maybe Nat) ('Just @Nat m) sh) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX2 ((':) @(Maybe Nat) ('Just @Nat m) sh) r)
d, Delta target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) r)
-> FullShapeTK (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) r)
e) of
(FTKX (Nested.SKnown m :: SNat n1
m@SNat n1
SNat :$% ShX sh Int
_) FullShapeTK x
x, FTKX (Nested.SKnown SNat n1
SNat :$% ShX sh Int
_) FullShapeTK x
_) ->
SingletonTK x
-> (KnownSTK x => EvalState target) -> EvalState target
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => EvalState target) -> EvalState target)
-> (KnownSTK x => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
let cShared :: target (TKX2 ((':) @(Maybe Nat) ('Just @Nat (m + n)) sh) r)
cShared = target (TKX2 ((':) @(Maybe Nat) ('Just @Nat (m + n)) sh) r)
-> target (TKX2 ((':) @(Maybe Nat) ('Just @Nat (m + n)) sh) r)
forall (y :: TK). target y -> target y
forall (target :: Target) (y :: TK).
ShareTensor target =>
target y -> target y
tshare target (ADTensorKind y)
target (TKX2 ((':) @(Maybe Nat) ('Just @Nat (m + n)) sh) r)
c
s2 :: EvalState target
s2 = EvalState target
-> target
(ADTensorKind (TKX2 ((':) @(Maybe Nat) ('Just @Nat m) sh) r))
-> Delta target (TKX2 ((':) @(Maybe Nat) ('Just @Nat m) sh) r)
-> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame EvalState target
s (SNat 0
-> SNat m
-> SNat n
-> target
(TKX2 ((':) @(Maybe Nat) ('Just @Nat ((0 + m) + n)) sh) r)
-> target (TKX2 ((':) @(Maybe Nat) ('Just @Nat m) sh) r)
forall (i :: Nat) (n :: Nat) (k :: Nat) (sh :: [Maybe Nat])
(x :: TK).
KnownSTK x =>
SNat i
-> SNat n
-> SNat k
-> target
(TKX2 ((':) @(Maybe Nat) ('Just @Nat ((i + n) + k)) sh) x)
-> target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
forall (target :: Target) (i :: Nat) (n :: Nat) (k :: Nat)
(sh :: [Maybe Nat]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
SNat i
-> SNat n
-> SNat k
-> target
(TKX2 ((':) @(Maybe Nat) ('Just @Nat ((i + n) + k)) sh) x)
-> target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
txslice (forall (n :: Nat). KnownNat n => SNat n
SNat @0) SNat m
forall (n :: Nat). KnownNat n => SNat n
SNat SNat n
forall (n :: Nat). KnownNat n => SNat n
SNat target (TKX2 ((':) @(Maybe Nat) ('Just @Nat (m + n)) sh) r)
target (TKX2 ((':) @(Maybe Nat) ('Just @Nat ((0 + m) + n)) sh) r)
cShared) Delta target (TKX2 ((':) @(Maybe Nat) ('Just @Nat m) sh) r)
d
in EvalState target
-> target
(ADTensorKind (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) r))
-> Delta target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) r)
-> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame EvalState target
s2 (SNat n1
-> SNat n
-> SNat 0
-> target
(TKX2 ((':) @(Maybe Nat) ('Just @Nat ((n1 + n) + 0)) sh) r)
-> target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) r)
forall (i :: Nat) (n :: Nat) (k :: Nat) (sh :: [Maybe Nat])
(x :: TK).
KnownSTK x =>
SNat i
-> SNat n
-> SNat k
-> target
(TKX2 ((':) @(Maybe Nat) ('Just @Nat ((i + n) + k)) sh) x)
-> target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
forall (target :: Target) (i :: Nat) (n :: Nat) (k :: Nat)
(sh :: [Maybe Nat]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
SNat i
-> SNat n
-> SNat k
-> target
(TKX2 ((':) @(Maybe Nat) ('Just @Nat ((i + n) + k)) sh) x)
-> target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
txslice SNat n1
m SNat n
forall (n :: Nat). KnownNat n => SNat n
SNat SNat 0
forall (n :: Nat). KnownNat n => SNat n
SNat target (TKX2 ((':) @(Maybe Nat) ('Just @Nat (m + n)) sh) r)
target (TKX2 ((':) @(Maybe Nat) ('Just @Nat ((n1 + n) + 0)) sh) r)
cShared) Delta target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) r)
e
DeltaSliceX i :: SNat i
i@SNat i
SNat SNat n
_ k :: SNat k
k@SNat k
SNat Delta
target (TKX2 ((':) @(Maybe Nat) ('Just @Nat ((i + n) + k)) sh) r)
d -> case Delta
target (TKX2 ((':) @(Maybe Nat) ('Just @Nat ((i + n) + k)) sh) r)
-> FullShapeTK
(TKX2 ((':) @(Maybe Nat) ('Just @Nat ((i + n) + k)) sh) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta
target (TKX2 ((':) @(Maybe Nat) ('Just @Nat ((i + n) + k)) sh) r)
d of
FTKX (SMayNat @Nat Int SNat n
_ :$% ShX sh Int
sh) FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => EvalState target) -> EvalState target
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => EvalState target) -> EvalState target)
-> (KnownSTK x => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
EvalState target
-> target
(ADTensorKind
(TKX2 ((':) @(Maybe Nat) ('Just @Nat (i + (n + k))) sh) r))
-> Delta
target (TKX2 ((':) @(Maybe Nat) ('Just @Nat (i + (n + k))) sh) r)
-> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame EvalState target
s (target (TKX2 ((':) @(Maybe Nat) ('Just @Nat i) sh) x)
-> target (TKX2 ((':) @(Maybe Nat) ('Just @Nat (n + k)) sh) x)
-> target
(TKX2 ((':) @(Maybe Nat) ('Just @Nat (i + (n + k))) sh) x)
forall (m :: Nat) (n :: Nat) (sh :: [Maybe Nat]) (x :: TK).
KnownSTK x =>
target (TKX2 ((':) @(Maybe Nat) ('Just @Nat m) sh) x)
-> target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
-> target (TKX2 ((':) @(Maybe Nat) ('Just @Nat (m + n)) sh) x)
forall (target :: Target) (m :: Nat) (n :: Nat) (sh :: [Maybe Nat])
(x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKX2 ((':) @(Maybe Nat) ('Just @Nat m) sh) x)
-> target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
-> target (TKX2 ((':) @(Maybe Nat) ('Just @Nat (m + n)) sh) x)
txappend
(FullShapeTK (TKX2 ((':) @(Maybe Nat) ('Just @Nat i) sh) x)
-> target (TKX2 ((':) @(Maybe Nat) ('Just @Nat i) sh) x)
forall (y :: TK). FullShapeTK y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
FullShapeTK y -> target y
tdefTarget (IShX ((':) @(Maybe Nat) ('Just @Nat i) sh)
-> FullShapeTK x
-> FullShapeTK (TKX2 ((':) @(Maybe Nat) ('Just @Nat i) sh) x)
forall (sh :: [Maybe Nat]) (x :: TK).
IShX sh -> FullShapeTK x -> FullShapeTK (TKX2 sh x)
FTKX (SNat i -> SMayNat @Nat Int SNat ('Just @Nat i)
forall {k} (f :: k -> Type) (n1 :: k) i.
f n1 -> SMayNat @k i f ('Just @k n1)
Nested.SKnown SNat i
i SMayNat @Nat Int SNat ('Just @Nat i)
-> ShX sh Int -> IShX ((':) @(Maybe Nat) ('Just @Nat i) sh)
forall {sh1 :: [Maybe Nat]} {i} (n :: Maybe Nat)
(sh :: [Maybe Nat]).
(((':) @(Maybe Nat) n sh :: [Maybe Nat]) ~ (sh1 :: [Maybe Nat])) =>
SMayNat @Nat i SNat n -> ShX sh i -> ShX sh1 i
:$% ShX sh Int
sh) FullShapeTK x
x))
(target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
-> target (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) x)
-> target (TKX2 ((':) @(Maybe Nat) ('Just @Nat (n + k)) sh) x)
forall (m :: Nat) (n :: Nat) (sh :: [Maybe Nat]) (x :: TK).
KnownSTK x =>
target (TKX2 ((':) @(Maybe Nat) ('Just @Nat m) sh) x)
-> target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
-> target (TKX2 ((':) @(Maybe Nat) ('Just @Nat (m + n)) sh) x)
forall (target :: Target) (m :: Nat) (n :: Nat) (sh :: [Maybe Nat])
(x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKX2 ((':) @(Maybe Nat) ('Just @Nat m) sh) x)
-> target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
-> target (TKX2 ((':) @(Maybe Nat) ('Just @Nat (m + n)) sh) x)
txappend
target (ADTensorKind y)
target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
c (FullShapeTK (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) x)
-> target (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) x)
forall (y :: TK). FullShapeTK y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
FullShapeTK y -> target y
tdefTarget
(IShX ((':) @(Maybe Nat) ('Just @Nat k) sh)
-> FullShapeTK x
-> FullShapeTK (TKX2 ((':) @(Maybe Nat) ('Just @Nat k) sh) x)
forall (sh :: [Maybe Nat]) (x :: TK).
IShX sh -> FullShapeTK x -> FullShapeTK (TKX2 sh x)
FTKX (SNat k -> SMayNat @Nat Int SNat ('Just @Nat k)
forall {k} (f :: k -> Type) (n1 :: k) i.
f n1 -> SMayNat @k i f ('Just @k n1)
Nested.SKnown SNat k
k SMayNat @Nat Int SNat ('Just @Nat k)
-> ShX sh Int -> IShX ((':) @(Maybe Nat) ('Just @Nat k) sh)
forall {sh1 :: [Maybe Nat]} {i} (n :: Maybe Nat)
(sh :: [Maybe Nat]).
(((':) @(Maybe Nat) n sh :: [Maybe Nat]) ~ (sh1 :: [Maybe Nat])) =>
SMayNat @Nat i SNat n -> ShX sh i -> ShX sh1 i
:$% ShX sh Int
sh) FullShapeTK x
x)))) Delta
target (TKX2 ((':) @(Maybe Nat) ('Just @Nat (i + (n + k))) sh) r)
Delta
target (TKX2 ((':) @(Maybe Nat) ('Just @Nat ((i + n) + k)) sh) r)
d
DeltaReverseX Delta target (TKX2 ((':) @(Maybe Nat) mn sh) r)
d -> case Delta target (TKX2 ((':) @(Maybe Nat) mn sh) r)
-> FullShapeTK (TKX2 ((':) @(Maybe Nat) mn sh) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX2 ((':) @(Maybe Nat) mn sh) r)
d of
FTKX IShX sh
_ FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => EvalState target) -> EvalState target
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => EvalState target) -> EvalState target)
-> (KnownSTK x => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
EvalState target
-> target (ADTensorKind (TKX2 ((':) @(Maybe Nat) mn sh) r))
-> Delta target (TKX2 ((':) @(Maybe Nat) mn sh) r)
-> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame EvalState target
s (target (TKX2 ((':) @(Maybe Nat) mn sh) r)
-> target (TKX2 ((':) @(Maybe Nat) mn sh) r)
forall (mn :: Maybe Nat) (sh :: [Maybe Nat]) (x :: TK).
KnownSTK x =>
target (TKX2 ((':) @(Maybe Nat) mn sh) x)
-> target (TKX2 ((':) @(Maybe Nat) mn sh) x)
forall (target :: Target) (mn :: Maybe Nat) (sh :: [Maybe Nat])
(x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKX2 ((':) @(Maybe Nat) mn sh) x)
-> target (TKX2 ((':) @(Maybe Nat) mn sh) x)
txreverse target (ADTensorKind y)
target (TKX2 ((':) @(Maybe Nat) mn sh) r)
c) Delta target (TKX2 ((':) @(Maybe Nat) mn sh) r)
d
DeltaTransposeX @perm @sh2 Perm perm
perm Delta target (TKX2 sh r)
d -> case Delta target (TKX2 sh r) -> FullShapeTK (TKX2 sh r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX2 sh r)
d of
FTKX IShX sh
_ FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => EvalState target) -> EvalState target
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => EvalState target) -> EvalState target)
-> (KnownSTK x => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
Perm perm
-> (forall (is' :: [Nat]).
(Assert
(AllElem' @Nat is' (Count 0 (Rank @Nat is'))) (TypeError ...),
Assert
(AllElem' @Nat (Count 0 (Rank @Nat is')) is') (TypeError ...)) =>
Perm is'
-> (forall (sh :: [Maybe Nat]).
((Rank @(Maybe Nat) sh :: Nat) ~ (Rank @Nat perm :: Nat)) =>
StaticShX sh
-> (:~:)
@[Maybe Nat]
(Permute @(Maybe Nat) is' (Permute @(Maybe Nat) perm sh))
sh)
-> EvalState target)
-> EvalState target
forall (is :: [Nat]) r.
Perm is
-> (forall (is' :: [Nat]).
IsPermutation is' =>
Perm is'
-> (forall (sh :: [Maybe Nat]).
((Rank @(Maybe Nat) sh :: Nat) ~ (Rank @Nat is :: Nat)) =>
StaticShX sh
-> (:~:)
@[Maybe Nat]
(Permute @(Maybe Nat) is' (Permute @(Maybe Nat) is sh))
sh)
-> r)
-> r
permInverse Perm perm
perm ((forall (is' :: [Nat]).
(Assert
(AllElem' @Nat is' (Count 0 (Rank @Nat is'))) (TypeError ...),
Assert
(AllElem' @Nat (Count 0 (Rank @Nat is')) is') (TypeError ...)) =>
Perm is'
-> (forall (sh :: [Maybe Nat]).
((Rank @(Maybe Nat) sh :: Nat) ~ (Rank @Nat perm :: Nat)) =>
StaticShX sh
-> (:~:)
@[Maybe Nat]
(Permute @(Maybe Nat) is' (Permute @(Maybe Nat) perm sh))
sh)
-> EvalState target)
-> EvalState target)
-> (forall (is' :: [Nat]).
(Assert
(AllElem' @Nat is' (Count 0 (Rank @Nat is'))) (TypeError ...),
Assert
(AllElem' @Nat (Count 0 (Rank @Nat is')) is') (TypeError ...)) =>
Perm is'
-> (forall (sh :: [Maybe Nat]).
((Rank @(Maybe Nat) sh :: Nat) ~ (Rank @Nat perm :: Nat)) =>
StaticShX sh
-> (:~:)
@[Maybe Nat]
(Permute @(Maybe Nat) is' (Permute @(Maybe Nat) perm sh))
sh)
-> EvalState target)
-> EvalState target
forall a b. (a -> b) -> a -> b
$ \(Perm is'
permR :: Permutation.Perm permR) forall (sh :: [Maybe Nat]).
((Rank @(Maybe Nat) sh :: Nat) ~ (Rank @Nat perm :: Nat)) =>
StaticShX sh
-> (:~:)
@[Maybe Nat]
(Permute @(Maybe Nat) is' (Permute @(Maybe Nat) perm sh))
sh
_ ->
(:~:)
@[Maybe Nat]
((++)
@(Maybe Nat)
(Permute
@(Maybe Nat)
is'
(TakeLen
@(Maybe Nat) @Nat is' (PermutePrefix @(Maybe Nat) perm sh)))
(DropLen
@(Maybe Nat) @Nat is' (PermutePrefix @(Maybe Nat) perm sh)))
sh
-> ((((++)
@(Maybe Nat)
(Permute
@(Maybe Nat)
is'
(TakeLen
@(Maybe Nat) @Nat is' (PermutePrefix @(Maybe Nat) perm sh)))
(DropLen
@(Maybe Nat)
@Nat
is'
(PermutePrefix @(Maybe Nat) perm sh)) :: [Maybe Nat])
~ (sh :: [Maybe Nat])) =>
EvalState target)
-> EvalState target
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:)
@[Maybe Nat]
((++)
@(Maybe Nat)
(Permute
@(Maybe Nat)
is'
(TakeLen
@(Maybe Nat) @Nat is' (PermutePrefix @(Maybe Nat) perm sh)))
(DropLen
@(Maybe Nat) @Nat is' (PermutePrefix @(Maybe Nat) perm sh)))
sh
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl
:: Permutation.PermutePrefix
permR (Permutation.PermutePrefix perm sh2) :~: sh2) (((((++)
@(Maybe Nat)
(Permute
@(Maybe Nat)
is'
(TakeLen
@(Maybe Nat) @Nat is' (PermutePrefix @(Maybe Nat) perm sh)))
(DropLen
@(Maybe Nat)
@Nat
is'
(PermutePrefix @(Maybe Nat) perm sh)) :: [Maybe Nat])
~ (sh :: [Maybe Nat])) =>
EvalState target)
-> EvalState target)
-> ((((++)
@(Maybe Nat)
(Permute
@(Maybe Nat)
is'
(TakeLen
@(Maybe Nat) @Nat is' (PermutePrefix @(Maybe Nat) perm sh)))
(DropLen
@(Maybe Nat)
@Nat
is'
(PermutePrefix @(Maybe Nat) perm sh)) :: [Maybe Nat])
~ (sh :: [Maybe Nat])) =>
EvalState target)
-> EvalState target
forall a b. (a -> b) -> a -> b
$
(:~:)
@Nat
(Rank @(Maybe Nat) (PermutePrefix @(Maybe Nat) perm sh))
(Rank @(Maybe Nat) sh)
-> (((Rank
@(Maybe Nat) (PermutePrefix @(Maybe Nat) perm sh) :: Nat)
~ (Rank @(Maybe Nat) sh :: Nat)) =>
EvalState target)
-> EvalState target
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:)
@Nat
(Rank @(Maybe Nat) (PermutePrefix @(Maybe Nat) perm sh))
(Rank @(Maybe Nat) sh)
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl
:: Rank (Permutation.PermutePrefix perm sh2) :~: Rank sh2) ((((Rank @(Maybe Nat) (PermutePrefix @(Maybe Nat) perm sh) :: Nat)
~ (Rank @(Maybe Nat) sh :: Nat)) =>
EvalState target)
-> EvalState target)
-> (((Rank
@(Maybe Nat) (PermutePrefix @(Maybe Nat) perm sh) :: Nat)
~ (Rank @(Maybe Nat) sh :: Nat)) =>
EvalState target)
-> EvalState target
forall a b. (a -> b) -> a -> b
$
(:~:) @Nat (Rank @Nat is') (Rank @Nat perm)
-> (((Rank @Nat is' :: Nat) ~ (Rank @Nat perm :: Nat)) =>
EvalState target)
-> EvalState target
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:) @Nat (Rank @Nat is') (Rank @Nat perm)
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl
:: Rank permR :~: Rank perm) ((((Rank @Nat is' :: Nat) ~ (Rank @Nat perm :: Nat)) =>
EvalState target)
-> EvalState target)
-> (((Rank @Nat is' :: Nat) ~ (Rank @Nat perm :: Nat)) =>
EvalState target)
-> EvalState target
forall a b. (a -> b) -> a -> b
$
EvalState target
-> target (ADTensorKind (TKX2 sh r))
-> Delta target (TKX2 sh r)
-> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame EvalState target
s (Perm is'
-> target (TKX2 (PermutePrefix @(Maybe Nat) perm sh) r)
-> target
(TKX2
((++)
@(Maybe Nat)
(Permute
@(Maybe Nat)
is'
(TakeLen
@(Maybe Nat) @Nat is' (PermutePrefix @(Maybe Nat) perm sh)))
(DropLen
@(Maybe Nat) @Nat is' (PermutePrefix @(Maybe Nat) perm sh)))
r)
forall (perm :: [Nat]) (sh :: [Maybe Nat]) (x :: TK).
(IsPermutation perm,
(<=) @Nat (Rank @Nat perm) (Rank @(Maybe Nat) sh), KnownSTK x) =>
Perm perm
-> target (TKX2 sh x)
-> target (TKX2 (PermutePrefix @(Maybe Nat) perm sh) x)
forall (target :: Target) (perm :: [Nat]) (sh :: [Maybe Nat])
(x :: TK).
(BaseTensor target, IsPermutation perm,
(<=) @Nat (Rank @Nat perm) (Rank @(Maybe Nat) sh), KnownSTK x) =>
Perm perm
-> target (TKX2 sh x)
-> target (TKX2 (PermutePrefix @(Maybe Nat) perm sh) x)
txtranspose Perm is'
permR target (ADTensorKind y)
target (TKX2 (PermutePrefix @(Maybe Nat) perm sh) r)
c) Delta target (TKX2 sh r)
d
DeltaReshapeX IShX sh2
_sh2 Delta target (TKX2 sh r)
d -> case Delta target (TKX2 sh r) -> FullShapeTK (TKX2 sh r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX2 sh r)
d of
FTKX IShX sh
sh FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => EvalState target) -> EvalState target
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => EvalState target) -> EvalState target)
-> (KnownSTK x => EvalState target) -> EvalState target
forall a b. (a -> b) -> a -> b
$
EvalState target
-> target (ADTensorKind (TKX2 sh r))
-> Delta target (TKX2 sh r)
-> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame EvalState target
s (IShX sh -> target (TKX2 sh2 r) -> target (TKX2 sh r)
forall (sh :: [Maybe Nat]) (sh2 :: [Maybe Nat]) (x :: TK).
KnownSTK x =>
IShX sh2 -> target (TKX2 sh x) -> target (TKX2 sh2 x)
forall (target :: Target) (sh :: [Maybe Nat]) (sh2 :: [Maybe Nat])
(x :: TK).
(BaseTensor target, KnownSTK x) =>
IShX sh2 -> target (TKX2 sh x) -> target (TKX2 sh2 x)
txreshape IShX sh
sh target (ADTensorKind y)
target (TKX2 sh2 r)
c) Delta target (TKX2 sh r)
d
DeltaConvert @a TKConversion a1 y
c1 Delta target a1
d -> case Delta target a1 -> FullShapeTK a1
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target a1
d of
FullShapeTK a1
aftk ->
(:~:) @TK (ADTensorKind a1) a1
-> (((ADTensorKind a1 :: TK) ~ (a1 :: TK)) => EvalState target)
-> EvalState target
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:) @TK (ADTensorKind a1) a1
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl :: ADTensorKind a :~: a) ((((ADTensorKind a1 :: TK) ~ (a1 :: TK)) => EvalState target)
-> EvalState target)
-> (((ADTensorKind a1 :: TK) ~ (a1 :: TK)) => EvalState target)
-> EvalState target
forall a b. (a -> b) -> a -> b
$
EvalState target
-> target (ADTensorKind a1) -> Delta target a1 -> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevSame
EvalState target
s (TKConversion y a1 -> SingletonTK y -> target y -> target a1
forall (a :: TK) (b :: TK).
TKConversion a b -> SingletonTK a -> target a -> target b
forall (target :: Target) (a :: TK) (b :: TK).
ConvertTensor target =>
TKConversion a b -> SingletonTK a -> target a -> target b
tconvert (FullShapeTK a1 -> TKConversion a1 y -> TKConversion y a1
forall (a :: TK) (b :: TK).
FullShapeTK a -> TKConversion a b -> TKConversion b a
transposeTKConversion FullShapeTK a1
aftk TKConversion a1 y
c1)
(TKConversion a1 y -> SingletonTK a1 -> SingletonTK y
forall (a :: TK) (b :: TK).
TKConversion a b -> SingletonTK a -> SingletonTK b
convertSTK TKConversion a1 y
c1 (SingletonTK a1 -> SingletonTK y)
-> SingletonTK a1 -> SingletonTK y
forall a b. (a -> b) -> a -> b
$ FullShapeTK a1 -> SingletonTK a1
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK (FullShapeTK a1 -> SingletonTK a1)
-> FullShapeTK a1 -> SingletonTK a1
forall a b. (a -> b) -> a -> b
$ Delta target a1 -> FullShapeTK a1
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target a1
d) target y
target (ADTensorKind y)
c) Delta target a1
d
Delta target y
d -> EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target) =>
EvalState target
-> target (ADTensorKind y) -> Delta target y -> EvalState target
evalRevFTK EvalState target
s target (ADTensorKind y)
c Delta target y
d
transposeTKConversion :: FullShapeTK a -> TKConversion a b -> TKConversion b a
transposeTKConversion :: forall (a :: TK) (b :: TK).
FullShapeTK a -> TKConversion a b -> TKConversion b a
transposeTKConversion FullShapeTK a
aftk TKConversion a b
c0 = case TKConversion a b
c0 of
TKConversion a b
ConvId -> TKConversion b a
TKConversion b b
forall (a :: TK). TKConversion a a
ConvId
ConvCmp TKConversion b1 b
c1 TKConversion a b1
c2 -> TKConversion b1 a -> TKConversion b b1 -> TKConversion b a
forall (b1 :: TK) (b :: TK) (a :: TK).
TKConversion b1 b -> TKConversion a b1 -> TKConversion a b
ConvCmp (FullShapeTK a -> TKConversion a b1 -> TKConversion b1 a
forall (a :: TK) (b :: TK).
FullShapeTK a -> TKConversion a b -> TKConversion b a
transposeTKConversion FullShapeTK a
aftk TKConversion a b1
c2)
(FullShapeTK b1 -> TKConversion b1 b -> TKConversion b b1
forall (a :: TK) (b :: TK).
FullShapeTK a -> TKConversion a b -> TKConversion b a
transposeTKConversion (TKConversion a b1 -> FullShapeTK a -> FullShapeTK b1
forall (a :: TK) (b :: TK).
TKConversion a b -> FullShapeTK a -> FullShapeTK b
convertFTK TKConversion a b1
c2 FullShapeTK a
aftk) TKConversion b1 b
c1)
TKConversion a b
ConvRX | FTKR @n IShR n
_ FullShapeTK x
x <- FullShapeTK a
aftk
, (:~:)
@Nat
(Rank @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)))
n
Refl <- Proxy @Nat n
-> (:~:)
@Nat
(Rank @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)))
n
forall (proxy :: Nat -> Type) (n :: Nat).
proxy n
-> (:~:)
@Nat
(Rank @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)))
n
lemRankReplicate (forall (t :: Nat). Proxy @Nat t
forall {k} (t :: k). Proxy @k t
Proxy @n) ->
SingletonTK x
-> TKConversion
(TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) x)
(TKR2
(Rank @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat))) x)
forall (a1 :: TK) (sh :: [Maybe Nat]).
SingletonTK a1
-> TKConversion (TKX2 sh a1) (TKR2 (Rank @(Maybe Nat) sh) a1)
ConvXR (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x)
TKConversion a b
ConvSX -> TKConversion b a
TKConversion (TKX2 (MapJust @Nat sh) a1) (TKS2 sh a1)
forall (sh :: [Nat]) (a1 :: TK).
TKConversion (TKX2 (MapJust @Nat sh) a1) (TKS2 sh a1)
ConvXS
ConvXR @_ @sh SingletonTK a1
_stk | (:~:)
@Nat
(Rank
@(Maybe Nat)
(Replicate @(Maybe Nat) (Rank @(Maybe Nat) sh) ('Nothing @Nat)))
(Rank @(Maybe Nat) sh)
Refl <- Proxy @Nat (Rank @(Maybe Nat) sh)
-> (:~:)
@Nat
(Rank
@(Maybe Nat)
(Replicate @(Maybe Nat) (Rank @(Maybe Nat) sh) ('Nothing @Nat)))
(Rank @(Maybe Nat) sh)
forall (proxy :: Nat -> Type) (n :: Nat).
proxy n
-> (:~:)
@Nat
(Rank @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)))
n
lemRankReplicate (forall (t :: Nat). Proxy @Nat t
forall {k} (t :: k). Proxy @k t
Proxy @(Rank sh)) ->
TKConversion
(TKX2
(Replicate @(Maybe Nat) (Rank @(Maybe Nat) sh) ('Nothing @Nat)) a1)
a
-> TKConversion
b
(TKX2
(Replicate @(Maybe Nat) (Rank @(Maybe Nat) sh) ('Nothing @Nat)) a1)
-> TKConversion b a
forall (b1 :: TK) (b :: TK) (a :: TK).
TKConversion b1 b -> TKConversion a b1 -> TKConversion a b
ConvCmp (FullShapeTK (TKX2 sh a1)
-> TKConversion
(TKX2
(Replicate @(Maybe Nat) (Rank @(Maybe Nat) sh) ('Nothing @Nat)) a1)
(TKX2 sh a1)
forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]) (a1 :: TK).
((Rank @(Maybe Nat) sh :: Nat) ~ (Rank @(Maybe Nat) sh' :: Nat)) =>
FullShapeTK (TKX2 sh' a1)
-> TKConversion (TKX2 sh a1) (TKX2 sh' a1)
ConvXX' FullShapeTK a
FullShapeTK (TKX2 sh a1)
aftk) TKConversion
b
(TKX2
(Replicate @(Maybe Nat) (Rank @(Maybe Nat) sh) ('Nothing @Nat)) a1)
TKConversion
(TKR2 (Rank @(Maybe Nat) sh) a1)
(TKX2
(Replicate @(Maybe Nat) (Rank @(Maybe Nat) sh) ('Nothing @Nat)) a1)
forall (n :: Nat) (a1 :: TK).
TKConversion
(TKR2 n a1) (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) a1)
ConvRX
TKConversion a b
ConvXS -> TKConversion b a
TKConversion (TKS2 sh a1) (TKX2 (MapJust @Nat sh) a1)
forall (sh :: [Nat]) (a1 :: TK).
TKConversion (TKS2 sh a1) (TKX2 (MapJust @Nat sh) a1)
ConvSX
ConvXS' (FTKS ShS sh
sh FullShapeTK x
_) | (:~:) @Nat (Rank @(Maybe Nat) (MapJust @Nat sh)) (Rank @Nat sh)
Refl <- ShS sh
-> (:~:) @Nat (Rank @(Maybe Nat) (MapJust @Nat sh)) (Rank @Nat sh)
forall (sh :: [Nat]).
ShS sh
-> (:~:) @Nat (Rank @(Maybe Nat) (MapJust @Nat sh)) (Rank @Nat sh)
lemRankMapJust ShS sh
sh ->
TKConversion (TKX2 (MapJust @Nat sh') a1) a
-> TKConversion b (TKX2 (MapJust @Nat sh') a1) -> TKConversion b a
forall (b1 :: TK) (b :: TK) (a :: TK).
TKConversion b1 b -> TKConversion a b1 -> TKConversion a b
ConvCmp (FullShapeTK (TKX2 sh a1)
-> TKConversion (TKX2 (MapJust @Nat sh') a1) (TKX2 sh a1)
forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]) (a1 :: TK).
((Rank @(Maybe Nat) sh :: Nat) ~ (Rank @(Maybe Nat) sh' :: Nat)) =>
FullShapeTK (TKX2 sh' a1)
-> TKConversion (TKX2 sh a1) (TKX2 sh' a1)
ConvXX' FullShapeTK a
FullShapeTK (TKX2 sh a1)
aftk) TKConversion b (TKX2 (MapJust @Nat sh') a1)
TKConversion (TKS2 sh' a1) (TKX2 (MapJust @Nat sh') a1)
forall (sh :: [Nat]) (a1 :: TK).
TKConversion (TKS2 sh a1) (TKX2 (MapJust @Nat sh) a1)
ConvSX
ConvXX' FullShapeTK (TKX2 sh' a1)
_ftk -> FullShapeTK (TKX2 sh a1) -> TKConversion (TKX2 sh' a1) (TKX2 sh a1)
forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]) (a1 :: TK).
((Rank @(Maybe Nat) sh :: Nat) ~ (Rank @(Maybe Nat) sh' :: Nat)) =>
FullShapeTK (TKX2 sh' a1)
-> TKConversion (TKX2 sh a1) (TKX2 sh' a1)
ConvXX' FullShapeTK a
FullShapeTK (TKX2 sh a1)
aftk
ConvRR TKConversion a1 b1
c | FTKR IShR n
_ FullShapeTK x
x <- FullShapeTK a
aftk -> TKConversion b1 x -> TKConversion (TKR2 n b1) (TKR2 n x)
forall (a1 :: TK) (b1 :: TK) (n :: Nat).
TKConversion a1 b1 -> TKConversion (TKR2 n a1) (TKR2 n b1)
ConvRR (FullShapeTK x -> TKConversion x b1 -> TKConversion b1 x
forall (a :: TK) (b :: TK).
FullShapeTK a -> TKConversion a b -> TKConversion b a
transposeTKConversion FullShapeTK x
x TKConversion a1 b1
TKConversion x b1
c)
ConvSS TKConversion a1 b1
c | FTKS ShS sh
_ FullShapeTK x
x <- FullShapeTK a
aftk -> TKConversion b1 x -> TKConversion (TKS2 sh b1) (TKS2 sh x)
forall (a1 :: TK) (b1 :: TK) (sh :: [Nat]).
TKConversion a1 b1 -> TKConversion (TKS2 sh a1) (TKS2 sh b1)
ConvSS (FullShapeTK x -> TKConversion x b1 -> TKConversion b1 x
forall (a :: TK) (b :: TK).
FullShapeTK a -> TKConversion a b -> TKConversion b a
transposeTKConversion FullShapeTK x
x TKConversion a1 b1
TKConversion x b1
c)
ConvXX TKConversion a1 b1
c | FTKX IShX sh
_ FullShapeTK x
x <- FullShapeTK a
aftk -> TKConversion b1 x -> TKConversion (TKX2 sh b1) (TKX2 sh x)
forall (a1 :: TK) (b1 :: TK) (sh :: [Maybe Nat]).
TKConversion a1 b1 -> TKConversion (TKX2 sh a1) (TKX2 sh b1)
ConvXX (FullShapeTK x -> TKConversion x b1 -> TKConversion b1 x
forall (a :: TK) (b :: TK).
FullShapeTK a -> TKConversion a b -> TKConversion b a
transposeTKConversion FullShapeTK x
x TKConversion a1 b1
TKConversion x b1
c)
ConvT2 TKConversion a1 a'
c1 TKConversion b1 b'
c2 | FTKProduct FullShapeTK y1
x1 FullShapeTK z
x2 <- FullShapeTK a
aftk ->
TKConversion a' y1
-> TKConversion b' z
-> TKConversion (TKProduct a' b') (TKProduct y1 z)
forall (a1 :: TK) (a' :: TK) (b1 :: TK) (b' :: TK).
TKConversion a1 a'
-> TKConversion b1 b'
-> TKConversion (TKProduct a1 b1) (TKProduct a' b')
ConvT2 (FullShapeTK y1 -> TKConversion y1 a' -> TKConversion a' y1
forall (a :: TK) (b :: TK).
FullShapeTK a -> TKConversion a b -> TKConversion b a
transposeTKConversion FullShapeTK y1
x1 TKConversion a1 a'
TKConversion y1 a'
c1) (FullShapeTK z -> TKConversion z b' -> TKConversion b' z
forall (a :: TK) (b :: TK).
FullShapeTK a -> TKConversion a b -> TKConversion b a
transposeTKConversion FullShapeTK z
x2 TKConversion b1 b'
TKConversion z b'
c2)
Conv0X SingletonTK a
_stk -> TKConversion b a
TKConversion (TKX2 ('[] @(Maybe Nat)) a) a
forall (b :: TK). TKConversion (TKX2 ('[] @(Maybe Nat)) b) b
ConvX0
TKConversion a b
ConvX0 | FTKX ShX sh Int
ZSX FullShapeTK x
x <- FullShapeTK a
aftk -> SingletonTK b -> TKConversion b (TKX2 ('[] @(Maybe Nat)) b)
forall (a :: TK).
SingletonTK a -> TKConversion a (TKX2 ('[] @(Maybe Nat)) a)
Conv0X (FullShapeTK b -> SingletonTK b
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK b
FullShapeTK x
x)
ConvNest SingletonTK (TKX2 sh a1)
_stk -> TKConversion b a
TKConversion
(TKX2 sh (TKX2 sh' a1)) (TKX2 ((++) @(Maybe Nat) sh sh') a1)
forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]) (a1 :: TK).
TKConversion
(TKX2 sh (TKX2 sh' a1)) (TKX2 ((++) @(Maybe Nat) sh sh') a1)
ConvUnnest
TKConversion a b
ConvUnnest | (FTKX IShX sh
shx (FTKX IShX sh
_ FullShapeTK x
x)) <- FullShapeTK a
aftk ->
SingletonTK (TKX2 sh x)
-> TKConversion
(TKX2 ((++) @(Maybe Nat) sh sh') x) (TKX2 sh (TKX2 sh' x))
forall (sh :: [Maybe Nat]) (a1 :: TK) (sh' :: [Maybe Nat]).
SingletonTK (TKX2 sh a1)
-> TKConversion
(TKX2 ((++) @(Maybe Nat) sh sh') a1) (TKX2 sh (TKX2 sh' a1))
ConvNest (StaticShX sh -> SingletonTK x -> SingletonTK (TKX2 sh x)
forall (sh :: [Maybe Nat]) (x :: TK).
StaticShX sh -> SingletonTK x -> SingletonTK (TKX2 sh x)
STKX (IShX sh -> StaticShX sh
forall (sh :: [Maybe Nat]) i. ShX sh i -> StaticShX sh
ssxFromShX IShX sh
shx) (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x))
ConvZip SingletonTK a1
stk1 SingletonTK b1
stk2 -> SingletonTK a1
-> SingletonTK b1
-> TKConversion
(TKX2 sh (TKProduct a1 b1)) (TKProduct (TKX2 sh a1) (TKX2 sh b1))
forall (a1 :: TK) (b1 :: TK) (sh :: [Maybe Nat]).
SingletonTK a1
-> SingletonTK b1
-> TKConversion
(TKX2 sh (TKProduct a1 b1)) (TKProduct (TKX2 sh a1) (TKX2 sh b1))
ConvUnzip SingletonTK a1
stk1 SingletonTK b1
stk2
ConvUnzip SingletonTK a1
stk1 SingletonTK b1
stk2 -> SingletonTK a1
-> SingletonTK b1
-> TKConversion
(TKProduct (TKX2 sh a1) (TKX2 sh b1)) (TKX2 sh (TKProduct a1 b1))
forall (a1 :: TK) (b1 :: TK) (sh :: [Maybe Nat]).
SingletonTK a1
-> SingletonTK b1
-> TKConversion
(TKProduct (TKX2 sh a1) (TKX2 sh b1)) (TKX2 sh (TKProduct a1 b1))
ConvZip SingletonTK a1
stk1 SingletonTK b1
stk2
evalRevFromnMap :: forall target. (ADReadyNoLet target, ShareTensor target)
=> EvalState target -> EvalState target
evalRevFromnMap :: forall (target :: Target).
(ADReadyNoLet target, ShareTensor target) =>
EvalState target -> EvalState target
evalRevFromnMap s :: EvalState target
s@EvalState{DEnumMap @TK (NodeId target) (Delta target)
nMap :: forall (target :: Target).
EvalState target -> DEnumMap @TK (NodeId target) (Delta target)
nMap :: DEnumMap @TK (NodeId target) (Delta target)
nMap, ADMap target
dMap :: forall (target :: Target). EvalState target -> ADMap target
dMap :: ADMap target
dMap} =
case DEnumMap @TK (NodeId target) (Delta target)
-> Maybe
(DSum @TK (NodeId target) (Delta target),
DEnumMap @TK (NodeId target) (Delta target))
forall {kind} (k :: kind -> Type) (v :: kind -> Type).
Enum1 @kind k =>
DEnumMap @kind k v -> Maybe (DSum @kind k v, DEnumMap @kind k v)
DMap.maxViewWithKey DEnumMap @TK (NodeId target) (Delta target)
nMap of
Just (NodeId target a
n :=> Delta target a
d, DEnumMap @TK (NodeId target) (Delta target)
nMap2) ->
let s2 :: EvalState target
s2 = EvalState target
s {nMap = nMap2}
s3 :: EvalState target
s3 = case NodeId target a -> ADMap target -> Maybe (Cotangent target a)
forall {kind} (k :: kind -> Type) (a :: kind) (v :: kind -> Type).
(Enum1 @kind k, TestEquality @kind k) =>
k a -> DEnumMap @kind k v -> Maybe (v a)
DMap.lookup NodeId target a
n ADMap target
dMap of
Just (Cotangent target (ADTensorKind a)
c) -> FullShapeTK a
-> EvalState target
-> target (ADTensorKind a)
-> Delta target a
-> EvalState target
forall (y :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target) =>
FullShapeTK y
-> EvalState target
-> target (ADTensorKind y)
-> Delta target y
-> EvalState target
evalRev (NodeId target a -> FullShapeTK a
forall (f :: Target) (y :: TK). NodeId f y -> FullShapeTK y
nodeIdToFTK NodeId target a
n) EvalState target
s2 target (ADTensorKind a)
c Delta target a
d
Maybe (Cotangent target a)
Nothing -> String -> EvalState target
forall a. (?callStack::CallStack) => String -> a
error (String -> EvalState target) -> String -> EvalState target
forall a b. (a -> b) -> a -> b
$ String
"evalRevFromnMap: missing cotangent " String -> ShowS
forall a. [a] -> [a] -> [a]
++ NodeId target a -> String
forall a. Show a => a -> String
show NodeId target a
n
in EvalState target -> EvalState target
forall (target :: Target).
(ADReadyNoLet target, ShareTensor target) =>
EvalState target -> EvalState target
evalRevFromnMap EvalState target
s3
Maybe
(DSum @TK (NodeId target) (Delta target),
DEnumMap @TK (NodeId target) (Delta target))
Nothing -> EvalState target
s
evalFwd
:: forall target y. (ADReadyNoLet target, ShareTensor target)
=> IMap target -> ADMap target -> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwd :: forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwd IMap target
params ADMap target
s Delta target y
d0 = case Delta target y
d0 of
DeltaShare NodeId target y
n Delta target y
d ->
case NodeId target y -> ADMap target -> Maybe (Cotangent target y)
forall {kind} (k :: kind -> Type) (a :: kind) (v :: kind -> Type).
(Enum1 @kind k, TestEquality @kind k) =>
k a -> DEnumMap @kind k v -> Maybe (v a)
DMap.lookup NodeId target y
n ADMap target
s of
Just Cotangent target y
e1 -> (ADMap target
s, Cotangent target y -> target (ADTensorKind y)
forall (target :: Target) (y :: TK).
Cotangent target y -> target (ADTensorKind y)
unCotangent Cotangent target y
e1)
Maybe (Cotangent target y)
Nothing ->
let (ADMap target
s2, target (ADTensorKind y)
cRaw) = IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwd IMap target
params ADMap target
s Delta target y
d
cShared :: target (ADTensorKind y)
cShared = target (ADTensorKind y) -> target (ADTensorKind y)
forall (y :: TK). target y -> target y
forall (target :: Target) (y :: TK).
ShareTensor target =>
target y -> target y
tshare target (ADTensorKind y)
cRaw
cd :: Cotangent target y
cd = target (ADTensorKind y) -> Cotangent target y
forall (target :: Target) (y :: TK).
target (ADTensorKind y) -> Cotangent target y
Cotangent target (ADTensorKind y)
cShared
s3 :: ADMap target
s3 = NodeId target y
-> Cotangent target y -> ADMap target -> ADMap target
forall {kind} (k :: kind -> Type) (a :: kind) (v :: kind -> Type).
Enum1 @kind k =>
k a -> v a -> DEnumMap @kind k v -> DEnumMap @kind k v
DMap.insert NodeId target y
n Cotangent target y
cd ADMap target
s2
in (ADMap target
s3, target (ADTensorKind y)
cShared)
DeltaInput InputId target y
inputId ->
case InputId target y -> IMap target -> Maybe (TensorOrZero target y)
forall {kind} (k :: kind -> Type) (a :: kind) (v :: kind -> Type).
(Enum1 @kind k, TestEquality @kind k) =>
k a -> DEnumMap @kind k v -> Maybe (v a)
DMap.lookup InputId target y
inputId IMap target
params of
Just TensorOrZero target y
dtk -> (ADMap target
s, FullShapeTK y -> target y -> target (ADTensorKind y)
forall (target :: Target) (y :: TK).
(BaseTensor target, ConvertTensor target, ShareTensor target) =>
FullShapeTK y -> target y -> target (ADTensorKind y)
toADTensorKindShared (InputId target y -> FullShapeTK y
forall (f :: Target) (y :: TK). InputId f y -> FullShapeTK y
inputIdToFTK InputId target y
inputId)
(target y -> target (ADTensorKind y))
-> target y -> target (ADTensorKind y)
forall a b. (a -> b) -> a -> b
$ TensorOrZero target y -> target y
forall (target :: Target) (x :: TK).
ADReadyNoLet target =>
TensorOrZero target x -> target x
evalTensorOrZero TensorOrZero target y
dtk)
Maybe (TensorOrZero target y)
Nothing -> String -> (ADMap target, target (ADTensorKind y))
forall a. (?callStack::CallStack) => String -> a
error String
"evalFwd: missing input"
DeltaPair Delta target y
d1 Delta target z
d2 ->
let (ADMap target
s2, target (ADTensorKind y)
t) = IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwd IMap target
params ADMap target
s Delta target y
d1
(ADMap target
s3, target (ADTensorKind z)
u) = IMap target
-> ADMap target
-> Delta target z
-> (ADMap target, target (ADTensorKind z))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwd IMap target
params ADMap target
s2 Delta target z
d2
in (ADMap target
s3, target (ADTensorKind y)
-> target (ADTensorKind z)
-> target (TKProduct (ADTensorKind y) (ADTensorKind z))
forall (x :: TK) (z :: TK).
target x -> target z -> target (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair target (ADTensorKind y)
t target (ADTensorKind z)
u)
DeltaProject1 Delta target (TKProduct y z)
d ->
let (ADMap target
s2, target (ADTensorKind (TKProduct y z))
v) = IMap target
-> ADMap target
-> Delta target (TKProduct y z)
-> (ADMap target, target (ADTensorKind (TKProduct y z)))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwd IMap target
params ADMap target
s Delta target (TKProduct y z)
d
in (ADMap target
s2, target (TKProduct (ADTensorKind y) (ADTensorKind z))
-> target (ADTensorKind y)
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 target (ADTensorKind (TKProduct y z))
target (TKProduct (ADTensorKind y) (ADTensorKind z))
v)
DeltaProject2 Delta target (TKProduct y y)
d ->
let (ADMap target
s2, target (ADTensorKind (TKProduct y y))
v) = IMap target
-> ADMap target
-> Delta target (TKProduct y y)
-> (ADMap target, target (ADTensorKind (TKProduct y y)))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwd IMap target
params ADMap target
s Delta target (TKProduct y y)
d
in (ADMap target
s2, target (TKProduct (ADTensorKind y) (ADTensorKind y))
-> target (ADTensorKind y)
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target z
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target z
tproject2 target (ADTensorKind (TKProduct y y))
target (TKProduct (ADTensorKind y) (ADTensorKind y))
v)
DeltaFromVector SNat k
snat SingletonTK y
stk Vector (Delta target y)
lsd | (:~:)
@TK
(BuildTensorKind k (ADTensorKind y))
(ADTensorKind (BuildTensorKind k y))
Refl <- SNat k
-> SingletonTK y
-> (:~:)
@TK
(BuildTensorKind k (ADTensorKind y))
(ADTensorKind (BuildTensorKind k y))
forall (k :: Nat) (y :: TK).
SNat k
-> SingletonTK y
-> (:~:)
@TK
(BuildTensorKind k (ADTensorKind y))
(ADTensorKind (BuildTensorKind k y))
lemBuildOfAD SNat k
snat SingletonTK y
stk ->
let (ADMap target
s2, Vector (target (ADTensorKind y))
l) = (ADMap target
-> Delta target y -> (ADMap target, target (ADTensorKind y)))
-> ADMap target
-> Vector (Delta target y)
-> (ADMap target, Vector (target (ADTensorKind y)))
forall (t :: Type -> Type) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
mapAccumL (IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwd IMap target
params) ADMap target
s Vector (Delta target y)
lsd
in (ADMap target
s2, SNat k
-> SingletonTK (ADTensorKind y)
-> Vector (target (ADTensorKind y))
-> target (BuildTensorKind k (ADTensorKind y))
forall (y :: TK) (k :: Nat).
SNat k
-> SingletonTK y
-> Vector (target y)
-> target (BuildTensorKind k y)
forall (target :: Target) (y :: TK) (k :: Nat).
BaseTensor target =>
SNat k
-> SingletonTK y
-> Vector (target y)
-> target (BuildTensorKind k y)
tfromVector SNat k
snat (SingletonTK y -> SingletonTK (ADTensorKind y)
forall (y :: TK). SingletonTK y -> SingletonTK (ADTensorKind y)
adSTK SingletonTK y
stk) Vector (target (ADTensorKind y))
l)
DeltaSum SNat k
snat SingletonTK y
stk Delta target (BuildTensorKind k y)
d | (:~:)
@TK
(BuildTensorKind k (ADTensorKind y))
(ADTensorKind (BuildTensorKind k y))
Refl <- SNat k
-> SingletonTK y
-> (:~:)
@TK
(BuildTensorKind k (ADTensorKind y))
(ADTensorKind (BuildTensorKind k y))
forall (k :: Nat) (y :: TK).
SNat k
-> SingletonTK y
-> (:~:)
@TK
(BuildTensorKind k (ADTensorKind y))
(ADTensorKind (BuildTensorKind k y))
lemBuildOfAD SNat k
snat SingletonTK y
stk ->
let (ADMap target
s2, target (ADTensorKind (BuildTensorKind k y))
t) = IMap target
-> ADMap target
-> Delta target (BuildTensorKind k y)
-> (ADMap target, target (ADTensorKind (BuildTensorKind k y)))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwd IMap target
params ADMap target
s Delta target (BuildTensorKind k y)
d
in (ADMap target
s2, SNat k
-> SingletonTK (ADTensorKind y)
-> target (BuildTensorKind k (ADTensorKind y))
-> target (ADTensorKind y)
forall (z :: TK) (k :: Nat).
ConvertTensor target =>
SNat k -> SingletonTK z -> target (BuildTensorKind k z) -> target z
forall (target :: Target) (z :: TK) (k :: Nat).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target (BuildTensorKind k z) -> target z
tsum SNat k
snat (SingletonTK y -> SingletonTK (ADTensorKind y)
forall (y :: TK). SingletonTK y -> SingletonTK (ADTensorKind y)
adSTK SingletonTK y
stk) target (ADTensorKind (BuildTensorKind k y))
target (BuildTensorKind k (ADTensorKind y))
t)
DeltaReplicate SNat k
snat SingletonTK y
stk Delta target y
d | (:~:)
@TK
(BuildTensorKind k (ADTensorKind y))
(ADTensorKind (BuildTensorKind k y))
Refl <- SNat k
-> SingletonTK y
-> (:~:)
@TK
(BuildTensorKind k (ADTensorKind y))
(ADTensorKind (BuildTensorKind k y))
forall (k :: Nat) (y :: TK).
SNat k
-> SingletonTK y
-> (:~:)
@TK
(BuildTensorKind k (ADTensorKind y))
(ADTensorKind (BuildTensorKind k y))
lemBuildOfAD SNat k
snat SingletonTK y
stk ->
let (ADMap target
s2, target (ADTensorKind y)
t) = IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwd IMap target
params ADMap target
s Delta target y
d
in (ADMap target
s2, SNat k
-> SingletonTK (ADTensorKind y)
-> target (ADTensorKind y)
-> target (BuildTensorKind k (ADTensorKind y))
forall (z :: TK) (k :: Nat).
ConvertTensor target =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
forall (target :: Target) (z :: TK) (k :: Nat).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
treplicate SNat k
snat (SingletonTK y -> SingletonTK (ADTensorKind y)
forall (y :: TK). SingletonTK y -> SingletonTK (ADTensorKind y)
adSTK SingletonTK y
stk) target (ADTensorKind y)
t)
DeltaMapAccumR SNat k
k FullShapeTK by
bftk FullShapeTK ey
eftk target (BuildTensorKind k accy)
q target (BuildTensorKind k ey)
es HFun
(TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy by))
df HFun
(TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy ey))
_rf Delta target accy
acc0' Delta target (BuildTensorKind k ey)
es'
| (:~:)
@TK
(BuildTensorKind k (ADTensorKind by))
(ADTensorKind (BuildTensorKind k by))
Refl <- SNat k
-> SingletonTK by
-> (:~:)
@TK
(BuildTensorKind k (ADTensorKind by))
(ADTensorKind (BuildTensorKind k by))
forall (k :: Nat) (y :: TK).
SNat k
-> SingletonTK y
-> (:~:)
@TK
(BuildTensorKind k (ADTensorKind y))
(ADTensorKind (BuildTensorKind k y))
lemBuildOfAD SNat k
k (FullShapeTK by -> SingletonTK by
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK by
bftk)
, (:~:)
@TK
(BuildTensorKind k (ADTensorKind ey))
(ADTensorKind (BuildTensorKind k ey))
Refl <- SNat k
-> SingletonTK ey
-> (:~:)
@TK
(BuildTensorKind k (ADTensorKind ey))
(ADTensorKind (BuildTensorKind k ey))
forall (k :: Nat) (y :: TK).
SNat k
-> SingletonTK y
-> (:~:)
@TK
(BuildTensorKind k (ADTensorKind y))
(ADTensorKind (BuildTensorKind k y))
lemBuildOfAD SNat k
k (FullShapeTK ey -> SingletonTK ey
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK ey
eftk) ->
let accftk :: FullShapeTK accy
accftk = Delta target accy -> FullShapeTK accy
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target accy
acc0'
accftkAD :: FullShapeTK (ADTensorKind accy)
accftkAD = FullShapeTK accy -> FullShapeTK (ADTensorKind accy)
forall (y :: TK). FullShapeTK y -> FullShapeTK (ADTensorKind y)
adFTK FullShapeTK accy
accftk
bftkAD :: FullShapeTK (ADTensorKind by)
bftkAD = FullShapeTK by -> FullShapeTK (ADTensorKind by)
forall (y :: TK). FullShapeTK y -> FullShapeTK (ADTensorKind y)
adFTK FullShapeTK by
bftk
eftkAD :: FullShapeTK (ADTensorKind ey)
eftkAD = FullShapeTK ey -> FullShapeTK (ADTensorKind ey)
forall (y :: TK). FullShapeTK y -> FullShapeTK (ADTensorKind y)
adFTK FullShapeTK ey
eftk
(ADMap target
s2, target (ADTensorKind accy)
cacc0) = IMap target
-> ADMap target
-> Delta target accy
-> (ADMap target, target (ADTensorKind accy))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwd IMap target
params ADMap target
s Delta target accy
acc0'
(ADMap target
s3, target (ADTensorKind (BuildTensorKind k ey))
ces) = IMap target
-> ADMap target
-> Delta target (BuildTensorKind k ey)
-> (ADMap target, target (ADTensorKind (BuildTensorKind k ey)))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwd IMap target
params ADMap target
s2 Delta target (BuildTensorKind k ey)
es'
in (ADMap target
s3, Proxy @Target target
-> SNat k
-> FullShapeTK (ADTensorKind accy)
-> FullShapeTK (ADTensorKind by)
-> FullShapeTK (TKProduct (ADTensorKind ey) (TKProduct accy ey))
-> (forall (f :: Target).
ADReady f =>
f (ADTensorKind accy)
-> f (TKProduct (ADTensorKind ey) (TKProduct accy ey))
-> f (TKProduct (ADTensorKind accy) (ADTensorKind by)))
-> target (ADTensorKind accy)
-> target
(BuildTensorKind
k (TKProduct (ADTensorKind ey) (TKProduct accy ey)))
-> target
(TKProduct
(ADTensorKind accy) (BuildTensorKind k (ADTensorKind by)))
forall (accy :: TK) (by :: TK) (ey :: TK) (k :: Nat)
(target :: Target).
BaseTensor target =>
Proxy @Target target
-> SNat k
-> FullShapeTK accy
-> FullShapeTK by
-> FullShapeTK ey
-> (forall (f :: Target).
ADReady f =>
f accy -> f ey -> f (TKProduct accy by))
-> target accy
-> target (BuildTensorKind k ey)
-> target (TKProduct accy (BuildTensorKind k by))
tmapAccumR (forall {k} (t :: k). Proxy @k t
forall (t :: Target). Proxy @Target t
Proxy @target)
SNat k
k FullShapeTK (ADTensorKind accy)
accftkAD FullShapeTK (ADTensorKind by)
bftkAD (FullShapeTK (ADTensorKind ey)
-> FullShapeTK (TKProduct accy ey)
-> FullShapeTK (TKProduct (ADTensorKind ey) (TKProduct accy ey))
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct FullShapeTK (ADTensorKind ey)
eftkAD
(FullShapeTK accy
-> FullShapeTK ey -> FullShapeTK (TKProduct accy ey)
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct FullShapeTK accy
accftk FullShapeTK ey
eftk))
(\f (ADTensorKind accy)
dacc f (TKProduct (ADTensorKind ey) (TKProduct accy ey))
de_acc_e ->
f (TKProduct (ADTensorKind ey) (TKProduct accy ey))
-> (f (TKProduct (ADTensorKind ey) (TKProduct accy ey))
-> f (TKProduct (ADTensorKind accy) (ADTensorKind by)))
-> f (TKProduct (ADTensorKind accy) (ADTensorKind by))
forall (x :: TK) (z :: TK). f x -> (f x -> f z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
LetTensor target =>
target x -> (target x -> target z) -> target z
ttlet f (TKProduct (ADTensorKind ey) (TKProduct accy ey))
de_acc_e ((f (TKProduct (ADTensorKind ey) (TKProduct accy ey))
-> f (TKProduct (ADTensorKind accy) (ADTensorKind by)))
-> f (TKProduct (ADTensorKind accy) (ADTensorKind by)))
-> (f (TKProduct (ADTensorKind ey) (TKProduct accy ey))
-> f (TKProduct (ADTensorKind accy) (ADTensorKind by)))
-> f (TKProduct (ADTensorKind accy) (ADTensorKind by))
forall a b. (a -> b) -> a -> b
$ \ !f (TKProduct (ADTensorKind ey) (TKProduct accy ey))
de_acc_e1 ->
HFun
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind by))
-> forall (f :: Target).
ADReady f =>
f (TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
(TKProduct accy ey))
-> f (TKProduct (ADTensorKind accy) (ADTensorKind by))
forall (x :: TK) (z :: TK).
HFun x z -> forall (f :: Target). ADReady f => f x -> f z
unHFun HFun
(TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy by))
HFun
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind by))
df (f (TKProduct (ADTensorKind accy) (ADTensorKind ey))
-> f (TKProduct accy ey)
-> f (TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
(TKProduct accy ey))
forall (x :: TK) (z :: TK). f x -> f z -> f (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair (f (ADTensorKind accy)
-> f (ADTensorKind ey)
-> f (TKProduct (ADTensorKind accy) (ADTensorKind ey))
forall (x :: TK) (z :: TK). f x -> f z -> f (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair f (ADTensorKind accy)
dacc (f (TKProduct (ADTensorKind ey) (TKProduct accy ey))
-> f (ADTensorKind ey)
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 f (TKProduct (ADTensorKind ey) (TKProduct accy ey))
de_acc_e1))
(f (TKProduct (ADTensorKind ey) (TKProduct accy ey))
-> f (TKProduct accy ey)
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target z
tproject2 f (TKProduct (ADTensorKind ey) (TKProduct accy ey))
de_acc_e1)))
target (ADTensorKind accy)
cacc0
(target (BuildTensorKind k (ADTensorKind ey))
-> target
(TKProduct (BuildTensorKind k accy) (BuildTensorKind k ey))
-> target
(TKProduct
(BuildTensorKind k (ADTensorKind ey))
(TKProduct (BuildTensorKind k accy) (BuildTensorKind k ey)))
forall (x :: TK) (z :: TK).
target x -> target z -> target (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair target (ADTensorKind (BuildTensorKind k ey))
target (BuildTensorKind k (ADTensorKind ey))
ces (target (BuildTensorKind k accy)
-> target (BuildTensorKind k ey)
-> target
(TKProduct (BuildTensorKind k accy) (BuildTensorKind k ey))
forall (x :: TK) (z :: TK).
target x -> target z -> target (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair target (BuildTensorKind k accy)
q target (BuildTensorKind k ey)
es)))
DeltaMapAccumL SNat k
k FullShapeTK by
bftk FullShapeTK ey
eftk target (BuildTensorKind k accy)
q target (BuildTensorKind k ey)
es HFun
(TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy by))
df HFun
(TKProduct (ADTensorKind (TKProduct accy by)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy ey))
_rf Delta target accy
acc0' Delta target (BuildTensorKind k ey)
es'
| (:~:)
@TK
(BuildTensorKind k (ADTensorKind by))
(ADTensorKind (BuildTensorKind k by))
Refl <- SNat k
-> SingletonTK by
-> (:~:)
@TK
(BuildTensorKind k (ADTensorKind by))
(ADTensorKind (BuildTensorKind k by))
forall (k :: Nat) (y :: TK).
SNat k
-> SingletonTK y
-> (:~:)
@TK
(BuildTensorKind k (ADTensorKind y))
(ADTensorKind (BuildTensorKind k y))
lemBuildOfAD SNat k
k (FullShapeTK by -> SingletonTK by
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK by
bftk)
, (:~:)
@TK
(BuildTensorKind k (ADTensorKind ey))
(ADTensorKind (BuildTensorKind k ey))
Refl <- SNat k
-> SingletonTK ey
-> (:~:)
@TK
(BuildTensorKind k (ADTensorKind ey))
(ADTensorKind (BuildTensorKind k ey))
forall (k :: Nat) (y :: TK).
SNat k
-> SingletonTK y
-> (:~:)
@TK
(BuildTensorKind k (ADTensorKind y))
(ADTensorKind (BuildTensorKind k y))
lemBuildOfAD SNat k
k (FullShapeTK ey -> SingletonTK ey
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK ey
eftk) ->
let accftk :: FullShapeTK accy
accftk = Delta target accy -> FullShapeTK accy
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target accy
acc0'
accftkAD :: FullShapeTK (ADTensorKind accy)
accftkAD = FullShapeTK accy -> FullShapeTK (ADTensorKind accy)
forall (y :: TK). FullShapeTK y -> FullShapeTK (ADTensorKind y)
adFTK FullShapeTK accy
accftk
bftkAD :: FullShapeTK (ADTensorKind by)
bftkAD = FullShapeTK by -> FullShapeTK (ADTensorKind by)
forall (y :: TK). FullShapeTK y -> FullShapeTK (ADTensorKind y)
adFTK FullShapeTK by
bftk
eftkAD :: FullShapeTK (ADTensorKind ey)
eftkAD = FullShapeTK ey -> FullShapeTK (ADTensorKind ey)
forall (y :: TK). FullShapeTK y -> FullShapeTK (ADTensorKind y)
adFTK FullShapeTK ey
eftk
(ADMap target
s2, target (ADTensorKind accy)
cacc0) = IMap target
-> ADMap target
-> Delta target accy
-> (ADMap target, target (ADTensorKind accy))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwd IMap target
params ADMap target
s Delta target accy
acc0'
(ADMap target
s3, target (ADTensorKind (BuildTensorKind k ey))
ces) = IMap target
-> ADMap target
-> Delta target (BuildTensorKind k ey)
-> (ADMap target, target (ADTensorKind (BuildTensorKind k ey)))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwd IMap target
params ADMap target
s2 Delta target (BuildTensorKind k ey)
es'
in (ADMap target
s3, Proxy @Target target
-> SNat k
-> FullShapeTK (ADTensorKind accy)
-> FullShapeTK (ADTensorKind by)
-> FullShapeTK (TKProduct (ADTensorKind ey) (TKProduct accy ey))
-> (forall (f :: Target).
ADReady f =>
f (ADTensorKind accy)
-> f (TKProduct (ADTensorKind ey) (TKProduct accy ey))
-> f (TKProduct (ADTensorKind accy) (ADTensorKind by)))
-> target (ADTensorKind accy)
-> target
(BuildTensorKind
k (TKProduct (ADTensorKind ey) (TKProduct accy ey)))
-> target
(TKProduct
(ADTensorKind accy) (BuildTensorKind k (ADTensorKind by)))
forall (accy :: TK) (by :: TK) (ey :: TK) (k :: Nat)
(target :: Target).
BaseTensor target =>
Proxy @Target target
-> SNat k
-> FullShapeTK accy
-> FullShapeTK by
-> FullShapeTK ey
-> (forall (f :: Target).
ADReady f =>
f accy -> f ey -> f (TKProduct accy by))
-> target accy
-> target (BuildTensorKind k ey)
-> target (TKProduct accy (BuildTensorKind k by))
tmapAccumL (forall {k} (t :: k). Proxy @k t
forall (t :: Target). Proxy @Target t
Proxy @target)
SNat k
k FullShapeTK (ADTensorKind accy)
accftkAD FullShapeTK (ADTensorKind by)
bftkAD (FullShapeTK (ADTensorKind ey)
-> FullShapeTK (TKProduct accy ey)
-> FullShapeTK (TKProduct (ADTensorKind ey) (TKProduct accy ey))
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct FullShapeTK (ADTensorKind ey)
eftkAD
(FullShapeTK accy
-> FullShapeTK ey -> FullShapeTK (TKProduct accy ey)
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct FullShapeTK accy
accftk FullShapeTK ey
eftk))
(\f (ADTensorKind accy)
dacc f (TKProduct (ADTensorKind ey) (TKProduct accy ey))
de_acc_e ->
f (TKProduct (ADTensorKind ey) (TKProduct accy ey))
-> (f (TKProduct (ADTensorKind ey) (TKProduct accy ey))
-> f (TKProduct (ADTensorKind accy) (ADTensorKind by)))
-> f (TKProduct (ADTensorKind accy) (ADTensorKind by))
forall (x :: TK) (z :: TK). f x -> (f x -> f z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
LetTensor target =>
target x -> (target x -> target z) -> target z
ttlet f (TKProduct (ADTensorKind ey) (TKProduct accy ey))
de_acc_e ((f (TKProduct (ADTensorKind ey) (TKProduct accy ey))
-> f (TKProduct (ADTensorKind accy) (ADTensorKind by)))
-> f (TKProduct (ADTensorKind accy) (ADTensorKind by)))
-> (f (TKProduct (ADTensorKind ey) (TKProduct accy ey))
-> f (TKProduct (ADTensorKind accy) (ADTensorKind by)))
-> f (TKProduct (ADTensorKind accy) (ADTensorKind by))
forall a b. (a -> b) -> a -> b
$ \ !f (TKProduct (ADTensorKind ey) (TKProduct accy ey))
de_acc_e1 ->
HFun
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind by))
-> forall (f :: Target).
ADReady f =>
f (TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
(TKProduct accy ey))
-> f (TKProduct (ADTensorKind accy) (ADTensorKind by))
forall (x :: TK) (z :: TK).
HFun x z -> forall (f :: Target). ADReady f => f x -> f z
unHFun HFun
(TKProduct (ADTensorKind (TKProduct accy ey)) (TKProduct accy ey))
(ADTensorKind (TKProduct accy by))
HFun
(TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
(TKProduct accy ey))
(TKProduct (ADTensorKind accy) (ADTensorKind by))
df (f (TKProduct (ADTensorKind accy) (ADTensorKind ey))
-> f (TKProduct accy ey)
-> f (TKProduct
(TKProduct (ADTensorKind accy) (ADTensorKind ey))
(TKProduct accy ey))
forall (x :: TK) (z :: TK). f x -> f z -> f (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair (f (ADTensorKind accy)
-> f (ADTensorKind ey)
-> f (TKProduct (ADTensorKind accy) (ADTensorKind ey))
forall (x :: TK) (z :: TK). f x -> f z -> f (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair f (ADTensorKind accy)
dacc (f (TKProduct (ADTensorKind ey) (TKProduct accy ey))
-> f (ADTensorKind ey)
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 f (TKProduct (ADTensorKind ey) (TKProduct accy ey))
de_acc_e1))
(f (TKProduct (ADTensorKind ey) (TKProduct accy ey))
-> f (TKProduct accy ey)
forall (x :: TK) (z :: TK). f (TKProduct x z) -> f z
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target z
tproject2 f (TKProduct (ADTensorKind ey) (TKProduct accy ey))
de_acc_e1)))
target (ADTensorKind accy)
cacc0
(target (BuildTensorKind k (ADTensorKind ey))
-> target
(TKProduct (BuildTensorKind k accy) (BuildTensorKind k ey))
-> target
(TKProduct
(BuildTensorKind k (ADTensorKind ey))
(TKProduct (BuildTensorKind k accy) (BuildTensorKind k ey)))
forall (x :: TK) (z :: TK).
target x -> target z -> target (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair target (ADTensorKind (BuildTensorKind k ey))
target (BuildTensorKind k (ADTensorKind ey))
ces (target (BuildTensorKind k accy)
-> target (BuildTensorKind k ey)
-> target
(TKProduct (BuildTensorKind k accy) (BuildTensorKind k ey))
forall (x :: TK) (z :: TK).
target x -> target z -> target (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair target (BuildTensorKind k accy)
q target (BuildTensorKind k ey)
es)))
Delta target y
_ -> let y :: FullShapeTK y
y = Delta target y -> FullShapeTK y
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target y
d0
ay :: FullShapeTK (ADTensorKind y)
ay = FullShapeTK y -> FullShapeTK (ADTensorKind y)
forall (y :: TK). FullShapeTK y -> FullShapeTK (ADTensorKind y)
adFTK FullShapeTK y
y
in case FullShapeTK y
-> FullShapeTK (ADTensorKind y)
-> Maybe ((:~:) @TK y (ADTensorKind y))
forall (y1 :: TK) (y2 :: TK).
FullShapeTK y1 -> FullShapeTK y2 -> Maybe ((:~:) @TK y1 y2)
matchingFTK FullShapeTK y
y FullShapeTK (ADTensorKind y)
ay of
Just (:~:) @TK y (ADTensorKind y)
Refl -> IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwdSame IMap target
params ADMap target
s Delta target y
d0
Maybe ((:~:) @TK y (ADTensorKind y))
_ -> (ADMap target
s, FullShapeTK (ADTensorKind y) -> target (ADTensorKind y)
forall (y :: TK). FullShapeTK y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
FullShapeTK y -> target y
tdefTarget FullShapeTK (ADTensorKind y)
ay)
evalFwdSame
:: forall target y.
(ADReadyNoLet target, ShareTensor target, y ~ ADTensorKind y)
=> IMap target -> ADMap target -> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwdSame :: forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwdSame IMap target
params ADMap target
s = \case
DeltaInput InputId target y
inputId ->
case InputId target y -> IMap target -> Maybe (TensorOrZero target y)
forall {kind} (k :: kind -> Type) (a :: kind) (v :: kind -> Type).
(Enum1 @kind k, TestEquality @kind k) =>
k a -> DEnumMap @kind k v -> Maybe (v a)
DMap.lookup InputId target y
inputId IMap target
params of
Just TensorOrZero target y
dtk -> (ADMap target
s, TensorOrZero target y -> target y
forall (target :: Target) (x :: TK).
ADReadyNoLet target =>
TensorOrZero target x -> target x
evalTensorOrZero TensorOrZero target y
dtk)
Maybe (TensorOrZero target y)
Nothing -> String -> (ADMap target, target y)
forall a. (?callStack::CallStack) => String -> a
error String
"evalFwdSame: missing input"
DeltaZero FullShapeTK y
ftk -> (ADMap target
s, FullShapeTK (ADTensorKind y) -> target (ADTensorKind y)
forall (y :: TK). FullShapeTK y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
FullShapeTK y -> target y
tdefTarget (FullShapeTK (ADTensorKind y) -> target (ADTensorKind y))
-> FullShapeTK (ADTensorKind y) -> target (ADTensorKind y)
forall a b. (a -> b) -> a -> b
$ FullShapeTK y -> FullShapeTK (ADTensorKind y)
forall (y :: TK). FullShapeTK y -> FullShapeTK (ADTensorKind y)
adFTK FullShapeTK y
ftk)
DeltaScale (NestedTarget target y
k) Delta target y
d -> (target y -> target (ADTensorKind y))
-> (ADMap target, target y)
-> (ADMap target, target (ADTensorKind y))
forall b c d. (b -> c) -> (d, b) -> (d, c)
forall (a :: Type -> Type -> Type) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (target y -> target y -> target y
forall a. Num a => a -> a -> a
* target y
k) ((ADMap target, target y)
-> (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target y)
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$ IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwdSame IMap target
params ADMap target
s Delta target y
d
DeltaAdd Delta target y
d Delta target y
e -> let (ADMap target
s2, target (ADTensorKind y)
t) = IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwdSame IMap target
params ADMap target
s Delta target y
d
(ADMap target
s3, target (ADTensorKind y)
u) = IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwdSame IMap target
params ADMap target
s2 Delta target y
e
in (ADMap target
s3, target y
target (ADTensorKind y)
t target y -> target y -> target y
forall a. Num a => a -> a -> a
+ target y
target (ADTensorKind y)
u)
d0 :: Delta target y
d0@(DeltaCastK @r1 Delta target (TKScalar r1)
d) ->
case SingletonTK (TKScalar r1)
-> SingletonTK (TKScalar (ADTensorScalar r1))
-> Maybe ((:~:) @TK (TKScalar r1) (TKScalar (ADTensorScalar r1)))
forall (y1 :: TK) (y2 :: TK).
SingletonTK y1 -> SingletonTK y2 -> Maybe ((:~:) @TK y1 y2)
sameSTK (forall r. GoodScalar r => SingletonTK (TKScalar r)
STKScalar @r1) (SingletonTK (TKScalar r1)
-> SingletonTK (ADTensorKind (TKScalar r1))
forall (y :: TK). SingletonTK y -> SingletonTK (ADTensorKind y)
adSTK (forall r. GoodScalar r => SingletonTK (TKScalar r)
STKScalar @r1)) of
Just (:~:) @TK (TKScalar r1) (TKScalar (ADTensorScalar r1))
Refl -> (target (TKScalar r1) -> target (ADTensorKind y))
-> (ADMap target, target (TKScalar r1))
-> (ADMap target, target (ADTensorKind y))
forall b c d. (b -> c) -> (d, b) -> (d, c)
forall (a :: Type -> Type -> Type) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second target (TKScalar r1) -> target (ADTensorKind y)
target (TKScalar r1) -> target (TKScalar r2)
forall r1 r2.
(RealFrac r1, GoodScalar r1, RealFrac r2, GoodScalar r2) =>
target (TKScalar r1) -> target (TKScalar r2)
forall (target :: Target) r1 r2.
(BaseTensor target, RealFrac r1, GoodScalar r1, RealFrac r2,
GoodScalar r2) =>
target (TKScalar r1) -> target (TKScalar r2)
tkcast ((ADMap target, target (TKScalar r1))
-> (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (TKScalar r1))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$ IMap target
-> ADMap target
-> Delta target (TKScalar r1)
-> (ADMap target, target (ADTensorKind (TKScalar r1)))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwdSame IMap target
params ADMap target
s Delta target (TKScalar r1)
d
Maybe ((:~:) @TK (TKScalar r1) (TKScalar (ADTensorScalar r1)))
_ -> (ADMap target
s, FullShapeTK (ADTensorKind y) -> target (ADTensorKind y)
forall (y :: TK). FullShapeTK y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
FullShapeTK y -> target y
tdefTarget (FullShapeTK (ADTensorKind y) -> target (ADTensorKind y))
-> FullShapeTK (ADTensorKind y) -> target (ADTensorKind y)
forall a b. (a -> b) -> a -> b
$ FullShapeTK y -> FullShapeTK (ADTensorKind y)
forall (y :: TK). FullShapeTK y -> FullShapeTK (ADTensorKind y)
adFTK (FullShapeTK y -> FullShapeTK (ADTensorKind y))
-> FullShapeTK y -> FullShapeTK (ADTensorKind y)
forall a b. (a -> b) -> a -> b
$ Delta target y -> FullShapeTK y
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target y
d0)
d0 :: Delta target y
d0@(DeltaCastR Delta target (TKR n r1)
d) -> case Delta target (TKR n r1) -> FullShapeTK (TKR n r1)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR n r1)
d of
FullShapeTK (TKR n r1)
y -> case FullShapeTK (TKR n r1)
-> FullShapeTK (TKR2 n (TKScalar (ADTensorScalar r1)))
-> Maybe
((:~:) @TK (TKR n r1) (TKR2 n (TKScalar (ADTensorScalar r1))))
forall (y1 :: TK) (y2 :: TK).
FullShapeTK y1 -> FullShapeTK y2 -> Maybe ((:~:) @TK y1 y2)
matchingFTK FullShapeTK (TKR n r1)
y (FullShapeTK (TKR n r1) -> FullShapeTK (ADTensorKind (TKR n r1))
forall (y :: TK). FullShapeTK y -> FullShapeTK (ADTensorKind y)
adFTK FullShapeTK (TKR n r1)
y) of
Just (:~:) @TK (TKR n r1) (TKR2 n (TKScalar (ADTensorScalar r1)))
Refl -> (target (TKR n r1) -> target (ADTensorKind y))
-> (ADMap target, target (TKR n r1))
-> (ADMap target, target (ADTensorKind y))
forall b c d. (b -> c) -> (d, b) -> (d, c)
forall (a :: Type -> Type -> Type) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second target (TKR n r1) -> target (ADTensorKind y)
target (TKR n r1) -> target (TKR2 n (TKScalar r2))
forall r1 r2 (n :: Nat).
(RealFrac r1, GoodScalar r1, RealFrac r2, GoodScalar r2) =>
target (TKR n r1) -> target (TKR n r2)
forall (target :: Target) r1 r2 (n :: Nat).
(BaseTensor target, RealFrac r1, GoodScalar r1, RealFrac r2,
GoodScalar r2) =>
target (TKR n r1) -> target (TKR n r2)
trcast ((ADMap target, target (TKR n r1))
-> (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (TKR n r1))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$ IMap target
-> ADMap target
-> Delta target (TKR n r1)
-> (ADMap target, target (ADTensorKind (TKR n r1)))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwdSame IMap target
params ADMap target
s Delta target (TKR n r1)
d
Maybe
((:~:) @TK (TKR n r1) (TKR2 n (TKScalar (ADTensorScalar r1))))
_ -> (ADMap target
s, FullShapeTK (ADTensorKind y) -> target (ADTensorKind y)
forall (y :: TK). FullShapeTK y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
FullShapeTK y -> target y
tdefTarget (FullShapeTK (ADTensorKind y) -> target (ADTensorKind y))
-> FullShapeTK (ADTensorKind y) -> target (ADTensorKind y)
forall a b. (a -> b) -> a -> b
$ FullShapeTK y -> FullShapeTK (ADTensorKind y)
forall (y :: TK). FullShapeTK y -> FullShapeTK (ADTensorKind y)
adFTK (FullShapeTK y -> FullShapeTK (ADTensorKind y))
-> FullShapeTK y -> FullShapeTK (ADTensorKind y)
forall a b. (a -> b) -> a -> b
$ Delta target y -> FullShapeTK y
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target y
d0)
DeltaSum0R (DeltaZero (FTKR IShR n
_ FullShapeTK x
x)) -> (ADMap target
s, FullShapeTK (TKR2 0 x) -> target (TKR2 0 x)
forall (y :: TK). FullShapeTK y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
FullShapeTK y -> target y
tdefTarget (IShR 0 -> FullShapeTK x -> FullShapeTK (TKR2 0 x)
forall (n :: Nat) (x :: TK).
IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
FTKR IShR 0
forall (n :: Nat) i. ((n :: Nat) ~ (0 :: Nat)) => ShR n i
ZSR FullShapeTK x
x))
DeltaSum0R Delta target (TKR2 n r)
d -> case Delta target (TKR2 n r) -> FullShapeTK (TKR2 n r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR2 n r)
d of
FTKR IShR n
sh FullShapeTK x
x | SNat n
SNat <- IShR n -> SNat n
forall (n :: Nat) i. ShR n i -> SNat n
shrRank IShR n
sh ->
SingletonTK x
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
(target (TKR2 n r) -> target (ADTensorKind y))
-> (ADMap target, target (TKR2 n r))
-> (ADMap target, target (ADTensorKind y))
forall b c d. (b -> c) -> (d, b) -> (d, c)
forall (a :: Type -> Type -> Type) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second target (TKR2 n r) -> target (ADTensorKind y)
target (TKR2 n r) -> target (TKR2 0 r)
forall (n :: Nat) (x :: TK).
(KnownNat n, KnownSTK x) =>
target (TKR2 n x) -> target (TKR2 0 x)
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownNat n, KnownSTK x) =>
target (TKR2 n x) -> target (TKR2 0 x)
trsum0 ((ADMap target, target (TKR2 n r))
-> (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (TKR2 n r))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$ IMap target
-> ADMap target
-> Delta target (TKR2 n r)
-> (ADMap target, target (ADTensorKind (TKR2 n r)))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwdSame IMap target
params ADMap target
s Delta target (TKR2 n r)
d
DeltaDot0R target (TKR n r)
_ DeltaZero{} -> (ADMap target
s, Ranked 0 r -> target (TKR2 0 (TKScalar r))
forall r (n :: Nat). GoodScalar r => Ranked n r -> target (TKR n r)
forall (target :: Target) r (n :: Nat).
(BaseTensor target, GoodScalar r) =>
Ranked n r -> target (TKR n r)
trconcrete (Ranked 0 r -> target (TKR2 0 (TKScalar r)))
-> Ranked 0 r -> target (TKR2 0 (TKScalar r))
forall a b. (a -> b) -> a -> b
$ r -> Ranked 0 r
forall a. Elt a => a -> Ranked 0 a
Nested.rscalar r
0)
DeltaDot0R target (TKR n r)
v Delta target (TKR n r)
d -> case Delta target (TKR n r) -> FullShapeTK (TKR n r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR n r)
d of
FTKR IShR n
sh FullShapeTK x
x | SNat n
SNat <- IShR n -> SNat n
forall (n :: Nat) i. ShR n i -> SNat n
shrRank IShR n
sh ->
SingletonTK x
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
(target (TKR n r) -> target (ADTensorKind y))
-> (ADMap target, target (TKR n r))
-> (ADMap target, target (ADTensorKind y))
forall b c d. (b -> c) -> (d, b) -> (d, c)
forall (a :: Type -> Type -> Type) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (target (TKR n r)
-> target (TKR n r) -> target (TKR2 0 (TKScalar r))
forall (n :: Nat) r.
(KnownNat n, GoodScalar r) =>
target (TKR n r) -> target (TKR n r) -> target (TKR 0 r)
forall (target :: Target) (n :: Nat) r.
(BaseTensor target, KnownNat n, GoodScalar r) =>
target (TKR n r) -> target (TKR n r) -> target (TKR 0 r)
trdot0 target (TKR n r)
v) ((ADMap target, target (TKR n r))
-> (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (TKR n r))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$ IMap target
-> ADMap target
-> Delta target (TKR n r)
-> (ADMap target, target (ADTensorKind (TKR n r)))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwdSame IMap target
params ADMap target
s Delta target (TKR n r)
d
DeltaIndexR SNat n
SNat Delta target (TKR2 (m + n) r)
d IxROf target m
ix -> case Delta target (TKR2 (m + n) r) -> FullShapeTK (TKR2 (m + n) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR2 (m + n) r)
d of
FTKR IShR n
_ FullShapeTK x
x | SNat m
SNat <- IxROf target m -> SNat m
forall (n :: Nat) i. IxR n i -> SNat n
ixrRank IxROf target m
ix ->
SingletonTK x
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
(target (TKR2 (m + n) r) -> target (ADTensorKind y))
-> (ADMap target, target (TKR2 (m + n) r))
-> (ADMap target, target (ADTensorKind y))
forall b c d. (b -> c) -> (d, b) -> (d, c)
forall (a :: Type -> Type -> Type) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (target (TKR2 (m + n) r) -> IxROf target m -> target (TKR2 n r)
forall (m :: Nat) (n :: Nat) (x :: TK).
(KnownNat m, KnownNat n, KnownSTK x) =>
target (TKR2 (m + n) x) -> IxROf target m -> target (TKR2 n x)
forall (target :: Target) (m :: Nat) (n :: Nat) (x :: TK).
(BaseTensor target, KnownNat m, KnownNat n, KnownSTK x) =>
target (TKR2 (m + n) x) -> IxROf target m -> target (TKR2 n x)
`trindex` IxROf target m
ix) ((ADMap target, target (TKR2 (m + n) r))
-> (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (TKR2 (m + n) r))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$ IMap target
-> ADMap target
-> Delta target (TKR2 n r)
-> (ADMap target, target (ADTensorKind (TKR2 n r)))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwdSame IMap target
params ADMap target
s Delta target (TKR2 n r)
Delta target (TKR2 (m + n) r)
d
DeltaScatterR SNat m
SNat SNat n
SNat SNat p
SNat IShR (p + n)
sh Delta target (TKR2 (m + n) r)
d IxROf target m -> IxROf target p
f -> case Delta target (TKR2 (m + n) r) -> FullShapeTK (TKR2 (m + n) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR2 (m + n) r)
d of
FTKR IShR n
_ FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
let (ADMap target
s2, target (ADTensorKind (TKR2 n r))
t) = IMap target
-> ADMap target
-> Delta target (TKR2 n r)
-> (ADMap target, target (ADTensorKind (TKR2 n r)))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwdSame IMap target
params ADMap target
s Delta target (TKR2 n r)
Delta target (TKR2 (m + n) r)
d
in (ADMap target
s2, IShR (p + n)
-> target (TKR2 (m + n) r)
-> (IxROf target m -> IxROf target p)
-> target (TKR2 (p + n) r)
forall (m :: Nat) (n :: Nat) (p :: Nat) (x :: TK).
(KnownNat m, KnownNat n, KnownNat p, KnownSTK x) =>
IShR (p + n)
-> target (TKR2 (m + n) x)
-> (IxROf target m -> IxROf target p)
-> target (TKR2 (p + n) x)
forall (target :: Target) (m :: Nat) (n :: Nat) (p :: Nat)
(x :: TK).
(BaseTensor target, KnownNat m, KnownNat n, KnownNat p,
KnownSTK x) =>
IShR (p + n)
-> target (TKR2 (m + n) x)
-> (IxROf target m -> IxROf target p)
-> target (TKR2 (p + n) x)
trscatter IShR (p + n)
sh target (ADTensorKind (TKR2 n r))
target (TKR2 (m + n) r)
t IxROf target m -> IxROf target p
f)
DeltaGatherR SNat m
SNat SNat n
SNat SNat p
SNat IShR (m + n)
sh Delta target (TKR2 (p + n) r)
d IxROf target m -> IxROf target p
f -> case Delta target (TKR2 (p + n) r) -> FullShapeTK (TKR2 (p + n) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR2 (p + n) r)
d of
FTKR IShR n
_ FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
let (ADMap target
s2, target (ADTensorKind (TKR2 n r))
t) = IMap target
-> ADMap target
-> Delta target (TKR2 n r)
-> (ADMap target, target (ADTensorKind (TKR2 n r)))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwdSame IMap target
params ADMap target
s Delta target (TKR2 n r)
Delta target (TKR2 (p + n) r)
d
in (ADMap target
s2, IShR (m + n)
-> target (TKR2 (p + n) r)
-> (IxROf target m -> IxROf target p)
-> target (TKR2 (m + n) r)
forall (m :: Nat) (n :: Nat) (p :: Nat) (x :: TK).
(KnownNat m, KnownNat n, KnownNat p, KnownSTK x) =>
IShR (m + n)
-> target (TKR2 (p + n) x)
-> (IxROf target m -> IxROf target p)
-> target (TKR2 (m + n) x)
forall (target :: Target) (m :: Nat) (n :: Nat) (p :: Nat)
(x :: TK).
(BaseTensor target, KnownNat m, KnownNat n, KnownNat p,
KnownSTK x) =>
IShR (m + n)
-> target (TKR2 (p + n) x)
-> (IxROf target m -> IxROf target p)
-> target (TKR2 (m + n) x)
trgather IShR (m + n)
sh target (ADTensorKind (TKR2 n r))
target (TKR2 (p + n) r)
t IxROf target m -> IxROf target p
f)
DeltaAppendR Delta target (TKR2 (1 + n) r)
d Delta target (TKR2 (1 + n) r)
e -> case Delta target (TKR2 (1 + n) r) -> FullShapeTK (TKR2 (1 + n) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR2 (1 + n) r)
d of
FTKR IShR n
_ FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
let (ADMap target
s2, target (ADTensorKind (TKR2 n r))
t) = IMap target
-> ADMap target
-> Delta target (TKR2 n r)
-> (ADMap target, target (ADTensorKind (TKR2 n r)))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwdSame IMap target
params ADMap target
s Delta target (TKR2 n r)
Delta target (TKR2 (1 + n) r)
d
(ADMap target
s3, target (ADTensorKind (TKR2 n r))
u) = IMap target
-> ADMap target
-> Delta target (TKR2 n r)
-> (ADMap target, target (ADTensorKind (TKR2 n r)))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwdSame IMap target
params ADMap target
s2 Delta target (TKR2 n r)
Delta target (TKR2 (1 + n) r)
e
in (ADMap target
s3, target (TKR2 (1 + n) r)
-> target (TKR2 (1 + n) r) -> target (TKR2 (1 + n) r)
forall (n :: Nat) (x :: TK).
KnownSTK x =>
target (TKR2 (1 + n) x)
-> target (TKR2 (1 + n) x) -> target (TKR2 (1 + n) x)
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKR2 (1 + n) x)
-> target (TKR2 (1 + n) x) -> target (TKR2 (1 + n) x)
trappend target (ADTensorKind (TKR2 n r))
target (TKR2 (1 + n) r)
t target (ADTensorKind (TKR2 n r))
target (TKR2 (1 + n) r)
u)
DeltaSliceR Int
i Int
n Delta target (TKR2 (1 + n) r)
d -> case Delta target (TKR2 (1 + n) r) -> FullShapeTK (TKR2 (1 + n) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR2 (1 + n) r)
d of
FTKR IShR n
_ FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
(target (TKR2 (1 + n) r) -> target (ADTensorKind y))
-> (ADMap target, target (TKR2 (1 + n) r))
-> (ADMap target, target (ADTensorKind y))
forall b c d. (b -> c) -> (d, b) -> (d, c)
forall (a :: Type -> Type -> Type) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (Int -> Int -> target (TKR2 (1 + n) r) -> target (TKR2 (1 + n) r)
forall (n :: Nat) (x :: TK).
KnownSTK x =>
Int -> Int -> target (TKR2 (1 + n) x) -> target (TKR2 (1 + n) x)
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
Int -> Int -> target (TKR2 (1 + n) x) -> target (TKR2 (1 + n) x)
trslice Int
i Int
n) ((ADMap target, target (TKR2 (1 + n) r))
-> (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (TKR2 (1 + n) r))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$ IMap target
-> ADMap target
-> Delta target (TKR2 n r)
-> (ADMap target, target (ADTensorKind (TKR2 n r)))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwdSame IMap target
params ADMap target
s Delta target (TKR2 n r)
Delta target (TKR2 (1 + n) r)
d
DeltaReverseR Delta target (TKR2 (1 + n) r)
d -> case Delta target (TKR2 (1 + n) r) -> FullShapeTK (TKR2 (1 + n) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR2 (1 + n) r)
d of
FTKR IShR n
_ FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
(target (TKR2 (1 + n) r) -> target (ADTensorKind y))
-> (ADMap target, target (TKR2 (1 + n) r))
-> (ADMap target, target (ADTensorKind y))
forall b c d. (b -> c) -> (d, b) -> (d, c)
forall (a :: Type -> Type -> Type) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second target (TKR2 (1 + n) r) -> target (ADTensorKind y)
target (TKR2 (1 + n) r) -> target (TKR2 (1 + n) r)
forall (n :: Nat) (x :: TK).
KnownSTK x =>
target (TKR2 (1 + n) x) -> target (TKR2 (1 + n) x)
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKR2 (1 + n) x) -> target (TKR2 (1 + n) x)
trreverse ((ADMap target, target (TKR2 (1 + n) r))
-> (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (TKR2 (1 + n) r))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$ IMap target
-> ADMap target
-> Delta target (TKR2 n r)
-> (ADMap target, target (ADTensorKind (TKR2 n r)))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwdSame IMap target
params ADMap target
s Delta target (TKR2 n r)
Delta target (TKR2 (1 + n) r)
d
DeltaTransposeR PermR
perm Delta target (TKR2 n r)
d -> case Delta target (TKR2 n r) -> FullShapeTK (TKR2 n r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR2 n r)
d of
FTKR IShR n
_ FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
(target (TKR2 n r) -> target (ADTensorKind y))
-> (ADMap target, target (TKR2 n r))
-> (ADMap target, target (ADTensorKind y))
forall b c d. (b -> c) -> (d, b) -> (d, c)
forall (a :: Type -> Type -> Type) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (PermR -> target (TKR2 n r) -> target (TKR2 n r)
forall (n :: Nat) (x :: TK).
KnownSTK x =>
PermR -> target (TKR2 n x) -> target (TKR2 n x)
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
PermR -> target (TKR2 n x) -> target (TKR2 n x)
trtranspose PermR
perm) ((ADMap target, target (TKR2 n r))
-> (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (TKR2 n r))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$ IMap target
-> ADMap target
-> Delta target (TKR2 n r)
-> (ADMap target, target (ADTensorKind (TKR2 n r)))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwdSame IMap target
params ADMap target
s Delta target (TKR2 n r)
d
DeltaReshapeR IShR m
sh2 Delta target (TKR2 n r)
d -> case Delta target (TKR2 n r) -> FullShapeTK (TKR2 n r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR2 n r)
d of
FTKR IShR n
_sh FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
(target (TKR2 n r) -> target (ADTensorKind y))
-> (ADMap target, target (TKR2 n r))
-> (ADMap target, target (ADTensorKind y))
forall b c d. (b -> c) -> (d, b) -> (d, c)
forall (a :: Type -> Type -> Type) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (IShR m -> target (TKR2 n r) -> target (TKR2 m r)
forall (n :: Nat) (m :: Nat) (x :: TK).
KnownSTK x =>
IShR m -> target (TKR2 n x) -> target (TKR2 m x)
forall (target :: Target) (n :: Nat) (m :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
IShR m -> target (TKR2 n x) -> target (TKR2 m x)
trreshape IShR m
sh2) ((ADMap target, target (TKR2 n r))
-> (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (TKR2 n r))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$ IMap target
-> ADMap target
-> Delta target (TKR2 n r)
-> (ADMap target, target (ADTensorKind (TKR2 n r)))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwdSame IMap target
params ADMap target
s Delta target (TKR2 n r)
d
d0 :: Delta target y
d0@(DeltaCastS Delta target (TKS sh r1)
d) -> case Delta target (TKS sh r1) -> FullShapeTK (TKS sh r1)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS sh r1)
d of
FullShapeTK (TKS sh r1)
y -> case FullShapeTK (TKS sh r1)
-> FullShapeTK (TKS2 sh (TKScalar (ADTensorScalar r1)))
-> Maybe
((:~:) @TK (TKS sh r1) (TKS2 sh (TKScalar (ADTensorScalar r1))))
forall (y1 :: TK) (y2 :: TK).
FullShapeTK y1 -> FullShapeTK y2 -> Maybe ((:~:) @TK y1 y2)
matchingFTK FullShapeTK (TKS sh r1)
y (FullShapeTK (TKS sh r1) -> FullShapeTK (ADTensorKind (TKS sh r1))
forall (y :: TK). FullShapeTK y -> FullShapeTK (ADTensorKind y)
adFTK FullShapeTK (TKS sh r1)
y) of
Just (:~:) @TK (TKS sh r1) (TKS2 sh (TKScalar (ADTensorScalar r1)))
Refl -> (target (TKS sh r1) -> target (ADTensorKind y))
-> (ADMap target, target (TKS sh r1))
-> (ADMap target, target (ADTensorKind y))
forall b c d. (b -> c) -> (d, b) -> (d, c)
forall (a :: Type -> Type -> Type) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second target (TKS sh r1) -> target (ADTensorKind y)
target (TKS sh r1) -> target (TKS2 sh (TKScalar r2))
forall r1 r2 (sh :: [Nat]).
(RealFrac r1, GoodScalar r1, RealFrac r2, GoodScalar r2) =>
target (TKS sh r1) -> target (TKS sh r2)
forall (target :: Target) r1 r2 (sh :: [Nat]).
(BaseTensor target, RealFrac r1, GoodScalar r1, RealFrac r2,
GoodScalar r2) =>
target (TKS sh r1) -> target (TKS sh r2)
tscast ((ADMap target, target (TKS sh r1))
-> (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (TKS sh r1))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$ IMap target
-> ADMap target
-> Delta target (TKS sh r1)
-> (ADMap target, target (ADTensorKind (TKS sh r1)))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwdSame IMap target
params ADMap target
s Delta target (TKS sh r1)
d
Maybe
((:~:) @TK (TKS sh r1) (TKS2 sh (TKScalar (ADTensorScalar r1))))
_ -> (ADMap target
s, FullShapeTK (ADTensorKind y) -> target (ADTensorKind y)
forall (y :: TK). FullShapeTK y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
FullShapeTK y -> target y
tdefTarget (FullShapeTK (ADTensorKind y) -> target (ADTensorKind y))
-> FullShapeTK (ADTensorKind y) -> target (ADTensorKind y)
forall a b. (a -> b) -> a -> b
$ FullShapeTK y -> FullShapeTK (ADTensorKind y)
forall (y :: TK). FullShapeTK y -> FullShapeTK (ADTensorKind y)
adFTK (FullShapeTK y -> FullShapeTK (ADTensorKind y))
-> FullShapeTK y -> FullShapeTK (ADTensorKind y)
forall a b. (a -> b) -> a -> b
$ Delta target y -> FullShapeTK y
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target y
d0)
DeltaSum0S (DeltaZero (FTKS ShS sh
_ FullShapeTK x
x)) -> (ADMap target
s, FullShapeTK (TKS2 ('[] @Nat) x) -> target (TKS2 ('[] @Nat) x)
forall (y :: TK). FullShapeTK y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
FullShapeTK y -> target y
tdefTarget (ShS ('[] @Nat) -> FullShapeTK x -> FullShapeTK (TKS2 ('[] @Nat) x)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS ShS ('[] @Nat)
forall (sh :: [Nat]).
((sh :: [Nat]) ~ ('[] @Nat :: [Nat])) =>
ShS sh
ZSS FullShapeTK x
x))
DeltaSum0S Delta target (TKS2 sh r)
d -> case Delta target (TKS2 sh r) -> FullShapeTK (TKS2 sh r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS2 sh r)
d of
FTKS ShS sh
sh FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
ShS sh
-> (KnownShS sh => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS sh
sh ((KnownShS sh => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownShS sh => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
(target (TKS2 sh r) -> target (ADTensorKind y))
-> (ADMap target, target (TKS2 sh r))
-> (ADMap target, target (ADTensorKind y))
forall b c d. (b -> c) -> (d, b) -> (d, c)
forall (a :: Type -> Type -> Type) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second target (TKS2 sh r) -> target (ADTensorKind y)
target (TKS2 sh r) -> target (TKS2 ('[] @Nat) r)
forall (sh :: [Nat]) (x :: TK).
(KnownShS sh, KnownSTK x) =>
target (TKS2 sh x) -> target (TKS2 ('[] @Nat) x)
forall (target :: Target) (sh :: [Nat]) (x :: TK).
(BaseTensor target, KnownShS sh, KnownSTK x) =>
target (TKS2 sh x) -> target (TKS2 ('[] @Nat) x)
tssum0 ((ADMap target, target (TKS2 sh r))
-> (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (TKS2 sh r))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$ IMap target
-> ADMap target
-> Delta target (TKS2 sh r)
-> (ADMap target, target (ADTensorKind (TKS2 sh r)))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwdSame IMap target
params ADMap target
s Delta target (TKS2 sh r)
d
DeltaDot0S target (TKS sh r)
_ DeltaZero{} -> (ADMap target
s, Shaped ('[] @Nat) r -> target (TKS2 ('[] @Nat) (TKScalar r))
forall r (sh :: [Nat]).
GoodScalar r =>
Shaped sh r -> target (TKS sh r)
forall (target :: Target) r (sh :: [Nat]).
(BaseTensor target, GoodScalar r) =>
Shaped sh r -> target (TKS sh r)
tsconcrete (Shaped ('[] @Nat) r -> target (TKS2 ('[] @Nat) (TKScalar r)))
-> Shaped ('[] @Nat) r -> target (TKS2 ('[] @Nat) (TKScalar r))
forall a b. (a -> b) -> a -> b
$ r -> Shaped ('[] @Nat) r
forall a. Elt a => a -> Shaped ('[] @Nat) a
Nested.sscalar r
0)
DeltaDot0S target (TKS sh r)
v Delta target (TKS sh r)
d -> case Delta target (TKS sh r) -> FullShapeTK (TKS sh r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS sh r)
d of
FTKS ShS sh
sh FullShapeTK x
FTKScalar ->
ShS sh
-> (KnownShS sh => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS sh
sh ((KnownShS sh => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownShS sh => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
(target (TKS sh r) -> target (ADTensorKind y))
-> (ADMap target, target (TKS sh r))
-> (ADMap target, target (ADTensorKind y))
forall b c d. (b -> c) -> (d, b) -> (d, c)
forall (a :: Type -> Type -> Type) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (target (TKS sh r)
-> target (TKS sh r) -> target (TKS2 ('[] @Nat) (TKScalar r))
forall (sh :: [Nat]) r.
(KnownShS sh, GoodScalar r) =>
target (TKS sh r) -> target (TKS sh r) -> target (TKS ('[] @Nat) r)
forall (target :: Target) (sh :: [Nat]) r.
(BaseTensor target, KnownShS sh, GoodScalar r) =>
target (TKS sh r) -> target (TKS sh r) -> target (TKS ('[] @Nat) r)
tsdot0 target (TKS sh r)
v) ((ADMap target, target (TKS sh r))
-> (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (TKS sh r))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$ IMap target
-> ADMap target
-> Delta target (TKS sh r)
-> (ADMap target, target (ADTensorKind (TKS sh r)))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwdSame IMap target
params ADMap target
s Delta target (TKS sh r)
d
DeltaIndexS ShS shn
shn Delta target (TKS2 ((++) @Nat shm shn) r)
d IxSOf target shm
ix -> case Delta target (TKS2 ((++) @Nat shm shn) r)
-> FullShapeTK (TKS2 ((++) @Nat shm shn) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS2 ((++) @Nat shm shn) r)
d of
FTKS ShS sh
_ FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
ShS shn
-> (KnownShS shn => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS shn
shn ((KnownShS shn => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownShS shn => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
ShS shm
-> (KnownShS shm => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS (IxSOf target shm -> ShS shm
forall (sh :: [Nat]) i. IxS sh i -> ShS sh
shsFromIxS IxSOf target shm
ix) ((KnownShS shm => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownShS shm => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
(target (TKS2 ((++) @Nat shm shn) r) -> target (ADTensorKind y))
-> (ADMap target, target (TKS2 ((++) @Nat shm shn) r))
-> (ADMap target, target (ADTensorKind y))
forall b c d. (b -> c) -> (d, b) -> (d, c)
forall (a :: Type -> Type -> Type) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (target (TKS2 ((++) @Nat shm shn) r)
-> IxSOf target shm -> target (TKS2 shn r)
forall (shm :: [Nat]) (shn :: [Nat]) (x :: TK).
(KnownShS shm, KnownShS shn, KnownSTK x) =>
target (TKS2 ((++) @Nat shm shn) x)
-> IxSOf target shm -> target (TKS2 shn x)
forall (target :: Target) (shm :: [Nat]) (shn :: [Nat]) (x :: TK).
(BaseTensor target, KnownShS shm, KnownShS shn, KnownSTK x) =>
target (TKS2 ((++) @Nat shm shn) x)
-> IxSOf target shm -> target (TKS2 shn x)
`tsindex` IxSOf target shm
ix) ((ADMap target, target (TKS2 ((++) @Nat shm shn) r))
-> (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (TKS2 ((++) @Nat shm shn) r))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$ IMap target
-> ADMap target
-> Delta target (TKS2 sh r)
-> (ADMap target, target (ADTensorKind (TKS2 sh r)))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwdSame IMap target
params ADMap target
s Delta target (TKS2 sh r)
Delta target (TKS2 ((++) @Nat shm shn) r)
d
DeltaScatterS @shm @shn ShS shm
shm ShS shn
shn ShS shp
shp Delta target (TKS2 ((++) @Nat shm shn) r)
d IxSOf target shm -> IxSOf target shp
f -> case Delta target (TKS2 ((++) @Nat shm shn) r)
-> FullShapeTK (TKS2 ((++) @Nat shm shn) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS2 ((++) @Nat shm shn) r)
d of
FTKS ShS sh
_ FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
ShS shm
-> (KnownShS shm => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS shm
shm ((KnownShS shm => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownShS shm => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
ShS shn
-> (KnownShS shn => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS shn
shn ((KnownShS shn => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownShS shn => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
ShS shp
-> (KnownShS shp => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS shp
shp ((KnownShS shp => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownShS shp => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
let (ADMap target
s2, target (ADTensorKind (TKS2 sh r))
t) = IMap target
-> ADMap target
-> Delta target (TKS2 sh r)
-> (ADMap target, target (ADTensorKind (TKS2 sh r)))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwdSame IMap target
params ADMap target
s Delta target (TKS2 sh r)
Delta target (TKS2 ((++) @Nat shm shn) r)
d
in (ADMap target
s2, forall (target :: Target) (shm :: [Nat]) (shn :: [Nat])
(shp :: [Nat]) (x :: TK).
(BaseTensor target, KnownShS shm, KnownShS shn, KnownShS shp,
KnownSTK x) =>
target (TKS2 ((++) @Nat shm shn) x)
-> (IxSOf target shm -> IxSOf target shp)
-> target (TKS2 ((++) @Nat shp shn) x)
tsscatter @_ @shm @shn target (ADTensorKind (TKS2 sh r))
target (TKS2 ((++) @Nat shm shn) r)
t IxSOf target shm -> IxSOf target shp
f)
DeltaGatherS @shm @shn ShS shm
shm ShS shn
shn ShS shp
shp Delta target (TKS2 ((++) @Nat shp shn) r)
d IxSOf target shm -> IxSOf target shp
f -> case Delta target (TKS2 ((++) @Nat shp shn) r)
-> FullShapeTK (TKS2 ((++) @Nat shp shn) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS2 ((++) @Nat shp shn) r)
d of
FTKS ShS sh
_ FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
ShS shm
-> (KnownShS shm => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS shm
shm ((KnownShS shm => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownShS shm => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
ShS shn
-> (KnownShS shn => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS shn
shn ((KnownShS shn => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownShS shn => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
ShS shp
-> (KnownShS shp => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (sh :: [Nat]) r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS ShS shp
shp ((KnownShS shp => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownShS shp => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
let (ADMap target
s2, target (ADTensorKind (TKS2 sh r))
t) = IMap target
-> ADMap target
-> Delta target (TKS2 sh r)
-> (ADMap target, target (ADTensorKind (TKS2 sh r)))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwdSame IMap target
params ADMap target
s Delta target (TKS2 sh r)
Delta target (TKS2 ((++) @Nat shp shn) r)
d
in (ADMap target
s2, forall (target :: Target) (shm :: [Nat]) (shn :: [Nat])
(shp :: [Nat]) (x :: TK).
(BaseTensor target, KnownShS shm, KnownShS shn, KnownShS shp,
KnownSTK x) =>
target (TKS2 ((++) @Nat shp shn) x)
-> (IxSOf target shm -> IxSOf target shp)
-> target (TKS2 ((++) @Nat shm shn) x)
tsgather @_ @shm @shn target (ADTensorKind (TKS2 sh r))
target (TKS2 ((++) @Nat shp shn) r)
t IxSOf target shm -> IxSOf target shp
f)
DeltaAppendS Delta target (TKS2 ((':) @Nat m sh) r)
d Delta target (TKS2 ((':) @Nat n sh) r)
e -> case Delta target (TKS2 ((':) @Nat m sh) r)
-> FullShapeTK (TKS2 ((':) @Nat m sh) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS2 ((':) @Nat m sh) r)
d of
FTKS ShS sh
_ FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
let (ADMap target
s2, target (ADTensorKind (TKS2 ((':) @Nat m sh) r))
t) = IMap target
-> ADMap target
-> Delta target (TKS2 ((':) @Nat m sh) r)
-> (ADMap target, target (ADTensorKind (TKS2 ((':) @Nat m sh) r)))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwdSame IMap target
params ADMap target
s Delta target (TKS2 ((':) @Nat m sh) r)
d
(ADMap target
s3, target (ADTensorKind (TKS2 ((':) @Nat n sh) r))
u) = IMap target
-> ADMap target
-> Delta target (TKS2 ((':) @Nat n sh) r)
-> (ADMap target, target (ADTensorKind (TKS2 ((':) @Nat n sh) r)))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwdSame IMap target
params ADMap target
s2 Delta target (TKS2 ((':) @Nat n sh) r)
e
in (ADMap target
s3, target (TKS2 ((':) @Nat m sh) r)
-> target (TKS2 ((':) @Nat n sh) r)
-> target (TKS2 ((':) @Nat (m + n) sh) r)
forall (m :: Nat) (n :: Nat) (sh :: [Nat]) (x :: TK).
KnownSTK x =>
target (TKS2 ((':) @Nat m sh) x)
-> target (TKS2 ((':) @Nat n sh) x)
-> target (TKS2 ((':) @Nat (m + n) sh) x)
forall (target :: Target) (m :: Nat) (n :: Nat) (sh :: [Nat])
(x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKS2 ((':) @Nat m sh) x)
-> target (TKS2 ((':) @Nat n sh) x)
-> target (TKS2 ((':) @Nat (m + n) sh) x)
tsappend target (ADTensorKind (TKS2 ((':) @Nat m sh) r))
target (TKS2 ((':) @Nat m sh) r)
t target (ADTensorKind (TKS2 ((':) @Nat n sh) r))
target (TKS2 ((':) @Nat n sh) r)
u)
DeltaSliceS SNat i
i SNat n
n SNat k
k Delta target (TKS2 ((':) @Nat ((i + n) + k) sh) r)
d -> case Delta target (TKS2 ((':) @Nat ((i + n) + k) sh) r)
-> FullShapeTK (TKS2 ((':) @Nat ((i + n) + k) sh) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS2 ((':) @Nat ((i + n) + k) sh) r)
d of
FTKS ShS sh
_ FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
(target (TKS2 ((':) @Nat ((i + n) + k) sh) r)
-> target (ADTensorKind y))
-> (ADMap target, target (TKS2 ((':) @Nat ((i + n) + k) sh) r))
-> (ADMap target, target (ADTensorKind y))
forall b c d. (b -> c) -> (d, b) -> (d, c)
forall (a :: Type -> Type -> Type) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (SNat i
-> SNat n
-> SNat k
-> target (TKS2 ((':) @Nat ((i + n) + k) sh) r)
-> target (TKS2 ((':) @Nat n sh) r)
forall (i :: Nat) (n :: Nat) (k :: Nat) (sh :: [Nat]) (x :: TK).
KnownSTK x =>
SNat i
-> SNat n
-> SNat k
-> target (TKS2 ((':) @Nat ((i + n) + k) sh) x)
-> target (TKS2 ((':) @Nat n sh) x)
forall (target :: Target) (i :: Nat) (n :: Nat) (k :: Nat)
(sh :: [Nat]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
SNat i
-> SNat n
-> SNat k
-> target (TKS2 ((':) @Nat ((i + n) + k) sh) x)
-> target (TKS2 ((':) @Nat n sh) x)
tsslice SNat i
i SNat n
n SNat k
k) ((ADMap target, target (TKS2 ((':) @Nat ((i + n) + k) sh) r))
-> (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (TKS2 ((':) @Nat ((i + n) + k) sh) r))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$ IMap target
-> ADMap target
-> Delta target (TKS2 ((':) @Nat ((i + n) + k) sh) r)
-> (ADMap target,
target (ADTensorKind (TKS2 ((':) @Nat ((i + n) + k) sh) r)))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwdSame IMap target
params ADMap target
s Delta target (TKS2 ((':) @Nat ((i + n) + k) sh) r)
d
DeltaReverseS Delta target (TKS2 ((':) @Nat n sh) r)
d -> case Delta target (TKS2 ((':) @Nat n sh) r)
-> FullShapeTK (TKS2 ((':) @Nat n sh) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS2 ((':) @Nat n sh) r)
d of
FTKS ShS sh
_ FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
(target (TKS2 ((':) @Nat n sh) r) -> target (ADTensorKind y))
-> (ADMap target, target (TKS2 ((':) @Nat n sh) r))
-> (ADMap target, target (ADTensorKind y))
forall b c d. (b -> c) -> (d, b) -> (d, c)
forall (a :: Type -> Type -> Type) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second target (TKS2 ((':) @Nat n sh) r) -> target (ADTensorKind y)
target (TKS2 ((':) @Nat n sh) r)
-> target (TKS2 ((':) @Nat n sh) r)
forall (n :: Nat) (sh :: [Nat]) (x :: TK).
KnownSTK x =>
target (TKS2 ((':) @Nat n sh) x)
-> target (TKS2 ((':) @Nat n sh) x)
forall (target :: Target) (n :: Nat) (sh :: [Nat]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKS2 ((':) @Nat n sh) x)
-> target (TKS2 ((':) @Nat n sh) x)
tsreverse ((ADMap target, target (TKS2 ((':) @Nat n sh) r))
-> (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (TKS2 ((':) @Nat n sh) r))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$ IMap target
-> ADMap target
-> Delta target (TKS2 ((':) @Nat n sh) r)
-> (ADMap target, target (ADTensorKind (TKS2 ((':) @Nat n sh) r)))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwdSame IMap target
params ADMap target
s Delta target (TKS2 ((':) @Nat n sh) r)
d
DeltaTransposeS Perm perm
perm Delta target (TKS2 sh r)
d -> case Delta target (TKS2 sh r) -> FullShapeTK (TKS2 sh r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS2 sh r)
d of
FTKS ShS sh
_ FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
(target (TKS2 sh r) -> target (ADTensorKind y))
-> (ADMap target, target (TKS2 sh r))
-> (ADMap target, target (ADTensorKind y))
forall b c d. (b -> c) -> (d, b) -> (d, c)
forall (a :: Type -> Type -> Type) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (Perm perm
-> target (TKS2 sh r)
-> target (TKS2 (PermutePrefix @Nat perm sh) r)
forall (perm :: [Nat]) (sh :: [Nat]) (x :: TK).
(IsPermutation perm, (<=) @Nat (Rank @Nat perm) (Rank @Nat sh),
KnownSTK x) =>
Perm perm
-> target (TKS2 sh x)
-> target (TKS2 (PermutePrefix @Nat perm sh) x)
forall (target :: Target) (perm :: [Nat]) (sh :: [Nat]) (x :: TK).
(BaseTensor target, IsPermutation perm,
(<=) @Nat (Rank @Nat perm) (Rank @Nat sh), KnownSTK x) =>
Perm perm
-> target (TKS2 sh x)
-> target (TKS2 (PermutePrefix @Nat perm sh) x)
tstranspose Perm perm
perm) ((ADMap target, target (TKS2 sh r))
-> (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (TKS2 sh r))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$ IMap target
-> ADMap target
-> Delta target (TKS2 sh r)
-> (ADMap target, target (ADTensorKind (TKS2 sh r)))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwdSame IMap target
params ADMap target
s Delta target (TKS2 sh r)
d
DeltaReshapeS ShS sh2
sh2 Delta target (TKS2 sh r)
d -> case Delta target (TKS2 sh r) -> FullShapeTK (TKS2 sh r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS2 sh r)
d of
FTKS ShS sh
_ FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
(target (TKS2 sh r) -> target (ADTensorKind y))
-> (ADMap target, target (TKS2 sh r))
-> (ADMap target, target (ADTensorKind y))
forall b c d. (b -> c) -> (d, b) -> (d, c)
forall (a :: Type -> Type -> Type) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (ShS sh2 -> target (TKS2 sh r) -> target (TKS2 sh2 r)
forall (sh :: [Nat]) (sh2 :: [Nat]) (x :: TK).
((Product sh :: Nat) ~ (Product sh2 :: Nat), KnownSTK x) =>
ShS sh2 -> target (TKS2 sh x) -> target (TKS2 sh2 x)
forall (target :: Target) (sh :: [Nat]) (sh2 :: [Nat]) (x :: TK).
(BaseTensor target, (Product sh :: Nat) ~ (Product sh2 :: Nat),
KnownSTK x) =>
ShS sh2 -> target (TKS2 sh x) -> target (TKS2 sh2 x)
tsreshape ShS sh2
sh2) ((ADMap target, target (TKS2 sh r))
-> (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (TKS2 sh r))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$ IMap target
-> ADMap target
-> Delta target (TKS2 sh r)
-> (ADMap target, target (ADTensorKind (TKS2 sh r)))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwdSame IMap target
params ADMap target
s Delta target (TKS2 sh r)
d
d0 :: Delta target y
d0@(DeltaCastX Delta target (TKX sh r1)
d) -> case Delta target (TKX sh r1) -> FullShapeTK (TKX sh r1)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX sh r1)
d of
FullShapeTK (TKX sh r1)
y -> case FullShapeTK (TKX sh r1)
-> FullShapeTK (TKX2 sh (TKScalar (ADTensorScalar r1)))
-> Maybe
((:~:) @TK (TKX sh r1) (TKX2 sh (TKScalar (ADTensorScalar r1))))
forall (y1 :: TK) (y2 :: TK).
FullShapeTK y1 -> FullShapeTK y2 -> Maybe ((:~:) @TK y1 y2)
matchingFTK FullShapeTK (TKX sh r1)
y (FullShapeTK (TKX sh r1) -> FullShapeTK (ADTensorKind (TKX sh r1))
forall (y :: TK). FullShapeTK y -> FullShapeTK (ADTensorKind y)
adFTK FullShapeTK (TKX sh r1)
y) of
Just (:~:) @TK (TKX sh r1) (TKX2 sh (TKScalar (ADTensorScalar r1)))
Refl -> (target (TKX sh r1) -> target (ADTensorKind y))
-> (ADMap target, target (TKX sh r1))
-> (ADMap target, target (ADTensorKind y))
forall b c d. (b -> c) -> (d, b) -> (d, c)
forall (a :: Type -> Type -> Type) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second target (TKX sh r1) -> target (ADTensorKind y)
target (TKX sh r1) -> target (TKX2 sh (TKScalar r2))
forall r1 r2 (sh :: [Maybe Nat]).
(RealFrac r1, GoodScalar r1, RealFrac r2, GoodScalar r2) =>
target (TKX sh r1) -> target (TKX sh r2)
forall (target :: Target) r1 r2 (sh :: [Maybe Nat]).
(BaseTensor target, RealFrac r1, GoodScalar r1, RealFrac r2,
GoodScalar r2) =>
target (TKX sh r1) -> target (TKX sh r2)
txcast ((ADMap target, target (TKX sh r1))
-> (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (TKX sh r1))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$ IMap target
-> ADMap target
-> Delta target (TKX sh r1)
-> (ADMap target, target (ADTensorKind (TKX sh r1)))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwdSame IMap target
params ADMap target
s Delta target (TKX sh r1)
d
Maybe
((:~:) @TK (TKX sh r1) (TKX2 sh (TKScalar (ADTensorScalar r1))))
_ -> (ADMap target
s, FullShapeTK (ADTensorKind y) -> target (ADTensorKind y)
forall (y :: TK). FullShapeTK y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
FullShapeTK y -> target y
tdefTarget (FullShapeTK (ADTensorKind y) -> target (ADTensorKind y))
-> FullShapeTK (ADTensorKind y) -> target (ADTensorKind y)
forall a b. (a -> b) -> a -> b
$ FullShapeTK y -> FullShapeTK (ADTensorKind y)
forall (y :: TK). FullShapeTK y -> FullShapeTK (ADTensorKind y)
adFTK (FullShapeTK y -> FullShapeTK (ADTensorKind y))
-> FullShapeTK y -> FullShapeTK (ADTensorKind y)
forall a b. (a -> b) -> a -> b
$ Delta target y -> FullShapeTK y
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target y
d0)
DeltaSum0X (DeltaZero (FTKX IShX sh
_ FullShapeTK x
x)) -> (ADMap target
s, FullShapeTK (TKX2 ('[] @(Maybe Nat)) x)
-> target (TKX2 ('[] @(Maybe Nat)) x)
forall (y :: TK). FullShapeTK y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
FullShapeTK y -> target y
tdefTarget (IShX ('[] @(Maybe Nat))
-> FullShapeTK x -> FullShapeTK (TKX2 ('[] @(Maybe Nat)) x)
forall (sh :: [Maybe Nat]) (x :: TK).
IShX sh -> FullShapeTK x -> FullShapeTK (TKX2 sh x)
FTKX IShX ('[] @(Maybe Nat))
forall (sh :: [Maybe Nat]) i.
((sh :: [Maybe Nat]) ~ ('[] @(Maybe Nat) :: [Maybe Nat])) =>
ShX sh i
ZSX FullShapeTK x
x))
DeltaSum0X Delta target (TKX2 sh r)
d -> case Delta target (TKX2 sh r) -> FullShapeTK (TKX2 sh r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX2 sh r)
d of
FTKX IShX sh
sh FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
StaticShX sh
-> (KnownShX sh => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (sh :: [Maybe Nat]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX (IShX sh -> StaticShX sh
forall (sh :: [Maybe Nat]) i. ShX sh i -> StaticShX sh
ssxFromShX IShX sh
sh) ((KnownShX sh => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownShX sh => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
(target (TKX2 sh r) -> target (ADTensorKind y))
-> (ADMap target, target (TKX2 sh r))
-> (ADMap target, target (ADTensorKind y))
forall b c d. (b -> c) -> (d, b) -> (d, c)
forall (a :: Type -> Type -> Type) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second target (TKX2 sh r) -> target (ADTensorKind y)
target (TKX2 sh r) -> target (TKX2 ('[] @(Maybe Nat)) r)
forall (sh :: [Maybe Nat]) (x :: TK).
(KnownShX sh, KnownSTK x, ConvertTensor target) =>
target (TKX2 sh x) -> target (TKX2 ('[] @(Maybe Nat)) x)
forall (target :: Target) (sh :: [Maybe Nat]) (x :: TK).
(BaseTensor target, KnownShX sh, KnownSTK x,
ConvertTensor target) =>
target (TKX2 sh x) -> target (TKX2 ('[] @(Maybe Nat)) x)
txsum0 ((ADMap target, target (TKX2 sh r))
-> (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (TKX2 sh r))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$ IMap target
-> ADMap target
-> Delta target (TKX2 sh r)
-> (ADMap target, target (ADTensorKind (TKX2 sh r)))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwdSame IMap target
params ADMap target
s Delta target (TKX2 sh r)
d
DeltaDot0X target (TKX sh r)
_ DeltaZero{} -> (ADMap target
s, Mixed ('[] @(Maybe Nat)) r
-> target (TKX2 ('[] @(Maybe Nat)) (TKScalar r))
forall r (sh :: [Maybe Nat]).
GoodScalar r =>
Mixed sh r -> target (TKX sh r)
forall (target :: Target) r (sh :: [Maybe Nat]).
(BaseTensor target, GoodScalar r) =>
Mixed sh r -> target (TKX sh r)
txconcrete (Mixed ('[] @(Maybe Nat)) r
-> target (TKX2 ('[] @(Maybe Nat)) (TKScalar r)))
-> Mixed ('[] @(Maybe Nat)) r
-> target (TKX2 ('[] @(Maybe Nat)) (TKScalar r))
forall a b. (a -> b) -> a -> b
$ r -> Mixed ('[] @(Maybe Nat)) r
forall a. Elt a => a -> Mixed ('[] @(Maybe Nat)) a
Nested.mscalar r
0)
DeltaDot0X target (TKX sh r)
v Delta target (TKX sh r)
d -> case Delta target (TKX sh r) -> FullShapeTK (TKX sh r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX sh r)
d of
FTKX IShX sh
sh FullShapeTK x
FTKScalar ->
StaticShX sh
-> (KnownShX sh => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (sh :: [Maybe Nat]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX (IShX sh -> StaticShX sh
forall (sh :: [Maybe Nat]) i. ShX sh i -> StaticShX sh
ssxFromShX IShX sh
sh) ((KnownShX sh => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownShX sh => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
(target (TKX sh r) -> target (ADTensorKind y))
-> (ADMap target, target (TKX sh r))
-> (ADMap target, target (ADTensorKind y))
forall b c d. (b -> c) -> (d, b) -> (d, c)
forall (a :: Type -> Type -> Type) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (target (TKX sh r)
-> target (TKX sh r)
-> target (TKX2 ('[] @(Maybe Nat)) (TKScalar r))
forall (sh :: [Maybe Nat]) r.
(KnownShX sh, GoodScalar r, ConvertTensor target) =>
target (TKX sh r)
-> target (TKX sh r) -> target (TKX ('[] @(Maybe Nat)) r)
forall (target :: Target) (sh :: [Maybe Nat]) r.
(BaseTensor target, KnownShX sh, GoodScalar r,
ConvertTensor target) =>
target (TKX sh r)
-> target (TKX sh r) -> target (TKX ('[] @(Maybe Nat)) r)
txdot0 target (TKX sh r)
v) ((ADMap target, target (TKX sh r))
-> (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (TKX sh r))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$ IMap target
-> ADMap target
-> Delta target (TKX sh r)
-> (ADMap target, target (ADTensorKind (TKX sh r)))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwdSame IMap target
params ADMap target
s Delta target (TKX sh r)
d
DeltaIndexX @shm @shn StaticShX shn
shn Delta target (TKX2 ((++) @(Maybe Nat) shm shn) r)
d IxXOf target shm
ix -> case Delta target (TKX2 ((++) @(Maybe Nat) shm shn) r)
-> FullShapeTK (TKX2 ((++) @(Maybe Nat) shm shn) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX2 ((++) @(Maybe Nat) shm shn) r)
d of
FTKX IShX sh
sh FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
StaticShX shn
-> (KnownShX shn => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (sh :: [Maybe Nat]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX StaticShX shn
shn ((KnownShX shn => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownShX shn => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
StaticShX shm
-> (KnownShX shm => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (sh :: [Maybe Nat]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX (forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]) i.
StaticShX ((++) @(Maybe Nat) sh sh') -> IxX sh i -> StaticShX sh
ssxTakeIx @shm @shn (IShX sh -> StaticShX sh
forall (sh :: [Maybe Nat]) i. ShX sh i -> StaticShX sh
ssxFromShX IShX sh
sh) IxXOf target shm
ix) ((KnownShX shm => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownShX shm => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
(target (TKX2 ((++) @(Maybe Nat) shm shn) r)
-> target (ADTensorKind y))
-> (ADMap target, target (TKX2 ((++) @(Maybe Nat) shm shn) r))
-> (ADMap target, target (ADTensorKind y))
forall b c d. (b -> c) -> (d, b) -> (d, c)
forall (a :: Type -> Type -> Type) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (target (TKX2 ((++) @(Maybe Nat) shm shn) r)
-> IxXOf target shm -> target (TKX2 shn r)
forall (sh1 :: [Maybe Nat]) (sh2 :: [Maybe Nat]) (x :: TK).
(KnownShX sh1, KnownShX sh2, KnownSTK x) =>
target (TKX2 ((++) @(Maybe Nat) sh1 sh2) x)
-> IxXOf target sh1 -> target (TKX2 sh2 x)
forall (target :: Target) (sh1 :: [Maybe Nat]) (sh2 :: [Maybe Nat])
(x :: TK).
(BaseTensor target, KnownShX sh1, KnownShX sh2, KnownSTK x) =>
target (TKX2 ((++) @(Maybe Nat) sh1 sh2) x)
-> IxXOf target sh1 -> target (TKX2 sh2 x)
`txindex` IxXOf target shm
ix) ((ADMap target, target (TKX2 ((++) @(Maybe Nat) shm shn) r))
-> (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (TKX2 ((++) @(Maybe Nat) shm shn) r))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$ IMap target
-> ADMap target
-> Delta target (TKX2 sh r)
-> (ADMap target, target (ADTensorKind (TKX2 sh r)))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwdSame IMap target
params ADMap target
s Delta target (TKX2 sh r)
Delta target (TKX2 ((++) @(Maybe Nat) shm shn) r)
d
DeltaScatterX @shm @shn StaticShX shm
shm StaticShX shn
shn StaticShX shp
shp IShX ((++) @(Maybe Nat) shp shn)
sh Delta target (TKX2 ((++) @(Maybe Nat) shm shn) r)
d IxXOf target shm -> IxXOf target shp
f -> case Delta target (TKX2 ((++) @(Maybe Nat) shm shn) r)
-> FullShapeTK (TKX2 ((++) @(Maybe Nat) shm shn) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX2 ((++) @(Maybe Nat) shm shn) r)
d of
FTKX IShX sh
_ FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
StaticShX shm
-> (KnownShX shm => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (sh :: [Maybe Nat]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX StaticShX shm
shm ((KnownShX shm => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownShX shm => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
StaticShX shn
-> (KnownShX shn => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (sh :: [Maybe Nat]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX StaticShX shn
shn ((KnownShX shn => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownShX shn => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
StaticShX shp
-> (KnownShX shp => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (sh :: [Maybe Nat]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX StaticShX shp
shp ((KnownShX shp => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownShX shp => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
let (ADMap target
s2, target (ADTensorKind (TKX2 sh r))
t) = IMap target
-> ADMap target
-> Delta target (TKX2 sh r)
-> (ADMap target, target (ADTensorKind (TKX2 sh r)))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwdSame IMap target
params ADMap target
s Delta target (TKX2 sh r)
Delta target (TKX2 ((++) @(Maybe Nat) shm shn) r)
d
in (ADMap target
s2, forall (target :: Target) (shm :: [Maybe Nat]) (shn :: [Maybe Nat])
(shp :: [Maybe Nat]) (x :: TK).
(BaseTensor target, KnownShX shm, KnownShX shn, KnownShX shp,
KnownSTK x) =>
IShX ((++) @(Maybe Nat) shp shn)
-> target (TKX2 ((++) @(Maybe Nat) shm shn) x)
-> (IxXOf target shm -> IxXOf target shp)
-> target (TKX2 ((++) @(Maybe Nat) shp shn) x)
txscatter @_ @shm @shn IShX ((++) @(Maybe Nat) shp shn)
sh target (ADTensorKind (TKX2 sh r))
target (TKX2 ((++) @(Maybe Nat) shm shn) r)
t IxXOf target shm -> IxXOf target shp
f)
DeltaGatherX @shm @shn StaticShX shm
shm StaticShX shn
shn StaticShX shp
shp IShX ((++) @(Maybe Nat) shm shn)
sh Delta target (TKX2 ((++) @(Maybe Nat) shp shn) r)
d IxXOf target shm -> IxXOf target shp
f -> case Delta target (TKX2 ((++) @(Maybe Nat) shp shn) r)
-> FullShapeTK (TKX2 ((++) @(Maybe Nat) shp shn) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX2 ((++) @(Maybe Nat) shp shn) r)
d of
FTKX IShX sh
_ FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
StaticShX shm
-> (KnownShX shm => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (sh :: [Maybe Nat]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX StaticShX shm
shm ((KnownShX shm => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownShX shm => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
StaticShX shn
-> (KnownShX shn => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (sh :: [Maybe Nat]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX StaticShX shn
shn ((KnownShX shn => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownShX shn => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
StaticShX shp
-> (KnownShX shp => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (sh :: [Maybe Nat]) r.
StaticShX sh -> (KnownShX sh => r) -> r
withKnownShX StaticShX shp
shp ((KnownShX shp => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownShX shp => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
let (ADMap target
s2, target (ADTensorKind (TKX2 sh r))
t) = IMap target
-> ADMap target
-> Delta target (TKX2 sh r)
-> (ADMap target, target (ADTensorKind (TKX2 sh r)))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwdSame IMap target
params ADMap target
s Delta target (TKX2 sh r)
Delta target (TKX2 ((++) @(Maybe Nat) shp shn) r)
d
in (ADMap target
s2, forall (target :: Target) (shm :: [Maybe Nat]) (shn :: [Maybe Nat])
(shp :: [Maybe Nat]) (x :: TK).
(BaseTensor target, KnownShX shm, KnownShX shn, KnownShX shp,
KnownSTK x) =>
IShX ((++) @(Maybe Nat) shm shn)
-> target (TKX2 ((++) @(Maybe Nat) shp shn) x)
-> (IxXOf target shm -> IxXOf target shp)
-> target (TKX2 ((++) @(Maybe Nat) shm shn) x)
txgather @_ @shm @shn IShX ((++) @(Maybe Nat) shm shn)
sh target (ADTensorKind (TKX2 sh r))
target (TKX2 ((++) @(Maybe Nat) shp shn) r)
t IxXOf target shm -> IxXOf target shp
f)
DeltaAppendX Delta target (TKX2 ((':) @(Maybe Nat) ('Just @Nat m) sh) r)
d Delta target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) r)
e -> case Delta target (TKX2 ((':) @(Maybe Nat) ('Just @Nat m) sh) r)
-> FullShapeTK (TKX2 ((':) @(Maybe Nat) ('Just @Nat m) sh) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX2 ((':) @(Maybe Nat) ('Just @Nat m) sh) r)
d of
FTKX IShX sh
_ FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
let (ADMap target
s2, target
(ADTensorKind (TKX2 ((':) @(Maybe Nat) ('Just @Nat m) sh) r))
t) = IMap target
-> ADMap target
-> Delta target (TKX2 ((':) @(Maybe Nat) ('Just @Nat m) sh) r)
-> (ADMap target,
target
(ADTensorKind (TKX2 ((':) @(Maybe Nat) ('Just @Nat m) sh) r)))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwdSame IMap target
params ADMap target
s Delta target (TKX2 ((':) @(Maybe Nat) ('Just @Nat m) sh) r)
d
(ADMap target
s3, target
(ADTensorKind (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) r))
u) = IMap target
-> ADMap target
-> Delta target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) r)
-> (ADMap target,
target
(ADTensorKind (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) r)))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwdSame IMap target
params ADMap target
s2 Delta target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) r)
e
in (ADMap target
s3, target (TKX2 ((':) @(Maybe Nat) ('Just @Nat m) sh) r)
-> target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) r)
-> target (TKX2 ((':) @(Maybe Nat) ('Just @Nat (m + n)) sh) r)
forall (m :: Nat) (n :: Nat) (sh :: [Maybe Nat]) (x :: TK).
KnownSTK x =>
target (TKX2 ((':) @(Maybe Nat) ('Just @Nat m) sh) x)
-> target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
-> target (TKX2 ((':) @(Maybe Nat) ('Just @Nat (m + n)) sh) x)
forall (target :: Target) (m :: Nat) (n :: Nat) (sh :: [Maybe Nat])
(x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKX2 ((':) @(Maybe Nat) ('Just @Nat m) sh) x)
-> target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
-> target (TKX2 ((':) @(Maybe Nat) ('Just @Nat (m + n)) sh) x)
txappend target
(ADTensorKind (TKX2 ((':) @(Maybe Nat) ('Just @Nat m) sh) r))
target (TKX2 ((':) @(Maybe Nat) ('Just @Nat m) sh) r)
t target
(ADTensorKind (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) r))
target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) r)
u)
DeltaSliceX SNat i
i SNat n
n SNat k
k Delta
target (TKX2 ((':) @(Maybe Nat) ('Just @Nat ((i + n) + k)) sh) r)
d -> case Delta
target (TKX2 ((':) @(Maybe Nat) ('Just @Nat ((i + n) + k)) sh) r)
-> FullShapeTK
(TKX2 ((':) @(Maybe Nat) ('Just @Nat ((i + n) + k)) sh) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta
target (TKX2 ((':) @(Maybe Nat) ('Just @Nat ((i + n) + k)) sh) r)
d of
FTKX IShX sh
_ FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
(target (TKX2 ((':) @(Maybe Nat) ('Just @Nat ((i + n) + k)) sh) r)
-> target (ADTensorKind y))
-> (ADMap target,
target (TKX2 ((':) @(Maybe Nat) ('Just @Nat ((i + n) + k)) sh) r))
-> (ADMap target, target (ADTensorKind y))
forall b c d. (b -> c) -> (d, b) -> (d, c)
forall (a :: Type -> Type -> Type) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (SNat i
-> SNat n
-> SNat k
-> target
(TKX2 ((':) @(Maybe Nat) ('Just @Nat ((i + n) + k)) sh) r)
-> target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) r)
forall (i :: Nat) (n :: Nat) (k :: Nat) (sh :: [Maybe Nat])
(x :: TK).
KnownSTK x =>
SNat i
-> SNat n
-> SNat k
-> target
(TKX2 ((':) @(Maybe Nat) ('Just @Nat ((i + n) + k)) sh) x)
-> target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
forall (target :: Target) (i :: Nat) (n :: Nat) (k :: Nat)
(sh :: [Maybe Nat]) (x :: TK).
(BaseTensor target, KnownSTK x) =>
SNat i
-> SNat n
-> SNat k
-> target
(TKX2 ((':) @(Maybe Nat) ('Just @Nat ((i + n) + k)) sh) x)
-> target (TKX2 ((':) @(Maybe Nat) ('Just @Nat n) sh) x)
txslice SNat i
i SNat n
n SNat k
k) ((ADMap target,
target (TKX2 ((':) @(Maybe Nat) ('Just @Nat ((i + n) + k)) sh) r))
-> (ADMap target, target (ADTensorKind y)))
-> (ADMap target,
target (TKX2 ((':) @(Maybe Nat) ('Just @Nat ((i + n) + k)) sh) r))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$ IMap target
-> ADMap target
-> Delta
target (TKX2 ((':) @(Maybe Nat) ('Just @Nat ((i + n) + k)) sh) r)
-> (ADMap target,
target
(ADTensorKind
(TKX2 ((':) @(Maybe Nat) ('Just @Nat ((i + n) + k)) sh) r)))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwdSame IMap target
params ADMap target
s Delta
target (TKX2 ((':) @(Maybe Nat) ('Just @Nat ((i + n) + k)) sh) r)
d
DeltaReverseX Delta target (TKX2 ((':) @(Maybe Nat) mn sh) r)
d -> case Delta target (TKX2 ((':) @(Maybe Nat) mn sh) r)
-> FullShapeTK (TKX2 ((':) @(Maybe Nat) mn sh) r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX2 ((':) @(Maybe Nat) mn sh) r)
d of
FTKX IShX sh
_ FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
(target (TKX2 ((':) @(Maybe Nat) mn sh) r)
-> target (ADTensorKind y))
-> (ADMap target, target (TKX2 ((':) @(Maybe Nat) mn sh) r))
-> (ADMap target, target (ADTensorKind y))
forall b c d. (b -> c) -> (d, b) -> (d, c)
forall (a :: Type -> Type -> Type) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second target (TKX2 ((':) @(Maybe Nat) mn sh) r)
-> target (ADTensorKind y)
target (TKX2 ((':) @(Maybe Nat) mn sh) r)
-> target (TKX2 ((':) @(Maybe Nat) mn sh) r)
forall (mn :: Maybe Nat) (sh :: [Maybe Nat]) (x :: TK).
KnownSTK x =>
target (TKX2 ((':) @(Maybe Nat) mn sh) x)
-> target (TKX2 ((':) @(Maybe Nat) mn sh) x)
forall (target :: Target) (mn :: Maybe Nat) (sh :: [Maybe Nat])
(x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKX2 ((':) @(Maybe Nat) mn sh) x)
-> target (TKX2 ((':) @(Maybe Nat) mn sh) x)
txreverse ((ADMap target, target (TKX2 ((':) @(Maybe Nat) mn sh) r))
-> (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (TKX2 ((':) @(Maybe Nat) mn sh) r))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$ IMap target
-> ADMap target
-> Delta target (TKX2 ((':) @(Maybe Nat) mn sh) r)
-> (ADMap target,
target (ADTensorKind (TKX2 ((':) @(Maybe Nat) mn sh) r)))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwdSame IMap target
params ADMap target
s Delta target (TKX2 ((':) @(Maybe Nat) mn sh) r)
d
DeltaTransposeX Perm perm
perm Delta target (TKX2 sh r)
d -> case Delta target (TKX2 sh r) -> FullShapeTK (TKX2 sh r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX2 sh r)
d of
FTKX IShX sh
_ FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
(target (TKX2 sh r) -> target (ADTensorKind y))
-> (ADMap target, target (TKX2 sh r))
-> (ADMap target, target (ADTensorKind y))
forall b c d. (b -> c) -> (d, b) -> (d, c)
forall (a :: Type -> Type -> Type) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (Perm perm
-> target (TKX2 sh r)
-> target (TKX2 (PermutePrefix @(Maybe Nat) perm sh) r)
forall (perm :: [Nat]) (sh :: [Maybe Nat]) (x :: TK).
(IsPermutation perm,
(<=) @Nat (Rank @Nat perm) (Rank @(Maybe Nat) sh), KnownSTK x) =>
Perm perm
-> target (TKX2 sh x)
-> target (TKX2 (PermutePrefix @(Maybe Nat) perm sh) x)
forall (target :: Target) (perm :: [Nat]) (sh :: [Maybe Nat])
(x :: TK).
(BaseTensor target, IsPermutation perm,
(<=) @Nat (Rank @Nat perm) (Rank @(Maybe Nat) sh), KnownSTK x) =>
Perm perm
-> target (TKX2 sh x)
-> target (TKX2 (PermutePrefix @(Maybe Nat) perm sh) x)
txtranspose Perm perm
perm) ((ADMap target, target (TKX2 sh r))
-> (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (TKX2 sh r))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$ IMap target
-> ADMap target
-> Delta target (TKX2 sh r)
-> (ADMap target, target (ADTensorKind (TKX2 sh r)))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwdSame IMap target
params ADMap target
s Delta target (TKX2 sh r)
d
DeltaReshapeX IShX sh2
sh2 Delta target (TKX2 sh r)
d -> case Delta target (TKX2 sh r) -> FullShapeTK (TKX2 sh r)
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX2 sh r)
d of
FTKX IShX sh
_ FullShapeTK x
x ->
SingletonTK x
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
x) ((KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (KnownSTK x => (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
(target (TKX2 sh r) -> target (ADTensorKind y))
-> (ADMap target, target (TKX2 sh r))
-> (ADMap target, target (ADTensorKind y))
forall b c d. (b -> c) -> (d, b) -> (d, c)
forall (a :: Type -> Type -> Type) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (IShX sh2 -> target (TKX2 sh r) -> target (TKX2 sh2 r)
forall (sh :: [Maybe Nat]) (sh2 :: [Maybe Nat]) (x :: TK).
KnownSTK x =>
IShX sh2 -> target (TKX2 sh x) -> target (TKX2 sh2 x)
forall (target :: Target) (sh :: [Maybe Nat]) (sh2 :: [Maybe Nat])
(x :: TK).
(BaseTensor target, KnownSTK x) =>
IShX sh2 -> target (TKX2 sh x) -> target (TKX2 sh2 x)
txreshape IShX sh2
sh2) ((ADMap target, target (TKX2 sh r))
-> (ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (TKX2 sh r))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$ IMap target
-> ADMap target
-> Delta target (TKX2 sh r)
-> (ADMap target, target (ADTensorKind (TKX2 sh r)))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwdSame IMap target
params ADMap target
s Delta target (TKX2 sh r)
d
DeltaConvert @a TKConversion a1 y
c1 Delta target a1
d ->
(:~:) @TK (ADTensorKind a1) a1
-> (((ADTensorKind a1 :: TK) ~ (a1 :: TK)) =>
(ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall {k} (a :: k) (b :: k) r.
(:~:) @k a b -> (((a :: k) ~ (b :: k)) => r) -> r
gcastWith ((:~:) @TK (ADTensorKind a1) a1
forall {k} (a :: k) (b :: k). (:~:) @k a b
unsafeCoerceRefl :: ADTensorKind a :~: a) ((((ADTensorKind a1 :: TK) ~ (a1 :: TK)) =>
(ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y)))
-> (((ADTensorKind a1 :: TK) ~ (a1 :: TK)) =>
(ADMap target, target (ADTensorKind y)))
-> (ADMap target, target (ADTensorKind y))
forall a b. (a -> b) -> a -> b
$
(target a1 -> target y)
-> (ADMap target, target a1) -> (ADMap target, target y)
forall b c d. (b -> c) -> (d, b) -> (d, c)
forall (a :: Type -> Type -> Type) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (TKConversion a1 y -> SingletonTK a1 -> target a1 -> target y
forall (a :: TK) (b :: TK).
TKConversion a b -> SingletonTK a -> target a -> target b
forall (target :: Target) (a :: TK) (b :: TK).
ConvertTensor target =>
TKConversion a b -> SingletonTK a -> target a -> target b
tconvert TKConversion a1 y
c1 (FullShapeTK a1 -> SingletonTK a1
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK (Delta target a1 -> FullShapeTK a1
forall (target :: Target) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target a1
d)))
(IMap target
-> ADMap target
-> Delta target a1
-> (ADMap target, target (ADTensorKind a1))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target,
(y :: TK) ~ (ADTensorKind y :: TK)) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwdSame IMap target
params ADMap target
s Delta target a1
d)
Delta target y
d -> IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
forall (target :: Target) (y :: TK).
(ADReadyNoLet target, ShareTensor target) =>
IMap target
-> ADMap target
-> Delta target y
-> (ADMap target, target (ADTensorKind y))
evalFwd IMap target
params ADMap target
s Delta target y
d