horde-ad-0.2.0.0: Higher Order Reverse Derivatives Efficiently - Automatic Differentiation
Safe HaskellNone
LanguageGHC2024

HordeAd.Core.Delta

Description

The grammar of delta expressions.

A delta expression can be viewed as a concise representation of a linear map (which is the derivative of the objective function) and its evaluation on a given argument (in module HordeAd.Core.DeltaEval) as an adjoint (in the algebraic sense) of the linear map applied to that argument. Since linear maps can be represented as matrices, this operation corresponds to a transposition of the matrix. However, the matrix is not constructed, but is represented and transposed preserving the sparsity of the representation.

The 'sparsity' is less obvious when a delta expression contains big concrete tensors, e.g., via the DeltaScale constructor. However, via DeltaReplicate and other constructors, the tensors can be enlarged much beyond what's embedded in the delta term. Also, if the expression refers to unknown inputs (DeltaInput) it may denote, after evaluation, a still larger tensor.

The algebraic structure here is an extension of vector space with some additional constructors. The crucial extra constructor DeltaInput replaces the usual one-hot access to parameters with something cheaper and more uniform. A lot of the remaining additional constructors is for introducing and reducing dimensions of tensors and it mimics many of the operations available for the primal value arrays.

Synopsis

Delta identifiers

data NodeId (a :: Target) (b :: TK) Source #

The identifiers for nodes of delta expression trees.

Instances

Instances details
Enum1 (NodeId target :: TK -> Type) Source # 
Instance details

Defined in HordeAd.Core.Delta

Associated Types

type Enum1Info (NodeId target :: TK -> Type) 
Instance details

Defined in HordeAd.Core.Delta

type Enum1Info (NodeId target :: TK -> Type) = Some FullShapeTK

Methods

fromEnum1 :: forall (a :: TK). NodeId target a -> (Int, Enum1Info (NodeId target)) #

toEnum1 :: Int -> Enum1Info (NodeId target) -> Some (NodeId target) #

TestEquality (NodeId target :: TK -> Type) Source # 
Instance details

Defined in HordeAd.Core.Delta

Methods

testEquality :: forall (a :: TK) (b :: TK). NodeId target a -> NodeId target b -> Maybe (a :~: b) #

Show (NodeId target y) Source # 
Instance details

Defined in HordeAd.Core.Delta

Methods

showsPrec :: Int -> NodeId target y -> ShowS #

show :: NodeId target y -> String #

showList :: [NodeId target y] -> ShowS #

type Enum1Info (NodeId target :: TK -> Type) Source # 
Instance details

Defined in HordeAd.Core.Delta

type Enum1Info (NodeId target :: TK -> Type) = Some FullShapeTK

mkNodeId :: forall (y :: TK) (f :: Target). FullShapeTK y -> Int -> NodeId f y Source #

Wrap non-negative (only!) integers in the NodeId newtype.

nodeIdToFTK :: forall (f :: Target) (y :: TK). NodeId f y -> FullShapeTK y Source #

data InputId (a :: Target) (b :: TK) Source #

The identifiers for input leaves of delta expressions.

Instances

Instances details
Enum1 (InputId target :: TK -> Type) Source # 
Instance details

Defined in HordeAd.Core.Delta

Associated Types

type Enum1Info (InputId target :: TK -> Type) 
Instance details

Defined in HordeAd.Core.Delta

type Enum1Info (InputId target :: TK -> Type) = Some FullShapeTK

Methods

fromEnum1 :: forall (a :: TK). InputId target a -> (Int, Enum1Info (InputId target)) #

toEnum1 :: Int -> Enum1Info (InputId target) -> Some (InputId target) #

TestEquality (InputId target :: TK -> Type) Source # 
Instance details

Defined in HordeAd.Core.Delta

Methods

testEquality :: forall (a :: TK) (b :: TK). InputId target a -> InputId target b -> Maybe (a :~: b) #

Show (InputId target y) Source # 
Instance details

Defined in HordeAd.Core.Delta

Methods

showsPrec :: Int -> InputId target y -> ShowS #

