{- | Allows to constraint values of layers parameters.

'ConstraintFn' type alias represents functions that are able to constrain the values of matrix
and 'Constraint' newtype wraps 'ConstraintFn's.

'ConstraintFn's should be applied on matrices from the 'Synapse.NN.Layers.Layer.updateParameters' function.
-}


module Synapse.NN.Layers.Constraints
    ( -- * 'ConstraintFn' type alias and 'Constraint' newtype

      
      ConstraintFn

    , Constraint (Constraint, unConstraint)

      -- * Value constraints

    
    , nonNegative
    , clampMin
    , clampMax
    , clampMinMax

      -- * Matrix-specific constraints

    
    , centralize
    ) where


import Synapse.Tensors (ElementwiseScalarOps((+.), (-.)), SingletonOps(unSingleton, mean))

import Synapse.Tensors.Mat (Mat)

import Data.Ord (clamp)


-- | 'ConstraintFn' type alias represents functions that are able to constrain the values of matrix.

type ConstraintFn a = Mat a -> Mat a


-- | 'Constraint' newtype wraps 'ConstraintFn's - functions that are able to constrain the values of matrix.

newtype Constraint a = Constraint
    { forall a. Constraint a -> ConstraintFn a
unConstraint :: ConstraintFn a  -- ^ Unwraps 'Constraint' newtype.

    }


-- Value constraints


-- | Ensures that matrix values are non-negative.

nonNegative :: (Num a, Ord a) => ConstraintFn a
nonNegative :: forall a. (Num a, Ord a) => ConstraintFn a
nonNegative = (a -> a) -> Mat a -> Mat a
forall a b. (a -> b) -> Mat a -> Mat b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> a -> a
forall a. Ord a => a -> a -> a
max a
0)

-- | Ensures that matrix values are more or equal than given value.

clampMin :: Ord a => a -> ConstraintFn a
clampMin :: forall a. Ord a => a -> ConstraintFn a
clampMin a
minimal = (a -> a) -> Mat a -> Mat a
forall a b. (a -> b) -> Mat a -> Mat b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> a -> a
forall a. Ord a => a -> a -> a
max a
minimal)

-- | Ensures that matrix values are less or equal than given value.

clampMax :: Ord a => a -> ConstraintFn a
clampMax :: forall a. Ord a => a -> ConstraintFn a
clampMax a
maximal = (a -> a) -> Mat a -> Mat a
forall a b. (a -> b) -> Mat a -> Mat b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> a -> a
forall a. Ord a => a -> a -> a
min a
maximal)

-- | Ensures that matrix values are clamped between given values.

clampMinMax :: Ord a => (a, a) -> ConstraintFn a
clampMinMax :: forall a. Ord a => (a, a) -> ConstraintFn a
clampMinMax = (a -> a) -> Mat a -> Mat a
forall a b. (a -> b) -> Mat a -> Mat b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((a -> a) -> Mat a -> Mat a)
-> ((a, a) -> a -> a) -> (a, a) -> Mat a -> Mat a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a, a) -> a -> a
forall a. Ord a => (a, a) -> a -> a
clamp


-- Matrix-specific constraints


-- | Ensures that matrix values are centralized by mean around given value.

centralize :: Fractional a => a -> ConstraintFn a
centralize :: forall a. Fractional a => a -> ConstraintFn a
centralize a
center Mat a
mat = Mat a
mat Mat a -> DType (Mat a) -> Mat a
forall f.
(ElementwiseScalarOps f, Num (DType f)) =>
f -> DType f -> f
-. Mat a -> DType (Mat a)
forall f. SingletonOps f => f -> DType f
unSingleton (Mat a -> Mat a
forall f. (SingletonOps f, Fractional (DType f)) => f -> f
mean Mat a
mat) Mat a -> DType (Mat a) -> Mat a
forall f.
(ElementwiseScalarOps f, Num (DType f)) =>
f -> DType f -> f
+. a
DType (Mat a)
center