Safe Haskell | None |
---|---|
Language | GHC2024 |
HordeAd.ADEngine
Description
The implementation of reverse derivative and forward derivative calculation for an objective function on values of complicated types, e.g., nested tuples of tensors.
The objective function can be defined as a sufficiently polymorphic
Haskell function that uses numeric classes as well as the multi-dimensional
tensor operation listed in HordeAd.OpsTensor. To obtain symbolic
derivatives (derivative code that can be executed many times without
performing AD again), the user needs an objective function polymorphic
enough so that it can be instantiated to the AstTensor
type (nested in tuples, etc., for some extra flexibility).
For non-symbolic derivatives, the ability to instantiate to the
ADVal
type of dual numbers is enough.
See the classes these types are instances of to gauge the breadth
of the offered respective APIs.
Synopsis
- grad :: forall src r tgt. (X src ~ X (Value src), KnownSTK (X src), AdaptableTarget (AstTensor 'AstMethodLet 'FullSpan) src, AdaptableTarget Concrete (Value src), tgt ~ AstTensor 'AstMethodLet 'FullSpan ('TKScalar r)) => (src -> tgt) -> Value src -> Value src
- vjp :: forall src (ztgt :: TK) tgt. (X src ~ X (Value src), KnownSTK (X src), AdaptableTarget (AstTensor 'AstMethodLet 'FullSpan) src, AdaptableTarget Concrete (Value src), tgt ~ AstTensor 'AstMethodLet 'FullSpan ztgt) => (src -> tgt) -> Value src -> Concrete (ADTensorKind ztgt) -> Value src
- gradArtifact :: forall src r tgt. (X src ~ X (Value src), KnownSTK (X src), AdaptableTarget (AstTensor 'AstMethodLet 'FullSpan) src, AdaptableTarget Concrete (Value src), tgt ~ AstTensor 'AstMethodLet 'FullSpan ('TKScalar r)) => (src -> tgt) -> Value src -> AstArtifactRev (X src) ('TKScalar r)
- vjpArtifact :: forall src (ztgt :: TK) tgt. (X src ~ X (Value src), KnownSTK (X src), AdaptableTarget (AstTensor 'AstMethodLet 'FullSpan) src, AdaptableTarget Concrete (Value src), tgt ~ AstTensor 'AstMethodLet 'FullSpan ztgt) => (src -> tgt) -> Value src -> AstArtifactRev (X src) ztgt
- gradInterpretArtifact :: forall (x :: TK) r avals. (X avals ~ ADTensorKind x, AdaptableTarget Concrete avals) => AstArtifactRev x ('TKScalar r) -> Concrete x -> avals
- vjpInterpretArtifact :: forall (x :: TK) (z :: TK) avals. (X avals ~ ADTensorKind x, AdaptableTarget Concrete avals) => AstArtifactRev x z -> Concrete x -> Concrete (ADTensorKind z) -> avals
- jvp :: forall src (ztgt :: TK) tgt. (X src ~ X (Value src), KnownSTK (X src), AdaptableTarget (AstTensor 'AstMethodLet 'FullSpan) src, AdaptableTarget Concrete (Value src), tgt ~ AstTensor 'AstMethodLet 'FullSpan ztgt) => (src -> tgt) -> Value src -> Value src -> Concrete (ADTensorKind ztgt)
- jvpArtifact :: forall src (ztgt :: TK) tgt. (X src ~ X (Value src), KnownSTK (X src), AdaptableTarget (AstTensor 'AstMethodLet 'FullSpan) src, AdaptableTarget Concrete (Value src), tgt ~ AstTensor 'AstMethodLet 'FullSpan ztgt) => (src -> tgt) -> Value src -> AstArtifactFwd (X src) ztgt
- jvpInterpretArtifact :: forall (x :: TK) (z :: TK). AstArtifactFwd x z -> Concrete x -> Concrete (ADTensorKind x) -> Concrete (ADTensorKind z)
- cgrad :: forall src r tgt. (X src ~ X (DValue src), KnownSTK (X src), AdaptableTarget (ADVal Concrete) src, AdaptableTarget Concrete (DValue src), tgt ~ ADVal Concrete ('TKScalar r)) => (src -> tgt) -> DValue src -> DValue src
- cvjp :: forall src (ztgt :: TK) tgt. (X src ~ X (DValue src), KnownSTK (X src), AdaptableTarget (ADVal Concrete) src, AdaptableTarget Concrete (DValue src), tgt ~ ADVal Concrete ztgt) => (src -> tgt) -> DValue src -> Concrete (ADTensorKind ztgt) -> DValue src
- cjvp :: forall src (ztgt :: TK) tgt. (X src ~ X (DValue src), KnownSTK (X src), AdaptableTarget (ADVal Concrete) src, AdaptableTarget Concrete (DValue src), tgt ~ ADVal Concrete ztgt) => (src -> tgt) -> DValue src -> DValue src -> Concrete (ADTensorKind ztgt)
- data IncomingCotangentHandling
- revArtifactAdapt :: forall src (ztgt :: TK) tgt. (AdaptableTarget (AstTensor 'AstMethodLet 'FullSpan) src, tgt ~ AstTensor 'AstMethodLet 'FullSpan ztgt) => IncomingCotangentHandling -> (src -> tgt) -> FullShapeTK (X src) -> AstArtifactRev (X src) ztgt
- revArtifactDelta :: forall src (ztgt :: TK) tgt. (AdaptableTarget (AstTensor 'AstMethodLet 'FullSpan) src, tgt ~ AstTensor 'AstMethodLet 'FullSpan ztgt) => IncomingCotangentHandling -> (src -> tgt) -> FullShapeTK (X src) -> (AstArtifactRev (X src) ztgt, Delta (AstRaw 'PrimalSpan) ztgt)
- revProduceArtifactWithoutInterpretation :: forall (x :: TK) (z :: TK). IncomingCotangentHandling -> (ADVal (AstRaw 'PrimalSpan) x -> ADVal (AstRaw 'PrimalSpan) z) -> FullShapeTK x -> (AstArtifactRev x z, Delta (AstRaw 'PrimalSpan) z)
- revInterpretArtifact :: forall (x :: TK) (z :: TK). AstArtifactRev x z -> Concrete x -> Maybe (Concrete (ADTensorKind z)) -> (Concrete (ADTensorKind x), Concrete z)
- fwdArtifactAdapt :: forall src (ztgt :: TK) tgt. (AdaptableTarget (AstTensor 'AstMethodLet 'FullSpan) src, tgt ~ AstTensor 'AstMethodLet 'FullSpan ztgt) => (src -> tgt) -> FullShapeTK (X src) -> AstArtifactFwd (X src) ztgt
- fwdArtifactDelta :: forall src (ztgt :: TK) tgt. (AdaptableTarget (AstTensor 'AstMethodLet 'FullSpan) src, tgt ~ AstTensor 'AstMethodLet 'FullSpan ztgt) => (src -> tgt) -> FullShapeTK (X src) -> (AstArtifactFwd (X src) ztgt, Delta (AstRaw 'PrimalSpan) ztgt)
- fwdInterpretArtifact :: forall (x :: TK) (z :: TK). AstArtifactFwd x z -> Concrete x -> Concrete (ADTensorKind x) -> (Concrete (ADTensorKind z), Concrete z)
- cfwdBoth :: forall src (ztgt :: TK) tgt. (X src ~ X (DValue src), KnownSTK (X src), AdaptableTarget (ADVal Concrete) src, AdaptableTarget Concrete (DValue src), tgt ~ ADVal Concrete ztgt) => (src -> tgt) -> DValue src -> DValue src -> (Concrete (ADTensorKind ztgt), Concrete ztgt)
Symbolic reverse derivative adaptors
Arguments
:: forall src r tgt. (X src ~ X (Value src), KnownSTK (X src), AdaptableTarget (AstTensor 'AstMethodLet 'FullSpan) src, AdaptableTarget Concrete (Value src), tgt ~ AstTensor 'AstMethodLet 'FullSpan ('TKScalar r)) | |
=> (src -> tgt) | the objective function |
-> Value src | |
-> Value src |
This simplified version of the symbolic reverse derivative operation
sets the incoming cotangent dt
to be 1 and assumes the codomain
of the function to be differentiated is a scalar.
We don't enforce (e.g., by quantifcation) that the objective function is closed, because we evaluate the result of the differentiation down to concrete arrays and so there's no risk of "perturbation confusion" between different levels of differentiation if it's done multiple times. For simplicity of the type signature, the resulting value is converted from the type of concrete contangents to the type of concrete input parameters.
Arguments
:: forall src (ztgt :: TK) tgt. (X src ~ X (Value src), KnownSTK (X src), AdaptableTarget (AstTensor 'AstMethodLet 'FullSpan) src, AdaptableTarget Concrete (Value src), tgt ~ AstTensor 'AstMethodLet 'FullSpan ztgt) | |
=> (src -> tgt) | the objective function |
-> Value src | |
-> Concrete (ADTensorKind ztgt) | |
-> Value src |
This version of the symbolic reverse derivative operation explicitly takes the sensitivity parameter (the incoming cotangent). It also permits an arbitrary (nested tuple+) type of the domain and arbitrary (nested pair) tensor kind of the codomain of the function to be differentiated. The downside of the generality is that if the function doesn't have an explicit type signature, the type to which this operation is instantiated often has to be spelled in full via explicit type applications to aid type reconstruction. For simplicity of the type signature, the resulting value is converted from the type of concrete contangents to the type of concrete input parameters.
Arguments
:: forall src r tgt. (X src ~ X (Value src), KnownSTK (X src), AdaptableTarget (AstTensor 'AstMethodLet 'FullSpan) src, AdaptableTarget Concrete (Value src), tgt ~ AstTensor 'AstMethodLet 'FullSpan ('TKScalar r)) | |
=> (src -> tgt) | the objective function |
-> Value src | |
-> AstArtifactRev (X src) ('TKScalar r) | the artifact containing the symbolic code of the derivative |
Compute the reverse derivative not for a specific input, but as symbolic function from inputs to the gradient value. The function is represented as an "artifact", which is the gradient AST term together with the variable corresponding to the input.
Arguments
:: forall src (ztgt :: TK) tgt. (X src ~ X (Value src), KnownSTK (X src), AdaptableTarget (AstTensor 'AstMethodLet 'FullSpan) src, AdaptableTarget Concrete (Value src), tgt ~ AstTensor 'AstMethodLet 'FullSpan ztgt) | |
=> (src -> tgt) | the objective function |
-> Value src | |
-> AstArtifactRev (X src) ztgt | the artifact containing the symbolic code of the derivative |
Compute the reverse derivative not for a specific input, but as symbolic function from inputs and incoming cotangents to the gradient value. The function is represented as an "artifact", which is the gradient AST term together with variables corresponding to the input and cotangent.
gradInterpretArtifact Source #
Arguments
:: forall (x :: TK) r avals. (X avals ~ ADTensorKind x, AdaptableTarget Concrete avals) | |
=> AstArtifactRev x ('TKScalar r) | the artifact containing the symbolic code of the derivative |
-> Concrete x | |
-> avals |
Interpret the "artifact" as a function from a concrete tensor to a concrete tensor (possibly adapted, e.g., from horde-ad nested pairs to Haskell n-tuples).
Arguments
:: forall (x :: TK) (z :: TK) avals. (X avals ~ ADTensorKind x, AdaptableTarget Concrete avals) | |
=> AstArtifactRev x z | the artifact containing the symbolic code of the derivative |
-> Concrete x | |
-> Concrete (ADTensorKind z) | |
-> avals |
Interpret the "artifact" as a function from concrete tensors to a concrete tensor (possibly adapted, e.g., from horde-ad nested pairs to Haskell n-tuples).
Symbolic forward derivative adaptors
Arguments
:: forall src (ztgt :: TK) tgt. (X src ~ X (Value src), KnownSTK (X src), AdaptableTarget (AstTensor 'AstMethodLet 'FullSpan) src, AdaptableTarget Concrete (Value src), tgt ~ AstTensor 'AstMethodLet 'FullSpan ztgt) | |
=> (src -> tgt) | the objective function |
-> Value src | |
-> Value src | |
-> Concrete (ADTensorKind ztgt) |
The forward derivative operation takes the perturbation parameter by convention. It permits an arbitrary (nested tuple+) type of the domain and arbitrary (nested pair) tensor kind of the codomain of the function to be differentiated. The generality sometimes makes it necessary to suppy type hints when applying this operation.
Arguments
:: forall src (ztgt :: TK) tgt. (X src ~ X (Value src), KnownSTK (X src), AdaptableTarget (AstTensor 'AstMethodLet 'FullSpan) src, AdaptableTarget Concrete (Value src), tgt ~ AstTensor 'AstMethodLet 'FullSpan ztgt) | |
=> (src -> tgt) | the objective function |
-> Value src | |
-> AstArtifactFwd (X src) ztgt | the artifact containing the symbolic code of the derivative |
Compute the forward derivative not for a specific input, but as symbolic function from inputs and perturbation to the derivative value. The function is represented as an "artifact", which is the derivative AST term together with variables corresponding to the input and perturbation.
Arguments
:: forall (x :: TK) (z :: TK). AstArtifactFwd x z | the artifact containing the symbolic code of the derivative |
-> Concrete x | |
-> Concrete (ADTensorKind x) | |
-> Concrete (ADTensorKind z) |
Interpret the "artifact" as a function from concrete tensors to a concrete tensor.
Non-symbolic reverse derivative adaptors
Arguments
:: forall src r tgt. (X src ~ X (DValue src), KnownSTK (X src), AdaptableTarget (ADVal Concrete) src, AdaptableTarget Concrete (DValue src), tgt ~ ADVal Concrete ('TKScalar r)) | |
=> (src -> tgt) | the objective function |
-> DValue src | |
-> DValue src |
This simplified version of the concrete (non-symbolic)
reverse derivative operation sets the incoming cotangent dt
to be 1
and assumes the codomain of the function to be differentiated is a scalar.
Arguments
:: forall src (ztgt :: TK) tgt. (X src ~ X (DValue src), KnownSTK (X src), AdaptableTarget (ADVal Concrete) src, AdaptableTarget Concrete (DValue src), tgt ~ ADVal Concrete ztgt) | |
=> (src -> tgt) | the objective function |
-> DValue src | |
-> Concrete (ADTensorKind ztgt) | |
-> DValue src |
This more general version of the concrete (non-symbolic) reverse derivative operation additionally takes the sensitivity parameter (the incoming cotangent).
Non-symbolic forward derivative adaptors
Arguments
:: forall src (ztgt :: TK) tgt. (X src ~ X (DValue src), KnownSTK (X src), AdaptableTarget (ADVal Concrete) src, AdaptableTarget Concrete (DValue src), tgt ~ ADVal Concrete ztgt) | |
=> (src -> tgt) | the objective function |
-> DValue src | |
-> DValue src | |
-> Concrete (ADTensorKind ztgt) |
Concrete (non-symbolic) forward derivative operation. It always takes the perturbation parameter, by convention.
Internal machinery for symbolic adaptors
data IncomingCotangentHandling Source #
Constructors
UseIncomingCotangent | |
IgnoreIncomingCotangent |
Arguments
:: forall src (ztgt :: TK) tgt. (AdaptableTarget (AstTensor 'AstMethodLet 'FullSpan) src, tgt ~ AstTensor 'AstMethodLet 'FullSpan ztgt) | |
=> IncomingCotangentHandling | |
-> (src -> tgt) | the objective function |
-> FullShapeTK (X src) | |
-> AstArtifactRev (X src) ztgt | the artifact containing the symbolic code of the derivative |
Arguments
:: forall src (ztgt :: TK) tgt. (AdaptableTarget (AstTensor 'AstMethodLet 'FullSpan) src, tgt ~ AstTensor 'AstMethodLet 'FullSpan ztgt) | |
=> IncomingCotangentHandling | |
-> (src -> tgt) | the objective function |
-> FullShapeTK (X src) | |
-> (AstArtifactRev (X src) ztgt, Delta (AstRaw 'PrimalSpan) ztgt) | the artifact containing the symbolic code of the derivative |
revProduceArtifactWithoutInterpretation Source #
Arguments
:: forall (x :: TK) (z :: TK). IncomingCotangentHandling | |
-> (ADVal (AstRaw 'PrimalSpan) x -> ADVal (AstRaw 'PrimalSpan) z) | |
-> FullShapeTK x | |
-> (AstArtifactRev x z, Delta (AstRaw 'PrimalSpan) z) | the artifact containing the symbolic code of the derivative |
Arguments
:: forall (x :: TK) (z :: TK). AstArtifactRev x z | the artifact containing the symbolic code of the derivative |
-> Concrete x | |
-> Maybe (Concrete (ADTensorKind z)) | |
-> (Concrete (ADTensorKind x), Concrete z) |
Arguments
:: forall src (ztgt :: TK) tgt. (AdaptableTarget (AstTensor 'AstMethodLet 'FullSpan) src, tgt ~ AstTensor 'AstMethodLet 'FullSpan ztgt) | |
=> (src -> tgt) | the objective function |
-> FullShapeTK (X src) | |
-> AstArtifactFwd (X src) ztgt | the artifact containing the symbolic code of the derivative |
Arguments
:: forall src (ztgt :: TK) tgt. (AdaptableTarget (AstTensor 'AstMethodLet 'FullSpan) src, tgt ~ AstTensor 'AstMethodLet 'FullSpan ztgt) | |
=> (src -> tgt) | the objective function |
-> FullShapeTK (X src) | |
-> (AstArtifactFwd (X src) ztgt, Delta (AstRaw 'PrimalSpan) ztgt) | the artifact containing the symbolic code of the derivative |
Arguments
:: forall (x :: TK) (z :: TK). AstArtifactFwd x z | the artifact containing the symbolic code of the derivative |
-> Concrete x | |
-> Concrete (ADTensorKind x) | |
-> (Concrete (ADTensorKind z), Concrete z) |
Internal machinery for non-symbolic adaptors
Arguments
:: forall src (ztgt :: TK) tgt. (X src ~ X (DValue src), KnownSTK (X src), AdaptableTarget (ADVal Concrete) src, AdaptableTarget Concrete (DValue src), tgt ~ ADVal Concrete ztgt) | |
=> (src -> tgt) | the objective function |
-> DValue src | |
-> DValue src | |
-> (Concrete (ADTensorKind ztgt), Concrete ztgt) |