{- | Allows to initialize values of layers parameters.

'InitializerFn' type alias represents functions that are able to initialize matrix with given size
and 'Initializer' newtype wraps 'InitializerFn's.

"Synapse" provides 4 types of initializers:
* Non-random constant initializers
* Random uniform distribution initializers
* Random normal distribution initializers
* Matrix-specific initializers
-}


module Synapse.NN.Layers.Initializers
    ( -- * 'InitializerFn' type alias and 'Initializer' newtype


      InitializerFn

    , Initializer (Initializer, unInitializer)
        
      -- * Non-random constant initializers


    , constants
    , zeroes
    , ones

      -- * Random uniform distribution initializers


    , randomUniform
    , lecunUniform
    , heUniform
    , glorotUniform

      -- * Random normal distribution initializers

    
    , randomNormal
    , lecunNormal
    , heNormal
    , glorotNormal

      -- * Matrix-like initializers

    
    , identity
    , orthogonal
    ) where


import Synapse.Tensors.Mat (Mat)
import qualified Synapse.Tensors.Mat as M

import System.Random (uniformListR, uniformRs, UniformRange, RandomGen)


-- | 'InitializerFn' type alias represents functions that are able to initialize matrix with given size.

type InitializerFn a = (Int, Int) -> Mat a


-- | 'Initializer' newtype wraps 'InitializerFn's - functions that are able to initialize matrix with given size.

newtype Initializer a = Initializer 
    { forall a. Initializer a -> InitializerFn a
unInitializer :: InitializerFn a  -- ^ Unwraps 'Initializer' newtype.

    }


-- Non-random constant initializers


-- | Initializes list that is filled with given constant.

constants :: Num a => a -> InitializerFn a
constants :: forall a. Num a => a -> InitializerFn a
constants a
c (Int
input, Int
output) = (Int, Int) -> a -> Mat a
forall a. (Int, Int) -> a -> Mat a
M.replicate (Int
input, Int
output) a
c

-- | Initializes list that is filled with zeroes.

zeroes :: Num a => InitializerFn a
zeroes :: forall a. Num a => InitializerFn a
zeroes = a -> InitializerFn a
forall a. Num a => a -> InitializerFn a
constants a
0

-- | Initializes list that is filled with ones.

ones :: Num a => InitializerFn a
ones :: forall a. Num a => InitializerFn a
ones = a -> InitializerFn a
forall a. Num a => a -> InitializerFn a
constants a
1


-- Random uniform distribution initializers


{- | Initializes list with samples from random uniform distribution in range.

This function does not preserve seed generator - split generator before calling this function.
-}
randomUniform :: (UniformRange a, RandomGen g) => (a, a) -> g -> InitializerFn a
randomUniform :: forall a g.
(UniformRange a, RandomGen g) =>
(a, a) -> g -> InitializerFn a
randomUniform (a, a)
range g
gen sizes :: (Int, Int)
sizes@(Int
input, Int
output) = (Int, Int) -> [a] -> Mat a
forall a. (Int, Int) -> [a] -> Mat a
M.fromList (Int, Int)
sizes ([a] -> Mat a) -> [a] -> Mat a
forall a b. (a -> b) -> a -> b
$ ([a], g) -> [a]
forall a b. (a, b) -> a
fst (([a], g) -> [a]) -> ([a], g) -> [a]
forall a b. (a -> b) -> a -> b
$ Int -> (a, a) -> g -> ([a], g)
forall a g.
(UniformRange a, RandomGen g) =>
Int -> (a, a) -> g -> ([a], g)
uniformListR (Int
input Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
output) (a, a)
range g
gen

