Safe Haskell | None |
---|---|
Language | GHC2024 |
MnistFcnnRanked2
Description
Ranked tensor-based implementation of fully connected neutral network for classification of MNIST digits. Sports 2 hidden layers. No mini-batches, so the maximum rank of tensors being used is 2.
Synopsis
- type ADFcnnMnist2ParametersShaped (target :: Target) (widthHidden :: Nat) (widthHidden2 :: Nat) r q = ((target (TKS '[widthHidden, SizeMnistGlyph] r), target (TKS '[widthHidden] r)), (target (TKS '[widthHidden2, widthHidden] q), target (TKS '[widthHidden2] r)), (target (TKS '[SizeMnistLabel, widthHidden2] r), target (TKS '[SizeMnistLabel] r)))
- type ADFcnnMnist2Parameters (target :: Target) r q = ((target (TKR 2 r), target (TKR 1 r)), (target (TKR 2 q), target (TKR 1 r)), (target (TKR 2 r), target (TKR 1 r)))
- type XParams2 r q = X (ADFcnnMnist2Parameters Concrete r q)
- afcnnMnist2 :: (ADReady target, GoodScalar r, Differentiable r, GoodScalar q, Differentiable q) => (target (TKR 1 r) -> target (TKR 1 r)) -> (target (TKR 1 r) -> target (TKR 1 r)) -> target (TKR 1 r) -> ADFcnnMnist2Parameters target r q -> target (TKR 1 r)
- afcnnMnistLoss2 :: (ADReady target, GoodScalar r, Differentiable r, GoodScalar q, Differentiable q) => (target (TKR 1 r), target (TKR 1 r)) -> ADFcnnMnist2Parameters target r q -> target ('TKScalar r)
- afcnnMnistTest2 :: (target ~ Concrete, GoodScalar r, Differentiable r, GoodScalar q, Differentiable q) => [MnistDataLinearR r] -> ADFcnnMnist2Parameters target r q -> r
- mnistTrainBench2VTOGradient :: (GoodScalar r, Differentiable r, GoodScalar q, Differentiable q) => Proxy q -> IncomingCotangentHandling -> Double -> StdGen -> Int -> Int -> (Concrete (XParams2 r q), AstArtifactRev ('TKProduct (XParams2 r q) ('TKProduct ('TKR2 1 ('TKScalar r)) ('TKR2 1 ('TKScalar r)))) ('TKScalar r))
- mnistTrainBench2VTOGradientX :: (GoodScalar r, Differentiable r, GoodScalar q, Differentiable q) => Proxy q -> IncomingCotangentHandling -> Double -> StdGen -> Int -> Int -> (Concrete (XParams2 r q), AstArtifactRev ('TKProduct (XParams2 r q) ('TKProduct ('TKR2 1 ('TKScalar r)) ('TKR2 1 ('TKScalar r)))) ('TKScalar r))
Documentation
type ADFcnnMnist2ParametersShaped (target :: Target) (widthHidden :: Nat) (widthHidden2 :: Nat) r q = ((target (TKS '[widthHidden, SizeMnistGlyph] r), target (TKS '[widthHidden] r)), (target (TKS '[widthHidden2, widthHidden] q), target (TKS '[widthHidden2] r)), (target (TKS '[SizeMnistLabel, widthHidden2] r), target (TKS '[SizeMnistLabel] r))) Source #
The differentiable type of all trainable parameters of this nn. Shaped version, statically checking all dimension widths.
type ADFcnnMnist2Parameters (target :: Target) r q = ((target (TKR 2 r), target (TKR 1 r)), (target (TKR 2 q), target (TKR 1 r)), (target (TKR 2 r), target (TKR 1 r))) Source #
The differentiable type of all trainable parameters of this nn.
afcnnMnist2 :: (ADReady target, GoodScalar r, Differentiable r, GoodScalar q, Differentiable q) => (target (TKR 1 r) -> target (TKR 1 r)) -> (target (TKR 1 r) -> target (TKR 1 r)) -> target (TKR 1 r) -> ADFcnnMnist2Parameters target r q -> target (TKR 1 r) Source #
Fully connected neural network for the MNIST digit classification task.
There are two hidden layers and both use the same activation function.
The output layer uses a different activation function.
The widths of the two hidden layers are widthHidden
and widthHidden2
,
respectively.
afcnnMnistLoss2 :: (ADReady target, GoodScalar r, Differentiable r, GoodScalar q, Differentiable q) => (target (TKR 1 r), target (TKR 1 r)) -> ADFcnnMnist2Parameters target r q -> target ('TKScalar r) Source #
The neural network applied to concrete activation functions and composed with the appropriate loss function.
afcnnMnistTest2 :: (target ~ Concrete, GoodScalar r, Differentiable r, GoodScalar q, Differentiable q) => [MnistDataLinearR r] -> ADFcnnMnist2Parameters target r q -> r Source #
A function testing the neural network given testing set of inputs and the trained parameters.
mnistTrainBench2VTOGradient :: (GoodScalar r, Differentiable r, GoodScalar q, Differentiable q) => Proxy q -> IncomingCotangentHandling -> Double -> StdGen -> Int -> Int -> (Concrete (XParams2 r q), AstArtifactRev ('TKProduct (XParams2 r q) ('TKProduct ('TKR2 1 ('TKScalar r)) ('TKR2 1 ('TKScalar r)))) ('TKScalar r)) Source #
The loss function applied to randomly generated initial parameters and wrapped in artifact generation. This is helpful to share code between tests and benchmarks and to separate compile-time and run-time for benchmarking (this part is considered compile-time).
mnistTrainBench2VTOGradientX :: (GoodScalar r, Differentiable r, GoodScalar q, Differentiable q) => Proxy q -> IncomingCotangentHandling -> Double -> StdGen -> Int -> Int -> (Concrete (XParams2 r q), AstArtifactRev ('TKProduct (XParams2 r q) ('TKProduct ('TKR2 1 ('TKScalar r)) ('TKR2 1 ('TKScalar r)))) ('TKScalar r)) Source #
A version of mnistTrainBench2VTOGradient
without any simplification,
even the AST smart constructors. Intended for benchmarking.