{- | Provides learning rate functions - functions that return coefficient which modulates how big are updates of parameters in training.
-}


module Synapse.NN.LearningRates
    ( -- * 'LearningRateFn' type alias and 'LearningRate' newtype


      LearningRateFn
    , LearningRate (LearningRate, unLearningRate)

      -- * Learning rate decay functions

    
    , exponentialDecay
    , inverseTimeDecay
    , polynomialDecay
    , cosineDecay
    , piecewiseConstantDecay
    ) where


-- | 'LearningRateFn' type alias represents functions that return coefficient which modulates how big are updates of parameters in training.

type LearningRateFn a = Int -> a


-- | 'LearningRate' newtype wraps 'LearningRateFn's - functions that modulate how big are updates of parameters in training.

newtype LearningRate a = LearningRate
    { forall a. LearningRate a -> LearningRateFn a
unLearningRate :: LearningRateFn a  -- ^ Unwraps 'LearningRate' newtype.

    }


-- Learning rate decay functions


-- | Takes initial learning rate, decay steps and decay rate and calculates exponential decay learning rate (@initial * decay_rate ^ (step / decay_steps)@).

exponentialDecay :: Num a => a -> Int -> a -> LearningRateFn a
exponentialDecay :: forall a. Num a => a -> Int -> a -> LearningRateFn a
exponentialDecay a
initial Int
steps a
rate Int
step = a
initial a -> a -> a
forall a. Num a => a -> a -> a
* a
rate a -> Int -> a
forall a b. (Num a, Integral b) => a -> b -> a
^ (Int
step Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
steps)

-- | Takes initial learning rate, decay steps and decay rate and calculates inverse time decay learning rate (@initial / (1 + rate * step / steps))@).

inverseTimeDecay :: Fractional a => a -> Int -> a -> LearningRateFn a
inverseTimeDecay :: forall a. Fractional a => a -> Int -> a -> LearningRateFn a
inverseTimeDecay a
initial Int
steps a
rate Int
step = a
initial a -> a -> a
forall a. Fractional a => a -> a -> a
/ (a
1.0 a -> a -> a
forall a. Num a => a -> a -> a
+ a
rate a -> a -> a
forall a. Num a => a -> a -> a
* Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
step Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
steps))

{- | Takes initial learning rate, decay steps, polynomial power and end decay and calculates polynomial decay learning rate
(@if step < steps then initial * (1 - step / steps) ** power else end@).
-}
polynomialDecay :: Floating a => a -> Int -> a -> a -> LearningRateFn a
polynomialDecay :: forall a. Floating a => a -> Int -> a -> a -> LearningRateFn a
polynomialDecay a
initial Int
steps a
power a
end Int
step
    | Int
step Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
steps = a
initial a -> a -> a
forall a. Num a => a -> a -> a
* Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
step Int
steps Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
steps) a -> a -> a
forall a. Floating a => a -> a -> a
** a
power
    | Bool
otherwise    = a
end

{- | Takes initial learning rate, decay steps, alpha coefficient and warmup steps and target (optional) and calculates cosine decay learning rate
(@(1 - alpha) * (0.5 * (1.0 + cos (pi * step / steps))) + alpha) *
(if warmup then (if step < warmupSteps then (warmupLR - initial) * step / warmupSteps else warmupLR) else initial)@).
-}
cosineDecay :: Floating a => a -> Int -> a -> Maybe (Int, a) -> LearningRateFn a
cosineDecay :: forall a.
Floating a =>
a -> Int -> a -> Maybe (Int, a) -> LearningRateFn a
cosineDecay a
initial Int
steps a
alpha Maybe (Int, a)
warmup Int
step =
    ((a
1.0 a -> a -> a
forall a. Num a => a -> a -> a
- a
alpha) a -> a -> a
forall a. Num a => a -> a -> a
* (a
0.5 a -> a -> a
forall a. Num a => a -> a -> a
* (a
1.0 a -> a -> a
forall a. Num a => a -> a -> a
+ a -> a
forall a. Floating a => a -> a
cos (a
forall a. Floating a => a
pi a -> a -> a
forall a. Num a => a -> a -> a
* Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
step Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
steps)))) a -> a -> a
forall a. Num a => a -> a -> a
+ a
alpha) a -> a -> a
forall a. Num a => a -> a -> a
*
    case Maybe (Int, a)
warmup of
        Maybe (Int, a)
Nothing -> a
initial
        Just (Int
warmupSteps, a
warmupLR) -> if Int
step Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
warmupSteps
                                        then (a
warmupLR a -> a -> a
forall a. Num a => a -> a -> a
- a
initial) a -> a -> a
forall a. Num a => a -> a -> a
* Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
step Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
warmupSteps)
                                        else a
warmupLR

{- | Takes list of boundaries and learning rate values and last rate value for those boundaries and calculates piecewise constant decay learning rate
(@if step < bound1 then value1 else if step < bound2 then value2 else lastRate@ for @[(bound1, value1), (bound2, value2)]@).
-}
piecewiseConstantDecay :: [(Int, a)] -> a -> LearningRateFn a
piecewiseConstantDecay :: forall a. [(Int, a)] -> a -> LearningRateFn a
piecewiseConstantDecay [] a
lastRate Int
_ = a
lastRate
piecewiseConstantDecay ((Int
stepBound, a
rateValue):[(Int, a)]
xs) a
lastRate Int
step
    | Int
step Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
stepBound = a
rateValue
    | Bool
otherwise        = [(Int, a)] -> a -> LearningRateFn a
forall a. [(Int, a)] -> a -> LearningRateFn a
piecewiseConstantDecay [(Int, a)]
xs a
lastRate Int
step