{- | Initializes list with samples from random LeCun uniform distribution in range.

This function does not preserve seed generator - split generator before calling this function.
-}
lecunUniform :: (UniformRange a, Floating a, RandomGen g) => g -> InitializerFn a
lecunUniform :: forall a g.
(UniformRange a, Floating a, RandomGen g) =>
g -> InitializerFn a
lecunUniform g
gen sizes :: (Int, Int)
sizes@(Int
input, Int
_) = let limit :: a
limit = a -> a
forall a. Floating a => a -> a
sqrt (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ a
3.0 a -> a -> a
forall a. Fractional a => a -> a -> a
/ Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
input
                                    in (a, a) -> g -> InitializerFn a
forall a g.
(UniformRange a, RandomGen g) =>
(a, a) -> g -> InitializerFn a
randomUniform (-a
limit, a
limit) g
gen (Int, Int)
sizes

{- | Initializes list with samples from random He uniform distribution in range.

This function does not preserve seed generator - split generator before calling this function.
-}
heUniform :: (UniformRange a, Floating a, RandomGen g) => g -> InitializerFn a
heUniform :: forall a g.
(UniformRange a, Floating a, RandomGen g) =>
g -> InitializerFn a
heUniform g
gen sizes :: (Int, Int)
sizes@(Int
input, Int
_) = let limit :: a
limit = a -> a
forall a. Floating a => a -> a
sqrt (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ a
6.0 a -> a -> a
forall a. Fractional a => a -> a -> a
/ Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
input
                                 in (a, a) -> g -> InitializerFn a
forall a g.
(UniformRange a, RandomGen g) =>
(a, a) -> g -> InitializerFn a
randomUniform (-a
limit, a
limit) g
gen (Int, Int)
sizes

{- | Initializes list with samples from random Glorot uniform distribution in range.

This function does not preserve seed generator - split generator before calling this function.
-}
glorotUniform :: (UniformRange a, Floating a, RandomGen g) => g -> InitializerFn a
glorotUniform :: forall a g.
(UniformRange a, Floating a, RandomGen g) =>
g -> InitializerFn a
glorotUniform g
gen sizes :: (Int, Int)
sizes@(Int
input, Int
output) = let limit :: a
limit = a -> a
forall a. Floating a => a -> a
sqrt (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ a
6.0 a -> a -> a
forall a. Fractional a => a -> a -> a
/ Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
input Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
output)
                                          in (a, a) -> g -> InitializerFn a
forall a g.
(UniformRange a, RandomGen g) =>
(a, a) -> g -> InitializerFn a
randomUniform (-a
limit, a
limit) g
gen (Int, Int)
sizes


-- Random normal distribution initializers


{- | Initializes list with samples from random normal distribution in range which could be truncated.

This function does not preserve seed generator - split generator before calling this function.
-}
randomNormal :: (UniformRange a, Floating a, Ord a, RandomGen g) => Maybe a -> a -> a -> g -> InitializerFn a
randomNormal :: forall a g.
(UniformRange a, Floating a, Ord a, RandomGen g) =>
Maybe a -> a -> a -> g -> InitializerFn a
randomNormal Maybe a
truncated a
mean a
stdDev g
gen sizes :: (Int, Int)
sizes@(Int
input, Int
output) = let us :: [(a, a)]
us = [a] -> [(a, a)]
forall {b}. Num b => [b] -> [(b, b)]
pairs ([a] -> [(a, a)]) -> [a] -> [(a, a)]
forall a b. (a -> b) -> a -> b
$ (a, a) -> g -> [a]
forall a g. (UniformRange a, RandomGen g) => (a, a) -> g -> [a]
uniformRs (a
0.0, a
1.0) g
gen
                                                                   ns :: [a]
ns = ((a, a) -> [a]) -> [(a, a)] -> [a]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((\(a
n1, a
n2) -> [a
n1, a
n2]) ((a, a) -> [a]) -> ((a, a) -> (a, a)) -> (a, a) -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a, a) -> (a, a)
forall {b}. Floating b => (b, b) -> (b, b)
transformBoxMuller) [(a, a)]
us
                                                                   ns' :: [a]
ns' = (a -> a) -> [a] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map ((a -> a -> a
forall a. Num a => a -> a -> a
+ a
mean) (a -> a) -> (a -> a) -> a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> a -> a
forall a. Num a => a -> a -> a
* a
stdDev)) [a]
ns
                                                                   ns'' :: [a]
ns'' = case Maybe a
truncated of
                                                                               Maybe a
Nothing  -> [a]
ns'
                                                                               Just a
eps -> (a -> Bool) -> [a] -> [a]
forall a. (a -> Bool) -> [a] -> [a]
filter (\a
x -> a -> a
forall a. Num a => a -> a
abs (a
x a -> a -> a
forall a. Num a => a -> a -> a
- a
mean) a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
eps) [a]
ns'
                                                               in (Int, Int) -> [a] -> Mat a
forall a. (Int, Int) -> [a] -> Mat a
M.fromList (Int, Int)
sizes ([a] -> Mat a) -> [a] -> Mat a
forall a b. (a -> b) -> a -> b
$ Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
take (Int
input Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
output) [a]
ns''
  where
    pairs :: [b] -> [(b, b)]
pairs [] = []
    pairs [b
x] = [(b
x, b
1)]
    pairs (b
a:b
b:[b]
xs) = (b
a, b
b) (b, b) -> [(b, b)] -> [(b, b)]
forall a. a -> [a] -> [a]
: [b] -> [(b, b)]
pairs [b]
xs

    transformBoxMuller :: (b, b) -> (b, b)
transformBoxMuller (b
u1, b
u2) = let r :: b
r = b -> b
forall a. Floating a => a -> a
sqrt (b -> b) -> b -> b
forall a b. (a -> b) -> a -> b
$ (-b
2.0) b -> b -> b
forall a. Num a => a -> a -> a
* b -> b
forall a. Floating a => a -> a
log b
u1
                                      theta :: b
theta = b
2.0 b -> b -> b
forall a. Num a => a -> a -> a
* b
forall a. Floating a => a
pi b -> b -> b
forall a. Num a => a -> a -> a
* b
u2
                                  in (b
r b -> b -> b
forall a. Num a => a -> a -> a
* b -> b
forall a. Floating a => a -> a
cos b
theta, b
r b -> b -> b
forall a. Num a => a -> a -> a
* b -> b
forall a. Floating a => a -> a
sin b
theta)

