{-# OPTIONS_GHC -Wno-orphans #-}
module HordeAd.ADEngine
(
grad, vjp
, gradArtifact, vjpArtifact
, gradInterpretArtifact, vjpInterpretArtifact
, jvp, jvpArtifact, jvpInterpretArtifact
, cgrad, cvjp
, cjvp
, IncomingCotangentHandling(..)
, revArtifactAdapt, revArtifactDelta
, revProduceArtifactWithoutInterpretation, revInterpretArtifact
, fwdArtifactAdapt, fwdArtifactDelta, fwdInterpretArtifact
, cfwdBoth
) where
import Prelude
import HordeAd.AstEngine
import HordeAd.Core.Adaptor
import HordeAd.Core.Ast
import HordeAd.Core.AstEnv
import HordeAd.Core.AstInterpret
import HordeAd.Core.CarriersADVal
import HordeAd.Core.CarriersAst
import HordeAd.Core.CarriersConcrete
import HordeAd.Core.Delta
import HordeAd.Core.DeltaEval
import HordeAd.Core.Ops
import HordeAd.Core.OpsADVal
import HordeAd.Core.OpsAst
import HordeAd.Core.TensorKind
import HordeAd.Core.Types
import HordeAd.Core.Unwind
grad
:: forall src r tgt.
( X src ~ X (Value src), KnownSTK (X src)
, AdaptableTarget (AstTensor AstMethodLet FullSpan) src
, AdaptableTarget Concrete (Value src)
, tgt ~ AstTensor AstMethodLet FullSpan (TKScalar r) )
=> (src -> tgt)
-> Value src
-> Value src
{-# INLINE grad #-}
grad :: forall src r tgt.
((X src :: TK) ~ (X (Value src) :: TK), KnownSTK (X src),
AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
AdaptableTarget Concrete (Value src),
(tgt :: Type)
~ (AstTensor AstMethodLet FullSpan (TKScalar r) :: Type)) =>
(src -> tgt) -> Value src -> Value src
grad src -> tgt
f Value src
vals = (src -> tgt)
-> Value src
-> Maybe (Concrete (ADTensorKind (TKScalar r)))
-> Value src
forall src (ztgt :: TK) tgt.
((X src :: TK) ~ (X (Value src) :: TK), KnownSTK (X src),
AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
AdaptableTarget Concrete (Value src),
(tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
(src -> tgt)
-> Value src -> Maybe (Concrete (ADTensorKind ztgt)) -> Value src
revMaybe src -> tgt
f Value src
vals Maybe (Concrete (ADTensorKind (TKScalar r)))
Maybe (Concrete (TKScalar (ADTensorScalar r)))
forall a. Maybe a
Nothing
vjp
:: forall src ztgt tgt.
( X src ~ X (Value src), KnownSTK (X src)
, AdaptableTarget (AstTensor AstMethodLet FullSpan) src
, AdaptableTarget Concrete (Value src)
, tgt ~ AstTensor AstMethodLet FullSpan ztgt )
=> (src -> tgt)
-> Value src
-> Concrete (ADTensorKind ztgt)
-> Value src
{-# INLINE vjp #-}
vjp :: forall src (ztgt :: TK) tgt.
((X src :: TK) ~ (X (Value src) :: TK), KnownSTK (X src),
AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
AdaptableTarget Concrete (Value src),
(tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
(src -> tgt)
-> Value src -> Concrete (ADTensorKind ztgt) -> Value src
vjp src -> tgt
f Value src
vals Concrete (ADTensorKind ztgt)
dt = (src -> tgt)
-> Value src -> Maybe (Concrete (ADTensorKind ztgt)) -> Value src
forall src (ztgt :: TK) tgt.
((X src :: TK) ~ (X (Value src) :: TK), KnownSTK (X src),
AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
AdaptableTarget Concrete (Value src),
(tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
(src -> tgt)
-> Value src -> Maybe (Concrete (ADTensorKind ztgt)) -> Value src
revMaybe src -> tgt
f Value src
vals (Concrete (ADTensorKind ztgt)
-> Maybe (Concrete (ADTensorKind ztgt))
forall a. a -> Maybe a
Just Concrete (ADTensorKind ztgt)
dt)
gradArtifact
:: forall src r tgt.
( X src ~ X (Value src), KnownSTK (X src)
, AdaptableTarget (AstTensor AstMethodLet FullSpan) src
, AdaptableTarget Concrete (Value src)
, tgt ~ AstTensor AstMethodLet FullSpan (TKScalar r) )
=> (src -> tgt)
-> Value src
-> AstArtifactRev (X src) (TKScalar r)
{-# INLINE gradArtifact #-}
gradArtifact :: forall src r tgt.
((X src :: TK) ~ (X (Value src) :: TK), KnownSTK (X src),
AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
AdaptableTarget Concrete (Value src),
(tgt :: Type)
~ (AstTensor AstMethodLet FullSpan (TKScalar r) :: Type)) =>
(src -> tgt) -> Value src -> AstArtifactRev (X src) (TKScalar r)
gradArtifact src -> tgt
f Value src
vals0 =
let xftk :: FullShapeTK (X src)
xftk = SingletonTK (X src) -> RepConcrete (X src) -> FullShapeTK (X src)
forall (y :: TK). SingletonTK y -> RepConcrete y -> FullShapeTK y
tftkG (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(X src)) (RepConcrete (X src) -> FullShapeTK (X src))
-> RepConcrete (X src) -> FullShapeTK (X src)
forall a b. (a -> b) -> a -> b
$ Concrete (X (Value src)) -> RepConcrete (X (Value src))
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete (Concrete (X (Value src)) -> RepConcrete (X (Value src)))
-> Concrete (X (Value src)) -> RepConcrete (X (Value src))
forall a b. (a -> b) -> a -> b
$ Value src -> Concrete (X (Value src))
forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget Value src
vals0
in IncomingCotangentHandling
-> (src -> tgt)
-> FullShapeTK (X src)
-> AstArtifactRev (X src) (TKScalar r)
forall src (ztgt :: TK) tgt.
(AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
(tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
IncomingCotangentHandling
-> (src -> tgt)
-> FullShapeTK (X src)
-> AstArtifactRev (X src) ztgt
revArtifactAdapt IncomingCotangentHandling
IgnoreIncomingCotangent src -> tgt
f FullShapeTK (X src)
xftk
vjpArtifact
:: forall src ztgt tgt.
( X src ~ X (Value src), KnownSTK (X src)
, AdaptableTarget (AstTensor AstMethodLet FullSpan) src
, AdaptableTarget Concrete (Value src)
, tgt ~ AstTensor AstMethodLet FullSpan ztgt )
=> (src -> tgt)
-> Value src
-> AstArtifactRev (X src) ztgt
{-# INLINE vjpArtifact #-}
vjpArtifact :: forall src (ztgt :: TK) tgt.
((X src :: TK) ~ (X (Value src) :: TK), KnownSTK (X src),
AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
AdaptableTarget Concrete (Value src),
(tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
(src -> tgt) -> Value src -> AstArtifactRev (X src) ztgt
vjpArtifact src -> tgt
f Value src
vals0 =
let xftk :: FullShapeTK (X src)
xftk = SingletonTK (X src) -> RepConcrete (X src) -> FullShapeTK (X src)
forall (y :: TK). SingletonTK y -> RepConcrete y -> FullShapeTK y
tftkG (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(X src)) (RepConcrete (X src) -> FullShapeTK (X src))
-> RepConcrete (X src) -> FullShapeTK (X src)
forall a b. (a -> b) -> a -> b
$ Concrete (X (Value src)) -> RepConcrete (X (Value src))
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete (Concrete (X (Value src)) -> RepConcrete (X (Value src)))
-> Concrete (X (Value src)) -> RepConcrete (X (Value src))
forall a b. (a -> b) -> a -> b
$ Value src -> Concrete (X (Value src))
forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget Value src
vals0
in IncomingCotangentHandling
-> (src -> tgt)
-> FullShapeTK (X src)
-> AstArtifactRev (X src) ztgt
forall src (ztgt :: TK) tgt.
(AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
(tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
IncomingCotangentHandling
-> (src -> tgt)
-> FullShapeTK (X src)
-> AstArtifactRev (X src) ztgt
revArtifactAdapt IncomingCotangentHandling
UseIncomingCotangent src -> tgt
f FullShapeTK (X src)
xftk
gradInterpretArtifact
:: forall x r avals.
(X avals ~ ADTensorKind x, AdaptableTarget Concrete avals)
=> AstArtifactRev x (TKScalar r)
-> Concrete x
-> avals
{-# INLINE gradInterpretArtifact #-}
gradInterpretArtifact :: forall (x :: TK) r avals.
((X avals :: TK) ~ (ADTensorKind x :: TK),
AdaptableTarget Concrete avals) =>
AstArtifactRev x (TKScalar r) -> Concrete x -> avals
gradInterpretArtifact AstArtifactRev{AstTensor AstMethodLet PrimalSpan (ADTensorKind x)
AstTensor AstMethodLet PrimalSpan (TKScalar r)
AstVarName PrimalSpan x
AstVarName PrimalSpan (ADTensorKind (TKScalar r))
artVarDtRev :: AstVarName PrimalSpan (ADTensorKind (TKScalar r))
artVarDomainRev :: AstVarName PrimalSpan x
artDerivativeRev :: AstTensor AstMethodLet PrimalSpan (ADTensorKind x)
artPrimalRev :: AstTensor AstMethodLet PrimalSpan (TKScalar r)
artPrimalRev :: forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstTensor AstMethodLet PrimalSpan z
artDerivativeRev :: forall (x :: TK) (z :: TK).
AstArtifactRev x z
-> AstTensor AstMethodLet PrimalSpan (ADTensorKind x)
artVarDomainRev :: forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstVarName PrimalSpan x
artVarDtRev :: forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstVarName PrimalSpan (ADTensorKind z)
..} Concrete x
parameters =
let xftk :: FullShapeTK x
xftk = AstVarName PrimalSpan x -> FullShapeTK x
forall (s :: AstSpanType) (y :: TK).
AstVarName s y -> FullShapeTK y
varNameToFTK AstVarName PrimalSpan x
artVarDomainRev
azftk :: FullShapeTK (TKScalar (ADTensorScalar r))
azftk = AstVarName PrimalSpan (TKScalar (ADTensorScalar r))
-> FullShapeTK (TKScalar (ADTensorScalar r))
forall (s :: AstSpanType) (y :: TK).
AstVarName s y -> FullShapeTK y
varNameToFTK AstVarName PrimalSpan (ADTensorKind (TKScalar r))
AstVarName PrimalSpan (TKScalar (ADTensorScalar r))
artVarDtRev
oneAtF :: Concrete (TKScalar (ADTensorScalar r))
oneAtF = (forall r. GoodScalar r => r)
-> FullShapeTK (TKScalar (ADTensorScalar r))
-> Concrete (TKScalar (ADTensorScalar r))
forall (y :: TK).
(forall r. GoodScalar r => r) -> FullShapeTK y -> Concrete y
forall (target :: Target) (y :: TK).
BaseTensor target =>
(forall r. GoodScalar r => r) -> FullShapeTK y -> target y
treplTarget r
forall r. GoodScalar r => r
1 FullShapeTK (TKScalar (ADTensorScalar r))
azftk
env :: AstEnv Concrete
env = AstVarName PrimalSpan (TKScalar (ADTensorScalar r))
-> Concrete (TKScalar (ADTensorScalar r))
-> AstEnv Concrete
-> AstEnv Concrete
forall (target :: Target) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName PrimalSpan (ADTensorKind (TKScalar r))
AstVarName PrimalSpan (TKScalar (ADTensorScalar r))
artVarDtRev Concrete (TKScalar (ADTensorScalar r))
oneAtF
(AstEnv Concrete -> AstEnv Concrete)
-> AstEnv Concrete -> AstEnv Concrete
forall a b. (a -> b) -> a -> b
$ AstVarName PrimalSpan x
-> Concrete x -> AstEnv Concrete -> AstEnv Concrete
forall (target :: Target) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName PrimalSpan x
artVarDomainRev Concrete x
parameters AstEnv Concrete
forall (target :: Target). AstEnv target
emptyEnv
in if SingletonTK x -> RepConcrete x -> FullShapeTK x
forall (y :: TK). SingletonTK y -> RepConcrete y -> FullShapeTK y
tftkG (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
xftk) (Concrete x -> RepConcrete x
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete x
parameters) FullShapeTK x -> FullShapeTK x -> Bool
forall a. Eq a => a -> a -> Bool
== FullShapeTK x
xftk
then Concrete (X avals) -> avals
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget (Concrete (X avals) -> avals) -> Concrete (X avals) -> avals
forall a b. (a -> b) -> a -> b
$ AstEnv Concrete
-> AstTensor AstMethodLet PrimalSpan (ADTensorKind x)
-> PrimalOf Concrete (ADTensorKind x)
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target
-> AstTensor AstMethodLet PrimalSpan y -> PrimalOf target y
interpretAstPrimal AstEnv Concrete
env AstTensor AstMethodLet PrimalSpan (ADTensorKind x)
artDerivativeRev
else [Char] -> avals
forall a. HasCallStack => [Char] -> a
error [Char]
"gradInterpretArtifact: reverse derivative parameters must have the same shape as the domain of the objective function"
vjpInterpretArtifact
:: forall x z avals.
(X avals ~ ADTensorKind x, AdaptableTarget Concrete avals)
=> AstArtifactRev x z
-> Concrete x
-> Concrete (ADTensorKind z)
-> avals
{-# INLINE vjpInterpretArtifact #-}
vjpInterpretArtifact :: forall (x :: TK) (z :: TK) avals.
((X avals :: TK) ~ (ADTensorKind x :: TK),
AdaptableTarget Concrete avals) =>
AstArtifactRev x z
-> Concrete x -> Concrete (ADTensorKind z) -> avals
vjpInterpretArtifact AstArtifactRev{AstTensor AstMethodLet PrimalSpan z
AstTensor AstMethodLet PrimalSpan (ADTensorKind x)
AstVarName PrimalSpan x
AstVarName PrimalSpan (ADTensorKind z)
artPrimalRev :: forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstTensor AstMethodLet PrimalSpan z
artDerivativeRev :: forall (x :: TK) (z :: TK).
AstArtifactRev x z
-> AstTensor AstMethodLet PrimalSpan (ADTensorKind x)
artVarDomainRev :: forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstVarName PrimalSpan x
artVarDtRev :: forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstVarName PrimalSpan (ADTensorKind z)
artVarDtRev :: AstVarName PrimalSpan (ADTensorKind z)
artVarDomainRev :: AstVarName PrimalSpan x
artDerivativeRev :: AstTensor AstMethodLet PrimalSpan (ADTensorKind x)
artPrimalRev :: AstTensor AstMethodLet PrimalSpan z
..} Concrete x
parameters Concrete (ADTensorKind z)
dt =
let xftk :: FullShapeTK x
xftk = AstVarName PrimalSpan x -> FullShapeTK x
forall (s :: AstSpanType) (y :: TK).
AstVarName s y -> FullShapeTK y
varNameToFTK AstVarName PrimalSpan x
artVarDomainRev
azftk :: FullShapeTK (ADTensorKind z)
azftk = AstVarName PrimalSpan (ADTensorKind z)
-> FullShapeTK (ADTensorKind z)
forall (s :: AstSpanType) (y :: TK).
AstVarName s y -> FullShapeTK y
varNameToFTK AstVarName PrimalSpan (ADTensorKind z)
artVarDtRev
env :: AstEnv Concrete
env = AstVarName PrimalSpan (ADTensorKind z)
-> Concrete (ADTensorKind z) -> AstEnv Concrete -> AstEnv Concrete
forall (target :: Target) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName PrimalSpan (ADTensorKind z)
artVarDtRev Concrete (ADTensorKind z)
dt
(AstEnv Concrete -> AstEnv Concrete)
-> AstEnv Concrete -> AstEnv Concrete
forall a b. (a -> b) -> a -> b
$ AstVarName PrimalSpan x
-> Concrete x -> AstEnv Concrete -> AstEnv Concrete
forall (target :: Target) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName PrimalSpan x
artVarDomainRev Concrete x
parameters AstEnv Concrete
forall (target :: Target). AstEnv target
emptyEnv
in if SingletonTK x -> RepConcrete x -> FullShapeTK x
forall (y :: TK). SingletonTK y -> RepConcrete y -> FullShapeTK y
tftkG (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
xftk) (Concrete x -> RepConcrete x
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete x
parameters) FullShapeTK x -> FullShapeTK x -> Bool
forall a. Eq a => a -> a -> Bool
== FullShapeTK x
xftk
then if SingletonTK (ADTensorKind z)
-> RepConcrete (ADTensorKind z) -> FullShapeTK (ADTensorKind z)
forall (y :: TK). SingletonTK y -> RepConcrete y -> FullShapeTK y
tftkG (FullShapeTK (ADTensorKind z) -> SingletonTK (ADTensorKind z)
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK (ADTensorKind z)
azftk) (Concrete (ADTensorKind z) -> RepConcrete (ADTensorKind z)
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete (ADTensorKind z)
dt) FullShapeTK (ADTensorKind z)
-> FullShapeTK (ADTensorKind z) -> Bool
forall a. Eq a => a -> a -> Bool
== FullShapeTK (ADTensorKind z)
azftk
then Concrete (X avals) -> avals
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget (Concrete (X avals) -> avals) -> Concrete (X avals) -> avals
forall a b. (a -> b) -> a -> b
$ AstEnv Concrete
-> AstTensor AstMethodLet PrimalSpan (ADTensorKind x)
-> PrimalOf Concrete (ADTensorKind x)
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target
-> AstTensor AstMethodLet PrimalSpan y -> PrimalOf target y
interpretAstPrimal AstEnv Concrete
env AstTensor AstMethodLet PrimalSpan (ADTensorKind x)
artDerivativeRev
else [Char] -> avals
forall a. HasCallStack => [Char] -> a
error [Char]
"vjpInterpretArtifact: reverse derivative incoming cotangent must have the same shape as the codomain of the objective function"
else [Char] -> avals
forall a. HasCallStack => [Char] -> a
error [Char]
"vjpInterpretArtifact: reverse derivative parameters must have the same shape as the domain of the objective function"
revMaybe
:: forall src ztgt tgt.
( X src ~ X (Value src), KnownSTK (X src)
, AdaptableTarget (AstTensor AstMethodLet FullSpan) src
, AdaptableTarget Concrete (Value src)
, tgt ~ AstTensor AstMethodLet FullSpan ztgt )
=> (src -> tgt)
-> Value src
-> Maybe (Concrete (ADTensorKind ztgt))
-> Value src
{-# INLINE revMaybe #-}
revMaybe :: forall src (ztgt :: TK) tgt.
((X src :: TK) ~ (X (Value src) :: TK), KnownSTK (X src),
AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
AdaptableTarget Concrete (Value src),
(tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
(src -> tgt)
-> Value src -> Maybe (Concrete (ADTensorKind ztgt)) -> Value src
revMaybe src -> tgt
f Value src
vals0 Maybe (Concrete (ADTensorKind ztgt))
mdt =
let valsTarget :: Concrete (X (Value src))
valsTarget = Value src -> Concrete (X (Value src))
forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget Value src
vals0
xftk :: FullShapeTK (X src)
xftk = SingletonTK (X src) -> RepConcrete (X src) -> FullShapeTK (X src)
forall (y :: TK). SingletonTK y -> RepConcrete y -> FullShapeTK y
tftkG (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(X src)) (RepConcrete (X src) -> FullShapeTK (X src))
-> RepConcrete (X src) -> FullShapeTK (X src)
forall a b. (a -> b) -> a -> b
$ Concrete (X src) -> RepConcrete (X src)
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete (X src)
Concrete (X (Value src))
valsTarget
cotangentHandling :: IncomingCotangentHandling
cotangentHandling =
IncomingCotangentHandling
-> (Concrete (ADTensorKind ztgt) -> IncomingCotangentHandling)
-> Maybe (Concrete (ADTensorKind ztgt))
-> IncomingCotangentHandling
forall b a. b -> (a -> b) -> Maybe a -> b
maybe IncomingCotangentHandling
IgnoreIncomingCotangent (IncomingCotangentHandling
-> Concrete (ADTensorKind ztgt) -> IncomingCotangentHandling
forall a b. a -> b -> a
const IncomingCotangentHandling
UseIncomingCotangent) Maybe (Concrete (ADTensorKind ztgt))
mdt
artifactRaw :: AstArtifactRev (X src) ztgt
artifactRaw = IncomingCotangentHandling
-> (src -> tgt)
-> FullShapeTK (X src)
-> AstArtifactRev (X src) ztgt
forall src (ztgt :: TK) tgt.
(AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
(tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
IncomingCotangentHandling
-> (src -> tgt)
-> FullShapeTK (X src)
-> AstArtifactRev (X src) ztgt
revArtifactAdapt IncomingCotangentHandling
cotangentHandling src -> tgt
f FullShapeTK (X src)
xftk
artifact :: AstArtifactRev (X src) ztgt
artifact = AstArtifactRev (X src) ztgt -> AstArtifactRev (X src) ztgt
forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstArtifactRev x z
simplifyArtifactGradient AstArtifactRev (X src) ztgt
artifactRaw
in Concrete (X (Value src)) -> Value src
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget (Concrete (X (Value src)) -> Value src)
-> Concrete (X (Value src)) -> Value src
forall a b. (a -> b) -> a -> b
$ SingletonTK (X (Value src))
-> Concrete (ADTensorKind (X (Value src)))
-> Concrete (X (Value src))
forall (target :: Target) (y :: TK).
(BaseTensor target, ConvertTensor target, ShareTensor target) =>
SingletonTK y -> target (ADTensorKind y) -> target y
fromADTensorKindShared (FullShapeTK (X src) -> SingletonTK (X src)
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK (X src)
xftk)
(Concrete (ADTensorKind (X (Value src)))
-> Concrete (X (Value src)))
-> Concrete (ADTensorKind (X (Value src)))
-> Concrete (X (Value src))
forall a b. (a -> b) -> a -> b
$ (Concrete (ADTensorKind (X (Value src))), Concrete ztgt)
-> Concrete (ADTensorKind (X (Value src)))
forall a b. (a, b) -> a
fst ((Concrete (ADTensorKind (X (Value src))), Concrete ztgt)
-> Concrete (ADTensorKind (X (Value src))))
-> (Concrete (ADTensorKind (X (Value src))), Concrete ztgt)
-> Concrete (ADTensorKind (X (Value src)))
forall a b. (a -> b) -> a -> b
$ AstArtifactRev (X src) ztgt
-> Concrete (X src)
-> Maybe (Concrete (ADTensorKind ztgt))
-> (Concrete (ADTensorKind (X src)), Concrete ztgt)
forall (x :: TK) (z :: TK).
AstArtifactRev x z
-> Concrete x
-> Maybe (Concrete (ADTensorKind z))
-> (Concrete (ADTensorKind x), Concrete z)
revInterpretArtifact AstArtifactRev (X src) ztgt
artifact Concrete (X src)
Concrete (X (Value src))
valsTarget Maybe (Concrete (ADTensorKind ztgt))
mdt
revArtifactAdapt
:: forall src ztgt tgt.
( AdaptableTarget (AstTensor AstMethodLet FullSpan) src
, tgt ~ AstTensor AstMethodLet FullSpan ztgt )
=> IncomingCotangentHandling
-> (src -> tgt)
-> FullShapeTK (X src)
-> AstArtifactRev (X src) ztgt
{-# INLINE revArtifactAdapt #-}
revArtifactAdapt :: forall src (ztgt :: TK) tgt.
(AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
(tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
IncomingCotangentHandling
-> (src -> tgt)
-> FullShapeTK (X src)
-> AstArtifactRev (X src) ztgt
revArtifactAdapt IncomingCotangentHandling
cotangentHandling src -> tgt
f FullShapeTK (X src)
xftk =
let g :: AstTensor AstMethodLet FullSpan (X src) -> tgt
g :: AstTensor AstMethodLet FullSpan (X src) -> tgt
g !AstTensor AstMethodLet FullSpan (X src)
arg = AstTensor AstMethodLet FullSpan ztgt
-> AstTensor AstMethodLet FullSpan ztgt
forall (z :: TK) (s :: AstSpanType).
AstSpan s =>
AstTensor AstMethodLet s z -> AstTensor AstMethodLet s z
simplifyInline (AstTensor AstMethodLet FullSpan ztgt
-> AstTensor AstMethodLet FullSpan ztgt)
-> AstTensor AstMethodLet FullSpan ztgt
-> AstTensor AstMethodLet FullSpan ztgt
forall a b. (a -> b) -> a -> b
$ AstTensor AstMethodLet FullSpan (X src)
-> (AstTensor AstMethodLet FullSpan (X src)
-> AstTensor AstMethodLet FullSpan ztgt)
-> AstTensor AstMethodLet FullSpan ztgt
forall (x :: TK) (z :: TK).
AstTensor AstMethodLet FullSpan x
-> (AstTensor AstMethodLet FullSpan x
-> AstTensor AstMethodLet FullSpan z)
-> AstTensor AstMethodLet FullSpan z
forall (target :: Target) (x :: TK) (z :: TK).
LetTensor target =>
target x -> (target x -> target z) -> target z
ttlet AstTensor AstMethodLet FullSpan (X src)
arg ((AstTensor AstMethodLet FullSpan (X src)
-> AstTensor AstMethodLet FullSpan ztgt)
-> AstTensor AstMethodLet FullSpan ztgt)
-> (AstTensor AstMethodLet FullSpan (X src)
-> AstTensor AstMethodLet FullSpan ztgt)
-> AstTensor AstMethodLet FullSpan ztgt
forall a b. (a -> b) -> a -> b
$ src -> tgt
src -> AstTensor AstMethodLet FullSpan ztgt
f (src -> AstTensor AstMethodLet FullSpan ztgt)
-> (AstTensor AstMethodLet FullSpan (X src) -> src)
-> AstTensor AstMethodLet FullSpan (X src)
-> AstTensor AstMethodLet FullSpan ztgt
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AstTensor AstMethodLet FullSpan (X src) -> src
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget
in IncomingCotangentHandling
-> (AstTensor AstMethodLet FullSpan (X src)
-> AstTensor AstMethodLet FullSpan ztgt)
-> AstEnv (ADVal (AstRaw PrimalSpan))
-> FullShapeTK (X src)
-> AstArtifactRev (X src) ztgt
forall (x :: TK) (z :: TK).
IncomingCotangentHandling
-> (AstTensor AstMethodLet FullSpan x
-> AstTensor AstMethodLet FullSpan z)
-> AstEnv (ADVal (AstRaw PrimalSpan))
-> FullShapeTK x
-> AstArtifactRev x z
revProduceArtifact IncomingCotangentHandling
cotangentHandling AstTensor AstMethodLet FullSpan (X src) -> tgt
AstTensor AstMethodLet FullSpan (X src)
-> AstTensor AstMethodLet FullSpan ztgt
g AstEnv (ADVal (AstRaw PrimalSpan))
forall (target :: Target). AstEnv target
emptyEnv FullShapeTK (X src)
xftk
revInterpretArtifact
:: forall x z.
AstArtifactRev x z
-> Concrete x
-> Maybe (Concrete (ADTensorKind z))
-> (Concrete (ADTensorKind x), Concrete z)
{-# INLINE revInterpretArtifact #-}
revInterpretArtifact :: forall (x :: TK) (z :: TK).
AstArtifactRev x z
-> Concrete x
-> Maybe (Concrete (ADTensorKind z))
-> (Concrete (ADTensorKind x), Concrete z)
revInterpretArtifact AstArtifactRev{AstTensor AstMethodLet PrimalSpan z
AstTensor AstMethodLet PrimalSpan (ADTensorKind x)
AstVarName PrimalSpan x
AstVarName PrimalSpan (ADTensorKind z)
artPrimalRev :: forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstTensor AstMethodLet PrimalSpan z
artDerivativeRev :: forall (x :: TK) (z :: TK).
AstArtifactRev x z
-> AstTensor AstMethodLet PrimalSpan (ADTensorKind x)
artVarDomainRev :: forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstVarName PrimalSpan x
artVarDtRev :: forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstVarName PrimalSpan (ADTensorKind z)
artVarDtRev :: AstVarName PrimalSpan (ADTensorKind z)
artVarDomainRev :: AstVarName PrimalSpan x
artDerivativeRev :: AstTensor AstMethodLet PrimalSpan (ADTensorKind x)
artPrimalRev :: AstTensor AstMethodLet PrimalSpan z
..} Concrete x
parameters Maybe (Concrete (ADTensorKind z))
mdt =
let azftk :: FullShapeTK (ADTensorKind z)
azftk = AstVarName PrimalSpan (ADTensorKind z)
-> FullShapeTK (ADTensorKind z)
forall (s :: AstSpanType) (y :: TK).
AstVarName s y -> FullShapeTK y
varNameToFTK AstVarName PrimalSpan (ADTensorKind z)
artVarDtRev
env :: AstEnv Concrete
env = AstVarName PrimalSpan x
-> Concrete x -> AstEnv Concrete -> AstEnv Concrete
forall (target :: Target) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName PrimalSpan x
artVarDomainRev Concrete x
parameters AstEnv Concrete
forall (target :: Target). AstEnv target
emptyEnv
envDt :: AstEnv Concrete
envDt = case Maybe (Concrete (ADTensorKind z))
mdt of
Maybe (Concrete (ADTensorKind z))
Nothing ->
let oneAtF :: Concrete (ADTensorKind z)
oneAtF = (forall r. GoodScalar r => r)
-> FullShapeTK (ADTensorKind z) -> Concrete (ADTensorKind z)
forall (y :: TK).
(forall r. GoodScalar r => r) -> FullShapeTK y -> Concrete y
forall (target :: Target) (y :: TK).
BaseTensor target =>
(forall r. GoodScalar r => r) -> FullShapeTK y -> target y
treplTarget r
forall r. GoodScalar r => r
1 FullShapeTK (ADTensorKind z)
azftk
in AstVarName PrimalSpan (ADTensorKind z)
-> Concrete (ADTensorKind z) -> AstEnv Concrete -> AstEnv Concrete
forall (target :: Target) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName PrimalSpan (ADTensorKind z)
artVarDtRev Concrete (ADTensorKind z)
oneAtF AstEnv Concrete
env
Just Concrete (ADTensorKind z)
dt ->
if SingletonTK (ADTensorKind z)
-> RepConcrete (ADTensorKind z) -> FullShapeTK (ADTensorKind z)
forall (y :: TK). SingletonTK y -> RepConcrete y -> FullShapeTK y
tftkG (FullShapeTK (ADTensorKind z) -> SingletonTK (ADTensorKind z)
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK (ADTensorKind z)
azftk) (Concrete (ADTensorKind z) -> RepConcrete (ADTensorKind z)
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete (ADTensorKind z)
dt) FullShapeTK (ADTensorKind z)
-> FullShapeTK (ADTensorKind z) -> Bool
forall a. Eq a => a -> a -> Bool
== FullShapeTK (ADTensorKind z)
azftk
then AstVarName PrimalSpan (ADTensorKind z)
-> Concrete (ADTensorKind z) -> AstEnv Concrete -> AstEnv Concrete
forall (target :: Target) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName PrimalSpan (ADTensorKind z)
artVarDtRev Concrete (ADTensorKind z)
dt AstEnv Concrete
env
else [Char] -> AstEnv Concrete
forall a. HasCallStack => [Char] -> a
error [Char]
"revInterpretArtifact: reverse derivative incoming cotangent must have the same shape as the codomain of the objective function"
gradient :: PrimalOf Concrete (ADTensorKind x)
gradient = AstEnv Concrete
-> AstTensor AstMethodLet PrimalSpan (ADTensorKind x)
-> PrimalOf Concrete (ADTensorKind x)
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target
-> AstTensor AstMethodLet PrimalSpan y -> PrimalOf target y
interpretAstPrimal AstEnv Concrete
envDt AstTensor AstMethodLet PrimalSpan (ADTensorKind x)
artDerivativeRev
primal :: PrimalOf Concrete z
primal = AstEnv Concrete
-> AstTensor AstMethodLet PrimalSpan z -> PrimalOf Concrete z
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target
-> AstTensor AstMethodLet PrimalSpan y -> PrimalOf target y
interpretAstPrimal AstEnv Concrete
env AstTensor AstMethodLet PrimalSpan z
artPrimalRev
in (PrimalOf Concrete (ADTensorKind x)
Concrete (ADTensorKind x)
gradient, PrimalOf Concrete z
Concrete z
primal)
revArtifactDelta
:: forall src ztgt tgt.
( AdaptableTarget (AstTensor AstMethodLet FullSpan) src
, tgt ~ AstTensor AstMethodLet FullSpan ztgt )
=> IncomingCotangentHandling
-> (src -> tgt)
-> FullShapeTK (X src)
-> (AstArtifactRev (X src) ztgt, Delta (AstRaw PrimalSpan) ztgt)
{-# INLINE revArtifactDelta #-}
revArtifactDelta :: forall src (ztgt :: TK) tgt.
(AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
(tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
IncomingCotangentHandling
-> (src -> tgt)
-> FullShapeTK (X src)
-> (AstArtifactRev (X src) ztgt, Delta (AstRaw PrimalSpan) ztgt)
revArtifactDelta IncomingCotangentHandling
cotangentHandling src -> tgt
f FullShapeTK (X src)
xftk =
let g :: AstTensor AstMethodLet FullSpan (X src) -> tgt
g :: AstTensor AstMethodLet FullSpan (X src) -> tgt
g !AstTensor AstMethodLet FullSpan (X src)
arg = AstTensor AstMethodLet FullSpan (X src)
-> (AstTensor AstMethodLet FullSpan (X src)
-> AstTensor AstMethodLet FullSpan ztgt)
-> AstTensor AstMethodLet FullSpan ztgt
forall (x :: TK) (z :: TK).
AstTensor AstMethodLet FullSpan x
-> (AstTensor AstMethodLet FullSpan x
-> AstTensor AstMethodLet FullSpan z)
-> AstTensor AstMethodLet FullSpan z
forall (target :: Target) (x :: TK) (z :: TK).
LetTensor target =>
target x -> (target x -> target z) -> target z
ttlet AstTensor AstMethodLet FullSpan (X src)
arg ((AstTensor AstMethodLet FullSpan (X src)
-> AstTensor AstMethodLet FullSpan ztgt)
-> AstTensor AstMethodLet FullSpan ztgt)
-> (AstTensor AstMethodLet FullSpan (X src)
-> AstTensor AstMethodLet FullSpan ztgt)
-> AstTensor AstMethodLet FullSpan ztgt
forall a b. (a -> b) -> a -> b
$ src -> tgt
src -> AstTensor AstMethodLet FullSpan ztgt
f (src -> AstTensor AstMethodLet FullSpan ztgt)
-> (AstTensor AstMethodLet FullSpan (X src) -> src)
-> AstTensor AstMethodLet FullSpan (X src)
-> AstTensor AstMethodLet FullSpan ztgt
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AstTensor AstMethodLet FullSpan (X src) -> src
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget
in IncomingCotangentHandling
-> (AstTensor AstMethodShare PrimalSpan (X src)
-> AstVarName FullSpan (X src)
-> AstTensor AstMethodLet FullSpan (X src)
-> ADVal (AstRaw PrimalSpan) ztgt)
-> FullShapeTK (X src)
-> (AstArtifactRev (X src) ztgt, Delta (AstRaw PrimalSpan) ztgt)
forall (x :: TK) (z :: TK).
IncomingCotangentHandling
-> (AstTensor AstMethodShare PrimalSpan x
-> AstVarName FullSpan x
-> AstTensor AstMethodLet FullSpan x
-> ADVal (AstRaw PrimalSpan) z)
-> FullShapeTK x
-> (AstArtifactRev x z, Delta (AstRaw PrimalSpan) z)
revArtifactFromForwardPass IncomingCotangentHandling
cotangentHandling
((AstTensor AstMethodLet FullSpan (X src)
-> AstTensor AstMethodLet FullSpan ztgt)
-> AstEnv (ADVal (AstRaw PrimalSpan))
-> AstTensor AstMethodShare PrimalSpan (X src)
-> AstVarName FullSpan (X src)
-> AstTensor AstMethodLet FullSpan (X src)
-> ADVal (AstRaw PrimalSpan) ztgt
forall (x :: TK) (z :: TK).
(AstTensor AstMethodLet FullSpan x
-> AstTensor AstMethodLet FullSpan z)
-> AstEnv (ADVal (AstRaw PrimalSpan))
-> AstTensor AstMethodShare PrimalSpan x
-> AstVarName FullSpan x
-> AstTensor AstMethodLet FullSpan x
-> ADVal (AstRaw PrimalSpan) z
forwardPassByInterpretation AstTensor AstMethodLet FullSpan (X src) -> tgt
AstTensor AstMethodLet FullSpan (X src)
-> AstTensor AstMethodLet FullSpan ztgt
g AstEnv (ADVal (AstRaw PrimalSpan))
forall (target :: Target). AstEnv target
emptyEnv) FullShapeTK (X src)
xftk
revProduceArtifactWithoutInterpretation
:: forall x z.
IncomingCotangentHandling
-> (ADVal (AstRaw PrimalSpan) x -> ADVal (AstRaw PrimalSpan) z)
-> FullShapeTK x
-> (AstArtifactRev x z, Delta (AstRaw PrimalSpan) z)
{-# INLINE revProduceArtifactWithoutInterpretation #-}
revProduceArtifactWithoutInterpretation :: forall (x :: TK) (z :: TK).
IncomingCotangentHandling
-> (ADVal (AstRaw PrimalSpan) x -> ADVal (AstRaw PrimalSpan) z)
-> FullShapeTK x
-> (AstArtifactRev x z, Delta (AstRaw PrimalSpan) z)
revProduceArtifactWithoutInterpretation IncomingCotangentHandling
cotangentHandling ADVal (AstRaw PrimalSpan) x -> ADVal (AstRaw PrimalSpan) z
f FullShapeTK x
xftk =
IncomingCotangentHandling
-> (AstTensor AstMethodShare PrimalSpan x
-> AstVarName FullSpan x
-> AstTensor AstMethodLet FullSpan x
-> ADVal (AstRaw PrimalSpan) z)
-> FullShapeTK x
-> (AstArtifactRev x z, Delta (AstRaw PrimalSpan) z)
forall (x :: TK) (z :: TK).
IncomingCotangentHandling
-> (AstTensor AstMethodShare PrimalSpan x
-> AstVarName FullSpan x
-> AstTensor AstMethodLet FullSpan x
-> ADVal (AstRaw PrimalSpan) z)
-> FullShapeTK x
-> (AstArtifactRev x z, Delta (AstRaw PrimalSpan) z)
revArtifactFromForwardPass IncomingCotangentHandling
cotangentHandling
((ADVal (AstRaw PrimalSpan) x -> ADVal (AstRaw PrimalSpan) z)
-> AstTensor AstMethodShare PrimalSpan x
-> AstVarName FullSpan x
-> AstTensor AstMethodLet FullSpan x
-> ADVal (AstRaw PrimalSpan) z
forall (x :: TK) (z :: TK).
(ADVal (AstRaw PrimalSpan) x -> ADVal (AstRaw PrimalSpan) z)
-> AstTensor AstMethodShare PrimalSpan x
-> AstVarName FullSpan x
-> AstTensor AstMethodLet FullSpan x
-> ADVal (AstRaw PrimalSpan) z
forwardPassByApplication ADVal (AstRaw PrimalSpan) x -> ADVal (AstRaw PrimalSpan) z
f)
FullShapeTK x
xftk
forwardPassByApplication
:: forall x z.
(ADVal (AstRaw PrimalSpan) x -> ADVal (AstRaw PrimalSpan) z)
-> AstTensor AstMethodShare PrimalSpan x
-> AstVarName FullSpan x
-> AstTensor AstMethodLet FullSpan x
-> ADVal (AstRaw PrimalSpan) z
{-# INLINE forwardPassByApplication #-}
forwardPassByApplication :: forall (x :: TK) (z :: TK).
(ADVal (AstRaw PrimalSpan) x -> ADVal (AstRaw PrimalSpan) z)
-> AstTensor AstMethodShare PrimalSpan x
-> AstVarName FullSpan x
-> AstTensor AstMethodLet FullSpan x
-> ADVal (AstRaw PrimalSpan) z
forwardPassByApplication ADVal (AstRaw PrimalSpan) x -> ADVal (AstRaw PrimalSpan) z
g AstTensor AstMethodShare PrimalSpan x
astVarPrimal AstVarName FullSpan x
var AstTensor AstMethodLet FullSpan x
_astVar =
let deltaInputs :: Delta (AstRaw PrimalSpan) x
deltaInputs = FullShapeTK x -> Delta (AstRaw PrimalSpan) x
forall (x :: TK) (target :: Target).
FullShapeTK x -> Delta target x
generateDeltaInputs (FullShapeTK x -> Delta (AstRaw PrimalSpan) x)
-> FullShapeTK x -> Delta (AstRaw PrimalSpan) x
forall a b. (a -> b) -> a -> b
$ AstVarName FullSpan x -> FullShapeTK x
forall (s :: AstSpanType) (y :: TK).
AstVarName s y -> FullShapeTK y
varNameToFTK AstVarName FullSpan x
var
varInputs :: ADVal (AstRaw PrimalSpan) x
varInputs = AstRaw PrimalSpan x
-> Delta (AstRaw PrimalSpan) x -> ADVal (AstRaw PrimalSpan) x
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dDnotShared (AstTensor AstMethodShare PrimalSpan x -> AstRaw PrimalSpan x
forall (s :: AstSpanType) (y :: TK).
AstTensor AstMethodShare s y -> AstRaw s y
AstRaw AstTensor AstMethodShare PrimalSpan x
astVarPrimal) Delta (AstRaw PrimalSpan) x
deltaInputs
in ADVal (AstRaw PrimalSpan) x -> ADVal (AstRaw PrimalSpan) z
g ADVal (AstRaw PrimalSpan) x
varInputs
jvp
:: forall src ztgt tgt.
( X src ~ X (Value src), KnownSTK (X src)
, AdaptableTarget (AstTensor AstMethodLet FullSpan) src
, AdaptableTarget Concrete (Value src)
, tgt ~ AstTensor AstMethodLet FullSpan ztgt )
=> (src -> tgt)
-> Value src
-> Value src
-> Concrete (ADTensorKind ztgt)
{-# INLINE jvp #-}
jvp :: forall src (ztgt :: TK) tgt.
((X src :: TK) ~ (X (Value src) :: TK), KnownSTK (X src),
AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
AdaptableTarget Concrete (Value src),
(tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
(src -> tgt)
-> Value src -> Value src -> Concrete (ADTensorKind ztgt)
jvp src -> tgt
f Value src
vals0 Value src
ds =
let valsTarget :: Concrete (X (Value src))
valsTarget = Value src -> Concrete (X (Value src))
forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget Value src
vals0
xftk :: FullShapeTK (X src)
xftk = SingletonTK (X src) -> RepConcrete (X src) -> FullShapeTK (X src)
forall (y :: TK). SingletonTK y -> RepConcrete y -> FullShapeTK y
tftkG (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(X src)) (RepConcrete (X src) -> FullShapeTK (X src))
-> RepConcrete (X src) -> FullShapeTK (X src)
forall a b. (a -> b) -> a -> b
$ Concrete (X src) -> RepConcrete (X src)
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete (X src)
Concrete (X (Value src))
valsTarget
artifactRaw :: AstArtifactFwd (X src) ztgt
artifactRaw = (src -> tgt) -> FullShapeTK (X src) -> AstArtifactFwd (X src) ztgt
forall src (ztgt :: TK) tgt.
(AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
(tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
(src -> tgt) -> FullShapeTK (X src) -> AstArtifactFwd (X src) ztgt
fwdArtifactAdapt src -> tgt
f FullShapeTK (X src)
xftk
artifact :: AstArtifactFwd (X src) ztgt
artifact = AstArtifactFwd (X src) ztgt -> AstArtifactFwd (X src) ztgt
forall (x :: TK) (z :: TK).
AstArtifactFwd x z -> AstArtifactFwd x z
simplifyArtifactDerivative AstArtifactFwd (X src) ztgt
artifactRaw
in (Concrete (ADTensorKind ztgt), Concrete ztgt)
-> Concrete (ADTensorKind ztgt)
forall a b. (a, b) -> a
fst ((Concrete (ADTensorKind ztgt), Concrete ztgt)
-> Concrete (ADTensorKind ztgt))
-> (Concrete (ADTensorKind ztgt), Concrete ztgt)
-> Concrete (ADTensorKind ztgt)
forall a b. (a -> b) -> a -> b
$ AstArtifactFwd (X src) ztgt
-> Concrete (X src)
-> Concrete (ADTensorKind (X src))
-> (Concrete (ADTensorKind ztgt), Concrete ztgt)
forall (x :: TK) (z :: TK).
AstArtifactFwd x z
-> Concrete x
-> Concrete (ADTensorKind x)
-> (Concrete (ADTensorKind z), Concrete z)
fwdInterpretArtifact AstArtifactFwd (X src) ztgt
artifact Concrete (X src)
Concrete (X (Value src))
valsTarget
(Concrete (ADTensorKind (X src))
-> (Concrete (ADTensorKind ztgt), Concrete ztgt))
-> Concrete (ADTensorKind (X src))
-> (Concrete (ADTensorKind ztgt), Concrete ztgt)
forall a b. (a -> b) -> a -> b
$ FullShapeTK (X src)
-> Concrete (X src) -> Concrete (ADTensorKind (X src))
forall (target :: Target) (y :: TK).
(BaseTensor target, ConvertTensor target, ShareTensor target) =>
FullShapeTK y -> target y -> target (ADTensorKind y)
toADTensorKindShared FullShapeTK (X src)
xftk (Value src -> Concrete (X (Value src))
forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget Value src
ds)
jvpArtifact
:: forall src ztgt tgt.
( X src ~ X (Value src), KnownSTK (X src)
, AdaptableTarget (AstTensor AstMethodLet FullSpan) src
, AdaptableTarget Concrete (Value src)
, tgt ~ AstTensor AstMethodLet FullSpan ztgt )
=> (src -> tgt)
-> Value src
-> AstArtifactFwd (X src) ztgt
{-# INLINE jvpArtifact #-}
jvpArtifact :: forall src (ztgt :: TK) tgt.
((X src :: TK) ~ (X (Value src) :: TK), KnownSTK (X src),
AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
AdaptableTarget Concrete (Value src),
(tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
(src -> tgt) -> Value src -> AstArtifactFwd (X src) ztgt
jvpArtifact src -> tgt
f Value src
vals0 =
let xftk :: FullShapeTK (X src)
xftk = SingletonTK (X src) -> RepConcrete (X src) -> FullShapeTK (X src)
forall (y :: TK). SingletonTK y -> RepConcrete y -> FullShapeTK y
tftkG (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(X src)) (RepConcrete (X src) -> FullShapeTK (X src))
-> RepConcrete (X src) -> FullShapeTK (X src)
forall a b. (a -> b) -> a -> b
$ Concrete (X (Value src)) -> RepConcrete (X (Value src))
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete (Concrete (X (Value src)) -> RepConcrete (X (Value src)))
-> Concrete (X (Value src)) -> RepConcrete (X (Value src))
forall a b. (a -> b) -> a -> b
$ Value src -> Concrete (X (Value src))
forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget Value src
vals0
in (src -> tgt) -> FullShapeTK (X src) -> AstArtifactFwd (X src) ztgt
forall src (ztgt :: TK) tgt.
(AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
(tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
(src -> tgt) -> FullShapeTK (X src) -> AstArtifactFwd (X src) ztgt
fwdArtifactAdapt src -> tgt
f FullShapeTK (X src)
xftk
jvpInterpretArtifact
:: forall x z.
AstArtifactFwd x z
-> Concrete x
-> Concrete (ADTensorKind x)
-> Concrete (ADTensorKind z)
{-# INLINE jvpInterpretArtifact #-}
jvpInterpretArtifact :: forall (x :: TK) (z :: TK).
AstArtifactFwd x z
-> Concrete x
-> Concrete (ADTensorKind x)
-> Concrete (ADTensorKind z)
jvpInterpretArtifact AstArtifactFwd x z
art Concrete x
parameters = (Concrete (ADTensorKind z), Concrete z)
-> Concrete (ADTensorKind z)
forall a b. (a, b) -> a
fst ((Concrete (ADTensorKind z), Concrete z)
-> Concrete (ADTensorKind z))
-> (Concrete (ADTensorKind x)
-> (Concrete (ADTensorKind z), Concrete z))
-> Concrete (ADTensorKind x)
-> Concrete (ADTensorKind z)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AstArtifactFwd x z
-> Concrete x
-> Concrete (ADTensorKind x)
-> (Concrete (ADTensorKind z), Concrete z)
forall (x :: TK) (z :: TK).
AstArtifactFwd x z
-> Concrete x
-> Concrete (ADTensorKind x)
-> (Concrete (ADTensorKind z), Concrete z)
fwdInterpretArtifact AstArtifactFwd x z
art Concrete x
parameters
fwdArtifactAdapt
:: forall src ztgt tgt.
( AdaptableTarget (AstTensor AstMethodLet FullSpan) src
, tgt ~ AstTensor AstMethodLet FullSpan ztgt )
=> (src -> tgt)
-> FullShapeTK (X src)
-> AstArtifactFwd (X src) ztgt
{-# INLINE fwdArtifactAdapt #-}
fwdArtifactAdapt :: forall src (ztgt :: TK) tgt.
(AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
(tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
(src -> tgt) -> FullShapeTK (X src) -> AstArtifactFwd (X src) ztgt
fwdArtifactAdapt src -> tgt
f FullShapeTK (X src)
xftk =
let g :: AstTensor AstMethodLet FullSpan (X src) -> tgt
g :: AstTensor AstMethodLet FullSpan (X src) -> tgt
g !AstTensor AstMethodLet FullSpan (X src)
arg = AstTensor AstMethodLet FullSpan ztgt
-> AstTensor AstMethodLet FullSpan ztgt
forall (z :: TK) (s :: AstSpanType).
AstSpan s =>
AstTensor AstMethodLet s z -> AstTensor AstMethodLet s z
simplifyInline (AstTensor AstMethodLet FullSpan ztgt
-> AstTensor AstMethodLet FullSpan ztgt)
-> AstTensor AstMethodLet FullSpan ztgt
-> AstTensor AstMethodLet FullSpan ztgt
forall a b. (a -> b) -> a -> b
$ AstTensor AstMethodLet FullSpan (X src)
-> (AstTensor AstMethodLet FullSpan (X src)
-> AstTensor AstMethodLet FullSpan ztgt)
-> AstTensor AstMethodLet FullSpan ztgt
forall (x :: TK) (z :: TK).
AstTensor AstMethodLet FullSpan x
-> (AstTensor AstMethodLet FullSpan x
-> AstTensor AstMethodLet FullSpan z)
-> AstTensor AstMethodLet FullSpan z
forall (target :: Target) (x :: TK) (z :: TK).
LetTensor target =>
target x -> (target x -> target z) -> target z
ttlet AstTensor AstMethodLet FullSpan (X src)
arg ((AstTensor AstMethodLet FullSpan (X src)
-> AstTensor AstMethodLet FullSpan ztgt)
-> AstTensor AstMethodLet FullSpan ztgt)
-> (AstTensor AstMethodLet FullSpan (X src)
-> AstTensor AstMethodLet FullSpan ztgt)
-> AstTensor AstMethodLet FullSpan ztgt
forall a b. (a -> b) -> a -> b
$ src -> tgt
src -> AstTensor AstMethodLet FullSpan ztgt
f (src -> AstTensor AstMethodLet FullSpan ztgt)
-> (AstTensor AstMethodLet FullSpan (X src) -> src)
-> AstTensor AstMethodLet FullSpan (X src)
-> AstTensor AstMethodLet FullSpan ztgt
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AstTensor AstMethodLet FullSpan (X src) -> src
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget
in (AstTensor AstMethodLet FullSpan (X src)
-> AstTensor AstMethodLet FullSpan ztgt)
-> AstEnv (ADVal (AstRaw PrimalSpan))
-> FullShapeTK (X src)
-> AstArtifactFwd (X src) ztgt
forall (x :: TK) (z :: TK).
(AstTensor AstMethodLet FullSpan x
-> AstTensor AstMethodLet FullSpan z)
-> AstEnv (ADVal (AstRaw PrimalSpan))
-> FullShapeTK x
-> AstArtifactFwd x z
fwdProduceArtifact AstTensor AstMethodLet FullSpan (X src) -> tgt
AstTensor AstMethodLet FullSpan (X src)
-> AstTensor AstMethodLet FullSpan ztgt
g AstEnv (ADVal (AstRaw PrimalSpan))
forall (target :: Target). AstEnv target
emptyEnv FullShapeTK (X src)
xftk
fwdInterpretArtifact
:: forall x z.
AstArtifactFwd x z
-> Concrete x
-> Concrete (ADTensorKind x)
-> (Concrete (ADTensorKind z), Concrete z)
{-# INLINE fwdInterpretArtifact #-}
fwdInterpretArtifact :: forall (x :: TK) (z :: TK).
AstArtifactFwd x z
-> Concrete x
-> Concrete (ADTensorKind x)
-> (Concrete (ADTensorKind z), Concrete z)
fwdInterpretArtifact AstArtifactFwd{AstTensor AstMethodLet PrimalSpan z
AstTensor AstMethodLet PrimalSpan (ADTensorKind z)
AstVarName PrimalSpan x
AstVarName PrimalSpan (ADTensorKind x)
artVarDsFwd :: AstVarName PrimalSpan (ADTensorKind x)
artVarDomainFwd :: AstVarName PrimalSpan x
artDerivativeFwd :: AstTensor AstMethodLet PrimalSpan (ADTensorKind z)
artPrimalFwd :: AstTensor AstMethodLet PrimalSpan z
artPrimalFwd :: forall (x :: TK) (z :: TK).
AstArtifactFwd x z -> AstTensor AstMethodLet PrimalSpan z
artDerivativeFwd :: forall (x :: TK) (z :: TK).
AstArtifactFwd x z
-> AstTensor AstMethodLet PrimalSpan (ADTensorKind z)
artVarDomainFwd :: forall (x :: TK) (z :: TK).
AstArtifactFwd x z -> AstVarName PrimalSpan x
artVarDsFwd :: forall (x :: TK) (z :: TK).
AstArtifactFwd x z -> AstVarName PrimalSpan (ADTensorKind x)
..} Concrete x
parameters Concrete (ADTensorKind x)
ds =
let xftk :: FullShapeTK x
xftk = AstVarName PrimalSpan x -> FullShapeTK x
forall (s :: AstSpanType) (y :: TK).
AstVarName s y -> FullShapeTK y
varNameToFTK AstVarName PrimalSpan x
artVarDomainFwd
xstk :: SingletonTK x
xstk = FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
xftk
env :: AstEnv Concrete
env = AstVarName PrimalSpan x
-> Concrete x -> AstEnv Concrete -> AstEnv Concrete
forall (target :: Target) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName PrimalSpan x
artVarDomainFwd Concrete x
parameters AstEnv Concrete
forall (target :: Target). AstEnv target
emptyEnv
envD :: AstEnv Concrete
envD = AstVarName PrimalSpan (ADTensorKind x)
-> Concrete (ADTensorKind x) -> AstEnv Concrete -> AstEnv Concrete
forall (target :: Target) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName PrimalSpan (ADTensorKind x)
artVarDsFwd Concrete (ADTensorKind x)
ds AstEnv Concrete
env
in if SingletonTK x -> RepConcrete x -> FullShapeTK x
forall (y :: TK). SingletonTK y -> RepConcrete y -> FullShapeTK y
tftkG SingletonTK x
xstk (Concrete x -> RepConcrete x
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete x
parameters) FullShapeTK x -> FullShapeTK x -> Bool
forall a. Eq a => a -> a -> Bool
== FullShapeTK x
xftk
then if SingletonTK (ADTensorKind x)
-> RepConcrete (ADTensorKind x) -> FullShapeTK (ADTensorKind x)
forall (y :: TK). SingletonTK y -> RepConcrete y -> FullShapeTK y
tftkG (SingletonTK x -> SingletonTK (ADTensorKind x)
forall (y :: TK). SingletonTK y -> SingletonTK (ADTensorKind y)
adSTK SingletonTK x
xstk) (Concrete (ADTensorKind x) -> RepConcrete (ADTensorKind x)
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete (ADTensorKind x)
ds) FullShapeTK (ADTensorKind x)
-> FullShapeTK (ADTensorKind x) -> Bool
forall a. Eq a => a -> a -> Bool
== FullShapeTK x -> FullShapeTK (ADTensorKind x)
forall (y :: TK). FullShapeTK y -> FullShapeTK (ADTensorKind y)
adFTK FullShapeTK x
xftk
then let derivative :: PrimalOf Concrete (ADTensorKind z)
derivative = AstEnv Concrete
-> AstTensor AstMethodLet PrimalSpan (ADTensorKind z)
-> PrimalOf Concrete (ADTensorKind z)
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target
-> AstTensor AstMethodLet PrimalSpan y -> PrimalOf target y
interpretAstPrimal AstEnv Concrete
envD AstTensor AstMethodLet PrimalSpan (ADTensorKind z)
artDerivativeFwd
primal :: PrimalOf Concrete z
primal = AstEnv Concrete
-> AstTensor AstMethodLet PrimalSpan z -> PrimalOf Concrete z
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target
-> AstTensor AstMethodLet PrimalSpan y -> PrimalOf target y
interpretAstPrimal AstEnv Concrete
env AstTensor AstMethodLet PrimalSpan z
artPrimalFwd
in (PrimalOf Concrete (ADTensorKind z)
Concrete (ADTensorKind z)
derivative, PrimalOf Concrete z
Concrete z
primal)
else [Char] -> (Concrete (ADTensorKind z), Concrete z)
forall a. HasCallStack => [Char] -> a
error [Char]
"fwdInterpretArtifact: forward derivative perturbation must have the same shape as the domain of the objective function"
else [Char] -> (Concrete (ADTensorKind z), Concrete z)
forall a. HasCallStack => [Char] -> a
error [Char]
"fwdInterpretArtifact: forward derivative input must have the same shape as the domain of the objective function"
fwdArtifactDelta
:: forall src ztgt tgt.
( AdaptableTarget (AstTensor AstMethodLet FullSpan) src
, tgt ~ AstTensor AstMethodLet FullSpan ztgt )
=> (src -> tgt)
-> FullShapeTK (X src)
-> (AstArtifactFwd (X src) ztgt, Delta (AstRaw PrimalSpan) ztgt)
{-# INLINE fwdArtifactDelta #-}
fwdArtifactDelta :: forall src (ztgt :: TK) tgt.
(AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
(tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
(src -> tgt)
-> FullShapeTK (X src)
-> (AstArtifactFwd (X src) ztgt, Delta (AstRaw PrimalSpan) ztgt)
fwdArtifactDelta src -> tgt
f FullShapeTK (X src)
xftk =
let g :: AstTensor AstMethodLet FullSpan (X src) -> tgt
g :: AstTensor AstMethodLet FullSpan (X src) -> tgt
g !AstTensor AstMethodLet FullSpan (X src)
arg = AstTensor AstMethodLet FullSpan (X src)
-> (AstTensor AstMethodLet FullSpan (X src)
-> AstTensor AstMethodLet FullSpan ztgt)
-> AstTensor AstMethodLet FullSpan ztgt
forall (x :: TK) (z :: TK).
AstTensor AstMethodLet FullSpan x
-> (AstTensor AstMethodLet FullSpan x
-> AstTensor AstMethodLet FullSpan z)
-> AstTensor AstMethodLet FullSpan z
forall (target :: Target) (x :: TK) (z :: TK).
LetTensor target =>
target x -> (target x -> target z) -> target z
ttlet AstTensor AstMethodLet FullSpan (X src)
arg ((AstTensor AstMethodLet FullSpan (X src)
-> AstTensor AstMethodLet FullSpan ztgt)
-> AstTensor AstMethodLet FullSpan ztgt)
-> (AstTensor AstMethodLet FullSpan (X src)
-> AstTensor AstMethodLet FullSpan ztgt)
-> AstTensor AstMethodLet FullSpan ztgt
forall a b. (a -> b) -> a -> b
$ src -> tgt
src -> AstTensor AstMethodLet FullSpan ztgt
f (src -> AstTensor AstMethodLet FullSpan ztgt)
-> (AstTensor AstMethodLet FullSpan (X src) -> src)
-> AstTensor AstMethodLet FullSpan (X src)
-> AstTensor AstMethodLet FullSpan ztgt
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AstTensor AstMethodLet FullSpan (X src) -> src
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget
in (AstTensor AstMethodShare PrimalSpan (X src)
-> AstVarName FullSpan (X src)
-> AstTensor AstMethodLet FullSpan (X src)
-> ADVal (AstRaw PrimalSpan) ztgt)
-> FullShapeTK (X src)
-> (AstArtifactFwd (X src) ztgt, Delta (AstRaw PrimalSpan) ztgt)
forall (x :: TK) (z :: TK).
(AstTensor AstMethodShare PrimalSpan x
-> AstVarName FullSpan x
-> AstTensor AstMethodLet FullSpan x
-> ADVal (AstRaw PrimalSpan) z)
-> FullShapeTK x
-> (AstArtifactFwd x z, Delta (AstRaw PrimalSpan) z)
fwdArtifactFromForwardPass ((AstTensor AstMethodLet FullSpan (X src)
-> AstTensor AstMethodLet FullSpan ztgt)
-> AstEnv (ADVal (AstRaw PrimalSpan))
-> AstTensor AstMethodShare PrimalSpan (X src)
-> AstVarName FullSpan (X src)
-> AstTensor AstMethodLet FullSpan (X src)
-> ADVal (AstRaw PrimalSpan) ztgt
forall (x :: TK) (z :: TK).
(AstTensor AstMethodLet FullSpan x
-> AstTensor AstMethodLet FullSpan z)
-> AstEnv (ADVal (AstRaw PrimalSpan))
-> AstTensor AstMethodShare PrimalSpan x
-> AstVarName FullSpan x
-> AstTensor AstMethodLet FullSpan x
-> ADVal (AstRaw PrimalSpan) z
forwardPassByInterpretation AstTensor AstMethodLet FullSpan (X src) -> tgt
AstTensor AstMethodLet FullSpan (X src)
-> AstTensor AstMethodLet FullSpan ztgt
g AstEnv (ADVal (AstRaw PrimalSpan))
forall (target :: Target). AstEnv target
emptyEnv) FullShapeTK (X src)
xftk
cgrad
:: forall src r tgt.
( X src ~ X (DValue src), KnownSTK (X src)
, AdaptableTarget (ADVal Concrete) src
, AdaptableTarget Concrete (DValue src)
, tgt ~ ADVal Concrete (TKScalar r) )
=> (src -> tgt)
-> DValue src
-> DValue src
{-# INLINE cgrad #-}
cgrad :: forall src r tgt.
((X src :: TK) ~ (X (DValue src) :: TK), KnownSTK (X src),
AdaptableTarget (ADVal Concrete) src,
AdaptableTarget Concrete (DValue src),
(tgt :: Type) ~ (ADVal Concrete (TKScalar r) :: Type)) =>
(src -> tgt) -> DValue src -> DValue src
cgrad src -> tgt
f DValue src
vals = (src -> tgt)
-> DValue src
-> Maybe (Concrete (ADTensorKind (TKScalar r)))
-> DValue src
forall src (ztgt :: TK) tgt.
((X src :: TK) ~ (X (DValue src) :: TK), KnownSTK (X src),
AdaptableTarget (ADVal Concrete) src,
AdaptableTarget Concrete (DValue src),
(tgt :: Type) ~ (ADVal Concrete ztgt :: Type)) =>
(src -> tgt)
-> DValue src -> Maybe (Concrete (ADTensorKind ztgt)) -> DValue src
crevMaybe src -> tgt
f DValue src
vals Maybe (Concrete (ADTensorKind (TKScalar r)))
Maybe (Concrete (TKScalar (ADTensorScalar r)))
forall a. Maybe a
Nothing
cvjp
:: forall src ztgt tgt.
( X src ~ X (DValue src), KnownSTK (X src)
, AdaptableTarget (ADVal Concrete) src
, AdaptableTarget Concrete (DValue src)
, tgt ~ ADVal Concrete ztgt )
=> (src -> tgt)
-> DValue src
-> Concrete (ADTensorKind ztgt)
-> DValue src
{-# INLINE cvjp #-}
cvjp :: forall src (ztgt :: TK) tgt.
((X src :: TK) ~ (X (DValue src) :: TK), KnownSTK (X src),
AdaptableTarget (ADVal Concrete) src,
AdaptableTarget Concrete (DValue src),
(tgt :: Type) ~ (ADVal Concrete ztgt :: Type)) =>
(src -> tgt)
-> DValue src -> Concrete (ADTensorKind ztgt) -> DValue src
cvjp src -> tgt
f DValue src
vals Concrete (ADTensorKind ztgt)
dt = (src -> tgt)
-> DValue src -> Maybe (Concrete (ADTensorKind ztgt)) -> DValue src
forall src (ztgt :: TK) tgt.
((X src :: TK) ~ (X (DValue src) :: TK), KnownSTK (X src),
AdaptableTarget (ADVal Concrete) src,
AdaptableTarget Concrete (DValue src),
(tgt :: Type) ~ (ADVal Concrete ztgt :: Type)) =>
(src -> tgt)
-> DValue src -> Maybe (Concrete (ADTensorKind ztgt)) -> DValue src
crevMaybe src -> tgt
f DValue src
vals (Concrete (ADTensorKind ztgt)
-> Maybe (Concrete (ADTensorKind ztgt))
forall a. a -> Maybe a
Just Concrete (ADTensorKind ztgt)
dt)
crevMaybe
:: forall src ztgt tgt.
( X src ~ X (DValue src), KnownSTK (X src)
, AdaptableTarget (ADVal Concrete) src
, AdaptableTarget Concrete (DValue src)
, tgt ~ ADVal Concrete ztgt )
=> (src -> tgt)
-> DValue src
-> Maybe (Concrete (ADTensorKind ztgt))
-> DValue src
{-# INLINE crevMaybe #-}
crevMaybe :: forall src (ztgt :: TK) tgt.
((X src :: TK) ~ (X (DValue src) :: TK), KnownSTK (X src),
AdaptableTarget (ADVal Concrete) src,
AdaptableTarget Concrete (DValue src),
(tgt :: Type) ~ (ADVal Concrete ztgt :: Type)) =>
(src -> tgt)
-> DValue src -> Maybe (Concrete (ADTensorKind ztgt)) -> DValue src
crevMaybe src -> tgt
f DValue src
vals0 Maybe (Concrete (ADTensorKind ztgt))
mdt =
let valsTarget :: Concrete (X (DValue src))
valsTarget = DValue src -> Concrete (X (DValue src))
forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget DValue src
vals0
g :: ADVal Concrete (X src) -> tgt
g :: ADVal Concrete (X src) -> tgt
g = src -> tgt
f (src -> tgt)
-> (ADVal Concrete (X src) -> src) -> ADVal Concrete (X src) -> tgt
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ADVal Concrete (X src) -> src
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget
xftk :: FullShapeTK (X src)
xftk = SingletonTK (X src) -> RepConcrete (X src) -> FullShapeTK (X src)
forall (y :: TK). SingletonTK y -> RepConcrete y -> FullShapeTK y
tftkG (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(X src)) (RepConcrete (X src) -> FullShapeTK (X src))
-> RepConcrete (X src) -> FullShapeTK (X src)
forall a b. (a -> b) -> a -> b
$ Concrete (X src) -> RepConcrete (X src)
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete (X src)
Concrete (X (DValue src))
valsTarget
in Concrete (X (DValue src)) -> DValue src
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget (Concrete (X (DValue src)) -> DValue src)
-> Concrete (X (DValue src)) -> DValue src
forall a b. (a -> b) -> a -> b
$ SingletonTK (X (DValue src))
-> Concrete (ADTensorKind (X (DValue src)))
-> Concrete (X (DValue src))
forall (target :: Target) (y :: TK).
(BaseTensor target, ConvertTensor target, ShareTensor target) =>
SingletonTK y -> target (ADTensorKind y) -> target y
fromADTensorKindShared (FullShapeTK (X src) -> SingletonTK (X src)
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK (X src)
xftk)
(Concrete (ADTensorKind (X (DValue src)))
-> Concrete (X (DValue src)))
-> Concrete (ADTensorKind (X (DValue src)))
-> Concrete (X (DValue src))
forall a b. (a -> b) -> a -> b
$ (Concrete (ADTensorKind (X (DValue src))), Concrete ztgt)
-> Concrete (ADTensorKind (X (DValue src)))
forall a b. (a, b) -> a
fst ((Concrete (ADTensorKind (X (DValue src))), Concrete ztgt)
-> Concrete (ADTensorKind (X (DValue src))))
-> (Concrete (ADTensorKind (X (DValue src))), Concrete ztgt)
-> Concrete (ADTensorKind (X (DValue src)))
forall a b. (a -> b) -> a -> b
$ Maybe (Concrete (ADTensorKind ztgt))
-> (ADVal Concrete (X src) -> ADVal Concrete ztgt)
-> FullShapeTK (X src)
-> Concrete (X src)
-> (Concrete (ADTensorKind (X src)), Concrete ztgt)
forall (x :: TK) (z :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target) =>
Maybe (target (ADTensorKind z))
-> (ADVal target x -> ADVal target z)
-> FullShapeTK x
-> target x
-> (target (ADTensorKind x), target z)
crevOnParams Maybe (Concrete (ADTensorKind ztgt))
mdt ADVal Concrete (X src) -> tgt
ADVal Concrete (X src) -> ADVal Concrete ztgt
g FullShapeTK (X src)
xftk Concrete (X src)
Concrete (X (DValue src))
valsTarget
cjvp
:: forall src ztgt tgt.
( X src ~ X (DValue src), KnownSTK (X src)
, AdaptableTarget (ADVal Concrete) src
, AdaptableTarget Concrete (DValue src)
, tgt ~ ADVal Concrete ztgt )
=> (src -> tgt)
-> DValue src
-> DValue src
-> Concrete (ADTensorKind ztgt)
{-# INLINE cjvp #-}
cjvp :: forall src (ztgt :: TK) tgt.
((X src :: TK) ~ (X (DValue src) :: TK), KnownSTK (X src),
AdaptableTarget (ADVal Concrete) src,
AdaptableTarget Concrete (DValue src),
(tgt :: Type) ~ (ADVal Concrete ztgt :: Type)) =>
(src -> tgt)
-> DValue src -> DValue src -> Concrete (ADTensorKind ztgt)
cjvp src -> tgt
f DValue src
vals DValue src
ds = (Concrete (ADTensorKind ztgt), Concrete ztgt)
-> Concrete (ADTensorKind ztgt)
forall a b. (a, b) -> a
fst ((Concrete (ADTensorKind ztgt), Concrete ztgt)
-> Concrete (ADTensorKind ztgt))
-> (Concrete (ADTensorKind ztgt), Concrete ztgt)
-> Concrete (ADTensorKind ztgt)
forall a b. (a -> b) -> a -> b
$ (src -> tgt)
-> DValue src
-> DValue src
-> (Concrete (ADTensorKind ztgt), Concrete ztgt)
forall src (ztgt :: TK) tgt.
((X src :: TK) ~ (X (DValue src) :: TK), KnownSTK (X src),
AdaptableTarget (ADVal Concrete) src,
AdaptableTarget Concrete (DValue src),
(tgt :: Type) ~ (ADVal Concrete ztgt :: Type)) =>
(src -> tgt)
-> DValue src
-> DValue src
-> (Concrete (ADTensorKind ztgt), Concrete ztgt)
cfwdBoth src -> tgt
f DValue src
vals DValue src
ds
cfwdBoth
:: forall src ztgt tgt.
( X src ~ X (DValue src), KnownSTK (X src)
, AdaptableTarget (ADVal Concrete) src
, AdaptableTarget Concrete (DValue src)
, tgt ~ ADVal Concrete ztgt )
=> (src -> tgt)
-> DValue src
-> DValue src
-> (Concrete (ADTensorKind ztgt), Concrete ztgt)
{-# INLINE cfwdBoth #-}
cfwdBoth :: forall src (ztgt :: TK) tgt.
((X src :: TK) ~ (X (DValue src) :: TK), KnownSTK (X src),
AdaptableTarget (ADVal Concrete) src,
AdaptableTarget Concrete (DValue src),
(tgt :: Type) ~ (ADVal Concrete ztgt :: Type)) =>
(src -> tgt)
-> DValue src
-> DValue src
-> (Concrete (ADTensorKind ztgt), Concrete ztgt)
cfwdBoth src -> tgt
f DValue src
vals0 DValue src
ds =
let valsTarget :: Concrete (X (DValue src))
valsTarget = DValue src -> Concrete (X (DValue src))
forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget DValue src
vals0
xftk :: FullShapeTK (X src)
xftk = SingletonTK (X src) -> RepConcrete (X src) -> FullShapeTK (X src)
forall (y :: TK). SingletonTK y -> RepConcrete y -> FullShapeTK y
tftkG (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(X src)) (RepConcrete (X src) -> FullShapeTK (X src))
-> RepConcrete (X src) -> FullShapeTK (X src)
forall a b. (a -> b) -> a -> b
$ Concrete (X src) -> RepConcrete (X src)
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete (X src)
Concrete (X (DValue src))
valsTarget
g :: ADVal Concrete (X src) -> tgt
g :: ADVal Concrete (X src) -> tgt
g = src -> tgt
f (src -> tgt)
-> (ADVal Concrete (X src) -> src) -> ADVal Concrete (X src) -> tgt
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ADVal Concrete (X src) -> src
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget
dsTarget :: Concrete (X (DValue src))
dsTarget = DValue src -> Concrete (X (DValue src))
forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget DValue src
ds
in if SingletonTK (X src) -> RepConcrete (X src) -> FullShapeTK (X src)
forall (y :: TK). SingletonTK y -> RepConcrete y -> FullShapeTK y
tftkG (FullShapeTK (X src) -> SingletonTK (X src)
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK (X src)
xftk) (Concrete (X src) -> RepConcrete (X src)
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete (X src)
Concrete (X (DValue src))
dsTarget) FullShapeTK (X src) -> FullShapeTK (X src) -> Bool
forall a. Eq a => a -> a -> Bool
== FullShapeTK (X src)
xftk
then FullShapeTK (X src)
-> Concrete (X src)
-> (ADVal Concrete (X src) -> ADVal Concrete ztgt)
-> Concrete (ADTensorKind (X src))
-> (Concrete (ADTensorKind ztgt), Concrete ztgt)
forall (x :: TK) (z :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target) =>
FullShapeTK x
-> target x
-> (ADVal target x -> ADVal target z)
-> target (ADTensorKind x)
-> (target (ADTensorKind z), target z)
cfwdOnParams FullShapeTK (X src)
xftk Concrete (X src)
Concrete (X (DValue src))
valsTarget ADVal Concrete (X src) -> tgt
ADVal Concrete (X src) -> ADVal Concrete ztgt
g
(Concrete (ADTensorKind (X src))
-> (Concrete (ADTensorKind ztgt), Concrete ztgt))
-> Concrete (ADTensorKind (X src))
-> (Concrete (ADTensorKind ztgt), Concrete ztgt)
forall a b. (a -> b) -> a -> b
$ FullShapeTK (X src)
-> Concrete (X src) -> Concrete (ADTensorKind (X src))
forall (target :: Target) (y :: TK).
(BaseTensor target, ConvertTensor target, ShareTensor target) =>
FullShapeTK y -> target y -> target (ADTensorKind y)
toADTensorKindShared FullShapeTK (X src)
xftk Concrete (X src)
Concrete (X (DValue src))
dsTarget
else [Char] -> (Concrete (ADTensorKind ztgt), Concrete ztgt)
forall a. HasCallStack => [Char] -> a
error [Char]
"cfwdBoth: forward derivative input must have the same shape as the perturbation argument"
{-# SPECIALIZE gradientFromDelta :: FullShapeTK x -> FullShapeTK z -> Concrete (ADTensorKind z) -> Delta Concrete z -> Concrete (ADTensorKind x) #-}
{-# SPECIALIZE evalRev :: FullShapeTK y -> EvalState Concrete -> Concrete (ADTensorKind y) -> Delta Concrete y -> EvalState Concrete #-}
{-# SPECIALIZE evalRevFTK :: EvalState Concrete -> Concrete (ADTensorKind y) -> Delta Concrete y -> EvalState Concrete #-}
{-# SPECIALIZE evalRevFromnMap :: EvalState Concrete -> EvalState Concrete #-}
{-# SPECIALIZE evalRevSame :: EvalState Concrete -> Concrete (TKScalar Double) -> Delta Concrete (TKScalar Double) -> EvalState Concrete #-}
{-# SPECIALIZE evalRevSame :: EvalState Concrete -> Concrete (TKScalar Float) -> Delta Concrete (TKScalar Float) -> EvalState Concrete #-}
{-# SPECIALIZE evalRevSame :: EvalState Concrete -> Concrete (TKR n Double) -> Delta Concrete (TKR n Double) -> EvalState Concrete #-}
{-# SPECIALIZE evalRevSame :: EvalState Concrete -> Concrete (TKR n Float) -> Delta Concrete (TKR n Float) -> EvalState Concrete #-}
{-# SPECIALIZE evalRevSame :: EvalState Concrete -> Concrete (TKS sh Double) -> Delta Concrete (TKS sh Double) -> EvalState Concrete #-}
{-# SPECIALIZE evalRevSame :: EvalState Concrete -> Concrete (TKS sh Float) -> Delta Concrete (TKS sh Float) -> EvalState Concrete #-}
{-# SPECIALIZE evalRevSame :: EvalState Concrete -> Concrete (TKX sh Double) -> Delta Concrete (TKX sh Double) -> EvalState Concrete #-}
{-# SPECIALIZE evalRevSame :: EvalState Concrete -> Concrete (TKX sh Float) -> Delta Concrete (TKX sh Float) -> EvalState Concrete #-}
{-# SPECIALIZE interpretAst
:: AstEnv (ADVal Concrete)
-> AstTensor AstMethodLet PrimalSpan y
-> ADVal Concrete y #-}
{-# SPECIALIZE interpretAst
:: AstEnv (ADVal (AstRaw PrimalSpan))
-> AstTensor AstMethodLet PrimalSpan y
-> ADVal (AstRaw PrimalSpan) y #-}
{-# SPECIALIZE interpretAst
:: AstEnv Concrete
-> AstTensor AstMethodLet PrimalSpan y
-> Concrete y #-}
{-# SPECIALIZE interpretAst
:: AstEnv (ADVal Concrete)
-> AstTensor AstMethodLet DualSpan y
-> ADVal Concrete y #-}
{-# SPECIALIZE interpretAst
:: AstEnv (ADVal (AstRaw PrimalSpan))
-> AstTensor AstMethodLet DualSpan y
-> ADVal (AstRaw PrimalSpan) y #-}
{-# SPECIALIZE interpretAst
:: AstEnv Concrete
-> AstTensor AstMethodLet DualSpan y
-> Concrete y #-}
{-# SPECIALIZE interpretAst
:: AstEnv (ADVal Concrete)
-> AstTensor AstMethodLet FullSpan y
-> ADVal Concrete y #-}
{-# SPECIALIZE interpretAst
:: AstEnv (ADVal (AstRaw PrimalSpan))
-> AstTensor AstMethodLet FullSpan y
-> ADVal (AstRaw PrimalSpan) y #-}
{-# SPECIALIZE interpretAst
:: AstEnv Concrete
-> AstTensor AstMethodLet FullSpan y
-> Concrete y #-}
{-# SPECIALIZE interpretAstPrimal
:: AstEnv (ADVal Concrete)
-> AstTensor AstMethodLet PrimalSpan y
-> Concrete y #-}
{-# SPECIALIZE interpretAstPrimal
:: AstEnv (ADVal (AstRaw PrimalSpan))
-> AstTensor AstMethodLet PrimalSpan y
-> AstRaw PrimalSpan y #-}
{-# SPECIALIZE interpretAstPrimal
:: AstEnv Concrete
-> AstTensor AstMethodLet PrimalSpan y
-> Concrete y #-}
{-# SPECIALIZE interpretAstBool
:: AstEnv (ADVal Concrete)
-> AstBool AstMethodLet
-> Bool #-}
{-# SPECIALIZE interpretAstBool
:: AstEnv (ADVal (AstRaw PrimalSpan))
-> AstBool AstMethodLet
-> AstBool AstMethodShare #-}
{-# SPECIALIZE interpretAstBool
:: AstEnv Concrete
-> AstBool AstMethodLet
-> Bool #-}