module Synapse.NN.Layers.Constraints
(
ConstraintFn
, Constraint (Constraint, unConstraint)
, nonNegative
, clampMin
, clampMax
, clampMinMax
, centralize
) where
import Synapse.Tensors (ElementwiseScalarOps((+.), (-.)), SingletonOps(unSingleton, mean))
import Synapse.Tensors.Mat (Mat)
import Data.Ord (clamp)
type ConstraintFn a = Mat a -> Mat a
newtype Constraint a = Constraint
{ forall a. Constraint a -> ConstraintFn a
unConstraint :: ConstraintFn a
}
nonNegative :: (Num a, Ord a) => ConstraintFn a
nonNegative :: forall a. (Num a, Ord a) => ConstraintFn a
nonNegative = (a -> a) -> Mat a -> Mat a
forall a b. (a -> b) -> Mat a -> Mat b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> a -> a
forall a. Ord a => a -> a -> a
max a
0)
clampMin :: Ord a => a -> ConstraintFn a
clampMin :: forall a. Ord a => a -> ConstraintFn a
clampMin a
minimal = (a -> a) -> Mat a -> Mat a
forall a b. (a -> b) -> Mat a -> Mat b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> a -> a
forall a. Ord a => a -> a -> a
max a
minimal)
clampMax :: Ord a => a -> ConstraintFn a
clampMax :: forall a. Ord a => a -> ConstraintFn a
clampMax a
maximal = (a -> a) -> Mat a -> Mat a
forall a b. (a -> b) -> Mat a -> Mat b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> a -> a
forall a. Ord a => a -> a -> a
min a
maximal)
clampMinMax :: Ord a => (a, a) -> ConstraintFn a
clampMinMax :: forall a. Ord a => (a, a) -> ConstraintFn a
clampMinMax = (a -> a) -> Mat a -> Mat a
forall a b. (a -> b) -> Mat a -> Mat b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((a -> a) -> Mat a -> Mat a)
-> ((a, a) -> a -> a) -> (a, a) -> Mat a -> Mat a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a, a) -> a -> a
forall a. Ord a => (a, a) -> a -> a
clamp
centralize :: Fractional a => a -> ConstraintFn a
centralize :: forall a. Fractional a => a -> ConstraintFn a
centralize a
center Mat a
mat = Mat a
mat Mat a -> DType (Mat a) -> Mat a
forall f.
(ElementwiseScalarOps f, Num (DType f)) =>
f -> DType f -> f
-. Mat a -> DType (Mat a)
forall f. SingletonOps f => f -> DType f
unSingleton (Mat a -> Mat a
forall f. (SingletonOps f, Fractional (DType f)) => f -> f
mean Mat a
mat) Mat a -> DType (Mat a) -> Mat a
forall f.
(ElementwiseScalarOps f, Num (DType f)) =>
f -> DType f -> f
+. a
DType (Mat a)
center