| Safe Haskell | Safe-Inferred |
|---|---|
| Language | Haskell2010 |
Synapse.Autograd
Description
This module implements reverse-mode automatic differentiation.
Machine learning and training of models are based on calculating gradients of operations. This can be done symbolically by dynamically creating a graph of all operations, which is then traversed to obtain the gradient.
Synapse provides several operations that support automatic differentiation,
but you could easily extend list of those: you just need to define function
that returns Symbol with correct local gradients.
You can check out implementations in the source to give yourself a reference
and read more about it in Symbol datatype docs.
Synopsis
- class (Eq a, Num a) => Symbolic a where
- symbolicZero :: a -> a
- symbolicOne :: a -> a
- symbolicN :: Int -> a -> a
- newtype SymbolIdentifier = SymbolIdentifier {}
- data Symbol a = Symbol {
- symbolIdentifier :: SymbolIdentifier
- unSymbol :: a
- symbolGradients :: [(Symbol a, Symbol a -> Symbol a)]
- type SymbolVec a = Symbol (Vec a)
- type SymbolMat a = Symbol (Mat a)
- symbol :: SymbolIdentifier -> a -> Symbol a
- constSymbol :: a -> Symbol a
- renameSymbol :: SymbolIdentifier -> Symbol a -> Symbol a
- symbolicUnaryOp :: (a -> a) -> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
- symbolicBinaryOp :: (a -> a -> a) -> Symbol a -> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
- data Gradients a
- allGradients :: Gradients a -> [(Symbol a, Symbol a)]
- getGradientsOf :: Symbolic a => Symbol a -> Gradients a
- wrt :: Symbolic a => Gradients a -> Symbol a -> Symbol a
- nthPartialGradient :: Symbolic a => Symbol a -> [Symbol a] -> Symbol a
- nthGradient :: Symbolic a => Int -> Symbol a -> Symbol a -> Symbol a
Symbolic and Symbol
class (Eq a, Num a) => Symbolic a where Source #
Symbolic typeclass describes types with few properties that are needed for autogradient.
Members of this typeclass could have default implementation due to Num, but such implementation is not always correct.
Vecs and Mats do not have only one zero or identity element, and so numerical literal is not enough.
symbolicZero and symbolicOne function additionally take reference value to consider dimensions.
Absence of default implementations forces to manually ensure correctness of those functions.
Synapse provides implementations for primitive types (Int, Float, Double),
and for containers types (Vec, Vec).
Detailed laws of Symbolic properties are in the docs for associated functions.
Minimal complete definition
Methods
symbolicZero :: a -> a Source #
Returns additive and multiplicative (elementwise) zero element. Argument is passed for the reference of the dimension.
symbolicOne :: a -> a Source #
Returns multiplicative (elementwise) identity element. Argument is passed for the reference of the dimension.
symbolicN :: Int -> a -> a Source #
Returns what could be considered N constant (sum of N symbolicOnes). Argument is passed for the reference of the dimension.
newtype SymbolIdentifier Source #
SymbolIdentifier is a newtype that wraps string, which needs to uniquely represent symbol.
Constructors
| SymbolIdentifier | |
Fields
| |
Instances
| IsString SymbolIdentifier Source # | |
Defined in Synapse.Autograd Methods fromString :: String -> SymbolIdentifier # | |
| Monoid SymbolIdentifier Source # | |
Defined in Synapse.Autograd Methods mappend :: SymbolIdentifier -> SymbolIdentifier -> SymbolIdentifier # mconcat :: [SymbolIdentifier] -> SymbolIdentifier # | |
| Semigroup SymbolIdentifier Source # | |
Defined in Synapse.Autograd Methods (<>) :: SymbolIdentifier -> SymbolIdentifier -> SymbolIdentifier # sconcat :: NonEmpty SymbolIdentifier -> SymbolIdentifier # stimes :: Integral b => b -> SymbolIdentifier -> SymbolIdentifier # | |
| Show SymbolIdentifier Source # | |
Defined in Synapse.Autograd Methods showsPrec :: Int -> SymbolIdentifier -> ShowS # show :: SymbolIdentifier -> String # showList :: [SymbolIdentifier] -> ShowS # | |
| Eq SymbolIdentifier Source # | |
Defined in Synapse.Autograd Methods (==) :: SymbolIdentifier -> SymbolIdentifier -> Bool # (/=) :: SymbolIdentifier -> SymbolIdentifier -> Bool # | |
Datatype that represents symbol variable (variable which operations are recorded to symbolically obtain derivatives).
Any operation returning Symbol a where a is Symbolic could be autogradiented - returned Symbol has symbolGradients list,
which allows Synapse to build a graph of computation and obtain needed gradients.
symbolGradients list contains pairs: first element in that pair is symbol wrt which you can take gradient and
the second element is closure that represents chain rule - it takes incoming local gradient of said symbol and multiplies it by local derivative.
You can check out implementations of those operations in the source to give yourself a reference.
Constructors
| Symbol | |
Fields
| |
Instances
symbol :: SymbolIdentifier -> a -> Symbol a Source #
Creates new symbol that refers to a variable (so it must have a name to be able to be differentiated wrt).
constSymbol :: a -> Symbol a Source #
Creates new symbol that refers to constant (so it does not have name and thus its gradients are not saved).
renameSymbol :: SymbolIdentifier -> Symbol a -> Symbol a Source #
Renames symbol which allows differentiating wrt it. Note: renaming practically creates new symbol for the gradient calculation.
symbolicUnaryOp :: (a -> a) -> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a Source #
Converts unary operation into symbolic one.
symbolicBinaryOp :: (a -> a -> a) -> Symbol a -> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a Source #
Converts binary operation into symbolic one.
Gradients calculation
Gradients datatype holds all gradients of one symbol with respect to other symbols.
allGradients :: Gradients a -> [(Symbol a, Symbol a)] Source #
Returns key-value pairs of all gradients of symbol.
getGradientsOf :: Symbolic a => Symbol a -> Gradients a Source #
Generates Gradients for given symbol.
wrt :: Symbolic a => Gradients a -> Symbol a -> Symbol a Source #
Chooses gradient with respect to given symbol.