{-# OPTIONS_GHC -Wno-missing-export-lists #-}
-- | Ranked tensor-based implementation of Recurrent Neural Network
-- for classification of MNIST digits. Sports 2 hidden layers.
module MnistRnnRanked2 where

import Prelude hiding (foldl')

import Data.Kind (Type)
import Data.List (foldl')
import Data.Vector.Generic qualified as V
import Data.Vector.Storable (Vector)
import GHC.TypeLits (KnownNat, Nat, type (+))

import Data.Array.Nested qualified as Nested
import Data.Array.Nested.Ranked.Shape

import HordeAd
import MnistData

-- | The differentiable type of all trainable parameters of this nn.
-- Shaped version, statically checking all dimension widths.
type ADRnnMnistParametersShaped (target :: Target) width r =
  ( LayerWeigthsRNNShaped target SizeMnistHeight width r
  , LayerWeigthsRNNShaped target width width r
  , ( target (TKS '[SizeMnistLabel, width] r)
    , target (TKS '[SizeMnistLabel] r) ) )

type LayerWeigthsRNNShaped :: Target -> Nat -> Nat -> Type -> Type
type LayerWeigthsRNNShaped target in_width out_width r =
  ( target (TKS '[out_width, in_width] r)   -- input weight
  , target (TKS '[out_width, out_width] r)  -- state weight
  , target (TKS '[out_width] r) )           -- bias

-- | The differentiable type of all trainable parameters of this nn.
type ADRnnMnistParameters target r =
  ( LayerWeigthsRNN target r
  , LayerWeigthsRNN target r
  , ( target (TKR 2 r)
    , target (TKR 1 r) ) )

type LayerWeigthsRNN (target :: Target) r =
  ( target (TKR 2 r)
  , target (TKR 2 r)
  , target (TKR 1 r) )

zeroStateR
  :: (BaseTensor target, GoodScalar r)
  => IShR n -> (target (TKR n r)  -- state
                    -> a)
  -> a
zeroStateR :: forall (target :: Target) r (n :: Nat) a.
(BaseTensor target, GoodScalar r) =>
IShR n -> (target (TKR n r) -> a) -> a
zeroStateR IShR n
sh target (TKR n r) -> a
f = target (TKR n r) -> a
f (IShR n -> r -> target (TKR n r)
forall (n :: Nat) r (target :: Target).
(GoodScalar r, BaseTensor target) =>
IShR n -> r -> target (TKR n r)
rrepl IShR n
sh r
0)

unrollLastR :: forall target state c w r n.
               (BaseTensor target, GoodScalar r, KnownNat n)
            => (state -> target (TKR n r) -> w -> (c, state))
            -> (state -> target (TKR (1 + n) r) -> w -> (c, state))
unrollLastR :: forall (target :: Target) state c w r (n :: Nat).
(BaseTensor target, GoodScalar r, KnownNat n) =>
(state -> target (TKR n r) -> w -> (c, state))
-> state -> target (TKR (1 + n) r) -> w -> (c, state)
unrollLastR state -> target (TKR n r) -> w -> (c, state)
f state
s0 target (TKR (1 + n) r)
xs w
w =
  let g :: (c, state) -> target (TKR n r) -> (c, state)
      g :: (c, state) -> target (TKR n r) -> (c, state)
g (c
_, !state
s) target (TKR n r)
x = state -> target (TKR n r) -> w -> (c, state)
f state
s target (TKR n r)
x w
w
  in ((c, state) -> target (TKR n r) -> (c, state))
-> (c, state) -> [target (TKR n r)] -> (c, state)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (c, state) -> target (TKR n r) -> (c, state)
g (c
forall a. HasCallStack => a
undefined, state
s0) (target (TKR (1 + n) r) -> [target (TKR n r)]
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownSTK x, KnownNat n, BaseTensor target) =>
target (TKR2 (1 + n) x) -> [target (TKR2 n x)]
runravelToList target (TKR (1 + n) r)
xs)

-- | A single recurrent layer with @tanh@ activation function.
rnnMnistLayerR
  :: (ADReady target, GoodScalar r, Differentiable r)
  => target (TKR 2 r)  -- ^ in state, @[out_width, batch_size]@
  -> target (TKR 2 r)  -- ^ input, @[in_width, batch_size]@
  -> LayerWeigthsRNN target r  -- ^ parameters
  -> target (TKR 2 r)  -- ^ output state, @[out_width, batch_size]@
rnnMnistLayerR :: forall (target :: Target) r.
(ADReady target, GoodScalar r, Differentiable r) =>
target (TKR 2 r)
-> target (TKR 2 r) -> LayerWeigthsRNN target r -> target (TKR 2 r)
rnnMnistLayerR target (TKR 2 r)
s target (TKR 2 r)
x (target (TKR 2 r)
wX, target (TKR 2 r)
wS, target (TKR 1 r)
b) = case target (TKR 2 r) -> IShR 2
forall (n :: Nat) (x :: TK).
KnownSTK x =>
target (TKR2 n x) -> IShR n
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKR2 n x) -> IShR n
rshape target (TKR 2 r)
s of
  Int
_out_width :$: Int
batch_size :$: ShR n Int
ZSR ->
    let y :: target (TKR 2 r)
y = target (TKR 2 r)
wX target (TKR 2 r) -> target (TKR 2 r) -> target (TKR 2 r)
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
target (TKR 2 r) -> target (TKR 2 r) -> target (TKR 2 r)
`rmatmul2` target (TKR 2 r)
x target (TKR 2 r) -> target (TKR 2 r) -> target (TKR 2 r)
forall a. Num a => a -> a -> a
+ target (TKR 2 r)
wS target (TKR 2 r) -> target (TKR 2 r) -> target (TKR 2 r)
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
target (TKR 2 r) -> target (TKR 2 r) -> target (TKR 2 r)
`rmatmul2` target (TKR 2 r)
s
            target (TKR 2 r) -> target (TKR 2 r) -> target (TKR 2 r)
forall a. Num a => a -> a -> a
+ target (TKR2 (2 + 0) (TKScalar r))
-> target (TKR2 (2 + 0) (TKScalar r))
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownSTK x, BaseTensor target) =>
target (TKR2 (2 + n) x) -> target (TKR2 (2 + n) x)
rtr (Int -> target (TKR 1 r) -> target (TKR2 (1 + 1) (TKScalar r))
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownNat n, KnownSTK x, BaseTensor target) =>
Int -> target (TKR2 n x) -> target (TKR2 (1 + n) x)
rreplicate Int
batch_size target (TKR 1 r)
b)
    in target (TKR 2 r) -> target (TKR 2 r)
forall a. Floating a => a -> a
tanh target (TKR 2 r)
y

-- TODO: represent state as a pair to avoid appending; tlet now supports this.
-- | Composition of two recurrent layers.
rnnMnistTwoR
  :: (ADReady target, GoodScalar r, Differentiable r)
  => target (TKR 2 r)  -- initial state, @[2 * out_width, batch_size]@
  -> PrimalOf target (TKR 2 r)  -- @[sizeMnistHeight, batch_size]@
  -> ( LayerWeigthsRNN target r  -- sizeMnistHeight out_width
     , LayerWeigthsRNN target r )  -- out_width out_width
  -> ( target (TKR 2 r)  -- @[out_width, batch_size]@
     , target (TKR 2 r) )  -- final state, @[2 * out_width, batch_size]@
rnnMnistTwoR :: forall (target :: Target) r.
(ADReady target, GoodScalar r, Differentiable r) =>
target (TKR 2 r)
-> PrimalOf target (TKR 2 r)
-> (LayerWeigthsRNN target r, LayerWeigthsRNN target r)
-> (target (TKR 2 r), target (TKR 2 r))
rnnMnistTwoR target (TKR 2 r)
s' PrimalOf target (TKR 2 r)
x ((target (TKR 2 r)
wX, target (TKR 2 r)
wS, target (TKR 1 r)
b), (target (TKR 2 r)
wX2, target (TKR 2 r)
wS2, target (TKR 1 r)
b2)) = case target (TKR 2 r) -> IShR 2
forall (n :: Nat) (x :: TK).
KnownSTK x =>
target (TKR2 n x) -> IShR n
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKR2 n x) -> IShR n
rshape target (TKR 2 r)
s' of
  Int
out_width_x_2 :$: Int
_batch_size :$: ShR n Int
ZSR ->
    let out_width :: Int
out_width = Int
out_width_x_2 Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2
        s3 :: target (TKR 2 r)
s3 = target (TKR 2 r)
-> (target (TKR 2 r) -> target (TKR 2 r)) -> target (TKR 2 r)
forall (x :: TK) (z :: TK) (target :: Target).
LetTensor target =>
target x -> (target x -> target z) -> target z
tlet target (TKR 2 r)
s' ((target (TKR 2 r) -> target (TKR 2 r)) -> target (TKR 2 r))
-> (target (TKR 2 r) -> target (TKR 2 r)) -> target (TKR 2 r)
forall a b. (a -> b) -> a -> b
$ \target (TKR 2 r)
s ->
          let s1 :: target (TKR2 (1 + 1) (TKScalar r))
s1 = Int
-> Int
-> target (TKR2 (1 + 1) (TKScalar r))
-> target (TKR2 (1 + 1) (TKScalar r))
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownSTK x, BaseTensor target) =>
Int -> Int -> target (TKR2 (1 + n) x) -> target (TKR2 (1 + n) x)
rslice Int
0 Int
out_width target (TKR 2 r)
target (TKR2 (1 + 1) (TKScalar r))
s
              s2 :: target (TKR2 (1 + 1) (TKScalar r))
s2 = Int
-> Int
-> target (TKR2 (1 + 1) (TKScalar r))
-> target (TKR2 (1 + 1) (TKScalar r))
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownSTK x, BaseTensor target) =>
Int -> Int -> target (TKR2 (1 + n) x) -> target (TKR2 (1 + n) x)
rslice Int
out_width Int
out_width target (TKR 2 r)
target (TKR2 (1 + 1) (TKScalar r))
s
              vec1 :: target (TKR 2 r)
vec1 = target (TKR 2 r)
-> target (TKR 2 r) -> LayerWeigthsRNN target r -> target (TKR 2 r)
forall (target :: Target) r.
(ADReady target, GoodScalar r, Differentiable r) =>
target (TKR 2 r)
-> target (TKR 2 r) -> LayerWeigthsRNN target r -> target (TKR 2 r)
rnnMnistLayerR target (TKR 2 r)
target (TKR2 (1 + 1) (TKScalar r))
s1 (PrimalOf target (TKR 2 r) -> target (TKR 2 r)
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownNat n, KnownSTK x) =>
PrimalOf target (TKR2 n x) -> target (TKR2 n x)
rfromPrimal PrimalOf target (TKR 2 r)
x) (target (TKR 2 r)
wX, target (TKR 2 r)
wS, target (TKR 1 r)
b)
              vec2 :: target (TKR 2 r)
vec2 = target (TKR 2 r)
-> target (TKR 2 r) -> LayerWeigthsRNN target r -> target (TKR 2 r)
forall (target :: Target) r.
(ADReady target, GoodScalar r, Differentiable r) =>
target (TKR 2 r)
-> target (TKR 2 r) -> LayerWeigthsRNN target r -> target (TKR 2 r)
rnnMnistLayerR target (TKR 2 r)
target (TKR2 (1 + 1) (TKScalar r))
s2 target (TKR 2 r)
vec1 (target (TKR 2 r)
wX2, target (TKR 2 r)
wS2, target (TKR 1 r)
b2)
          in target (TKR2 (1 + 1) (TKScalar r))
-> target (TKR2 (1 + 1) (TKScalar r))
-> target (TKR2 (1 + 1) (TKScalar r))
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownSTK x, BaseTensor target) =>
target (TKR2 (1 + n) x)
-> target (TKR2 (1 + n) x) -> target (TKR2 (1 + n) x)
rappend target (TKR 2 r)
target (TKR2 (1 + 1) (TKScalar r))
vec1 target (TKR 2 r)
target (TKR2 (1 + 1) (TKScalar r))
vec2
    in (Int
-> Int
-> target (TKR2 (1 + 1) (TKScalar r))
-> target (TKR2 (1 + 1) (TKScalar r))
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownSTK x, BaseTensor target) =>
Int -> Int -> target (TKR2 (1 + n) x) -> target (TKR2 (1 + n) x)
rslice Int
out_width Int
out_width target (TKR 2 r)
target (TKR2 (1 + 1) (TKScalar r))
s3, target (TKR 2 r)
s3)

-- | The two-layer recurrent nn with its state initialized to zero
-- and the result composed with a fully connected layer.
rnnMnistZeroR
  :: (ADReady target, GoodScalar r, Differentiable r)
  => Int  -- ^ batch_size
  -> PrimalOf target (TKR 3 r)
       -- ^ input data @[sizeMnistWidth, sizeMnistHeight, batch_size]@
  -> ADRnnMnistParameters target r  -- ^ parameters
  -> target (TKR 2 r)  -- ^ output classification @[SizeMnistLabel, batch_size]@
rnnMnistZeroR :: forall (target :: Target) r.
(ADReady target, GoodScalar r, Differentiable r) =>
Int
-> PrimalOf target (TKR 3 r)
-> ADRnnMnistParameters target r
-> target (TKR 2 r)
rnnMnistZeroR Int
batch_size PrimalOf target (TKR 3 r)
xs
              ((target (TKR 2 r)
wX, target (TKR 2 r)
wS, target (TKR 1 r)
b), (target (TKR 2 r)
wX2, target (TKR 2 r)
wS2, target (TKR 1 r)
b2), (target (TKR 2 r)
w3, target (TKR 1 r)
b3)) = case target (TKR 1 r) -> IShR 1
forall (n :: Nat) (x :: TK).
KnownSTK x =>
target (TKR2 n x) -> IShR n
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKR2 n x) -> IShR n
rshape target (TKR 1 r)
b of
  Int
