{- | Provides dense layer implementation.

'Dense' datatype represents densely-connected neural network layer and
it performs following operation: @x `matMul` w + b@, where @w@ is weights and @b@ is bias (if present) of a layer.
-}


-- 'TypeFamilies' are needed to instantiate 'DType'.

{-# LANGUAGE TypeFamilies #-}


module Synapse.NN.Layers.Dense
    ( -- * 'Dense' datatype


      Dense (Dense, denseWeights, denseBias, denseConstraints, denseRegularizers)
    , layerDenseWith
    , layerDense
    ) where



import Synapse.Tensors (DType, SingletonOps(singleton), MatOps(addMatRow, matMul))

import Synapse.Tensors.Vec (Vec)

import Synapse.Tensors.Mat (Mat)
import qualified Synapse.Tensors.Mat as M

import Synapse.Autograd (Symbolic, SymbolIdentifier(SymbolIdentifier), symbol, SymbolMat)

import Synapse.NN.Layers.Layer (AbstractLayer(..), LayerConfiguration)
import Synapse.NN.Layers.Initializers (Initializer(Initializer), zeroes, ones)
import Synapse.NN.Layers.Constraints (Constraint(Constraint))
import Synapse.NN.Layers.Regularizers (Regularizer(Regularizer))


{- | 'Dense' datatype represents densely-connected neural network layer.

'Dense' performs following operation: @x `matMul` w + b@, where @w@ is weights and @b@ is bias (if present) of a layer.
-}
data Dense a = Dense
    { forall a. Dense a -> Mat a
denseWeights      :: Mat a                           -- ^ Matrix that represents weights of dense layer.

    , forall a. Dense a -> Vec a
denseBias         :: Vec a                           -- ^ Vector that represents bias of dense layer.


    , forall a. Dense a -> (Constraint a, Constraint a)
denseConstraints  :: (Constraint a, Constraint a)    -- ^ Constraints on weights and bias of dense layer.

    , forall a. Dense a -> (Regularizer a, Regularizer a)
denseRegularizers :: (Regularizer a, Regularizer a)  -- ^ Regularizers on weights and bias of dense layer.

    }

-- | Creates symbol for weights.

weightsSymbol :: SymbolIdentifier -> Mat a -> SymbolMat a
weightsSymbol :: forall a. SymbolIdentifier -> Mat a -> SymbolMat a
weightsSymbol SymbolIdentifier
prefix = SymbolIdentifier -> Mat a -> Symbol (Mat a)
forall a. SymbolIdentifier -> a -> Symbol a
symbol (SymbolIdentifier
prefix SymbolIdentifier -> SymbolIdentifier -> SymbolIdentifier
forall a. Semigroup a => a -> a -> a
<> String -> SymbolIdentifier
SymbolIdentifier String
"1")


-- | Creates symbol for bias.

biasSymbol :: SymbolIdentifier -> Vec a -> SymbolMat a
biasSymbol :: forall a. SymbolIdentifier -> Vec a -> SymbolMat a
biasSymbol SymbolIdentifier
prefix = SymbolIdentifier -> Mat a -> Symbol (Mat a)
forall a. SymbolIdentifier -> a -> Symbol a
symbol (SymbolIdentifier
prefix SymbolIdentifier -> SymbolIdentifier -> SymbolIdentifier
forall a. Semigroup a => a -> a -> a
<> String -> SymbolIdentifier
SymbolIdentifier String
"2") (Mat a -> Symbol (Mat a))
-> (Vec a -> Mat a) -> Vec a -> Symbol (Mat a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vec a -> Mat a
forall a. Vec a -> Mat a
M.rowVec

type instance DType (Dense a) = a

instance AbstractLayer Dense where
    inputSize :: forall a. Dense a -> Maybe Int
inputSize = Int -> Maybe Int
forall a. a -> Maybe a
Just (Int -> Maybe Int) -> (Dense a -> Int) -> Dense a -> Maybe Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Mat a -> Int
forall a. Mat a -> Int
M.nRows (Mat a -> Int) -> (Dense a -> Mat a) -> Dense a -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Dense a -> Mat a
forall a. Dense a -> Mat a
denseWeights
    outputSize :: forall a. Dense a -> Maybe Int
outputSize = Int -> Maybe Int
forall a. a -> Maybe a
Just (Int -> Maybe Int) -> (Dense a -> Int) -> Dense a -> Maybe Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Mat a -> Int
forall a. Mat a -> Int
M.nCols (Mat a -> Int) -> (Dense a -> Mat a) -> Dense a -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Dense a -> Mat a
forall a. Dense a -> Mat a
denseWeights

    nParameters :: forall a. Dense a -> Int
nParameters Dense a
_ = Int
2
    getParameters :: forall a. SymbolIdentifier -> Dense a -> [SymbolMat a]
getParameters SymbolIdentifier
prefix (Dense Mat a
weights Vec a
bias (Constraint a, Constraint a)
_ (Regularizer a, Regularizer a)
_) = [SymbolIdentifier -> Mat a -> SymbolMat a
forall a. SymbolIdentifier -> Mat a -> SymbolMat a
weightsSymbol SymbolIdentifier
prefix Mat a
weights, SymbolIdentifier -> Vec a -> SymbolMat a
forall a. SymbolIdentifier -> Vec a -> SymbolMat a
biasSymbol SymbolIdentifier
prefix Vec a
bias]
    updateParameters :: forall a. Dense a -> [Mat a] -> Dense a
updateParameters (Dense Mat a
_ Vec a
_ constraints :: (Constraint a, Constraint a)
constraints@(Constraint ConstraintFn a
weightsConstraintFn, Constraint ConstraintFn a
biasConstraintFn) (Regularizer a, Regularizer a)
regularizers) [Mat a
weights', Mat a
biasMat'] =
        Mat a
-> Vec a
-> (Constraint a, Constraint a)
-> (Regularizer a, Regularizer a)
-> Dense a
forall a.
Mat a
-> Vec a
-> (Constraint a, Constraint a)
-> (Regularizer a, Regularizer a)
-> Dense a
Dense (ConstraintFn a
weightsConstraintFn Mat a
weights') (Mat a -> Int -> Vec a
forall a. Mat a -> Int -> Vec a
M.indexRow (ConstraintFn a
biasConstraintFn Mat a
biasMat') Int
0) (Constraint a, Constraint a)
constraints (Regularizer a, Regularizer a)
regularizers
    updateParameters Dense a
_ [Mat a]
_ = String -> Dense a
forall a. HasCallStack => String -> a
error String
"Parameters update failed - wrong amount of parameters was given"

    symbolicForward :: forall a.
(Symbolic a, Floating a, Ord a) =>
SymbolIdentifier
-> SymbolMat a -> Dense a -> (SymbolMat a, SymbolMat a)
symbolicForward SymbolIdentifier
prefix SymbolMat a
input (Dense Mat a
weights Vec a
bias (Constraint a, Constraint a)
_ (Regularizer RegularizerFn a
weightsRegularizerFn, Regularizer RegularizerFn a
biasRegularizerFn)) =
        let symbolWeights :: SymbolMat a
symbolWeights = SymbolIdentifier -> Mat a -> SymbolMat a
forall a. SymbolIdentifier -> Mat a -> SymbolMat a
weightsSymbol SymbolIdentifier
prefix Mat a
weights
            symbolBias :: SymbolMat a
symbolBias = SymbolIdentifier -> Vec a -> SymbolMat a
forall a. SymbolIdentifier -> Vec a -> SymbolMat a
biasSymbol SymbolIdentifier
prefix Vec a
bias
        in
        ( (SymbolMat a
input SymbolMat a -> RegularizerFn a
forall f. (MatOps f, Num (DType f)) => f -> f -> f
`matMul` SymbolMat a
symbolWeights) SymbolMat a -> RegularizerFn a
forall f. (MatOps f, Num (DType f)) => f -> f -> f
`addMatRow` SymbolMat a
symbolBias
        , RegularizerFn a
weightsRegularizerFn SymbolMat a
symbolWeights SymbolMat a -> RegularizerFn a
forall a. Num a => a -> a -> a
+ RegularizerFn a
biasRegularizerFn SymbolMat a
symbolBias
        )

-- | Creates configuration of dense layer.

layerDenseWith
    :: Symbolic a
    => (Initializer a, Constraint a, Regularizer a)  -- ^ Weights initializer, constraint and regularizer.

    -> (Initializer a, Constraint a, Regularizer a)  -- ^ Bias initializer, constraint and regularizer.

    -> Int                                           -- ^ Amount of neurons.

    -> LayerConfiguration (Dense a)
layerDenseWith :: forall a.
Symbolic a =>
(Initializer a, Constraint a, Regularizer a)
-> (Initializer a, Constraint a, Regularizer a)
-> Int
-> LayerConfiguration (Dense a)
layerDenseWith (Initializer InitializerFn a
weightsInitializer, Constraint a
weightsConstraints, Regularizer a
weightsRegularizer)
               (Initializer InitializerFn a
biasInitializer, Constraint a
biasConstraints, Regularizer a
biasRegularizer)
               Int
neurons Int
input =
    Mat a
-> Vec a
-> (Constraint a, Constraint a)
-> (Regularizer a, Regularizer a)
-> Dense a
forall a.
Mat a
-> Vec a
-> (Constraint a, Constraint a)
-> (Regularizer a, Regularizer a)
-> Dense a
Dense (InitializerFn a
weightsInitializer (Int
input, Int
neurons)) (Mat a -> Int -> Vec a
forall a. Mat a -> Int -> Vec a
M.indexRow (InitializerFn a
biasInitializer (Int
1, Int
neurons)) Int
0) 
          (Constraint a
weightsConstraints, Constraint a
biasConstraints) (Regularizer a
weightsRegularizer, Regularizer a
biasRegularizer)

-- | Creates default configuration of dense layer - no constraints and weight are initialized with ones, bias is initialized with zeroes.

layerDense :: Symbolic a => Int -> LayerConfiguration (Dense a)
layerDense :: forall a. Symbolic a => Int -> LayerConfiguration (Dense a)
layerDense = (Initializer a, Constraint a, Regularizer a)
-> (Initializer a, Constraint a, Regularizer a)
-> Int
-> LayerConfiguration (Dense a)
forall a.
Symbolic a =>
(Initializer a, Constraint a, Regularizer a)
-> (Initializer a, Constraint a, Regularizer a)
-> Int
-> LayerConfiguration (Dense a)
layerDenseWith (InitializerFn a -> Initializer a
forall a. InitializerFn a -> Initializer a
Initializer InitializerFn a
forall a. Num a => InitializerFn a
ones, ConstraintFn a -> Constraint a
forall a. ConstraintFn a -> Constraint a
Constraint ConstraintFn a
forall a. a -> a
id, RegularizerFn a -> Regularizer a
forall a. RegularizerFn a -> Regularizer a
Regularizer (SymbolMat a -> RegularizerFn a
forall a b. a -> b -> a
const (SymbolMat a -> RegularizerFn a) -> SymbolMat a -> RegularizerFn a
forall a b. (a -> b) -> a -> b
$ DType (SymbolMat a) -> SymbolMat a
forall f. SingletonOps f => DType f -> f
singleton a
DType (SymbolMat a)
0))
                            (InitializerFn a -> Initializer a
forall a. InitializerFn a -> Initializer a
Initializer InitializerFn a
forall a. Num a => InitializerFn a
zeroes, ConstraintFn a -> Constraint a
forall a. ConstraintFn a -> Constraint a
Constraint ConstraintFn a
forall a. a -> a
id, RegularizerFn a -> Regularizer a
forall a. RegularizerFn a -> Regularizer a
Regularizer (SymbolMat a -> RegularizerFn a
forall a b. a -> b -> a
const (SymbolMat a -> RegularizerFn a) -> SymbolMat a -> RegularizerFn a
forall a b. (a -> b) -> a -> b
$ DType (SymbolMat a) -> SymbolMat a
forall f. SingletonOps f => DType f -> f
singleton a
DType (SymbolMat a)
0))