{-# LANGUAGE OverloadedStrings #-}
module Synapse.NN.Training
(
CallbackFnOnTrainBegin
, CallbackFnOnEpochBegin
, CallbackFnOnBatchBegin
, CallbackFnOnBatchEnd
, CallbackFnOnEpochEnd
, CallbackFnOnTrainEnd
, Callbacks
( Callbacks
, callbacksOnTrainBegin
, callbacksOnEpochBegin
, callbacksOnBatchBegin
, callbacksOnBatchEnd
, callbacksOnEpochEnd
, callbacksOnTrainEnd
)
, emptyCallbacks
, Hyperparameters
( Hyperparameters
, hyperparametersEpochs
, hyperparametersBatchSize
, hyperparametersDataset
, hyperparametersLearningRate
, hyperparametersLoss
, hyperparametersMetrics
)
, RecordedMetric (RecordedMetric, unRecordedMetric)
, train
) where
import Synapse.Tensors (SingletonOps(unSingleton))
import Synapse.Tensors.Vec (Vec(Vec), unVec)
import Synapse.Tensors.Mat (Mat)
import Synapse.Autograd (Symbolic, SymbolIdentifier(SymbolIdentifier), unSymbol, symbol, constSymbol, getGradientsOf)
import Synapse.NN.Layers.Layer (AbstractLayer(..))
import Synapse.NN.Optimizers (Optimizer(..), optimizerUpdateParameters)
import Synapse.NN.Batching (Sample(Sample), unDataset, shuffleDataset, VecDataset, batchVectors, BatchedDataset)
import Synapse.NN.LearningRates (LearningRate(LearningRate))
import Synapse.NN.Losses (Loss(Loss))
import Synapse.NN.Metrics (Metric(Metric))
import Data.Functor ((<&>))
import Control.Monad (forM_, when)
import Data.IORef (IORef, newIORef, writeIORef, readIORef, modifyIORef')
import System.Random (RandomGen)
import System.ProgressBar
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as MV
type CallbackFnOnTrainBegin model optimizer a
= IORef (model a)
-> IORef [OptimizerParameters optimizer a]
-> IO ()
type CallbackFnOnEpochBegin model optimizer a
= IORef Int
-> IORef (model a)
-> IORef [OptimizerParameters optimizer a]
-> IORef (BatchedDataset a)
-> IO ()
type CallbackFnOnBatchBegin model optimizer a
= IORef Int
-> IORef Int
-> IORef (model a)
-> IORef [OptimizerParameters optimizer a]
-> IORef (Sample (Mat a))
-> IORef a
-> IO ()
type CallbackFnOnBatchEnd model optimizer a
= IORef Int
-> IORef Int
-> IORef (model a)
-> IORef [OptimizerParameters optimizer a]
-> IORef (Vec a)
-> IO ()
type CallbackFnOnEpochEnd model optimizer a
= IORef Int
-> IORef (model a)
-> IORef [OptimizerParameters optimizer a]
-> IO ()
type CallbackFnOnTrainEnd model optimizer a
= IORef (model a)
-> IORef [OptimizerParameters optimizer a]
-> IORef (Vec (RecordedMetric a))
-> IO ()
data Callbacks model optimizer a = Callbacks
{ forall (model :: * -> *) (optimizer :: * -> *) a.
Callbacks model optimizer a
-> [CallbackFnOnTrainBegin model optimizer a]
callbacksOnTrainBegin :: [CallbackFnOnTrainBegin model optimizer a]
, forall (model :: * -> *) (optimizer :: * -> *) a.
Callbacks model optimizer a
-> [CallbackFnOnEpochBegin model optimizer a]
callbacksOnEpochBegin :: [CallbackFnOnEpochBegin model optimizer a]
, forall (model :: * -> *) (optimizer :: * -> *) a.
Callbacks model optimizer a
-> [CallbackFnOnBatchBegin model optimizer a]
callbacksOnBatchBegin :: [CallbackFnOnBatchBegin model optimizer a]
, forall (model :: * -> *) (optimizer :: * -> *) a.
Callbacks model optimizer a
-> [CallbackFnOnBatchEnd model optimizer a]
callbacksOnBatchEnd :: [CallbackFnOnBatchEnd model optimizer a]
, forall (model :: * -> *) (optimizer :: * -> *) a.
Callbacks model optimizer a
-> [CallbackFnOnEpochEnd model optimizer a]
callbacksOnEpochEnd :: [CallbackFnOnEpochEnd model optimizer a]
, forall (model :: * -> *) (optimizer :: * -> *) a.
Callbacks model optimizer a
-> [CallbackFnOnTrainEnd model optimizer a]
callbacksOnTrainEnd :: [CallbackFnOnTrainEnd model optimizer a]
}
emptyCallbacks :: Callbacks model optimizer a
emptyCallbacks :: forall (model :: * -> *) (optimizer :: * -> *) a.
Callbacks model optimizer a
emptyCallbacks = [CallbackFnOnTrainBegin model optimizer a]
-> [CallbackFnOnEpochBegin model optimizer a]
-> [CallbackFnOnBatchBegin model optimizer a]
-> [CallbackFnOnBatchEnd model optimizer a]
-> [CallbackFnOnEpochEnd model optimizer a]
-> [CallbackFnOnTrainEnd model optimizer a]
-> Callbacks model optimizer a
forall (model :: * -> *) (optimizer :: * -> *) a.
[CallbackFnOnTrainBegin model optimizer a]
-> [CallbackFnOnEpochBegin model optimizer a]
-> [CallbackFnOnBatchBegin model optimizer a]
-> [CallbackFnOnBatchEnd model optimizer a]
-> [CallbackFnOnEpochEnd model optimizer a]
-> [CallbackFnOnTrainEnd model optimizer a]
-> Callbacks model optimizer a
Callbacks [] [] [] [] [] []
data Hyperparameters a = Hyperparameters
{ forall a. Hyperparameters a -> Int
hyperparametersEpochs :: Int
, forall a. Hyperparameters a -> Int
hyperparametersBatchSize :: Int
, forall a. Hyperparameters a -> VecDataset a
hyperparametersDataset :: VecDataset a
, forall a. Hyperparameters a -> LearningRate a
hyperparametersLearningRate :: LearningRate a
, forall a. Hyperparameters a -> Loss a
hyperparametersLoss :: Loss a
, forall a. Hyperparameters a -> Vec (Metric a)
hyperparametersMetrics :: Vec (Metric a)
}
newtype RecordedMetric a = RecordedMetric
{ forall a. RecordedMetric a -> Vec a
unRecordedMetric :: Vec a
}
whileM_ :: (Monad m) => m Bool -> m a -> m ()
whileM_ :: forall (m :: * -> *) a. Monad m => m Bool -> m a -> m ()
whileM_ m Bool
p m a
f = m ()
go
where
go :: m ()
go = m Bool
p m Bool -> (Bool -> m ()) -> m ()
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Bool -> m () -> m ()) -> m () -> Bool -> m ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (m a
f m a -> m () -> m ()
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> m ()
go)
train
:: (Symbolic a, Floating a, Ord a, Show a, RandomGen g, AbstractLayer model, Optimizer optimizer)
=> model a
-> optimizer a
-> Hyperparameters a
-> Callbacks model optimizer a
-> g
-> IO (model a, [OptimizerParameters optimizer a], Vec (RecordedMetric a), g)
train :: forall a g (model :: * -> *) (optimizer :: * -> *).
(Symbolic a, Floating a, Ord a, Show a, RandomGen g,
AbstractLayer model, Optimizer optimizer) =>
model a
-> optimizer a
-> Hyperparameters a
-> Callbacks model optimizer a
-> g
-> IO
(model a, [OptimizerParameters optimizer a],
Vec (RecordedMetric a), g)
train model a
model optimizer a
optimizer (Hyperparameters Int
epochs Int
batchSize VecDataset a
dataset (LearningRate LearningRateFn a
lr) (Loss LossFn a
loss) (Vec Vector (Metric a)
metrics)) Callbacks model optimizer a
callbacks g
gen0 =
let modelIdentifier :: SymbolIdentifier
modelIdentifier = String -> SymbolIdentifier
SymbolIdentifier String
"m"
inputIdentifier :: SymbolIdentifier
inputIdentifier = String -> SymbolIdentifier
SymbolIdentifier String
"input"
in do
let totalIterations :: Int
totalIterations = Int
epochs Int -> Int -> Int
forall a. Num a => a -> a -> a
* ((Vector (Sample (Vec a)) -> Int
forall a. Vector a -> Int
V.length (Vec (Sample (Vec a)) -> Vector (Sample (Vec a))
forall a. Vec a -> Vector a
unVec (Vec (Sample (Vec a)) -> Vector (Sample (Vec a)))
-> Vec (Sample (Vec a)) -> Vector (Sample (Vec a))
forall a b. (a -> b) -> a -> b
$ VecDataset a -> Vec (Sample (Vec a))
forall a. Dataset a -> Vec (Sample a)
unDataset VecDataset a
dataset) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
batchSize Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
batchSize)
ProgressBar ()
progressBar <- Style () -> Double -> Progress () -> IO (ProgressBar ())
forall s. Style s -> Double -> Progress s -> IO (ProgressBar s)
newProgressBar
(Style ()
forall s. Style s
defStyle { stylePrefix = exact <> msg " iterations"
, stylePostfix = msg "\nElapsed time: " <> elapsedTime renderDuration
<> msg "; Remaining time: " <> remainingTime renderDuration ""
<> msg "; Total time: " <> totalTime renderDuration ""
}
) Double
10 (Int -> Int -> () -> Progress ()
forall s. Int -> Int -> s -> Progress s
Progress Int
0 Int
totalIterations ())
IORef (model a)
modelRef <- model a -> IO (IORef (model a))
forall a. a -> IO (IORef a)
newIORef model a
model
IORef [OptimizerParameters optimizer a]
optimizerParametersRef <- [OptimizerParameters optimizer a]
-> IO (IORef [OptimizerParameters optimizer a])
forall a. a -> IO (IORef a)
newIORef ([OptimizerParameters optimizer a]
-> IO (IORef [OptimizerParameters optimizer a]))
-> [OptimizerParameters optimizer a]
-> IO (IORef [OptimizerParameters optimizer a])
forall a b. (a -> b) -> a -> b
$ optimizer a -> Mat a -> OptimizerParameters optimizer a
forall a.
Num a =>
optimizer a -> Mat a -> OptimizerParameters optimizer a
forall (optimizer :: * -> *) a.
(Optimizer optimizer, Num a) =>
optimizer a -> Mat a -> OptimizerParameters optimizer a
optimizerInitialParameters optimizer a
optimizer (Mat a -> OptimizerParameters optimizer a)
-> (SymbolMat a -> Mat a)
-> SymbolMat a
-> OptimizerParameters optimizer a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol (SymbolMat a -> OptimizerParameters optimizer a)
-> [SymbolMat a] -> [OptimizerParameters optimizer a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SymbolIdentifier -> model a -> [SymbolMat a]
forall a. SymbolIdentifier -> model a -> [SymbolMat a]
forall (l :: * -> *) a.
AbstractLayer l =>
SymbolIdentifier -> l a -> [SymbolMat a]
getParameters SymbolIdentifier
modelIdentifier model a
model
((IORef (model a)
-> IORef [OptimizerParameters optimizer a] -> IO ())
-> IO ())
-> [IORef (model a)
-> IORef [OptimizerParameters optimizer a] -> IO ()]
-> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (\IORef (model a) -> IORef [OptimizerParameters optimizer a] -> IO ()
fn -> IORef (model a) -> IORef [OptimizerParameters optimizer a] -> IO ()
fn IORef (model a)
modelRef IORef [OptimizerParameters optimizer a]
optimizerParametersRef) (Callbacks model optimizer a
-> [IORef (model a)
-> IORef [OptimizerParameters optimizer a] -> IO ()]
forall (model :: * -> *) (optimizer :: * -> *) a.
Callbacks model optimizer a
-> [CallbackFnOnTrainBegin model optimizer a]
callbacksOnTrainBegin Callbacks model optimizer a
callbacks)
Vector (MVector RealWorld a)
allMetrics <- Int
-> (Int -> IO (MVector RealWorld a))
-> IO (Vector (MVector RealWorld a))
forall (m :: * -> *) a.
Monad m =>
Int -> (Int -> m a) -> m (Vector a)
V.generateM (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Vector (Metric a) -> Int
forall a. Vector a -> Int
V.length Vector (Metric a)
metrics) (IO (MVector RealWorld a) -> Int -> IO (MVector RealWorld a)
forall a b. a -> b -> a
const (IO (MVector RealWorld a) -> Int -> IO (MVector RealWorld a))
-> IO (MVector RealWorld a) -> Int -> IO (MVector RealWorld a)
forall a b. (a -> b) -> a -> b
$ Int -> IO (MVector (PrimState IO) a)
forall (m :: * -> *) a.
PrimMonad m =>
Int -> m (MVector (PrimState m) a)
MV.new Int
totalIterations)
IORef g
gen <- g -> IO (IORef g)
forall a. a -> IO (IORef a)
newIORef g
gen0
IORef Int
epochRef <- Int -> IO (IORef Int)
forall a. a -> IO (IORef a)
newIORef Int
1
IO Bool -> IO () -> IO ()
forall (m :: * -> *) a. Monad m => m Bool -> m a -> m ()
whileM_ (IORef Int -> IO Int
forall a. IORef a -> IO a
readIORef IORef Int
epochRef IO Int -> (Int -> Bool) -> IO Bool
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
epochs)) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
g
currentGen <- IORef g -> IO g
forall a. IORef a -> IO a
readIORef IORef g
gen
let (VecDataset a
shuffledDataset, g
gen') = VecDataset a -> g -> (VecDataset a, g)
forall g a. RandomGen g => Dataset a -> g -> (Dataset a, g)
shuffleDataset VecDataset a
dataset g
currentGen
()
_ <- IORef g -> g -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef g
gen g
gen'
IORef (BatchedDataset a)
batchedDatasetRef <- BatchedDataset a -> IO (IORef (BatchedDataset a))
forall a. a -> IO (IORef a)
newIORef (BatchedDataset a -> IO (IORef (BatchedDataset a)))
-> BatchedDataset a -> IO (IORef (BatchedDataset a))
forall a b. (a -> b) -> a -> b
$ Int -> VecDataset a -> BatchedDataset a
forall a. Int -> VecDataset a -> BatchedDataset a
batchVectors Int
batchSize VecDataset a
shuffledDataset
((IORef Int
-> IORef (model a)
-> IORef [OptimizerParameters optimizer a]
-> IORef (BatchedDataset a)
-> IO ())
-> IO ())
-> [IORef Int
-> IORef (model a)
-> IORef [OptimizerParameters optimizer a]
-> IORef (BatchedDataset a)
-> IO ()]
-> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (\IORef Int
-> IORef (model a)
-> IORef [OptimizerParameters optimizer a]
-> IORef (BatchedDataset a)
-> IO ()
fn -> IORef Int
-> IORef (model a)
-> IORef [OptimizerParameters optimizer a]
-> IORef (BatchedDataset a)
-> IO ()
fn IORef Int
epochRef IORef (model a)
modelRef IORef [OptimizerParameters optimizer a]
optimizerParametersRef IORef (BatchedDataset a)
batchedDatasetRef) (Callbacks model optimizer a
-> [IORef Int
-> IORef (model a)
-> IORef [OptimizerParameters optimizer a]
-> IORef (BatchedDataset a)
-> IO ()]
forall (model :: * -> *) (optimizer :: * -> *) a.
Callbacks model optimizer a
-> [CallbackFnOnEpochBegin model optimizer a]
callbacksOnEpochBegin Callbacks model optimizer a
callbacks)
Int
epoch <- IORef Int -> IO Int
forall a. IORef a -> IO a
readIORef IORef Int
epochRef
BatchedDataset a
batchedDataset <- IORef (BatchedDataset a) -> IO (BatchedDataset a)
forall a. IORef a -> IO a
readIORef IORef (BatchedDataset a)
batchedDatasetRef
IORef Int
batchIRef <- Int -> IO (IORef Int)
forall a. a -> IO (IORef a)
newIORef Int
1
IO Bool -> IO () -> IO ()
forall (m :: * -> *) a. Monad m => m Bool -> m a -> m ()
whileM_ (IORef Int -> IO Int
forall a. IORef a -> IO a
readIORef IORef Int
batchIRef IO Int -> (Int -> Bool) -> IO Bool
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Vector (Sample (Mat a)) -> Int
forall a. Vector a -> Int
V.length (Vec (Sample (Mat a)) -> Vector (Sample (Mat a))
forall a. Vec a -> Vector a
unVec (Vec (Sample (Mat a)) -> Vector (Sample (Mat a)))
-> Vec (Sample (Mat a)) -> Vector (Sample (Mat a))
forall a b. (a -> b) -> a -> b
$ BatchedDataset a -> Vec (Sample (Mat a))
forall a. Dataset a -> Vec (Sample a)
unDataset BatchedDataset a
batchedDataset))) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
ProgressBar () -> Int -> IO ()
forall s. ProgressBar s -> Int -> IO ()
incProgress ProgressBar ()
progressBar Int
1
Int
batchI <- IORef Int -> IO Int
forall a. IORef a -> IO a
readIORef IORef Int
batchIRef
let currentIteration :: Int
currentIteration = Int
epoch Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
batchI
IORef (Sample (Mat a))
batchRef <- Sample (Mat a) -> IO (IORef (Sample (Mat a)))
forall a. a -> IO (IORef a)
newIORef (Sample (Mat a) -> IO (IORef (Sample (Mat a))))
-> Sample (Mat a) -> IO (IORef (Sample (Mat a)))
forall a b. (a -> b) -> a -> b
$ Vector (Sample (Mat a)) -> Int -> Sample (Mat a)
forall a. Vector a -> Int -> a
V.unsafeIndex (Vec (Sample (Mat a)) -> Vector (Sample (Mat a))
forall a. Vec a -> Vector a
unVec (Vec (Sample (Mat a)) -> Vector (Sample (Mat a)))
-> Vec (Sample (Mat a)) -> Vector (Sample (Mat a))
forall a b. (a -> b) -> a -> b
$ BatchedDataset a -> Vec (Sample (Mat a))
forall a. Dataset a -> Vec (Sample a)
unDataset BatchedDataset a
batchedDataset) (Int
batchI Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
IORef a
lrValueRef <- a -> IO (IORef a)
forall a. a -> IO (IORef a)
newIORef (a -> IO (IORef a)) -> a -> IO (IORef a)
forall a b. (a -> b) -> a -> b
$ LearningRateFn a
lr Int
currentIteration
((IORef Int
-> IORef Int
-> IORef (model a)
-> IORef [OptimizerParameters optimizer a]
-> IORef (Sample (Mat a))
-> IORef a
-> IO ())
-> IO ())
-> [IORef Int
-> IORef Int
-> IORef (model a)
-> IORef [OptimizerParameters optimizer a]
-> IORef (Sample (Mat a))
-> IORef a
-> IO ()]
-> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (\IORef Int
-> IORef Int
-> IORef (model a)
-> IORef [OptimizerParameters optimizer a]
-> IORef (Sample (Mat a))
-> IORef a
-> IO ()
fn -> IORef Int
-> IORef Int
-> IORef (model a)
-> IORef [OptimizerParameters optimizer a]
-> IORef (Sample (Mat a))
-> IORef a
-> IO ()
fn IORef Int
epochRef IORef Int
batchIRef IORef (model a)
modelRef IORef [OptimizerParameters optimizer a]
optimizerParametersRef IORef (Sample (Mat a))
batchRef IORef a
lrValueRef) (Callbacks model optimizer a
-> [IORef Int
-> IORef Int
-> IORef (model a)
-> IORef [OptimizerParameters optimizer a]
-> IORef (Sample (Mat a))
-> IORef a
-> IO ()]
forall (model :: * -> *) (optimizer :: * -> *) a.
Callbacks model optimizer a
-> [CallbackFnOnBatchBegin model optimizer a]
callbacksOnBatchBegin Callbacks model optimizer a
callbacks)
(Sample Mat a
batchInput Mat a
batchOutput) <- IORef (Sample (Mat a)) -> IO (Sample (Mat a))
forall a. IORef a -> IO a
readIORef IORef (Sample (Mat a))
batchRef
(SymbolMat a
prediction, SymbolMat a
regularizersLoss) <- IORef (model a) -> IO (model a)
forall a. IORef a -> IO a
readIORef IORef (model a)
modelRef IO (model a)
-> (model a -> (SymbolMat a, SymbolMat a))
-> IO (SymbolMat a, SymbolMat a)
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> SymbolIdentifier
-> SymbolMat a -> model a -> (SymbolMat a, SymbolMat a)
forall a.
(Symbolic a, Floating a, Ord a) =>
SymbolIdentifier
-> SymbolMat a -> model a -> (SymbolMat a, SymbolMat a)
forall (l :: * -> *) a.
(AbstractLayer l, Symbolic a, Floating a, Ord a) =>
SymbolIdentifier
-> SymbolMat a -> l a -> (SymbolMat a, SymbolMat a)
symbolicForward SymbolIdentifier
modelIdentifier (SymbolIdentifier -> Mat a -> SymbolMat a
forall a. SymbolIdentifier -> a -> Symbol a
symbol SymbolIdentifier
inputIdentifier Mat a
batchInput)
a
lrValue <- IORef a -> IO a
forall a. IORef a -> IO a
readIORef IORef a
lrValueRef
let lossValue :: SymbolMat a
lossValue = LossFn a
loss (Mat a -> SymbolMat a
forall a. a -> Symbol a
constSymbol Mat a
batchOutput) SymbolMat a
prediction
[SymbolMat a]
parameters <- IORef (model a) -> IO (model a)
forall a. IORef a -> IO a
readIORef IORef (model a)
modelRef IO (model a) -> (model a -> [SymbolMat a]) -> IO [SymbolMat a]
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> SymbolIdentifier -> model a -> [SymbolMat a]
forall a. SymbolIdentifier -> model a -> [SymbolMat a]
forall (l :: * -> *) a.
AbstractLayer l =>
SymbolIdentifier -> l a -> [SymbolMat a]
getParameters SymbolIdentifier
modelIdentifier
[OptimizerParameters optimizer a]
optimizerParameters <- IORef [OptimizerParameters optimizer a]
-> IO [OptimizerParameters optimizer a]
forall a. IORef a -> IO a
readIORef IORef [OptimizerParameters optimizer a]
optimizerParametersRef
let ([Mat a]
parameters', [OptimizerParameters optimizer a]
optimizerParameters') = [(Mat a, OptimizerParameters optimizer a)]
-> ([Mat a], [OptimizerParameters optimizer a])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Mat a, OptimizerParameters optimizer a)]
-> ([Mat a], [OptimizerParameters optimizer a]))
-> [(Mat a, OptimizerParameters optimizer a)]
-> ([Mat a], [OptimizerParameters optimizer a])
forall a b. (a -> b) -> a -> b
$ optimizer a
-> (a, Gradients (Mat a))
-> [(SymbolMat a, OptimizerParameters optimizer a)]
-> [(Mat a, OptimizerParameters optimizer a)]
forall a (optimizer :: * -> *).
(Symbolic a, Optimizer optimizer) =>
optimizer a
-> (a, Gradients (Mat a))
-> [(SymbolMat a, OptimizerParameters optimizer a)]
-> [(Mat a, OptimizerParameters optimizer a)]
optimizerUpdateParameters optimizer a
optimizer (a
lrValue, SymbolMat a -> Gradients (Mat a)
forall a. Symbolic a => Symbol a -> Gradients a
getGradientsOf (SymbolMat a -> Gradients (Mat a))
-> SymbolMat a -> Gradients (Mat a)
forall a b. (a -> b) -> a -> b
$ SymbolMat a
lossValue LossFn a
forall a. Num a => a -> a -> a
+ SymbolMat a
regularizersLoss)
([SymbolMat a]
-> [OptimizerParameters optimizer a]
-> [(SymbolMat a, OptimizerParameters optimizer a)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SymbolMat a]
parameters [OptimizerParameters optimizer a]
optimizerParameters)
()
_ <- IORef (model a) -> (model a -> model a) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef' IORef (model a)
modelRef (model a -> [Mat a] -> model a
forall a. model a -> [Mat a] -> model a
forall (l :: * -> *) a. AbstractLayer l => l a -> [Mat a] -> l a
`updateParameters` [Mat a]
parameters')
()
_ <- IORef [OptimizerParameters optimizer a]
-> [OptimizerParameters optimizer a] -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef [OptimizerParameters optimizer a]
optimizerParametersRef [OptimizerParameters optimizer a]
optimizerParameters'
IORef (Vec a)
metricsValuesRef <- Vec a -> IO (IORef (Vec a))
forall a. a -> IO (IORef a)
newIORef (Vec a -> IO (IORef (Vec a))) -> Vec a -> IO (IORef (Vec a))
forall a b. (a -> b) -> a -> b
$ Vector a -> Vec a
forall a. Vector a -> Vec a
Vec (Vector a -> Vec a) -> Vector a -> Vec a
forall a b. (a -> b) -> a -> b
$ a -> Vector a -> Vector a
forall a. a -> Vector a -> Vector a
V.cons (SymbolMat a -> DType (SymbolMat a)
forall f. SingletonOps f => f -> DType f
unSingleton SymbolMat a
lossValue) (Vector a -> Vector a) -> Vector a -> Vector a
forall a b. (a -> b) -> a -> b
$ (Metric a -> a) -> Vector (Metric a) -> Vector a
forall a b. (a -> b) -> Vector a -> Vector b
V.map (\(Metric MetricFn a
metric) -> Mat a -> DType (Mat a)
forall f. SingletonOps f => f -> DType f
unSingleton (Mat a -> DType (Mat a)) -> Mat a -> DType (Mat a)
forall a b. (a -> b) -> a -> b
$ MetricFn a
metric Mat a
batchOutput (SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol SymbolMat a
prediction)) Vector (Metric a)
metrics
((IORef Int
-> IORef Int
-> IORef (model a)
-> IORef [OptimizerParameters optimizer a]
-> IORef (Vec a)
-> IO ())
-> IO ())
-> [IORef Int
-> IORef Int
-> IORef (model a)
-> IORef [OptimizerParameters optimizer a]
-> IORef (Vec a)
-> IO ()]
-> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (\IORef Int
-> IORef Int
-> IORef (model a)
-> IORef [OptimizerParameters optimizer a]
-> IORef (Vec a)
-> IO ()
fn -> IORef Int
-> IORef Int
-> IORef (model a)
-> IORef [OptimizerParameters optimizer a]
-> IORef (Vec a)
-> IO ()
fn IORef Int
epochRef IORef Int
batchIRef IORef (model a)
modelRef IORef [OptimizerParameters optimizer a]
optimizerParametersRef IORef (Vec a)
metricsValuesRef) (Callbacks model optimizer a
-> [IORef Int
-> IORef Int
-> IORef (model a)
-> IORef [OptimizerParameters optimizer a]
-> IORef (Vec a)
-> IO ()]
forall (model :: * -> *) (optimizer :: * -> *) a.
Callbacks model optimizer a
-> [CallbackFnOnBatchEnd model optimizer a]
callbacksOnBatchEnd Callbacks model optimizer a
callbacks)
(Vec Vector a
metricsValues) <- IORef (Vec a) -> IO (Vec a)
forall a. IORef a -> IO a
readIORef IORef (Vec a)
metricsValuesRef
[Int] -> (Int -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0 .. (Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
metricsValues Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)] ((Int -> IO ()) -> IO ()) -> (Int -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Int
metricI -> MVector (PrimState IO) a -> Int -> a -> IO ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write (Vector (MVector RealWorld a) -> Int -> MVector RealWorld a
forall a. Vector a -> Int -> a
V.unsafeIndex Vector (MVector RealWorld a)
allMetrics Int
metricI) (Int
currentIteration Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (Vector a -> LearningRateFn a
forall a. Vector a -> Int -> a
V.unsafeIndex Vector a
metricsValues Int
metricI)
IORef Int -> (Int -> Int) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef' IORef Int
batchIRef (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
((IORef Int
-> IORef (model a)
-> IORef [OptimizerParameters optimizer a]
-> IO ())
-> IO ())
-> [IORef Int
-> IORef (model a)
-> IORef [OptimizerParameters optimizer a]
-> IO ()]
-> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (\IORef Int
-> IORef (model a)
-> IORef [OptimizerParameters optimizer a]
-> IO ()
fn -> IORef Int
-> IORef (model a)
-> IORef [OptimizerParameters optimizer a]
-> IO ()
fn IORef Int
epochRef IORef (model a)
modelRef IORef [OptimizerParameters optimizer a]
optimizerParametersRef) (Callbacks model optimizer a
-> [IORef Int
-> IORef (model a)
-> IORef [OptimizerParameters optimizer a]
-> IO ()]
forall (model :: * -> *) (optimizer :: * -> *) a.
Callbacks model optimizer a
-> [CallbackFnOnEpochEnd model optimizer a]
callbacksOnEpochEnd Callbacks model optimizer a
callbacks)
IORef Int -> (Int -> Int) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef' IORef Int
epochRef (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
IORef (Vec (RecordedMetric a))
recordedMetricsRef <- (MVector RealWorld a -> IO (RecordedMetric a))
-> Vector (MVector RealWorld a) -> IO (Vector (RecordedMetric a))
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Vector a -> m (Vector b)
V.mapM ((Vector a -> RecordedMetric a)
-> IO (Vector a) -> IO (RecordedMetric a)
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Vec a -> RecordedMetric a
forall a. Vec a -> RecordedMetric a
RecordedMetric (Vec a -> RecordedMetric a)
-> (Vector a -> Vec a) -> Vector a -> RecordedMetric a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector a -> Vec a
forall a. Vector a -> Vec a
Vec) (IO (Vector a) -> IO (RecordedMetric a))
-> (MVector RealWorld a -> IO (Vector a))
-> MVector RealWorld a
-> IO (RecordedMetric a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MVector RealWorld a -> IO (Vector a)
MVector (PrimState IO) a -> IO (Vector a)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
V.unsafeFreeze) Vector (MVector RealWorld a)
allMetrics IO (Vector (RecordedMetric a))
-> (Vector (RecordedMetric a)
-> IO (IORef (Vec (RecordedMetric a))))
-> IO (IORef (Vec (RecordedMetric a)))
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Vec (RecordedMetric a) -> IO (IORef (Vec (RecordedMetric a)))
forall a. a -> IO (IORef a)
newIORef (Vec (RecordedMetric a) -> IO (IORef (Vec (RecordedMetric a))))
-> (Vector (RecordedMetric a) -> Vec (RecordedMetric a))
-> Vector (RecordedMetric a)
-> IO (IORef (Vec (RecordedMetric a)))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector (RecordedMetric a) -> Vec (RecordedMetric a)
forall a. Vector a -> Vec a
Vec
((IORef (model a)
-> IORef [OptimizerParameters optimizer a]
-> IORef (Vec (RecordedMetric a))
-> IO ())
-> IO ())
-> [IORef (model a)
-> IORef [OptimizerParameters optimizer a]
-> IORef (Vec (RecordedMetric a))
-> IO ()]
-> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (\IORef (model a)
-> IORef [OptimizerParameters optimizer a]
-> IORef (Vec (RecordedMetric a))
-> IO ()
fn -> IORef (model a)
-> IORef [OptimizerParameters optimizer a]
-> IORef (Vec (RecordedMetric a))
-> IO ()
fn IORef (model a)
modelRef IORef [OptimizerParameters optimizer a]
optimizerParametersRef IORef (Vec (RecordedMetric a))
recordedMetricsRef) (Callbacks model optimizer a
-> [IORef (model a)
-> IORef [OptimizerParameters optimizer a]
-> IORef (Vec (RecordedMetric a))
-> IO ()]
forall (model :: * -> *) (optimizer :: * -> *) a.
Callbacks model optimizer a
-> [CallbackFnOnTrainEnd model optimizer a]
callbacksOnTrainEnd Callbacks model optimizer a
callbacks)
model a
trainedModel <- IORef (model a) -> IO (model a)
forall a. IORef a -> IO a
readIORef IORef (model a)
modelRef
[OptimizerParameters optimizer a]
trainedOptimizerParameters <- IORef [OptimizerParameters optimizer a]
-> IO [OptimizerParameters optimizer a]
forall a. IORef a -> IO a
readIORef IORef [OptimizerParameters optimizer a]
optimizerParametersRef
Vec (RecordedMetric a)
recordedMetrics <- IORef (Vec (RecordedMetric a)) -> IO (Vec (RecordedMetric a))
forall a. IORef a -> IO a
readIORef IORef (Vec (RecordedMetric a))
recordedMetricsRef
g
gen'' <- IORef g -> IO g
forall a. IORef a -> IO a
readIORef IORef g
gen
(model a, [OptimizerParameters optimizer a],
Vec (RecordedMetric a), g)
-> IO
(model a, [OptimizerParameters optimizer a],
Vec (RecordedMetric a), g)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (model a
trainedModel, [OptimizerParameters optimizer a]
trainedOptimizerParameters, Vec (RecordedMetric a)
recordedMetrics, g
gen'')