out_width :$: ShR n Int
ZSR ->
    let sh :: IShR 2
sh = Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
out_width Int -> IShR 1 -> IShR 2
forall {n1 :: Nat} {i} (n :: Nat).
((n + 1 :: Nat) ~ (n1 :: Nat)) =>
i -> ShR n i -> ShR n1 i
:$: Int
batch_size Int -> ShR 0 Int -> IShR 1
forall {n1 :: Nat} {i} (n :: Nat).
((n + 1 :: Nat) ~ (n1 :: Nat)) =>
i -> ShR n i -> ShR n1 i
:$: ShR 0 Int
forall (n :: Nat) i. ((n :: Nat) ~ (0 :: Nat)) => ShR n i
ZSR
        (target (TKR 2 r)
out, target (TKR 2 r)
_s) = IShR 2
-> (target (TKR 2 r)
    -> PrimalOf target (TKR (1 + 2) r)
    -> (LayerWeigthsRNN target r, LayerWeigthsRNN target r)
    -> (target (TKR 2 r), target (TKR 2 r)))
-> PrimalOf target (TKR (1 + 2) r)
-> (LayerWeigthsRNN target r, LayerWeigthsRNN target r)
-> (target (TKR 2 r), target (TKR 2 r))
forall (target :: Target) r (n :: Nat) a.
(BaseTensor target, GoodScalar r) =>
IShR n -> (target (TKR n r) -> a) -> a
zeroStateR IShR 2
sh ((target (TKR 2 r)
 -> PrimalOf target (TKR 2 r)
 -> (LayerWeigthsRNN target r, LayerWeigthsRNN target r)
 -> (target (TKR 2 r), target (TKR 2 r)))
-> target (TKR 2 r)
-> PrimalOf target (TKR (1 + 2) r)
-> (LayerWeigthsRNN target r, LayerWeigthsRNN target r)
-> (target (TKR 2 r), target (TKR 2 r))
forall (target :: Target) state c w r (n :: Nat).
(BaseTensor target, GoodScalar r, KnownNat n) =>
(state -> target (TKR n r) -> w -> (c, state))
-> state -> target (TKR (1 + n) r) -> w -> (c, state)
unrollLastR target (TKR 2 r)
-> PrimalOf target (TKR 2 r)
-> (LayerWeigthsRNN target r, LayerWeigthsRNN target r)
-> (target (TKR 2 r), target (TKR 2 r))
forall (target :: Target) r.
(ADReady target, GoodScalar r, Differentiable r) =>
target (TKR 2 r)
-> PrimalOf target (TKR 2 r)
-> (LayerWeigthsRNN target r, LayerWeigthsRNN target r)
-> (target (TKR 2 r), target (TKR 2 r))
rnnMnistTwoR) PrimalOf target (TKR 3 r)
PrimalOf target (TKR (1 + 2) r)
xs
                                  ((target (TKR 2 r)
wX, target (TKR 2 r)
wS, target (TKR 1 r)
b), (target (TKR 2 r)
wX2, target (TKR 2 r)
wS2, target (TKR 1 r)
b2))
    in target (TKR 2 r)
