{-# OPTIONS_GHC -Wno-orphans #-}
-- | The implementation of reverse derivative and forward derivative
-- calculation for an objective function on values of complicated types,
-- e.g., nested tuples of tensors.
--
-- The objective function can be defined as a sufficiently polymorphic
-- Haskell function that uses numeric classes as well as the multi-dimensional
-- tensor operation listed in "HordeAd.OpsTensor". To obtain symbolic
-- derivatives (derivative code that can be executed many times without
-- performing AD again), the user needs an objective function polymorphic
-- enough so that it can be instantiated to the 'HordeAd.Core.Ast.AstTensor'
-- type (nested in tuples, etc., for some extra flexibility).
-- For non-symbolic derivatives, the ability to instantiate to the
-- `HordeAd.Core.CarriersADVal.ADVal` type of dual numbers is enough.
-- See the classes these types are instances of to gauge the breadth
-- of the offered respective APIs.
module HordeAd.ADEngine
  ( -- * Symbolic reverse derivative adaptors
    grad, vjp
  , gradArtifact, vjpArtifact
  , gradInterpretArtifact, vjpInterpretArtifact
    -- * Symbolic forward derivative adaptors
  , jvp, jvpArtifact, jvpInterpretArtifact
    -- * Non-symbolic reverse derivative adaptors
  , cgrad, cvjp
    -- * Non-symbolic forward derivative adaptors
  , cjvp
    -- * Internal machinery for symbolic adaptors
  , IncomingCotangentHandling(..)
  , revArtifactAdapt, revArtifactDelta
  , revProduceArtifactWithoutInterpretation, revInterpretArtifact
  , fwdArtifactAdapt, fwdArtifactDelta, fwdInterpretArtifact
    -- * Internal machinery for non-symbolic adaptors
  , cfwdBoth
  ) where

import Prelude

import HordeAd.AstEngine
import HordeAd.Core.Adaptor
import HordeAd.Core.Ast
import HordeAd.Core.AstEnv
import HordeAd.Core.AstInterpret
import HordeAd.Core.CarriersADVal
import HordeAd.Core.CarriersAst
import HordeAd.Core.CarriersConcrete
import HordeAd.Core.Delta
import HordeAd.Core.DeltaEval
import HordeAd.Core.Ops
import HordeAd.Core.OpsADVal
import HordeAd.Core.OpsAst
import HordeAd.Core.TensorKind
import HordeAd.Core.Types
import HordeAd.Core.Unwind

-- * Symbolic reverse derivative adaptors

-- | This simplified version of the symbolic reverse derivative operation
-- sets the incoming cotangent @dt@ to be 1 and assumes the codomain
-- of the function to be differentiated is a scalar.
--
-- We don't enforce (e.g., by quantifcation) that the objective function
-- is closed, because we evaluate the result of the differentiation
-- down to concrete arrays and so there's no risk of "perturbation confusion"
-- between different levels of differentiation if it's done multiple times.
-- For simplicity of the type signature, the resulting value is converted from
-- the type of concrete contangents to the type of concrete input parameters.
grad
  :: forall src r tgt.
     ( X src ~ X (Value src), KnownSTK (X src)
     , AdaptableTarget (AstTensor AstMethodLet FullSpan) src
     , AdaptableTarget Concrete (Value src)
     , tgt ~ AstTensor AstMethodLet FullSpan (TKScalar r) )
  => (src -> tgt)  -- ^ the objective function
  -> Value src
  -> Value src  -- morally Value (ADTensorKind src)
{-# INLINE grad #-}
grad :: forall src r tgt.
((X src :: TK) ~ (X (Value src) :: TK), KnownSTK (X src),
 AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
 AdaptableTarget Concrete (Value src),
 (tgt :: Type)
 ~ (AstTensor AstMethodLet FullSpan (TKScalar r) :: Type)) =>
(src -> tgt) -> Value src -> Value src
grad src -> tgt
f Value src
vals = (src -> tgt)
-> Value src
-> Maybe (Concrete (ADTensorKind (TKScalar r)))
-> Value src
forall src (ztgt :: TK) tgt.
((X src :: TK) ~ (X (Value src) :: TK), KnownSTK (X src),
 AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
 AdaptableTarget Concrete (Value src),
 (tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
(src -> tgt)
-> Value src -> Maybe (Concrete (ADTensorKind ztgt)) -> Value src
revMaybe src -> tgt
f Value src
vals Maybe (Concrete (ADTensorKind (TKScalar r)))
Maybe (Concrete (TKScalar (ADTensorScalar r)))
forall a. Maybe a
Nothing

-- | This version of the symbolic reverse derivative operation
-- explicitly takes the sensitivity parameter (the incoming cotangent).
-- It also permits an arbitrary (nested tuple+) type of the domain
-- and arbitrary (nested pair) tensor kind of the codomain
-- of the function to be differentiated. The downside of the generality
-- is that if the function doesn't have an explicit type signature,
-- the type to which this operation is instantiated often has to be spelled
-- in full via explicit type applications to aid type reconstruction.
-- For simplicity of the type signature, the resulting value is converted from
-- the type of concrete contangents to the type of concrete input parameters.
vjp
  :: forall src ztgt tgt.
     ( X src ~ X (Value src), KnownSTK (X src)
     , AdaptableTarget (AstTensor AstMethodLet FullSpan) src
     , AdaptableTarget Concrete (Value src)
     , tgt ~ AstTensor AstMethodLet FullSpan ztgt )
  => (src -> tgt)  -- ^ the objective function
  -> Value src
  -> Concrete (ADTensorKind ztgt)
  -> Value src  -- morally Value (ADTensorKind src)
{-# INLINE vjp #-}
vjp :: forall src (ztgt :: TK) tgt.
((X src :: TK) ~ (X (Value src) :: TK), KnownSTK (X src),
 AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
 AdaptableTarget Concrete (Value src),
 (tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
(src -> tgt)
-> Value src -> Concrete (ADTensorKind ztgt) -> Value src
vjp src -> tgt
f Value src
vals Concrete (ADTensorKind ztgt)
dt = (src -> tgt)
-> Value src -> Maybe (Concrete (ADTensorKind ztgt)) -> Value src
forall src (ztgt :: TK) tgt.
((X src :: TK) ~ (X (Value src) :: TK), KnownSTK (X src),
 AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
 AdaptableTarget Concrete (Value src),
 (tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
(src -> tgt)
-> Value src -> Maybe (Concrete (ADTensorKind ztgt)) -> Value src
revMaybe src -> tgt
f Value src
vals (Concrete (ADTensorKind ztgt)
-> Maybe (Concrete (ADTensorKind ztgt))
forall a. a -> Maybe a
Just Concrete (ADTensorKind ztgt)
dt)

-- | Compute the reverse derivative not for a specific input, but as symbolic
-- function from inputs to the gradient value.
-- The function is represented as an "artifact", which is the gradient
-- AST term together with the variable corresponding to the input.
gradArtifact
  :: forall src r tgt.
     ( X src ~ X (Value src), KnownSTK (X src)
     , AdaptableTarget (AstTensor AstMethodLet FullSpan) src
     , AdaptableTarget Concrete (Value src)
     , tgt ~ AstTensor AstMethodLet FullSpan (TKScalar r) )
  => (src -> tgt)  -- ^ the objective function
  -> Value src
  -> AstArtifactRev (X src) (TKScalar r)
       -- ^ the artifact containing the symbolic code of the derivative
{-# INLINE gradArtifact #-}
gradArtifact :: forall src r tgt.
((X src :: TK) ~ (X (Value src) :: TK), KnownSTK (X src),
 AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
 AdaptableTarget Concrete (Value src),
 (tgt :: Type)
 ~ (AstTensor AstMethodLet FullSpan (TKScalar r) :: Type)) =>
(src -> tgt) -> Value src -> AstArtifactRev (X src) (TKScalar r)
gradArtifact src -> tgt
f Value src
vals0 =
  let xftk :: FullShapeTK (X src)
xftk = SingletonTK (X src) -> RepConcrete (X src) -> FullShapeTK (X src)
forall (y :: TK). SingletonTK y -> RepConcrete y -> FullShapeTK y
tftkG (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(X src)) (RepConcrete (X src) -> FullShapeTK (X src))
-> RepConcrete (X src) -> FullShapeTK (X src)
forall a b. (a -> b) -> a -> b
$ Concrete (X (Value src)) -> RepConcrete (X (Value src))
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete (Concrete (X (Value src)) -> RepConcrete (X (Value src)))
-> Concrete (X (Value src)) -> RepConcrete (X (Value src))
forall a b. (a -> b) -> a -> b
$ Value src -> Concrete (X (Value src))
forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget Value src
vals0
  in IncomingCotangentHandling
-> (src -> tgt)
-> FullShapeTK (X src)
-> AstArtifactRev (X src) (TKScalar r)
forall src (ztgt :: TK) tgt.
(AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
 (tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
IncomingCotangentHandling
-> (src -> tgt)
-> FullShapeTK (X src)
-> AstArtifactRev (X src) ztgt
revArtifactAdapt IncomingCotangentHandling
IgnoreIncomingCotangent src -> tgt
f FullShapeTK (X src)
xftk

-- | Compute the reverse derivative not for a specific input, but as symbolic
-- function from inputs and incoming cotangents to the gradient value.
-- The function is represented as an "artifact", which is the gradient
-- AST term together with variables corresponding to the input and cotangent.
vjpArtifact
  :: forall src ztgt tgt.
     ( X src ~ X (Value src), KnownSTK (X src)
     , AdaptableTarget (AstTensor AstMethodLet FullSpan) src
     , AdaptableTarget Concrete (Value src)
     , tgt ~ AstTensor AstMethodLet FullSpan ztgt )
  => (src -> tgt)  -- ^ the objective function
  -> Value src
  -> AstArtifactRev (X src) ztgt
       -- ^ the artifact containing the symbolic code of the derivative
{-# INLINE vjpArtifact #-}
vjpArtifact :: forall src (ztgt :: TK) tgt.
((X src :: TK) ~ (X (Value src) :: TK), KnownSTK (X src),
 AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
 AdaptableTarget Concrete (Value src),
 (tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
(src -> tgt) -> Value src -> AstArtifactRev (X src) ztgt
vjpArtifact src -> tgt
f Value src
vals0 =
  let xftk :: FullShapeTK (X src)
xftk = SingletonTK (X src) -> RepConcrete (X src) -> FullShapeTK (X src)
forall (y :: TK). SingletonTK y -> RepConcrete y -> FullShapeTK y
tftkG (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(X src)) (RepConcrete (X src) -> FullShapeTK (X src))
-> RepConcrete (X src) -> FullShapeTK (X src)
forall a b. (a -> b) -> a -> b
$ Concrete (X (Value src)) -> RepConcrete (X (Value src))
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete (Concrete (X (Value src)) -> RepConcrete (X (Value src)))
-> Concrete (X (Value src)) -> RepConcrete (X (Value src))
forall a b. (a -> b) -> a -> b
$ Value src -> Concrete (X (Value src))
forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget Value src
vals0
  in IncomingCotangentHandling
-> (src -> tgt)
-> FullShapeTK (X src)
-> AstArtifactRev (X src) ztgt
forall src (ztgt :: TK) tgt.
(AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
 (tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
IncomingCotangentHandling
-> (src -> tgt)
-> FullShapeTK (X src)
-> AstArtifactRev (X src) ztgt
revArtifactAdapt IncomingCotangentHandling
UseIncomingCotangent src -> tgt
f FullShapeTK (X src)
xftk

-- | Interpret the "artifact" as a function from a concrete tensor
-- to a concrete tensor (possibly adapted, e.g., from horde-ad nested pairs
-- to Haskell n-tuples).
gradInterpretArtifact
  :: forall x r avals.
     (X avals ~ ADTensorKind x, AdaptableTarget Concrete avals)
  => AstArtifactRev x (TKScalar r)
       -- ^ the artifact containing the symbolic code of the derivative
  -> Concrete x
  -> avals
{-# INLINE gradInterpretArtifact #-}
gradInterpretArtifact :: forall (x :: TK) r avals.
((X avals :: TK) ~ (ADTensorKind x :: TK),
 AdaptableTarget Concrete avals) =>
AstArtifactRev x (TKScalar r) -> Concrete x -> avals
gradInterpretArtifact AstArtifactRev{AstTensor AstMethodLet PrimalSpan (ADTensorKind x)
AstTensor AstMethodLet PrimalSpan (TKScalar r)
AstVarName PrimalSpan x
AstVarName PrimalSpan (ADTensorKind (TKScalar r))
artVarDtRev :: AstVarName PrimalSpan (ADTensorKind (TKScalar r))
artVarDomainRev :: AstVarName PrimalSpan x
artDerivativeRev :: AstTensor AstMethodLet PrimalSpan (ADTensorKind x)
artPrimalRev :: AstTensor AstMethodLet PrimalSpan (TKScalar r)
artPrimalRev :: forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstTensor AstMethodLet PrimalSpan z
artDerivativeRev :: forall (x :: TK) (z :: TK).
AstArtifactRev x z
-> AstTensor AstMethodLet PrimalSpan (ADTensorKind x)
artVarDomainRev :: forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstVarName PrimalSpan x
artVarDtRev :: forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstVarName PrimalSpan (ADTensorKind z)
..} Concrete x
parameters =
  let xftk :: FullShapeTK x
xftk = AstVarName PrimalSpan x -> FullShapeTK x
forall (s :: AstSpanType) (y :: TK).
AstVarName s y -> FullShapeTK y
varNameToFTK AstVarName PrimalSpan x
artVarDomainRev
      azftk :: FullShapeTK (TKScalar (ADTensorScalar r))
azftk = AstVarName PrimalSpan (TKScalar (ADTensorScalar r))
-> FullShapeTK (TKScalar (ADTensorScalar r))
forall (s :: AstSpanType) (y :: TK).
AstVarName s y -> FullShapeTK y
varNameToFTK AstVarName PrimalSpan (ADTensorKind (TKScalar r))
AstVarName PrimalSpan (TKScalar (ADTensorScalar r))
artVarDtRev
                -- STKScalar @(ADTensorScalar r) or STKScalar @Z1
      oneAtF :: Concrete (TKScalar (ADTensorScalar r))
oneAtF = (forall r. GoodScalar r => r)
-> FullShapeTK (TKScalar (ADTensorScalar r))
-> Concrete (TKScalar (ADTensorScalar r))
forall (y :: TK).
(forall r. GoodScalar r => r) -> FullShapeTK y -> Concrete y
forall (target :: Target) (y :: TK).
BaseTensor target =>
(forall r. GoodScalar r => r) -> FullShapeTK y -> target y
treplTarget r
forall r. GoodScalar r => r
1 FullShapeTK (TKScalar (ADTensorScalar r))
azftk
      env :: AstEnv Concrete
env = AstVarName PrimalSpan (TKScalar (ADTensorScalar r))
-> Concrete (TKScalar (ADTensorScalar r))
-> AstEnv Concrete
-> AstEnv Concrete
forall (target :: Target) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName PrimalSpan (ADTensorKind (TKScalar r))
AstVarName PrimalSpan (TKScalar (ADTensorScalar r))
artVarDtRev Concrete (TKScalar (ADTensorScalar r))
oneAtF
            (AstEnv Concrete -> AstEnv Concrete)
-> AstEnv Concrete -> AstEnv Concrete
forall a b. (a -> b) -> a -> b
$ AstVarName PrimalSpan x
-> Concrete x -> AstEnv Concrete -> AstEnv Concrete
forall (target :: Target) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName PrimalSpan x
artVarDomainRev Concrete x
parameters AstEnv Concrete
forall (target :: Target). AstEnv target
emptyEnv
  in if SingletonTK x -> RepConcrete x -> FullShapeTK x
forall (y :: TK). SingletonTK y -> RepConcrete y -> FullShapeTK y
tftkG (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
xftk) (Concrete x -> RepConcrete x
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete x
parameters) FullShapeTK x -> FullShapeTK x -> Bool
forall a. Eq a => a -> a -> Bool
== FullShapeTK x
xftk
     then Concrete (X avals) -> avals
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget (Concrete (X avals) -> avals) -> Concrete (X avals) -> avals
forall a b. (a -> b) -> a -> b
$ AstEnv Concrete
-> AstTensor AstMethodLet PrimalSpan (ADTensorKind x)
-> PrimalOf Concrete (ADTensorKind x)
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target
-> AstTensor AstMethodLet PrimalSpan y -> PrimalOf target y
interpretAstPrimal AstEnv Concrete
env AstTensor AstMethodLet PrimalSpan (ADTensorKind x)
artDerivativeRev
     else [Char] -> avals
forall a. HasCallStack => [Char] -> a
error [Char]
"gradInterpretArtifact: reverse derivative parameters must have the same shape as the domain of the objective function"

-- | Interpret the "artifact" as a function from concrete tensors
-- to a concrete tensor (possibly adapted, e.g., from horde-ad nested pairs
-- to Haskell n-tuples).
vjpInterpretArtifact
  :: forall x z avals.
     (X avals ~ ADTensorKind x, AdaptableTarget Concrete avals)
  => AstArtifactRev x z
       -- ^ the artifact containing the symbolic code of the derivative
  -> Concrete x
  -> Concrete (ADTensorKind z)
  -> avals
{-# INLINE vjpInterpretArtifact #-}
vjpInterpretArtifact :: forall (x :: TK) (z :: TK) avals.
((X avals :: TK) ~ (ADTensorKind x :: TK),
 AdaptableTarget Concrete avals) =>
AstArtifactRev x z
-> Concrete x -> Concrete (ADTensorKind z) -> avals
vjpInterpretArtifact AstArtifactRev{AstTensor AstMethodLet PrimalSpan z
AstTensor AstMethodLet PrimalSpan (ADTensorKind x)
AstVarName PrimalSpan x
AstVarName PrimalSpan (ADTensorKind z)
artPrimalRev :: forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstTensor AstMethodLet PrimalSpan z
artDerivativeRev :: forall (x :: TK) (z :: TK).
AstArtifactRev x z
-> AstTensor AstMethodLet PrimalSpan (ADTensorKind x)
artVarDomainRev :: forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstVarName PrimalSpan x
artVarDtRev :: forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstVarName PrimalSpan (ADTensorKind z)
artVarDtRev :: AstVarName PrimalSpan (ADTensorKind z)
artVarDomainRev :: AstVarName PrimalSpan x
artDerivativeRev :: AstTensor AstMethodLet PrimalSpan (ADTensorKind x)
artPrimalRev :: AstTensor AstMethodLet PrimalSpan z
..} Concrete x
parameters Concrete (ADTensorKind z)
dt =
  let xftk :: FullShapeTK x
xftk = AstVarName PrimalSpan x -> FullShapeTK x
forall (s :: AstSpanType) (y :: TK).
AstVarName s y -> FullShapeTK y
varNameToFTK AstVarName PrimalSpan x
artVarDomainRev
      azftk :: FullShapeTK (ADTensorKind z)
azftk = AstVarName PrimalSpan (ADTensorKind z)
-> FullShapeTK (ADTensorKind z)
forall (s :: AstSpanType) (y :: TK).
AstVarName s y -> FullShapeTK y
varNameToFTK AstVarName PrimalSpan (ADTensorKind z)
artVarDtRev
      env :: AstEnv Concrete
env = AstVarName PrimalSpan (ADTensorKind z)
-> Concrete (ADTensorKind z) -> AstEnv Concrete -> AstEnv Concrete
forall (target :: Target) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName PrimalSpan (ADTensorKind z)
artVarDtRev Concrete (ADTensorKind z)
dt
            (AstEnv Concrete -> AstEnv Concrete)
-> AstEnv Concrete -> AstEnv Concrete
forall a b. (a -> b) -> a -> b
$ AstVarName PrimalSpan x
-> Concrete x -> AstEnv Concrete -> AstEnv Concrete
forall (target :: Target) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName PrimalSpan x
artVarDomainRev Concrete x
parameters AstEnv Concrete
forall (target :: Target). AstEnv target
emptyEnv
  in if SingletonTK x -> RepConcrete x -> FullShapeTK x
forall (y :: TK). SingletonTK y -> RepConcrete y -> FullShapeTK y
tftkG (FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
xftk) (Concrete x -> RepConcrete x
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete x
parameters) FullShapeTK x -> FullShapeTK x -> Bool
forall a. Eq a => a -> a -> Bool
== FullShapeTK x
xftk
     then if SingletonTK (ADTensorKind z)
-> RepConcrete (ADTensorKind z) -> FullShapeTK (ADTensorKind z)
forall (y :: TK). SingletonTK y -> RepConcrete y -> FullShapeTK y
tftkG (FullShapeTK (ADTensorKind z) -> SingletonTK (ADTensorKind z)
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK (ADTensorKind z)
azftk) (Concrete (ADTensorKind z) -> RepConcrete (ADTensorKind z)
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete (ADTensorKind z)
dt) FullShapeTK (ADTensorKind z)
-> FullShapeTK (ADTensorKind z) -> Bool
forall a. Eq a => a -> a -> Bool
== FullShapeTK (ADTensorKind z)
azftk
          then Concrete (X avals) -> avals
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget (Concrete (X avals) -> avals) -> Concrete (X avals) -> avals
forall a b. (a -> b) -> a -> b
$ AstEnv Concrete
-> AstTensor AstMethodLet PrimalSpan (ADTensorKind x)
-> PrimalOf Concrete (ADTensorKind x)
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target
-> AstTensor AstMethodLet PrimalSpan y -> PrimalOf target y
interpretAstPrimal AstEnv Concrete
env AstTensor AstMethodLet PrimalSpan (ADTensorKind x)
artDerivativeRev
          else [Char] -> avals
forall a. HasCallStack => [Char] -> a
error [Char]
"vjpInterpretArtifact: reverse derivative incoming cotangent must have the same shape as the codomain of the objective function"
     else [Char] -> avals
forall a. HasCallStack => [Char] -> a
error [Char]
"vjpInterpretArtifact: reverse derivative parameters must have the same shape as the domain of the objective function"


-- * Symbolic reverse derivative adaptors' internal machinery

revMaybe
  :: forall src ztgt tgt.
     ( X src ~ X (Value src), KnownSTK (X src)
     , AdaptableTarget (AstTensor AstMethodLet FullSpan) src
     , AdaptableTarget Concrete (Value src)
     , tgt ~ AstTensor AstMethodLet FullSpan ztgt )
  => (src -> tgt)  -- ^ the objective function
  -> Value src
  -> Maybe (Concrete (ADTensorKind ztgt))
  -> Value src  -- morally Value (ADTensorKind src)
{-# INLINE revMaybe #-}
revMaybe :: forall src (ztgt :: TK) tgt.
((X src :: TK) ~ (X (Value src) :: TK), KnownSTK (X src),
 AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
 AdaptableTarget Concrete (Value src),
 (tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
(src -> tgt)
-> Value src -> Maybe (Concrete (ADTensorKind ztgt)) -> Value src
revMaybe src -> tgt
f Value src
vals0 Maybe (Concrete (ADTensorKind ztgt))
mdt =
  let valsTarget :: Concrete (X (Value src))
valsTarget = Value src -> Concrete (X (Value src))
forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget Value src
vals0
      xftk :: FullShapeTK (X src)
xftk = SingletonTK (X src) -> RepConcrete (X src) -> FullShapeTK (X src)
forall (y :: TK). SingletonTK y -> RepConcrete y -> FullShapeTK y
tftkG (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(X src)) (RepConcrete (X src) -> FullShapeTK (X src))
-> RepConcrete (X src) -> FullShapeTK (X src)
forall a b. (a -> b) -> a -> b
$ Concrete (X src) -> RepConcrete (X src)
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete (X src)
Concrete (X (Value src))
valsTarget
      cotangentHandling :: IncomingCotangentHandling
cotangentHandling =
        IncomingCotangentHandling
-> (Concrete (ADTensorKind ztgt) -> IncomingCotangentHandling)
-> Maybe (Concrete (ADTensorKind ztgt))
-> IncomingCotangentHandling
forall b a. b -> (a -> b) -> Maybe a -> b
maybe IncomingCotangentHandling
IgnoreIncomingCotangent (IncomingCotangentHandling
-> Concrete (ADTensorKind ztgt) -> IncomingCotangentHandling
forall a b. a -> b -> a
const IncomingCotangentHandling
UseIncomingCotangent) Maybe (Concrete (ADTensorKind ztgt))
mdt
      artifactRaw :: AstArtifactRev (X src) ztgt
artifactRaw = IncomingCotangentHandling
-> (src -> tgt)
-> FullShapeTK (X src)
-> AstArtifactRev (X src) ztgt
forall src (ztgt :: TK) tgt.
(AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
 (tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
IncomingCotangentHandling
-> (src -> tgt)
-> FullShapeTK (X src)
-> AstArtifactRev (X src) ztgt
revArtifactAdapt IncomingCotangentHandling
cotangentHandling src -> tgt
f FullShapeTK (X src)
xftk
      artifact :: AstArtifactRev (X src) ztgt
artifact = AstArtifactRev (X src) ztgt -> AstArtifactRev (X src) ztgt
forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstArtifactRev x z
simplifyArtifactGradient AstArtifactRev (X src) ztgt
artifactRaw
  in Concrete (X (Value src)) -> Value src
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget (Concrete (X (Value src)) -> Value src)
-> Concrete (X (Value src)) -> Value src
forall a b. (a -> b) -> a -> b
$ SingletonTK (X (Value src))
-> Concrete (ADTensorKind (X (Value src)))
-> Concrete (X (Value src))
forall (target :: Target) (y :: TK).
(BaseTensor target, ConvertTensor target, ShareTensor target) =>
SingletonTK y -> target (ADTensorKind y) -> target y
fromADTensorKindShared (FullShapeTK (X src) -> SingletonTK (X src)
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK (X src)
xftk)
     (Concrete (ADTensorKind (X (Value src)))
 -> Concrete (X (Value src)))
-> Concrete (ADTensorKind (X (Value src)))
-> Concrete (X (Value src))
forall a b. (a -> b) -> a -> b
$ (Concrete (ADTensorKind (X (Value src))), Concrete ztgt)
-> Concrete (ADTensorKind (X (Value src)))
forall a b. (a, b) -> a
fst ((Concrete (ADTensorKind (X (Value src))), Concrete ztgt)
 -> Concrete (ADTensorKind (X (Value src))))
-> (Concrete (ADTensorKind (X (Value src))), Concrete ztgt)
-> Concrete (ADTensorKind (X (Value src)))
forall a b. (a -> b) -> a -> b
$ AstArtifactRev (X src) ztgt
-> Concrete (X src)
-> Maybe (Concrete (ADTensorKind ztgt))
-> (Concrete (ADTensorKind (X src)), Concrete ztgt)
forall (x :: TK) (z :: TK).
AstArtifactRev x z
-> Concrete x
-> Maybe (Concrete (ADTensorKind z))
-> (Concrete (ADTensorKind x), Concrete z)
revInterpretArtifact AstArtifactRev (X src) ztgt
artifact Concrete (X src)
Concrete (X (Value src))
valsTarget Maybe (Concrete (ADTensorKind ztgt))
mdt

revArtifactAdapt
  :: forall src ztgt tgt.
     ( AdaptableTarget (AstTensor AstMethodLet FullSpan) src
     , tgt ~ AstTensor AstMethodLet FullSpan ztgt )
  => IncomingCotangentHandling
  -> (src -> tgt)  -- ^ the objective function
  -> FullShapeTK (X src)
  -> AstArtifactRev (X src) ztgt
       -- ^ the artifact containing the symbolic code of the derivative
{-# INLINE revArtifactAdapt #-}
revArtifactAdapt :: forall src (ztgt :: TK) tgt.
(AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
 (tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
IncomingCotangentHandling
-> (src -> tgt)
-> FullShapeTK (X src)
-> AstArtifactRev (X src) ztgt
revArtifactAdapt IncomingCotangentHandling
cotangentHandling src -> tgt
f FullShapeTK (X src)
xftk =
  let g :: AstTensor AstMethodLet FullSpan (X src) -> tgt
      g :: AstTensor AstMethodLet FullSpan (X src) -> tgt
g !AstTensor AstMethodLet FullSpan (X src)
arg = AstTensor AstMethodLet FullSpan ztgt
-> AstTensor AstMethodLet FullSpan ztgt
forall (z :: TK) (s :: AstSpanType).
AstSpan s =>
AstTensor AstMethodLet s z -> AstTensor AstMethodLet s z
simplifyInline (AstTensor AstMethodLet FullSpan ztgt
 -> AstTensor AstMethodLet FullSpan ztgt)
-> AstTensor AstMethodLet FullSpan ztgt
-> AstTensor AstMethodLet FullSpan ztgt
forall a b. (a -> b) -> a -> b
$ AstTensor AstMethodLet FullSpan (X src)
-> (AstTensor AstMethodLet FullSpan (X src)
    -> AstTensor AstMethodLet FullSpan ztgt)
-> AstTensor AstMethodLet FullSpan ztgt
forall (x :: TK) (z :: TK).
AstTensor AstMethodLet FullSpan x
-> (AstTensor AstMethodLet FullSpan x
    -> AstTensor AstMethodLet FullSpan z)
-> AstTensor AstMethodLet FullSpan z
forall (target :: Target) (x :: TK) (z :: TK).
LetTensor target =>
target x -> (target x -> target z) -> target z
ttlet AstTensor AstMethodLet FullSpan (X src)
arg ((AstTensor AstMethodLet FullSpan (X src)
  -> AstTensor AstMethodLet FullSpan ztgt)
 -> AstTensor AstMethodLet FullSpan ztgt)
-> (AstTensor AstMethodLet FullSpan (X src)
    -> AstTensor AstMethodLet FullSpan ztgt)
-> AstTensor AstMethodLet FullSpan ztgt
forall a b. (a -> b) -> a -> b
$ src -> tgt
src -> AstTensor AstMethodLet FullSpan ztgt
f (src -> AstTensor AstMethodLet FullSpan ztgt)
-> (AstTensor AstMethodLet FullSpan (X src) -> src)
-> AstTensor AstMethodLet FullSpan (X src)
-> AstTensor AstMethodLet FullSpan ztgt
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AstTensor AstMethodLet FullSpan (X src) -> src
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget
                                  -- fromTarget requires duplicable
  in IncomingCotangentHandling
-> (AstTensor AstMethodLet FullSpan (X src)
    -> AstTensor AstMethodLet FullSpan ztgt)
-> AstEnv (ADVal (AstRaw PrimalSpan))
-> FullShapeTK (X src)
-> AstArtifactRev (X src) ztgt
forall (x :: TK) (z :: TK).
IncomingCotangentHandling
-> (AstTensor AstMethodLet FullSpan x
    -> AstTensor AstMethodLet FullSpan z)
-> AstEnv (ADVal (AstRaw PrimalSpan))
-> FullShapeTK x
-> AstArtifactRev x z
revProduceArtifact IncomingCotangentHandling
cotangentHandling AstTensor AstMethodLet FullSpan (X src) -> tgt
AstTensor AstMethodLet FullSpan (X src)
-> AstTensor AstMethodLet FullSpan ztgt
g AstEnv (ADVal (AstRaw PrimalSpan))
forall (target :: Target). AstEnv target
emptyEnv FullShapeTK (X src)
xftk

revInterpretArtifact
  :: forall x z.
     AstArtifactRev x z
       -- ^ the artifact containing the symbolic code of the derivative
  -> Concrete x
  -> Maybe (Concrete (ADTensorKind z))
  -> (Concrete (ADTensorKind x), Concrete z)
{-# INLINE revInterpretArtifact #-}
revInterpretArtifact :: forall (x :: TK) (z :: TK).
AstArtifactRev x z
-> Concrete x
-> Maybe (Concrete (ADTensorKind z))
-> (Concrete (ADTensorKind x), Concrete z)
revInterpretArtifact AstArtifactRev{AstTensor AstMethodLet PrimalSpan z
AstTensor AstMethodLet PrimalSpan (ADTensorKind x)
AstVarName PrimalSpan x
AstVarName PrimalSpan (ADTensorKind z)
artPrimalRev :: forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstTensor AstMethodLet PrimalSpan z
artDerivativeRev :: forall (x :: TK) (z :: TK).
AstArtifactRev x z
-> AstTensor AstMethodLet PrimalSpan (ADTensorKind x)
artVarDomainRev :: forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstVarName PrimalSpan x
artVarDtRev :: forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstVarName PrimalSpan (ADTensorKind z)
artVarDtRev :: AstVarName PrimalSpan (ADTensorKind z)
artVarDomainRev :: AstVarName PrimalSpan x
artDerivativeRev :: AstTensor AstMethodLet PrimalSpan (ADTensorKind x)
artPrimalRev :: AstTensor AstMethodLet PrimalSpan z
..} Concrete x
parameters Maybe (Concrete (ADTensorKind z))
mdt =
  let azftk :: FullShapeTK (ADTensorKind z)
azftk = AstVarName PrimalSpan (ADTensorKind z)
-> FullShapeTK (ADTensorKind z)
forall (s :: AstSpanType) (y :: TK).
AstVarName s y -> FullShapeTK y
varNameToFTK AstVarName PrimalSpan (ADTensorKind z)
artVarDtRev
      env :: AstEnv Concrete
env = AstVarName PrimalSpan x
-> Concrete x -> AstEnv Concrete -> AstEnv Concrete
forall (target :: Target) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName PrimalSpan x
artVarDomainRev Concrete x
parameters AstEnv Concrete
forall (target :: Target). AstEnv target
emptyEnv
      envDt :: AstEnv Concrete
envDt = case Maybe (Concrete (ADTensorKind z))
mdt of
        Maybe (Concrete (ADTensorKind z))
Nothing ->
          let oneAtF :: Concrete (ADTensorKind z)
oneAtF = (forall r. GoodScalar r => r)
-> FullShapeTK (ADTensorKind z) -> Concrete (ADTensorKind z)
forall (y :: TK).
(forall r. GoodScalar r => r) -> FullShapeTK y -> Concrete y
forall (target :: Target) (y :: TK).
BaseTensor target =>
(forall r. GoodScalar r => r) -> FullShapeTK y -> target y
treplTarget r
forall r. GoodScalar r => r
1 FullShapeTK (ADTensorKind z)
azftk
          in AstVarName PrimalSpan (ADTensorKind z)
-> Concrete (ADTensorKind z) -> AstEnv Concrete -> AstEnv Concrete
forall (target :: Target) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName PrimalSpan (ADTensorKind z)
artVarDtRev Concrete (ADTensorKind z)
oneAtF AstEnv Concrete
env
        Just Concrete (ADTensorKind z)
dt ->
          if SingletonTK (ADTensorKind z)
-> RepConcrete (ADTensorKind z) -> FullShapeTK (ADTensorKind z)
forall (y :: TK). SingletonTK y -> RepConcrete y -> FullShapeTK y
tftkG (FullShapeTK (ADTensorKind z) -> SingletonTK (ADTensorKind z)
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK (ADTensorKind z)
azftk) (Concrete (ADTensorKind z) -> RepConcrete (ADTensorKind z)
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete (ADTensorKind z)
dt) FullShapeTK (ADTensorKind z)
-> FullShapeTK (ADTensorKind z) -> Bool
forall a. Eq a => a -> a -> Bool
== FullShapeTK (ADTensorKind z)
azftk
          then AstVarName PrimalSpan (ADTensorKind z)
-> Concrete (ADTensorKind z) -> AstEnv Concrete -> AstEnv Concrete
forall (target :: Target) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName PrimalSpan (ADTensorKind z)
artVarDtRev Concrete (ADTensorKind z)
dt AstEnv Concrete
env
          else [Char] -> AstEnv Concrete
forall a. HasCallStack => [Char] -> a
error [Char]
"revInterpretArtifact: reverse derivative incoming cotangent must have the same shape as the codomain of the objective function"
      gradient :: PrimalOf Concrete (ADTensorKind x)
gradient = AstEnv Concrete
-> AstTensor AstMethodLet PrimalSpan (ADTensorKind x)
-> PrimalOf Concrete (ADTensorKind x)
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target
-> AstTensor AstMethodLet PrimalSpan y -> PrimalOf target y
interpretAstPrimal AstEnv Concrete
envDt AstTensor AstMethodLet PrimalSpan (ADTensorKind x)
artDerivativeRev
      primal :: PrimalOf Concrete z
primal = AstEnv Concrete
-> AstTensor AstMethodLet PrimalSpan z -> PrimalOf Concrete z
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target
-> AstTensor AstMethodLet PrimalSpan y -> PrimalOf target y
interpretAstPrimal AstEnv Concrete
env AstTensor AstMethodLet PrimalSpan z
artPrimalRev
  in (PrimalOf Concrete (ADTensorKind x)
Concrete (ADTensorKind x)
gradient, PrimalOf Concrete z
Concrete z
primal)


-- * Symbolic reverse derivative adaptors' testing-only internal machinery

revArtifactDelta
  :: forall src ztgt tgt.
     ( AdaptableTarget (AstTensor AstMethodLet FullSpan) src
     , tgt ~ AstTensor AstMethodLet FullSpan ztgt )
  => IncomingCotangentHandling
  -> (src -> tgt)  -- ^ the objective function
  -> FullShapeTK (X src)
  -> (AstArtifactRev (X src) ztgt, Delta (AstRaw PrimalSpan) ztgt)
       -- ^ the artifact containing the symbolic code of the derivative
{-# INLINE revArtifactDelta #-}
revArtifactDelta :: forall src (ztgt :: TK) tgt.
(AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
 (tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
IncomingCotangentHandling
-> (src -> tgt)
-> FullShapeTK (X src)
-> (AstArtifactRev (X src) ztgt, Delta (AstRaw PrimalSpan) ztgt)
revArtifactDelta IncomingCotangentHandling
cotangentHandling src -> tgt
f FullShapeTK (X src)
xftk =
  let g :: AstTensor AstMethodLet FullSpan (X src) -> tgt
      g :: AstTensor AstMethodLet FullSpan (X src) -> tgt
g !AstTensor AstMethodLet FullSpan (X src)
arg = AstTensor AstMethodLet FullSpan (X src)
-> (AstTensor AstMethodLet FullSpan (X src)
    -> AstTensor AstMethodLet FullSpan ztgt)
-> AstTensor AstMethodLet FullSpan ztgt
forall (x :: TK) (z :: TK).
AstTensor AstMethodLet FullSpan x
-> (AstTensor AstMethodLet FullSpan x
    -> AstTensor AstMethodLet FullSpan z)
-> AstTensor AstMethodLet FullSpan z
forall (target :: Target) (x :: TK) (z :: TK).
LetTensor target =>
target x -> (target x -> target z) -> target z
ttlet AstTensor AstMethodLet FullSpan (X src)
arg ((AstTensor AstMethodLet FullSpan (X src)
  -> AstTensor AstMethodLet FullSpan ztgt)
 -> AstTensor AstMethodLet FullSpan ztgt)
-> (AstTensor AstMethodLet FullSpan (X src)
    -> AstTensor AstMethodLet FullSpan ztgt)
-> AstTensor AstMethodLet FullSpan ztgt
forall a b. (a -> b) -> a -> b
$ src -> tgt
src -> AstTensor AstMethodLet FullSpan ztgt
f (src -> AstTensor AstMethodLet FullSpan ztgt)
-> (AstTensor AstMethodLet FullSpan (X src) -> src)
-> AstTensor AstMethodLet FullSpan (X src)
-> AstTensor AstMethodLet FullSpan ztgt
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AstTensor AstMethodLet FullSpan (X src) -> src
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget
  in IncomingCotangentHandling
-> (AstTensor AstMethodShare PrimalSpan (X src)
    -> AstVarName FullSpan (X src)
    -> AstTensor AstMethodLet FullSpan (X src)
    -> ADVal (AstRaw PrimalSpan) ztgt)
-> FullShapeTK (X src)
-> (AstArtifactRev (X src) ztgt, Delta (AstRaw PrimalSpan) ztgt)
forall (x :: TK) (z :: TK).
IncomingCotangentHandling
-> (AstTensor AstMethodShare PrimalSpan x
    -> AstVarName FullSpan x
    -> AstTensor AstMethodLet FullSpan x
    -> ADVal (AstRaw PrimalSpan) z)
-> FullShapeTK x
-> (AstArtifactRev x z, Delta (AstRaw PrimalSpan) z)
revArtifactFromForwardPass IncomingCotangentHandling
cotangentHandling
                                ((AstTensor AstMethodLet FullSpan (X src)
 -> AstTensor AstMethodLet FullSpan ztgt)
-> AstEnv (ADVal (AstRaw PrimalSpan))
-> AstTensor AstMethodShare PrimalSpan (X src)
-> AstVarName FullSpan (X src)
-> AstTensor AstMethodLet FullSpan (X src)
-> ADVal (AstRaw PrimalSpan) ztgt
forall (x :: TK) (z :: TK).
(AstTensor AstMethodLet FullSpan x
 -> AstTensor AstMethodLet FullSpan z)
-> AstEnv (ADVal (AstRaw PrimalSpan))
-> AstTensor AstMethodShare PrimalSpan x
-> AstVarName FullSpan x
-> AstTensor AstMethodLet FullSpan x
-> ADVal (AstRaw PrimalSpan) z
forwardPassByInterpretation AstTensor AstMethodLet FullSpan (X src) -> tgt
AstTensor AstMethodLet FullSpan (X src)
-> AstTensor AstMethodLet FullSpan ztgt
g AstEnv (ADVal (AstRaw PrimalSpan))
forall (target :: Target). AstEnv target
emptyEnv) FullShapeTK (X src)
xftk

revProduceArtifactWithoutInterpretation
  :: forall x z.
     IncomingCotangentHandling
  -> (ADVal (AstRaw PrimalSpan) x -> ADVal (AstRaw PrimalSpan) z)
  -> FullShapeTK x
  -> (AstArtifactRev x z, Delta (AstRaw PrimalSpan) z)
       -- ^ the artifact containing the symbolic code of the derivative
{-# INLINE revProduceArtifactWithoutInterpretation #-}
revProduceArtifactWithoutInterpretation :: forall (x :: TK) (z :: TK).
IncomingCotangentHandling
-> (ADVal (AstRaw PrimalSpan) x -> ADVal (AstRaw PrimalSpan) z)
-> FullShapeTK x
-> (AstArtifactRev x z, Delta (AstRaw PrimalSpan) z)
revProduceArtifactWithoutInterpretation IncomingCotangentHandling
cotangentHandling ADVal (AstRaw PrimalSpan) x -> ADVal (AstRaw PrimalSpan) z
f FullShapeTK x
xftk =
  -- No simplification performed to let individual tests decide.
  IncomingCotangentHandling
-> (AstTensor AstMethodShare PrimalSpan x
    -> AstVarName FullSpan x
    -> AstTensor AstMethodLet FullSpan x
    -> ADVal (AstRaw PrimalSpan) z)
-> FullShapeTK x
-> (AstArtifactRev x z, Delta (AstRaw PrimalSpan) z)
forall (x :: TK) (z :: TK).
IncomingCotangentHandling
-> (AstTensor AstMethodShare PrimalSpan x
    -> AstVarName FullSpan x
    -> AstTensor AstMethodLet FullSpan x
    -> ADVal (AstRaw PrimalSpan) z)
-> FullShapeTK x
-> (AstArtifactRev x z, Delta (AstRaw PrimalSpan) z)
revArtifactFromForwardPass IncomingCotangentHandling
cotangentHandling
                             ((ADVal (AstRaw PrimalSpan) x -> ADVal (AstRaw PrimalSpan) z)
-> AstTensor AstMethodShare PrimalSpan x
-> AstVarName FullSpan x
-> AstTensor AstMethodLet FullSpan x
-> ADVal (AstRaw PrimalSpan) z
forall (x :: TK) (z :: TK).
(ADVal (AstRaw PrimalSpan) x -> ADVal (AstRaw PrimalSpan) z)
-> AstTensor AstMethodShare PrimalSpan x
-> AstVarName FullSpan x
-> AstTensor AstMethodLet FullSpan x
-> ADVal (AstRaw PrimalSpan) z
forwardPassByApplication ADVal (AstRaw PrimalSpan) x -> ADVal (AstRaw PrimalSpan) z
f)
                             FullShapeTK x
xftk

forwardPassByApplication
  :: forall x z.
     (ADVal (AstRaw PrimalSpan) x -> ADVal (AstRaw PrimalSpan) z)
  -> AstTensor AstMethodShare PrimalSpan x
  -> AstVarName FullSpan x
  -> AstTensor AstMethodLet FullSpan x
  -> ADVal (AstRaw PrimalSpan) z
{-# INLINE forwardPassByApplication #-}
forwardPassByApplication :: forall (x :: TK) (z :: TK).
(ADVal (AstRaw PrimalSpan) x -> ADVal (AstRaw PrimalSpan) z)
-> AstTensor AstMethodShare PrimalSpan x
-> AstVarName FullSpan x
-> AstTensor AstMethodLet FullSpan x
-> ADVal (AstRaw PrimalSpan) z
forwardPassByApplication ADVal (AstRaw PrimalSpan) x -> ADVal (AstRaw PrimalSpan) z
g AstTensor AstMethodShare PrimalSpan x
astVarPrimal AstVarName FullSpan x
var AstTensor AstMethodLet FullSpan x
_astVar =
  let deltaInputs :: Delta (AstRaw PrimalSpan) x
deltaInputs = FullShapeTK x -> Delta (AstRaw PrimalSpan) x
forall (x :: TK) (target :: Target).
FullShapeTK x -> Delta target x
generateDeltaInputs (FullShapeTK x -> Delta (AstRaw PrimalSpan) x)
-> FullShapeTK x -> Delta (AstRaw PrimalSpan) x
forall a b. (a -> b) -> a -> b
$ AstVarName FullSpan x -> FullShapeTK x
forall (s :: AstSpanType) (y :: TK).
AstVarName s y -> FullShapeTK y
varNameToFTK AstVarName FullSpan x
var
      varInputs :: ADVal (AstRaw PrimalSpan) x
varInputs = AstRaw PrimalSpan x
-> Delta (AstRaw PrimalSpan) x -> ADVal (AstRaw PrimalSpan) x
forall (f :: Target) (z :: TK). f z -> Delta f z -> ADVal f z
dDnotShared (AstTensor AstMethodShare PrimalSpan x -> AstRaw PrimalSpan x
forall (s :: AstSpanType) (y :: TK).
AstTensor AstMethodShare s y -> AstRaw s y
AstRaw AstTensor AstMethodShare PrimalSpan x
astVarPrimal) Delta (AstRaw PrimalSpan) x
deltaInputs
  in ADVal (AstRaw PrimalSpan) x -> ADVal (AstRaw PrimalSpan) z
g ADVal (AstRaw PrimalSpan) x
varInputs


-- * Symbolic forward derivative adaptors

-- | The forward derivative operation takes the perturbation parameter
-- by convention. It permits an arbitrary (nested tuple+)
-- type of the domain and arbitrary (nested pair) tensor kind of the codomain
-- of the function to be differentiated. The generality sometimes makes it
-- necessary to suppy type hints when applying this operation.
jvp
  :: forall src ztgt tgt.
     ( X src ~ X (Value src), KnownSTK (X src)
     , AdaptableTarget (AstTensor AstMethodLet FullSpan) src
     , AdaptableTarget Concrete (Value src)
     , tgt ~ AstTensor AstMethodLet FullSpan ztgt )
  => (src -> tgt)  -- ^ the objective function
  -> Value src
  -> Value src  -- morally (ADTensorKind src)
  -> Concrete (ADTensorKind ztgt)
{-# INLINE jvp #-}
jvp :: forall src (ztgt :: TK) tgt.
((X src :: TK) ~ (X (Value src) :: TK), KnownSTK (X src),
 AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
 AdaptableTarget Concrete (Value src),
 (tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
(src -> tgt)
-> Value src -> Value src -> Concrete (ADTensorKind ztgt)
jvp src -> tgt
f Value src
vals0 Value src
ds =
  let valsTarget :: Concrete (X (Value src))
valsTarget = Value src -> Concrete (X (Value src))
forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget Value src
vals0
      xftk :: FullShapeTK (X src)
xftk = SingletonTK (X src) -> RepConcrete (X src) -> FullShapeTK (X src)
forall (y :: TK). SingletonTK y -> RepConcrete y -> FullShapeTK y
tftkG (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(X src)) (RepConcrete (X src) -> FullShapeTK (X src))
-> RepConcrete (X src) -> FullShapeTK (X src)
forall a b. (a -> b) -> a -> b
$ Concrete (X src) -> RepConcrete (X src)
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete (X src)
Concrete (X (Value src))
valsTarget
      artifactRaw :: AstArtifactFwd (X src) ztgt
artifactRaw = (src -> tgt) -> FullShapeTK (X src) -> AstArtifactFwd (X src) ztgt
forall src (ztgt :: TK) tgt.
(AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
 (tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
(src -> tgt) -> FullShapeTK (X src) -> AstArtifactFwd (X src) ztgt
fwdArtifactAdapt src -> tgt
f FullShapeTK (X src)
xftk
      artifact :: AstArtifactFwd (X src) ztgt
artifact = AstArtifactFwd (X src) ztgt -> AstArtifactFwd (X src) ztgt
forall (x :: TK) (z :: TK).
AstArtifactFwd x z -> AstArtifactFwd x z
simplifyArtifactDerivative AstArtifactFwd (X src) ztgt
artifactRaw
  in (Concrete (ADTensorKind ztgt), Concrete ztgt)
-> Concrete (ADTensorKind ztgt)
forall a b. (a, b) -> a
fst ((Concrete (ADTensorKind ztgt), Concrete ztgt)
 -> Concrete (ADTensorKind ztgt))
-> (Concrete (ADTensorKind ztgt), Concrete ztgt)
-> Concrete (ADTensorKind ztgt)
forall a b. (a -> b) -> a -> b
$ AstArtifactFwd (X src) ztgt
-> Concrete (X src)
-> Concrete (ADTensorKind (X src))
-> (Concrete (ADTensorKind ztgt), Concrete ztgt)
forall (x :: TK) (z :: TK).
AstArtifactFwd x z
-> Concrete x
-> Concrete (ADTensorKind x)
-> (Concrete (ADTensorKind z), Concrete z)
fwdInterpretArtifact AstArtifactFwd (X src) ztgt
artifact Concrete (X src)
Concrete (X (Value src))
valsTarget
         (Concrete (ADTensorKind (X src))
 -> (Concrete (ADTensorKind ztgt), Concrete ztgt))
-> Concrete (ADTensorKind (X src))
-> (Concrete (ADTensorKind ztgt), Concrete ztgt)
forall a b. (a -> b) -> a -> b
$ FullShapeTK (X src)
-> Concrete (X src) -> Concrete (ADTensorKind (X src))
forall (target :: Target) (y :: TK).
(BaseTensor target, ConvertTensor target, ShareTensor target) =>
FullShapeTK y -> target y -> target (ADTensorKind y)
toADTensorKindShared FullShapeTK (X src)
xftk (Value src -> Concrete (X (Value src))
forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget Value src
ds)
       -- the shapes of vals0 vs ds are checked in fwdInterpretArtifact

-- | Compute the forward derivative not for a specific input, but as symbolic
-- function from inputs and perturbation to the derivative value.
-- The function is represented as an "artifact", which is the derivative
-- AST term together with variables corresponding to the input and perturbation.
jvpArtifact
  :: forall src ztgt tgt.
     ( X src ~ X (Value src), KnownSTK (X src)
     , AdaptableTarget (AstTensor AstMethodLet FullSpan) src
     , AdaptableTarget Concrete (Value src)
     , tgt ~ AstTensor AstMethodLet FullSpan ztgt )
  => (src -> tgt)  -- ^ the objective function
  -> Value src
  -> AstArtifactFwd (X src) ztgt
       -- ^ the artifact containing the symbolic code of the derivative
{-# INLINE jvpArtifact #-}
jvpArtifact :: forall src (ztgt :: TK) tgt.
((X src :: TK) ~ (X (Value src) :: TK), KnownSTK (X src),
 AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
 AdaptableTarget Concrete (Value src),
 (tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
(src -> tgt) -> Value src -> AstArtifactFwd (X src) ztgt
jvpArtifact src -> tgt
f Value src
vals0 =
  let xftk :: FullShapeTK (X src)
xftk = SingletonTK (X src) -> RepConcrete (X src) -> FullShapeTK (X src)
forall (y :: TK). SingletonTK y -> RepConcrete y -> FullShapeTK y
tftkG (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(X src)) (RepConcrete (X src) -> FullShapeTK (X src))
-> RepConcrete (X src) -> FullShapeTK (X src)
forall a b. (a -> b) -> a -> b
$ Concrete (X (Value src)) -> RepConcrete (X (Value src))
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete (Concrete (X (Value src)) -> RepConcrete (X (Value src)))
-> Concrete (X (Value src)) -> RepConcrete (X (Value src))
forall a b. (a -> b) -> a -> b
$ Value src -> Concrete (X (Value src))
forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget Value src
vals0
  in (src -> tgt) -> FullShapeTK (X src) -> AstArtifactFwd (X src) ztgt
forall src (ztgt :: TK) tgt.
(AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
 (tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
(src -> tgt) -> FullShapeTK (X src) -> AstArtifactFwd (X src) ztgt
fwdArtifactAdapt src -> tgt
f FullShapeTK (X src)
xftk

-- | Interpret the "artifact" as a function from concrete tensors
-- to a concrete tensor.
jvpInterpretArtifact
  :: forall x z.
     AstArtifactFwd x z
       -- ^ the artifact containing the symbolic code of the derivative
  -> Concrete x
  -> Concrete (ADTensorKind x)
  -> Concrete (ADTensorKind z)
{-# INLINE jvpInterpretArtifact #-}
jvpInterpretArtifact :: forall (x :: TK) (z :: TK).
AstArtifactFwd x z
-> Concrete x
-> Concrete (ADTensorKind x)
-> Concrete (ADTensorKind z)
jvpInterpretArtifact AstArtifactFwd x z
art Concrete x
parameters = (Concrete (ADTensorKind z), Concrete z)
-> Concrete (ADTensorKind z)
forall a b. (a, b) -> a
fst ((Concrete (ADTensorKind z), Concrete z)
 -> Concrete (ADTensorKind z))
-> (Concrete (ADTensorKind x)
    -> (Concrete (ADTensorKind z), Concrete z))
-> Concrete (ADTensorKind x)
-> Concrete (ADTensorKind z)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AstArtifactFwd x z
-> Concrete x
-> Concrete (ADTensorKind x)
-> (Concrete (ADTensorKind z), Concrete z)
forall (x :: TK) (z :: TK).
AstArtifactFwd x z
-> Concrete x
-> Concrete (ADTensorKind x)
-> (Concrete (ADTensorKind z), Concrete z)
fwdInterpretArtifact AstArtifactFwd x z
art Concrete x
parameters
  -- the shapes of parameters vs ds are checked in fwdInterpretArtifact


-- * Symbolic forward derivative adaptors' internal machinery

fwdArtifactAdapt
  :: forall src ztgt tgt.
     ( AdaptableTarget (AstTensor AstMethodLet FullSpan) src
     , tgt ~ AstTensor AstMethodLet FullSpan ztgt )
  => (src -> tgt)  -- ^ the objective function
  -> FullShapeTK (X src)
  -> AstArtifactFwd (X src) ztgt
       -- ^ the artifact containing the symbolic code of the derivative
{-# INLINE fwdArtifactAdapt #-}
fwdArtifactAdapt :: forall src (ztgt :: TK) tgt.
(AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
 (tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
(src -> tgt) -> FullShapeTK (X src) -> AstArtifactFwd (X src) ztgt
fwdArtifactAdapt src -> tgt
f FullShapeTK (X src)
xftk =
  let g :: AstTensor AstMethodLet FullSpan (X src) -> tgt
      g :: AstTensor AstMethodLet FullSpan (X src) -> tgt
g !AstTensor AstMethodLet FullSpan (X src)
arg = AstTensor AstMethodLet FullSpan ztgt
-> AstTensor AstMethodLet FullSpan ztgt
forall (z :: TK) (s :: AstSpanType).
AstSpan s =>
AstTensor AstMethodLet s z -> AstTensor AstMethodLet s z
simplifyInline (AstTensor AstMethodLet FullSpan ztgt
 -> AstTensor AstMethodLet FullSpan ztgt)
-> AstTensor AstMethodLet FullSpan ztgt
-> AstTensor AstMethodLet FullSpan ztgt
forall a b. (a -> b) -> a -> b
$ AstTensor AstMethodLet FullSpan (X src)
-> (AstTensor AstMethodLet FullSpan (X src)
    -> AstTensor AstMethodLet FullSpan ztgt)
-> AstTensor AstMethodLet FullSpan ztgt
forall (x :: TK) (z :: TK).
AstTensor AstMethodLet FullSpan x
-> (AstTensor AstMethodLet FullSpan x
    -> AstTensor AstMethodLet FullSpan z)
-> AstTensor AstMethodLet FullSpan z
forall (target :: Target) (x :: TK) (z :: TK).
LetTensor target =>
target x -> (target x -> target z) -> target z
ttlet AstTensor AstMethodLet FullSpan (X src)
arg ((AstTensor AstMethodLet FullSpan (X src)
  -> AstTensor AstMethodLet FullSpan ztgt)
 -> AstTensor AstMethodLet FullSpan ztgt)
-> (AstTensor AstMethodLet FullSpan (X src)
    -> AstTensor AstMethodLet FullSpan ztgt)
-> AstTensor AstMethodLet FullSpan ztgt
forall a b. (a -> b) -> a -> b
$ src -> tgt
src -> AstTensor AstMethodLet FullSpan ztgt
f (src -> AstTensor AstMethodLet FullSpan ztgt)
-> (AstTensor AstMethodLet FullSpan (X src) -> src)
-> AstTensor AstMethodLet FullSpan (X src)
-> AstTensor AstMethodLet FullSpan ztgt
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AstTensor AstMethodLet FullSpan (X src) -> src
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget
                                  -- fromTarget requires duplicable
  in (AstTensor AstMethodLet FullSpan (X src)
 -> AstTensor AstMethodLet FullSpan ztgt)
-> AstEnv (ADVal (AstRaw PrimalSpan))
-> FullShapeTK (X src)
-> AstArtifactFwd (X src) ztgt
forall (x :: TK) (z :: TK).
(AstTensor AstMethodLet FullSpan x
 -> AstTensor AstMethodLet FullSpan z)
-> AstEnv (ADVal (AstRaw PrimalSpan))
-> FullShapeTK x
-> AstArtifactFwd x z
fwdProduceArtifact AstTensor AstMethodLet FullSpan (X src) -> tgt
AstTensor AstMethodLet FullSpan (X src)
-> AstTensor AstMethodLet FullSpan ztgt
g AstEnv (ADVal (AstRaw PrimalSpan))
forall (target :: Target). AstEnv target
emptyEnv FullShapeTK (X src)
xftk

fwdInterpretArtifact
  :: forall x z.
     AstArtifactFwd x z
       -- ^ the artifact containing the symbolic code of the derivative
  -> Concrete x
  -> Concrete (ADTensorKind x)
  -> (Concrete (ADTensorKind z), Concrete z)
{-# INLINE fwdInterpretArtifact #-}
fwdInterpretArtifact :: forall (x :: TK) (z :: TK).
AstArtifactFwd x z
-> Concrete x
-> Concrete (ADTensorKind x)
-> (Concrete (ADTensorKind z), Concrete z)
fwdInterpretArtifact AstArtifactFwd{AstTensor AstMethodLet PrimalSpan z
AstTensor AstMethodLet PrimalSpan (ADTensorKind z)
AstVarName PrimalSpan x
AstVarName PrimalSpan (ADTensorKind x)
artVarDsFwd :: AstVarName PrimalSpan (ADTensorKind x)
artVarDomainFwd :: AstVarName PrimalSpan x
artDerivativeFwd :: AstTensor AstMethodLet PrimalSpan (ADTensorKind z)
artPrimalFwd :: AstTensor AstMethodLet PrimalSpan z
artPrimalFwd :: forall (x :: TK) (z :: TK).
AstArtifactFwd x z -> AstTensor AstMethodLet PrimalSpan z
artDerivativeFwd :: forall (x :: TK) (z :: TK).
AstArtifactFwd x z
-> AstTensor AstMethodLet PrimalSpan (ADTensorKind z)
artVarDomainFwd :: forall (x :: TK) (z :: TK).
AstArtifactFwd x z -> AstVarName PrimalSpan x
artVarDsFwd :: forall (x :: TK) (z :: TK).
AstArtifactFwd x z -> AstVarName PrimalSpan (ADTensorKind x)
..} Concrete x
parameters Concrete (ADTensorKind x)
ds =
  let xftk :: FullShapeTK x
xftk = AstVarName PrimalSpan x -> FullShapeTK x
forall (s :: AstSpanType) (y :: TK).
AstVarName s y -> FullShapeTK y
varNameToFTK AstVarName PrimalSpan x
artVarDomainFwd
      xstk :: SingletonTK x
xstk = FullShapeTK x -> SingletonTK x
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK x
xftk
      env :: AstEnv Concrete
env = AstVarName PrimalSpan x
-> Concrete x -> AstEnv Concrete -> AstEnv Concrete
forall (target :: Target) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName PrimalSpan x
artVarDomainFwd Concrete x
parameters AstEnv Concrete
forall (target :: Target). AstEnv target
emptyEnv
      envD :: AstEnv Concrete
envD = AstVarName PrimalSpan (ADTensorKind x)
-> Concrete (ADTensorKind x) -> AstEnv Concrete -> AstEnv Concrete
forall (target :: Target) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName PrimalSpan (ADTensorKind x)
artVarDsFwd Concrete (ADTensorKind x)
ds AstEnv Concrete
env
  in if SingletonTK x -> RepConcrete x -> FullShapeTK x
forall (y :: TK). SingletonTK y -> RepConcrete y -> FullShapeTK y
tftkG SingletonTK x
xstk (Concrete x -> RepConcrete x
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete x
parameters) FullShapeTK x -> FullShapeTK x -> Bool
forall a. Eq a => a -> a -> Bool
== FullShapeTK x
xftk
     then if SingletonTK (ADTensorKind x)
-> RepConcrete (ADTensorKind x) -> FullShapeTK (ADTensorKind x)
forall (y :: TK). SingletonTK y -> RepConcrete y -> FullShapeTK y
tftkG (SingletonTK x -> SingletonTK (ADTensorKind x)
forall (y :: TK). SingletonTK y -> SingletonTK (ADTensorKind y)
adSTK SingletonTK x
xstk) (Concrete (ADTensorKind x) -> RepConcrete (ADTensorKind x)
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete (ADTensorKind x)
ds) FullShapeTK (ADTensorKind x)
-> FullShapeTK (ADTensorKind x) -> Bool
forall a. Eq a => a -> a -> Bool
== FullShapeTK x -> FullShapeTK (ADTensorKind x)
forall (y :: TK). FullShapeTK y -> FullShapeTK (ADTensorKind y)
adFTK FullShapeTK x
xftk
          then let derivative :: PrimalOf Concrete (ADTensorKind z)
derivative = AstEnv Concrete
-> AstTensor AstMethodLet PrimalSpan (ADTensorKind z)
-> PrimalOf Concrete (ADTensorKind z)
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target
-> AstTensor AstMethodLet PrimalSpan y -> PrimalOf target y
interpretAstPrimal AstEnv Concrete
envD AstTensor AstMethodLet PrimalSpan (ADTensorKind z)
artDerivativeFwd
                   primal :: PrimalOf Concrete z
primal = AstEnv Concrete
-> AstTensor AstMethodLet PrimalSpan z -> PrimalOf Concrete z
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target
-> AstTensor AstMethodLet PrimalSpan y -> PrimalOf target y
interpretAstPrimal AstEnv Concrete
env AstTensor AstMethodLet PrimalSpan z
artPrimalFwd
               in (PrimalOf Concrete (ADTensorKind z)
Concrete (ADTensorKind z)
derivative, PrimalOf Concrete z
Concrete z
primal)
          else [Char] -> (Concrete (ADTensorKind z), Concrete z)
forall a. HasCallStack => [Char] -> a
error [Char]
"fwdInterpretArtifact: forward derivative perturbation must have the same shape as the domain of the objective function"
     else [Char] -> (Concrete (ADTensorKind z), Concrete z)
forall a. HasCallStack => [Char] -> a
error [Char]
"fwdInterpretArtifact: forward derivative input must have the same shape as the domain of the objective function"


-- * Symbolic forward derivative adaptors' testing-only internal machinery

fwdArtifactDelta
  :: forall src ztgt tgt.
     ( AdaptableTarget (AstTensor AstMethodLet FullSpan) src
     , tgt ~ AstTensor AstMethodLet FullSpan ztgt )
  => (src -> tgt)  -- ^ the objective function
  -> FullShapeTK (X src)
  -> (AstArtifactFwd (X src) ztgt, Delta (AstRaw PrimalSpan) ztgt)
       -- ^ the artifact containing the symbolic code of the derivative
{-# INLINE fwdArtifactDelta #-}
fwdArtifactDelta :: forall src (ztgt :: TK) tgt.
(AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
 (tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
(src -> tgt)
-> FullShapeTK (X src)
-> (AstArtifactFwd (X src) ztgt, Delta (AstRaw PrimalSpan) ztgt)
fwdArtifactDelta src -> tgt
f FullShapeTK (X src)
xftk =
  let g :: AstTensor AstMethodLet FullSpan (X src) -> tgt
      g :: AstTensor AstMethodLet FullSpan (X src) -> tgt
g !AstTensor AstMethodLet FullSpan (X src)
arg = AstTensor AstMethodLet FullSpan (X src)
-> (AstTensor AstMethodLet FullSpan (X src)
    -> AstTensor AstMethodLet FullSpan ztgt)
-> AstTensor AstMethodLet FullSpan ztgt
forall (x :: TK) (z :: TK).
AstTensor AstMethodLet FullSpan x
-> (AstTensor AstMethodLet FullSpan x
    -> AstTensor AstMethodLet FullSpan z)
-> AstTensor AstMethodLet FullSpan z
forall (target :: Target) (x :: TK) (z :: TK).
LetTensor target =>
target x -> (target x -> target z) -> target z
ttlet AstTensor AstMethodLet FullSpan (X src)
arg ((AstTensor AstMethodLet FullSpan (X src)
  -> AstTensor AstMethodLet FullSpan ztgt)
 -> AstTensor AstMethodLet FullSpan ztgt)
-> (AstTensor AstMethodLet FullSpan (X src)
    -> AstTensor AstMethodLet FullSpan ztgt)
-> AstTensor AstMethodLet FullSpan ztgt
forall a b. (a -> b) -> a -> b
$ src -> tgt
src -> AstTensor AstMethodLet FullSpan ztgt
f (src -> AstTensor AstMethodLet FullSpan ztgt)
-> (AstTensor AstMethodLet FullSpan (X src) -> src)
-> AstTensor AstMethodLet FullSpan (X src)
-> AstTensor AstMethodLet FullSpan ztgt
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AstTensor AstMethodLet FullSpan (X src) -> src
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget
  in (AstTensor AstMethodShare PrimalSpan (X src)
 -> AstVarName FullSpan (X src)
 -> AstTensor AstMethodLet FullSpan (X src)
 -> ADVal (AstRaw PrimalSpan) ztgt)
-> FullShapeTK (X src)
-> (AstArtifactFwd (X src) ztgt, Delta (AstRaw PrimalSpan) ztgt)
forall (x :: TK) (z :: TK).
(AstTensor AstMethodShare PrimalSpan x
 -> AstVarName FullSpan x
 -> AstTensor AstMethodLet FullSpan x
 -> ADVal (AstRaw PrimalSpan) z)
-> FullShapeTK x
-> (AstArtifactFwd x z, Delta (AstRaw PrimalSpan) z)
fwdArtifactFromForwardPass ((AstTensor AstMethodLet FullSpan (X src)
 -> AstTensor AstMethodLet FullSpan ztgt)
-> AstEnv (ADVal (AstRaw PrimalSpan))
-> AstTensor AstMethodShare PrimalSpan (X src)
-> AstVarName FullSpan (X src)
-> AstTensor AstMethodLet FullSpan (X src)
-> ADVal (AstRaw PrimalSpan) ztgt
forall (x :: TK) (z :: TK).
(AstTensor AstMethodLet FullSpan x
 -> AstTensor AstMethodLet FullSpan z)
-> AstEnv (ADVal (AstRaw PrimalSpan))
-> AstTensor AstMethodShare PrimalSpan x
-> AstVarName FullSpan x
-> AstTensor AstMethodLet FullSpan x
-> ADVal (AstRaw PrimalSpan) z
forwardPassByInterpretation AstTensor AstMethodLet FullSpan (X src) -> tgt
AstTensor AstMethodLet FullSpan (X src)
-> AstTensor AstMethodLet FullSpan ztgt
g AstEnv (ADVal (AstRaw PrimalSpan))
forall (target :: Target). AstEnv target
emptyEnv) FullShapeTK (X src)
xftk


-- * Non-symbolic reverse derivative adaptors

-- We are inlining these functions because they take function arguments
-- and are not too large. However, because they are called in many places,
-- we break the inline chain not far from the top, to avoid exe blowup.
--
-- | This simplified version of the concrete (non-symbolic)
-- reverse derivative operation sets the incoming cotangent @dt@ to be 1
-- and assumes the codomain of the function to be differentiated is a scalar.
cgrad
  :: forall src r tgt.
     ( X src ~ X (DValue src), KnownSTK (X src)
     , AdaptableTarget (ADVal Concrete) src
     , AdaptableTarget Concrete (DValue src)
     , tgt ~ ADVal Concrete (TKScalar r) )
  => (src -> tgt)  -- ^ the objective function
  -> DValue src
  -> DValue src  -- morally DValue (ADTensorKind src)
{-# INLINE cgrad #-}
cgrad :: forall src r tgt.
((X src :: TK) ~ (X (DValue src) :: TK), KnownSTK (X src),
 AdaptableTarget (ADVal Concrete) src,
 AdaptableTarget Concrete (DValue src),
 (tgt :: Type) ~ (ADVal Concrete (TKScalar r) :: Type)) =>
(src -> tgt) -> DValue src -> DValue src
cgrad src -> tgt
f DValue src
vals = (src -> tgt)
-> DValue src
-> Maybe (Concrete (ADTensorKind (TKScalar r)))
-> DValue src
forall src (ztgt :: TK) tgt.
((X src :: TK) ~ (X (DValue src) :: TK), KnownSTK (X src),
 AdaptableTarget (ADVal Concrete) src,
 AdaptableTarget Concrete (DValue src),
 (tgt :: Type) ~ (ADVal Concrete ztgt :: Type)) =>
(src -> tgt)
-> DValue src -> Maybe (Concrete (ADTensorKind ztgt)) -> DValue src
crevMaybe src -> tgt
f DValue src
vals Maybe (Concrete (ADTensorKind (TKScalar r)))
Maybe (Concrete (TKScalar (ADTensorScalar r)))
forall a. Maybe a
Nothing

-- | This more general version of the concrete (non-symbolic)
-- reverse derivative operation additionally takes the sensitivity parameter
-- (the incoming cotangent).
cvjp
  :: forall src ztgt tgt.
     ( X src ~ X (DValue src), KnownSTK (X src)
     , AdaptableTarget (ADVal Concrete) src
     , AdaptableTarget Concrete (DValue src)
     , tgt ~ ADVal Concrete ztgt )
  => (src -> tgt)  -- ^ the objective function
  -> DValue src
  -> Concrete (ADTensorKind ztgt)
  -> DValue src  -- morally DValue (ADTensorKind src)
{-# INLINE cvjp #-}
cvjp :: forall src (ztgt :: TK) tgt.
((X src :: TK) ~ (X (DValue src) :: TK), KnownSTK (X src),
 AdaptableTarget (ADVal Concrete) src,
 AdaptableTarget Concrete (DValue src),
 (tgt :: Type) ~ (ADVal Concrete ztgt :: Type)) =>
(src -> tgt)
-> DValue src -> Concrete (ADTensorKind ztgt) -> DValue src
cvjp src -> tgt
f DValue src
vals Concrete (ADTensorKind ztgt)
dt = (src -> tgt)
-> DValue src -> Maybe (Concrete (ADTensorKind ztgt)) -> DValue src
forall src (ztgt :: TK) tgt.
((X src :: TK) ~ (X (DValue src) :: TK), KnownSTK (X src),
 AdaptableTarget (ADVal Concrete) src,
 AdaptableTarget Concrete (DValue src),
 (tgt :: Type) ~ (ADVal Concrete ztgt :: Type)) =>
(src -> tgt)
-> DValue src -> Maybe (Concrete (ADTensorKind ztgt)) -> DValue src
crevMaybe src -> tgt
f DValue src
vals (Concrete (ADTensorKind ztgt)
-> Maybe (Concrete (ADTensorKind ztgt))
forall a. a -> Maybe a
Just Concrete (ADTensorKind ztgt)
dt)


-- * Non-symbolic reverse derivative adaptors' internal machinery

crevMaybe
  :: forall src ztgt tgt.
     ( X src ~ X (DValue src), KnownSTK (X src)
     , AdaptableTarget (ADVal Concrete) src
     , AdaptableTarget Concrete (DValue src)
     , tgt ~ ADVal Concrete ztgt )
  => (src -> tgt)  -- ^ the objective function
  -> DValue src
  -> Maybe (Concrete (ADTensorKind ztgt))
  -> DValue src  -- morally DValue (ADTensorKind src)
{-# INLINE crevMaybe #-}
crevMaybe :: forall src (ztgt :: TK) tgt.
((X src :: TK) ~ (X (DValue src) :: TK), KnownSTK (X src),
 AdaptableTarget (ADVal Concrete) src,
 AdaptableTarget Concrete (DValue src),
 (tgt :: Type) ~ (ADVal Concrete ztgt :: Type)) =>
(src -> tgt)
-> DValue src -> Maybe (Concrete (ADTensorKind ztgt)) -> DValue src
crevMaybe src -> tgt
f DValue src
vals0 Maybe (Concrete (ADTensorKind ztgt))
mdt =
  let valsTarget :: Concrete (X (DValue src))
valsTarget = DValue src -> Concrete (X (DValue src))
forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget DValue src
vals0
      g :: ADVal Concrete (X src) -> tgt
      g :: ADVal Concrete (X src) -> tgt
g = src -> tgt
f (src -> tgt)
-> (ADVal Concrete (X src) -> src) -> ADVal Concrete (X src) -> tgt
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ADVal Concrete (X src) -> src
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget
      xftk :: FullShapeTK (X src)
xftk = SingletonTK (X src) -> RepConcrete (X src) -> FullShapeTK (X src)
forall (y :: TK). SingletonTK y -> RepConcrete y -> FullShapeTK y
tftkG (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(X src)) (RepConcrete (X src) -> FullShapeTK (X src))
-> RepConcrete (X src) -> FullShapeTK (X src)
forall a b. (a -> b) -> a -> b
$ Concrete (X src) -> RepConcrete (X src)
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete (X src)
Concrete (X (DValue src))
valsTarget
  in Concrete (X (DValue src)) -> DValue src
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget (Concrete (X (DValue src)) -> DValue src)
-> Concrete (X (DValue src)) -> DValue src
forall a b. (a -> b) -> a -> b
$ SingletonTK (X (DValue src))
-> Concrete (ADTensorKind (X (DValue src)))
-> Concrete (X (DValue src))
forall (target :: Target) (y :: TK).
(BaseTensor target, ConvertTensor target, ShareTensor target) =>
SingletonTK y -> target (ADTensorKind y) -> target y
fromADTensorKindShared (FullShapeTK (X src) -> SingletonTK (X src)
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK (X src)
xftk)
     (Concrete (ADTensorKind (X (DValue src)))
 -> Concrete (X (DValue src)))
-> Concrete (ADTensorKind (X (DValue src)))
-> Concrete (X (DValue src))
forall a b. (a -> b) -> a -> b
$ (Concrete (ADTensorKind (X (DValue src))), Concrete ztgt)
-> Concrete (ADTensorKind (X (DValue src)))
forall a b. (a, b) -> a
fst ((Concrete (ADTensorKind (X (DValue src))), Concrete ztgt)
 -> Concrete (ADTensorKind (X (DValue src))))
-> (Concrete (ADTensorKind (X (DValue src))), Concrete ztgt)
-> Concrete (ADTensorKind (X (DValue src)))
forall a b. (a -> b) -> a -> b
$ Maybe (Concrete (ADTensorKind ztgt))
-> (ADVal Concrete (X src) -> ADVal Concrete ztgt)
-> FullShapeTK (X src)
-> Concrete (X src)
-> (Concrete (ADTensorKind (X src)), Concrete ztgt)
forall (x :: TK) (z :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target) =>
Maybe (target (ADTensorKind z))
-> (ADVal target x -> ADVal target z)
-> FullShapeTK x
-> target x
-> (target (ADTensorKind x), target z)
crevOnParams Maybe (Concrete (ADTensorKind ztgt))
mdt ADVal Concrete (X src) -> tgt
ADVal Concrete (X src) -> ADVal Concrete ztgt
g FullShapeTK (X src)
xftk Concrete (X src)
Concrete (X (DValue src))
valsTarget


-- * Non-symbolic forward derivative adaptors

-- | Concrete (non-symbolic) forward derivative operation. It always takes
-- the perturbation parameter, by convention.
cjvp
  :: forall src ztgt tgt.
     ( X src ~ X (DValue src), KnownSTK (X src)
     , AdaptableTarget (ADVal Concrete) src
     , AdaptableTarget Concrete (DValue src)
     , tgt ~ ADVal Concrete ztgt )
  => (src -> tgt)  -- ^ the objective function
  -> DValue src
  -> DValue src  -- morally (ADTensorKind src)
  -> Concrete (ADTensorKind ztgt)
{-# INLINE cjvp #-}
cjvp :: forall src (ztgt :: TK) tgt.
((X src :: TK) ~ (X (DValue src) :: TK), KnownSTK (X src),
 AdaptableTarget (ADVal Concrete) src,
 AdaptableTarget Concrete (DValue src),
 (tgt :: Type) ~ (ADVal Concrete ztgt :: Type)) =>
(src -> tgt)
-> DValue src -> DValue src -> Concrete (ADTensorKind ztgt)
cjvp src -> tgt
f DValue src
vals DValue src
ds = (Concrete (ADTensorKind ztgt), Concrete ztgt)
-> Concrete (ADTensorKind ztgt)
forall a b. (a, b) -> a
fst ((Concrete (ADTensorKind ztgt), Concrete ztgt)
 -> Concrete (ADTensorKind ztgt))
-> (Concrete (ADTensorKind ztgt), Concrete ztgt)
-> Concrete (ADTensorKind ztgt)
forall a b. (a -> b) -> a -> b
$ (src -> tgt)
-> DValue src
-> DValue src
-> (Concrete (ADTensorKind ztgt), Concrete ztgt)
forall src (ztgt :: TK) tgt.
((X src :: TK) ~ (X (DValue src) :: TK), KnownSTK (X src),
 AdaptableTarget (ADVal Concrete) src,
 AdaptableTarget Concrete (DValue src),
 (tgt :: Type) ~ (ADVal Concrete ztgt :: Type)) =>
(src -> tgt)
-> DValue src
-> DValue src
-> (Concrete (ADTensorKind ztgt), Concrete ztgt)
cfwdBoth src -> tgt
f DValue src
vals DValue src
ds


-- * Non-symbolic forward derivative adaptors' internal machinery

cfwdBoth
  :: forall src ztgt tgt.
     ( X src ~ X (DValue src), KnownSTK (X src)
     , AdaptableTarget (ADVal Concrete) src
     , AdaptableTarget Concrete (DValue src)
     , tgt ~ ADVal Concrete ztgt )
  => (src -> tgt)  -- ^ the objective function
  -> DValue src
  -> DValue src  -- morally (ADTensorKind src)
  -> (Concrete (ADTensorKind ztgt), Concrete ztgt)
{-# INLINE cfwdBoth #-}
cfwdBoth :: forall src (ztgt :: TK) tgt.
((X src :: TK) ~ (X (DValue src) :: TK), KnownSTK (X src),
 AdaptableTarget (ADVal Concrete) src,
 AdaptableTarget Concrete (DValue src),
 (tgt :: Type) ~ (ADVal Concrete ztgt :: Type)) =>
(src -> tgt)
-> DValue src
-> DValue src
-> (Concrete (ADTensorKind ztgt), Concrete ztgt)
cfwdBoth src -> tgt
f DValue src
vals0 DValue src
ds =
  let valsTarget :: Concrete (X (DValue src))
valsTarget = DValue src -> Concrete (X (DValue src))
forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget DValue src
vals0
      xftk :: FullShapeTK (X src)
xftk = SingletonTK (X src) -> RepConcrete (X src) -> FullShapeTK (X src)
forall (y :: TK). SingletonTK y -> RepConcrete y -> FullShapeTK y
tftkG (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(X src)) (RepConcrete (X src) -> FullShapeTK (X src))
-> RepConcrete (X src) -> FullShapeTK (X src)
forall a b. (a -> b) -> a -> b
$ Concrete (X src) -> RepConcrete (X src)
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete (X src)
Concrete (X (DValue src))
valsTarget
      g :: ADVal Concrete (X src) -> tgt
      g :: ADVal Concrete (X src) -> tgt
g = src -> tgt
f (src -> tgt)
-> (ADVal Concrete (X src) -> src) -> ADVal Concrete (X src) -> tgt
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ADVal Concrete (X src) -> src
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget
      dsTarget :: Concrete (X (DValue src))
dsTarget = DValue src -> Concrete (X (DValue src))
forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget DValue src
ds
  in if SingletonTK (X src) -> RepConcrete (X src) -> FullShapeTK (X src)
forall (y :: TK). SingletonTK y -> RepConcrete y -> FullShapeTK y
tftkG (FullShapeTK (X src) -> SingletonTK (X src)
forall (y :: TK). FullShapeTK y -> SingletonTK y
ftkToSTK FullShapeTK (X src)
xftk) (Concrete (X src) -> RepConcrete (X src)
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete (X src)
Concrete (X (DValue src))
dsTarget) FullShapeTK (X src) -> FullShapeTK (X src) -> Bool
forall a. Eq a => a -> a -> Bool
== FullShapeTK (X src)
xftk
     then FullShapeTK (X src)
-> Concrete (X src)
-> (ADVal Concrete (X src) -> ADVal Concrete ztgt)
-> Concrete (ADTensorKind (X src))
-> (Concrete (ADTensorKind ztgt), Concrete ztgt)
forall (x :: TK) (z :: TK) (target :: Target).
(ADReadyNoLet target, ShareTensor target) =>
FullShapeTK x
-> target x
-> (ADVal target x -> ADVal target z)
-> target (ADTensorKind x)
-> (target (ADTensorKind z), target z)
cfwdOnParams FullShapeTK (X src)
xftk Concrete (X src)
Concrete (X (DValue src))
valsTarget ADVal Concrete (X src) -> tgt
ADVal Concrete (X src) -> ADVal Concrete ztgt
g
          (Concrete (ADTensorKind (X src))
 -> (Concrete (ADTensorKind ztgt), Concrete ztgt))
-> Concrete (ADTensorKind (X src))
-> (Concrete (ADTensorKind ztgt), Concrete ztgt)
forall a b. (a -> b) -> a -> b
$ FullShapeTK (X src)
-> Concrete (X src) -> Concrete (ADTensorKind (X src))
forall (target :: Target) (y :: TK).
(BaseTensor target, ConvertTensor target, ShareTensor target) =>
FullShapeTK y -> target y -> target (ADTensorKind y)
toADTensorKindShared FullShapeTK (X src)
xftk Concrete (X src)
Concrete (X (DValue src))
dsTarget
     else [Char] -> (Concrete (ADTensorKind ztgt), Concrete ztgt)
forall a. HasCallStack => [Char] -> a
error [Char]
"cfwdBoth: forward derivative input must have the same shape as the perturbation argument"





-- This specialization is not possible where the functions are defined,
-- due to dependency cycles, but it's possible here:
{-# SPECIALIZE gradientFromDelta :: FullShapeTK x -> FullShapeTK z -> Concrete (ADTensorKind z) -> Delta Concrete z -> Concrete (ADTensorKind x) #-}
{-# SPECIALIZE evalRev :: FullShapeTK y -> EvalState Concrete -> Concrete (ADTensorKind y) -> Delta Concrete y -> EvalState Concrete #-}
{-# SPECIALIZE evalRevFTK :: EvalState Concrete -> Concrete (ADTensorKind y) -> Delta Concrete y -> EvalState Concrete #-}
-- RULE left-hand side too complicated to desugar:
-- {-# SPECIALIZE evalRevSame :: y ~ ADTensorKind y => EvalState Concrete -> Concrete (ADTensorKind y) -> Delta Concrete y -> EvalState Concrete #-}
{-# SPECIALIZE evalRevFromnMap :: EvalState Concrete -> EvalState Concrete #-}

{-# SPECIALIZE evalRevSame :: EvalState Concrete -> Concrete (TKScalar Double) -> Delta Concrete (TKScalar Double) -> EvalState Concrete #-}
{-# SPECIALIZE evalRevSame :: EvalState Concrete -> Concrete (TKScalar Float) -> Delta Concrete (TKScalar Float) -> EvalState Concrete #-}
{-# SPECIALIZE evalRevSame :: EvalState Concrete -> Concrete (TKR n Double) -> Delta Concrete (TKR n Double) -> EvalState Concrete #-}
{-# SPECIALIZE evalRevSame :: EvalState Concrete -> Concrete (TKR n Float) -> Delta Concrete (TKR n Float) -> EvalState Concrete #-}
{-# SPECIALIZE evalRevSame :: EvalState Concrete -> Concrete (TKS sh Double) -> Delta Concrete (TKS sh Double) -> EvalState Concrete #-}
{-# SPECIALIZE evalRevSame :: EvalState Concrete -> Concrete (TKS sh Float) -> Delta Concrete (TKS sh Float) -> EvalState Concrete #-}
{-# SPECIALIZE evalRevSame :: EvalState Concrete -> Concrete (TKX sh Double) -> Delta Concrete (TKX sh Double) -> EvalState Concrete #-}
{-# SPECIALIZE evalRevSame :: EvalState Concrete -> Concrete (TKX sh Float) -> Delta Concrete (TKX sh Float) -> EvalState Concrete #-}


-- These and all other SPECIALIZE pragmas are needed due to the already
-- mostly fixed issues #21286 and others, even just to compare
-- the output with them and without.
-- This is needed for all three AstSpan values, to handle recursive calls
-- from interpretAstDual, etc.
{-# SPECIALIZE interpretAst
  :: AstEnv (ADVal Concrete)
  -> AstTensor AstMethodLet PrimalSpan y
  -> ADVal Concrete y #-}
{-# SPECIALIZE interpretAst
  :: AstEnv (ADVal (AstRaw PrimalSpan))
  -> AstTensor AstMethodLet PrimalSpan y
  -> ADVal (AstRaw PrimalSpan) y #-}
{-# SPECIALIZE interpretAst
  :: AstEnv Concrete
  -> AstTensor AstMethodLet PrimalSpan y
  -> Concrete y #-}
{-# SPECIALIZE interpretAst
  :: AstEnv (ADVal Concrete)
  -> AstTensor AstMethodLet DualSpan y
  -> ADVal Concrete y #-}
{-# SPECIALIZE interpretAst
  :: AstEnv (ADVal (AstRaw PrimalSpan))
  -> AstTensor AstMethodLet DualSpan y
  -> ADVal (AstRaw PrimalSpan) y #-}
{-# SPECIALIZE interpretAst
  :: AstEnv Concrete
  -> AstTensor AstMethodLet DualSpan y
  -> Concrete y #-}
{-# SPECIALIZE interpretAst
  :: AstEnv (ADVal Concrete)
  -> AstTensor AstMethodLet FullSpan y
  -> ADVal Concrete y #-}
{-# SPECIALIZE interpretAst
  :: AstEnv (ADVal (AstRaw PrimalSpan))
  -> AstTensor AstMethodLet FullSpan y
  -> ADVal (AstRaw PrimalSpan) y #-}
{-# SPECIALIZE interpretAst
  :: AstEnv Concrete
  -> AstTensor AstMethodLet FullSpan y
  -> Concrete y #-}

{-# SPECIALIZE interpretAstPrimal
  :: AstEnv (ADVal Concrete)
  -> AstTensor AstMethodLet PrimalSpan y
  -> Concrete y #-}
{-# SPECIALIZE interpretAstPrimal
  :: AstEnv (ADVal (AstRaw PrimalSpan))
  -> AstTensor AstMethodLet PrimalSpan y
  -> AstRaw PrimalSpan y #-}
{-# SPECIALIZE interpretAstPrimal
  :: AstEnv Concrete
  -> AstTensor AstMethodLet PrimalSpan y
  -> Concrete y #-}

{-# SPECIALIZE interpretAstBool
  :: AstEnv (ADVal Concrete)
  -> AstBool AstMethodLet
  -> Bool #-}
{-# SPECIALIZE interpretAstBool
  :: AstEnv (ADVal (AstRaw PrimalSpan))
  -> AstBool AstMethodLet
  -> AstBool AstMethodShare #-}
{-# SPECIALIZE interpretAstBool
  :: AstEnv Concrete
  -> AstBool AstMethodLet
  -> Bool #-}