{- | This module provides datatypes and functions that implement neural networks training.
-}


-- | 'OverloadedStrings' are needed to use strings in progress bar.

{-# LANGUAGE OverloadedStrings #-}


module Synapse.NN.Training
    ( -- * 'Callbacks' datatype and associated type aliases


      CallbackFnOnTrainBegin
    , CallbackFnOnEpochBegin
    , CallbackFnOnBatchBegin
    , CallbackFnOnBatchEnd
    , CallbackFnOnEpochEnd
    , CallbackFnOnTrainEnd

    , Callbacks
          ( Callbacks
          , callbacksOnTrainBegin
          , callbacksOnEpochBegin
          , callbacksOnBatchBegin
          , callbacksOnBatchEnd
          , callbacksOnEpochEnd
          , callbacksOnTrainEnd
          )
    , emptyCallbacks

    -- * 'Hyperparameters' datatype


    , Hyperparameters
          ( Hyperparameters
          , hyperparametersEpochs
          , hyperparametersBatchSize
          , hyperparametersDataset
          , hyperparametersLearningRate
          , hyperparametersLoss
          , hyperparametersMetrics
          )

    , RecordedMetric (RecordedMetric, unRecordedMetric)

      -- * Training


    , train
    ) where


import Synapse.Tensors (SingletonOps(unSingleton))
import Synapse.Tensors.Vec (Vec(Vec), unVec)
import Synapse.Tensors.Mat (Mat)

import Synapse.Autograd (Symbolic, SymbolIdentifier(SymbolIdentifier), unSymbol, symbol, constSymbol, getGradientsOf)

import Synapse.NN.Layers.Layer (AbstractLayer(..))
import Synapse.NN.Optimizers (Optimizer(..), optimizerUpdateParameters)
import Synapse.NN.Batching (Sample(Sample), unDataset, shuffleDataset, VecDataset, batchVectors, BatchedDataset)
import Synapse.NN.LearningRates (LearningRate(LearningRate))
import Synapse.NN.Losses (Loss(Loss))
import Synapse.NN.Metrics (Metric(Metric))

import Data.Functor ((<&>))
import Control.Monad (forM_, when)
import Data.IORef (IORef, newIORef, writeIORef, readIORef, modifyIORef')

import System.Random (RandomGen)

import System.ProgressBar

import qualified Data.Vector as V
import qualified Data.Vector.Mutable as MV


-- | Type of callback that is called at the beginning of training.

type CallbackFnOnTrainBegin model optimizer a
    =  IORef (model a)                          -- ^ Initial model state.

    -> IORef [OptimizerParameters optimizer a]  -- ^ Initial optimizer parameters.

    -> IO ()

-- | Type of callback that is called at the beginning of training epoch.

type CallbackFnOnEpochBegin model optimizer a
    =  IORef Int                                -- ^ Current epoch.

    -> IORef (model a)                          -- ^ Model state at the beginning of the epoch processing.

    -> IORef [OptimizerParameters optimizer a]  -- ^ Optimizer parameters at the beginning of the epoch processing.

    -> IORef (BatchedDataset a)                 -- ^ Batched shuffled dataset.

    -> IO ()

-- | Type of callback that is called at the beginning of training batch.

type CallbackFnOnBatchBegin model optimizer a
    =  IORef Int                                -- ^ Current epoch.

    -> IORef Int                                -- ^ Current batch number.

    -> IORef (model a)                          -- ^ Model state at the beginning of the batch processing.

    -> IORef [OptimizerParameters optimizer a]  -- ^ Optimizer parameters at the beginning of the batch processing.

    -> IORef (Sample (Mat a))                   -- ^ Batch that is being processed.

    -> IORef a                                  -- ^ Learning rate value.

    -> IO ()

-- | Type of callback that is called at the end of training batch.

type CallbackFnOnBatchEnd model optimizer a
    =  IORef Int                                -- ^ Current epoch.

    -> IORef Int                                -- ^ Current batch number.

    -> IORef (model a)                          -- ^ Model state at the end of the batch processing.

    -> IORef [OptimizerParameters optimizer a]  -- ^ Optimizer parameters at the end of the batch processing.

    -> IORef (Vec a)                            -- ^ Metrics that were recorded on this batch.

    -> IO ()

-- | Type of callback that is called at the end of training epoch.

type CallbackFnOnEpochEnd model optimizer a
    =  IORef Int                                -- ^ Current epoch.

    -> IORef (model a)                          -- ^ Model state at the end of the epoch processing.

    -> IORef [OptimizerParameters optimizer a]  -- ^ Optimizer parameters at the end of the epoch processing.

    -> IO ()

-- | Type of callback that is called at the end of training.

type CallbackFnOnTrainEnd model optimizer a
    =  IORef (model a)                            -- ^ Model state at the end of the training.

    -> IORef [OptimizerParameters optimizer a]    -- ^ Optimizer parameters at the end of the training.

    -> IORef (Vec (RecordedMetric a))             -- ^ Recorded metrics.

    -> IO ()

{- | 'Callbacks' record datatype holds all callbacks for the training.

All callbacks take 'IORef's to various training parameters,
which allows to affect training in any way possible.

This interface should be used with caution, because some changes might break the training completely.
-}
data Callbacks model optimizer a = Callbacks
    { forall (model :: * -> *) (optimizer :: * -> *) a.
Callbacks model optimizer a
-> [CallbackFnOnTrainBegin model optimizer a]
callbacksOnTrainBegin :: [CallbackFnOnTrainBegin model optimizer a]  -- ^ Callbacks that will be called at the beginning of training.

    , forall (model :: * -> *) (optimizer :: * -> *) a.
Callbacks model optimizer a
-> [CallbackFnOnEpochBegin model optimizer a]
callbacksOnEpochBegin :: [CallbackFnOnEpochBegin model optimizer a]  -- ^ Callbacks that will be called at the beginning of training epoch processing.

    , forall (model :: * -> *) (optimizer :: * -> *) a.
Callbacks model optimizer a
-> [CallbackFnOnBatchBegin model optimizer a]
callbacksOnBatchBegin :: [CallbackFnOnBatchBegin model optimizer a]  -- ^ Callbacks that will be called at the beginning of training batch processing.

    , forall (model :: * -> *) (optimizer :: * -> *) a.
Callbacks model optimizer a
-> [CallbackFnOnBatchEnd model optimizer a]
callbacksOnBatchEnd   :: [CallbackFnOnBatchEnd   model optimizer a]  -- ^ Callbacks that will be called at the end of training batch processing.

    , forall (model :: * -> *) (optimizer :: * -> *) a.
Callbacks model optimizer a
-> [CallbackFnOnEpochEnd model optimizer a]
callbacksOnEpochEnd   :: [CallbackFnOnEpochEnd   model optimizer a]  -- ^ Callbacks that will be called at the end of training epoch processing.

    , forall (model :: * -> *) (optimizer :: * -> *) a.
Callbacks model optimizer a
-> [CallbackFnOnTrainEnd model optimizer a]
callbacksOnTrainEnd   :: [CallbackFnOnTrainEnd   model optimizer a]  -- ^ Callbacks that will be called at the end of training.

    }

-- | Returns empty 'Callbacks' record. It could also be used to build your own callbacks upon.

emptyCallbacks :: Callbacks model optimizer a
emptyCallbacks :: forall (model :: * -> *) (optimizer :: * -> *) a.
Callbacks model optimizer a
emptyCallbacks = [CallbackFnOnTrainBegin model optimizer a]
-> [CallbackFnOnEpochBegin model optimizer a]
-> [CallbackFnOnBatchBegin model optimizer a]
-> [CallbackFnOnBatchEnd model optimizer a]
-> [CallbackFnOnEpochEnd model optimizer a]
-> [CallbackFnOnTrainEnd model optimizer a]
-> Callbacks model optimizer a
forall (model :: * -> *) (optimizer :: * -> *) a.
[CallbackFnOnTrainBegin model optimizer a]
-> [CallbackFnOnEpochBegin model optimizer a]
-> [CallbackFnOnBatchBegin model optimizer a]
-> [CallbackFnOnBatchEnd model optimizer a]
-> [CallbackFnOnEpochEnd model optimizer a]
-> [CallbackFnOnTrainEnd model optimizer a]
-> Callbacks model optimizer a
Callbacks [] [] [] [] [] []


-- | 'Hyperparameters' datatype represents configuration of a training.

data Hyperparameters a = Hyperparameters
    { forall a. Hyperparameters a -> Int
hyperparametersEpochs       :: Int             -- ^ Number of epochs in the training.

    , forall a. Hyperparameters a -> Int
hyperparametersBatchSize    :: Int             -- ^ Size of batches that will be used in the training.


    , forall a. Hyperparameters a -> VecDataset a
hyperparametersDataset      :: VecDataset a    -- ^ Dataset with samples of vector functions.


    , forall a. Hyperparameters a -> LearningRate a
hyperparametersLearningRate :: LearningRate a  -- ^ 'LearningRate' that will be used in the training.

    , forall a. Hyperparameters a -> Loss a
hyperparametersLoss         :: Loss a          -- ^ 'Loss' that will be used in the training.


    , forall a. Hyperparameters a -> Vec (Metric a)
hyperparametersMetrics      :: Vec (Metric a)  -- ^ 'Metric's that will be recorded during training.

    }

-- | 'RecordedMetric' newtype wraps vector of results of metrics.

newtype RecordedMetric a = RecordedMetric
    { forall a. RecordedMetric a -> Vec a
unRecordedMetric :: Vec a  -- ^ Results of metric recording.

    }


-- | 'whileM_' function implements a monadic @while@ loop which can be @break@ed if the condition becomes false.

whileM_ :: (Monad m) => m Bool -> m a -> m ()
whileM_ :: forall (m :: * -> *) a. Monad m => m Bool -> m a -> m ()
whileM_ m Bool
p m a
f = m ()
go
  where
    go :: m ()
go = m Bool
p m Bool -> (Bool -> m ()) -> m ()
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Bool -> m () -> m ()) -> m () -> Bool -> m ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (m a
f m a -> m () -> m ()
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> m ()
go)

-- | 'train' function allows training neural networks on datasets with specified parameters.

train
    :: (Symbolic a, Floating a, Ord a, Show a, RandomGen g, AbstractLayer model, Optimizer optimizer)
    => model a                                                                     -- ^ Trained model.

    -> optimizer a                                                                 -- ^ Optimizer that will be during training.

    -> Hyperparameters a                                                           -- ^ Hyperparameters of training.

    -> Callbacks model optimizer a                                                 -- ^ Callbacks that will be used during training.

    -> g                                                                           -- ^ Generator of random values that will be used to shuffle dataset.

    -> IO (model a, [OptimizerParameters optimizer a], Vec (RecordedMetric a), g)  -- ^ Updated model, optimizer parameters at the end of training, vector of recorded metrics (loss is also recorded and is the first in vector), updated generator of random values.

train :: forall a g (model :: * -> *) (optimizer :: * -> *).
(Symbolic a, Floating a, Ord a, Show a, RandomGen g,
 AbstractLayer model, Optimizer optimizer) =>
model a
-> optimizer a
-> Hyperparameters a
-> Callbacks model optimizer a
-> g
-> IO
     (model a, [OptimizerParameters optimizer a],
      Vec (RecordedMetric a), g)
train model a
model optimizer a
optimizer (Hyperparameters Int
epochs Int
batchSize VecDataset a
dataset (LearningRate LearningRateFn a
lr) (Loss LossFn a
loss) (Vec Vector (Metric a)
metrics)) Callbacks model optimizer a
callbacks g
gen0 = 
    let modelIdentifier :: SymbolIdentifier
modelIdentifier = String -> SymbolIdentifier
SymbolIdentifier String
"m"
        inputIdentifier :: SymbolIdentifier
inputIdentifier = String -> SymbolIdentifier
SymbolIdentifier String
"input"
    in do
    let totalIterations :: Int
totalIterations = Int
epochs Int -> Int -> Int
forall a. Num a => a -> a -> a
* ((Vector (Sample (Vec a)) -> Int
forall a. Vector a -> Int
V.length (Vec (Sample (Vec a)) -> Vector (Sample (Vec a))
forall a. Vec a -> Vector a
unVec (Vec (Sample (Vec a)) -> Vector (Sample (Vec a)))
-> Vec (Sample (Vec a)) -> Vector (Sample (Vec a))
forall a b. (a -> b) -> a -> b
$ VecDataset a -> Vec (Sample (Vec a))
forall a. Dataset a -> Vec (Sample a)
unDataset VecDataset a
dataset) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
batchSize Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
batchSize)
    ProgressBar ()
progressBar <- Style () -> Double -> Progress () -> IO (ProgressBar ())
forall s. Style s -> Double -> Progress s -> IO (ProgressBar s)
newProgressBar
        (Style ()
forall s. Style s
defStyle { stylePrefix = exact <> msg " iterations"
                  , stylePostfix = msg "\nElapsed time: " <> elapsedTime renderDuration
                                <> msg "; Remaining time: " <> remainingTime renderDuration ""
                                <> msg "; Total time: " <> totalTime renderDuration ""
                  }
        ) Double
10 (Int -> Int -> () -> Progress ()
forall s. Int -> Int -> s -> Progress s
Progress Int
0 Int
totalIterations ())

    IORef (model a)
modelRef <- model a -> IO (IORef (model a))
forall a. a -> IO (IORef a)
newIORef model a
model

    IORef [OptimizerParameters optimizer a]
optimizerParametersRef <- [OptimizerParameters optimizer a]
-> IO (IORef [OptimizerParameters optimizer a])
forall a. a -> IO (IORef a)
newIORef ([OptimizerParameters optimizer a]
 -> IO (IORef [OptimizerParameters optimizer a]))
-> [OptimizerParameters optimizer a]
-> IO (IORef [OptimizerParameters optimizer a])
forall a b. (a -> b) -> a -> b
$ optimizer a -> Mat a -> OptimizerParameters optimizer a
forall a.
Num a =>
optimizer a -> Mat a -> OptimizerParameters optimizer a
forall (optimizer :: * -> *) a.
(Optimizer optimizer, Num a) =>
optimizer a -> Mat a -> OptimizerParameters optimizer a
optimizerInitialParameters optimizer a
optimizer (Mat a -> OptimizerParameters optimizer a)
-> (SymbolMat a -> Mat a)
-> SymbolMat a
-> OptimizerParameters optimizer a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol (SymbolMat a -> OptimizerParameters optimizer a)
-> [SymbolMat a] -> [OptimizerParameters optimizer a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SymbolIdentifier -> model a -> [SymbolMat a]
forall a. SymbolIdentifier -> model a -> [SymbolMat a]
forall (l :: * -> *) a.
AbstractLayer l =>
SymbolIdentifier -> l a -> [SymbolMat a]
getParameters SymbolIdentifier
modelIdentifier model a
model

    ((IORef (model a)
  -> IORef [OptimizerParameters optimizer a] -> IO ())
 -> IO ())
-> [IORef (model a)
    -> IORef [OptimizerParameters optimizer a] -> IO ()]
-> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (\IORef (model a) -> IORef [OptimizerParameters optimizer a] -> IO ()
fn -> IORef (model a) -> IORef [OptimizerParameters optimizer a] -> IO ()
fn IORef (model a)
modelRef IORef [OptimizerParameters optimizer a]
optimizerParametersRef) (Callbacks model optimizer a
-> [IORef (model a)
    -> IORef [OptimizerParameters optimizer a] -> IO ()]
forall (model :: * -> *) (optimizer :: * -> *) a.
Callbacks model optimizer a
-> [CallbackFnOnTrainBegin model optimizer a]
callbacksOnTrainBegin Callbacks model optimizer a
callbacks)

    Vector (MVector RealWorld a)
allMetrics <- Int
-> (Int -> IO (MVector RealWorld a))
-> IO (Vector (MVector RealWorld a))
forall (m :: * -> *) a.
Monad m =>
Int -> (Int -> m a) -> m (Vector a)
V.generateM (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Vector (Metric a) -> Int
forall a. Vector a -> Int
V.length Vector (Metric a)
metrics) (IO (MVector RealWorld a) -> Int -> IO (MVector RealWorld a)
forall a b. a -> b -> a
const (IO (MVector RealWorld a) -> Int -> IO (MVector RealWorld a))
-> IO (MVector RealWorld a) -> Int -> IO (MVector RealWorld a)
forall a b. (a -> b) -> a -> b
$ Int -> IO (MVector (PrimState IO) a)
forall (m :: * -> *) a.
PrimMonad m =>
Int -> m (MVector (PrimState m) a)
MV.new Int
totalIterations)

    IORef g
gen <- g -> IO (IORef g)
forall a. a -> IO (IORef a)
newIORef g
gen0


    IORef Int
epochRef <- Int -> IO (IORef Int)
forall a. a -> IO (IORef a)
newIORef Int
1
    IO Bool -> IO () -> IO ()
forall (m :: * -> *) a. Monad m => m Bool -> m a -> m ()
whileM_ (IORef Int -> IO Int
forall a. IORef a -> IO a
readIORef IORef Int
epochRef IO Int -> (Int -> Bool) -> IO Bool
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
epochs)) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do

        g
currentGen <- IORef g -> IO g
forall a. IORef a -> IO a
readIORef IORef g
gen
        let (VecDataset a
shuffledDataset, g
gen') = VecDataset a -> g -> (VecDataset a, g)
forall g a. RandomGen g => Dataset a -> g -> (Dataset a, g)
shuffleDataset VecDataset a
dataset g
currentGen
        ()
_ <- IORef g -> g -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef g
gen g
gen'

        IORef (BatchedDataset a)
batchedDatasetRef <- BatchedDataset a -> IO (IORef (BatchedDataset a))
forall a. a -> IO (IORef a)
newIORef (BatchedDataset a -> IO (IORef (BatchedDataset a)))
-> BatchedDataset a -> IO (IORef (BatchedDataset a))
forall a b. (a -> b) -> a -> b
$ Int -> VecDataset a -> BatchedDataset a
forall a. Int -> VecDataset a -> BatchedDataset a
batchVectors Int
batchSize VecDataset a
shuffledDataset

        ((IORef Int
  -> IORef (model a)
  -> IORef [OptimizerParameters optimizer a]
  -> IORef (BatchedDataset a)
  -> IO ())
 -> IO ())
-> [IORef Int
    -> IORef (model a)
    -> IORef [OptimizerParameters optimizer a]
    -> IORef (BatchedDataset a)
    -> IO ()]
-> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (\IORef Int
-> IORef (model a)
-> IORef [OptimizerParameters optimizer a]
-> IORef (BatchedDataset a)
-> IO ()
fn -> IORef Int
-> IORef (model a)
-> IORef [OptimizerParameters optimizer a]
-> IORef (BatchedDataset a)
-> IO ()
fn IORef Int
epochRef IORef (model a)
modelRef IORef [OptimizerParameters optimizer a]
optimizerParametersRef IORef (BatchedDataset a)
batchedDatasetRef) (Callbacks model optimizer a
-> [IORef Int
    -> IORef (model a)
    -> IORef [OptimizerParameters optimizer a]
    -> IORef (BatchedDataset a)
    -> IO ()]
forall (model :: * -> *) (optimizer :: * -> *) a.
Callbacks model optimizer a
-> [CallbackFnOnEpochBegin model optimizer a]
callbacksOnEpochBegin Callbacks model optimizer a
callbacks)

        Int
epoch <- IORef Int -> IO Int
forall a. IORef a -> IO a
readIORef IORef Int
epochRef

        BatchedDataset a
batchedDataset <- IORef (BatchedDataset a) -> IO (BatchedDataset a)
forall a. IORef a -> IO a
readIORef IORef (BatchedDataset a)
batchedDatasetRef
        IORef Int
batchIRef <- Int -> IO (IORef Int)
forall a. a -> IO (IORef a)
newIORef Int
1
        IO Bool -> IO () -> IO ()
forall (m :: * -> *) a. Monad m => m Bool -> m a -> m ()
whileM_ (IORef Int -> IO Int
forall a. IORef a -> IO a
readIORef IORef Int
batchIRef IO Int -> (Int -> Bool) -> IO Bool
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Vector (Sample (Mat a)) -> Int
forall a. Vector a -> Int
V.length (Vec (Sample (Mat a)) -> Vector (Sample (Mat a))
forall a. Vec a -> Vector a
unVec (Vec (Sample (Mat a)) -> Vector (Sample (Mat a)))
-> Vec (Sample (Mat a)) -> Vector (Sample (Mat a))
forall a b. (a -> b) -> a -> b
$ BatchedDataset a -> Vec (Sample (Mat a))
forall a. Dataset a -> Vec (Sample a)
unDataset BatchedDataset a
batchedDataset))) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
            ProgressBar () -> Int -> IO ()
forall s. ProgressBar s -> Int -> IO ()
incProgress ProgressBar ()
progressBar Int
1

            Int
batchI <- IORef Int -> IO Int
forall a. IORef a -> IO a
readIORef IORef Int
batchIRef
            let currentIteration :: Int
currentIteration = Int
epoch Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
batchI

            IORef (Sample (Mat a))
batchRef <- Sample (Mat a) -> IO (IORef (Sample (Mat a)))
forall a. a -> IO (IORef a)
newIORef (Sample (Mat a) -> IO (IORef (Sample (Mat a))))
-> Sample (Mat a) -> IO (IORef (Sample (Mat a)))
forall a b. (a -> b) -> a -> b
$ Vector (Sample (Mat a)) -> Int -> Sample (Mat a)
forall a. Vector a -> Int -> a
V.unsafeIndex (Vec (Sample (Mat a)) -> Vector (Sample (Mat a))
forall a. Vec a -> Vector a
unVec (Vec (Sample (Mat a)) -> Vector (Sample (Mat a)))
-> Vec (Sample (Mat a)) -> Vector (Sample (Mat a))
forall a b. (a -> b) -> a -> b
$ BatchedDataset a -> Vec (Sample (Mat a))
forall a. Dataset a -> Vec (Sample a)
unDataset BatchedDataset a
batchedDataset) (Int
batchI Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
            IORef a
lrValueRef <- a -> IO (IORef a)
forall a. a -> IO (IORef a)
newIORef (a -> IO (IORef a)) -> a -> IO (IORef a)
forall a b. (a -> b) -> a -> b
$ LearningRateFn a
lr Int
currentIteration

            ((IORef Int
  -> IORef Int
  -> IORef (model a)
  -> IORef [OptimizerParameters optimizer a]
  -> IORef (Sample (Mat a))
  -> IORef a
  -> IO ())
 -> IO ())
-> [IORef Int
    -> IORef Int
    -> IORef (model a)
    -> IORef [OptimizerParameters optimizer a]
    -> IORef (Sample (Mat a))
    -> IORef a
    -> IO ()]
-> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (\IORef Int
-> IORef Int
-> IORef (model a)
-> IORef [OptimizerParameters optimizer a]
-> IORef (Sample (Mat a))
-> IORef a
-> IO ()
fn -> IORef Int
-> IORef Int
-> IORef (model a)
-> IORef [OptimizerParameters optimizer a]
-> IORef (Sample (Mat a))
-> IORef a
-> IO ()
fn IORef Int
epochRef IORef Int
batchIRef IORef (model a)
modelRef IORef [OptimizerParameters optimizer a]
optimizerParametersRef IORef (Sample (Mat a))
batchRef IORef a
lrValueRef) (Callbacks model optimizer a
-> [IORef Int
    -> IORef Int
    -> IORef (model a)
    -> IORef [OptimizerParameters optimizer a]
    -> IORef (Sample (Mat a))
    -> IORef a
    -> IO ()]
forall (model :: * -> *) (optimizer :: * -> *) a.
Callbacks model optimizer a
-> [CallbackFnOnBatchBegin model optimizer a]
callbacksOnBatchBegin Callbacks model optimizer a
callbacks)

            (Sample Mat a
batchInput Mat a
batchOutput) <- IORef (Sample (Mat a)) -> IO (Sample (Mat a))
forall a. IORef a -> IO a
readIORef IORef (Sample (Mat a))
batchRef

            (SymbolMat a
prediction, SymbolMat a
regularizersLoss) <- IORef (model a) -> IO (model a)
forall a. IORef a -> IO a
readIORef IORef (model a)
modelRef IO (model a)
-> (model a -> (SymbolMat a, SymbolMat a))
-> IO (SymbolMat a, SymbolMat a)
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> SymbolIdentifier
-> SymbolMat a -> model a -> (SymbolMat a, SymbolMat a)
forall a.
(Symbolic a, Floating a, Ord a) =>
SymbolIdentifier
-> SymbolMat a -> model a -> (SymbolMat a, SymbolMat a)
forall (l :: * -> *) a.
(AbstractLayer l, Symbolic a, Floating a, Ord a) =>
SymbolIdentifier
-> SymbolMat a -> l a -> (SymbolMat a, SymbolMat a)
symbolicForward SymbolIdentifier
modelIdentifier (SymbolIdentifier -> Mat a -> SymbolMat a
forall a. SymbolIdentifier -> a -> Symbol a
symbol SymbolIdentifier
inputIdentifier Mat a
batchInput)

            a
lrValue <- IORef a -> IO a
forall a. IORef a -> IO a
readIORef IORef a
lrValueRef
            let lossValue :: SymbolMat a
lossValue = LossFn a
loss (Mat a -> SymbolMat a
forall a. a -> Symbol a
constSymbol Mat a
batchOutput) SymbolMat a
prediction

            [SymbolMat a]
parameters <- IORef (model a) -> IO (model a)
forall a. IORef a -> IO a
readIORef IORef (model a)
modelRef IO (model a) -> (model a -> [SymbolMat a]) -> IO [SymbolMat a]
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> SymbolIdentifier -> model a -> [SymbolMat a]
forall a. SymbolIdentifier -> model a -> [SymbolMat a]
forall (l :: * -> *) a.
AbstractLayer l =>
SymbolIdentifier -> l a -> [SymbolMat a]
getParameters SymbolIdentifier
modelIdentifier
            [OptimizerParameters optimizer a]
optimizerParameters <- IORef [OptimizerParameters optimizer a]
-> IO [OptimizerParameters optimizer a]
forall a. IORef a -> IO a
readIORef IORef [OptimizerParameters optimizer a]
optimizerParametersRef
            let ([Mat a]
parameters', [OptimizerParameters optimizer a]
optimizerParameters') = [(Mat a, OptimizerParameters optimizer a)]
-> ([Mat a], [OptimizerParameters optimizer a])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Mat a, OptimizerParameters optimizer a)]
 -> ([Mat a], [OptimizerParameters optimizer a]))
-> [(Mat a, OptimizerParameters optimizer a)]
-> ([Mat a], [OptimizerParameters optimizer a])
forall a b. (a -> b) -> a -> b
$ optimizer a
-> (a, Gradients (Mat a))
-> [(SymbolMat a, OptimizerParameters optimizer a)]
-> [(Mat a, OptimizerParameters optimizer a)]
forall a (optimizer :: * -> *).
(Symbolic a, Optimizer optimizer) =>
optimizer a
-> (a, Gradients (Mat a))
-> [(SymbolMat a, OptimizerParameters optimizer a)]
-> [(Mat a, OptimizerParameters optimizer a)]
optimizerUpdateParameters optimizer a
optimizer (a
lrValue, SymbolMat a -> Gradients (Mat a)
forall a. Symbolic a => Symbol a -> Gradients a
getGradientsOf (SymbolMat a -> Gradients (Mat a))
-> SymbolMat a -> Gradients (Mat a)
forall a b. (a -> b) -> a -> b
$ SymbolMat a
lossValue LossFn a
forall a. Num a => a -> a -> a
+ SymbolMat a
regularizersLoss)
                                                                                        ([SymbolMat a]
-> [OptimizerParameters optimizer a]
-> [(SymbolMat a, OptimizerParameters optimizer a)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SymbolMat a]
parameters [OptimizerParameters optimizer a]
optimizerParameters)

            ()
_ <- IORef (model a) -> (model a -> model a) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef' IORef (model a)
modelRef (model a -> [Mat a] -> model a
forall a. model a -> [Mat a] -> model a
forall (l :: * -> *) a. AbstractLayer l => l a -> [Mat a] -> l a
`updateParameters` [Mat a]
parameters')
            ()
_ <- IORef [OptimizerParameters optimizer a]
-> [OptimizerParameters optimizer a] -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef [OptimizerParameters optimizer a]
optimizerParametersRef [OptimizerParameters optimizer a]
optimizerParameters'

            IORef (Vec a)
metricsValuesRef <- Vec a -> IO (IORef (Vec a))
forall a. a -> IO (IORef a)
newIORef (Vec a -> IO (IORef (Vec a))) -> Vec a -> IO (IORef (Vec a))
forall a b. (a -> b) -> a -> b
$ Vector a -> Vec a
forall a. Vector a -> Vec a
Vec (Vector a -> Vec a) -> Vector a -> Vec a
forall a b. (a -> b) -> a -> b
$ a -> Vector a -> Vector a
forall a. a -> Vector a -> Vector a
V.cons (SymbolMat a -> DType (SymbolMat a)
forall f. SingletonOps f => f -> DType f
unSingleton SymbolMat a
lossValue) (Vector a -> Vector a) -> Vector a -> Vector a
forall a b. (a -> b) -> a -> b
$ (Metric a -> a) -> Vector (Metric a) -> Vector a
forall a b. (a -> b) -> Vector a -> Vector b
V.map (\(Metric MetricFn a
metric) -> Mat a -> DType (Mat a)
forall f. SingletonOps f => f -> DType f
unSingleton (Mat a -> DType (Mat a)) -> Mat a -> DType (Mat a)
forall a b. (a -> b) -> a -> b
$ MetricFn a
metric Mat a
batchOutput (SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol SymbolMat a
prediction)) Vector (Metric a)
metrics

            ((IORef Int
  -> IORef Int
  -> IORef (model a)
  -> IORef [OptimizerParameters optimizer a]
  -> IORef (Vec a)
  -> IO ())
 -> IO ())
-> [IORef Int
    -> IORef Int
    -> IORef (model a)
    -> IORef [OptimizerParameters optimizer a]
    -> IORef (Vec a)
    -> IO ()]
-> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (\IORef Int
-> IORef Int
-> IORef (model a)
-> IORef [OptimizerParameters optimizer a]
-> IORef (Vec a)
-> IO ()
fn -> IORef Int
-> IORef Int
-> IORef (model a)
-> IORef [OptimizerParameters optimizer a]
-> IORef (Vec a)
-> IO ()
fn IORef Int
epochRef IORef Int
batchIRef IORef (model a)
modelRef IORef [OptimizerParameters optimizer a]
optimizerParametersRef IORef (Vec a)
metricsValuesRef) (Callbacks model optimizer a
-> [IORef Int
    -> IORef Int
    -> IORef (model a)
    -> IORef [OptimizerParameters optimizer a]
    -> IORef (Vec a)
    -> IO ()]
forall (model :: * -> *) (optimizer :: * -> *) a.
Callbacks model optimizer a
-> [CallbackFnOnBatchEnd model optimizer a]
callbacksOnBatchEnd Callbacks model optimizer a
callbacks)

            (Vec Vector a
metricsValues) <- IORef (Vec a) -> IO (Vec a)
forall a. IORef a -> IO a
readIORef IORef (Vec a)
metricsValuesRef
            [Int] -> (Int -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0 .. (Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
metricsValues Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)] ((Int -> IO ()) -> IO ()) -> (Int -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Int
metricI -> MVector (PrimState IO) a -> Int -> a -> IO ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write (Vector (MVector RealWorld a) -> Int -> MVector RealWorld a
forall a. Vector a -> Int -> a
V.unsafeIndex Vector (MVector RealWorld a)
allMetrics Int
metricI) (Int
currentIteration Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (Vector a -> LearningRateFn a
forall a. Vector a -> Int -> a
V.unsafeIndex Vector a
metricsValues Int
metricI)

            IORef Int -> (Int -> Int) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef' IORef Int
batchIRef (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)

        ((IORef Int
  -> IORef (model a)
  -> IORef [OptimizerParameters optimizer a]
  -> IO ())
 -> IO ())
-> [IORef Int
    -> IORef (model a)
    -> IORef [OptimizerParameters optimizer a]
    -> IO ()]
-> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (\IORef Int
-> IORef (model a)
-> IORef [OptimizerParameters optimizer a]
-> IO ()
fn -> IORef Int
-> IORef (model a)
-> IORef [OptimizerParameters optimizer a]
-> IO ()
fn IORef Int
epochRef IORef (model a)
modelRef IORef [OptimizerParameters optimizer a]
optimizerParametersRef) (Callbacks model optimizer a
-> [IORef Int
    -> IORef (model a)
    -> IORef [OptimizerParameters optimizer a]
    -> IO ()]
forall (model :: * -> *) (optimizer :: * -> *) a.
Callbacks model optimizer a
-> [CallbackFnOnEpochEnd model optimizer a]
callbacksOnEpochEnd Callbacks model optimizer a
callbacks)

        IORef Int -> (Int -> Int) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef' IORef Int
epochRef (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)


    IORef (Vec (RecordedMetric a))
recordedMetricsRef <- (MVector RealWorld a -> IO (RecordedMetric a))
-> Vector (MVector RealWorld a) -> IO (Vector (RecordedMetric a))
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Vector a -> m (Vector b)
V.mapM ((Vector a -> RecordedMetric a)
-> IO (Vector a) -> IO (RecordedMetric a)
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Vec a -> RecordedMetric a
forall a. Vec a -> RecordedMetric a
RecordedMetric (Vec a -> RecordedMetric a)
-> (Vector a -> Vec a) -> Vector a -> RecordedMetric a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector a -> Vec a
forall a. Vector a -> Vec a
Vec) (IO (Vector a) -> IO (RecordedMetric a))
-> (MVector RealWorld a -> IO (Vector a))
-> MVector RealWorld a
-> IO (RecordedMetric a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MVector RealWorld a -> IO (Vector a)
MVector (PrimState IO) a -> IO (Vector a)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
V.unsafeFreeze) Vector (MVector RealWorld a)
allMetrics IO (Vector (RecordedMetric a))
-> (Vector (RecordedMetric a)
    -> IO (IORef (Vec (RecordedMetric a))))
-> IO (IORef (Vec (RecordedMetric a)))
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Vec (RecordedMetric a) -> IO (IORef (Vec (RecordedMetric a)))
forall a. a -> IO (IORef a)
newIORef (Vec (RecordedMetric a) -> IO (IORef (Vec (RecordedMetric a))))
-> (Vector (RecordedMetric a) -> Vec (RecordedMetric a))
-> Vector (RecordedMetric a)
-> IO (IORef (Vec (RecordedMetric a)))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector (RecordedMetric a) -> Vec (RecordedMetric a)
forall a. Vector a -> Vec a
Vec

    ((IORef (model a)
  -> IORef [OptimizerParameters optimizer a]
  -> IORef (Vec (RecordedMetric a))
  -> IO ())
 -> IO ())
-> [IORef (model a)
    -> IORef [OptimizerParameters optimizer a]
    -> IORef (Vec (RecordedMetric a))
    -> IO ()]
-> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (\IORef (model a)
-> IORef [OptimizerParameters optimizer a]
-> IORef (Vec (RecordedMetric a))
-> IO ()
fn -> IORef (model a)
-> IORef [OptimizerParameters optimizer a]
-> IORef (Vec (RecordedMetric a))
-> IO ()
fn IORef (model a)
modelRef IORef [OptimizerParameters optimizer a]
optimizerParametersRef IORef (Vec (RecordedMetric a))
recordedMetricsRef) (Callbacks model optimizer a
-> [IORef (model a)
    -> IORef [OptimizerParameters optimizer a]
    -> IORef (Vec (RecordedMetric a))
    -> IO ()]
forall (model :: * -> *) (optimizer :: * -> *) a.
Callbacks model optimizer a
-> [CallbackFnOnTrainEnd model optimizer a]
callbacksOnTrainEnd Callbacks model optimizer a
callbacks)

    model a
trainedModel <- IORef (model a) -> IO (model a)
forall a. IORef a -> IO a
readIORef IORef (model a)
modelRef
    [OptimizerParameters optimizer a]
trainedOptimizerParameters <- IORef [OptimizerParameters optimizer a]
-> IO [OptimizerParameters optimizer a]
forall a. IORef a -> IO a
readIORef IORef [OptimizerParameters optimizer a]
optimizerParametersRef
    Vec (RecordedMetric a)
recordedMetrics <- IORef (Vec (RecordedMetric a)) -> IO (Vec (RecordedMetric a))
forall a. IORef a -> IO a
readIORef IORef (Vec (RecordedMetric a))
recordedMetricsRef
    g
gen'' <- IORef g -> IO g
forall a. IORef a -> IO a
readIORef IORef g
gen

    (model a, [OptimizerParameters optimizer a],
 Vec (RecordedMetric a), g)
-> IO
     (model a, [OptimizerParameters optimizer a],
      Vec (RecordedMetric a), g)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (model a
trainedModel, [OptimizerParameters optimizer a]
trainedOptimizerParameters, Vec (RecordedMetric a)
recordedMetrics, g
gen'')