{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -Wno-missing-export-lists #-}
-- | Shaped tensor-based implementation of Convolutional Neural Network
-- for classification of MNIST digits. Sports 2 hidden layers.
--
-- With the current CPU backend it's slow enough that it's hard to see
-- if it trains.
module MnistCnnShaped2 where

import Prelude

import Data.Vector.Generic qualified as V
import Data.Vector.Storable (Vector)
import GHC.TypeLits (fromSNat, type (*), type (+), type (<=), type Div)

import Data.Array.Nested qualified as Nested
import Data.Array.Nested.Shaped.Shape

import HordeAd
import MnistData

-- | The differentiable type of all trainable parameters of this nn.
-- Shaped version, statically checking all dimension widths.
--
-- Due to subtraction complicating posititive number type inference,
-- @kh@ denotes kernel height minus one and analogously @kw@ is kernel
-- width minus one.
type ADCnnMnistParametersShaped
       (target :: Target) h w kh kw c_out n_hidden r =
  ( ( target (TKS '[c_out, 1, kh + 1, kw + 1] r)
    , target (TKS '[c_out] r) )
  , ( target (TKS '[c_out, c_out, kh + 1, kw + 1] r)
    , target (TKS '[c_out] r) )
  , ( target (TKS '[n_hidden, c_out * (h `Div` 4) * (w `Div` 4) ] r)
    , target (TKS '[n_hidden] r) )
  , ( target (TKS '[SizeMnistLabel, n_hidden] r)
    , target (TKS '[SizeMnistLabel] r) )
  )

-- | A single convolutional layer with @relu@ and @maxPool@.
-- The @c_in@ type parameter is going to be alwayst 1, meaning grayscale,
-- but this function works for any @c_in@.
convMnistLayerS
  :: forall kh kw h w c_in c_out batch_size target r.
     ( 1 <= kh
     , 1 <= kw  -- wrongly reported as redundant due to plugins
     , ADReady target, GoodScalar r, Differentiable r )
  => SNat kh -> SNat kw -> SNat h -> SNat w
  -> SNat c_in -> SNat c_out -> SNat batch_size
  -> target (TKS '[c_out, c_in, kh + 1, kw + 1] r)
  -> target (TKS '[batch_size, c_in, h, w] r)
  -> target (TKS '[c_out] r)
  -> target (TKS '[batch_size, c_out, h `Div` 2, w `Div` 2] r)
convMnistLayerS :: forall (kh :: Natural) (kw :: Natural) (h :: Natural)
       (w :: Natural) (c_in :: Natural) (c_out :: Natural)
       (batch_size :: Natural) (target :: Target) r.
((<=) @Natural 1 kh, (<=) @Natural 1 kw, ADReady target,
 GoodScalar r, Differentiable r) =>
SNat kh
-> SNat kw
-> SNat h
-> SNat w
-> SNat c_in
-> SNat c_out
-> SNat batch_size
-> target
     (TKS
        ((':)
           @Natural
           c_out
           ((':)
              @Natural
              c_in
              ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
        r)
-> target
     (TKS
        ((':)
           @Natural
           batch_size
           ((':)
              @Natural c_in ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
        r)
-> target (TKS ((':) @Natural c_out ('[] @Natural)) r)
-> target
     (TKS
        ((':)
           @Natural
           batch_size
           ((':)
              @Natural
              c_out
              ((':)
                 @Natural (Div h 2) ((':) @Natural (Div w 2) ('[] @Natural)))))
        r)
convMnistLayerS SNat kh
SNat SNat kw
SNat SNat h
SNat SNat w
SNat SNat c_in
SNat SNat c_out
SNat SNat batch_size
SNat
                target
  (TKS
     ((':)
        @Natural
        c_out
        ((':)
           @Natural
           c_in
           ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
     r)
ker target
  (TKS
     ((':)
        @Natural
        batch_size
        ((':)
           @Natural c_in ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
     r)
input target (TKS ((':) @Natural c_out ('[] @Natural)) r)
bias =
  let yConv :: target
  (TKS
     ((':)
        @Natural
        batch_size
        ((':)
           @Natural c_out ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
     r)
yConv = target
  (TKS
     ((':)
        @Natural
        c_out
        ((':)
           @Natural
           c_in
           ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
     r)
-> target
     (TKS
        ((':)
           @Natural
           batch_size
           ((':)
              @Natural c_in ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
        r)
-> target
     (TKS
        ((':)
           @Natural
           batch_size
           ((':)
              @Natural c_out ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
        r)
forall (nCoutK :: Natural) (nCinpK :: Natural) (nKh :: Natural)
       (nKw :: Natural) (nImgs :: Natural) (nCinpA :: Natural)
       (nAh :: Natural) (nAw :: Natural) (target :: Target) r
       (shB :: [Natural]) (shK1 :: [Natural]).
(KnownNat nCoutK, KnownNat nCinpK, KnownNat nKh, KnownNat nKw,
 KnownNat nImgs, KnownNat nAh, KnownNat nAw, ADReady target,
 GoodScalar r, (nCinpA :: Natural) ~ (nCinpK :: Natural),
 (shB :: [Natural])
 ~ ((':)
      @Natural
      nImgs
      ((':)
         @Natural
         nCoutK
         ((':)
            @Natural nAh ((':) @Natural nAw ('[] @Natural)))) :: [Natural]),
 (shK1 :: [Natural])
 ~ ((':)
      @Natural
      1
      ((':)
         @Natural
         nCinpA
         ((':)
            @Natural nKh ((':) @Natural nKw ('[] @Natural)))) :: [Natural])) =>
target
  (TKS
     ((':)
        @Natural
        nCoutK
        ((':)
           @Natural
           nCinpK
           ((':) @Natural nKh ((':) @Natural nKw ('[] @Natural)))))
     r)
-> target
     (TKS
        ((':)
           @Natural
           nImgs
           ((':)
              @Natural
              nCinpA
              ((':) @Natural nAh ((':) @Natural nAw ('[] @Natural)))))
        r)
-> target (TKS shB r)
conv2dUnpaddedS target
  (TKS
     ((':)
        @Natural
        c_out
        ((':)
           @Natural
           c_in
           ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
     r)
ker target
  (TKS
     ((':)
        @Natural
        batch_size
        ((':)
           @Natural c_in ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
     r)
input
      biasStretched :: target
  (TKS2
     (PermutePrefix
        @Natural
        ((':)
           @Natural
           0
           ((':)
              @Natural 3 ((':) @Natural 1 ((':) @Natural 2 ('[] @Natural)))))
        ((':)
           @Natural
           batch_size
           ((':)
              @Natural
              h
              ((':) @Natural w ((':) @Natural c_out ('[] @Natural))))))
     (TKScalar r))
biasStretched = forall (perm :: [Natural]) (sh :: [Natural]) (x :: TK)
       (target :: Target).
(KnownPerm perm, IsPermutation perm,
 (<=) @Natural (Rank @Natural perm) (Rank @Natural sh), KnownSTK x,
 BaseTensor target) =>
target (TKS2 sh x)
-> target (TKS2 (PermutePrefix @Natural perm sh) x)
stranspose @'[0, 3, 1, 2]
                      (target
   (TKS2
      ((':)
         @Natural
         batch_size
         ((':)
            @Natural h ((':) @Natural w ((':) @Natural c_out ('[] @Natural)))))
      (TKScalar r))
 -> target
      (TKS2
         (PermutePrefix
            @Natural
            ((':)
               @Natural
               0
               ((':)
                  @Natural 3 ((':) @Natural 1 ((':) @Natural 2 ('[] @Natural)))))
            ((':)
               @Natural
               batch_size
               ((':)
                  @Natural
                  h
                  ((':) @Natural w ((':) @Natural c_out ('[] @Natural))))))
         (TKScalar r)))
-> target
     (TKS2
        ((':)
           @Natural
           batch_size
           ((':)
              @Natural h ((':) @Natural w ((':) @Natural c_out ('[] @Natural)))))
        (TKScalar r))
-> target
     (TKS2
        (PermutePrefix
           @Natural
           ((':)
              @Natural
              0
              ((':)
                 @Natural 3 ((':) @Natural 1 ((':) @Natural 2 ('[] @Natural)))))
           ((':)
              @Natural
              batch_size
              ((':)
                 @Natural
                 h
                 ((':) @Natural w ((':) @Natural c_out ('[] @Natural))))))
        (TKScalar r))
forall a b. (a -> b) -> a -> b
$ target
  (TKS2
     ((':)
        @Natural h ((':) @Natural w ((':) @Natural c_out ('[] @Natural))))
     (TKScalar r))
-> target
     (TKS2
        ((':)
           @Natural
           batch_size
           ((':)
              @Natural h ((':) @Natural w ((':) @Natural c_out ('[] @Natural)))))
        (TKScalar r))
forall (k :: Natural) (sh :: [Natural]) (x :: TK)
       (target :: Target).
(KnownNat k, KnownShS sh, KnownSTK x, BaseTensor target) =>
target (TKS2 sh x) -> target (TKS2 ((':) @Natural k sh) x)
sreplicate {-@batch_size-}
                      (target
   (TKS2
      ((':)
         @Natural h ((':) @Natural w ((':) @Natural c_out ('[] @Natural))))
      (TKScalar r))
 -> target
      (TKS2
         ((':)
            @Natural
            batch_size
            ((':)
               @Natural h ((':) @Natural w ((':) @Natural c_out ('[] @Natural)))))
         (TKScalar r)))
-> target
     (TKS2
        ((':)
           @Natural h ((':) @Natural w ((':) @Natural c_out ('[] @Natural))))
        (TKScalar r))
-> target
     (TKS2
        ((':)
           @Natural
           batch_size
           ((':)
              @Natural h ((':) @Natural w ((':) @Natural c_out ('[] @Natural)))))
        (TKScalar r))
forall a b. (a -> b) -> a -> b
$ target
  (TKS2
     ((':) @Natural w ((':) @Natural c_out ('[] @Natural)))
     (TKScalar r))
-> target
     (TKS2
        ((':)
           @Natural h ((':) @Natural w ((':) @Natural c_out ('[] @Natural))))
        (TKScalar r))
forall (k :: Natural) (sh :: [Natural]) (x :: TK)
       (target :: Target).
(KnownNat k, KnownShS sh, KnownSTK x, BaseTensor target) =>
target (TKS2 sh x) -> target (TKS2 ((':) @Natural k sh) x)
sreplicate {-@h-}
                      (target
   (TKS2
      ((':) @Natural w ((':) @Natural c_out ('[] @Natural)))
      (TKScalar r))
 -> target
      (TKS2
         ((':)
            @Natural h ((':) @Natural w ((':) @Natural c_out ('[] @Natural))))
         (TKScalar r)))
-> target
     (TKS2
        ((':) @Natural w ((':) @Natural c_out ('[] @Natural)))
        (TKScalar r))
-> target
     (TKS2
        ((':)
           @Natural h ((':) @Natural w ((':) @Natural c_out ('[] @Natural))))
        (TKScalar r))
forall a b. (a -> b) -> a -> b
$ target (TKS ((':) @Natural c_out ('[] @Natural)) r)
-> target
     (TKS2
        ((':) @Natural w ((':) @Natural c_out ('[] @Natural)))
        (TKScalar r))
forall (k :: Natural) (sh :: [Natural]) (x :: TK)
       (target :: Target).
(KnownNat k, KnownShS sh, KnownSTK x, BaseTensor target) =>
target (TKS2 sh x) -> target (TKS2 ((':) @Natural k sh) x)
sreplicate {-@w-} target (TKS ((':) @Natural c_out ('[] @Natural)) r)
bias
      yRelu :: target
  (TKS
     ((':)
        @Natural
        batch_size
        ((':)
           @Natural c_out ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
     r)
yRelu = target
  (TKS
     ((':)
        @Natural
        batch_size
        ((':)
           @Natural c_out ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
     r)
-> target
     (TKS
        ((':)
           @Natural
           batch_size
           ((':)
              @Natural c_out ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
        r)
forall (target :: Target) (sh :: [Natural]) r.
(KnownShS sh, ADReady target, GoodScalar r, Differentiable r) =>
target (TKS sh r) -> target (TKS sh r)
reluS (target
   (TKS
      ((':)
         @Natural
         batch_size
         ((':)
            @Natural c_out ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
      r)
 -> target
      (TKS
         ((':)
            @Natural
            batch_size
            ((':)
               @Natural c_out ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
         r))
-> target
     (TKS
        ((':)
           @Natural
           batch_size
           ((':)
              @Natural c_out ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
        r)
-> target
     (TKS
        ((':)
           @Natural
           batch_size
           ((':)
              @Natural c_out ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
        r)
forall a b. (a -> b) -> a -> b
$ target
  (TKS
     ((':)
        @Natural
        batch_size
        ((':)
           @Natural c_out ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
     r)
yConv target
  (TKS
     ((':)
        @Natural
        batch_size
        ((':)
           @Natural c_out ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
     r)
-> target
     (TKS
        ((':)
           @Natural
           batch_size
           ((':)
              @Natural c_out ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
        r)
-> target
     (TKS
        ((':)
           @Natural
           batch_size
           ((':)
              @Natural c_out ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
        r)
forall a. Num a => a -> a -> a
+ target
  (TKS
     ((':)
        @Natural
        batch_size
        ((':)
           @Natural c_out ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
     r)
target
  (TKS2
     (PermutePrefix
        @Natural
        ((':)
           @Natural
           0
           ((':)
              @Natural 3 ((':) @Natural 1 ((':) @Natural 2 ('[] @Natural)))))
        ((':)
           @Natural
           batch_size
           ((':)
              @Natural
              h
              ((':) @Natural w ((':) @Natural c_out ('[] @Natural))))))
     (TKScalar r))
biasStretched
  in forall (ksize :: Natural) (stride :: Natural)
       (batch_size :: Natural) (channels :: Natural) (h :: Natural)
       (w :: Natural) (target :: Target) r (shOut :: [Natural])
       (shK1 :: [Natural]).
(KnownNat ksize, KnownNat stride, KnownNat batch_size,
 KnownNat channels, KnownNat h, KnownNat w, (<=) @Natural 1 stride,
 ADReady target, GoodScalar r,
 (shOut :: [Natural])
 ~ ((':)
      @Natural
      batch_size
      ((':)
         @Natural
         channels
         ((':)
            @Natural
            (Div h stride)
            ((':) @Natural (Div w stride) ('[] @Natural)))) :: [Natural]),
 (shK1 :: [Natural])
 ~ ((':)
      @Natural
      1
      ((':)
         @Natural
         1
         ((':)
            @Natural
            ksize
            ((':) @Natural ksize ('[] @Natural)))) :: [Natural])) =>
target
  (TKS
     ((':)
        @Natural
        batch_size
        ((':)
           @Natural
           channels
           ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
     r)
-> target (TKS shOut r)
maxPool2dUnpaddedS @2 @2 target
  (TKS
     ((':)
        @Natural
        batch_size
        ((':)
           @Natural c_out ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
     r)
yRelu

-- | Composition of two convolutional layers.
convMnistTwoS
  :: forall kh kw h w c_out n_hidden batch_size target r.
       -- @h@ and @w@ are fixed with MNIST data, but not with test data
     ( 1 <= kh  -- kernel height is large enough
     , 1 <= kw  -- kernel width is large enough
     , ADReady target, GoodScalar r, Differentiable r )
  => SNat kh -> SNat kw -> SNat h -> SNat w
  -> SNat c_out -> SNat n_hidden -> SNat batch_size
       -- ^ these boilerplate lines tie type parameters to the corresponding
       -- SNat value parameters denoting basic dimensions
  -> PrimalOf target (TKS '[batch_size, 1, h, w] r)  -- ^ input images
  -> ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
       -- ^ parameters
  -> target (TKS '[SizeMnistLabel, batch_size] r)  -- ^ output classification
convMnistTwoS :: forall (kh :: Natural) (kw :: Natural) (h :: Natural)
       (w :: Natural) (c_out :: Natural) (n_hidden :: Natural)
       (batch_size :: Natural) (target :: Target) r.
((<=) @Natural 1 kh, (<=) @Natural 1 kw, ADReady target,
 GoodScalar r, Differentiable r) =>
SNat kh
-> SNat kw
-> SNat h
-> SNat w
-> SNat c_out
-> SNat n_hidden
-> SNat batch_size
-> PrimalOf
     target
     (TKS
        ((':)
           @Natural
           batch_size
           ((':)
              @Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
        r)
-> ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
-> target
     (TKS
        ((':)
           @Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
        r)
convMnistTwoS kh :: SNat kh
kh@SNat kh
SNat kw :: SNat kw
kw@SNat kw
SNat h :: SNat h
h@SNat h
SNat w :: SNat w
w@SNat w
SNat
              c_out :: SNat c_out
c_out@SNat c_out
SNat _n_hidden :: SNat n_hidden
_n_hidden@SNat n_hidden
SNat batch_size :: SNat batch_size
batch_size@SNat batch_size
SNat
              PrimalOf
  target
  (TKS
     ((':)
        @Natural
        batch_size
        ((':)
           @Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
     r)
input
              ( (target
  (TKS
     ((':)
        @Natural
        c_out
        ((':)
           @Natural
           1
           ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
     r)
ker1, target (TKS ((':) @Natural c_out ('[] @Natural)) r)
bias1), (target
  (TKS
     ((':)
        @Natural
        c_out
        ((':)
           @Natural
           c_out
           ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
     r)
ker2, target (TKS ((':) @Natural c_out ('[] @Natural)) r)
bias2)
              , (target
  (TKS
     ((':)
        @Natural
        n_hidden
        ((':) @Natural ((c_out * Div h 4) * Div w 4) ('[] @Natural)))
     r)
weightsDense, target (TKS ((':) @Natural n_hidden ('[] @Natural)) r)
biasesDense), (target
  (TKS
     ((':)
        @Natural SizeMnistLabel ((':) @Natural n_hidden ('[] @Natural)))
     r)
weightsReadout, target (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r)
biasesReadout) ) =
  forall (a :: Natural) (b :: Natural) r.
(((a :: Natural) ~ (b :: Natural)) => r) -> r
forall {k} (a :: k) (b :: k) r. (((a :: k) ~ (b :: k)) => r) -> r
assumeEquality @(Div (Div w 2) 2) @(Div w 4) ((((Div (Div w 2) 2 :: Natural) ~ (Div w 4 :: Natural)) =>
  target
    (TKS
       ((':)
          @Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
       r))
 -> target
      (TKS
         ((':)
            @Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
         r))
-> (((Div (Div w 2) 2 :: Natural) ~ (Div w 4 :: Natural)) =>
    target
      (TKS
         ((':)
            @Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
         r))
-> target
     (TKS
        ((':)
           @Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
        r)
forall a b. (a -> b) -> a -> b
$
  forall (a :: Natural) (b :: Natural) r.
(((a :: Natural) ~ (b :: Natural)) => r) -> r
forall {k} (a :: k) (b :: k) r. (((a :: k) ~ (b :: k)) => r) -> r
assumeEquality @(Div (Div h 2) 2) @(Div h 4) ((((Div (Div h 2) 2 :: Natural) ~ (Div h 4 :: Natural)) =>
  target
    (TKS
       ((':)
          @Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
       r))
 -> target
      (TKS
         ((':)
            @Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
         r))
-> (((Div (Div h 2) 2 :: Natural) ~ (Div h 4 :: Natural)) =>
    target
      (TKS
         ((':)
            @Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
         r))
-> target
     (TKS
        ((':)
           @Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
        r)
forall a b. (a -> b) -> a -> b
$
  let t1 :: target
  (TKS
     ((':)
        @Natural
        batch_size
        ((':)
           @Natural
           c_out
           ((':)
              @Natural (Div h 2) ((':) @Natural (Div w 2) ('[] @Natural)))))
     r)
t1 = SNat kh
-> SNat kw
-> SNat h
-> SNat w
-> SNat 1
-> SNat c_out
-> SNat batch_size
-> target
     (TKS
        ((':)
           @Natural
           c_out
           ((':)
              @Natural
              1
              ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
        r)
-> target
     (TKS
        ((':)
           @Natural
           batch_size
           ((':)
              @Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
        r)
-> target (TKS ((':) @Natural c_out ('[] @Natural)) r)
-> target
     (TKS
        ((':)
           @Natural
           batch_size
           ((':)
              @Natural
              c_out
              ((':)
                 @Natural (Div h 2) ((':) @Natural (Div w 2) ('[] @Natural)))))
        r)
forall (kh :: Natural) (kw :: Natural) (h :: Natural)
       (w :: Natural) (c_in :: Natural) (c_out :: Natural)
       (batch_size :: Natural) (target :: Target) r.
((<=) @Natural 1 kh, (<=) @Natural 1 kw, ADReady target,
 GoodScalar r, Differentiable r) =>
SNat kh
-> SNat kw
-> SNat h
-> SNat w
-> SNat c_in
-> SNat c_out
-> SNat batch_size
-> target
     (TKS
        ((':)
           @Natural
           c_out
           ((':)
              @Natural
              c_in
              ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
        r)
-> target
     (TKS
        ((':)
           @Natural
           batch_size
           ((':)
              @Natural c_in ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
        r)
-> target (TKS ((':) @Natural c_out ('[] @Natural)) r)
-> target
     (TKS
        ((':)
           @Natural
           batch_size
           ((':)
              @Natural
              c_out
              ((':)
                 @Natural (Div h 2) ((':) @Natural (Div w 2) ('[] @Natural)))))
        r)
convMnistLayerS SNat kh
kh SNat kw
kw SNat h
h SNat w
w
                           (forall (n :: Natural). KnownNat n => SNat n
SNat @1) SNat c_out
c_out SNat batch_size
batch_size
                           target
  (TKS
     ((':)
        @Natural
        c_out
        ((':)
           @Natural
           1
           ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
     r)
ker1 (PrimalOf
  target
  (TKS
     ((':)
        @Natural
        batch_size
        ((':)
           @Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
     r)
-> target
     (TKS
        ((':)
           @Natural
           batch_size
           ((':)
              @Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
        r)
forall (target :: Target) (sh :: [Natural]) (x :: TK).
(BaseTensor target, KnownShS sh, KnownSTK x) =>
PrimalOf target (TKS2 sh x) -> target (TKS2 sh x)
sfromPrimal PrimalOf
  target
  (TKS
     ((':)
        @Natural
        batch_size
        ((':)
           @Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
     r)
input) target (TKS ((':) @Natural c_out ('[] @Natural)) r)
bias1
--      t2 :: target (TKS '[batch_size, c_out, h `Div` 4, w `Div` 4] r)
      t2 :: target
  (TKS
     ((':)
        @Natural
        batch_size
        ((':)
           @Natural
           c_out
           ((':)
              @Natural
              (Div (Div h 2) 2)
              ((':) @Natural (Div (Div w 2) 2) ('[] @Natural)))))
     r)
t2 = SNat kh
-> SNat kw
-> SNat (Div h 2)
-> SNat (Div w 2)
-> SNat c_out
-> SNat c_out
-> SNat batch_size
-> target
     (TKS
        ((':)
           @Natural
           c_out
           ((':)
              @Natural
              c_out
              ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
        r)
-> target
     (TKS
        ((':)
           @Natural
           batch_size
           ((':)
              @Natural
              c_out
              ((':)
                 @Natural (Div h 2) ((':) @Natural (Div w 2) ('[] @Natural)))))
        r)
-> target (TKS ((':) @Natural c_out ('[] @Natural)) r)
-> target
     (TKS
        ((':)
           @Natural
           batch_size
           ((':)
              @Natural
              c_out
              ((':)
                 @Natural
                 (Div (Div h 2) 2)
                 ((':) @Natural (Div (Div w 2) 2) ('[] @Natural)))))
        r)
forall (kh :: Natural) (kw :: Natural) (h :: Natural)
       (w :: Natural) (c_in :: Natural) (c_out :: Natural)
       (batch_size :: Natural) (target :: Target) r.
((<=) @Natural 1 kh, (<=) @Natural 1 kw, ADReady target,
 GoodScalar r, Differentiable r) =>
SNat kh
-> SNat kw
-> SNat h
-> SNat w
-> SNat c_in
-> SNat c_out
-> SNat batch_size
-> target
     (TKS
        ((':)
           @Natural
           c_out
           ((':)
              @Natural
              c_in
              ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
        r)
-> target
     (TKS
        ((':)
           @Natural
           batch_size
           ((':)
              @Natural c_in ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
        r)
-> target (TKS ((':) @Natural c_out ('[] @Natural)) r)
-> target
     (TKS
        ((':)
           @Natural
           batch_size
           ((':)
              @Natural
              c_out
              ((':)
                 @Natural (Div h 2) ((':) @Natural (Div w 2) ('[] @Natural)))))
        r)
convMnistLayerS SNat kh
kh SNat kw
kw (forall (n :: Natural). KnownNat n => SNat n
SNat @(h `Div` 2)) (forall (n :: Natural). KnownNat n => SNat n
SNat @(w `Div` 2))
                           SNat c_out
c_out SNat c_out
c_out SNat batch_size
batch_size
                           target
  (TKS
     ((':)
        @Natural
        c_out
        ((':)
           @Natural
           c_out
           ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
     r)
ker2 target
  (TKS
     ((':)
        @Natural
        batch_size
        ((':)
           @Natural
           c_out
           ((':)
              @Natural (Div h 2) ((':) @Natural (Div w 2) ('[] @Natural)))))
     r)
t1 target (TKS ((':) @Natural c_out ('[] @Natural)) r)
bias2
--      m1 :: target (TKS '[batch_size, c_out * (h `Div` 4) * (w `Div` 4)] r)
      m1 :: target
  (TKS2
     ((':)
        @Natural
        batch_size
        ((':) @Natural ((c_out * Div h 4) * Div w 4) ('[] @Natural)))
     (TKScalar r))
m1 = target
  (TKS2
     ((':)
        @Natural
        batch_size
        ((':)
           @Natural
           c_out
           ((':)
              @Natural (Div h 4) ((':) @Natural (Div w 4) ('[] @Natural)))))
     (TKScalar r))
-> target
     (TKS2
        ((':)
           @Natural
           batch_size
           ((':) @Natural ((c_out * Div h 4) * Div w 4) ('[] @Natural)))
        (TKScalar r))
forall (sh :: [Natural]) (sh2 :: [Natural]) (x :: TK)
       (target :: Target).
((Product sh :: Natural) ~ (Product sh2 :: Natural), KnownShS sh2,
 KnownSTK x, BaseTensor target) =>
target (TKS2 sh x) -> target (TKS2 sh2 x)
sreshape target
  (TKS2
     ((':)
        @Natural
        batch_size
        ((':)
           @Natural
           c_out
           ((':)
              @Natural (Div h 4) ((':) @Natural (Div w 4) ('[] @Natural)))))
     (TKScalar r))
target
  (TKS
     ((':)
        @Natural
        batch_size
        ((':)
           @Natural
           c_out
           ((':)
              @Natural
              (Div (Div h 2) 2)
              ((':) @Natural (Div (Div w 2) 2) ('[] @Natural)))))
     r)
t2
      denseLayer :: target
  (TKS
     ((':) @Natural n_hidden ((':) @Natural batch_size ('[] @Natural)))
     r)
denseLayer = target
  (TKS
     ((':)
        @Natural
        n_hidden
        ((':) @Natural ((c_out * Div h 4) * Div w 4) ('[] @Natural)))
     r)
weightsDense target
  (TKS
     ((':)
        @Natural
        n_hidden
        ((':) @Natural ((c_out * Div h 4) * Div w 4) ('[] @Natural)))
     r)
-> target
     (TKS
        ((':)
           @Natural
           ((c_out * Div h 4) * Div w 4)
           ((':) @Natural batch_size ('[] @Natural)))
        r)
-> target
     (TKS
        ((':) @Natural n_hidden ((':) @Natural batch_size ('[] @Natural)))
        r)
forall (m :: Natural) (n :: Natural) (p :: Natural) r
       (target :: Target).
(KnownNat m, KnownNat n, KnownNat p, GoodScalar r,
 BaseTensor target) =>
target (TKS ((':) @Natural m ((':) @Natural n ('[] @Natural))) r)
-> target
     (TKS ((':) @Natural n ((':) @Natural p ('[] @Natural))) r)
-> target
     (TKS ((':) @Natural m ((':) @Natural p ('[] @Natural))) r)
`smatmul2` target
  (TKS2
     ((':)
        @Natural
        batch_size
        ((':) @Natural ((c_out * Div h 4) * Div w 4) ('[] @Natural)))
     (TKScalar r))
-> target
     (TKS
        ((':)
           @Natural
           ((c_out * Div h 4) * Div w 4)
           ((':) @Natural batch_size ('[] @Natural)))
        r)
forall (n :: Natural) (m :: Natural) (sh :: [Natural]) (x :: TK)
       (target :: Target).
(KnownSTK x, BaseTensor target) =>
target (TKS2 ((':) @Natural n ((':) @Natural m sh)) x)
-> target (TKS2 ((':) @Natural m ((':) @Natural n sh)) x)
str target
  (TKS2
     ((':)
        @Natural
        batch_size
        ((':) @Natural ((c_out * Div h 4) * Div w 4) ('[] @Natural)))
     (TKScalar r))
m1
                   target
  (TKS
     ((':) @Natural n_hidden ((':) @Natural batch_size ('[] @Natural)))
     r)
-> target
     (TKS
        ((':) @Natural n_hidden ((':) @Natural batch_size ('[] @Natural)))
        r)
-> target
     (TKS
        ((':) @Natural n_hidden ((':) @Natural batch_size ('[] @Natural)))
        r)
forall a. Num a => a -> a -> a
+ target
  (TKS2
     ((':) @Natural batch_size ((':) @Natural n_hidden ('[] @Natural)))
     (TKScalar r))
-> target
     (TKS
        ((':) @Natural n_hidden ((':) @Natural batch_size ('[] @Natural)))
        r)
forall (n :: Natural) (m :: Natural) (sh :: [Natural]) (x :: TK)
       (target :: Target).
(KnownSTK x, BaseTensor target) =>
target (TKS2 ((':) @Natural n ((':) @Natural m sh)) x)
-> target (TKS2 ((':) @Natural m ((':) @Natural n sh)) x)
str (target (TKS ((':) @Natural n_hidden ('[] @Natural)) r)
-> target
     (TKS2
        ((':) @Natural batch_size ((':) @Natural n_hidden ('[] @Natural)))
        (TKScalar r))
forall (k :: Natural) (sh :: [Natural]) (x :: TK)
       (target :: Target).
(KnownNat k, KnownShS sh, KnownSTK x, BaseTensor target) =>
target (TKS2 sh x) -> target (TKS2 ((':) @Natural k sh) x)
sreplicate target (TKS ((':) @Natural n_hidden ('[] @Natural)) r)
biasesDense)
  in target
  (TKS
     ((':)
        @Natural SizeMnistLabel ((':) @Natural n_hidden ('[] @Natural)))
     r)
weightsReadout target
  (TKS
     ((':)
        @Natural SizeMnistLabel ((':) @Natural n_hidden ('[] @Natural)))
     r)
-> target
     (TKS
        ((':) @Natural n_hidden ((':) @Natural batch_size ('[] @Natural)))
        r)
-> target
     (TKS
        ((':)
           @Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
        r)
forall (m :: Natural) (n :: Natural) (p :: Natural) r
       (target :: Target).
(KnownNat m, KnownNat n, KnownNat p, GoodScalar r,
 BaseTensor target) =>
target (TKS ((':) @Natural m ((':) @Natural n ('[] @Natural))) r)
-> target
     (TKS ((':) @Natural n ((':) @Natural p ('[] @Natural))) r)
-> target
     (TKS ((':) @Natural m ((':) @Natural p ('[] @Natural))) r)
`smatmul2` target
  (TKS
     ((':) @Natural n_hidden ((':) @Natural batch_size ('[] @Natural)))
     r)
-> target
     (TKS
        ((':) @Natural n_hidden ((':) @Natural batch_size ('[] @Natural)))
        r)
forall (target :: Target) (sh :: [Natural]) r.
(KnownShS sh, ADReady target, GoodScalar r, Differentiable r) =>
target (TKS sh r) -> target (TKS sh r)
reluS target
  (TKS
     ((':) @Natural n_hidden ((':) @Natural batch_size ('[] @Natural)))
     r)
denseLayer
     target
  (TKS
     ((':)
        @Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
     r)
-> target
     (TKS
        ((':)
           @Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
        r)
-> target
     (TKS
        ((':)
           @Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
        r)
forall a. Num a => a -> a -> a
+ target
  (TKS2
     ((':)
        @Natural batch_size ((':) @Natural SizeMnistLabel ('[] @Natural)))
     (TKScalar r))
-> target
     (TKS
        ((':)
           @Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
        r)
forall (n :: Natural) (m :: Natural) (sh :: [Natural]) (x :: TK)
       (target :: Target).
(KnownSTK x, BaseTensor target) =>
target (TKS2 ((':) @Natural n ((':) @Natural m sh)) x)
-> target (TKS2 ((':) @Natural m ((':) @Natural n sh)) x)
str (target (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r)
-> target
     (TKS2
        ((':)
           @Natural batch_size ((':) @Natural SizeMnistLabel ('[] @Natural)))
        (TKScalar r))
forall (k :: Natural) (sh :: [Natural]) (x :: TK)
       (target :: Target).
(KnownNat k, KnownShS sh, KnownSTK x, BaseTensor target) =>
target (TKS2 sh x) -> target (TKS2 ((':) @Natural k sh) x)
sreplicate target (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r)
biasesReadout)

-- | The neural network composed with the SoftMax-CrossEntropy loss function.
convMnistLossFusedS
  :: forall kh kw h w c_out n_hidden batch_size target r.
     ( h ~ SizeMnistHeight, w ~ SizeMnistWidth
     , 1 <= kh
     , 1 <= kw
     , ADReady target, ADReady (PrimalOf target)
     , GoodScalar r, Differentiable r )
  => SNat kh -> SNat kw
  -> SNat c_out
  -> SNat n_hidden -> SNat batch_size
  -> ( PrimalOf target (TKS '[batch_size, h, w] r)
     , PrimalOf target (TKS '[batch_size, SizeMnistLabel] r) )
  -> ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
  -> target (TKScalar r)
convMnistLossFusedS :: forall (kh :: Natural) (kw :: Natural) (h :: Natural)
       (w :: Natural) (c_out :: Natural) (n_hidden :: Natural)
       (batch_size :: Natural) (target :: Target) r.
((h :: Natural) ~ (SizeMnistHeight :: Natural),
 (w :: Natural) ~ (SizeMnistHeight :: Natural), (<=) @Natural 1 kh,
 (<=) @Natural 1 kw, ADReady target, ADReady (PrimalOf target),
 GoodScalar r, Differentiable r) =>
SNat kh
-> SNat kw
-> SNat c_out
-> SNat n_hidden
-> SNat batch_size
-> (PrimalOf
      target
      (TKS
         ((':)
            @Natural
            batch_size
            ((':) @Natural h ((':) @Natural w ('[] @Natural))))
         r),
    PrimalOf
      target
      (TKS
         ((':)
            @Natural batch_size ((':) @Natural SizeMnistLabel ('[] @Natural)))
         r))
-> ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
-> target (TKScalar r)
convMnistLossFusedS kh :: SNat kh
kh@SNat kh
SNat kw :: SNat kw
kw@SNat kw
SNat
                    c_out :: SNat c_out
c_out@SNat c_out
SNat n_hidden :: SNat n_hidden
n_hidden@SNat n_hidden
SNat batch_size :: SNat batch_size
batch_size@SNat batch_size
SNat
                    (PrimalOf
  target
  (TKS
     ((':)
        @Natural
        batch_size
        ((':) @Natural h ((':) @Natural w ('[] @Natural))))
     r)
glyphS, PrimalOf
  target
  (TKS
     ((':)
        @Natural batch_size ((':) @Natural SizeMnistLabel ('[] @Natural)))
     r)
labelS) ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
adparameters =
  let input :: PrimalOf target (TKS '[batch_size, 1, h, w] r)
      input :: PrimalOf
  target
  (TKS
     ((':)
        @Natural
        batch_size
        ((':)
           @Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
     r)
input = PrimalOf
  target
  (TKS2
     ((':)
        @Natural
        batch_size
        ((':)
           @Natural
           SizeMnistHeight
           ((':) @Natural SizeMnistHeight ('[] @Natural))))
     (TKScalar r))
-> PrimalOf
     target
     (TKS2
        ((':)
           @Natural
           batch_size
           ((':)
              @Natural
              1
              ((':)
                 @Natural
                 SizeMnistHeight
                 ((':) @Natural SizeMnistHeight ('[] @Natural)))))
        (TKScalar r))
forall (sh :: [Natural]) (sh2 :: [Natural]) (x :: TK)
       (target :: Target).
((Product sh :: Natural) ~ (Product sh2 :: Natural), KnownShS sh2,
 KnownSTK x, BaseTensor target) =>
target (TKS2 sh x) -> target (TKS2 sh2 x)
sreshape PrimalOf
  target
  (TKS
     ((':)
        @Natural
        batch_size
        ((':) @Natural h ((':) @Natural w ('[] @Natural))))
     r)
PrimalOf
  target
  (TKS2
     ((':)
        @Natural
        batch_size
        ((':)
           @Natural
           SizeMnistHeight
           ((':) @Natural SizeMnistHeight ('[] @Natural))))
     (TKScalar r))
glyphS
      result :: target
  (TKS
     ((':)
        @Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
     r)
result = SNat kh
-> SNat kw
-> SNat h
-> SNat w
-> SNat c_out
-> SNat n_hidden
-> SNat batch_size
-> PrimalOf
     target
     (TKS
        ((':)
           @Natural
           batch_size
           ((':)
              @Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
        r)
-> ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
-> target
     (TKS
        ((':)
           @Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
        r)
forall (kh :: Natural) (kw :: Natural) (h :: Natural)
       (w :: Natural) (c_out :: Natural) (n_hidden :: Natural)
       (batch_size :: Natural) (target :: Target) r.
((<=) @Natural 1 kh, (<=) @Natural 1 kw, ADReady target,
 GoodScalar r, Differentiable r) =>
SNat kh
-> SNat kw
-> SNat h
-> SNat w
-> SNat c_out
-> SNat n_hidden
-> SNat batch_size
-> PrimalOf
     target
     (TKS
        ((':)
           @Natural
           batch_size
           ((':)
              @Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
        r)
-> ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
-> target
     (TKS
        ((':)
           @Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
        r)
convMnistTwoS SNat kh
kh SNat kw
kw (forall (n :: Natural). KnownNat n => SNat n
SNat @h) (forall (n :: Natural). KnownNat n => SNat n
SNat @w)
                             SNat c_out
c_out SNat n_hidden
n_hidden SNat batch_size
batch_size
                             PrimalOf
  target
  (TKS
     ((':)
        @Natural
        batch_size
        ((':)
           @Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
     r)
input ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
adparameters
      targets :: PrimalOf
  target
  (TKS
     ((':)
        @Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
     r)
targets = PrimalOf
  target
  (TKS
     ((':)
        @Natural batch_size ((':) @Natural SizeMnistLabel ('[] @Natural)))
     r)
-> PrimalOf
     target
     (TKS
        ((':)
           @Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
        r)
forall (n :: Natural) (m :: Natural) (sh :: [Natural]) (x :: TK)
       (target :: Target).
(KnownSTK x, BaseTensor target) =>
target (TKS2 ((':) @Natural n ((':) @Natural m sh)) x)
-> target (TKS2 ((':) @Natural m ((':) @Natural n sh)) x)
str PrimalOf
  target
  (TKS
     ((':)
        @Natural batch_size ((':) @Natural SizeMnistLabel ('[] @Natural)))
     r)
labelS
      loss :: target (TKScalar r)
loss = PrimalOf
  target
  (TKS
     ((':)
        @Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
     r)
-> target
     (TKS
        ((':)
           @Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
        r)
-> target (TKScalar r)
forall (target :: Target) (sh :: [Natural]) r.
(ADReady target, ADReady (PrimalOf target), GoodScalar r,
 KnownShS sh, Differentiable r) =>
PrimalOf target (TKS sh r)
-> target (TKS sh r) -> target (TKScalar r)
lossSoftMaxCrossEntropyS PrimalOf
  target
  (TKS
     ((':)
        @Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
     r)
targets target
  (TKS
     ((':)
        @Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
     r)
result
  in PrimalOf target (TKScalar r) -> target (TKScalar r)
forall (target :: Target) r.
(BaseTensor target, GoodScalar r) =>
PrimalOf target (TKScalar r) -> target (TKScalar r)
kfromPrimal (PrimalOf target (TKScalar r) -> PrimalOf target (TKScalar r)
forall a. Fractional a => a -> a
recip (PrimalOf target (TKScalar r) -> PrimalOf target (TKScalar r))
-> PrimalOf target (TKScalar r) -> PrimalOf target (TKScalar r)
forall a b. (a -> b) -> a -> b
$ r -> PrimalOf target (TKScalar r)
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
r -> target (TKScalar r)
kconcrete (r -> PrimalOf target (TKScalar r))
-> r -> PrimalOf target (TKScalar r)
forall a b. (a -> b) -> a -> b
$ Integer -> r
forall a. Num a => Integer -> a
fromInteger (Integer -> r) -> Integer -> r
forall a b. (a -> b) -> a -> b
$ SNat batch_size -> Integer
forall (n :: Natural). SNat n -> Integer
fromSNat SNat batch_size
batch_size) target (TKScalar r) -> target (TKScalar r) -> target (TKScalar r)
forall a. Num a => a -> a -> a
* target (TKScalar r)
loss

-- | A function testing the neural network given testing set of inputs
-- and the trained parameters.
convMnistTestS
  :: forall kh kw h w c_out n_hidden batch_size target r.
     ( h ~ SizeMnistHeight, w ~ SizeMnistWidth
     , 1 <= kh
     , 1 <= kw
     , target ~ Concrete
     , GoodScalar r, Differentiable r )
  => SNat kh -> SNat kw
  -> SNat c_out
  -> SNat n_hidden -> SNat batch_size
  -> MnistDataBatchS batch_size r
  -> ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
  -> r
convMnistTestS :: forall (kh :: Natural) (kw :: Natural) (h :: Natural)
       (w :: Natural) (c_out :: Natural) (n_hidden :: Natural)
       (batch_size :: Natural) (target :: Target) r.
((h :: Natural) ~ (SizeMnistHeight :: Natural),
 (w :: Natural) ~ (SizeMnistHeight :: Natural), (<=) @Natural 1 kh,
 (<=) @Natural 1 kw, (target :: Target) ~ (Concrete :: Target),
 GoodScalar r, Differentiable r) =>
SNat kh
-> SNat kw
-> SNat c_out
-> SNat n_hidden
-> SNat batch_size
-> MnistDataBatchS batch_size r
-> ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
-> r
convMnistTestS SNat kh
_ SNat kw
_ SNat c_out
_ SNat n_hidden
_ batch_size :: SNat batch_size
batch_size@SNat batch_size
SNat MnistDataBatchS batch_size r
_ ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
_
  | SNat batch_size -> Int
forall (n :: Natural). SNat n -> Int
sNatValue SNat batch_size
batch_size Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = r
0
convMnistTestS kh :: SNat kh
kh@SNat kh
SNat kw :: SNat kw
kw@SNat kw
SNat
               c_out :: SNat c_out
c_out@SNat c_out
SNat n_hidden :: SNat n_hidden
n_hidden@SNat n_hidden
SNat batch_size :: SNat batch_size
batch_size@SNat batch_size
SNat
               (Shaped
  ((':)
     @Natural
     batch_size
     ((':)
        @Natural
        SizeMnistHeight
        ((':) @Natural SizeMnistHeight ('[] @Natural))))
  r
glyphS, Shaped
  ((':)
     @Natural batch_size ((':) @Natural SizeMnistLabel ('[] @Natural)))
  r
labelS) ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
testParams =
  let input :: target (TKS '[batch_size, 1, h, w] r)
      input :: target
  (TKS
     ((':)
        @Natural
        batch_size
        ((':)
           @Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
     r)
input = Shaped
  ((':)
     @Natural
     batch_size
     ((':)
        @Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
  r
-> target
     (TKS
        ((':)
           @Natural
           batch_size
           ((':)
              @Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
        r)
forall r (target :: Target) (sh :: [Natural]).
(GoodScalar r, BaseTensor target) =>
Shaped sh r -> target (TKS sh r)
sconcrete (Shaped
   ((':)
      @Natural
      batch_size
      ((':)
         @Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
   r
 -> target
      (TKS
         ((':)
            @Natural
            batch_size
            ((':)
               @Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
         r))
-> Shaped
     ((':)
        @Natural
        batch_size
        ((':)
           @Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
     r
-> target
     (TKS
        ((':)
           @Natural
           batch_size
           ((':)
              @Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
        r)
forall a b. (a -> b) -> a -> b
$ ShS
  ((':)
     @Natural
     batch_size
     ((':)
        @Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
-> Shaped
     ((':)
        @Natural
        batch_size
        ((':)
           @Natural
           SizeMnistHeight
           ((':) @Natural SizeMnistHeight ('[] @Natural))))
     r
-> Shaped
     ((':)
        @Natural
        batch_size
        ((':)
           @Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
     r
forall a (sh :: [Natural]) (sh' :: [Natural]).
(Elt a, (Product sh :: Natural) ~ (Product sh' :: Natural)) =>
ShS sh' -> Shaped sh a -> Shaped sh' a
Nested.sreshape ShS
  ((':)
     @Natural
     batch_size
     ((':)
        @Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
forall (sh :: [Natural]). KnownShS sh => ShS sh
knownShS Shaped
  ((':)
     @Natural
     batch_size
     ((':)
        @Natural
        SizeMnistHeight
        ((':) @Natural SizeMnistHeight ('[] @Natural))))
  r
glyphS
      outputS :: Concrete (TKS '[SizeMnistLabel, batch_size] r)
      outputS :: Concrete
  (TKS
     ((':)
        @Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
     r)
outputS =
        let nn :: ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
               -> target (TKS '[SizeMnistLabel, batch_size] r)
            nn :: ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
-> target
     (TKS
        ((':)
           @Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
        r)
nn = SNat kh
-> SNat kw
-> SNat h
-> SNat w
-> SNat c_out
-> SNat n_hidden
-> SNat batch_size
-> PrimalOf
     target
     (TKS
        ((':)
           @Natural
           batch_size
           ((':)
              @Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
        r)
-> ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
-> target
     (TKS
        ((':)
           @Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
        r)
forall (kh :: Natural) (kw :: Natural) (h :: Natural)
       (w :: Natural) (c_out :: Natural) (n_hidden :: Natural)
       (batch_size :: Natural) (target :: Target) r.
((<=) @Natural 1 kh, (<=) @Natural 1 kw, ADReady target,
 GoodScalar r, Differentiable r) =>
SNat kh
-> SNat kw
-> SNat h
-> SNat w
-> SNat c_out
-> SNat n_hidden
-> SNat batch_size
-> PrimalOf
     target
     (TKS
        ((':)
           @Natural
           batch_size
           ((':)
              @Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
        r)
-> ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
-> target
     (TKS
        ((':)
           @Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
        r)
convMnistTwoS SNat kh
kh SNat kw
kw (forall (n :: Natural). KnownNat n => SNat n
SNat @h) (forall (n :: Natural). KnownNat n => SNat n
SNat @w)
                               SNat c_out
c_out SNat n_hidden
n_hidden SNat batch_size
batch_size
                               target
  (TKS
     ((':)
        @Natural
        batch_size
        ((':)
           @Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
     r)
PrimalOf
  target
  (TKS
     ((':)
        @Natural
        batch_size
        ((':)
           @Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
     r)
input
        in ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
-> target
     (TKS
        ((':)
           @Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
        r)
nn ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
testParams
      outputs :: [Vector r]
outputs = (Concrete (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r)
 -> Vector r)
-> [Concrete (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r)]
-> [Vector r]
forall a b. (a -> b) -> [a] -> [b]
map Concrete (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r)
-> Vector r
forall r (sh :: [Natural]).
GoodScalar r =>
Concrete (TKS sh r) -> Vector r
stoVector ([Concrete (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r)]
 -> [Vector r])
-> [Concrete (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r)]
-> [Vector r]
forall a b. (a -> b) -> a -> b
$ Concrete
  (TKS2
     ((':)
        @Natural batch_size ((':) @Natural SizeMnistLabel ('[] @Natural)))
     (TKScalar r))
-> [Concrete (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r)]
forall (n :: Natural) (sh :: [Natural]) (x :: TK)
       (target :: Target).
(KnownNat n, KnownShS sh, KnownSTK x, BaseTensor target) =>
target (TKS2 ((':) @Natural n sh) x) -> [target (TKS2 sh x)]
sunravelToList
                (Concrete
   (TKS2
      ((':)
         @Natural batch_size ((':) @Natural SizeMnistLabel ('[] @Natural)))
      (TKScalar r))
 -> [Concrete
       (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r)])
-> Concrete
     (TKS2
        ((':)
           @Natural batch_size ((':) @Natural SizeMnistLabel ('[] @Natural)))
        (TKScalar r))
-> [Concrete (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r)]
forall a b. (a -> b) -> a -> b
$ forall (perm :: [Natural]) (sh :: [Natural]) (x :: TK)
       (target :: Target).
(KnownPerm perm, IsPermutation perm,
 (<=) @Natural (Rank @Natural perm) (Rank @Natural sh), KnownSTK x,
 BaseTensor target) =>
target (TKS2 sh x)
-> target (TKS2 (PermutePrefix @Natural perm sh) x)
stranspose @'[1, 0] Concrete
  (TKS
     ((':)
        @Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
     r)
outputS
      labels :: [Vector r]
labels = (Concrete (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r)
 -> Vector r)
-> [Concrete (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r)]
-> [Vector r]
forall a b. (a -> b) -> [a] -> [b]
map Concrete (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r)
-> Vector r
forall r (sh :: [Natural]).
GoodScalar r =>
Concrete (TKS sh r) -> Vector r
stoVector
               ([Concrete (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r)]
 -> [Vector r])
-> [Concrete (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r)]
-> [Vector r]
forall a b. (a -> b) -> a -> b
$ forall (n :: Natural) (sh :: [Natural]) (x :: TK)
       (target :: Target).
(KnownNat n, KnownShS sh, KnownSTK x, BaseTensor target) =>
target (TKS2 ((':) @Natural n sh) x) -> [target (TKS2 sh x)]
sunravelToList @_ @_ @(TKScalar r)
               (Concrete
   (TKS2
      ((':)
         @Natural batch_size ((':) @Natural SizeMnistLabel ('[] @Natural)))
      (TKScalar r))
 -> [Concrete
       (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r)])
-> Concrete
     (TKS2
        ((':)
           @Natural batch_size ((':) @Natural SizeMnistLabel ('[] @Natural)))
        (TKScalar r))
-> [Concrete (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r)]
forall a b. (a -> b) -> a -> b
$ Shaped
  ((':)
     @Natural batch_size ((':) @Natural SizeMnistLabel ('[] @Natural)))
  r
-> Concrete
     (TKS2
        ((':)
           @Natural batch_size ((':) @Natural SizeMnistLabel ('[] @Natural)))
        (TKScalar r))
forall r (target :: Target) (sh :: [Natural]).
(GoodScalar r, BaseTensor target) =>
Shaped sh r -> target (TKS sh r)
sconcrete Shaped
  ((':)
     @Natural batch_size ((':) @Natural SizeMnistLabel ('[] @Natural)))
  r
labelS
      matchesLabels :: Vector r -> Vector r -> Int
      matchesLabels :: Vector r -> Vector r -> Int
matchesLabels Vector r
output Vector r
label | Vector r -> Int
forall (v :: Type -> Type) a. (Vector v a, Ord a) => v a -> Int
V.maxIndex Vector r
output Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Vector r -> Int
forall (v :: Type -> Type) a. (Vector v a, Ord a) => v a -> Int
V.maxIndex Vector r
label = Int
1
                                 | Bool
otherwise = Int
0
  in Int -> r
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([Int] -> Int
forall a. Num a => [a] -> a
forall (t :: Type -> Type) a. (Foldable t, Num a) => t a -> a
sum ((Vector r -> Vector r -> Int) -> [Vector r] -> [Vector r] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Vector r -> Vector r -> Int
matchesLabels [Vector r]
outputs [Vector r]
labels))
     r -> r -> r
forall a. Fractional a => a -> a -> a
/ Integer -> r
forall a. Num a => Integer -> a
fromInteger (SNat batch_size -> Integer
forall (n :: Natural). SNat n -> Integer
fromSNat SNat batch_size
batch_size)