Safe Haskell | None |
---|---|
Language | GHC2024 |
MnistRnnRanked2
Description
Ranked tensor-based implementation of Recurrent Neural Network for classification of MNIST digits. Sports 2 hidden layers.
Synopsis
- type ADRnnMnistParametersShaped (target :: Target) (width :: Nat) r = (LayerWeigthsRNNShaped target SizeMnistHeight width r, LayerWeigthsRNNShaped target width width r, (target (TKS '[SizeMnistLabel, width] r), target (TKS '[SizeMnistLabel] r)))
- type LayerWeigthsRNNShaped (target :: Target) (in_width :: Nat) (out_width :: Nat) r = (target (TKS '[out_width, in_width] r), target (TKS '[out_width, out_width] r), target (TKS '[out_width] r))
- type ADRnnMnistParameters (target :: Target) r = (LayerWeigthsRNN target r, LayerWeigthsRNN target r, (target (TKR 2 r), target (TKR 1 r)))
- type LayerWeigthsRNN (target :: Target) r = (target (TKR 2 r), target (TKR 2 r), target (TKR 1 r))
- zeroStateR :: forall target r (n :: Nat) a. (BaseTensor target, GoodScalar r) => IShR n -> (target (TKR n r) -> a) -> a
- unrollLastR :: forall target state c w r (n :: Nat). (BaseTensor target, GoodScalar r, KnownNat n) => (state -> target (TKR n r) -> w -> (c, state)) -> state -> target (TKR (1 + n) r) -> w -> (c, state)
- rnnMnistLayerR :: (ADReady target, GoodScalar r, Differentiable r) => target (TKR 2 r) -> target (TKR 2 r) -> LayerWeigthsRNN target r -> target (TKR 2 r)
- rnnMnistTwoR :: (ADReady target, GoodScalar r, Differentiable r) => target (TKR 2 r) -> PrimalOf target (TKR 2 r) -> (LayerWeigthsRNN target r, LayerWeigthsRNN target r) -> (target (TKR 2 r), target (TKR 2 r))
- rnnMnistZeroR :: (ADReady target, GoodScalar r, Differentiable r) => Int -> PrimalOf target (TKR 3 r) -> ADRnnMnistParameters target r -> target (TKR 2 r)
- rnnMnistLossFusedR :: (ADReady target, ADReady (PrimalOf target), GoodScalar r, Differentiable r) => Int -> (PrimalOf target (TKR 3 r), PrimalOf target (TKR 2 r)) -> ADRnnMnistParameters target r -> target ('TKScalar r)
- rnnMnistTestR :: (target ~ Concrete, GoodScalar r, Differentiable r) => Int -> MnistDataBatchR r -> ADRnnMnistParameters target r -> r
Documentation
type ADRnnMnistParametersShaped (target :: Target) (width :: Nat) r = (LayerWeigthsRNNShaped target SizeMnistHeight width r, LayerWeigthsRNNShaped target width width r, (target (TKS '[SizeMnistLabel, width] r), target (TKS '[SizeMnistLabel] r))) Source #
The differentiable type of all trainable parameters of this nn. Shaped version, statically checking all dimension widths.
type LayerWeigthsRNNShaped (target :: Target) (in_width :: Nat) (out_width :: Nat) r = (target (TKS '[out_width, in_width] r), target (TKS '[out_width, out_width] r), target (TKS '[out_width] r)) Source #
type ADRnnMnistParameters (target :: Target) r = (LayerWeigthsRNN target r, LayerWeigthsRNN target r, (target (TKR 2 r), target (TKR 1 r))) Source #
The differentiable type of all trainable parameters of this nn.
type LayerWeigthsRNN (target :: Target) r = (target (TKR 2 r), target (TKR 2 r), target (TKR 1 r)) Source #
zeroStateR :: forall target r (n :: Nat) a. (BaseTensor target, GoodScalar r) => IShR n -> (target (TKR n r) -> a) -> a Source #
unrollLastR :: forall target state c w r (n :: Nat). (BaseTensor target, GoodScalar r, KnownNat n) => (state -> target (TKR n r) -> w -> (c, state)) -> state -> target (TKR (1 + n) r) -> w -> (c, state) Source #
Arguments
:: (ADReady target, GoodScalar r, Differentiable r) | |
=> target (TKR 2 r) | in state, |
-> target (TKR 2 r) | input, |
-> LayerWeigthsRNN target r | parameters |
-> target (TKR 2 r) | output state, |
A single recurrent layer with tanh
activation function.
rnnMnistTwoR :: (ADReady target, GoodScalar r, Differentiable r) => target (TKR 2 r) -> PrimalOf target (TKR 2 r) -> (LayerWeigthsRNN target r, LayerWeigthsRNN target r) -> (target (TKR 2 r), target (TKR 2 r)) Source #
Composition of two recurrent layers.
Arguments
:: (ADReady target, GoodScalar r, Differentiable r) | |
=> Int | batch_size |
-> PrimalOf target (TKR 3 r) | input data |
-> ADRnnMnistParameters target r | parameters |
-> target (TKR 2 r) | output classification |
The two-layer recurrent nn with its state initialized to zero and the result composed with a fully connected layer.
rnnMnistLossFusedR :: (ADReady target, ADReady (PrimalOf target), GoodScalar r, Differentiable r) => Int -> (PrimalOf target (TKR 3 r), PrimalOf target (TKR 2 r)) -> ADRnnMnistParameters target r -> target ('TKScalar r) Source #
The neural network composed with the SoftMax-CrossEntropy loss function.
rnnMnistTestR :: (target ~ Concrete, GoodScalar r, Differentiable r) => Int -> MnistDataBatchR r -> ADRnnMnistParameters target r -> r Source #
A function testing the neural network given testing set of inputs and the trained parameters.