{-# OPTIONS_GHC -Wno-missing-export-lists #-}
-- | Implementation of fully connected neutral network for classification
-- of MNIST digits with sized lists of rank 1 tensors (vectors)
-- as the trainable parameters. Sports 2 hidden layers. No mini-batches.
-- This is an exotic and fundamentally inefficient way of implementing nns,
-- but it's valuable for comparative benchmarking.
module MnistFcnnRanked1 where

import Prelude

import Data.Vector.Generic qualified as V
import GHC.TypeLits (KnownNat, Nat)

import Data.Array.Nested qualified as Nested
import Data.Array.Nested.Ranked.Shape

import HordeAd
import HordeAd.Core.Ops (tfromListR)
import MnistData

-- | The differentiable type of all trainable parameters of this nn.
type ADFcnnMnist1Parameters
       (target :: Target) (widthHidden :: Nat) (widthHidden2 :: Nat) r =
  ( ( ListR widthHidden (target (TKS '[SizeMnistGlyph] r))
    , target (TKS '[widthHidden] r) )
  , ( ListR widthHidden2 (target (TKS '[widthHidden] Float))
    , target (TKS '[widthHidden2] r) )
  , ( ListR SizeMnistLabel (target (TKS '[widthHidden2] r))
    , target (TKS '[SizeMnistLabel] r) )
  )

-- | An ad-hoc matrix multiplication analogue for matrices represented
-- as lists of vectors.
listMatmul1
  :: forall target r w1 w2.
     (ADReady target, GoodScalar r, KnownNat w1)
  => target (TKS '[w1] r) -> ListR w2 (target (TKS '[w1] r))
  -> target (TKS '[w2] r)
{-# INLINE listMatmul1 #-}  -- this doesn't want to specialize
listMatmul1 :: forall (target :: Target) r (w1 :: Nat) (w2 :: Nat).
(ADReady target, GoodScalar r, KnownNat w1) =>
target (TKS ((':) @Nat w1 ('[] @Nat)) r)
-> ListR w2 (target (TKS ((':) @Nat w1 ('[] @Nat)) r))
-> target (TKS ((':) @Nat w2 ('[] @Nat)) r)
listMatmul1 target (TKS ((':) @Nat w1 ('[] @Nat)) r)
x0 ListR w2 (target (TKS ((':) @Nat w1 ('[] @Nat)) r))
weights = target (TKS ((':) @Nat w1 ('[] @Nat)) r)
-> (target (TKS ((':) @Nat w1 ('[] @Nat)) r)
    -> target (TKS ((':) @Nat w2 ('[] @Nat)) r))
-> target (TKS ((':) @Nat w2 ('[] @Nat)) r)
forall (x :: TK) (z :: TK) (target :: Target).
LetTensor target =>
target x -> (target x -> target z) -> target z
tlet target (TKS ((':) @Nat w1 ('[] @Nat)) r)
x0 ((target (TKS ((':) @Nat w1 ('[] @Nat)) r)
  -> target (TKS ((':) @Nat w2 ('[] @Nat)) r))
 -> target (TKS ((':) @Nat w2 ('[] @Nat)) r))
-> (target (TKS ((':) @Nat w1 ('[] @Nat)) r)
    -> target (TKS ((':) @Nat w2 ('[] @Nat)) r))
-> target (TKS ((':) @Nat w2 ('[] @Nat)) r)
forall a b. (a -> b) -> a -> b
$ \target (TKS ((':) @Nat w1 ('[] @Nat)) r)
x ->
  let f :: target (TKS '[w1] r) -> target (TKS '[] r)
      f :: target (TKS ((':) @Nat w1 ('[] @Nat)) r)
-> target (TKS ('[] @Nat) r)
f target (TKS ((':) @Nat w1 ('[] @Nat)) r)
v = target (TKS ((':) @Nat w1 ('[] @Nat)) r)
v target (TKS ((':) @Nat w1 ('[] @Nat)) r)
-> target (TKS ((':) @Nat w1 ('[] @Nat)) r)
-> target (TKS ('[] @Nat) r)
forall (sh :: [Nat]) r (target :: Target).
(KnownShS sh, GoodScalar r, BaseTensor target) =>
target (TKS sh r) -> target (TKS sh r) -> target (TKS ('[] @Nat) r)
`sdot0` target (TKS ((':) @Nat w1 ('[] @Nat)) r)
x
  in SingletonTK (TKS ('[] @Nat) r)
-> ListR w2 (target (TKS ('[] @Nat) r))
-> target (BuildTensorKind w2 (TKS ('[] @Nat) r))
forall (y :: TK) (k :: Nat).
SingletonTK y -> ListR k (target y) -> target (BuildTensorKind k y)
forall (target :: Target) (y :: TK) (k :: Nat).
BaseTensor target =>
SingletonTK y -> ListR k (target y) -> target (BuildTensorKind k y)
tfromListR SingletonTK (TKS ('[] @Nat) r)
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK (ListR w2 (target (TKS ('[] @Nat) r))
 -> target (BuildTensorKind w2 (TKS ('[] @Nat) r)))
-> ListR w2 (target (TKS ('[] @Nat) r))
-> target (BuildTensorKind w2 (TKS ('[] @Nat) r))
forall a b. (a -> b) -> a -> b
$ target (TKS ((':) @Nat w1 ('[] @Nat)) r)
-> target (TKS ('[] @Nat) r)
f (target (TKS ((':) @Nat w1 ('[] @Nat)) r)
 -> target (TKS ('[] @Nat) r))
-> ListR w2 (target (TKS ((':) @Nat w1 ('[] @Nat)) r))
-> ListR w2 (target (TKS ('[] @Nat) r))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> ListR w2 (target (TKS ((':) @Nat w1 ('[] @Nat)) r))
weights

-- | Fully connected neural network for the MNIST digit classification task.
-- There are two hidden layers and both use the same activation function.
-- The output layer uses a different activation function.
-- The widths of the two hidden layers are @widthHidden@ and @widthHidden2@,
-- respectively.
afcnnMnist1 :: forall target r widthHidden widthHidden2.
               (ADReady target, GoodScalar r, Differentiable r)
            => (forall n. KnownNat n
                => target (TKS '[n] r) -> target (TKS '[n] r))
            -> (target (TKS '[SizeMnistLabel] r)
                -> target (TKS '[SizeMnistLabel] r))
            -> SNat widthHidden -> SNat widthHidden2
            -> target (TKS '[SizeMnistGlyph] r)
            -> ADFcnnMnist1Parameters target widthHidden widthHidden2 r
            -> target (TKR 1 r)
afcnnMnist1 :: forall (target :: Target) r (widthHidden :: Nat)
       (widthHidden2 :: Nat).
(ADReady target, GoodScalar r, Differentiable r) =>
(forall (n :: Nat).
 KnownNat n =>
 target (TKS ((':) @Nat n ('[] @Nat)) r)
 -> target (TKS ((':) @Nat n ('[] @Nat)) r))
-> (target (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)
    -> target (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))
-> SNat widthHidden
-> SNat widthHidden2
-> target (TKS ((':) @Nat SizeMnistGlyph ('[] @Nat)) r)
-> ADFcnnMnist1Parameters target widthHidden widthHidden2 r
-> target (TKR 1 r)
afcnnMnist1 forall (n :: Nat).
KnownNat n =>
target (TKS ((':) @Nat n ('[] @Nat)) r)
-> target (TKS ((':) @Nat n ('[] @Nat)) r)
factivationHidden target (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)
-> target (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)
factivationOutput SNat widthHidden
SNat SNat widthHidden2
SNat
            target (TKS ((':) @Nat SizeMnistGlyph ('[] @Nat)) r)
datum ((ListR
  widthHidden (target (TKS ((':) @Nat SizeMnistGlyph ('[] @Nat)) r))
hidden, target (TKS ((':) @Nat widthHidden ('[] @Nat)) r)
bias), (ListR
  widthHidden2
  (target (TKS ((':) @Nat widthHidden ('[] @Nat)) Float))
hidden2, target (TKS ((':) @Nat widthHidden2 ('[] @Nat)) r)
bias2), (ListR
  SizeMnistLabel (target (TKS ((':) @Nat widthHidden2 ('[] @Nat)) r))
readout, target (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)
biasr)) =
  let hiddenLayer1 :: target (TKS ((':) @Nat widthHidden ('[] @Nat)) r)
hiddenLayer1 = target (TKS ((':) @Nat 784 ('[] @Nat)) r)
-> ListR widthHidden (target (TKS ((':) @Nat 784 ('[] @Nat)) r))
-> target (TKS ((':) @Nat widthHidden ('[] @Nat)) r)
forall (target :: Target) r (w1 :: Nat) (w2 :: Nat).
(ADReady target, GoodScalar r, KnownNat w1) =>
target (TKS ((':) @Nat w1 ('[] @Nat)) r)
-> ListR w2 (target (TKS ((':) @Nat w1 ('[] @Nat)) r))
-> target (TKS ((':) @Nat w2 ('[] @Nat)) r)
listMatmul1 target (TKS ((':) @Nat 784 ('[] @Nat)) r)
target (TKS ((':) @Nat SizeMnistGlyph ('[] @Nat)) r)
datum ListR widthHidden (target (TKS ((':) @Nat 784 ('[] @Nat)) r))
ListR
  widthHidden (target (TKS ((':) @Nat SizeMnistGlyph ('[] @Nat)) r))
hidden target (TKS ((':) @Nat widthHidden ('[] @Nat)) r)
-> target (TKS ((':) @Nat widthHidden ('[] @Nat)) r)
-> target (TKS ((':) @Nat widthHidden ('[] @Nat)) r)
forall a. Num a => a -> a -> a
+ target (TKS ((':) @Nat widthHidden ('[] @Nat)) r)
bias
      nonlinearLayer1 :: target (TKS ((':) @Nat widthHidden ('[] @Nat)) r)
nonlinearLayer1 = target (TKS ((':) @Nat widthHidden ('[] @Nat)) r)
-> target (TKS ((':) @Nat widthHidden ('[] @Nat)) r)
forall (n :: Nat).
KnownNat n =>
target (TKS ((':) @Nat n ('[] @Nat)) r)
-> target (TKS ((':) @Nat n ('[] @Nat)) r)
factivationHidden target (TKS ((':) @Nat widthHidden ('[] @Nat)) r)
hiddenLayer1
      hiddenLayer2 :: target (TKS ((':) @Nat widthHidden2 ('[] @Nat)) r)
hiddenLayer2 = target (TKS ((':) @Nat widthHidden2 ('[] @Nat)) Float)
-> target (TKS ((':) @Nat widthHidden2 ('[] @Nat)) r)
forall r1 r2 (target :: Target) (sh :: [Nat]).
(RealFrac r1, GoodScalar r1, RealFrac r2, GoodScalar r2,
 BaseTensor target) =>
target (TKS sh r1) -> target (TKS sh r2)
scast (target (TKS ((':) @Nat widthHidden ('[] @Nat)) Float)
-> ListR
     widthHidden2
     (target (TKS ((':) @Nat widthHidden ('[] @Nat)) Float))
-> target (TKS ((':) @Nat widthHidden2 ('[] @Nat)) Float)
forall (target :: Target) r (w1 :: Nat) (w2 :: Nat).
(ADReady target, GoodScalar r, KnownNat w1) =>
target (TKS ((':) @Nat w1 ('[] @Nat)) r)
-> ListR w2 (target (TKS ((':) @Nat w1 ('[] @Nat)) r))
-> target (TKS ((':) @Nat w2 ('[] @Nat)) r)
listMatmul1 (target (TKS ((':) @Nat widthHidden ('[] @Nat)) r)
-> target (TKS ((':) @Nat widthHidden ('[] @Nat)) Float)
forall r1 r2 (target :: Target) (sh :: [Nat]).
(RealFrac r1, GoodScalar r1, RealFrac r2, GoodScalar r2,
 BaseTensor target) =>
target (TKS sh r1) -> target (TKS sh r2)
scast target (TKS ((':) @Nat widthHidden ('[] @Nat)) r)
nonlinearLayer1) ListR
  widthHidden2
  (target (TKS ((':) @Nat widthHidden ('[] @Nat)) Float))
hidden2) target (TKS ((':) @Nat widthHidden2 ('[] @Nat)) r)
-> target (TKS ((':) @Nat widthHidden2 ('[] @Nat)) r)
-> target (TKS ((':) @Nat widthHidden2 ('[] @Nat)) r)
forall a. Num a => a -> a -> a
+ target (TKS ((':) @Nat widthHidden2 ('[] @Nat)) r)
bias2
      nonlinearLayer2 :: target (TKS ((':) @Nat widthHidden2 ('[] @Nat)) r)
nonlinearLayer2 = target (TKS ((':) @Nat widthHidden2 ('[] @Nat)) r)
-> target (TKS ((':) @Nat widthHidden2 ('[] @Nat)) r)
forall (n :: Nat).
KnownNat n =>
target (TKS ((':) @Nat n ('[] @Nat)) r)
-> target (TKS ((':) @Nat n ('[] @Nat)) r)
factivationHidden target (TKS ((':) @Nat widthHidden2 ('[] @Nat)) r)
hiddenLayer2
      outputLayer :: target (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)
outputLayer = target (TKS ((':) @Nat widthHidden2 ('[] @Nat)) r)
-> ListR
     SizeMnistLabel (target (TKS ((':) @Nat widthHidden2 ('[] @Nat)) r))
-> target (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)
forall (target :: Target) r (w1 :: Nat) (w2 :: Nat).
(ADReady target, GoodScalar r, KnownNat w1) =>
target (TKS ((':) @Nat w1 ('[] @Nat)) r)
-> ListR w2 (target (TKS ((':) @Nat w1 ('[] @Nat)) r))
-> target (TKS ((':) @Nat w2 ('[] @Nat)) r)
listMatmul1 target (TKS ((':) @Nat widthHidden2 ('[] @Nat)) r)
nonlinearLayer2 ListR
  SizeMnistLabel (target (TKS ((':) @Nat widthHidden2 ('[] @Nat)) r))
readout target (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)
-> target (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)
-> target (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)
forall a. Num a => a -> a -> a
+ target (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)
biasr
      result :: target (TKS '[SizeMnistLabel] r)
      result :: target (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)
result = target (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)
-> target (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)
factivationOutput target (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)
outputLayer
  in target (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)
-> target
     (TKR2
        (Rank @Nat ((':) @Nat SizeMnistLabel ('[] @Nat))) (TKScalar r))
forall (sh :: [Nat]) (x :: TK).
(KnownShS sh, KnownSTK x) =>
target (TKS2 sh x) -> target (TKR2 (Rank @Nat sh) x)
forall (target :: Target) (sh :: [Nat]) (x :: TK).
(ConvertTensor target, KnownShS sh, KnownSTK x) =>
target (TKS2 sh x) -> target (TKR2 (Rank @Nat sh) x)
rfromS target (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)
result

-- | The neural network applied to concrete activation functions
-- and composed with the appropriate loss function.
afcnnMnistLoss1
  :: (ADReady target, GoodScalar r, Differentiable r)
  => SNat widthHidden -> SNat widthHidden2
  -> (target (TKR 1 r), target (TKR 1 r))
  -> ADFcnnMnist1Parameters target widthHidden widthHidden2 r
  -> target (TKScalar r)
afcnnMnistLoss1 :: forall (target :: Target) r (widthHidden :: Nat)
       (widthHidden2 :: Nat).
(ADReady target, GoodScalar r, Differentiable r) =>
SNat widthHidden
-> SNat widthHidden2
-> (target (TKR 1 r), target (TKR 1 r))
-> ADFcnnMnist1Parameters target widthHidden widthHidden2 r
-> target (TKScalar r)
afcnnMnistLoss1 SNat widthHidden
widthHidden SNat widthHidden2
widthHidden2 (target (TKR 1 r)
datum, target (TKR 1 r)
target) ADFcnnMnist1Parameters target widthHidden widthHidden2 r
adparams =
  let result :: target (TKR 1 r)
result = (forall (n :: Nat).
 KnownNat n =>
 target (TKS ((':) @Nat n ('[] @Nat)) r)
 -> target (TKS ((':) @Nat n ('[] @Nat)) r))
-> (target (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)
    -> target (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))
-> SNat widthHidden
-> SNat widthHidden2
-> target (TKS ((':) @Nat SizeMnistGlyph ('[] @Nat)) r)
-> ADFcnnMnist1Parameters target widthHidden widthHidden2 r
-> target (TKR 1 r)
forall (target :: Target) r (widthHidden :: Nat)
       (widthHidden2 :: Nat).
(ADReady target, GoodScalar r, Differentiable r) =>
(forall (n :: Nat).
 KnownNat n =>
 target (TKS ((':) @Nat n ('[] @Nat)) r)
 -> target (TKS ((':) @Nat n ('[] @Nat)) r))
-> (target (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)
    -> target (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))
-> SNat widthHidden
-> SNat widthHidden2
-> target (TKS ((':) @Nat SizeMnistGlyph ('[] @Nat)) r)
-> ADFcnnMnist1Parameters target widthHidden widthHidden2 r
-> target (TKR 1 r)
afcnnMnist1 target (TKS ((':) @Nat n ('[] @Nat)) r)
-> target (TKS ((':) @Nat n ('[] @Nat)) r)
forall (n :: Nat).
KnownNat n =>
target (TKS ((':) @Nat n ('[] @Nat)) r)
-> target (TKS ((':) @Nat n ('[] @Nat)) r)
forall (target :: Target) r (sh :: [Nat]).
(BaseTensor target, LetTensor target, BaseTensor (PrimalOf target),
 KnownShS sh, GoodScalar r, Differentiable r) =>
target (TKS sh r) -> target (TKS sh r)
logisticS target (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)
-> target (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)
forall (target :: Target) (sh :: [Nat]) r.
(KnownShS sh, BaseTensor target, LetTensor target, GoodScalar r,
 Differentiable r) =>
target (TKS sh r) -> target (TKS sh r)
softMax1S
                           SNat widthHidden
widthHidden SNat widthHidden2
widthHidden2 (target (TKR2 (Rank @Nat ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
-> target (TKS2 ((':) @Nat 784 ('[] @Nat)) (TKScalar r))
forall (sh :: [Nat]) (x :: TK).
(KnownShS sh, KnownSTK x) =>
target (TKR2 (Rank @Nat sh) x) -> target (TKS2 sh x)
forall (target :: Target) (sh :: [Nat]) (x :: TK).
(ConvertTensor target, KnownShS sh, KnownSTK x) =>
target (TKR2 (Rank @Nat sh) x) -> target (TKS2 sh x)
sfromR target (TKR 1 r)
target (TKR2 (Rank @Nat ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
datum) ADFcnnMnist1Parameters target widthHidden widthHidden2 r
adparams
  in target (TKR 1 r) -> target (TKR 1 r) -> target (TKScalar r)
forall (target :: Target) (n :: Nat) r.
(BaseTensor target, ConvertTensor target, KnownNat n, GoodScalar r,
 Differentiable r) =>
target (TKR n r) -> target (TKR n r) -> target (TKScalar r)
lossCrossEntropyV target (TKR 1 r)
target target (TKR 1 r)
result
-- {-# SPECIALIZE afcnnMnistLoss1 :: (GoodScalar r, Differentiable r) => SNat widthHidden -> SNat widthHidden2 -> (AstTensor AstMethodLet FullSpan (TKR 1 r), AstTensor AstMethodLet FullSpan (TKR 1 r)) -> ADFcnnMnist1Parameters (AstTensor AstMethodLet FullSpan) widthHidden widthHidden2 r -> AstTensor AstMethodLet FullSpan (TKScalar r) #-}
{-# SPECIALIZE afcnnMnistLoss1 :: SNat widthHidden -> SNat widthHidden2 -> (AstTensor AstMethodLet FullSpan (TKR 1 Double), AstTensor AstMethodLet FullSpan (TKR 1 Double)) -> ADFcnnMnist1Parameters (AstTensor AstMethodLet FullSpan) widthHidden widthHidden2 Double -> AstTensor AstMethodLet FullSpan (TKScalar Double) #-}
{-# SPECIALIZE afcnnMnistLoss1 :: SNat widthHidden -> SNat widthHidden2 -> (AstTensor AstMethodLet FullSpan (TKR 1 Float), AstTensor AstMethodLet FullSpan (TKR 1 Float)) -> ADFcnnMnist1Parameters (AstTensor AstMethodLet FullSpan) widthHidden widthHidden2 Float -> AstTensor AstMethodLet FullSpan (TKScalar Float) #-}
{-# SPECIALIZE afcnnMnistLoss1 :: SNat widthHidden -> SNat widthHidden2 -> (ADVal Concrete (TKR 1 Double), ADVal Concrete (TKR 1 Double)) -> ADFcnnMnist1Parameters (ADVal Concrete) widthHidden widthHidden2 Double -> ADVal Concrete (TKScalar Double) #-}
{-# SPECIALIZE afcnnMnistLoss1 :: SNat widthHidden -> SNat widthHidden2 -> (ADVal Concrete (TKR 1 Float), ADVal Concrete (TKR 1 Float)) -> ADFcnnMnist1Parameters (ADVal Concrete) widthHidden widthHidden2 Float -> ADVal Concrete (TKScalar Float) #-}

-- | A function testing the neural network given testing set of inputs
-- and the trained parameters.
afcnnMnistTest1
  :: forall target widthHidden widthHidden2 r.
     (target ~ Concrete, GoodScalar r, Differentiable r)
  => SNat widthHidden -> SNat widthHidden2
  -> [MnistDataLinearR r]
  -> ADFcnnMnist1Parameters target widthHidden widthHidden2 r
  -> r
afcnnMnistTest1 :: forall (target :: Target) (widthHidden :: Nat)
       (widthHidden2 :: Nat) r.
((target :: Target) ~ (Concrete :: Target), GoodScalar r,
 Differentiable r) =>
SNat widthHidden
-> SNat widthHidden2
-> [MnistDataLinearR r]
-> ADFcnnMnist1Parameters target widthHidden widthHidden2 r
-> r
afcnnMnistTest1 SNat widthHidden
_ SNat widthHidden2
_ [] ADFcnnMnist1Parameters target widthHidden widthHidden2 r
_ = r
0
afcnnMnistTest1 SNat widthHidden
widthHidden SNat widthHidden2
widthHidden2 [MnistDataLinearR r]
dataList ADFcnnMnist1Parameters target widthHidden widthHidden2 r
testParams =
  let matchesLabels :: MnistDataLinearR r -> Bool
      matchesLabels :: MnistDataLinearR r -> Bool
matchesLabels (Ranked 1 r
glyph, Ranked 1 r
label) =
        let glyph1 :: target (TKR 1 r)
glyph1 = Ranked 1 r -> target (TKR 1 r)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 1 r
glyph
            nn :: ADFcnnMnist1Parameters target widthHidden widthHidden2 r
               -> target (TKR 1 r)
            nn :: ADFcnnMnist1Parameters target widthHidden widthHidden2 r
-> target (TKR 1 r)
nn = (forall (n :: Nat).
 KnownNat n =>
 target (TKS ((':) @Nat n ('[] @Nat)) r)
 -> target (TKS ((':) @Nat n ('[] @Nat)) r))
-> (target (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)
    -> target (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))
-> SNat widthHidden
-> SNat widthHidden2
-> target (TKS ((':) @Nat SizeMnistGlyph ('[] @Nat)) r)
-> ADFcnnMnist1Parameters target widthHidden widthHidden2 r
-> target (TKR 1 r)
forall (target :: Target) r (widthHidden :: Nat)
       (widthHidden2 :: Nat).
(ADReady target, GoodScalar r, Differentiable r) =>
(forall (n :: Nat).
 KnownNat n =>
 target (TKS ((':) @Nat n ('[] @Nat)) r)
 -> target (TKS ((':) @Nat n ('[] @Nat)) r))
-> (target (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)
    -> target (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))
-> SNat widthHidden
-> SNat widthHidden2
-> target (TKS ((':) @Nat SizeMnistGlyph ('[] @Nat)) r)
-> ADFcnnMnist1Parameters target widthHidden widthHidden2 r
-> target (TKR 1 r)
afcnnMnist1 target (TKS ((':) @Nat n ('[] @Nat)) r)
-> target (TKS ((':) @Nat n ('[] @Nat)) r)
forall (n :: Nat).
KnownNat n =>
target (TKS ((':) @Nat n ('[] @Nat)) r)
-> target (TKS ((':) @Nat n ('[] @Nat)) r)
forall (target :: Target) r (sh :: [Nat]).
(BaseTensor target, LetTensor target, BaseTensor (PrimalOf target),
 KnownShS sh, GoodScalar r, Differentiable r) =>
target (TKS sh r) -> target (TKS sh r)
logisticS target (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)
-> target (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)
forall (target :: Target) (sh :: [Nat]) r.
(KnownShS sh, BaseTensor target, LetTensor target, GoodScalar r,
 Differentiable r) =>
target (TKS sh r) -> target (TKS sh r)
softMax1S
                             SNat widthHidden
widthHidden SNat widthHidden2
widthHidden2 (target (TKR2 (Rank @Nat ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
-> target (TKS2 ((':) @Nat 784 ('[] @Nat)) (TKScalar r))
forall (sh :: [Nat]) (x :: TK).
(KnownShS sh, KnownSTK x) =>
target (TKR2 (Rank @Nat sh) x) -> target (TKS2 sh x)
forall (target :: Target) (sh :: [Nat]) (x :: TK).
(ConvertTensor target, KnownShS sh, KnownSTK x) =>
target (TKR2 (Rank @Nat sh) x) -> target (TKS2 sh x)
sfromR target (TKR 1 r)
target (TKR2 (Rank @Nat ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
glyph1)
            v :: Vector r
v = Concrete (TKR 1 r) -> Vector r
forall r (n :: Nat). GoodScalar r => Concrete (TKR n r) -> Vector r
rtoVector (Concrete (TKR 1 r) -> Vector r) -> Concrete (TKR 1 r) -> Vector r
forall a b. (a -> b) -> a -> b
$ ADFcnnMnist1Parameters target widthHidden widthHidden2 r
-> target (TKR 1 r)
nn ADFcnnMnist1Parameters target widthHidden widthHidden2 r
testParams
        in Vector r -> Int
forall (v :: Type -> Type) a. (Vector v a, Ord a) => v a -> Int
V.maxIndex Vector r
v Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Vector r -> Int
forall (v :: Type -> Type) a. (Vector v a, Ord a) => v a -> Int
V.maxIndex (Ranked 1 r -> Vector r
forall a (n :: Nat). PrimElt a => Ranked n a -> Vector a
Nested.rtoVector Ranked 1 r
label)
  in Int -> r
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([MnistDataLinearR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length ((MnistDataLinearR r -> Bool)
-> [MnistDataLinearR r] -> [MnistDataLinearR r]
forall a. (a -> Bool) -> [a] -> [a]
filter MnistDataLinearR r -> Bool
matchesLabels [MnistDataLinearR r]
dataList))
     r -> r -> r
forall a. Fractional a => a -> a -> a
/ Int -> r
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([MnistDataLinearR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataLinearR r]
dataList)