module Synapse.NN.Layers.Initializers
(
InitializerFn
, Initializer (Initializer, unInitializer)
, constants
, zeroes
, ones
, randomUniform
, lecunUniform
, heUniform
, glorotUniform
, randomNormal
, lecunNormal
, heNormal
, glorotNormal
, identity
, orthogonal
) where
import Synapse.Tensors.Mat (Mat)
import qualified Synapse.Tensors.Mat as M
import System.Random (uniformListR, uniformRs, UniformRange, RandomGen)
type InitializerFn a = (Int, Int) -> Mat a
newtype Initializer a = Initializer
{ forall a. Initializer a -> InitializerFn a
unInitializer :: InitializerFn a
}
constants :: Num a => a -> InitializerFn a
constants :: forall a. Num a => a -> InitializerFn a
constants a
c (Int
input, Int
output) = (Int, Int) -> a -> Mat a
forall a. (Int, Int) -> a -> Mat a
M.replicate (Int
input, Int
output) a
c
zeroes :: Num a => InitializerFn a
zeroes :: forall a. Num a => InitializerFn a
zeroes = a -> InitializerFn a
forall a. Num a => a -> InitializerFn a
constants a
0
ones :: Num a => InitializerFn a
ones :: forall a. Num a => InitializerFn a
ones = a -> InitializerFn a
forall a. Num a => a -> InitializerFn a
constants a
1
randomUniform :: (UniformRange a, RandomGen g) => (a, a) -> g -> InitializerFn a
randomUniform :: forall a g.
(UniformRange a, RandomGen g) =>
(a, a) -> g -> InitializerFn a
randomUniform (a, a)
range g
gen sizes :: (Int, Int)
sizes@(Int
input, Int
output) = (Int, Int) -> [a] -> Mat a
forall a. (Int, Int) -> [a] -> Mat a
M.fromList (Int, Int)
sizes ([a] -> Mat a) -> [a] -> Mat a
forall a b. (a -> b) -> a -> b
$ ([a], g) -> [a]
forall a b. (a, b) -> a
fst (([a], g) -> [a]) -> ([a], g) -> [a]
forall a b. (a -> b) -> a -> b
$ Int -> (a, a) -> g -> ([a], g)
forall a g.
(UniformRange a, RandomGen g) =>
Int -> (a, a) -> g -> ([a], g)
uniformListR (Int
input Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
output) (a, a)
range g
gen
lecunUniform :: (UniformRange a, Floating a, RandomGen g) => g -> InitializerFn a
lecunUniform :: forall a g.
(UniformRange a, Floating a, RandomGen g) =>
g -> InitializerFn a
lecunUniform g
gen sizes :: (Int, Int)
sizes@(Int
input, Int
_) = let limit :: a
limit = a -> a
forall a. Floating a => a -> a
sqrt (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ a
3.0 a -> a -> a
forall a. Fractional a => a -> a -> a
/ Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
input
in (a, a) -> g -> InitializerFn a
forall a g.
(UniformRange a, RandomGen g) =>
(a, a) -> g -> InitializerFn a
randomUniform (-a
limit, a
limit) g
gen (Int, Int)
sizes
heUniform :: (UniformRange a, Floating a, RandomGen g) => g -> InitializerFn a
heUniform :: forall a g.
(UniformRange a, Floating a, RandomGen g) =>
g -> InitializerFn a
heUniform g
gen sizes :: (Int, Int)
sizes@(Int
input, Int
_) = let limit :: a
limit = a -> a
forall a. Floating a => a -> a
sqrt (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ a
6.0 a -> a -> a
forall a. Fractional a => a -> a -> a
/ Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
input
in (a, a) -> g -> InitializerFn a
forall a g.
(UniformRange a, RandomGen g) =>
(a, a) -> g -> InitializerFn a
randomUniform (-a
limit, a
limit) g
gen (Int, Int)
sizes
glorotUniform :: (UniformRange a, Floating a, RandomGen g) => g -> InitializerFn a
glorotUniform :: forall a g.
(UniformRange a, Floating a, RandomGen g) =>
g -> InitializerFn a
glorotUniform g
gen sizes :: (Int, Int)
sizes@(Int
input, Int
output) = let limit :: a
limit = a -> a
forall a. Floating a => a -> a
sqrt (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ a
6.0 a -> a -> a
forall a. Fractional a => a -> a -> a
/ Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
input Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
output)
in (a, a) -> g -> InitializerFn a
forall a g.
(UniformRange a, RandomGen g) =>
(a, a) -> g -> InitializerFn a
randomUniform (-a
limit, a
limit) g
gen (Int, Int)
sizes
randomNormal :: (UniformRange a, Floating a, Ord a, RandomGen g) => Maybe a -> a -> a -> g -> InitializerFn a
randomNormal :: forall a g.
(UniformRange a, Floating a, Ord a, RandomGen g) =>
Maybe a -> a -> a -> g -> InitializerFn a
randomNormal Maybe a
truncated a
mean a
stdDev g
gen sizes :: (Int, Int)
sizes@(Int
input, Int
output) = let us :: [(a, a)]
us = [a] -> [(a, a)]
forall {b}. Num b => [b] -> [(b, b)]
pairs ([a] -> [(a, a)]) -> [a] -> [(a, a)]
forall a b. (a -> b) -> a -> b
$ (a, a) -> g -> [a]
forall a g. (UniformRange a, RandomGen g) => (a, a) -> g -> [a]
uniformRs (a
0.0, a
1.0) g
gen
ns :: [a]
ns = ((a, a) -> [a]) -> [(a, a)] -> [a]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((\(a
n1, a
n2) -> [a
n1, a
n2]) ((a, a) -> [a]) -> ((a, a) -> (a, a)) -> (a, a) -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a, a) -> (a, a)
forall {b}. Floating b => (b, b) -> (b, b)
transformBoxMuller) [(a, a)]
us
ns' :: [a]
ns' = (a -> a) -> [a] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map ((a -> a -> a
forall a. Num a => a -> a -> a
+ a
mean) (a -> a) -> (a -> a) -> a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> a -> a
forall a. Num a => a -> a -> a
* a
stdDev)) [a]
ns
ns'' :: [a]
ns'' = case Maybe a
truncated of
Maybe a
Nothing -> [a]
ns'
Just a
eps -> (a -> Bool) -> [a] -> [a]
forall a. (a -> Bool) -> [a] -> [a]
filter (\a
x -> a -> a
forall a. Num a => a -> a
abs (a
x a -> a -> a
forall a. Num a => a -> a -> a
- a
mean) a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
eps) [a]
ns'
in (Int, Int) -> [a] -> Mat a
forall a. (Int, Int) -> [a] -> Mat a
M.fromList (Int, Int)
sizes ([a] -> Mat a) -> [a] -> Mat a
forall a b. (a -> b) -> a -> b
$ Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
take (Int
input Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
output) [a]
ns''
where
pairs :: [b] -> [(b, b)]
pairs [] = []
pairs [b
x] = [(b
x, b
1)]
pairs (b
a:b
b:[b]
xs) = (b
a, b
b) (b, b) -> [(b, b)] -> [(b, b)]
forall a. a -> [a] -> [a]
: [b] -> [(b, b)]
pairs [b]
xs
transformBoxMuller :: (b, b) -> (b, b)
transformBoxMuller (b
u1, b
u2) = let r :: b
r = b -> b
forall a. Floating a => a -> a
sqrt (b -> b) -> b -> b
forall a b. (a -> b) -> a -> b
$ (-b
2.0) b -> b -> b
forall a. Num a => a -> a -> a
* b -> b
forall a. Floating a => a -> a
log b
u1
theta :: b
theta = b
2.0 b -> b -> b
forall a. Num a => a -> a -> a
* b
forall a. Floating a => a
pi b -> b -> b
forall a. Num a => a -> a -> a
* b
u2
in (b
r b -> b -> b
forall a. Num a => a -> a -> a
* b -> b
forall a. Floating a => a -> a
cos b
theta, b
r b -> b -> b
forall a. Num a => a -> a -> a
* b -> b
forall a. Floating a => a -> a
sin b
theta)
lecunNormal :: (UniformRange a, Floating a, Ord a, RandomGen g) => g -> InitializerFn a
lecunNormal :: forall a g.
(UniformRange a, Floating a, Ord a, RandomGen g) =>
g -> InitializerFn a
lecunNormal g
gen sizes :: (Int, Int)
sizes@(Int
input, Int
_) = let mean :: a
mean = a
0
stdDev :: a
stdDev = a -> a
forall a. Floating a => a -> a
sqrt (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ a
1.0 a -> a -> a
forall a. Fractional a => a -> a -> a
/ Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
input
in Maybe a -> a -> a -> g -> InitializerFn a
forall a g.
(UniformRange a, Floating a, Ord a, RandomGen g) =>
Maybe a -> a -> a -> g -> InitializerFn a
randomNormal (a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> a -> Maybe a
forall a b. (a -> b) -> a -> b
$ a
2.0 a -> a -> a
forall a. Num a => a -> a -> a
* a
stdDev) a
mean a
stdDev g
gen (Int, Int)
sizes
heNormal :: (UniformRange a, Floating a, Ord a, RandomGen g) => g -> InitializerFn a
heNormal :: forall a g.
(UniformRange a, Floating a, Ord a, RandomGen g) =>
g -> InitializerFn a
heNormal g
gen sizes :: (Int, Int)
sizes@(Int
input, Int
_) = let mean :: a
mean = a
0
stdDev :: a
stdDev = a -> a
forall a. Floating a => a -> a
sqrt (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ a
2.0 a -> a -> a
forall a. Fractional a => a -> a -> a
/ Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
input
in Maybe a -> a -> a -> g -> InitializerFn a
forall a g.
(UniformRange a, Floating a, Ord a, RandomGen g) =>
Maybe a -> a -> a -> g -> InitializerFn a
randomNormal (a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> a -> Maybe a
forall a b. (a -> b) -> a -> b
$ a
2.0 a -> a -> a
forall a. Num a => a -> a -> a
* a
stdDev) a
mean a
stdDev g
gen (Int, Int)
sizes
glorotNormal :: (UniformRange a, Floating a, Ord a, RandomGen g) => g -> InitializerFn a
glorotNormal :: forall a g.
(UniformRange a, Floating a, Ord a, RandomGen g) =>
g -> InitializerFn a
glorotNormal g
gen sizes :: (Int, Int)
sizes@(Int
input, Int
output) = let mean :: a
mean = a
0
stdDev :: a
stdDev = a -> a
forall a. Floating a => a -> a
sqrt (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ a
2.0 a -> a -> a
forall a. Fractional a => a -> a -> a
/ Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
input Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
output)
in Maybe a -> a -> a -> g -> InitializerFn a
forall a g.
(UniformRange a, Floating a, Ord a, RandomGen g) =>
Maybe a -> a -> a -> g -> InitializerFn a
randomNormal (a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> a -> Maybe a
forall a b. (a -> b) -> a -> b
$ a
2.0 a -> a -> a
forall a. Num a => a -> a -> a
* a
stdDev) a
mean a
stdDev g
gen (Int, Int)
sizes
identity :: Num a => InitializerFn a
identity :: forall a. Num a => InitializerFn a
identity (Int
input, Int
output)
| Int
input Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
output = [Char] -> Mat a
forall a. HasCallStack => [Char] -> a
error [Char]
"Given dimensions do not represent square matrix"
| Bool
otherwise = Int -> Mat a
forall a. Num a => Int -> Mat a
M.identity Int
input
orthogonal :: (UniformRange a, Floating a, Ord a, RandomGen g) => g -> InitializerFn a
orthogonal :: forall a g.
(UniformRange a, Floating a, Ord a, RandomGen g) =>
g -> InitializerFn a
orthogonal g
gen (Int, Int)
sizes = Mat a -> Mat a
forall a. Floating a => Mat a -> Mat a
M.orthogonalized (Mat a -> Mat a) -> Mat a -> Mat a
forall a b. (a -> b) -> a -> b
$ Maybe a -> a -> a -> g -> InitializerFn a
forall a g.
(UniformRange a, Floating a, Ord a, RandomGen g) =>
Maybe a -> a -> a -> g -> InitializerFn a
randomNormal Maybe a
forall a. Maybe a
Nothing a
0.0 a
1.0 g
gen (Int, Int)
sizes