show :: InputId target y -> String #

showList :: [InputId target y] -> ShowS #

type Enum1Info (InputId target :: TK -> Type) Source # 
Instance details

Defined in HordeAd.Core.Delta

type Enum1Info (InputId target :: TK -> Type) = Some FullShapeTK

mkInputId :: forall (y :: TK) (f :: Target). FullShapeTK y -> Int -> InputId f y Source #

Wrap non-negative (only!) integers in the InputId newtype.

inputIdToFTK :: forall (f :: Target) (y :: TK). InputId f y -> FullShapeTK y Source #

The grammar of delta expressions

data Delta (a :: Target) (b :: TK) where Source #

The grammar of delta expressions.

The NodeId identifier that appears in a DeltaShare n d expression is the unique identity stamp of subterm d, that is, there is no different term e such that DeltaShare n e appears in any delta expression term in memory during the same run of an executable. The subterm identity is used to avoid evaluating shared subterms repeatedly in gradient and derivative computations. The identifiers also represent data dependencies among terms for the purpose of gradient and derivative computation. Computation for a term may depend only on data obtained from terms with lower value of their node identifiers. Such data dependency determination agrees with the subterm relation, but is faster than traversing the term tree in order to determine the relation of terms.

When computing gradients, node identifiers are also used to index, directly or indirectly, the data accumulated for each node, in the form of cotangents, that is partial derivatives of the objective function with respect to the position(s) of the node in the whole objective function dual number term (or, more precisely, with respect to the single node in the term DAG, in which subterms with the same node identifier are collapsed). Only the DeltaInput nodes have a separate data storage. The InputId identifiers in the DeltaInput term constructors are indexes into a contiguous vector of cotangents of DeltaInput subterms of the whole term. The value at that index is the partial derivative of the objective function (represented by the whole term, or more precisely by (the data flow graph of) its particular evaluation from which the delta expression originates) with respect to the input parameter component at that index in the objective function domain.

Constructors

