{-# LANGUAGE TypeFamilies #-}
module Synapse.NN.Optimizers
(
Optimizer (OptimizerParameters, optimizerInitialParameters, optimizerUpdateStep)
, optimizerUpdateParameters
, SGD (SGD, sgdMomentum, sgdNesterov)
) where
import Synapse.Tensors (DType, ElementwiseScalarOps((*.)))
import Synapse.Tensors.Mat (Mat)
import qualified Synapse.Tensors.Mat as M
import Synapse.Autograd (Symbolic, Symbol(unSymbol), SymbolMat, Gradients, wrt)
import Synapse.NN.Layers.Initializers (zeroes)
import Data.Kind (Type)
class Optimizer optimizer where
type OptimizerParameters optimizer a :: Type
optimizerInitialParameters :: Num a => optimizer a -> Mat a -> OptimizerParameters optimizer a
optimizerUpdateStep
:: Num a
=> optimizer a
-> (a, Mat a)
-> (Mat a, OptimizerParameters optimizer a)
-> (Mat a, OptimizerParameters optimizer a)
optimizerUpdateParameters
:: (Symbolic a, Optimizer optimizer)
=> optimizer a
-> (a, Gradients (Mat a))
-> [(SymbolMat a, OptimizerParameters optimizer a)]
-> [(Mat a, OptimizerParameters optimizer a)]
optimizerUpdateParameters :: forall a (optimizer :: * -> *).
(Symbolic a, Optimizer optimizer) =>
optimizer a
-> (a, Gradients (Mat a))
-> [(SymbolMat a, OptimizerParameters optimizer a)]
-> [(Mat a, OptimizerParameters optimizer a)]
optimizerUpdateParameters optimizer a
_ (a, Gradients (Mat a))
_ [] = []
optimizerUpdateParameters optimizer a
optimizer (a
lrValue, Gradients (Mat a)
gradients) ((SymbolMat a
parameter, OptimizerParameters optimizer a
optimizerParameter):[(SymbolMat a, OptimizerParameters optimizer a)]
xs) =
optimizer a
-> (a, Mat a)
-> (Mat a, OptimizerParameters optimizer a)
-> (Mat a, OptimizerParameters optimizer a)
forall a.
Num a =>
optimizer a
-> (a, Mat a)
-> (Mat a, OptimizerParameters optimizer a)
-> (Mat a, OptimizerParameters optimizer a)
forall (optimizer :: * -> *) a.
(Optimizer optimizer, Num a) =>
optimizer a
-> (a, Mat a)
-> (Mat a, OptimizerParameters optimizer a)
-> (Mat a, OptimizerParameters optimizer a)
optimizerUpdateStep optimizer a
optimizer (a
lrValue, SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol (SymbolMat a -> Mat a) -> SymbolMat a -> Mat a
forall a b. (a -> b) -> a -> b
$ Gradients (Mat a)
gradients Gradients (Mat a) -> SymbolMat a -> SymbolMat a
forall a. Symbolic a => Gradients a -> Symbol a -> Symbol a
`wrt` SymbolMat a
parameter) (SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol SymbolMat a
parameter, OptimizerParameters optimizer a
optimizerParameter)
(Mat a, OptimizerParameters optimizer a)
-> [(Mat a, OptimizerParameters optimizer a)]
-> [(Mat a, OptimizerParameters optimizer a)]
forall a. a -> [a] -> [a]
: optimizer a
-> (a, Gradients (Mat a))
-> [(SymbolMat a, OptimizerParameters optimizer a)]
-> [(Mat a, OptimizerParameters optimizer a)]
forall a (optimizer :: * -> *).
(Symbolic a, Optimizer optimizer) =>
optimizer a
-> (a, Gradients (Mat a))
-> [(SymbolMat a, OptimizerParameters optimizer a)]
-> [(Mat a, OptimizerParameters optimizer a)]
optimizerUpdateParameters optimizer a
optimizer (a
lrValue, Gradients (Mat a)
gradients) [(SymbolMat a, OptimizerParameters optimizer a)]
xs
data SGD a = SGD
{ forall a. SGD a -> a
sgdMomentum :: a
, forall a. SGD a -> Bool
sgdNesterov :: Bool
} deriving (SGD a -> SGD a -> Bool
(SGD a -> SGD a -> Bool) -> (SGD a -> SGD a -> Bool) -> Eq (SGD a)
forall a. Eq a => SGD a -> SGD a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall a. Eq a => SGD a -> SGD a -> Bool
== :: SGD a -> SGD a -> Bool
$c/= :: forall a. Eq a => SGD a -> SGD a -> Bool
/= :: SGD a -> SGD a -> Bool
Eq, Int -> SGD a -> ShowS
[SGD a] -> ShowS
SGD a -> String
(Int -> SGD a -> ShowS)
-> (SGD a -> String) -> ([SGD a] -> ShowS) -> Show (SGD a)
forall a. Show a => Int -> SGD a -> ShowS
forall a. Show a => [SGD a] -> ShowS
forall a. Show a => SGD a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall a. Show a => Int -> SGD a -> ShowS
showsPrec :: Int -> SGD a -> ShowS
$cshow :: forall a. Show a => SGD a -> String
show :: SGD a -> String
$cshowList :: forall a. Show a => [SGD a] -> ShowS
showList :: [SGD a] -> ShowS
Show)
type instance DType (SGD a) = a
instance Optimizer SGD where
type OptimizerParameters SGD a = Mat a
optimizerInitialParameters :: forall a. Num a => SGD a -> Mat a -> OptimizerParameters SGD a
optimizerInitialParameters SGD a
_ Mat a
parameter = InitializerFn a
forall a. Num a => InitializerFn a
zeroes (Mat a -> (Int, Int)
forall a. Mat a -> (Int, Int)
M.size Mat a
parameter)
optimizerUpdateStep :: forall a.
Num a =>
SGD a
-> (a, Mat a)
-> (Mat a, OptimizerParameters SGD a)
-> (Mat a, OptimizerParameters SGD a)
optimizerUpdateStep (SGD a
momentum Bool
nesterov) (a
lr, Mat a
gradient) (Mat a
parameter, OptimizerParameters SGD a
velocity) = (Mat a
parameter', Mat a
OptimizerParameters SGD a
velocity')
where
velocity' :: Mat a
velocity' = Mat a
OptimizerParameters SGD a
velocity Mat a -> DType (Mat a) -> Mat a
forall f.
(ElementwiseScalarOps f, Num (DType f)) =>
f -> DType f -> f
*. a
DType (Mat a)
momentum Mat a -> Mat a -> Mat a
forall a. Num a => a -> a -> a
- Mat a
gradient Mat a -> DType (Mat a) -> Mat a
forall f.
(ElementwiseScalarOps f, Num (DType f)) =>
f -> DType f -> f
*. a
DType (Mat a)
lr
parameter' :: Mat a
parameter' = if Bool
nesterov
then Mat a
parameter Mat a -> Mat a -> Mat a
forall a. Num a => a -> a -> a
+ Mat a
velocity' Mat a -> DType (Mat a) -> Mat a
forall f.
(ElementwiseScalarOps f, Num (DType f)) =>
f -> DType f -> f
*. a
DType (Mat a)
momentum Mat a -> Mat a -> Mat a
forall a. Num a => a -> a -> a
- Mat a
gradient Mat a -> DType (Mat a) -> Mat a
forall f.
(ElementwiseScalarOps f, Num (DType f)) =>
f -> DType f -> f
*. a
DType (Mat a)
lr
else Mat a
parameter Mat a -> Mat a -> Mat a
forall a. Num a => a -> a -> a
+ Mat a
velocity'