{- | This module implements several optimizers that are used in training.
-}


-- 'TypeFamilies' are needed to use 'DType' and define 'Optimizer' typeclass.

{-# LANGUAGE TypeFamilies #-}


module Synapse.NN.Optimizers
    ( -- * 'Optimizer' typeclass

      
      Optimizer (OptimizerParameters, optimizerInitialParameters, optimizerUpdateStep)

    , optimizerUpdateParameters
    
      -- * Optimizers


    , SGD (SGD, sgdMomentum, sgdNesterov)
    ) where


import Synapse.Tensors (DType, ElementwiseScalarOps((*.)))

import Synapse.Tensors.Mat (Mat)
import qualified Synapse.Tensors.Mat as M

import Synapse.Autograd (Symbolic, Symbol(unSymbol), SymbolMat, Gradients, wrt)

import Synapse.NN.Layers.Initializers (zeroes)

import Data.Kind (Type)


-- | 'Optimizer' typeclass represents optimizer - algorithm that defines an update rule of neural network parameters.

class Optimizer optimizer where
    -- | 'OptimizerParameters' represent optimizer-specific parameters that it needs to implement update rule.

    type OptimizerParameters optimizer a :: Type

    -- | Returns initial state of optimizer-specific parameters for given variable.

    optimizerInitialParameters :: Num a => optimizer a -> Mat a -> OptimizerParameters optimizer a
    
    -- | Performs the update step of optimizer.

    optimizerUpdateStep
        :: Num a
        => optimizer a                               -- ^ Optimizer itself.

        -> (a, Mat a)                                -- ^ Learning rate and gradient of given parameter.

        -> (Mat a, OptimizerParameters optimizer a)  -- ^ Given parameter and current state of optimizer-specific parameters.

        -> (Mat a, OptimizerParameters optimizer a)  -- ^ Updated parameter and a new state of optimizer-specific parameters.


-- | 'optimizerUpdateParameters' function updates whole model using optimizer by performing 'optimizerUpdateStep' for every parameter.

optimizerUpdateParameters
    :: (Symbolic a, Optimizer optimizer)
    => optimizer a                                       -- ^ Optimizer itself.

    -> (a, Gradients (Mat a))                            -- ^ Learning rate and gradients of all parameters. 

    -> [(SymbolMat a, OptimizerParameters optimizer a)]  -- ^ Given parameters and current state of optimizer-specific parameters.

    -> [(Mat a, OptimizerParameters optimizer a)]        -- ^ Updated parameters and a new state of optimizer-specific parameters.

optimizerUpdateParameters :: forall a (optimizer :: * -> *).
(Symbolic a, Optimizer optimizer) =>
optimizer a
-> (a, Gradients (Mat a))
-> [(SymbolMat a, OptimizerParameters optimizer a)]
-> [(Mat a, OptimizerParameters optimizer a)]
optimizerUpdateParameters optimizer a
_ (a, Gradients (Mat a))
_ [] = []
optimizerUpdateParameters optimizer a
optimizer (a
lrValue, Gradients (Mat a)
gradients) ((SymbolMat a
parameter, OptimizerParameters optimizer a
optimizerParameter):[(SymbolMat a, OptimizerParameters optimizer a)]
xs) =
    optimizer a
-> (a, Mat a)
-> (Mat a, OptimizerParameters optimizer a)
-> (Mat a, OptimizerParameters optimizer a)
forall a.
Num a =>
optimizer a
-> (a, Mat a)
-> (Mat a, OptimizerParameters optimizer a)
-> (Mat a, OptimizerParameters optimizer a)
forall (optimizer :: * -> *) a.
(Optimizer optimizer, Num a) =>
optimizer a
-> (a, Mat a)
-> (Mat a, OptimizerParameters optimizer a)
-> (Mat a, OptimizerParameters optimizer a)
optimizerUpdateStep optimizer a
optimizer (a
lrValue, SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol (SymbolMat a -> Mat a) -> SymbolMat a -> Mat a
forall a b. (a -> b) -> a -> b
$ Gradients (Mat a)
gradients Gradients (Mat a) -> SymbolMat a -> SymbolMat a
forall a. Symbolic a => Gradients a -> Symbol a -> Symbol a
`wrt` SymbolMat a
parameter) (SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol SymbolMat a
parameter, OptimizerParameters optimizer a
optimizerParameter)
    (Mat a, OptimizerParameters optimizer a)
-> [(Mat a, OptimizerParameters optimizer a)]
-> [(Mat a, OptimizerParameters optimizer a)]
forall a. a -> [a] -> [a]
: optimizer a
-> (a, Gradients (Mat a))
-> [(SymbolMat a, OptimizerParameters optimizer a)]
-> [(Mat a, OptimizerParameters optimizer a)]
forall a (optimizer :: * -> *).
(Symbolic a, Optimizer optimizer) =>
optimizer a
-> (a, Gradients (Mat a))
-> [(SymbolMat a, OptimizerParameters optimizer a)]
-> [(Mat a, OptimizerParameters optimizer a)]
optimizerUpdateParameters optimizer a
optimizer (a
lrValue, Gradients (Mat a)
gradients) [(SymbolMat a, OptimizerParameters optimizer a)]
xs


-- | 'SGD' is a optimizer that implements stochastic gradient-descent algorithm.

data SGD a = SGD
    { forall a. SGD a -> a
sgdMomentum :: a     -- ^ Momentum coefficient.

    , forall a. SGD a -> Bool
sgdNesterov :: Bool  -- ^ Nesterov update rule.

    } deriving (SGD a -> SGD a -> Bool
(SGD a -> SGD a -> Bool) -> (SGD a -> SGD a -> Bool) -> Eq (SGD a)
forall a. Eq a => SGD a -> SGD a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall a. Eq a => SGD a -> SGD a -> Bool
== :: SGD a -> SGD a -> Bool
$c/= :: forall a. Eq a => SGD a -> SGD a -> Bool
/= :: SGD a -> SGD a -> Bool
Eq, Int -> SGD a -> ShowS
[SGD a] -> ShowS
SGD a -> String
(Int -> SGD a -> ShowS)
-> (SGD a -> String) -> ([SGD a] -> ShowS) -> Show (SGD a)
forall a. Show a => Int -> SGD a -> ShowS
forall a. Show a => [SGD a] -> ShowS
forall a. Show a => SGD a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall a. Show a => Int -> SGD a -> ShowS
showsPrec :: Int -> SGD a -> ShowS
$cshow :: forall a. Show a => SGD a -> String
show :: SGD a -> String
$cshowList :: forall a. Show a => [SGD a] -> ShowS
showList :: [SGD a] -> ShowS
Show)

type instance DType (SGD a) = a

instance Optimizer SGD where
    type OptimizerParameters SGD a = Mat a

    optimizerInitialParameters :: forall a. Num a => SGD a -> Mat a -> OptimizerParameters SGD a
optimizerInitialParameters SGD a
_ Mat a
parameter = InitializerFn a
forall a. Num a => InitializerFn a
zeroes (Mat a -> (Int, Int)
forall a. Mat a -> (Int, Int)
M.size Mat a
parameter)

    optimizerUpdateStep :: forall a.
Num a =>
SGD a
-> (a, Mat a)
-> (Mat a, OptimizerParameters SGD a)
-> (Mat a, OptimizerParameters SGD a)
optimizerUpdateStep (SGD a
momentum Bool
nesterov) (a
lr, Mat a
gradient) (Mat a
parameter, OptimizerParameters SGD a
velocity) = (Mat a
parameter', Mat a
OptimizerParameters SGD a
velocity')
      where
        velocity' :: Mat a
velocity' = Mat a
OptimizerParameters SGD a
velocity Mat a -> DType (Mat a) -> Mat a
forall f.
(ElementwiseScalarOps f, Num (DType f)) =>
f -> DType f -> f
*. a
DType (Mat a)
momentum Mat a -> Mat a -> Mat a
forall a. Num a => a -> a -> a
- Mat a
gradient Mat a -> DType (Mat a) -> Mat a
forall f.
(ElementwiseScalarOps f, Num (DType f)) =>
f -> DType f -> f
*. a
DType (Mat a)
lr

        parameter' :: Mat a
parameter' = if Bool
nesterov
                     then Mat a
parameter Mat a -> Mat a -> Mat a
forall a. Num a => a -> a -> a
+ Mat a
velocity' Mat a -> DType (Mat a) -> Mat a
forall f.
(ElementwiseScalarOps f, Num (DType f)) =>
f -> DType f -> f
*. a
DType (Mat a)
momentum Mat a -> Mat a -> Mat a
forall a. Num a => a -> a -> a
- Mat a
gradient Mat a -> DType (Mat a) -> Mat a
forall f.
(ElementwiseScalarOps f, Num (DType f)) =>
f -> DType f -> f
*. a
DType (Mat a)
lr
                     else Mat a
parameter Mat a -> Mat a -> Mat a
forall a. Num a => a -> a -> a
+ Mat a
velocity'