w3 target (TKR 2 r) -> target (TKR 2 r) -> target (TKR 2 r)
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
target (TKR 2 r) -> target (TKR 2 r) -> target (TKR 2 r)
`rmatmul2` target (TKR 2 r)
out target (TKR 2 r) -> target (TKR 2 r) -> target (TKR 2 r)
forall a. Num a => a -> a -> a
+ target (TKR2 (2 + 0) (TKScalar r))
-> target (TKR2 (2 + 0) (TKScalar r))
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownSTK x, BaseTensor target) =>
target (TKR2 (2 + n) x) -> target (TKR2 (2 + n) x)
rtr (Int -> target (TKR 1 r) -> target (TKR2 (1 + 1) (TKScalar r))
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownNat n, KnownSTK x, BaseTensor target) =>
Int -> target (TKR2 n x) -> target (TKR2 (1 + n) x)
rreplicate Int
batch_size target (TKR 1 r)
b3)

-- | The neural network composed with the SoftMax-CrossEntropy loss function.
rnnMnistLossFusedR
  :: (ADReady target, ADReady (PrimalOf target), GoodScalar r, Differentiable r)
  => Int
  -> (PrimalOf target (TKR 3 r), PrimalOf target (TKR 2 r))  -- batch_size
  -> ADRnnMnistParameters target r  -- SizeMnistHeight out_width
  -> target (TKScalar r)
