{- | Provides collection of functions that impose penalties on parameters which is done by adding result to loss value.
-}


module Synapse.NN.Layers.Regularizers
    ( -- * 'RegularizerFn' type alias and 'Regularizer' newtype


      RegularizerFn
    , Regularizer (Regularizer, unRegularizer)

      -- * Regularizers


    , l1
    , l2
    ) where


import Synapse.Tensors (ElementwiseScalarOps((*.)), SingletonOps(elementsSum))

import Synapse.Autograd (SymbolMat, Symbolic)


-- | 'RegularizerFn' type alias represents functions that impose penalties on parameters which is done by adding result of regularization to loss value.

type RegularizerFn a = SymbolMat a -> SymbolMat a


{- | 'Regularizer' newtype wraps 'RegularizerFn's - functions that impose penalties on parameters.

Every regularization function must return symbol of singleton matrix.
-}
newtype Regularizer a = Regularizer
    { forall a. Regularizer a -> RegularizerFn a
unRegularizer :: RegularizerFn a  -- ^ Unwraps 'Regularizer' newtype.

    }


-- Regularizers


-- | Applies a L1 regularization penalty (sum of absolute values of parameter multiplied by a coefficient).

l1 :: (Symbolic a, Num a) => a -> RegularizerFn a
l1 :: forall a. (Symbolic a, Num a) => a -> RegularizerFn a
l1 a
k SymbolMat a
mat = SymbolMat a -> SymbolMat a
forall f. (SingletonOps f, Num (DType f)) => f -> f
elementsSum (SymbolMat a -> SymbolMat a
forall a. Num a => a -> a
abs SymbolMat a
mat) SymbolMat a -> DType (SymbolMat a) -> SymbolMat a
forall f.
(ElementwiseScalarOps f, Num (DType f)) =>
f -> DType f -> f
*. a
DType (SymbolMat a)
k

-- | Applies a L1 regularization penalty (sum of squared values of parameter multiplied by a coefficient).

l2 :: (Symbolic a, Num a) => a -> RegularizerFn a
l2 :: forall a. (Symbolic a, Num a) => a -> RegularizerFn a
l2 a
k SymbolMat a
mat = SymbolMat a -> SymbolMat a
forall f. (SingletonOps f, Num (DType f)) => f -> f
elementsSum (SymbolMat a
mat SymbolMat a -> SymbolMat a -> SymbolMat a
forall a. Num a => a -> a -> a
* SymbolMat a
mat) SymbolMat a -> DType (SymbolMat a) -> SymbolMat a
forall f.
(ElementwiseScalarOps f, Num (DType f)) =>
f -> DType f -> f
*. a
DType (SymbolMat a)
k