{- | Initializes list with samples from random LeCun normal distribution in range
which is truncated for values more than two standard deviations from mean.

This function does not preserve seed generator - split generator before calling this function.
-}
lecunNormal :: (UniformRange a, Floating a, Ord a, RandomGen g) => g -> InitializerFn a
lecunNormal :: forall a g.
(UniformRange a, Floating a, Ord a, RandomGen g) =>
g -> InitializerFn a
lecunNormal g
gen sizes :: (Int, Int)
sizes@(Int
input, Int
_) = let mean :: a
mean = a
0
                                       stdDev :: a
stdDev = a -> a
forall a. Floating a => a -> a
sqrt (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ a
1.0 a -> a -> a
forall a. Fractional a => a -> a -> a
/ Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
input
                                   in Maybe a -> a -> a -> g -> InitializerFn a
forall a g.
(UniformRange a, Floating a, Ord a, RandomGen g) =>
Maybe a -> a -> a -> g -> InitializerFn a
randomNormal (a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> a -> Maybe a
forall a b. (a -> b) -> a -> b
$ a
2.0 a -> a -> a
forall a. Num a => a -> a -> a
* a
stdDev) a
mean a
stdDev g
gen (Int, Int)
sizes

{- | Initializes list with samples from random He normal distribution in range
which is truncated for values more than two standard deviations from mean.

This function does not preserve seed generator - split generator before calling this function.
-}
heNormal :: (UniformRange a, Floating a, Ord a, RandomGen g) => g -> InitializerFn a
heNormal :: forall a g.
(UniformRange a, Floating a, Ord a, RandomGen g) =>
g -> InitializerFn a
heNormal g
gen sizes :: (Int, Int)
sizes@(Int
input, Int
_) = let mean :: a
mean = a
0
                                    stdDev :: a
stdDev = a -> a
forall a. Floating a => a -> a
sqrt (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ a
2.0 a -> a -> a
forall a. Fractional a => a -> a -> a
/ Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
input
                                in Maybe a -> a -> a -> g -> InitializerFn a
forall a g.
(UniformRange a, Floating a, Ord a, RandomGen g) =>
Maybe a -> a -> a -> g -> InitializerFn a
randomNormal (a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> a -> Maybe a
forall a b. (a -> b) -> a -> b
$ a
2.0 a -> a -> a
forall a. Num a => a -> a -> a
* a
stdDev) a
mean a
stdDev g
gen (Int, Int)
sizes

{- | Initializes list with samples from random Glorot normal distribution in range
which is truncated for values more than two standard deviations from mean.

This function does not preserve seed generator - split generator before calling this function.
-}
glorotNormal :: (UniformRange a, Floating a, Ord a, RandomGen g) => g -> InitializerFn a
glorotNormal :: forall a g.
(UniformRange a, Floating a, Ord a, RandomGen g) =>
g -> InitializerFn a
glorotNormal g
gen sizes :: (Int, Int)
sizes@(Int
input, Int
output) = let mean :: a
mean = a
0
                                             stdDev :: a
stdDev = a -> a
forall a. Floating a => a -> a
sqrt (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ a
2.0 a -> a -> a
forall a. Fractional a => a -> a -> a
/ Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
input Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
output)
                                         in Maybe a -> a -> a -> g -> InitializerFn a
forall a g.
(UniformRange a, Floating a, Ord a, RandomGen g) =>
Maybe a -> a -> a -> g -> InitializerFn a
randomNormal (a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> a -> Maybe a
forall a b. (a -> b) -> a -> b
$ a
2.0 a -> a -> a
forall a. Num a => a -> a -> a
* a
stdDev) a
mean a
stdDev g
gen (Int, Int)
sizes


-- Matrix-like initializers


-- | Initializes flat identity matrix. If dimensions do not represent square matrix, an error is thrown.

identity :: Num a => InitializerFn a
identity :: forall a. Num a => InitializerFn a
identity (Int
input, Int
output)
    | Int
input Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
output = [Char] -> Mat a
forall a. HasCallStack => [Char] -> a
error [Char]
"Given dimensions do not represent square matrix"
    | Bool
otherwise       = Int -> Mat a
forall a. Num a => Int -> Mat a
M.identity Int
input

{- | Initializes float orthogonal matrix obtained from a random normal distribution
that is truncated for values more than two standard deviations from mean.

This function does not preserve seed generator - split generator before calling this function.
-}
orthogonal :: (UniformRange a, Floating a, Ord a, RandomGen g) => g -> InitializerFn a
orthogonal :: forall a g.
(UniformRange a, Floating a, Ord a, RandomGen g) =>
g -> InitializerFn a
orthogonal g
gen (Int, Int)
sizes = Mat a -> Mat a
forall a. Floating a => Mat a -> Mat a
M.orthogonalized (Mat a -> Mat a) -> Mat a -> Mat a
forall a b. (a -> b) -> a -> b
$ Maybe a -> a -> a -> g -> InitializerFn a
forall a g.
(UniformRange a, Floating a, Ord a, RandomGen g) =>
Maybe a -> a -> a -> g -> InitializerFn a
randomNormal Maybe a
forall a. Maybe a
Nothing a
0.0 a
1.0 g
gen (Int, Int)
sizes