{-# LANGUAGE UndecidableInstances #-}
module HordeAd.Core.CarriersADVal
(
ADVal, pattern D, dD, dDnotShared
, unDeltaPair, unDeltaPairUnshared, dScale, dAdd
, dSFromR, dSFromX, dXFromS
, ensureToplevelSharing, scaleNotShared, addNotShared, multNotShared
, generateDeltaInputs
) where
import Prelude
import Data.Int (Int64)
import Data.Proxy (Proxy (Proxy))
import Data.Type.Equality ((:~:) (Refl))
import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Shaped.Shape
import Data.Array.Nested.Lemmas
import HordeAd.Core.CarriersConcrete
import HordeAd.Core.Delta
import HordeAd.Core.DeltaFreshId
import HordeAd.Core.Ops
import HordeAd.Core.TensorKind
import HordeAd.Core.Types
type role ADVal nominal nominal
data ADVal (f :: Target) y = ADVal (f y) (Delta f y)
pattern D :: f z -> Delta f z -> ADVal f z
pattern $mD :: forall {r} {f :: TK -> Type} {z :: TK}.
ADVal f z -> (f z -> Delta f z -> r) -> ((# #) -> r) -> r
D t u <- ADVal t u
{-# COMPLETE D #-}
deriving instance (Show (f z), Show (Delta f z))
=> Show (ADVal f z)
dD :: forall f z.
f z -> Delta f z -> ADVal f z
dD :: forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
dD !f z
a !Delta f z
dual = f z -> Delta f z -> ADVal f z
forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
dDnotShared f z
a (Delta f z -> Delta f z
forall (y :: TK) (target :: TK -> Type).
Delta target y -> Delta target y
shareDelta Delta f z
dual)
dDnotShared :: f z -> Delta f z -> ADVal f z
dDnotShared :: forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
dDnotShared = f z -> Delta f z -> ADVal f z
forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
ADVal
unDeltaPair :: Delta target (TKProduct x y) -> (Delta target x, Delta target y)
unDeltaPair :: forall (target :: TK -> Type) (x :: TK) (y :: TK).
Delta target (TKProduct x y) -> (Delta target x, Delta target y)
unDeltaPair (DeltaPair Delta target y
a Delta target z
b) = (Delta target x
Delta target y
a, Delta target y
Delta target z
b)
unDeltaPair (DeltaZero (FTKProduct FullShapeTK y1
ftk1 FullShapeTK z
ftk2)) =
(FullShapeTK x -> Delta target x
forall (b :: TK) (a :: TK -> Type). FullShapeTK b -> Delta a b
DeltaZero FullShapeTK x
FullShapeTK y1
ftk1, FullShapeTK y -> Delta target y
forall (b :: TK) (a :: TK -> Type). FullShapeTK b -> Delta a b
DeltaZero FullShapeTK y
FullShapeTK z
ftk2)
unDeltaPair Delta target (TKProduct x y)
d = let dShared :: Delta target (TKProduct x y)
dShared = Delta target (TKProduct x y) -> Delta target (TKProduct x y)
forall (y :: TK) (target :: TK -> Type).
Delta target y -> Delta target y
shareDelta Delta target (TKProduct x y)
d
in (Delta target (TKProduct x y) -> Delta target x
forall (b :: TK) (z :: TK) (a :: TK -> Type).
Delta a (TKProduct b z) -> Delta a b
DeltaProject1 Delta target (TKProduct x y)
dShared, Delta target (TKProduct x y) -> Delta target y
forall (y :: TK) (b :: TK) (a :: TK -> Type).
Delta a (TKProduct y b) -> Delta a b
DeltaProject2 Delta target (TKProduct x y)
dShared)
unDeltaPairUnshared :: Delta target (TKProduct x y)
-> (Delta target x, Delta target y)
unDeltaPairUnshared :: forall (target :: TK -> Type) (x :: TK) (y :: TK).
Delta target (TKProduct x y) -> (Delta target x, Delta target y)
unDeltaPairUnshared (DeltaPair Delta target y
a Delta target z
b) = (Delta target x
Delta target y
a, Delta target y
Delta target z
b)
unDeltaPairUnshared (DeltaZero (FTKProduct FullShapeTK y1
ftk1 FullShapeTK z
ftk2)) =
(FullShapeTK x -> Delta target x
forall (b :: TK) (a :: TK -> Type). FullShapeTK b -> Delta a b
DeltaZero FullShapeTK x
FullShapeTK y1
ftk1, FullShapeTK y -> Delta target y
forall (b :: TK) (a :: TK -> Type). FullShapeTK b -> Delta a b
DeltaZero FullShapeTK y
FullShapeTK z
ftk2)
unDeltaPairUnshared Delta target (TKProduct x y)
d = (Delta target (TKProduct x y) -> Delta target x
forall (b :: TK) (z :: TK) (a :: TK -> Type).
Delta a (TKProduct b z) -> Delta a b
DeltaProject1 Delta target (TKProduct x y)
d, Delta target (TKProduct x y) -> Delta target y
forall (y :: TK) (b :: TK) (a :: TK -> Type).
Delta a (TKProduct y b) -> Delta a b
DeltaProject2 Delta target (TKProduct x y)
d)
dScale :: Num (f z) => f z -> Delta f z -> Delta f z
dScale :: forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
f z -> Delta f z -> Delta f z
dScale f z
_ (DeltaZero FullShapeTK z
ftk) = FullShapeTK z -> Delta f z
forall (b :: TK) (a :: TK -> Type). FullShapeTK b -> Delta a b
DeltaZero FullShapeTK z
ftk
dScale f z
v Delta f z
u' = NestedTarget f z -> Delta f z -> Delta f z
forall (a :: TK -> Type) (b :: TK).
Num (a b) =>
NestedTarget a b -> Delta a b -> Delta a b
DeltaScale (f z -> NestedTarget f z
forall (target :: TK -> Type) (y :: TK).
target y -> NestedTarget target y
NestedTarget f z
v) Delta f z
u'
dAdd :: Num (f z) => Delta f z -> Delta f z -> Delta f z
dAdd :: forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
Delta f z -> Delta f z -> Delta f z
dAdd DeltaZero{} Delta f z
w = Delta f z
w
dAdd Delta f z
v DeltaZero{} = Delta f z
v
dAdd Delta f z
v Delta f z
w = Delta f z -> Delta f z -> Delta f z
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
Delta f z -> Delta f z -> Delta f z
DeltaAdd Delta f z
v Delta f z
w
dSFromR :: forall sh x target.
ShS sh -> Delta target (TKR2 (Rank sh) x)
-> Delta target (TKS2 sh x)
dSFromR :: forall (sh :: [Nat]) (x :: TK) (target :: TK -> Type).
ShS sh
-> Delta target (TKR2 (Rank @Nat sh) x) -> Delta target (TKS2 sh x)
dSFromR ShS sh
sh w :: Delta target (TKR2 (Rank @Nat sh) x)
w@(DeltaConvert TKConversion a1 (TKR2 (Rank @Nat sh) x)
_c Delta target a1
d)
| FTKR IShR n
_ FullShapeTK x
x <- Delta target (TKR2 (Rank @Nat sh) x)
-> FullShapeTK (TKR2 (Rank @Nat sh) x)
forall (target :: TK -> Type) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR2 (Rank @Nat sh) x)
w
, Just (:~:) @TK a1 (TKS2 sh x)
Refl <- FullShapeTK a1
-> FullShapeTK (TKS2 sh x) -> Maybe ((:~:) @TK a1 (TKS2 sh x))
forall (y1 :: TK) (y2 :: TK).
FullShapeTK y1 -> FullShapeTK y2 -> Maybe ((:~:) @TK y1 y2)
matchingFTK (Delta target a1 -> FullShapeTK a1
forall (target :: TK -> Type) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target a1
d) (ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS ShS sh
sh FullShapeTK x
x) = Delta target a1
Delta target (TKS2 sh x)
d
dSFromR ShS sh
sh Delta target (TKR2 (Rank @Nat sh) x)
d | FTKR IShR n
_ FullShapeTK x
x <- Delta target (TKR2 (Rank @Nat sh) x)
-> FullShapeTK (TKR2 (Rank @Nat sh) x)
forall (target :: TK -> Type) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKR2 (Rank @Nat sh) x)
d
, (:~:)
@Nat
(Rank @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)))
n
Refl <- Proxy @Nat n
-> (:~:)
@Nat
(Rank @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)))
n
forall (proxy :: Nat -> Type) (n :: Nat).
proxy n
-> (:~:)
@Nat
(Rank @(Maybe Nat) (Replicate @(Maybe Nat) n ('Nothing @Nat)))
n
lemRankReplicate (forall (t :: Nat). Proxy @Nat t
forall {k} (t :: k). Proxy @k t
Proxy @(Rank sh)) =
let c2 :: TKConversion (TKR2 n x) (TKS2 sh x)
c2 = TKConversion
(TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) x) (TKS2 sh x)
-> TKConversion
(TKR2 n x) (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) x)
-> TKConversion (TKR2 n x) (TKS2 sh x)
forall (b1 :: TK) (b :: TK) (a :: TK).
TKConversion b1 b -> TKConversion a b1 -> TKConversion a b
ConvCmp (FullShapeTK (TKS2 sh x)
-> TKConversion
(TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) x) (TKS2 sh x)
forall (sh :: [Maybe Nat]) (sh' :: [Nat]) (a1 :: TK).
((Rank @(Maybe Nat) sh :: Nat) ~ (Rank @Nat sh' :: Nat)) =>
FullShapeTK (TKS2 sh' a1)
-> TKConversion (TKX2 sh a1) (TKS2 sh' a1)
ConvXS' (ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS ShS sh
sh FullShapeTK x
x)) TKConversion
(TKR2 n x) (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) x)
forall (n :: Nat) (a1 :: TK).
TKConversion
(TKR2 n a1) (TKX2 (Replicate @(Maybe Nat) n ('Nothing @Nat)) a1)
ConvRX
in TKConversion (TKR2 n x) (TKS2 sh x)
-> Delta target (TKR2 n x) -> Delta target (TKS2 sh x)
forall (a1 :: TK) (b :: TK) (a :: TK -> Type).
TKConversion a1 b -> Delta a a1 -> Delta a b
DeltaConvert TKConversion (TKR2 n x) (TKS2 sh x)
TKConversion (TKR2 n x) (TKS2 sh x)
c2 Delta target (TKR2 n x)
Delta target (TKR2 (Rank @Nat sh) x)
d
dSFromX :: forall sh sh' x target. Rank sh ~ Rank sh'
=> ShS sh -> Delta target (TKX2 sh' x)
-> Delta target (TKS2 sh x)
dSFromX :: forall (sh :: [Nat]) (sh' :: [Maybe Nat]) (x :: TK)
(target :: TK -> Type).
((Rank @Nat sh :: Nat) ~ (Rank @(Maybe Nat) sh' :: Nat)) =>
ShS sh -> Delta target (TKX2 sh' x) -> Delta target (TKS2 sh x)
dSFromX ShS sh
sh w :: Delta target (TKX2 sh' x)
w@(DeltaConvert TKConversion a1 (TKX2 sh' x)
_c Delta target a1
d)
| FTKX IShX sh
_ FullShapeTK x
x <- Delta target (TKX2 sh' x) -> FullShapeTK (TKX2 sh' x)
forall (target :: TK -> Type) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX2 sh' x)
w
, Just (:~:) @TK a1 (TKS2 sh x)
Refl <- FullShapeTK a1
-> FullShapeTK (TKS2 sh x) -> Maybe ((:~:) @TK a1 (TKS2 sh x))
forall (y1 :: TK) (y2 :: TK).
FullShapeTK y1 -> FullShapeTK y2 -> Maybe ((:~:) @TK y1 y2)
matchingFTK (Delta target a1 -> FullShapeTK a1
forall (target :: TK -> Type) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target a1
d) (ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS ShS sh
sh FullShapeTK x
x) = Delta target a1
Delta target (TKS2 sh x)
d
dSFromX ShS sh
sh Delta target (TKX2 sh' x)
d | FTKX IShX sh
_ FullShapeTK x
x <- Delta target (TKX2 sh' x) -> FullShapeTK (TKX2 sh' x)
forall (target :: TK -> Type) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKX2 sh' x)
d =
let c2 :: TKConversion (TKX2 sh' x) (TKS2 sh x)
c2 = FullShapeTK (TKS2 sh x) -> TKConversion (TKX2 sh' x) (TKS2 sh x)
forall (sh :: [Maybe Nat]) (sh' :: [Nat]) (a1 :: TK).
((Rank @(Maybe Nat) sh :: Nat) ~ (Rank @Nat sh' :: Nat)) =>
FullShapeTK (TKS2 sh' a1)
-> TKConversion (TKX2 sh a1) (TKS2 sh' a1)
ConvXS' (ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS ShS sh
sh FullShapeTK x
x)
in TKConversion (TKX2 sh' x) (TKS2 sh x)
-> Delta target (TKX2 sh' x) -> Delta target (TKS2 sh x)
forall (a1 :: TK) (b :: TK) (a :: TK -> Type).
TKConversion a1 b -> Delta a a1 -> Delta a b
DeltaConvert TKConversion (TKX2 sh' x) (TKS2 sh x)
TKConversion (TKX2 sh' x) (TKS2 sh x)
c2 Delta target (TKX2 sh' x)
Delta target (TKX2 sh' x)
d
dXFromS :: forall sh sh' x target. Rank sh ~ Rank sh'
=> StaticShX sh' -> Delta target (TKS2 sh x)
-> Delta target (TKX2 sh' x)
dXFromS :: forall (sh :: [Nat]) (sh' :: [Maybe Nat]) (x :: TK)
(target :: TK -> Type).
((Rank @Nat sh :: Nat) ~ (Rank @(Maybe Nat) sh' :: Nat)) =>
StaticShX sh'
-> Delta target (TKS2 sh x) -> Delta target (TKX2 sh' x)
dXFromS StaticShX sh'
ssx w :: Delta target (TKS2 sh x)
w@(DeltaConvert TKConversion a1 (TKS2 sh x)
_c Delta target a1
d)
| FTKS ShS sh
sh FullShapeTK x
x <- Delta target (TKS2 sh x) -> FullShapeTK (TKS2 sh x)
forall (target :: TK -> Type) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS2 sh x)
w
, let shx :: IShX sh'
shx = StaticShX sh' -> ShS sh -> IShX sh'
forall (sh :: [Nat]) (sh' :: [Maybe Nat]).
((Rank @Nat sh :: Nat) ~ (Rank @(Maybe Nat) sh' :: Nat)) =>
StaticShX sh' -> ShS sh -> IShX sh'
shCastSX StaticShX sh'
ssx ShS sh
sh
, Just (:~:) @TK a1 (TKX2 sh' x)
Refl <- FullShapeTK a1
-> FullShapeTK (TKX2 sh' x) -> Maybe ((:~:) @TK a1 (TKX2 sh' x))
forall (y1 :: TK) (y2 :: TK).
FullShapeTK y1 -> FullShapeTK y2 -> Maybe ((:~:) @TK y1 y2)
matchingFTK (Delta target a1 -> FullShapeTK a1
forall (target :: TK -> Type) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target a1
d) (IShX sh' -> FullShapeTK x -> FullShapeTK (TKX2 sh' x)
forall (sh :: [Maybe Nat]) (x :: TK).
IShX sh -> FullShapeTK x -> FullShapeTK (TKX2 sh x)
FTKX IShX sh'
shx FullShapeTK x
x) = Delta target a1
Delta target (TKX2 sh' x)
d
dXFromS StaticShX sh'
ssx Delta target (TKS2 sh x)
d
| FTKS ShS sh
sh FullShapeTK x
x <- Delta target (TKS2 sh x) -> FullShapeTK (TKS2 sh x)
forall (target :: TK -> Type) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta target (TKS2 sh x)
d
, let shx :: IShX sh'
shx = StaticShX sh' -> ShS sh -> IShX sh'
forall (sh :: [Nat]) (sh' :: [Maybe Nat]).
((Rank @Nat sh :: Nat) ~ (Rank @(Maybe Nat) sh' :: Nat)) =>
StaticShX sh' -> ShS sh -> IShX sh'
shCastSX StaticShX sh'
ssx ShS sh
sh
, (:~:) @Nat (Rank @(Maybe Nat) (MapJust @Nat sh)) (Rank @Nat sh)
Refl <- ShS sh
-> (:~:) @Nat (Rank @(Maybe Nat) (MapJust @Nat sh)) (Rank @Nat sh)
forall (sh :: [Nat]).
ShS sh
-> (:~:) @Nat (Rank @(Maybe Nat) (MapJust @Nat sh)) (Rank @Nat sh)
lemRankMapJust ShS sh
sh =
let c2 :: TKConversion (TKS2 sh x) (TKX2 sh' x)
c2 = TKConversion (TKX2 (MapJust @Nat sh) x) (TKX2 sh' x)
-> TKConversion (TKS2 sh x) (TKX2 (MapJust @Nat sh) x)
-> TKConversion (TKS2 sh x) (TKX2 sh' x)
forall (b1 :: TK) (b :: TK) (a :: TK).
TKConversion b1 b -> TKConversion a b1 -> TKConversion a b
ConvCmp (FullShapeTK (TKX2 sh' x)
-> TKConversion (TKX2 (MapJust @Nat sh) x) (TKX2 sh' x)
forall (sh :: [Maybe Nat]) (sh' :: [Maybe Nat]) (a1 :: TK).
((Rank @(Maybe Nat) sh :: Nat) ~ (Rank @(Maybe Nat) sh' :: Nat)) =>
FullShapeTK (TKX2 sh' a1)
-> TKConversion (TKX2 sh a1) (TKX2 sh' a1)
ConvXX' (IShX sh' -> FullShapeTK x -> FullShapeTK (TKX2 sh' x)
forall (sh :: [Maybe Nat]) (x :: TK).
IShX sh -> FullShapeTK x -> FullShapeTK (TKX2 sh x)
FTKX IShX sh'
shx FullShapeTK x
x)) TKConversion (TKS2 sh x) (TKX2 (MapJust @Nat sh) x)
forall (sh :: [Nat]) (a1 :: TK).
TKConversion (TKS2 sh a1) (TKX2 (MapJust @Nat sh) a1)
ConvSX
in TKConversion (TKS2 sh x) (TKX2 sh' x)
-> Delta target (TKS2 sh x) -> Delta target (TKX2 sh' x)
forall (a1 :: TK) (b :: TK) (a :: TK -> Type).
TKConversion a1 b -> Delta a a1 -> Delta a b
DeltaConvert TKConversion (TKS2 sh x) (TKX2 sh' x)
TKConversion (TKS2 sh x) (TKX2 sh' x)
c2 Delta target (TKS2 sh x)
Delta target (TKS2 sh x)
d
intOfShape :: forall z f. ADReadyNoLet f
=> Delta f z -> Int -> f z
intOfShape :: forall (z :: TK) (f :: TK -> Type).
ADReadyNoLet f =>
Delta f z -> Int -> f z
intOfShape Delta f z
tsh Int
c = (forall r. GoodScalar r => r) -> FullShapeTK z -> f z
forall (y :: TK).
(forall r. GoodScalar r => r) -> FullShapeTK y -> f y
forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
(forall r. GoodScalar r => r) -> FullShapeTK y -> target y
treplTarget (Int -> r
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
c) (Delta f z -> FullShapeTK z
forall (target :: TK -> Type) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta f z
tsh)
ensureToplevelSharing :: ADVal f z -> ADVal f z
ensureToplevelSharing :: forall (f :: TK -> Type) (z :: TK). ADVal f z -> ADVal f z
ensureToplevelSharing (D f z
u Delta f z
u') = f z -> Delta f z -> ADVal f z
forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
dD f z
u Delta f z
u'
scaleNotShared :: Num (f z)
=> f z -> ADVal f z -> ADVal f z
scaleNotShared :: forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
f z -> ADVal f z -> ADVal f z
scaleNotShared !f z
a (D f z
u Delta f z
u') = f z -> Delta f z -> ADVal f z
forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
dDnotShared (f z
a f z -> f z -> f z
forall a. Num a => a -> a -> a
* f z
u) (f z -> Delta f z -> Delta f z
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
f z -> Delta f z -> Delta f z
dScale f z
a Delta f z
u')
addNotShared :: forall f z. Num (f z)
=> ADVal f z -> ADVal f z -> ADVal f z
addNotShared :: forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
ADVal f z -> ADVal f z -> ADVal f z
addNotShared (D f z
u Delta f z
u') (D f z
v Delta f z
v') = f z -> Delta f z -> ADVal f z
forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
dDnotShared (f z
u f z -> f z -> f z
forall a. Num a => a -> a -> a
+ f z
v) (Delta f z -> Delta f z -> Delta f z
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
Delta f z -> Delta f z -> Delta f z
dAdd Delta f z
u' Delta f z
v')
multNotShared :: forall f z. Num (f z)
=> ADVal f z -> ADVal f z -> ADVal f z
multNotShared :: forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
ADVal f z -> ADVal f z -> ADVal f z
multNotShared (D f z
u Delta f z
u') (D f z
v Delta f z
v') =
f z -> Delta f z -> ADVal f z
forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
dDnotShared (f z
u f z -> f z -> f z
forall a. Num a => a -> a -> a
* f z
v) (Delta f z -> Delta f z -> Delta f z
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
Delta f z -> Delta f z -> Delta f z
dAdd (f z -> Delta f z -> Delta f z
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
f z -> Delta f z -> Delta f z
dScale f z
v Delta f z
u') (f z -> Delta f z -> Delta f z
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
f z -> Delta f z -> Delta f z
dScale f z
u Delta f z
v'))
generateDeltaInputs :: forall x target.
FullShapeTK x -> Delta target x
generateDeltaInputs :: forall (b :: TK) (a :: TK -> Type). FullShapeTK b -> Delta a b
generateDeltaInputs =
let gen :: Int -> FullShapeTK y -> (Delta target y, Int)
gen :: forall (y :: TK). Int -> FullShapeTK y -> (Delta target y, Int)
gen Int
j FullShapeTK y
ftk = case FullShapeTK y
ftk of
FTKProduct FullShapeTK y1
ftk1 FullShapeTK z
ftk2 ->
let (Delta target y1
d1, Int
j1) = Int -> FullShapeTK y1 -> (Delta target y1, Int)
forall (y :: TK). Int -> FullShapeTK y -> (Delta target y, Int)
gen Int
j FullShapeTK y1
ftk1
(Delta target z
d2, Int
j2) = Int -> FullShapeTK z -> (Delta target z, Int)
forall (y :: TK). Int -> FullShapeTK y -> (Delta target y, Int)
gen Int
j1 FullShapeTK z
ftk2
in (Delta target y1 -> Delta target z -> Delta target (TKProduct y1 z)
forall (y :: TK) (z :: TK) (a :: TK -> Type).
Delta a y -> Delta a z -> Delta a (TKProduct y z)
DeltaPair Delta target y1
d1 Delta target z
d2, Int
j2)
FullShapeTK y
_ | FullShapeTK y -> Bool
forall (y :: TK). FullShapeTK y -> Bool
differentiableFTK FullShapeTK y
ftk -> (InputId target y -> Delta target y
forall (a :: TK -> Type) (b :: TK). InputId a b -> Delta a b
DeltaInput (FullShapeTK y -> Int -> InputId target y
forall (y :: TK) (f :: TK -> Type).
FullShapeTK y -> Int -> InputId f y
mkInputId FullShapeTK y
ftk Int
j), Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
FullShapeTK y
_ -> (FullShapeTK y -> Delta target y
forall (b :: TK) (a :: TK -> Type). FullShapeTK b -> Delta a b
DeltaZero FullShapeTK y
ftk, Int
j)
in (Delta target x, Int) -> Delta target x
forall a b. (a, b) -> a
fst ((Delta target x, Int) -> Delta target x)
-> (FullShapeTK x -> (Delta target x, Int))
-> FullShapeTK x
-> Delta target x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> FullShapeTK x -> (Delta target x, Int)
forall (y :: TK). Int -> FullShapeTK y -> (Delta target y, Int)
gen Int
0
type instance BoolOf (ADVal f) = BoolOf f
instance EqH f (TKScalar r) => EqH (ADVal f) (TKScalar r) where
D f (TKScalar r)
u Delta f (TKScalar r)
_ ==. :: ADVal f (TKScalar r) -> ADVal f (TKScalar r) -> BoolOf (ADVal f)
==. D f (TKScalar r)
v Delta f (TKScalar r)
_ = f (TKScalar r)
u f (TKScalar r) -> f (TKScalar r) -> BoolOf f
forall (f :: TK -> Type) (y :: TK).
EqH f y =>
f y -> f y -> BoolOf f
==. f (TKScalar r)
v
instance OrdH f (TKScalar r) => OrdH (ADVal f) (TKScalar r) where
D f (TKScalar r)
u Delta f (TKScalar r)
_ <=. :: ADVal f (TKScalar r) -> ADVal f (TKScalar r) -> BoolOf (ADVal f)
<=. D f (TKScalar r)
v Delta f (TKScalar r)
_ = f (TKScalar r)
u f (TKScalar r) -> f (TKScalar r) -> BoolOf f
forall (f :: TK -> Type) (y :: TK).
OrdH f y =>
f y -> f y -> BoolOf f
<=. f (TKScalar r)
v
instance EqH f (TKR n r) => EqH (ADVal f) (TKR n r) where
D f (TKR n r)
u Delta f (TKR n r)
_ ==. :: ADVal f (TKR n r) -> ADVal f (TKR n r) -> BoolOf (ADVal f)
==. D f (TKR n r)
v Delta f (TKR n r)
_ = f (TKR n r)
u f (TKR n r) -> f (TKR n r) -> BoolOf f
forall (f :: TK -> Type) (y :: TK).
EqH f y =>
f y -> f y -> BoolOf f
==. f (TKR n r)
v
instance OrdH f (TKR n r) => OrdH (ADVal f) (TKR n r) where
D f (TKR n r)
u Delta f (TKR n r)
_ <=. :: ADVal f (TKR n r) -> ADVal f (TKR n r) -> BoolOf (ADVal f)
<=. D f (TKR n r)
v Delta f (TKR n r)
_ = f (TKR n r)
u f (TKR n r) -> f (TKR n r) -> BoolOf f
forall (f :: TK -> Type) (y :: TK).
OrdH f y =>
f y -> f y -> BoolOf f
<=. f (TKR n r)
v
instance EqH f (TKS sh r) => EqH (ADVal f) (TKS sh r) where
D f (TKS sh r)
u Delta f (TKS sh r)
_ ==. :: ADVal f (TKS sh r) -> ADVal f (TKS sh r) -> BoolOf (ADVal f)
==. D f (TKS sh r)
v Delta f (TKS sh r)
_ = f (TKS sh r)
u f (TKS sh r) -> f (TKS sh r) -> BoolOf f
forall (f :: TK -> Type) (y :: TK).
EqH f y =>
f y -> f y -> BoolOf f
==. f (TKS sh r)
v
instance OrdH f (TKS sh r) => OrdH (ADVal f) (TKS sh r) where
D f (TKS sh r)
u Delta f (TKS sh r)
_ <=. :: ADVal f (TKS sh r) -> ADVal f (TKS sh r) -> BoolOf (ADVal f)
<=. D f (TKS sh r)
v Delta f (TKS sh r)
_ = f (TKS sh r)
u f (TKS sh r) -> f (TKS sh r) -> BoolOf f
forall (f :: TK -> Type) (y :: TK).
OrdH f y =>
f y -> f y -> BoolOf f
<=. f (TKS sh r)
v
instance EqH f (TKX sh r) => EqH (ADVal f) (TKX sh r) where
D f (TKX sh r)
u Delta f (TKX sh r)
_ ==. :: ADVal f (TKX sh r) -> ADVal f (TKX sh r) -> BoolOf (ADVal f)
==. D f (TKX sh r)
v Delta f (TKX sh r)
_ = f (TKX sh r)
u f (TKX sh r) -> f (TKX sh r) -> BoolOf f
forall (f :: TK -> Type) (y :: TK).
EqH f y =>
f y -> f y -> BoolOf f
==. f (TKX sh r)
v
instance OrdH f (TKX sh r) => OrdH (ADVal f) (TKX sh r) where
D f (TKX sh r)
u Delta f (TKX sh r)
_ <=. :: ADVal f (TKX sh r) -> ADVal f (TKX sh r) -> BoolOf (ADVal f)
<=. D f (TKX sh r)
v Delta f (TKX sh r)
_ = f (TKX sh r)
u f (TKX sh r) -> f (TKX sh r) -> BoolOf f
forall (f :: TK -> Type) (y :: TK).
OrdH f y =>
f y -> f y -> BoolOf f
<=. f (TKX sh r)
v
type instance HFunOf (ADVal f) x y = HFun x y
type instance PrimalOf (ADVal f) = f
type instance DualOf (ADVal f) = Delta f
type instance ShareOf (ADVal f) = ADVal f
instance Eq (ADVal f z) where
== :: ADVal f z -> ADVal f z -> Bool
(==) = String -> ADVal f z -> ADVal f z -> Bool
forall a. HasCallStack => String -> a
error String
"Eq is not defined for ADVal; please use EqH instead"
/= :: ADVal f z -> ADVal f z -> Bool
(/=) = String -> ADVal f z -> ADVal f z -> Bool
forall a. HasCallStack => String -> a
error String
"Eq is not defined for ADVal; please use EqH instead"
instance Ord (ADVal f z) where
<= :: ADVal f z -> ADVal f z -> Bool
(<=) = String -> ADVal f z -> ADVal f z -> Bool
forall a. HasCallStack => String -> a
error String
"Ord is not defined for ADVal; please use OrdH instead"
instance (GoodScalar r, ShareTensor f, ADReadyNoLet f)
=> Num (ADVal f (TKScalar r)) where
D f (TKScalar r)
u Delta f (TKScalar r)
u' + :: ADVal f (TKScalar r)
-> ADVal f (TKScalar r) -> ADVal f (TKScalar r)
+ D f (TKScalar r)
v Delta f (TKScalar r)
v' = f (TKScalar r) -> Delta f (TKScalar r) -> ADVal f (TKScalar r)
forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
dD (f (TKScalar r)
u f (TKScalar r) -> f (TKScalar r) -> f (TKScalar r)
forall a. Num a => a -> a -> a
+ f (TKScalar r)
v) (Delta f (TKScalar r)
-> Delta f (TKScalar r) -> Delta f (TKScalar r)
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
Delta f z -> Delta f z -> Delta f z
dAdd Delta f (TKScalar r)
u' Delta f (TKScalar r)
v')
D f (TKScalar r)
u Delta f (TKScalar r)
u' - :: ADVal f (TKScalar r)
-> ADVal f (TKScalar r) -> ADVal f (TKScalar r)
- D f (TKScalar r)
v Delta f (TKScalar r)
v' =
f (TKScalar r) -> Delta f (TKScalar r) -> ADVal f (TKScalar r)
forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
dD (f (TKScalar r)
u f (TKScalar r) -> f (TKScalar r) -> f (TKScalar r)
forall a. Num a => a -> a -> a
- f (TKScalar r)
v) (Delta f (TKScalar r)
-> Delta f (TKScalar r) -> Delta f (TKScalar r)
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
Delta f z -> Delta f z -> Delta f z
dAdd Delta f (TKScalar r)
u' (f (TKScalar r) -> Delta f (TKScalar r) -> Delta f (TKScalar r)
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
f z -> Delta f z -> Delta f z
dScale (Delta f (TKScalar r) -> Int -> f (TKScalar r)
forall (z :: TK) (f :: TK -> Type).
ADReadyNoLet f =>
Delta f z -> Int -> f z
intOfShape Delta f (TKScalar r)
v' (-Int
1)) Delta f (TKScalar r)
v'))
D f (TKScalar r)
ue Delta f (TKScalar r)
u' * :: ADVal f (TKScalar r)
-> ADVal f (TKScalar r) -> ADVal f (TKScalar r)
* D f (TKScalar r)
ve Delta f (TKScalar r)
v' =
let !u :: f (TKScalar r)
u = f (TKScalar r) -> f (TKScalar r)
forall (y :: TK). f y -> f y
forall (target :: TK -> Type) (y :: TK).
ShareTensor target =>
target y -> target y
tshare f (TKScalar r)
ue in
let !v :: f (TKScalar r)
v = f (TKScalar r) -> f (TKScalar r)
forall (y :: TK). f y -> f y
forall (target :: TK -> Type) (y :: TK).
ShareTensor target =>
target y -> target y
tshare f (TKScalar r)
ve
in f (TKScalar r) -> Delta f (TKScalar r) -> ADVal f (TKScalar r)
forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
dD (f (TKScalar r)
u f (TKScalar r) -> f (TKScalar r) -> f (TKScalar r)
forall a. Num a => a -> a -> a
* f (TKScalar r)
v) (Delta f (TKScalar r)
-> Delta f (TKScalar r) -> Delta f (TKScalar r)
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
Delta f z -> Delta f z -> Delta f z
dAdd (f (TKScalar r) -> Delta f (TKScalar r) -> Delta f (TKScalar r)
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
f z -> Delta f z -> Delta f z
dScale f (TKScalar r)
v Delta f (TKScalar r)
u') (f (TKScalar r) -> Delta f (TKScalar r) -> Delta f (TKScalar r)
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
f z -> Delta f z -> Delta f z
dScale f (TKScalar r)
u Delta f (TKScalar r)
v'))
negate :: ADVal f (TKScalar r) -> ADVal f (TKScalar r)
negate (D f (TKScalar r)
v Delta f (TKScalar r)
v') = f (TKScalar r) -> Delta f (TKScalar r) -> ADVal f (TKScalar r)
forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
dD (f (TKScalar r) -> f (TKScalar r)
forall a. Num a => a -> a
negate f (TKScalar r)
v) (f (TKScalar r) -> Delta f (TKScalar r) -> Delta f (TKScalar r)
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
f z -> Delta f z -> Delta f z
dScale (Delta f (TKScalar r) -> Int -> f (TKScalar r)
forall (z :: TK) (f :: TK -> Type).
ADReadyNoLet f =>
Delta f z -> Int -> f z
intOfShape Delta f (TKScalar r)
v' (-Int
1)) Delta f (TKScalar r)
v')
abs :: ADVal f (TKScalar r) -> ADVal f (TKScalar r)
abs (D f (TKScalar r)
ve Delta f (TKScalar r)
v') = let !v :: f (TKScalar r)
v = f (TKScalar r) -> f (TKScalar r)
forall (y :: TK). f y -> f y
forall (target :: TK -> Type) (y :: TK).
ShareTensor target =>
target y -> target y
tshare f (TKScalar r)
ve
in f (TKScalar r) -> Delta f (TKScalar r) -> ADVal f (TKScalar r)
forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
dD (f (TKScalar r) -> f (TKScalar r)
forall a. Num a => a -> a
abs f (TKScalar r)
v) (f (TKScalar r) -> Delta f (TKScalar r) -> Delta f (TKScalar r)
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
f z -> Delta f z -> Delta f z
dScale (f (TKScalar r) -> f (TKScalar r)
forall a. Num a => a -> a
signum f (TKScalar r)
v) Delta f (TKScalar r)
v')
signum :: ADVal f (TKScalar r) -> ADVal f (TKScalar r)
signum (D f (TKScalar r)
v Delta f (TKScalar r)
v') = f (TKScalar r) -> Delta f (TKScalar r) -> ADVal f (TKScalar r)
forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
dDnotShared (f (TKScalar r) -> f (TKScalar r)
forall a. Num a => a -> a
signum f (TKScalar r)
v) (FullShapeTK (TKScalar r) -> Delta f (TKScalar r)
forall (b :: TK) (a :: TK -> Type). FullShapeTK b -> Delta a b
DeltaZero (FullShapeTK (TKScalar r) -> Delta f (TKScalar r))
-> FullShapeTK (TKScalar r) -> Delta f (TKScalar r)
forall a b. (a -> b) -> a -> b
$ Delta f (TKScalar r) -> FullShapeTK (TKScalar r)
forall (target :: TK -> Type) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta f (TKScalar r)
v')
fromInteger :: Integer -> ADVal f (TKScalar r)
fromInteger Integer
i = f (TKScalar r) -> Delta f (TKScalar r) -> ADVal f (TKScalar r)
forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
dDnotShared (Integer -> f (TKScalar r)
forall a. Num a => Integer -> a
fromInteger Integer
i) (FullShapeTK (TKScalar r) -> Delta f (TKScalar r)
forall (b :: TK) (a :: TK -> Type). FullShapeTK b -> Delta a b
DeltaZero FullShapeTK (TKScalar r)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar)
{-# SPECIALIZE instance (ShareTensor Concrete, ADReadyNoLet Concrete) => Num (ADVal Concrete (TKScalar Double)) #-}
{-# SPECIALIZE instance (ShareTensor Concrete, ADReadyNoLet Concrete) => Num (ADVal Concrete (TKScalar Float)) #-}
{-# SPECIALIZE instance (ShareTensor Concrete, ADReadyNoLet Concrete) => Num (ADVal Concrete (TKScalar Int64)) #-}
instance {-# OVERLAPPABLE #-}
(Num (f z), ShareTensor f, ADReadyNoLet f)
=> Num (ADVal f z) where
D f z
u Delta f z
u' + :: ADVal f z -> ADVal f z -> ADVal f z
+ D f z
v Delta f z
v' = f z -> Delta f z -> ADVal f z
forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
dD (f z
u f z -> f z -> f z
forall a. Num a => a -> a -> a
+ f z
v) (Delta f z -> Delta f z -> Delta f z
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
Delta f z -> Delta f z -> Delta f z
dAdd Delta f z
u' Delta f z
v')
D f z
u Delta f z
u' - :: ADVal f z -> ADVal f z -> ADVal f z
- D f z
v Delta f z
v' =
f z -> Delta f z -> ADVal f z
forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
dD (f z
u f z -> f z -> f z
forall a. Num a => a -> a -> a
- f z
v) (Delta f z -> Delta f z -> Delta f z
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
Delta f z -> Delta f z -> Delta f z
dAdd Delta f z
u' (f z -> Delta f z -> Delta f z
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
f z -> Delta f z -> Delta f z
dScale (Delta f z -> Int -> f z
forall (z :: TK) (f :: TK -> Type).
ADReadyNoLet f =>
Delta f z -> Int -> f z
intOfShape Delta f z
v' (-Int
1)) Delta f z
v'))
D f z
ue Delta f z
u' * :: ADVal f z -> ADVal f z -> ADVal f z
* D f z
ve Delta f z
v' =
let !u :: f z
u = f z -> f z
forall (y :: TK). f y -> f y
forall (target :: TK -> Type) (y :: TK).
ShareTensor target =>
target y -> target y
tshare f z
ue in
let !v :: f z
v = f z -> f z
forall (y :: TK). f y -> f y
forall (target :: TK -> Type) (y :: TK).
ShareTensor target =>
target y -> target y
tshare f z
ve
in f z -> Delta f z -> ADVal f z
forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
dD (f z
u f z -> f z -> f z
forall a. Num a => a -> a -> a
* f z
v) (Delta f z -> Delta f z -> Delta f z
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
Delta f z -> Delta f z -> Delta f z
dAdd (f z -> Delta f z -> Delta f z
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
f z -> Delta f z -> Delta f z
dScale f z
v Delta f z
u') (f z -> Delta f z -> Delta f z
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
f z -> Delta f z -> Delta f z
dScale f z
u Delta f z
v'))
negate :: ADVal f z -> ADVal f z
negate (D f z
v Delta f z
v') = f z -> Delta f z -> ADVal f z
forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
dD (f z -> f z
forall a. Num a => a -> a
negate f z
v) (f z -> Delta f z -> Delta f z
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
f z -> Delta f z -> Delta f z
dScale (Delta f z -> Int -> f z
forall (z :: TK) (f :: TK -> Type).
ADReadyNoLet f =>
Delta f z -> Int -> f z
intOfShape Delta f z
v' (-Int
1)) Delta f z
v')
abs :: ADVal f z -> ADVal f z
abs (D f z
ve Delta f z
v') = let !v :: f z
v = f z -> f z
forall (y :: TK). f y -> f y
forall (target :: TK -> Type) (y :: TK).
ShareTensor target =>
target y -> target y
tshare f z
ve
in f z -> Delta f z -> ADVal f z
forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
dD (f z -> f z
forall a. Num a => a -> a
abs f z
v) (f z -> Delta f z -> Delta f z
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
f z -> Delta f z -> Delta f z
dScale (f z -> f z
forall a. Num a => a -> a
signum f z
v) Delta f z
v')
signum :: ADVal f z -> ADVal f z
signum (D f z
v Delta f z
v') = f z -> Delta f z -> ADVal f z
forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
dDnotShared (f z -> f z
forall a. Num a => a -> a
signum f z
v) (FullShapeTK z -> Delta f z
forall (b :: TK) (a :: TK -> Type). FullShapeTK b -> Delta a b
DeltaZero (FullShapeTK z -> Delta f z) -> FullShapeTK z -> Delta f z
forall a b. (a -> b) -> a -> b
$ Delta f z -> FullShapeTK z
forall (target :: TK -> Type) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta f z
v')
fromInteger :: Integer -> ADVal f z
fromInteger = String -> Integer -> ADVal f z
forall a. HasCallStack => String -> a
error String
"fromInteger is not defined for tensors in general"
{-# SPECIALIZE instance (ShareTensor Concrete, ADReadyNoLet Concrete) => Num (ADVal Concrete (TKR n Double)) #-}
{-# SPECIALIZE instance (ShareTensor Concrete, ADReadyNoLet Concrete) => Num (ADVal Concrete (TKR n Float)) #-}
{-# SPECIALIZE instance (ShareTensor Concrete, ADReadyNoLet Concrete) => Num (ADVal Concrete (TKR n Int64)) #-}
{-# SPECIALIZE instance (ShareTensor Concrete, ADReadyNoLet Concrete) => Num (ADVal Concrete (TKS sh Double)) #-}
{-# SPECIALIZE instance (ShareTensor Concrete, ADReadyNoLet Concrete) => Num (ADVal Concrete (TKS sh Float)) #-}
{-# SPECIALIZE instance (ShareTensor Concrete, ADReadyNoLet Concrete) => Num (ADVal Concrete (TKS sh Int64)) #-}
{-# SPECIALIZE instance (ShareTensor Concrete, ADReadyNoLet Concrete) => Num (ADVal Concrete (TKX sh Double)) #-}
{-# SPECIALIZE instance (ShareTensor Concrete, ADReadyNoLet Concrete) => Num (ADVal Concrete (TKX sh Float)) #-}
{-# SPECIALIZE instance (ShareTensor Concrete, ADReadyNoLet Concrete) => Num (ADVal Concrete (TKX sh Int64)) #-}
instance (Real (f z), ShareTensor f, ADReadyNoLet f)
=> Real (ADVal f z) where
toRational :: ADVal f z -> Rational
toRational (D f z
v Delta f z
_) = f z -> Rational
forall a. Real a => a -> Rational
toRational f z
v
instance (IntegralH (f z), ShareTensor f, ADReadyNoLet f)
=> IntegralH (ADVal f z) where
quotH :: ADVal f z -> ADVal f z -> ADVal f z
quotH (D f z
u Delta f z
_) (D f z
v Delta f z
v') = f z -> Delta f z -> ADVal f z
forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
dDnotShared (f z -> f z -> f z
forall a. IntegralH a => a -> a -> a
quotH f z
u f z
v) (FullShapeTK z -> Delta f z
forall (b :: TK) (a :: TK -> Type). FullShapeTK b -> Delta a b
DeltaZero (FullShapeTK z -> Delta f z) -> FullShapeTK z -> Delta f z
forall a b. (a -> b) -> a -> b
$ Delta f z -> FullShapeTK z
forall (target :: TK -> Type) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta f z
v')
remH :: ADVal f z -> ADVal f z -> ADVal f z
remH (D f z
u Delta f z
_) (D f z
v Delta f z
v') = f z -> Delta f z -> ADVal f z
forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
dDnotShared (f z -> f z -> f z
forall a. IntegralH a => a -> a -> a
remH f z
u f z
v) (FullShapeTK z -> Delta f z
forall (b :: TK) (a :: TK -> Type). FullShapeTK b -> Delta a b
DeltaZero (FullShapeTK z -> Delta f z) -> FullShapeTK z -> Delta f z
forall a b. (a -> b) -> a -> b
$ Delta f z -> FullShapeTK z
forall (target :: TK -> Type) (y :: TK).
Delta target y -> FullShapeTK y
ftkDelta Delta f z
v')
{-# SPECIALIZE instance (ShareTensor Concrete, ADReadyNoLet Concrete) => IntegralH (ADVal Concrete (TKR n Int64)) #-}
{-# SPECIALIZE instance (ShareTensor Concrete, ADReadyNoLet Concrete) => IntegralH (ADVal Concrete (TKS sh Int64)) #-}
{-# SPECIALIZE instance (ShareTensor Concrete, ADReadyNoLet Concrete) => IntegralH (ADVal Concrete (TKX sh Int64)) #-}
instance ( GoodScalar r, Fractional (f (TKScalar r)), ShareTensor f
, ADReadyNoLet f )
=> Fractional (ADVal f (TKScalar r)) where
D f (TKScalar r)
ue Delta f (TKScalar r)
u' / :: ADVal f (TKScalar r)
-> ADVal f (TKScalar r) -> ADVal f (TKScalar r)
/ D f (TKScalar r)
ve Delta f (TKScalar r)
v' =
let !u :: f (TKScalar r)
u = f (TKScalar r) -> f (TKScalar r)
forall (y :: TK). f y -> f y
forall (target :: TK -> Type) (y :: TK).
ShareTensor target =>
target y -> target y
tshare f (TKScalar r)
ue in
let !v :: f (TKScalar r)
v = f (TKScalar r) -> f (TKScalar r)
forall (y :: TK). f y -> f y
forall (target :: TK -> Type) (y :: TK).
ShareTensor target =>
target y -> target y
tshare f (TKScalar r)
ve
in f (TKScalar r) -> Delta f (TKScalar r) -> ADVal f (TKScalar r)
forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
dD (f (TKScalar r)
u f (TKScalar r) -> f (TKScalar r) -> f (TKScalar r)
forall a. Fractional a => a -> a -> a
/ f (TKScalar r)
v) (Delta f (TKScalar r)
-> Delta f (TKScalar r) -> Delta f (TKScalar r)
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
Delta f z -> Delta f z -> Delta f z
dAdd (f (TKScalar r) -> Delta f (TKScalar r) -> Delta f (TKScalar r)
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
f z -> Delta f z -> Delta f z
dScale (f (TKScalar r) -> f (TKScalar r)
forall a. Fractional a => a -> a
recip f (TKScalar r)
v) Delta f (TKScalar r)
u') (f (TKScalar r) -> Delta f (TKScalar r) -> Delta f (TKScalar r)
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
f z -> Delta f z -> Delta f z
dScale ((- f (TKScalar r)
u) f (TKScalar r) -> f (TKScalar r) -> f (TKScalar r)
forall a. Fractional a => a -> a -> a
/ (f (TKScalar r)
v f (TKScalar r) -> f (TKScalar r) -> f (TKScalar r)
forall a. Num a => a -> a -> a
* f (TKScalar r)
v)) Delta f (TKScalar r)
v'))
recip :: ADVal f (TKScalar r) -> ADVal f (TKScalar r)
recip (D f (TKScalar r)
ve Delta f (TKScalar r)
v') =
let !v :: f (TKScalar r)
v = f (TKScalar r) -> f (TKScalar r)
forall (y :: TK). f y -> f y
forall (target :: TK -> Type) (y :: TK).
ShareTensor target =>
target y -> target y
tshare f (TKScalar r)
ve
minusRecipSq :: f (TKScalar r)
minusRecipSq = - f (TKScalar r) -> f (TKScalar r)
forall a. Fractional a => a -> a
recip (f (TKScalar r)
v f (TKScalar r) -> f (TKScalar r) -> f (TKScalar r)
forall a. Num a => a -> a -> a
* f (TKScalar r)
v)
in f (TKScalar r) -> Delta f (TKScalar r) -> ADVal f (TKScalar r)
forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
dD (f (TKScalar r) -> f (TKScalar r)
forall a. Fractional a => a -> a
recip f (TKScalar r)
v) (f (TKScalar r) -> Delta f (TKScalar r) -> Delta f (TKScalar r)
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
f z -> Delta f z -> Delta f z
dScale f (TKScalar r)
minusRecipSq Delta f (TKScalar r)
v')
fromRational :: Rational -> ADVal f (TKScalar r)
fromRational Rational
r = f (TKScalar r) -> Delta f (TKScalar r) -> ADVal f (TKScalar r)
forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
dDnotShared (Rational -> f (TKScalar r)
forall a. Fractional a => Rational -> a
fromRational Rational
r) (FullShapeTK (TKScalar r) -> Delta f (TKScalar r)
forall (b :: TK) (a :: TK -> Type). FullShapeTK b -> Delta a b
DeltaZero FullShapeTK (TKScalar r)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar)
instance {-# OVERLAPPABLE #-}
(Fractional (f z), ShareTensor f, ADReadyNoLet f)
=> Fractional (ADVal f z) where
D f z
ue Delta f z
u' / :: ADVal f z -> ADVal f z -> ADVal f z
/ D f z
ve Delta f z
v' =
let !u :: f z
u = f z -> f z
forall (y :: TK). f y -> f y
forall (target :: TK -> Type) (y :: TK).
ShareTensor target =>
target y -> target y
tshare f z
ue in
let !v :: f z
v = f z -> f z
forall (y :: TK). f y -> f y
forall (target :: TK -> Type) (y :: TK).
ShareTensor target =>
target y -> target y
tshare f z
ve
in f z -> Delta f z -> ADVal f z
forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
dD (f z
u f z -> f z -> f z
forall a. Fractional a => a -> a -> a
/ f z
v) (Delta f z -> Delta f z -> Delta f z
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
Delta f z -> Delta f z -> Delta f z
dAdd (f z -> Delta f z -> Delta f z
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
f z -> Delta f z -> Delta f z
dScale (f z -> f z
forall a. Fractional a => a -> a
recip f z
v) Delta f z
u') (f z -> Delta f z -> Delta f z
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
f z -> Delta f z -> Delta f z
dScale ((- f z
u) f z -> f z -> f z
forall a. Fractional a => a -> a -> a
/ (f z
v f z -> f z -> f z
forall a. Num a => a -> a -> a
* f z
v)) Delta f z
v'))
recip :: ADVal f z -> ADVal f z
recip (D f z
ve Delta f z
v') =
let !v :: f z
v = f z -> f z
forall (y :: TK). f y -> f y
forall (target :: TK -> Type) (y :: TK).
ShareTensor target =>
target y -> target y
tshare f z
ve
minusRecipSq :: f z
minusRecipSq = - f z -> f z
forall a. Fractional a => a -> a
recip (f z
v f z -> f z -> f z
forall a. Num a => a -> a -> a
* f z
v)
in f z -> Delta f z -> ADVal f z
forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
dD (f z -> f z
forall a. Fractional a => a -> a
recip f z
v) (f z -> Delta f z -> Delta f z
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
f z -> Delta f z -> Delta f z
dScale f z
minusRecipSq Delta f z
v')
fromRational :: Rational -> ADVal f z
fromRational = String -> Rational -> ADVal f z
forall a. HasCallStack => String -> a
error String
"fromRational is not defined for tensors in general"
instance (Floating (f z), ShareTensor f, ADReadyNoLet f)
=> Floating (ADVal f z) where
pi :: ADVal f z
pi = String -> ADVal f z
forall a. HasCallStack => String -> a
error String
"pi is not defined for tensors"
exp :: ADVal f z -> ADVal f z
exp (D f z
ue Delta f z
u') = let !expU :: f z
expU = f z -> f z
forall (y :: TK). f y -> f y
forall (target :: TK -> Type) (y :: TK).
ShareTensor target =>
target y -> target y
tshare (f z -> f z
forall a. Floating a => a -> a
exp f z
ue)
in f z -> Delta f z -> ADVal f z
forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
dD f z
expU (f z -> Delta f z -> Delta f z
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
f z -> Delta f z -> Delta f z
dScale f z
expU Delta f z
u')
log :: ADVal f z -> ADVal f z
log (D f z
ue Delta f z
u') = let !u :: f z
u = f z -> f z
forall (y :: TK). f y -> f y
forall (target :: TK -> Type) (y :: TK).
ShareTensor target =>
target y -> target y
tshare f z
ue
in f z -> Delta f z -> ADVal f z
forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
dD (f z -> f z
forall a. Floating a => a -> a
log f z
u) (f z -> Delta f z -> Delta f z
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
f z -> Delta f z -> Delta f z
dScale (f z -> f z
forall a. Fractional a => a -> a
recip f z
u) Delta f z
u')
sqrt :: ADVal f z -> ADVal f z
sqrt (D f z
ue Delta f z
u') = let !sqrtU :: f z
sqrtU = f z -> f z
forall (y :: TK). f y -> f y
forall (target :: TK -> Type) (y :: TK).
ShareTensor target =>
target y -> target y
tshare (f z -> f z
forall a. Floating a => a -> a
sqrt f z
ue)
in f z -> Delta f z -> ADVal f z
forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
dD f z
sqrtU (f z -> Delta f z -> Delta f z
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
f z -> Delta f z -> Delta f z
dScale (f z -> f z
forall a. Fractional a => a -> a
recip (f z
sqrtU f z -> f z -> f z
forall a. Num a => a -> a -> a
+ f z
sqrtU)) Delta f z
u')
D f z
ue Delta f z
u' ** :: ADVal f z -> ADVal f z -> ADVal f z
** D f z
ve Delta f z
v' =
let !u :: f z
u = f z -> f z
forall (y :: TK). f y -> f y
forall (target :: TK -> Type) (y :: TK).
ShareTensor target =>
target y -> target y
tshare f z
ue in
let !v :: f z
v = f z -> f z
forall (y :: TK). f y -> f y
forall (target :: TK -> Type) (y :: TK).
ShareTensor target =>
target y -> target y
tshare f z
ve
in f z -> Delta f z -> ADVal f z
forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
dD (f z
u f z -> f z -> f z
forall a. Floating a => a -> a -> a
** f z
v) (Delta f z -> Delta f z -> Delta f z
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
Delta f z -> Delta f z -> Delta f z
dAdd (f z -> Delta f z -> Delta f z
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
f z -> Delta f z -> Delta f z
dScale (f z
v f z -> f z -> f z
forall a. Num a => a -> a -> a
* (f z
u f z -> f z -> f z
forall a. Floating a => a -> a -> a
** (f z
v f z -> f z -> f z
forall a. Num a => a -> a -> a
- Delta f z -> Int -> f z
forall (z :: TK) (f :: TK -> Type).
ADReadyNoLet f =>
Delta f z -> Int -> f z
intOfShape Delta f z
v' Int
1))) Delta f z
u')
(f z -> Delta f z -> Delta f z
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
f z -> Delta f z -> Delta f z
dScale ((f z
u f z -> f z -> f z
forall a. Floating a => a -> a -> a
** f z
v) f z -> f z -> f z
forall a. Num a => a -> a -> a
* f z -> f z
forall a. Floating a => a -> a
log f z
u) Delta f z
v'))
sin :: ADVal f z -> ADVal f z
sin (D f z
ue Delta f z
u') = let !u :: f z
u = f z -> f z
forall (y :: TK). f y -> f y
forall (target :: TK -> Type) (y :: TK).
ShareTensor target =>
target y -> target y
tshare f z
ue
in f z -> Delta f z -> ADVal f z
forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
dD (f z -> f z
forall a. Floating a => a -> a
sin f z
u) (f z -> Delta f z -> Delta f z
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
f z -> Delta f z -> Delta f z
dScale (f z -> f z
forall a. Floating a => a -> a
cos f z
u) Delta f z
u')
cos :: ADVal f z -> ADVal f z
cos (D f z
ue Delta f z
u') = let !u :: f z
u = f z -> f z
forall (y :: TK). f y -> f y
forall (target :: TK -> Type) (y :: TK).
ShareTensor target =>
target y -> target y
tshare f z
ue
in f z -> Delta f z -> ADVal f z
forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
dD (f z -> f z
forall a. Floating a => a -> a
cos f z
u) (f z -> Delta f z -> Delta f z
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
f z -> Delta f z -> Delta f z
dScale (- (f z -> f z
forall a. Floating a => a -> a
sin f z
u)) Delta f z
u')
tan :: ADVal f z -> ADVal f z
tan (D f z
ue Delta f z
u') = let !u :: f z
u = f z -> f z
forall (y :: TK). f y -> f y
forall (target :: TK -> Type) (y :: TK).
ShareTensor target =>
target y -> target y
tshare f z
ue in
let !cosU :: f z
cosU = f z -> f z
forall (y :: TK). f y -> f y
forall (target :: TK -> Type) (y :: TK).
ShareTensor target =>
target y -> target y
tshare (f z -> f z
forall a. Floating a => a -> a
cos f z
u)
in f z -> Delta f z -> ADVal f z
forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
dD (f z -> f z
forall a. Floating a => a -> a
tan f z
u) (f z -> Delta f z -> Delta f z
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
f z -> Delta f z -> Delta f z
dScale (f z -> f z
forall a. Fractional a => a -> a
recip (f z
cosU f z -> f z -> f z
forall a. Num a => a -> a -> a
* f z
cosU)) Delta f z
u')
asin :: ADVal f z -> ADVal f z
asin (D f z
ue Delta f z
u') = let !u :: f z
u = f z -> f z
forall (y :: TK). f y -> f y
forall (target :: TK -> Type) (y :: TK).
ShareTensor target =>
target y -> target y
tshare f z
ue
in f z -> Delta f z -> ADVal f z
forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
dD (f z -> f z
forall a. Floating a => a -> a
asin f z
u)
(f z -> Delta f z -> Delta f z
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
f z -> Delta f z -> Delta f z
dScale (f z -> f z
forall a. Fractional a => a -> a
recip (f z -> f z
forall a. Floating a => a -> a
sqrt (Delta f z -> Int -> f z
forall (z :: TK) (f :: TK -> Type).
ADReadyNoLet f =>
Delta f z -> Int -> f z
intOfShape Delta f z
u' Int
1 f z -> f z -> f z
forall a. Num a => a -> a -> a
- f z
u f z -> f z -> f z
forall a. Num a => a -> a -> a
* f z
u))) Delta f z
u')
acos :: ADVal f z -> ADVal f z
acos (D f z
ue Delta f z
u') = let !u :: f z
u = f z -> f z
forall (y :: TK). f y -> f y
forall (target :: TK -> Type) (y :: TK).
ShareTensor target =>
target y -> target y
tshare f z
ue
in f z -> Delta f z -> ADVal f z
forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
dD (f z -> f z
forall a. Floating a => a -> a
acos f z
u)
(f z -> Delta f z -> Delta f z
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
f z -> Delta f z -> Delta f z
dScale (- f z -> f z
forall a. Fractional a => a -> a
recip (f z -> f z
forall a. Floating a => a -> a
sqrt (Delta f z -> Int -> f z
forall (z :: TK) (f :: TK -> Type).
ADReadyNoLet f =>
Delta f z -> Int -> f z
intOfShape Delta f z
u' Int
1 f z -> f z -> f z
forall a. Num a => a -> a -> a
- f z
u f z -> f z -> f z
forall a. Num a => a -> a -> a
* f z
u))) Delta f z
u')
atan :: ADVal f z -> ADVal f z
atan (D f z
ue Delta f z
u') = let !u :: f z
u = f z -> f z
forall (y :: TK). f y -> f y
forall (target :: TK -> Type) (y :: TK).
ShareTensor target =>
target y -> target y
tshare f z
ue
in f z -> Delta f z -> ADVal f z
forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
dD (f z -> f z
forall a. Floating a => a -> a
atan f z
u)
(f z -> Delta f z -> Delta f z
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
f z -> Delta f z -> Delta f z
dScale (f z -> f z
forall a. Fractional a => a -> a
recip (Delta f z -> Int -> f z
forall (z :: TK) (f :: TK -> Type).
ADReadyNoLet f =>
Delta f z -> Int -> f z
intOfShape Delta f z
u' Int
1 f z -> f z -> f z
forall a. Num a => a -> a -> a
+ f z
u f z -> f z -> f z
forall a. Num a => a -> a -> a
* f z
u)) Delta f z
u')
sinh :: ADVal f z -> ADVal f z
sinh (D f z
ue Delta f z
u') = let !u :: f z
u = f z -> f z
forall (y :: TK). f y -> f y
forall (target :: TK -> Type) (y :: TK).
ShareTensor target =>
target y -> target y
tshare f z
ue
in f z -> Delta f z -> ADVal f z
forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
dD (f z -> f z
forall a. Floating a => a -> a
sinh f z
u) (f z -> Delta f z -> Delta f z
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
f z -> Delta f z -> Delta f z
dScale (f z -> f z
forall a. Floating a => a -> a
cosh f z
u) Delta f z
u')
cosh :: ADVal f z -> ADVal f z
cosh (D f z
ue Delta f z
u') = let !u :: f z
u = f z -> f z
forall (y :: TK). f y -> f y
forall (target :: TK -> Type) (y :: TK).
ShareTensor target =>
target y -> target y
tshare f z
ue
in f z -> Delta f z -> ADVal f z
forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
dD (f z -> f z
forall a. Floating a => a -> a
cosh f z
u) (f z -> Delta f z -> Delta f z
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
f z -> Delta f z -> Delta f z
dScale (f z -> f z
forall a. Floating a => a -> a
sinh f z
u) Delta f z
u')
tanh :: ADVal f z -> ADVal f z
tanh (D f z
ue Delta f z
u') = let !y :: f z
y = f z -> f z
forall (y :: TK). f y -> f y
forall (target :: TK -> Type) (y :: TK).
ShareTensor target =>
target y -> target y
tshare (f z -> f z
forall a. Floating a => a -> a
tanh f z
ue)
in f z -> Delta f z -> ADVal f z
forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
dD f z
y (f z -> Delta f z -> Delta f z
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
f z -> Delta f z -> Delta f z
dScale (Delta f z -> Int -> f z
forall (z :: TK) (f :: TK -> Type).
ADReadyNoLet f =>
Delta f z -> Int -> f z
intOfShape Delta f z
u' Int
1 f z -> f z -> f z
forall a. Num a => a -> a -> a
- f z
y f z -> f z -> f z
forall a. Num a => a -> a -> a
* f z
y) Delta f z
u')
asinh :: ADVal f z -> ADVal f z
asinh (D f z
ue Delta f z
u') = let !u :: f z
u = f z -> f z
forall (y :: TK). f y -> f y
forall (target :: TK -> Type) (y :: TK).
ShareTensor target =>
target y -> target y
tshare f z
ue
in f z -> Delta f z -> ADVal f z
forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
dD (f z -> f z
forall a. Floating a => a -> a
asinh f z
u)
(f z -> Delta f z -> Delta f z
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
f z -> Delta f z -> Delta f z
dScale (f z -> f z
forall a. Fractional a => a -> a
recip (f z -> f z
forall a. Floating a => a -> a
sqrt (Delta f z -> Int -> f z
forall (z :: TK) (f :: TK -> Type).
ADReadyNoLet f =>
Delta f z -> Int -> f z
intOfShape Delta f z
u' Int
1 f z -> f z -> f z
forall a. Num a => a -> a -> a
+ f z
u f z -> f z -> f z
forall a. Num a => a -> a -> a
* f z
u))) Delta f z
u')
acosh :: ADVal f z -> ADVal f z
acosh (D f z
ue Delta f z
u') = let !u :: f z
u = f z -> f z
forall (y :: TK). f y -> f y
forall (target :: TK -> Type) (y :: TK).
ShareTensor target =>
target y -> target y
tshare f z
ue
in f z -> Delta f z -> ADVal f z
forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
dD (f z -> f z
forall a. Floating a => a -> a
acosh f z
u)
(f z -> Delta f z -> Delta f z
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
f z -> Delta f z -> Delta f z
dScale (f z -> f z
forall a. Fractional a => a -> a
recip (f z -> f z
forall a. Floating a => a -> a
sqrt (f z
u f z -> f z -> f z
forall a. Num a => a -> a -> a
* f z
u f z -> f z -> f z
forall a. Num a => a -> a -> a
- Delta f z -> Int -> f z
forall (z :: TK) (f :: TK -> Type).
ADReadyNoLet f =>
Delta f z -> Int -> f z
intOfShape Delta f z
u' Int
1))) Delta f z
u')
atanh :: ADVal f z -> ADVal f z
atanh (D f z
ue Delta f z
u') = let !u :: f z
u = f z -> f z
forall (y :: TK). f y -> f y
forall (target :: TK -> Type) (y :: TK).
ShareTensor target =>
target y -> target y
tshare f z
ue
in f z -> Delta f z -> ADVal f z
forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
dD (f z -> f z
forall a. Floating a => a -> a
atanh f z
u)
(f z -> Delta f z -> Delta f z
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
f z -> Delta f z -> Delta f z
dScale (f z -> f z
forall a. Fractional a => a -> a
recip (Delta f z -> Int -> f z
forall (z :: TK) (f :: TK -> Type).
ADReadyNoLet f =>
Delta f z -> Int -> f z
intOfShape Delta f z
u' Int
1 f z -> f z -> f z
forall a. Num a => a -> a -> a
- f z
u f z -> f z -> f z
forall a. Num a => a -> a -> a
* f z
u)) Delta f z
u')
instance (RealFrac (f z), ShareTensor f, ADReadyNoLet f)
=> RealFrac (ADVal f z) where
properFraction :: forall b. Integral b => ADVal f z -> (b, ADVal f z)
properFraction = String -> ADVal f z -> (b, ADVal f z)
forall a. HasCallStack => String -> a
error String
"properFraction is not defined for tensors"
instance (Fractional (f z), RealFloatH (f z), ShareTensor f, ADReadyNoLet f)
=> RealFloatH (ADVal f z) where
atan2H :: ADVal f z -> ADVal f z -> ADVal f z
atan2H (D f z
ue Delta f z
u') (D f z
ve Delta f z
v') =
let !u :: f z
u = f z -> f z
forall (y :: TK). f y -> f y
forall (target :: TK -> Type) (y :: TK).
ShareTensor target =>
target y -> target y
tshare f z
ue in
let !v :: f z
v = f z -> f z
forall (y :: TK). f y -> f y
forall (target :: TK -> Type) (y :: TK).
ShareTensor target =>
target y -> target y
tshare f z
ve in
let !t :: f z
t = f z -> f z
forall (y :: TK). f y -> f y
forall (target :: TK -> Type) (y :: TK).
ShareTensor target =>
target y -> target y
tshare (f z -> f z
forall a. Fractional a => a -> a
recip (f z
u f z -> f z -> f z
forall a. Num a => a -> a -> a
* f z
u f z -> f z -> f z
forall a. Num a => a -> a -> a
+ f z
v f z -> f z -> f z
forall a. Num a => a -> a -> a
* f z
v))
in f z -> Delta f z -> ADVal f z
forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
dD (f z -> f z -> f z
forall a. RealFloatH a => a -> a -> a
atan2H f z
u f z
v) (Delta f z -> Delta f z -> Delta f z
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
Delta f z -> Delta f z -> Delta f z
dAdd (f z -> Delta f z -> Delta f z
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
f z -> Delta f z -> Delta f z
dScale ((- f z
u) f z -> f z -> f z
forall a. Num a => a -> a -> a
* f z
t) Delta f z
v') (f z -> Delta f z -> Delta f z
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
f z -> Delta f z -> Delta f z
dScale (f z
v f z -> f z -> f z
forall a. Num a => a -> a -> a
* f z
t) Delta f z
u'))
instance (RealFloat (f z), ShareTensor f, ADReadyNoLet f)
=> RealFloat (ADVal f z) where
atan2 :: ADVal f z -> ADVal f z -> ADVal f z
atan2 (D f z
ue Delta f z
u') (D f z
ve Delta f z
v') =
let !u :: f z
u = f z -> f z
forall (y :: TK). f y -> f y
forall (target :: TK -> Type) (y :: TK).
ShareTensor target =>
target y -> target y
tshare f z
ue in
let !v :: f z
v = f z -> f z
forall (y :: TK). f y -> f y
forall (target :: TK -> Type) (y :: TK).
ShareTensor target =>
target y -> target y
tshare f z
ve in
let !t :: f z
t = f z -> f z
forall (y :: TK). f y -> f y
forall (target :: TK -> Type) (y :: TK).
ShareTensor target =>
target y -> target y
tshare (f z -> f z
forall a. Fractional a => a -> a
recip (f z
u f z -> f z -> f z
forall a. Num a => a -> a -> a
* f z
u f z -> f z -> f z
forall a. Num a => a -> a -> a
+ f z
v f z -> f z -> f z
forall a. Num a => a -> a -> a
* f z
v))
in f z -> Delta f z -> ADVal f z
forall (f :: TK -> Type) (z :: TK). f z -> Delta f z -> ADVal f z
dD (f z -> f z -> f z
forall a. RealFloat a => a -> a -> a
atan2 f z
u f z
v) (Delta f z -> Delta f z -> Delta f z
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
Delta f z -> Delta f z -> Delta f z
dAdd (f z -> Delta f z -> Delta f z
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
f z -> Delta f z -> Delta f z
dScale ((- f z
u) f z -> f z -> f z
forall a. Num a => a -> a -> a
* f z
t) Delta f z
v') (f z -> Delta f z -> Delta f z
forall (f :: TK -> Type) (z :: TK).
Num (f z) =>
f z -> Delta f z -> Delta f z
dScale (f z
v f z -> f z -> f z
forall a. Num a => a -> a -> a
* f z
t) Delta f z
u'))
floatRadix :: ADVal f z -> Integer
floatRadix (D f z
u Delta f z
_) = f z -> Integer
forall a. RealFloat a => a -> Integer
floatRadix f z
u
floatDigits :: ADVal f z -> Int
floatDigits (D f z
u Delta f z
_) = f z -> Int
forall a. RealFloat a => a -> Int
floatDigits f z
u
floatRange :: ADVal f z -> (Int, Int)
floatRange (D f z
u Delta f z
_) = f z -> (Int, Int)
forall a. RealFloat a => a -> (Int, Int)
floatRange f z
u
decodeFloat :: ADVal f z -> (Integer, Int)
decodeFloat (D f z
u Delta f z
_) = f z -> (Integer, Int)
forall a. RealFloat a => a -> (Integer, Int)
decodeFloat f z
u
encodeFloat :: Integer -> Int -> ADVal f z
encodeFloat Integer
_i Int
_j = String -> ADVal f z
forall a. HasCallStack => String -> a
error String
"encodeFloat is not defined for tensors"
isNaN :: ADVal f z -> Bool
isNaN (D f z
u Delta f z
_) = f z -> Bool
forall a. RealFloat a => a -> Bool
isNaN f z
u
isInfinite :: ADVal f z -> Bool
isInfinite (D f z
u Delta f z
_) = f z -> Bool
forall a. RealFloat a => a -> Bool
isInfinite f z
u
isDenormalized :: ADVal f z -> Bool
isDenormalized (D f z
u Delta f z
_) = f z -> Bool
forall a. RealFloat a => a -> Bool
isDenormalized f z
u
isNegativeZero :: ADVal f z -> Bool
isNegativeZero (D f z
u Delta f z
_) = f z -> Bool
forall a. RealFloat a => a -> Bool
isNegativeZero f z
u
isIEEE :: ADVal f z -> Bool
isIEEE (D f z
u Delta f z
_) = f z -> Bool
forall a. RealFloat a => a -> Bool
isIEEE f z
u