rnnMnistLossFusedR :: forall (target :: Target) r.
(ADReady target, ADReady (PrimalOf target), GoodScalar r,
 Differentiable r) =>
Int
-> (PrimalOf target (TKR 3 r), PrimalOf target (TKR 2 r))
-> ADRnnMnistParameters target r
-> target (TKScalar r)
rnnMnistLossFusedR Int
batch_size (PrimalOf target (TKR 3 r)
glyphR, PrimalOf target (TKR 2 r)
labelR) ADRnnMnistParameters target r
adparameters =
  let xs :: PrimalOf target (TKR 3 r)
xs = PermR -> PrimalOf target (TKR 3 r) -> PrimalOf target (TKR 3 r)
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownSTK x, BaseTensor target) =>
PermR -> target (TKR2 n x) -> target (TKR2 n x)
rtranspose [Int
2, Int
1, Int
0] PrimalOf target (TKR 3 r)
glyphR
      result :: target (TKR 2 r)
result = Int
-> PrimalOf target (TKR 3 r)
-> ADRnnMnistParameters target r
-> target (TKR 2 r)
forall (target :: Target) r.
(ADReady target, GoodScalar r, Differentiable r) =>
Int
-> PrimalOf target (TKR 3 r)
-> ADRnnMnistParameters target r
-> target (TKR 2 r)
rnnMnistZeroR Int
batch_size PrimalOf target (TKR 3 r)
xs ADRnnMnistParameters target r
adparameters
      targets :: PrimalOf target (TKR2 (2 + 0) (TKScalar r))