DeltaShare :: forall (a :: Target) (b :: TK). NodeId a b -> Delta a b -> Delta a b 
DeltaInput :: forall (a :: Target) (b :: TK). InputId a b -> Delta a b 
DeltaPair :: forall (y :: TK) (z :: TK) (a :: Target). Delta a y -> Delta a z -> Delta a ('TKProduct y z) 
DeltaProject1 :: forall (b :: TK) (z :: TK) (a :: Target). Delta a ('TKProduct b z) -> Delta a b 
DeltaProject2 :: forall (y :: TK) (b :: TK) (a :: Target). Delta a ('TKProduct y b) -> Delta a b 
DeltaFromVector :: forall (y :: TK) (k :: Nat) (a :: Target). SNat k -> SingletonTK y -> Vector (Delta a y) -> Delta a (BuildTensorKind k y) 
DeltaSum :: forall (b :: TK) (k :: Nat) (a :: Target). SNat k -> SingletonTK b -> Delta a (BuildTensorKind k b) -> Delta a b 
DeltaReplicate :: forall (y :: TK) (k :: Nat) (a :: Target). SNat k -> SingletonTK y -> Delta a y -> Delta a (BuildTensorKind k y) 
DeltaMapAccumR :: forall (a :: Target) (k :: Nat) (accy :: TK) (by :: TK) (ey :: TK). (Show (a (BuildTensorKind k accy)), Show (a (BuildTensorKind k ey))) => SNat k -> FullShapeTK by -> FullShapeTK ey -> a (BuildTensorKind k accy) -> a (BuildTensorKind k ey) -> HFun ('TKProduct (ADTensorKind ('TKProduct accy ey)) ('TKProduct accy ey)) (ADTensorKind ('TKProduct accy by)) -> HFun ('TKProduct (ADTensorKind ('TKProduct accy by)) ('TKProduct accy ey)) (ADTensorKind ('TKProduct accy ey)) -> Delta a accy -> Delta a (BuildTensorKind k ey) -> Delta a ('TKProduct accy (BuildTensorKind k by)) 
DeltaMapAccumL :: forall (a :: Target) (k :: Nat) (accy :: TK) (by :: TK) (ey :: TK). (Show (a (BuildTensorKind k accy)), Show (a (BuildTensorKind k ey))) => SNat k -> FullShapeTK by -> FullShapeTK ey -> a (BuildTensorKind k accy) -> a (BuildTensorKind k ey) -> HFun ('TKProduct (ADTensorKind ('TKProduct accy ey)) ('TKProduct accy ey)) (ADTensorKind ('TKProduct accy by)) -> HFun ('TKProduct (ADTensorKind ('TKProduct accy by)) ('TKProduct accy ey)) (ADTensorKind ('TKProduct accy ey)) -> Delta a accy -> Delta a (BuildTensorKind k ey) -> Delta a ('TKProduct accy (BuildTensorKind k by)) 
DeltaZero :: forall (b :: TK) (a :: Target). FullShapeTK b -> Delta a b 
DeltaScale :: forall (a :: Target) (b :: TK). Num (a b) => NestedTarget a b -> Delta a b -> Delta a b 
DeltaAdd :: forall (a :: Target) (b :: TK). Num (a b) => Delta a b -> Delta a b -> Delta a b 
DeltaCastK :: forall r1 r2 (a :: Target). (GoodScalar r1, RealFrac r1, GoodScalar r2, RealFrac r2) => Delta a ('TKScalar r1) -> Delta a ('TKScalar r2) 
DeltaCastR :: forall r1 r2 (a :: Target) (n :: Nat). (GoodScalar r1, RealFrac r1, GoodScalar r2, RealFrac r2) => Delta a (TKR n r1) -> Delta a ('TKR2 n ('TKScalar r2)) 
DeltaSum0R :: forall (a :: Target) (n :: Nat) (r :: TK). Delta a ('TKR2 n r) -> Delta a ('TKR2 0 r) 
DeltaDot0R :: forall r (a :: Target) (n :: Nat). (GoodScalar r, Show (a (TKR n r))) => a (TKR n r) -> Delta a (TKR n r) -> Delta a ('TKR2 0 ('TKScalar r)) 
DeltaIndexR :: forall (m :: Natural) (n :: Nat) (r :: TK) (a :: Target). SNat n -> Delta a ('TKR2 (m + n) r) -> IxROf a m -> Delta a ('TKR2 n r) 
DeltaScatterR :: forall (m :: Nat) (n :: Natural) (p :: Natural) (r :: TK) (a :: Target). SNat m -> SNat n -> SNat p -> IShR (p + n) -> Delta a ('TKR2 (m + n) r) -> (IxROf a m -> IxROf a p) -> Delta a ('TKR2 (p + n) r) 
DeltaGatherR :: forall (m :: Natural) (n :: Natural) (p :: Nat) (r :: TK) (a :: Target). SNat m -> SNat n -> SNat p -> IShR (m + n) -> Delta a ('TKR2 (p + n) r) -> (IxROf a m -> IxROf a p) -> Delta a ('TKR2 (m + n) r) 
DeltaAppendR :: forall (a :: Target) (n :: Natural) (r :: TK). Delta a ('TKR2 (1 + n) r) -> Delta a ('TKR2 (1 + n) r) -> Delta a ('TKR2 (1 + n) r) 
DeltaSliceR :: forall (a :: Target) (n :: Natural) (r :: TK). Int -> Int -> Delta a ('TKR2 (1 + n) r) -> Delta a ('TKR2 (1 + n) r) 
DeltaReverseR :: forall (a :: Target) (n :: Natural) (r :: TK). Delta a ('TKR2 (1 + n) r) -> Delta a ('TKR2 (1 + n) r) 
DeltaTransposeR :: forall (a :: Target) (n :: Nat) (r :: TK). PermR -> Delta a ('TKR2 n r) -> Delta a ('TKR2 n r) 
DeltaReshapeR :: forall (m :: Nat) (a :: Target) (n :: Nat) (r :: TK). IShR m -> Delta a ('TKR2 n r) -> Delta a ('TKR2 m r) 
DeltaCastS :: forall r1 r2 (a :: Target) (sh :: [Nat]). (GoodScalar r1, RealFrac r1, GoodScalar r2, RealFrac r2) => Delta a (TKS sh r1) -> Delta a ('TKS2 sh ('TKScalar r2)) 
DeltaSum0S :: forall (a :: Target) (sh :: [Nat]) (r :: TK). Delta a ('TKS2 sh r) -> Delta a ('TKS2 ('[] :: [Nat]) r) 
DeltaDot0S :: forall r (a :: Target) (sh :: [Nat]). (GoodScalar r, Show (a (TKS sh r))) => a (TKS sh r) -> Delta a (TKS sh r) -> Delta a ('TKS2 ('[] :: [Nat]) ('TKScalar r)) 
DeltaIndexS :: forall (shm :: [Nat]) (shn :: [Nat]) (r :: TK) (a :: Target). ShS shn -> Delta a ('TKS2 (shm ++ shn) r) -> IxSOf a shm -> Delta a ('TKS2 shn r) 
DeltaScatterS :: forall (shm :: [Nat]) (shn :: [Nat]) (shp :: [Nat]) (r :: TK) (a :: Target). ShS shm -> ShS shn -> ShS shp -> Delta a ('TKS2 (shm ++ shn) r) -> (IxSOf a shm -> IxSOf a shp) -> Delta a ('TKS2 (shp ++ shn) r) 
DeltaGatherS :: forall (shm :: [Nat]) (shn :: [Nat]) (shp :: [Nat]) (r :: TK) (a :: Target). ShS shm -> ShS shn -> ShS shp -> Delta a ('TKS2 (shp ++ shn) r) -> (IxSOf a shm -> IxSOf a shp) -> Delta a ('TKS2 (shm ++ shn) r) 
DeltaAppendS :: forall (a :: Target) (r :: TK) (m :: Natural) (n :: Natural) (sh :: [Natural]). Delta a ('TKS2 (m ': sh) r) -> Delta a ('TKS2 (n ': sh) r) -> Delta a ('TKS2 ((m + n) ': sh) r) 
DeltaSliceS :: forall (i :: Nat) (n :: Nat) (k :: Nat) (a :: Target) (sh :: [Nat]) (r :: TK). SNat i -> SNat n -> SNat k -> Delta a ('TKS2 (((i + n) + k) ': sh) r) -> Delta a ('TKS2 (n ': sh) r) 
DeltaReverseS :: forall (a :: Target) (n :: Nat) (sh :: [Nat]) (r :: TK). Delta a ('TKS2 (n ': sh) r) -> Delta a ('TKS2 (n ': sh) r) 
DeltaTransposeS :: forall (perm :: [Natural]) (sh :: [Nat]) (r :: TK) (a :: Target). (IsPermutation perm, Rank perm <= Rank sh) => Perm perm -> Delta a ('TKS2 sh r) -> Delta a ('TKS2 (PermutePrefix perm sh) r) 
DeltaReshapeS :: forall (sh :: [Natural]) (sh2 :: [Natural]) (a :: Target) (r :: TK). Product sh ~ Product sh2 => ShS sh2 -> Delta a ('TKS2 sh r) -> Delta a ('TKS2 sh2 r) 
DeltaCastX :: forall r1 r2 (a :: Target) (sh :: [Maybe Nat]). (GoodScalar r1, RealFrac r1, GoodScalar r2, RealFrac r2) => Delta a (TKX sh r1) -> Delta a ('TKX2 sh ('TKScalar r2)) 
DeltaSum0X :: forall (a :: Target) (sh :: [Maybe Nat]) (r :: TK). Delta a ('TKX2 sh r) -> Delta a ('TKX2 ('[] :: [Maybe Nat]) r) 
DeltaDot0X :: forall r (a :: Target) (sh :: [Maybe Nat]). (GoodScalar r, Show (a (TKX sh r))) => a (TKX sh r) -> Delta a (TKX sh r) -> Delta a ('TKX2 ('[] :: [Maybe Nat]) ('TKScalar r)) 
DeltaIndexX :: forall (shm :: [Maybe Nat]) (shn :: [Maybe Nat]) (r :: TK) (a :: Target). StaticShX shn -> Delta a ('TKX2 (shm ++ shn) r) -> IxXOf a shm -> Delta a ('TKX2 shn r) 
DeltaScatterX :: forall (shm :: [Maybe Nat]) (shn :: [Maybe Nat]) (shp :: [Maybe Nat]) (a :: Target) (r :: TK). StaticShX shm -> StaticShX shn -> StaticShX shp -> IShX (shp ++ shn) -> Delta a ('TKX2 (shm ++ shn) r) -> (IxXOf a shm -> IxXOf a shp) -> Delta a ('TKX2 (shp ++ shn) r) 
DeltaGatherX :: forall (shm :: [Maybe Nat]) (shn :: [Maybe Nat]) (shp :: [Maybe Nat]) (a :: Target) (r :: TK). StaticShX shm -> StaticShX shn -> StaticShX shp -> IShX (shm ++ shn) -> Delta a ('TKX2 (shp ++ shn) r) -> (IxXOf a shm -> IxXOf a shp) -> Delta a ('TKX2 (shm ++ shn) r) 
DeltaAppendX :: forall (a :: Target) (m :: Natural) (sh :: [Maybe Natural]) (r :: TK) (n :: Natural). Delta a ('TKX2 ('Just m ': sh) r) -> Delta a ('TKX2 ('Just n ': sh) r) -> Delta a ('TKX2 ('Just (m + n) ': sh) r) 
DeltaSliceX :: forall (i :: Nat) (n :: Nat) (k :: Nat) (a :: Target) (sh :: [Maybe Nat]) (r :: TK). SNat i -> SNat n -> SNat k -> Delta a ('TKX2 ('Just ((i + n) + k) ': sh) r) -> Delta a ('TKX2 ('Just n ': sh) r) 
DeltaReverseX :: forall (a :: Target) (mn :: Maybe Nat) (sh :: [Maybe Nat]) (r :: TK). Delta a ('TKX2 (mn ': sh) r) -> Delta a ('TKX2 (mn ': sh) r) 
DeltaTransposeX :: forall (perm :: [Natural]) (sh :: [Maybe Nat]) (r :: TK) (a :: Target). (IsPermutation perm, Rank perm <= Rank sh) => Perm perm -> Delta a ('TKX2 sh r) -> Delta a ('TKX2 (PermutePrefix perm sh) r) 
DeltaReshapeX :: forall (sh2 :: [Maybe Nat]) (a :: Target) (sh :: [Maybe Nat]) (r :: TK). IShX sh2 -> Delta a ('TKX2 sh r) -> Delta a ('TKX2 sh2 r) 
DeltaConvert :: forall (a1 :: TK) (b :: TK) (a :: Target). TKConversion a1 b -> Delta a a1 -> Delta a b 

Instances

Instances details
Show (IntOf target) => Show (Delta target y) Source # 
Instance details

Defined in HordeAd.Core.Delta

Methods

showsPrec :: Int -> Delta target y -> ShowS #

show :: Delta target y -> String #

showList :: [Delta target y] -> ShowS #

newtype NestedTarget (target :: Target) (y :: TK) Source #

A newtype defined only to cut the knot of Show instances in DeltaScale that are problematic to pass around as dictionaries without bloating each constructor. The DeltaScale constructor appears in delta expressions a lot and so the primal subterm would bloat the pretty-printed output (though OTOH the primal terms are often important).

Possibly, Has Show (Delta target) is a better solution.

Constructors

NestedTarget (target y) 

Instances

Instances details
Show (NestedTarget target y) Source # 
Instance details

Defined in HordeAd.Core.Delta

Methods

showsPrec :: Int -> NestedTarget target y -> ShowS #

show :: NestedTarget target y -> String #

showList :: [NestedTarget target y] -> ShowS #

Full tensor kind derivation for delta expressions

ftkDelta :: forall (target :: Target) (y :: TK). Delta target y -> FullShapeTK y Source #

Full tensor kind derivation for delta expressions.