{- | Provides collection of functions that are used to as a reference of what needs to be minimised during training.

'LossFn' type alias represents those functions, and "Synapse" offers a variety of them.
-}


module Synapse.NN.Losses
    ( -- * 'LossFn' type alias and 'Loss' newtype


      LossFn
    
    , Loss (Loss, unLoss)

      -- * Regression losses

    
    , mse
    , msle
    , mae
    , mape
    , logcosh
    ) where


import Synapse.Tensors (ElementwiseScalarOps((+.), (*.), (**.)), SingletonOps(mean))

import Synapse.Autograd (SymbolMat, Symbolic)


-- | 'LossFn' type alias represents functions that are able to provide a reference of what relation between matrices needs to be minimised.

type LossFn a = SymbolMat a -> SymbolMat a -> SymbolMat a


{- | 'Loss' newtype wraps 'LossFn's - differentiable functions that are able to provide a reference of what relation between matrices needs to be minimised.

Every loss function must return symbol of singleton matrix.
-}
newtype Loss a = Loss 
    { forall a. Loss a -> LossFn a
unLoss :: LossFn a  -- ^ Unwraps 'Loss' newtype.

    }


-- Regression losses


-- | Computes the mean of squares of errors.

mse :: (Symbolic a, Floating a) => LossFn a
mse :: forall a. (Symbolic a, Floating a) => LossFn a
mse SymbolMat a
true SymbolMat a
predicted = SymbolMat a -> SymbolMat a
forall f. (SingletonOps f, Fractional (DType f)) => f -> f
mean (SymbolMat a -> SymbolMat a) -> SymbolMat a -> SymbolMat a
forall a b. (a -> b) -> a -> b
$ (SymbolMat a
true SymbolMat a -> SymbolMat a -> SymbolMat a
forall a. Num a => a -> a -> a
- SymbolMat a
predicted) SymbolMat a -> DType (SymbolMat a) -> SymbolMat a
forall f.
(ElementwiseScalarOps f, Floating (DType f)) =>
f -> DType f -> f
**. a
DType (SymbolMat a)
2.0

-- | Computes the mean squared logarithmic error.

msle :: (Symbolic a, Floating a) => LossFn a
msle :: forall a. (Symbolic a, Floating a) => LossFn a
msle SymbolMat a
true SymbolMat a
predicted = SymbolMat a -> SymbolMat a
forall f. (SingletonOps f, Fractional (DType f)) => f -> f
mean (SymbolMat a -> SymbolMat a) -> SymbolMat a -> SymbolMat a
forall a b. (a -> b) -> a -> b
$ (SymbolMat a -> SymbolMat a
forall a. Floating a => a -> a
log (SymbolMat a
true SymbolMat a -> DType (SymbolMat a) -> SymbolMat a
forall f.
(ElementwiseScalarOps f, Num (DType f)) =>
f -> DType f -> f
+. a
DType (SymbolMat a)
1) SymbolMat a -> SymbolMat a -> SymbolMat a
forall a. Num a => a -> a -> a
- SymbolMat a -> SymbolMat a
forall a. Floating a => a -> a
log (SymbolMat a
predicted SymbolMat a -> DType (SymbolMat a) -> SymbolMat a
forall f.
(ElementwiseScalarOps f, Num (DType f)) =>
f -> DType f -> f
+. a
DType (SymbolMat a)
1)) SymbolMat a -> DType (SymbolMat a) -> SymbolMat a
forall f.
(ElementwiseScalarOps f, Floating (DType f)) =>
f -> DType f -> f
**. a
DType (SymbolMat a)
2.0

-- | Computes the mean of absolute error.

mae :: (Symbolic a, Floating a) => LossFn a
mae :: forall a. (Symbolic a, Floating a) => LossFn a
mae SymbolMat a
true SymbolMat a
predicted = SymbolMat a -> SymbolMat a
forall f. (SingletonOps f, Fractional (DType f)) => f -> f
mean (SymbolMat a -> SymbolMat a) -> SymbolMat a -> SymbolMat a
forall a b. (a -> b) -> a -> b
$ SymbolMat a -> SymbolMat a
forall a. Num a => a -> a
abs (SymbolMat a
true SymbolMat a -> SymbolMat a -> SymbolMat a
forall a. Num a => a -> a -> a
- SymbolMat a
predicted)

-- | Computes the mean absolute percentage error.

mape :: (Symbolic a, Floating a) => LossFn a
mape :: forall a. (Symbolic a, Floating a) => LossFn a
mape SymbolMat a
true SymbolMat a
predicted = SymbolMat a -> SymbolMat a
forall f. (SingletonOps f, Fractional (DType f)) => f -> f
mean (SymbolMat a -> SymbolMat a
forall a. Num a => a -> a
abs (SymbolMat a
true SymbolMat a -> SymbolMat a -> SymbolMat a
forall a. Num a => a -> a -> a
- SymbolMat a
predicted) SymbolMat a -> SymbolMat a -> SymbolMat a
forall a. Fractional a => a -> a -> a
/ SymbolMat a
true) SymbolMat a -> DType (SymbolMat a) -> SymbolMat a
forall f.
(ElementwiseScalarOps f, Num (DType f)) =>
f -> DType f -> f
*. a
DType (SymbolMat a)
100

-- | Computes the logarithm of the hyperbolic cosine of the error.

logcosh :: (Symbolic a, Floating a) => LossFn a
logcosh :: forall a. (Symbolic a, Floating a) => LossFn a
logcosh SymbolMat a
true SymbolMat a
predicted = SymbolMat a -> SymbolMat a
forall f. (SingletonOps f, Fractional (DType f)) => f -> f
mean (SymbolMat a -> SymbolMat a) -> SymbolMat a -> SymbolMat a
forall a b. (a -> b) -> a -> b
$ SymbolMat a -> SymbolMat a
forall a. Floating a => a -> a
log (SymbolMat a -> SymbolMat a) -> SymbolMat a -> SymbolMat a
forall a b. (a -> b) -> a -> b
$ SymbolMat a -> SymbolMat a
forall a. Floating a => a -> a
cosh (SymbolMat a
true SymbolMat a -> SymbolMat a -> SymbolMat a
forall a. Num a => a -> a -> a
- SymbolMat a
predicted)