targets = PrimalOf target (TKR2 (2 + 0) (TKScalar r))
-> PrimalOf target (TKR2 (2 + 0) (TKScalar r))
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownSTK x, BaseTensor target) =>
target (TKR2 (2 + n) x) -> target (TKR2 (2 + n) x)
rtr PrimalOf target (TKR 2 r)
PrimalOf target (TKR2 (2 + 0) (TKScalar r))
labelR
      loss :: target (TKScalar r)
loss = PrimalOf target (TKR 2 r)
-> target (TKR 2 r) -> target (TKScalar r)
forall (target :: Target) (n :: Nat) r.
(BaseTensor target, ConvertTensor target, LetTensor target,
 BaseTensor (PrimalOf target), ConvertTensor (PrimalOf target),
 LetTensor (PrimalOf target), KnownNat n, GoodScalar r,
 Differentiable r) =>
PrimalOf target (TKR n r)
-> target (TKR n r) -> target (TKScalar r)
lossSoftMaxCrossEntropyR PrimalOf target (TKR 2 r)
PrimalOf target (TKR2 (2 + 0) (TKScalar r))
targets target (TKR 2 r)
result
  in PrimalOf target (TKScalar r) -> target (TKScalar r)
forall (target :: Target) r.
(BaseTensor target, GoodScalar r) =>
PrimalOf target (TKScalar r) -> target (TKScalar r)
kfromPrimal (PrimalOf target (TKScalar r) -> PrimalOf target (TKScalar r)
forall a. Fractional a => a -> a
recip (PrimalOf target (TKScalar r) -> PrimalOf target (TKScalar r))
-> PrimalOf target (TKScalar r) -> PrimalOf target (TKScalar r)
forall a b. (a -> b) -> a -> b
$ r -> PrimalOf target (TKScalar r)
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
r -> target (TKScalar r)
kconcrete (r -> PrimalOf target (TKScalar r))
-> r -> PrimalOf target (TKScalar r)
forall a b. (a -> b) -> a -> b
$ Int -> r
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
batch_size) target (TKScalar r) -> target (TKScalar r) -> target (TKScalar r)
forall a. Num a => a -> a -> a
* target (TKScalar r)
loss

