horde-ad-0.1.0.0: Higher Order Reverse Derivatives Efficiently - Automatic Differentiation
Safe HaskellNone
LanguageGHC2024

HordeAd.ADEngine

Description

The implementation of reverse derivative and forward derivative calculation for an objective function on values of complicated types, e.g., nested tuples of tensors.

The objective function can be defined as a sufficiently polymorphic Haskell function that uses numeric classes as well as the multi-dimensional tensor operation listed in HordeAd.OpsTensor. To obtain symbolic derivatives (derivative code that can be executed many times without performing AD again), the user needs an objective function polymorphic enough so that it can be instantiated to the AstTensor type (nested in tuples, etc., for some extra flexibility). For non-symbolic derivatives, the ability to instantiate to the ADVal type of dual numbers is enough. See the classes these types are instances of to gauge the breadth of the offered respective APIs.

Synopsis

Symbolic reverse derivative adaptors

grad Source #

Arguments

:: forall src r tgt. (X src ~ X (Value src), KnownSTK (X src), AdaptableTarget (AstTensor 'AstMethodLet 'FullSpan) src, AdaptableTarget Concrete (Value src), tgt ~ AstTensor 'AstMethodLet 'FullSpan ('TKScalar r)) 
=> (src -> tgt)

the objective function

-> Value src 
-> Value src 

This simplified version of the symbolic reverse derivative operation sets the incoming cotangent dt to be 1 and assumes the codomain of the function to be differentiated is a scalar.

We don't enforce (e.g., by quantifcation) that the objective function is closed, because we evaluate the result of the differentiation down to concrete arrays and so there's no risk of "perturbation confusion" between different levels of differentiation if it's done multiple times. For simplicity of the type signature, the resulting value is converted from the type of concrete contangents to the type of concrete input parameters.

vjp Source #

Arguments

:: forall src (ztgt :: TK) tgt. (X src ~ X (Value src), KnownSTK (X src), AdaptableTarget (AstTensor 'AstMethodLet 'FullSpan) src, AdaptableTarget Concrete (Value src), tgt ~ AstTensor 'AstMethodLet 'FullSpan ztgt) 
=> (src -> tgt)

the objective function

-> Value src 
-> Concrete (ADTensorKind ztgt) 
-> Value src 

This version of the symbolic reverse derivative operation explicitly takes the sensitivity parameter (the incoming cotangent). It also permits an arbitrary (nested tuple+) type of the domain and arbitrary (nested pair) tensor kind of the codomain of the function to be differentiated. The downside of the generality is that if the function doesn't have an explicit type signature, the type to which this operation is instantiated often has to be spelled in full via explicit type applications to aid type reconstruction. For simplicity of the type signature, the resulting value is converted from the type of concrete contangents to the type of concrete input parameters.

gradArtifact Source #

Arguments

:: forall src r tgt. (X src ~ X (Value src), KnownSTK (X src), AdaptableTarget (AstTensor 'AstMethodLet 'FullSpan) src, AdaptableTarget Concrete (Value src), tgt ~ AstTensor 'AstMethodLet 'FullSpan ('TKScalar r)) 
=> (src -> tgt)

the objective function

-> Value src 
-> AstArtifactRev (X src) ('TKScalar r)

the artifact containing the symbolic code of the derivative

Compute the reverse derivative not for a specific input, but as symbolic function from inputs to the gradient value. The function is represented as an "artifact", which is the gradient AST term together with the variable corresponding to the input.

vjpArtifact Source #

Arguments

:: forall src (ztgt :: TK) tgt. (X src ~ X (Value src), KnownSTK (X src), AdaptableTarget (AstTensor 'AstMethodLet 'FullSpan) src, AdaptableTarget Concrete (Value src), tgt ~ AstTensor 'AstMethodLet 'FullSpan ztgt) 
=> (src -> tgt)

the objective function

-> Value src 
-> AstArtifactRev (X src) ztgt

the artifact containing the symbolic code of the derivative

Compute the reverse derivative not for a specific input, but as symbolic function from inputs and incoming cotangents to the gradient value. The function is represented as an "artifact", which is the gradient AST term together with variables corresponding to the input and cotangent.

gradInterpretArtifact Source #

Arguments

:: forall (x :: TK) r avals. (X avals ~ ADTensorKind x, AdaptableTarget Concrete avals) 
=> AstArtifactRev x ('TKScalar r)

the artifact containing the symbolic code of the derivative

-> Concrete x 
-> avals 

Interpret the "artifact" as a function from a concrete tensor to a concrete tensor (possibly adapted, e.g., from horde-ad nested pairs to Haskell n-tuples).

vjpInterpretArtifact Source #

Arguments

:: forall (x :: TK) (z :: TK) avals. (X avals ~ ADTensorKind x, AdaptableTarget Concrete avals) 
=> AstArtifactRev x z

the artifact containing the symbolic code of the derivative

-> Concrete x 
-> Concrete (ADTensorKind z) 
-> avals 

Interpret the "artifact" as a function from concrete tensors to a concrete tensor (possibly adapted, e.g., from horde-ad nested pairs to Haskell n-tuples).

Symbolic forward derivative adaptors

jvp Source #

Arguments

:: forall src (ztgt :: TK) tgt. (X src ~ X (Value src), KnownSTK (X src), AdaptableTarget (AstTensor 'AstMethodLet 'FullSpan) src, AdaptableTarget Concrete (Value src), tgt ~ AstTensor 'AstMethodLet 'FullSpan ztgt) 
=> (src -> tgt)

the objective function

-> Value src 
-> Value src 
-> Concrete (ADTensorKind ztgt) 

The forward derivative operation takes the perturbation parameter by convention. It permits an arbitrary (nested tuple+) type of the domain and arbitrary (nested pair) tensor kind of the codomain of the function to be differentiated. The generality sometimes makes it necessary to suppy type hints when applying this operation.

jvpArtifact Source #

Arguments

:: forall src (ztgt :: TK) tgt. (X src ~ X (Value src), KnownSTK (X src), AdaptableTarget (AstTensor 'AstMethodLet 'FullSpan) src, AdaptableTarget Concrete (Value src), tgt ~ AstTensor 'AstMethodLet 'FullSpan ztgt) 
=> (src -> tgt)

the objective function

-> Value src 
-> AstArtifactFwd (X src) ztgt

the artifact containing the symbolic code of the derivative

Compute the forward derivative not for a specific input, but as symbolic function from inputs and perturbation to the derivative value. The function is represented as an "artifact", which is the derivative AST term together with variables corresponding to the input and perturbation.

jvpInterpretArtifact Source #

Arguments

:: forall (x :: TK) (z :: TK). AstArtifactFwd x z

the artifact containing the symbolic code of the derivative

-> Concrete x 
-> Concrete (ADTensorKind x) 
-> Concrete (ADTensorKind z) 

Interpret the "artifact" as a function from concrete tensors to a concrete tensor.

Non-symbolic reverse derivative adaptors

cgrad Source #

Arguments

:: forall src r tgt. (X src ~ X (DValue src), KnownSTK (X src), AdaptableTarget (ADVal Concrete) src, AdaptableTarget Concrete (DValue src), tgt ~ ADVal Concrete ('TKScalar r)) 
=> (src -> tgt)

the objective function

-> DValue src 
-> DValue src 

This simplified version of the concrete (non-symbolic) reverse derivative operation sets the incoming cotangent dt to be 1 and assumes the codomain of the function to be differentiated is a scalar.

cvjp Source #

Arguments

:: forall src (ztgt :: TK) tgt. (X src ~ X (DValue src), KnownSTK (X src), AdaptableTarget (ADVal Concrete) src, AdaptableTarget Concrete (DValue src), tgt ~ ADVal Concrete ztgt) 
=> (src -> tgt)

the objective function

-> DValue src 
-> Concrete (ADTensorKind ztgt) 
-> DValue src 

This more general version of the concrete (non-symbolic) reverse derivative operation additionally takes the sensitivity parameter (the incoming cotangent).

Non-symbolic forward derivative adaptors

cjvp Source #

Arguments

:: forall src (ztgt :: TK) tgt. (X src ~ X (DValue src), KnownSTK (X src), AdaptableTarget (ADVal Concrete) src, AdaptableTarget Concrete (DValue src), tgt ~ ADVal Concrete ztgt) 
=> (src -> tgt)

the objective function

-> DValue src 
-> DValue src 
-> Concrete (ADTensorKind ztgt) 

Concrete (non-symbolic) forward derivative operation. It always takes the perturbation parameter, by convention.

Internal machinery for symbolic adaptors

revArtifactAdapt Source #

Arguments

:: forall src (ztgt :: TK) tgt. (AdaptableTarget (AstTensor 'AstMethodLet 'FullSpan) src, tgt ~ AstTensor 'AstMethodLet 'FullSpan ztgt) 
=> IncomingCotangentHandling 
-> (src -> tgt)

the objective function

-> FullShapeTK (X src) 
-> AstArtifactRev (X src) ztgt

the artifact containing the symbolic code of the derivative

revArtifactDelta Source #

Arguments

:: forall src (ztgt :: TK) tgt. (AdaptableTarget (AstTensor 'AstMethodLet 'FullSpan) src, tgt ~ AstTensor 'AstMethodLet 'FullSpan ztgt) 
=> IncomingCotangentHandling 
-> (src -> tgt)

the objective function

-> FullShapeTK (X src) 
-> (AstArtifactRev (X src) ztgt, Delta (AstRaw 'PrimalSpan) ztgt)

the artifact containing the symbolic code of the derivative

revProduceArtifactWithoutInterpretation Source #

Arguments

:: forall (x :: TK) (z :: TK). IncomingCotangentHandling 
-> (ADVal (AstRaw 'PrimalSpan) x -> ADVal (AstRaw 'PrimalSpan) z) 
-> FullShapeTK x 
-> (AstArtifactRev x z, Delta (AstRaw 'PrimalSpan) z)

the artifact containing the symbolic code of the derivative

revInterpretArtifact Source #

Arguments

:: forall (x :: TK) (z :: TK). AstArtifactRev x z

the artifact containing the symbolic code of the derivative

-> Concrete x 
-> Maybe (Concrete (ADTensorKind z)) 
-> (Concrete (ADTensorKind x), Concrete z) 

fwdArtifactAdapt Source #

Arguments

:: forall src (ztgt :: TK) tgt. (AdaptableTarget (AstTensor 'AstMethodLet 'FullSpan) src, tgt ~ AstTensor 'AstMethodLet 'FullSpan ztgt) 
=> (src -> tgt)

the objective function

-> FullShapeTK (X src) 
-> AstArtifactFwd (X src) ztgt

the artifact containing the symbolic code of the derivative

fwdArtifactDelta Source #

Arguments

:: forall src (ztgt :: TK) tgt. (AdaptableTarget (AstTensor 'AstMethodLet 'FullSpan) src, tgt ~ AstTensor 'AstMethodLet 'FullSpan ztgt) 
=> (src -> tgt)

the objective function

-> FullShapeTK (X src) 
-> (AstArtifactFwd (X src) ztgt, Delta (AstRaw 'PrimalSpan) ztgt)

the artifact containing the symbolic code of the derivative

fwdInterpretArtifact Source #

Arguments

:: forall (x :: TK) (z :: TK). AstArtifactFwd x z

the artifact containing the symbolic code of the derivative

-> Concrete x 
-> Concrete (ADTensorKind x) 
-> (Concrete (ADTensorKind z), Concrete z) 

Internal machinery for non-symbolic adaptors

cfwdBoth Source #

Arguments

:: forall src (ztgt :: TK) tgt. (X src ~ X (DValue src), KnownSTK (X src), AdaptableTarget (ADVal Concrete) src, AdaptableTarget Concrete (DValue src), tgt ~ ADVal Concrete ztgt) 
=> (src -> tgt)

the objective function

-> DValue src 
-> DValue src 
-> (Concrete (ADTensorKind ztgt), Concrete ztgt)