-- | A function testing the neural network given testing set of inputs
-- and the trained parameters.
rnnMnistTestR
  :: forall target r.
     (target ~ Concrete, GoodScalar r, Differentiable r)
  => Int
  -> MnistDataBatchR r  -- batch_size
  -> ADRnnMnistParameters target r
  -> r
rnnMnistTestR :: forall (target :: Target) r.
((target :: Target) ~ (Concrete :: Target), GoodScalar r,
 Differentiable r) =>
Int -> MnistDataBatchR r -> ADRnnMnistParameters target r -> r
rnnMnistTestR Int
0 MnistDataBatchR r
_ ADRnnMnistParameters target r
_ = r
0
rnnMnistTestR Int
batch_size (Ranked 3 r
glyphR, Ranked 2 r
labelR) ADRnnMnistParameters target r
testParams =
  let input :: target (TKR 3 r)
      input :: target (TKR 3 r)
input = Ranked 3 r -> target (TKR 3 r)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete (Ranked 3 r -> target (TKR 3 r)) -> Ranked 3 r -> target (TKR 3 r)
forall a b. (a -> b) -> a -> b
$ PermR -> Ranked 3 r -> Ranked 3 r
forall (n :: Nat) a. Elt a => PermR -> Ranked n a -> Ranked n a
Nested.rtranspose [Int
2, Int
1, Int
0] Ranked 3 r
glyphR
      outputR :: Concrete (TKR 2 r)
      outputR :: Concrete (TKR 2 r)
outputR =
        let nn :: ADRnnMnistParameters target r  -- SizeMnistHeight out_width
               -> target (TKR 2 r)  -- [SizeMnistLabel, batch_size]
            nn :: ADRnnMnistParameters target r -> target (TKR 2 r)
nn = Int
-> PrimalOf target (TKR 3 r)
-> ADRnnMnistParameters target r
-> target (TKR 2 r)
forall (target :: Target) r.
(ADReady target, GoodScalar r, Differentiable r) =>
Int
-> PrimalOf target (TKR 3 r)
-> ADRnnMnistParameters target r
-> target (TKR 2 r)
rnnMnistZeroR Int
batch_size target (TKR 3 r)
PrimalOf target (TKR 3 r)
input
        in ADRnnMnistParameters target r -> target (TKR 2 r)
nn ADRnnMnistParameters target r
testParams
      outputs :: [Vector r]
outputs = (Concrete (TKR 1 r) -> Vector r)
-> [Concrete (TKR 1 r)] -> [Vector r]
forall a b. (a -> b) -> [a] -> [b]
map Concrete (TKR 1 r) -> Vector r
forall r (n :: Nat). GoodScalar r => Concrete (TKR n r) -> Vector r
rtoVector ([Concrete (TKR 1 r)] -> [Vector r])
-> [Concrete (TKR 1 r)] -> [Vector r]
forall a b. (a -> b) -> a -> b
$ Concrete (TKR2 (1 + 1) (TKScalar r)) -> [Concrete (TKR 1 r)]
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownSTK x, KnownNat n, BaseTensor target) =>
target (TKR2 (1 + n) x) -> [target (TKR2 n x)]
runravelToList
                (Concrete (TKR2 (1 + 1) (TKScalar r)) -> [Concrete (TKR 1 r)])
-> Concrete (TKR2 (1 + 1) (TKScalar r)) -> [Concrete (TKR 1 r)]
forall a b. (a -> b) -> a -> b
$ PermR -> Concrete (TKR 2 r) -> Concrete (TKR 2 r)
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownSTK x, BaseTensor target) =>
PermR -> target (TKR2 n x) -> target (TKR2 n x)
rtranspose [Int
1, Int
0] Concrete (TKR 2 r)
outputR
      labels :: [Vector r]
labels = (Concrete (TKR 1 r) -> Vector r)
-> [Concrete (TKR 1 r)] -> [Vector r]
forall a b. (a -> b) -> [a] -> [b]
map Concrete (TKR 1 r) -> Vector r
forall r (n :: Nat). GoodScalar r => Concrete (TKR n r) -> Vector r
rtoVector
               ([Concrete (TKR 1 r)] -> [Vector r])
-> [Concrete (TKR 1 r)] -> [Vector r]
forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) (x :: TK) (target :: Target).
(KnownSTK x, KnownNat n, BaseTensor target) =>
target (TKR2 (1 + n) x) -> [target (TKR2 n x)]
runravelToList @_ @(TKScalar r)
               (Concrete (TKR2 (1 + 1) (TKScalar r)) -> [Concrete (TKR 1 r)])
-> Concrete (TKR2 (1 + 1) (TKScalar r)) -> [Concrete (TKR 1 r)]
forall a b. (a -> b) -> a -> b
$ Ranked 2 r -> Concrete (TKR 2 r)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 2 r
labelR
      matchesLabels :: Vector r -> Vector r -> Int
      matchesLabels :: Vector r -> Vector r -> Int
matchesLabels Vector r
output Vector r
label | Vector r -> Int
forall (v :: Type -> Type) a. (Vector v a, Ord a) => v a -> Int
V.maxIndex Vector r
output Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Vector r -> Int
forall (v :: Type -> Type) a. (Vector v a, Ord a) => v a -> Int
V.maxIndex Vector r
label = Int
1
                                 | Bool
otherwise = Int
0
  in Int -> r
forall a b. (Integral a, Num b) => a -> b
fromIntegral (PermR -> Int
forall a. Num a => [a] -> a
forall (t :: Type -> Type) a. (Foldable t, Num a) => t a -> a
sum ((Vector r -> Vector r -> Int) -> [Vector r] -> [Vector r] -> PermR
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Vector r -> Vector r -> Int
matchesLabels [Vector r]
outputs [Vector r]
labels))
     r -> r -> r
forall a. Fractional a => a -> a -> a
/ Int -> r
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
batch_size