{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
-- | Tests of "MnistCnnShaped2" convolutional neural network
-- using a few different optimization pipelines.
--
-- With the current CPU backend it's slow enough that it's hard to see
-- if it trains.
module TestMnistCNNS
  ( testTrees
  ) where

import Prelude

import Control.Monad (foldM, unless)
import GHC.TypeLits (KnownNat, type (<=))
import System.IO (hPutStrLn, stderr)
import System.Random
import Test.Tasty
import Test.Tasty.HUnit hiding (assert)
import Text.Printf

import Data.Array.Nested.Shaped.Shape

import HordeAd
import HordeAd.Core.Adaptor
import HordeAd.Core.AstEnv
import HordeAd.Core.AstFreshId
import HordeAd.Core.AstInterpret

import EqEpsilon

import MnistCnnShaped2 qualified
import MnistData

-- TODO: optimize enough that it can run for one full epoch in reasonable time
-- and then verify it trains down to ~20% validation error in a short enough
-- time to include such a training run in tests.

testTrees :: [TestTree]
testTrees :: [TestTree]
testTrees = [ TestTree
tensorADValMnistTestsCNNSA
            , TestTree
tensorADValMnistTestsCNNSI
            , TestTree
tensorADValMnistTestsCNNSO
            ]

type XParams kh kw c_out n_hidden r =
  X (MnistCnnShaped2.ADCnnMnistParametersShaped
       Concrete SizeMnistHeight SizeMnistWidth kh kw c_out n_hidden r)

-- POPL differentiation, straight via the ADVal instance of RankedTensor,
-- which side-steps vectorization.
mnistTestCaseCNNSA
  :: forall kh kw r.
     ( 1 <= kh, 1 <= kw
     , Differentiable r, GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r )
  => String
  -> Int -> Int -> SNat kh -> SNat kw -> Int -> Int -> Int -> Int -> r
  -> TestTree
mnistTestCaseCNNSA :: forall (kh :: Natural) (kw :: Natural) r.
((<=) @Natural 1 kh, (<=) @Natural 1 kw, Differentiable r,
 GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> SNat kh
-> SNat kw
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNSA String
prefix Int
epochs Int
maxBatches kh :: SNat kh
kh@SNat kh
SNat kw :: SNat kw
kw@SNat kw
SNat Int
c_outInt Int
n_hiddenInt
                   Int
miniBatchSizeInt Int
totalBatchSize r
expected =
  Int
-> (forall (n :: Natural). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Natural). KnownNat n => SNat n -> r) -> r
withSNat Int
c_outInt ((forall (n :: Natural). KnownNat n => SNat n -> TestTree)
 -> TestTree)
-> (forall (n :: Natural). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
_c_outSNat :: SNat c_out) ->
  Int
-> (forall (n :: Natural). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Natural). KnownNat n => SNat n -> r) -> r
withSNat Int
n_hiddenInt ((forall (n :: Natural). KnownNat n => SNat n -> TestTree)
 -> TestTree)
-> (forall (n :: Natural). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
_n_hiddenSNat :: SNat n_hidden) ->
  Int
-> (forall (n :: Natural). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Natural). KnownNat n => SNat n -> r) -> r
withSNat Int
miniBatchSizeInt ((forall (n :: Natural). KnownNat n => SNat n -> TestTree)
 -> TestTree)
-> (forall (n :: Natural). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
miniBatchSize :: SNat miniBatchSize) ->
  let targetInit :: Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
targetInit =
        (Concrete
   (X (ADCnnMnistParametersShaped
         Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
 StdGen)
-> Concrete
     (X (ADCnnMnistParametersShaped
           Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
forall a b. (a, b) -> a
fst ((Concrete
    (X (ADCnnMnistParametersShaped
          Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
  StdGen)
 -> Concrete
      (X (ADCnnMnistParametersShaped
            Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
-> (Concrete
      (X (ADCnnMnistParametersShaped
            Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
    StdGen)
-> Concrete
     (X (ADCnnMnistParametersShaped
           Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
forall a b. (a -> b) -> a -> b
$ forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue
                @(Concrete (X (MnistCnnShaped2.ADCnnMnistParametersShaped
                                 Concrete SizeMnistHeight SizeMnistWidth
                                 kh kw c_out n_hidden r)))
                Double
0.4 (Int -> StdGen
mkStdGen Int
44)
      name :: String
name = String
prefix String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": "
             String -> String -> String
forall a. [a] -> [a] -> [a]
++ [String] -> String
unwords [ Int -> String
forall a. Show a => a -> String
show Int
epochs, Int -> String
forall a. Show a => a -> String
show Int
maxBatches
                        , Int -> String
forall a. Show a => a -> String
show (SNat kh -> Int
forall (n :: Natural). SNat n -> Int
sNatValue SNat kh
kh), Int -> String
forall a. Show a => a -> String
show (SNat kw -> Int
forall (n :: Natural). SNat n -> Int
sNatValue SNat kw
kw)
                        , Int -> String
forall a. Show a => a -> String
show Int
c_outInt, Int -> String
forall a. Show a => a -> String
show Int
n_hiddenInt
                        , Int -> String
forall a. Show a => a -> String
show Int
miniBatchSizeInt
                        , Int -> String
forall a. Show a => a -> String
show (Int -> String) -> Int -> String
forall a b. (a -> b) -> a -> b
$ SingletonTK
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK (SingletonTK
   (X (ADCnnMnistParametersShaped
         Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
 -> Int)
-> SingletonTK
     (X (ADCnnMnistParametersShaped
           Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams kh kw c_out n_hidden r)
                        , Int -> String
forall a. Show a => a -> String
show (SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':)
                    @Natural
                    n
                    ((':)
                       @Natural
                       1
                       ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':)
                    @Natural
                    n
                    ((':)
                       @Natural
                       n
                       ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                 (TKScalar r))
              (TKS2
                 ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
targetInit) ]
      ftest :: KnownNat batch_size
            => MnistDataBatchS batch_size r
            -> Concrete (XParams kh kw c_out n_hidden r) -> r
      ftest :: forall (batch_size :: Natural).
KnownNat batch_size =>
MnistDataBatchS batch_size r
-> Concrete
     (X (ADCnnMnistParametersShaped
           Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> r
ftest @batch_size MnistDataBatchS batch_size r
mnistData Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
pars =
        SNat kh
-> SNat kw
-> SNat n
-> SNat n
-> SNat batch_size
-> MnistDataBatchS batch_size r
-> ADCnnMnistParametersShaped
     Concrete SizeMnistHeight SizeMnistHeight kh kw n n r
-> r
forall (kh :: Natural) (kw :: Natural) (h :: Natural)
       (w :: Natural) (c_out :: Natural) (n_hidden :: Natural)
       (batch_size :: Natural) (target :: TK -> Type) r.
((h :: Natural) ~ (SizeMnistHeight :: Natural),
 (w :: Natural) ~ (SizeMnistHeight :: Natural), (<=) @Natural 1 kh,
 (<=) @Natural 1 kw,
 (target :: (TK -> Type)) ~ (Concrete :: (TK -> Type)),
 GoodScalar r, Differentiable r) =>
SNat kh
-> SNat kw
-> SNat c_out
-> SNat n_hidden
-> SNat batch_size
-> MnistDataBatchS batch_size r
-> ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
-> r
MnistCnnShaped2.convMnistTestS SNat kh
kh SNat kw
kw (forall (n :: Natural). KnownNat n => SNat n
SNat @c_out) (forall (n :: Natural). KnownNat n => SNat n
SNat @n_hidden)
          (forall (n :: Natural). KnownNat n => SNat n
SNat @batch_size) MnistDataBatchS batch_size r
mnistData (forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget @Concrete Concrete
  (X ((Concrete
         (TKS2
            ((':)
               @Natural
               n
               ((':)
                  @Natural
                  1
                  ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
            (TKScalar r)),
       Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
      (Concrete
         (TKS2
            ((':)
               @Natural
               n
               ((':)
                  @Natural
                  n
                  ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
            (TKScalar r)),
       Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
      (Concrete
         (TKS2
            ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
            (TKScalar r)),
       Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
      (Concrete
         (TKS2
            ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
            (TKScalar r)),
       Concrete
         (TKS2
            ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
pars)
  in String -> Assertion -> TestTree
testCase String
name (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ do
      Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
        String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: Epochs to run/max batches per epoch: %d/%d"
               String
prefix Int
epochs Int
maxBatches
      trainData <- (MnistData r -> MnistDataS r) -> [MnistData r] -> [MnistDataS r]
forall a b. (a -> b) -> [a] -> [b]
map MnistData r -> MnistDataS r
forall r. PrimElt r => MnistData r -> MnistDataS r
mkMnistDataS
                   ([MnistData r] -> [MnistDataS r])
-> IO [MnistData r] -> IO [MnistDataS r]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> String -> IO [MnistData r]
forall r.
(Storable r, Fractional r) =>
String -> String -> IO [MnistData r]
loadMnistData String
trainGlyphsPath String
trainLabelsPath
      testData <- map mkMnistDataS . take (totalBatchSize * maxBatches)
                  <$> loadMnistData testGlyphsPath testLabelsPath
      withSNat (totalBatchSize * maxBatches) $ \(SNat @lenTestData) -> do
       let testDataS :: MnistDataBatchS n r
testDataS = [MnistDataS r] -> MnistDataBatchS n r
forall (batch_size :: Natural) r.
(Elt r, KnownNat batch_size) =>
[MnistDataS r] -> MnistDataBatchS batch_size r
mkMnistDataBatchS [MnistDataS r]
testData
           f :: MnistDataBatchS miniBatchSize r
             -> ADVal Concrete (XParams kh kw c_out n_hidden r)
             -> ADVal Concrete (TKScalar r)
           f :: MnistDataBatchS n r
-> ADVal
     Concrete
     (X (ADCnnMnistParametersShaped
           Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> ADVal Concrete (TKScalar r)
f (Shaped
  ((':)
     @Natural
     n
     ((':)
        @Natural
        SizeMnistHeight
        ((':) @Natural SizeMnistHeight ('[] @Natural))))
  r
glyphR, Shaped
  ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r
labelR) ADVal
  Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
adinputs =
             SNat kh
-> SNat kw
-> SNat n
-> SNat n
-> SNat n
-> (PrimalOf
      (ADVal Concrete)
      (TKS
         ((':)
            @Natural
            n
            ((':)
               @Natural
               SizeMnistHeight
               ((':) @Natural SizeMnistHeight ('[] @Natural))))
         r),
    PrimalOf
      (ADVal Concrete)
      (TKS
         ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r))
-> ADCnnMnistParametersShaped
     (ADVal Concrete) SizeMnistHeight SizeMnistHeight kh kw n n r
-> ADVal Concrete (TKScalar r)
forall (kh :: Natural) (kw :: Natural) (h :: Natural)
       (w :: Natural) (c_out :: Natural) (n_hidden :: Natural)
       (batch_size :: Natural) (target :: TK -> Type) r.
((h :: Natural) ~ (SizeMnistHeight :: Natural),
 (w :: Natural) ~ (SizeMnistHeight :: Natural), (<=) @Natural 1 kh,
 (<=) @Natural 1 kw, ADReady target, ADReady (PrimalOf target),
 GoodScalar r, Differentiable r) =>
SNat kh
-> SNat kw
-> SNat c_out
-> SNat n_hidden
-> SNat batch_size
-> (PrimalOf
      target
      (TKS
         ((':)
            @Natural
            batch_size
            ((':) @Natural h ((':) @Natural w ('[] @Natural))))
         r),
    PrimalOf
      target
      (TKS
         ((':)
            @Natural batch_size ((':) @Natural SizeMnistLabel ('[] @Natural)))
         r))
-> ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
-> target (TKScalar r)
MnistCnnShaped2.convMnistLossFusedS
               SNat kh
kh SNat kw
kw (forall (n :: Natural). KnownNat n => SNat n
SNat @c_out) (forall (n :: Natural). KnownNat n => SNat n
SNat @n_hidden)
               SNat n
miniBatchSize (Shaped
  ((':)
     @Natural
     n
     ((':)
        @Natural
        SizeMnistHeight
        ((':) @Natural SizeMnistHeight ('[] @Natural))))
  r
-> Concrete
     (TKS
        ((':)
           @Natural
           n
           ((':)
              @Natural
              SizeMnistHeight
              ((':) @Natural SizeMnistHeight ('[] @Natural))))
        r)
forall r (target :: TK -> Type) (sh :: [Natural]).
(GoodScalar r, BaseTensor target) =>
Shaped sh r -> target (TKS sh r)
sconcrete Shaped
  ((':)
     @Natural
     n
     ((':)
        @Natural
        SizeMnistHeight
        ((':) @Natural SizeMnistHeight ('[] @Natural))))
  r
glyphR, Shaped
  ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r
-> Concrete
     (TKS
        ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r)
forall r (target :: TK -> Type) (sh :: [Natural]).
(GoodScalar r, BaseTensor target) =>
Shaped sh r -> target (TKS sh r)
sconcrete Shaped
  ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r
labelR)
               (ADVal
  Concrete
  (X ((ADVal
         Concrete
         (TKS2
            ((':)
               @Natural
               n
               ((':)
                  @Natural
                  1
                  ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
            (TKScalar r)),
       ADVal
         Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
      (ADVal
         Concrete
         (TKS2
            ((':)
               @Natural
               n
               ((':)
                  @Natural
                  n
                  ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
            (TKScalar r)),
       ADVal
         Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
      (ADVal
         Concrete
         (TKS2
            ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
            (TKScalar r)),
       ADVal
         Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
      (ADVal
         Concrete
         (TKS2
            ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
            (TKScalar r)),
       ADVal
         Concrete
         (TKS2
            ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> ((ADVal
       Concrete
       (TKS2
          ((':)
             @Natural
             n
             ((':)
                @Natural
                1
                ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
          (TKScalar r)),
     ADVal
       Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
    (ADVal
       Concrete
       (TKS2
          ((':)
             @Natural
             n
             ((':)
                @Natural
                n
                ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
          (TKScalar r)),
     ADVal
       Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
    (ADVal
       Concrete
       (TKS2
          ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
          (TKScalar r)),
     ADVal
       Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
    (ADVal
       Concrete
       (TKS2
          ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
          (TKScalar r)),
     ADVal
       Concrete
       (TKS2 ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))
forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget ADVal
  Concrete
  (X ((ADVal
         Concrete
         (TKS2
            ((':)
               @Natural
               n
               ((':)
                  @Natural
                  1
                  ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
            (TKScalar r)),
       ADVal
         Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
      (ADVal
         Concrete
         (TKS2
            ((':)
               @Natural
               n
               ((':)
                  @Natural
                  n
                  ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
            (TKScalar r)),
       ADVal
         Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
      (ADVal
         Concrete
         (TKS2
            ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
            (TKScalar r)),
       ADVal
         Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
      (ADVal
         Concrete
         (TKS2
            ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
            (TKScalar r)),
       ADVal
         Concrete
         (TKS2
            ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
ADVal
  Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
adinputs)
           runBatch :: (Concrete (XParams kh kw c_out n_hidden r), StateAdam (XParams kh kw c_out n_hidden r))
                    -> (Int, [MnistDataS r])
                    -> IO (Concrete (XParams kh kw c_out n_hidden r), StateAdam (XParams kh kw c_out n_hidden r))
           runBatch :: (Concrete
   (X (ADCnnMnistParametersShaped
         Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
 StateAdam
   (X (ADCnnMnistParametersShaped
         Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
-> (Int, [MnistDataS r])
-> IO
     (Concrete
        (X (ADCnnMnistParametersShaped
              Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
      StateAdam
        (X (ADCnnMnistParametersShaped
              Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
runBatch (!Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
parameters, !StateAdam
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
stateAdam) (Int
k, [MnistDataS r]
chunk) = do
             let chunkS :: [MnistDataBatchS n r]
chunkS = ([MnistDataS r] -> MnistDataBatchS n r)
-> [[MnistDataS r]] -> [MnistDataBatchS n r]
forall a b. (a -> b) -> [a] -> [b]
map [MnistDataS r] -> MnistDataBatchS n r
forall (batch_size :: Natural) r.
(Elt r, KnownNat batch_size) =>
[MnistDataS r] -> MnistDataBatchS batch_size r
mkMnistDataBatchS
                          ([[MnistDataS r]] -> [MnistDataBatchS n r])
-> [[MnistDataS r]] -> [MnistDataBatchS n r]
forall a b. (a -> b) -> a -> b
$ ([MnistDataS r] -> Bool) -> [[MnistDataS r]] -> [[MnistDataS r]]
forall a. (a -> Bool) -> [a] -> [a]
filter (\[MnistDataS r]
ch -> [MnistDataS r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataS r]
ch Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
miniBatchSizeInt)
                          ([[MnistDataS r]] -> [[MnistDataS r]])
-> [[MnistDataS r]] -> [[MnistDataS r]]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataS r] -> [[MnistDataS r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
miniBatchSizeInt [MnistDataS r]
chunk
                 res :: (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     1
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     n
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
      (TKProduct
         (TKProduct
            (TKS2
               ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
               (TKScalar r))
            (TKS2
               ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     1
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     n
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
      (TKProduct
         (TKProduct
            (TKS2
               ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
               (TKScalar r))
            (TKS2
               ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
res@(Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
parameters2, StateAdam
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
_) =
                   (MnistDataBatchS n r
 -> ADVal
      Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (TKS2
                  ((':)
                     @Natural
                     n
                     ((':)
                        @Natural
                        1
                        ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                  (TKScalar r))
               (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
            (TKProduct
               (TKS2
                  ((':)
                     @Natural
                     n
                     ((':)
                        @Natural
                        n
                        ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                  (TKScalar r))
               (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
         (TKProduct
            (TKProduct
               (TKS2
                  ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                  (TKScalar r))
               (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
            (TKProduct
               (TKS2
                  ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                  (TKScalar r))
               (TKS2
                  ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
 -> ADVal Concrete (TKScalar r))
-> [MnistDataBatchS n r]
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':)
                    @Natural
                    n
                    ((':)
                       @Natural
                       1
                       ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':)
                    @Natural
                    n
                    ((':)
                       @Natural
                       n
                       ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                 (TKScalar r))
              (TKS2
                 ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> StateAdam
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':)
                    @Natural
                    n
                    ((':)
                       @Natural
                       1
                       ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':)
                    @Natural
                    n
                    ((':)
                       @Natural
                       n
                       ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                 (TKScalar r))
              (TKS2
                 ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (TKS2
                  ((':)
                     @Natural
                     n
                     ((':)
                        @Natural
                        1
                        ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                  (TKScalar r))
               (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
            (TKProduct
               (TKS2
                  ((':)
                     @Natural
                     n
                     ((':)
                        @Natural
                        n
                        ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                  (TKScalar r))
               (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
         (TKProduct
            (TKProduct
               (TKS2
                  ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                  (TKScalar r))
               (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
            (TKProduct
               (TKS2
                  ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                  (TKScalar r))
               (TKS2
                  ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
    StateAdam
      (TKProduct
         (TKProduct
            (TKProduct
               (TKS2
                  ((':)
                     @Natural
                     n
                     ((':)
                        @Natural
                        1
                        ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                  (TKScalar r))
               (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
            (TKProduct
               (TKS2
                  ((':)
                     @Natural
                     n
                     ((':)
                        @Natural
                        n
                        ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                  (TKScalar r))
               (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
         (TKProduct
            (TKProduct
               (TKS2
                  ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                  (TKScalar r))
               (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
            (TKProduct
               (TKS2
                  ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                  (TKScalar r))
               (TKS2
                  ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
forall a (x :: TK) (z :: TK).
KnownSTK x =>
(a -> ADVal Concrete x -> ADVal Concrete z)
-> [a] -> Concrete x -> StateAdam x -> (Concrete x, StateAdam x)
sgdAdam MnistDataBatchS n r
-> ADVal
     Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':)
                    @Natural
                    n
                    ((':)
                       @Natural
                       1
                       ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':)
                    @Natural
                    n
                    ((':)
                       @Natural
                       n
                       ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                 (TKScalar r))
              (TKS2
                 ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> ADVal Concrete (TKScalar r)
MnistDataBatchS n r
-> ADVal
     Concrete
     (X (ADCnnMnistParametersShaped
           Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> ADVal Concrete (TKScalar r)
f [MnistDataBatchS n r]
chunkS Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
parameters StateAdam
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
StateAdam
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
stateAdam
                 trainScore :: r
trainScore = Int -> (forall (n :: Natural). KnownNat n => SNat n -> r) -> r
forall r.
Int -> (forall (n :: Natural). KnownNat n => SNat n -> r) -> r
withSNat ([MnistDataS r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataS r]
chunk) ((forall (n :: Natural). KnownNat n => SNat n -> r) -> r)
-> (forall (n :: Natural). KnownNat n => SNat n -> r) -> r
forall a b. (a -> b) -> a -> b
$ \(SNat @len) ->
                   forall (batch_size :: Natural).
KnownNat batch_size =>
MnistDataBatchS batch_size r
-> Concrete
     (X (ADCnnMnistParametersShaped
           Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> r
ftest @len ([MnistDataS r] -> MnistDataBatchS n r
forall (batch_size :: Natural) r.
(Elt r, KnownNat batch_size) =>
[MnistDataS r] -> MnistDataBatchS batch_size r
mkMnistDataBatchS [MnistDataS r]
chunk) Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
parameters2
                 testScore :: r
testScore = forall (batch_size :: Natural).
KnownNat batch_size =>
MnistDataBatchS batch_size r
-> Concrete
     (X (ADCnnMnistParametersShaped
           Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> r
ftest @lenTestData MnistDataBatchS n r
testDataS Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
parameters2
                 lenChunk :: Int
lenChunk = [MnistDataS r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataS r]
chunk
             Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
n_hiddenInt Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$ do
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
                 String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: (Batch %d with %d points)"
                        String
prefix Int
k Int
lenChunk
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
                 String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Training error:   %.2f%%"
                        String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
trainScore) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
                 String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Validation error: %.2f%%"
                        String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
testScore ) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
             (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     1
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     n
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
      (TKProduct
         (TKProduct
            (TKS2
               ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
               (TKScalar r))
            (TKS2
               ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     1
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     n
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
      (TKProduct
         (TKProduct
            (TKS2
               ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
               (TKScalar r))
            (TKS2
               ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          1
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          n
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                    (TKScalar r))
                 (TKS2
                    ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
      StateAdam
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          1
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          n
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                    (TKScalar r))
                 (TKS2
                    ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     1
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     n
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
      (TKProduct
         (TKProduct
            (TKS2
               ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
               (TKScalar r))
            (TKS2
               ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     1
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     n
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
      (TKProduct
         (TKProduct
            (TKS2
               ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
               (TKScalar r))
            (TKS2
               ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
res
       let runEpoch :: Int
                    -> (Concrete (XParams kh kw c_out n_hidden r), StateAdam (XParams kh kw c_out n_hidden r))
                    -> IO (Concrete (XParams kh kw c_out n_hidden r))
           runEpoch :: Int
-> (Concrete
      (X (ADCnnMnistParametersShaped
            Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
    StateAdam
      (X (ADCnnMnistParametersShaped
            Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
-> IO
     (Concrete
        (X (ADCnnMnistParametersShaped
              Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
runEpoch Int
n (Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
params2, StateAdam
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
_) | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
epochs = Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          1
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          n
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                    (TKScalar r))
                 (TKS2
                    ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
params2
           runEpoch Int
n paramsStateAdam :: (Concrete
   (X (ADCnnMnistParametersShaped
         Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
 StateAdam
   (X (ADCnnMnistParametersShaped
         Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
paramsStateAdam@(!Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
_, !StateAdam
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
_) = do
             Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
n_hiddenInt Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$ String -> String -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: [Epoch %d]" String
prefix Int
n
             let trainDataShuffled :: [MnistDataS r]
trainDataShuffled = StdGen -> [MnistDataS r] -> [MnistDataS r]
forall a. StdGen -> [a] -> [a]
shuffle (Int -> StdGen
mkStdGen (Int -> StdGen) -> Int -> StdGen
forall a b. (a -> b) -> a -> b
$ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
5) [MnistDataS r]
trainData
                 chunks :: [(Int, [MnistDataS r])]
chunks = Int -> [(Int, [MnistDataS r])] -> [(Int, [MnistDataS r])]
forall a. Int -> [a] -> [a]
take Int
maxBatches
                          ([(Int, [MnistDataS r])] -> [(Int, [MnistDataS r])])
-> [(Int, [MnistDataS r])] -> [(Int, [MnistDataS r])]
forall a b. (a -> b) -> a -> b
$ [Int] -> [[MnistDataS r]] -> [(Int, [MnistDataS r])]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
1 ..]
                          ([[MnistDataS r]] -> [(Int, [MnistDataS r])])
-> [[MnistDataS r]] -> [(Int, [MnistDataS r])]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataS r] -> [[MnistDataS r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
totalBatchSize [MnistDataS r]
trainDataShuffled
             res <- ((Concrete
    (TKProduct
       (TKProduct
          (TKProduct
             (TKS2
                ((':)
                   @Natural
                   n
                   ((':)
                      @Natural
                      1
                      ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                (TKScalar r))
             (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
          (TKProduct
             (TKS2
                ((':)
                   @Natural
                   n
                   ((':)
                      @Natural
                      n
                      ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                (TKScalar r))
             (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
       (TKProduct
          (TKProduct
             (TKS2
                ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                (TKScalar r))
             (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
          (TKProduct
             (TKS2
                ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                (TKScalar r))
             (TKS2
                ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
  StateAdam
    (TKProduct
       (TKProduct
          (TKProduct
             (TKS2
                ((':)
                   @Natural
                   n
                   ((':)
                      @Natural
                      1
                      ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                (TKScalar r))
             (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
          (TKProduct
             (TKS2
                ((':)
                   @Natural
                   n
                   ((':)
                      @Natural
                      n
                      ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                (TKScalar r))
             (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
       (TKProduct
          (TKProduct
             (TKS2
                ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                (TKScalar r))
             (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
          (TKProduct
             (TKS2
                ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                (TKScalar r))
             (TKS2
                ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
 -> (Int, [MnistDataS r])
 -> IO
      (Concrete
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKS2
                     ((':)
                        @Natural
                        n
                        ((':)
                           @Natural
                           1
                           ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                     (TKScalar r))
                  (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
               (TKProduct
                  (TKS2
                     ((':)
                        @Natural
                        n
                        ((':)
                           @Natural
                           n
                           ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                     (TKScalar r))
                  (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
            (TKProduct
               (TKProduct
                  (TKS2
                     ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                     (TKScalar r))
                  (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
               (TKProduct
                  (TKS2
                     ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                     (TKScalar r))
                  (TKS2
                     ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
       StateAdam
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKS2
                     ((':)
                        @Natural
                        n
                        ((':)
                           @Natural
                           1
                           ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                     (TKScalar r))
                  (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
               (TKProduct
                  (TKS2
                     ((':)
                        @Natural
                        n
                        ((':)
                           @Natural
                           n
                           ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                     (TKScalar r))
                  (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
            (TKProduct
               (TKProduct
                  (TKS2
                     ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                     (TKScalar r))
                  (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
               (TKProduct
                  (TKS2
                     ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                     (TKScalar r))
                  (TKS2
                     ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (TKS2
                  ((':)
                     @Natural
                     n
                     ((':)
                        @Natural
                        1
                        ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                  (TKScalar r))
               (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
            (TKProduct
               (TKS2
                  ((':)
                     @Natural
                     n
                     ((':)
                        @Natural
                        n
                        ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                  (TKScalar r))
               (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
         (TKProduct
            (TKProduct
               (TKS2
                  ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                  (TKScalar r))
               (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
            (TKProduct
               (TKS2
                  ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                  (TKScalar r))
               (TKS2
                  ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
    StateAdam
      (TKProduct
         (TKProduct
            (TKProduct
               (TKS2
                  ((':)
                     @Natural
                     n
                     ((':)
                        @Natural
                        1
                        ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                  (TKScalar r))
               (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
            (TKProduct
               (TKS2
                  ((':)
                     @Natural
                     n
                     ((':)
                        @Natural
                        n
                        ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                  (TKScalar r))
               (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
         (TKProduct
            (TKProduct
               (TKS2
                  ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                  (TKScalar r))
               (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
            (TKProduct
               (TKS2
                  ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                  (TKScalar r))
               (TKS2
                  ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
-> [(Int, [MnistDataS r])]
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          1
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          n
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                    (TKScalar r))
                 (TKS2
                    ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
      StateAdam
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          1
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          n
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                    (TKScalar r))
                 (TKS2
                    ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
forall (t :: Type -> Type) (m :: Type -> Type) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     1
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     n
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
      (TKProduct
         (TKProduct
            (TKS2
               ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
               (TKScalar r))
            (TKS2
               ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     1
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     n
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
      (TKProduct
         (TKProduct
            (TKS2
               ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
               (TKScalar r))
            (TKS2
               ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
-> (Int, [MnistDataS r])
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          1
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          n
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                    (TKScalar r))
                 (TKS2
                    ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
      StateAdam
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          1
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          n
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                    (TKScalar r))
                 (TKS2
                    ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
(Concrete
   (X (ADCnnMnistParametersShaped
         Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
 StateAdam
   (X (ADCnnMnistParametersShaped
         Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
-> (Int, [MnistDataS r])
-> IO
     (Concrete
        (X (ADCnnMnistParametersShaped
              Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
      StateAdam
        (X (ADCnnMnistParametersShaped
              Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
runBatch (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     1
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     n
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
      (TKProduct
         (TKProduct
            (TKS2
               ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
               (TKScalar r))
            (TKS2
               ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     1
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     n
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
      (TKProduct
         (TKProduct
            (TKS2
               ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
               (TKScalar r))
            (TKS2
               ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
(Concrete
   (X (ADCnnMnistParametersShaped
         Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
 StateAdam
   (X (ADCnnMnistParametersShaped
         Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
paramsStateAdam [(Int, [MnistDataS r])]
chunks
             runEpoch (succ n) res
           ftk :: FullShapeTK
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
ftk = forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams kh kw c_out n_hidden r)) Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
targetInit
       res <- Int
-> (Concrete
      (X (ADCnnMnistParametersShaped
            Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
    StateAdam
      (X (ADCnnMnistParametersShaped
            Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
-> IO
     (Concrete
        (X (ADCnnMnistParametersShaped
              Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
runEpoch Int
1 (Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
targetInit, FullShapeTK
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> StateAdam
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':)
                    @Natural
                    n
                    ((':)
                       @Natural
                       1
                       ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':)
                    @Natural
                    n
                    ((':)
                       @Natural
                       n
                       ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                 (TKScalar r))
              (TKS2
                 ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
forall (y :: TK). FullShapeTK y -> StateAdam y
initialStateAdam FullShapeTK
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
ftk)
       let testErrorFinal =
             r
1 r -> r -> r
forall a. Num a => a -> a -> a
- MnistDataBatchS n r
-> Concrete
     (X (ADCnnMnistParametersShaped
           Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> r
forall (batch_size :: Natural).
KnownNat batch_size =>
MnistDataBatchS batch_size r
-> Concrete
     (X (ADCnnMnistParametersShaped
           Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> r
ftest MnistDataBatchS n r
testDataS Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
res
       testErrorFinal @?~ expected

tensorADValMnistTestsCNNSA :: TestTree
tensorADValMnistTestsCNNSA :: TestTree
tensorADValMnistTestsCNNSA = String -> [TestTree] -> TestTree
testGroup String
"CNNS ADVal MNIST tests"
  [ String
-> Int
-> Int
-> SNat 4
-> SNat 4
-> Int
-> Int
-> Int
-> Int
-> Double
-> TestTree
forall (kh :: Natural) (kw :: Natural) r.
((<=) @Natural 1 kh, (<=) @Natural 1 kw, Differentiable r,
 GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> SNat kh
-> SNat kw
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNSA String
"CNNSA 1 epoch, 1 batch"
                       Int
1 Int
1 (forall (n :: Natural). KnownNat n => SNat n
SNat @4) (forall (n :: Natural). KnownNat n => SNat n
SNat @4) Int
8 Int
16 Int
1 Int
1
                       (Double
1 :: Double)
  , String
-> Int
-> Int
-> SNat 2
-> SNat 3
-> Int
-> Int
-> Int
-> Int
-> Float
-> TestTree
forall (kh :: Natural) (kw :: Natural) r.
((<=) @Natural 1 kh, (<=) @Natural 1 kw, Differentiable r,
 GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> SNat kh
-> SNat kw
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNSA String
"CNNSA artificial 1 2 3 4 5"
                       Int
1 Int
1 (forall (n :: Natural). KnownNat n => SNat n
SNat @2) (forall (n :: Natural). KnownNat n => SNat n
SNat @3) Int
4 Int
5 Int
1 Int
10
                       (Float
1 :: Float)
  , String
-> Int
-> Int
-> SNat 3
-> SNat 2
-> Int
-> Int
-> Int
-> Int
-> Double
-> TestTree
forall (kh :: Natural) (kw :: Natural) r.
((<=) @Natural 1 kh, (<=) @Natural 1 kw, Differentiable r,
 GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> SNat kh
-> SNat kw
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNSA String
"CNNSA artificial 5 4 3 2 1"
                       Int
5 Int
4 (forall (n :: Natural). KnownNat n => SNat n
SNat @3) (forall (n :: Natural). KnownNat n => SNat n
SNat @2) Int
1 Int
1 Int
1 Int
1
                       (Double
1 :: Double)
  , String
-> Int
-> Int
-> SNat 4
-> SNat 4
-> Int
-> Int
-> Int
-> Int
-> Float
-> TestTree
forall (kh :: Natural) (kw :: Natural) r.
((<=) @Natural 1 kh, (<=) @Natural 1 kw, Differentiable r,
 GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> SNat kh
-> SNat kw
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNSA String
"CNNSA 1 epoch, 0 batch"
                       Int
1 Int
0 (forall (n :: Natural). KnownNat n => SNat n
SNat @4) (forall (n :: Natural). KnownNat n => SNat n
SNat @4) Int
16 Int
64 Int
16 Int
50
                       (Float
1.0 :: Float)
  ]

-- POPL differentiation, with Ast term defined and vectorized only once,
-- but differentiated anew in each gradient descent iteration.
mnistTestCaseCNNSI
  :: forall kh kw r.
     ( 1 <= kh, 1 <= kw
     , Differentiable r, GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r )
  => String
  -> Int -> Int -> SNat kh -> SNat kw -> Int -> Int -> Int -> Int -> r
  -> TestTree
mnistTestCaseCNNSI :: forall (kh :: Natural) (kw :: Natural) r.
((<=) @Natural 1 kh, (<=) @Natural 1 kw, Differentiable r,
 GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> SNat kh
-> SNat kw
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNSI String
prefix Int
epochs Int
maxBatches kh :: SNat kh
kh@SNat kh
SNat kw :: SNat kw
kw@SNat kw
SNat Int
c_outInt Int
n_hiddenInt
                   Int
miniBatchSizeInt Int
totalBatchSize r
expected =
  Int
-> (forall (n :: Natural). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Natural). KnownNat n => SNat n -> r) -> r
withSNat Int
c_outInt ((forall (n :: Natural). KnownNat n => SNat n -> TestTree)
 -> TestTree)
-> (forall (n :: Natural). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
_c_outSNat :: SNat c_out) ->
  Int
-> (forall (n :: Natural). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Natural). KnownNat n => SNat n -> r) -> r
withSNat Int
n_hiddenInt ((forall (n :: Natural). KnownNat n => SNat n -> TestTree)
 -> TestTree)
-> (forall (n :: Natural). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
_n_hiddenSNat :: SNat n_hidden) ->
  Int
-> (forall (n :: Natural). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Natural). KnownNat n => SNat n -> r) -> r
withSNat Int
miniBatchSizeInt ((forall (n :: Natural). KnownNat n => SNat n -> TestTree)
 -> TestTree)
-> (forall (n :: Natural). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
miniBatchSize :: SNat miniBatchSize) ->
  let targetInit :: Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
targetInit =
        (Concrete
   (X (ADCnnMnistParametersShaped
         Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
 StdGen)
-> Concrete
     (X (ADCnnMnistParametersShaped
           Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
forall a b. (a, b) -> a
fst ((Concrete
    (X (ADCnnMnistParametersShaped
          Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
  StdGen)
 -> Concrete
      (X (ADCnnMnistParametersShaped
            Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
-> (Concrete
      (X (ADCnnMnistParametersShaped
            Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
    StdGen)
-> Concrete
     (X (ADCnnMnistParametersShaped
           Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
forall a b. (a -> b) -> a -> b
$ forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue
                @(Concrete (X (MnistCnnShaped2.ADCnnMnistParametersShaped
                                 Concrete SizeMnistHeight SizeMnistWidth
                                 kh kw c_out n_hidden r)))
                Double
0.4 (Int -> StdGen
mkStdGen Int
44)
      name :: String
name = String
prefix String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": "
             String -> String -> String
forall a. [a] -> [a] -> [a]
++ [String] -> String
unwords [ Int -> String
forall a. Show a => a -> String
show Int
epochs, Int -> String
forall a. Show a => a -> String
show Int
maxBatches
                        , Int -> String
forall a. Show a => a -> String
show (SNat kh -> Int
forall (n :: Natural). SNat n -> Int
sNatValue SNat kh
kh), Int -> String
forall a. Show a => a -> String
show (SNat kw -> Int
forall (n :: Natural). SNat n -> Int
sNatValue SNat kw
kw)
                        , Int -> String
forall a. Show a => a -> String
show Int
c_outInt, Int -> String
forall a. Show a => a -> String
show Int
n_hiddenInt
                        , Int -> String
forall a. Show a => a -> String
show Int
miniBatchSizeInt
                        , Int -> String
forall a. Show a => a -> String
show (Int -> String) -> Int -> String
forall a b. (a -> b) -> a -> b
$ SingletonTK
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK (SingletonTK
   (X (ADCnnMnistParametersShaped
         Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
 -> Int)
-> SingletonTK
     (X (ADCnnMnistParametersShaped
           Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams kh kw c_out n_hidden r)
                        , Int -> String
forall a. Show a => a -> String
show (SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':)
                    @Natural
                    n
                    ((':)
                       @Natural
                       1
                       ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':)
                    @Natural
                    n
                    ((':)
                       @Natural
                       n
                       ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                 (TKScalar r))
              (TKS2
                 ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
targetInit) ]
      ftest :: KnownNat batch_size
            => MnistDataBatchS batch_size r
            -> Concrete (XParams kh kw c_out n_hidden r) -> r
      ftest :: forall (batch_size :: Natural).
KnownNat batch_size =>
MnistDataBatchS batch_size r
-> Concrete
     (X (ADCnnMnistParametersShaped
           Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> r
ftest @batch_size MnistDataBatchS batch_size r
mnistData Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
pars =
        SNat kh
-> SNat kw
-> SNat n
-> SNat n
-> SNat batch_size
-> MnistDataBatchS batch_size r
-> ADCnnMnistParametersShaped
     Concrete SizeMnistHeight SizeMnistHeight kh kw n n r
-> r
forall (kh :: Natural) (kw :: Natural) (h :: Natural)
       (w :: Natural) (c_out :: Natural) (n_hidden :: Natural)
       (batch_size :: Natural) (target :: TK -> Type) r.
((h :: Natural) ~ (SizeMnistHeight :: Natural),
 (w :: Natural) ~ (SizeMnistHeight :: Natural), (<=) @Natural 1 kh,
 (<=) @Natural 1 kw,
 (target :: (TK -> Type)) ~ (Concrete :: (TK -> Type)),
 GoodScalar r, Differentiable r) =>
SNat kh
-> SNat kw
-> SNat c_out
-> SNat n_hidden
-> SNat batch_size
-> MnistDataBatchS batch_size r
-> ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
-> r
MnistCnnShaped2.convMnistTestS SNat kh
kh SNat kw
kw (forall (n :: Natural). KnownNat n => SNat n
SNat @c_out) (forall (n :: Natural). KnownNat n => SNat n
SNat @n_hidden)
          (forall (n :: Natural). KnownNat n => SNat n
SNat @batch_size) MnistDataBatchS batch_size r
mnistData (forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget @Concrete Concrete
  (X ((Concrete
         (TKS2
            ((':)
               @Natural
               n
               ((':)
                  @Natural
                  1
                  ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
            (TKScalar r)),
       Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
      (Concrete
         (TKS2
            ((':)
               @Natural
               n
               ((':)
                  @Natural
                  n
                  ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
            (TKScalar r)),
       Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
      (Concrete
         (TKS2
            ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
            (TKScalar r)),
       Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
      (Concrete
         (TKS2
            ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
            (TKScalar r)),
       Concrete
         (TKS2
            ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
pars)
  in String -> Assertion -> TestTree
testCase String
name (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ do
      Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
        String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: Epochs to run/max batches per epoch: %d/%d"
               String
prefix Int
epochs Int
maxBatches
      trainData <- (MnistData r -> MnistDataS r) -> [MnistData r] -> [MnistDataS r]
forall a b. (a -> b) -> [a] -> [b]
map MnistData r -> MnistDataS r
forall r. PrimElt r => MnistData r -> MnistDataS r
mkMnistDataS
                   ([MnistData r] -> [MnistDataS r])
-> IO [MnistData r] -> IO [MnistDataS r]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> String -> IO [MnistData r]
forall r.
(Storable r, Fractional r) =>
String -> String -> IO [MnistData r]
loadMnistData String
trainGlyphsPath String
trainLabelsPath
      testData <- map mkMnistDataS . take (totalBatchSize * maxBatches)
                  <$> loadMnistData testGlyphsPath testLabelsPath
      withSNat (totalBatchSize * maxBatches) $ \(SNat @lenTestData) -> do
       let testDataS :: MnistDataBatchS n r
testDataS = [MnistDataS r] -> MnistDataBatchS n r
forall (batch_size :: Natural) r.
(Elt r, KnownNat batch_size) =>
[MnistDataS r] -> MnistDataBatchS batch_size r
mkMnistDataBatchS [MnistDataS r]
testData
           ftk :: FullShapeTK
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
ftk = forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams kh kw c_out n_hidden r)) Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
targetInit
       (_, _, var, varAst2) <- FullShapeTK
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> IO
     (AstVarName
        PrimalSpan
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          1
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          n
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                    (TKScalar r))
                 (TKS2
                    ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
      AstTensor
        AstMethodShare
        PrimalSpan
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          1
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          n
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                    (TKScalar r))
                 (TKS2
                    ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
      AstVarName
        FullSpan
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          1
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          n
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                    (TKScalar r))
                 (TKS2
                    ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
      AstTensor
        AstMethodLet
        FullSpan
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          1
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          n
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                    (TKScalar r))
                 (TKS2
                    ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
forall (x :: TK).
FullShapeTK x
-> IO
     (AstVarName PrimalSpan x, AstTensor AstMethodShare PrimalSpan x,
      AstVarName FullSpan x, AstTensor AstMethodLet FullSpan x)
funToAstRevIO FullShapeTK
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
ftk
       (varGlyph, astGlyph) <-
         funToAstIO (FTKS (miniBatchSize
                           :$$ sizeMnistHeight
                           :$$ sizeMnistWidth
                           :$$ ZSS) FTKScalar) id
       (varLabel, astLabel) <-
         funToAstIO (FTKS (miniBatchSize
                           :$$ sizeMnistLabel
                           :$$ ZSS) FTKScalar) id
       let ast :: AstTensor AstMethodLet FullSpan (TKScalar r)
           ast = AstTensor AstMethodLet FullSpan (TKScalar r)
-> AstTensor AstMethodLet FullSpan (TKScalar r)
forall (z :: TK) (s :: AstSpanType).
AstSpan s =>
AstTensor AstMethodLet s z -> AstTensor AstMethodLet s z
simplifyInline
                 (AstTensor AstMethodLet FullSpan (TKScalar r)
 -> AstTensor AstMethodLet FullSpan (TKScalar r))
-> AstTensor AstMethodLet FullSpan (TKScalar r)
-> AstTensor AstMethodLet FullSpan (TKScalar r)
forall a b. (a -> b) -> a -> b
$ SNat kh
-> SNat kw
-> SNat n
-> SNat n
-> SNat n
-> (PrimalOf
      (AstTensor AstMethodLet FullSpan)
      (TKS2
         ((':)
            @Natural
            n
            ((':)
               @Natural
               SizeMnistHeight
               ((':) @Natural SizeMnistHeight ('[] @Natural))))
         (TKScalar r)),
    PrimalOf
      (AstTensor AstMethodLet FullSpan)
      (TKS2
         ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
         (TKScalar r)))
-> ADCnnMnistParametersShaped
     (AstTensor AstMethodLet FullSpan)
     SizeMnistHeight
     SizeMnistHeight
     kh
     kw
     n
     n
     r
-> AstTensor AstMethodLet FullSpan (TKScalar r)
forall (kh :: Natural) (kw :: Natural) (h :: Natural)
       (w :: Natural) (c_out :: Natural) (n_hidden :: Natural)
       (batch_size :: Natural) (target :: TK -> Type) r.
((h :: Natural) ~ (SizeMnistHeight :: Natural),
 (w :: Natural) ~ (SizeMnistHeight :: Natural), (<=) @Natural 1 kh,
 (<=) @Natural 1 kw, ADReady target, ADReady (PrimalOf target),
 GoodScalar r, Differentiable r) =>
SNat kh
-> SNat kw
-> SNat c_out
-> SNat n_hidden
-> SNat batch_size
-> (PrimalOf
      target
      (TKS
         ((':)
            @Natural
            batch_size
            ((':) @Natural h ((':) @Natural w ('[] @Natural))))
         r),
    PrimalOf
      target
      (TKS
         ((':)
            @Natural batch_size ((':) @Natural SizeMnistLabel ('[] @Natural)))
         r))
-> ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
-> target (TKScalar r)
MnistCnnShaped2.convMnistLossFusedS
                     SNat kh
kh SNat kw
kw (forall (n :: Natural). KnownNat n => SNat n
SNat @c_out) (forall (n :: Natural). KnownNat n => SNat n
SNat @n_hidden)
                     SNat n
miniBatchSize (AstTensor
  AstMethodLet
  PrimalSpan
  (TKS2
     ((':)
        @Natural
        n
        ((':)
           @Natural
           SizeMnistHeight
           ((':) @Natural SizeMnistHeight ('[] @Natural))))
     (TKScalar r))
PrimalOf
  (AstTensor AstMethodLet FullSpan)
  (TKS2
     ((':)
        @Natural
        n
        ((':)
           @Natural
           SizeMnistHeight
           ((':) @Natural SizeMnistHeight ('[] @Natural))))
     (TKScalar r))
astGlyph, AstTensor
  AstMethodLet
  PrimalSpan
  (TKS2
     ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
     (TKScalar r))
PrimalOf
  (AstTensor AstMethodLet FullSpan)
  (TKS2
     ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
     (TKScalar r))
astLabel)
                     (AstTensor
  AstMethodLet
  FullSpan
  (X ((AstTensor
         AstMethodLet
         FullSpan
         (TKS2
            ((':)
               @Natural
               n
               ((':)
                  @Natural
                  1
                  ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
            (TKScalar r)),
       AstTensor
         AstMethodLet
         FullSpan
         (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
      (AstTensor
         AstMethodLet
         FullSpan
         (TKS2
            ((':)
               @Natural
               n
               ((':)
                  @Natural
                  n
                  ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
            (TKScalar r)),
       AstTensor
         AstMethodLet
         FullSpan
         (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
      (AstTensor
         AstMethodLet
         FullSpan
         (TKS2
            ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
            (TKScalar r)),
       AstTensor
         AstMethodLet
         FullSpan
         (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
      (AstTensor
         AstMethodLet
         FullSpan
         (TKS2
            ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
            (TKScalar r)),
       AstTensor
         AstMethodLet
         FullSpan
         (TKS2
            ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> ((AstTensor
       AstMethodLet
       FullSpan
       (TKS2
          ((':)
             @Natural
             n
             ((':)
                @Natural
                1
                ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
          (TKScalar r)),
     AstTensor
       AstMethodLet
       FullSpan
       (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
    (AstTensor
       AstMethodLet
       FullSpan
       (TKS2
          ((':)
             @Natural
             n
             ((':)
                @Natural
                n
                ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
          (TKScalar r)),
     AstTensor
       AstMethodLet
       FullSpan
       (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
    (AstTensor
       AstMethodLet
       FullSpan
       (TKS2
          ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
          (TKScalar r)),
     AstTensor
       AstMethodLet
       FullSpan
       (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
    (AstTensor
       AstMethodLet
       FullSpan
       (TKS2
          ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
          (TKScalar r)),
     AstTensor
       AstMethodLet
       FullSpan
       (TKS2 ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))
forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget AstTensor
  AstMethodLet
  FullSpan
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
AstTensor
  AstMethodLet
  FullSpan
  (X ((AstTensor
         AstMethodLet
         FullSpan
         (TKS2
            ((':)
               @Natural
               n
               ((':)
                  @Natural
                  1
                  ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
            (TKScalar r)),
       AstTensor
         AstMethodLet
         FullSpan
         (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
      (AstTensor
         AstMethodLet
         FullSpan
         (TKS2
            ((':)
               @Natural
               n
               ((':)
                  @Natural
                  n
                  ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
            (TKScalar r)),
       AstTensor
         AstMethodLet
         FullSpan
         (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
      (AstTensor
         AstMethodLet
         FullSpan
         (TKS2
            ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
            (TKScalar r)),
       AstTensor
         AstMethodLet
         FullSpan
         (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
      (AstTensor
         AstMethodLet
         FullSpan
         (TKS2
            ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
            (TKScalar r)),
       AstTensor
         AstMethodLet
         FullSpan
         (TKS2
            ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
varAst2)
           f :: MnistDataBatchS miniBatchSize r
             -> ADVal Concrete (XParams kh kw c_out n_hidden r)
             -> ADVal Concrete (TKScalar r)
           f (Shaped
  ((':)
     @Natural
     n
     ((':)
        @Natural
        SizeMnistHeight
        ((':) @Natural SizeMnistHeight ('[] @Natural))))
  r
glyph, Shaped
  ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r
label) ADVal
  Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
varInputs =
             let env :: AstEnv (ADVal Concrete)
env = AstVarName
  FullSpan
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> ADVal
     Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':)
                    @Natural
                    n
                    ((':)
                       @Natural
                       1
                       ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':)
                    @Natural
                    n
                    ((':)
                       @Natural
                       n
                       ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                 (TKScalar r))
              (TKS2
                 ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> AstEnv (ADVal Concrete)
-> AstEnv (ADVal Concrete)
forall (target :: TK -> Type) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName
  FullSpan
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
var ADVal
  Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
ADVal
  Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
varInputs AstEnv (ADVal Concrete)
forall (target :: TK -> Type). AstEnv target
emptyEnv
                 envMnist :: AstEnv (ADVal Concrete)
envMnist = AstVarName
  PrimalSpan
  (TKS2
     ((':)
        @Natural
        n
        ((':)
           @Natural
           SizeMnistHeight
           ((':) @Natural SizeMnistHeight ('[] @Natural))))
     (TKScalar r))
-> ADVal
     Concrete
     (TKS2
        ((':)
           @Natural
           n
           ((':)
              @Natural
              SizeMnistHeight
              ((':) @Natural SizeMnistHeight ('[] @Natural))))
        (TKScalar r))
-> AstEnv (ADVal Concrete)
-> AstEnv (ADVal Concrete)
forall (target :: TK -> Type) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName
  PrimalSpan
  (TKS2
     ((':)
        @Natural
        n
        ((':)
           @Natural
           SizeMnistHeight
           ((':) @Natural SizeMnistHeight ('[] @Natural))))
     (TKScalar r))
varGlyph (Shaped
  ((':)
     @Natural
     n
     ((':)
        @Natural
        SizeMnistHeight
        ((':) @Natural SizeMnistHeight ('[] @Natural))))
  r
-> ADVal
     Concrete
     (TKS2
        ((':)
           @Natural
           n
           ((':)
              @Natural
              SizeMnistHeight
              ((':) @Natural SizeMnistHeight ('[] @Natural))))
        (TKScalar r))
forall r (target :: TK -> Type) (sh :: [Natural]).
(GoodScalar r, BaseTensor target) =>
Shaped sh r -> target (TKS sh r)
sconcrete Shaped
  ((':)
     @Natural
     n
     ((':)
        @Natural
        SizeMnistHeight
        ((':) @Natural SizeMnistHeight ('[] @Natural))))
  r
glyph)
                            (AstEnv (ADVal Concrete) -> AstEnv (ADVal Concrete))
-> AstEnv (ADVal Concrete) -> AstEnv (ADVal Concrete)
forall a b. (a -> b) -> a -> b
$ AstVarName
  PrimalSpan
  (TKS2
     ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
     (TKScalar r))
-> ADVal
     Concrete
     (TKS2
        ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
        (TKScalar r))
-> AstEnv (ADVal Concrete)
-> AstEnv (ADVal Concrete)
forall (target :: TK -> Type) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName
  PrimalSpan
  (TKS2
     ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
     (TKScalar r))
varLabel (Shaped
  ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r
-> ADVal
     Concrete
     (TKS2
        ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
        (TKScalar r))
forall r (target :: TK -> Type) (sh :: [Natural]).
(GoodScalar r, BaseTensor target) =>
Shaped sh r -> target (TKS sh r)
sconcrete Shaped
  ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r
label) AstEnv (ADVal Concrete)
env
             in AstEnv (ADVal Concrete)
-> AstTensor AstMethodLet FullSpan (TKScalar r)
-> ADVal Concrete (TKScalar r)
forall (target :: TK -> Type) (y :: TK).
ADReady target =>
AstEnv target -> AstTensor AstMethodLet FullSpan y -> target y
interpretAstFull AstEnv (ADVal Concrete)
envMnist AstTensor AstMethodLet FullSpan (TKScalar r)
ast
           runBatch :: (Concrete (XParams kh kw c_out n_hidden r), StateAdam (XParams kh kw c_out n_hidden r))
                    -> (Int, [MnistDataS r])
                    -> IO (Concrete (XParams kh kw c_out n_hidden r), StateAdam (XParams kh kw c_out n_hidden r))
           runBatch (!Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
parameters, !StateAdam
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
stateAdam) (Int
k, [MnistDataS r]
chunk) = do
             let chunkS :: [MnistDataBatchS n r]
chunkS = ([MnistDataS r] -> MnistDataBatchS n r)
-> [[MnistDataS r]] -> [MnistDataBatchS n r]
forall a b. (a -> b) -> [a] -> [b]
map [MnistDataS r] -> MnistDataBatchS n r
forall (batch_size :: Natural) r.
(Elt r, KnownNat batch_size) =>
[MnistDataS r] -> MnistDataBatchS batch_size r
mkMnistDataBatchS
                          ([[MnistDataS r]] -> [MnistDataBatchS n r])
-> [[MnistDataS r]] -> [MnistDataBatchS n r]
forall a b. (a -> b) -> a -> b
$ ([MnistDataS r] -> Bool) -> [[MnistDataS r]] -> [[MnistDataS r]]
forall a. (a -> Bool) -> [a] -> [a]
filter (\[MnistDataS r]
ch -> [MnistDataS r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataS r]
ch Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
miniBatchSizeInt)
                          ([[MnistDataS r]] -> [[MnistDataS r]])
-> [[MnistDataS r]] -> [[MnistDataS r]]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataS r] -> [[MnistDataS r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
miniBatchSizeInt [MnistDataS r]
chunk
                 res :: (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     1
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     n
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
      (TKProduct
         (TKProduct
            (TKS2
               ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
               (TKScalar r))
            (TKS2
               ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     1
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     n
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
      (TKProduct
         (TKProduct
            (TKS2
               ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
               (TKScalar r))
            (TKS2
               ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
res@(Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
parameters2, StateAdam
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
_) =
                   (MnistDataBatchS n r
 -> ADVal
      Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (TKS2
                  ((':)
                     @Natural
                     n
                     ((':)
                        @Natural
                        1
                        ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                  (TKScalar r))
               (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
            (TKProduct
               (TKS2
                  ((':)
                     @Natural
                     n
                     ((':)
                        @Natural
                        n
                        ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                  (TKScalar r))
               (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
         (TKProduct
            (TKProduct
               (TKS2
                  ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                  (TKScalar r))
               (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
            (TKProduct
               (TKS2
                  ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                  (TKScalar r))
               (TKS2
                  ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
 -> ADVal Concrete (TKScalar r))
-> [MnistDataBatchS n r]
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':)
                    @Natural
                    n
                    ((':)
                       @Natural
                       1
                       ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':)
                    @Natural
                    n
                    ((':)
                       @Natural
                       n
                       ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                 (TKScalar r))
              (TKS2
                 ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> StateAdam
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':)
                    @Natural
                    n
                    ((':)
                       @Natural
                       1
                       ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':)
                    @Natural
                    n
                    ((':)
                       @Natural
                       n
                       ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                 (TKScalar r))
              (TKS2
                 ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (TKS2
                  ((':)
                     @Natural
                     n
                     ((':)
                        @Natural
                        1
                        ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                  (TKScalar r))
               (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
            (TKProduct
               (TKS2
                  ((':)
                     @Natural
                     n
                     ((':)
                        @Natural
                        n
                        ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                  (TKScalar r))
               (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
         (TKProduct
            (TKProduct
               (TKS2
                  ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                  (TKScalar r))
               (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
            (TKProduct
               (TKS2
                  ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                  (TKScalar r))
               (TKS2
                  ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
    StateAdam
      (TKProduct
         (TKProduct
            (TKProduct
               (TKS2
                  ((':)
                     @Natural
                     n
                     ((':)
                        @Natural
                        1
                        ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                  (TKScalar r))
               (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
            (TKProduct
               (TKS2
                  ((':)
                     @Natural
                     n
                     ((':)
                        @Natural
                        n
                        ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                  (TKScalar r))
               (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
         (TKProduct
            (TKProduct
               (TKS2
                  ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                  (TKScalar r))
               (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
            (TKProduct
               (TKS2
                  ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                  (TKScalar r))
               (TKS2
                  ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
forall a (x :: TK) (z :: TK).
KnownSTK x =>
(a -> ADVal Concrete x -> ADVal Concrete z)
-> [a] -> Concrete x -> StateAdam x -> (Concrete x, StateAdam x)
sgdAdam MnistDataBatchS n r
-> ADVal
     Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':)
                    @Natural
                    n
                    ((':)
                       @Natural
                       1
                       ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':)
                    @Natural
                    n
                    ((':)
                       @Natural
                       n
                       ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                 (TKScalar r))
              (TKS2
                 ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> ADVal Concrete (TKScalar r)
MnistDataBatchS n r
-> ADVal
     Concrete
     (X (ADCnnMnistParametersShaped
           Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> ADVal Concrete (TKScalar r)
f [MnistDataBatchS n r]
chunkS Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
parameters StateAdam
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
StateAdam
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
stateAdam
                 !trainScore :: r
trainScore = Int -> (forall (n :: Natural). KnownNat n => SNat n -> r) -> r
forall r.
Int -> (forall (n :: Natural). KnownNat n => SNat n -> r) -> r
withSNat ([MnistDataS r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataS r]
chunk) ((forall (n :: Natural). KnownNat n => SNat n -> r) -> r)
-> (forall (n :: Natural). KnownNat n => SNat n -> r) -> r
forall a b. (a -> b) -> a -> b
$ \(SNat @len) ->
                   forall (batch_size :: Natural).
KnownNat batch_size =>
MnistDataBatchS batch_size r
-> Concrete
     (X (ADCnnMnistParametersShaped
           Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> r
ftest @len ([MnistDataS r] -> MnistDataBatchS n r
forall (batch_size :: Natural) r.
(Elt r, KnownNat batch_size) =>
[MnistDataS r] -> MnistDataBatchS batch_size r
mkMnistDataBatchS [MnistDataS r]
chunk) Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
parameters2
                 !testScore :: r
testScore = forall (batch_size :: Natural).
KnownNat batch_size =>
MnistDataBatchS batch_size r
-> Concrete
     (X (ADCnnMnistParametersShaped
           Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> r
ftest @lenTestData MnistDataBatchS n r
testDataS Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
parameters2
                 !lenChunk :: Int
lenChunk = [MnistDataS r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataS r]
chunk
             Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
n_hiddenInt Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$ do
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
                 String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: (Batch %d with %d points)"
                        String
prefix Int
k Int
lenChunk
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
                 String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Training error:   %.2f%%"
                        String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
trainScore) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
                 String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Validation error: %.2f%%"
                        String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
testScore ) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
             (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     1
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     n
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
      (TKProduct
         (TKProduct
            (TKS2
               ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
               (TKScalar r))
            (TKS2
               ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     1
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     n
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
      (TKProduct
         (TKProduct
            (TKS2
               ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
               (TKScalar r))
            (TKS2
               ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          1
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          n
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                    (TKScalar r))
                 (TKS2
                    ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
      StateAdam
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          1
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          n
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                    (TKScalar r))
                 (TKS2
                    ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     1
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     n
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
      (TKProduct
         (TKProduct
            (TKS2
               ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
               (TKScalar r))
            (TKS2
               ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     1
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     n
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
      (TKProduct
         (TKProduct
            (TKS2
               ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
               (TKScalar r))
            (TKS2
               ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
res
       let runEpoch :: Int
                    -> (Concrete (XParams kh kw c_out n_hidden r), StateAdam (XParams kh kw c_out n_hidden r))
                    -> IO (Concrete (XParams kh kw c_out n_hidden r))
           runEpoch Int
n (Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
params2, StateAdam
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
_) | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
epochs = Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          1
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          n
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                    (TKScalar r))
                 (TKS2
                    ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
params2
           runEpoch Int
n paramsStateAdam :: (Concrete
   (X (ADCnnMnistParametersShaped
         Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
 StateAdam
   (X (ADCnnMnistParametersShaped
         Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
paramsStateAdam@(!Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
_, !StateAdam
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
_) = do
             Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
n_hiddenInt Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$ String -> String -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: [Epoch %d]" String
prefix Int
n
             let trainDataShuffled :: [MnistDataS r]
trainDataShuffled = StdGen -> [MnistDataS r] -> [MnistDataS r]
forall a. StdGen -> [a] -> [a]
shuffle (Int -> StdGen
mkStdGen (Int -> StdGen) -> Int -> StdGen
forall a b. (a -> b) -> a -> b
$ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
5) [MnistDataS r]
trainData
                 chunks :: [(Int, [MnistDataS r])]
chunks = Int -> [(Int, [MnistDataS r])] -> [(Int, [MnistDataS r])]
forall a. Int -> [a] -> [a]
take Int
maxBatches
                          ([(Int, [MnistDataS r])] -> [(Int, [MnistDataS r])])
-> [(Int, [MnistDataS r])] -> [(Int, [MnistDataS r])]
forall a b. (a -> b) -> a -> b
$ [Int] -> [[MnistDataS r]] -> [(Int, [MnistDataS r])]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
1 ..]
                          ([[MnistDataS r]] -> [(Int, [MnistDataS r])])
-> [[MnistDataS r]] -> [(Int, [MnistDataS r])]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataS r] -> [[MnistDataS r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
totalBatchSize [MnistDataS r]
trainDataShuffled
             res <- ((Concrete
    (TKProduct
       (TKProduct
          (TKProduct
             (TKS2
                ((':)
                   @Natural
                   n
                   ((':)
                      @Natural
                      1
                      ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                (TKScalar r))
             (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
          (TKProduct
             (TKS2
                ((':)
                   @Natural
                   n
                   ((':)
                      @Natural
                      n
                      ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                (TKScalar r))
             (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
       (TKProduct
          (TKProduct
             (TKS2
                ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                (TKScalar r))
             (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
          (TKProduct
             (TKS2
                ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                (TKScalar r))
             (TKS2
                ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
  StateAdam
    (TKProduct
       (TKProduct
          (TKProduct
             (TKS2
                ((':)
                   @Natural
                   n
                   ((':)
                      @Natural
                      1
                      ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                (TKScalar r))
             (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
          (TKProduct
             (TKS2
                ((':)
                   @Natural
                   n
                   ((':)
                      @Natural
                      n
                      ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                (TKScalar r))
             (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
       (TKProduct
          (TKProduct
             (TKS2
                ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                (TKScalar r))
             (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
          (TKProduct
             (TKS2
                ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                (TKScalar r))
             (TKS2
                ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
 -> (Int, [MnistDataS r])
 -> IO
      (Concrete
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKS2
                     ((':)
                        @Natural
                        n
                        ((':)
                           @Natural
                           1
                           ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                     (TKScalar r))
                  (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
               (TKProduct
                  (TKS2
                     ((':)
                        @Natural
                        n
                        ((':)
                           @Natural
                           n
                           ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                     (TKScalar r))
                  (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
            (TKProduct
               (TKProduct
                  (TKS2
                     ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                     (TKScalar r))
                  (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
               (TKProduct
                  (TKS2
                     ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                     (TKScalar r))
                  (TKS2
                     ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
       StateAdam
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKS2
                     ((':)
                        @Natural
                        n
                        ((':)
                           @Natural
                           1
                           ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                     (TKScalar r))
                  (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
               (TKProduct
                  (TKS2
                     ((':)
                        @Natural
                        n
                        ((':)
                           @Natural
                           n
                           ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                     (TKScalar r))
                  (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
            (TKProduct
               (TKProduct
                  (TKS2
                     ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                     (TKScalar r))
                  (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
               (TKProduct
                  (TKS2
                     ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                     (TKScalar r))
                  (TKS2
                     ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (TKS2
                  ((':)
                     @Natural
                     n
                     ((':)
                        @Natural
                        1
                        ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                  (TKScalar r))
               (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
            (TKProduct
               (TKS2
                  ((':)
                     @Natural
                     n
                     ((':)
                        @Natural
                        n
                        ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                  (TKScalar r))
               (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
         (TKProduct
            (TKProduct
               (TKS2
                  ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                  (TKScalar r))
               (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
            (TKProduct
               (TKS2
                  ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                  (TKScalar r))
               (TKS2
                  ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
    StateAdam
      (TKProduct
         (TKProduct
            (TKProduct
               (TKS2
                  ((':)
                     @Natural
                     n
                     ((':)
                        @Natural
                        1
                        ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                  (TKScalar r))
               (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
            (TKProduct
               (TKS2
                  ((':)
                     @Natural
                     n
                     ((':)
                        @Natural
                        n
                        ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                  (TKScalar r))
               (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
         (TKProduct
            (TKProduct
               (TKS2
                  ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                  (TKScalar r))
               (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
            (TKProduct
               (TKS2
                  ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                  (TKScalar r))
               (TKS2
                  ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
-> [(Int, [MnistDataS r])]
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          1
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          n
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                    (TKScalar r))
                 (TKS2
                    ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
      StateAdam
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          1
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          n
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                    (TKScalar r))
                 (TKS2
                    ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
forall (t :: Type -> Type) (m :: Type -> Type) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     1
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     n
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
      (TKProduct
         (TKProduct
            (TKS2
               ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
               (TKScalar r))
            (TKS2
               ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     1
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     n
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
      (TKProduct
         (TKProduct
            (TKS2
               ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
               (TKScalar r))
            (TKS2
               ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
-> (Int, [MnistDataS r])
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          1
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          n
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                    (TKScalar r))
                 (TKS2
                    ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
      StateAdam
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          1
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          n
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                    (TKScalar r))
                 (TKS2
                    ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
(Concrete
   (X (ADCnnMnistParametersShaped
         Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
 StateAdam
   (X (ADCnnMnistParametersShaped
         Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
-> (Int, [MnistDataS r])
-> IO
     (Concrete
        (X (ADCnnMnistParametersShaped
              Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
      StateAdam
        (X (ADCnnMnistParametersShaped
              Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
runBatch (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     1
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     n
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
      (TKProduct
         (TKProduct
            (TKS2
               ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
               (TKScalar r))
            (TKS2
               ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     1
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     n
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
      (TKProduct
         (TKProduct
            (TKS2
               ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
               (TKScalar r))
            (TKS2
               ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
(Concrete
   (X (ADCnnMnistParametersShaped
         Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
 StateAdam
   (X (ADCnnMnistParametersShaped
         Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
paramsStateAdam [(Int, [MnistDataS r])]
chunks
             runEpoch (succ n) res
       res <- runEpoch 1 (targetInit, initialStateAdam ftk)
       let testErrorFinal =
             r
1 r -> r -> r
forall a. Num a => a -> a -> a
- MnistDataBatchS n r
-> Concrete
     (X (ADCnnMnistParametersShaped
           Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> r
forall (batch_size :: Natural).
KnownNat batch_size =>
MnistDataBatchS batch_size r
-> Concrete
     (X (ADCnnMnistParametersShaped
           Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> r
ftest MnistDataBatchS n r
testDataS Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
res
       testErrorFinal @?~ expected

tensorADValMnistTestsCNNSI :: TestTree
tensorADValMnistTestsCNNSI :: TestTree
tensorADValMnistTestsCNNSI = String -> [TestTree] -> TestTree
testGroup String
"CNNS Intermediate MNIST tests"
  [ String
-> Int
-> Int
-> SNat 4
-> SNat 4
-> Int
-> Int
-> Int
-> Int
-> Double
-> TestTree
forall (kh :: Natural) (kw :: Natural) r.
((<=) @Natural 1 kh, (<=) @Natural 1 kw, Differentiable r,
 GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> SNat kh
-> SNat kw
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNSI String
"CNNSI 1 epoch, 1 batch"
                       Int
1 Int
1 (forall (n :: Natural). KnownNat n => SNat n
SNat @4) (forall (n :: Natural). KnownNat n => SNat n
SNat @4) Int
8 Int
16 Int
1 Int
1
                       (Double
1 :: Double)
  , String
-> Int
-> Int
-> SNat 2
-> SNat 3
-> Int
-> Int
-> Int
-> Int
-> Float
-> TestTree
forall (kh :: Natural) (kw :: Natural) r.
((<=) @Natural 1 kh, (<=) @Natural 1 kw, Differentiable r,
 GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> SNat kh
-> SNat kw
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNSI String
"CNNSI artificial 1 2 3 4 5"
                       Int
1 Int
1 (forall (n :: Natural). KnownNat n => SNat n
SNat @2) (forall (n :: Natural). KnownNat n => SNat n
SNat @3) Int
4 Int
5 Int
1 Int
10
                       (Float
1 :: Float)
  , String
-> Int
-> Int
-> SNat 3
-> SNat 2
-> Int
-> Int
-> Int
-> Int
-> Double
-> TestTree
forall (kh :: Natural) (kw :: Natural) r.
((<=) @Natural 1 kh, (<=) @Natural 1 kw, Differentiable r,
 GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> SNat kh
-> SNat kw
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNSI String
"CNNSI artificial 5 4 3 2 1"
                       Int
5 Int
4 (forall (n :: Natural). KnownNat n => SNat n
SNat @3) (forall (n :: Natural). KnownNat n => SNat n
SNat @2) Int
1 Int
1 Int
1 Int
1
                       (Double
1 :: Double)
  , String
-> Int
-> Int
-> SNat 4
-> SNat 4
-> Int
-> Int
-> Int
-> Int
-> Float
-> TestTree
forall (kh :: Natural) (kw :: Natural) r.
((<=) @Natural 1 kh, (<=) @Natural 1 kw, Differentiable r,
 GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> SNat kh
-> SNat kw
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNSI String
"CNNSI 1 epoch, 0 batch"
                       Int
1 Int
0 (forall (n :: Natural). KnownNat n => SNat n
SNat @4) (forall (n :: Natural). KnownNat n => SNat n
SNat @4) Int
16 Int
64 Int
16 Int
50
                       (Float
1.0 :: Float)
  ]

-- JAX differentiation, Ast term built and differentiated only once
-- and the result interpreted with different inputs in each gradient
-- descent iteration.
mnistTestCaseCNNSO
  :: forall kh kw r.
     ( 1 <= kh, 1 <= kw
     , Differentiable r, GoodScalar r
     , PrintfArg r, AssertEqualUpToEpsilon r, ADTensorScalar r ~ r )
  => String
  -> Int -> Int -> SNat kh -> SNat kw -> Int -> Int -> Int -> Int -> r
  -> TestTree
mnistTestCaseCNNSO :: forall (kh :: Natural) (kw :: Natural) r.
((<=) @Natural 1 kh, (<=) @Natural 1 kw, Differentiable r,
 GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r,
 (ADTensorScalar r :: Type) ~ (r :: Type)) =>
String
-> Int
-> Int
-> SNat kh
-> SNat kw
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNSO String
prefix Int
epochs Int
maxBatches kh :: SNat kh
kh@SNat kh
SNat kw :: SNat kw
kw@SNat kw
SNat Int
c_outInt Int
n_hiddenInt
                   Int
miniBatchSizeInt Int
totalBatchSize r
expected =
  Int
-> (forall (n :: Natural). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Natural). KnownNat n => SNat n -> r) -> r
withSNat Int
c_outInt ((forall (n :: Natural). KnownNat n => SNat n -> TestTree)
 -> TestTree)
-> (forall (n :: Natural). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
_c_outSNat :: SNat c_out) ->
  Int
-> (forall (n :: Natural). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Natural). KnownNat n => SNat n -> r) -> r
withSNat Int
n_hiddenInt ((forall (n :: Natural). KnownNat n => SNat n -> TestTree)
 -> TestTree)
-> (forall (n :: Natural). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
_n_hiddenSNat :: SNat n_hidden) ->
  Int
-> (forall (n :: Natural). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Natural). KnownNat n => SNat n -> r) -> r
withSNat Int
miniBatchSizeInt ((forall (n :: Natural). KnownNat n => SNat n -> TestTree)
 -> TestTree)
-> (forall (n :: Natural). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
miniBatchSize :: SNat miniBatchSize) ->
  let targetInit :: Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
targetInit =
        (Concrete
   (X (ADCnnMnistParametersShaped
         Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
 StdGen)
-> Concrete
     (X (ADCnnMnistParametersShaped
           Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
forall a b. (a, b) -> a
fst ((Concrete
    (X (ADCnnMnistParametersShaped
          Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
  StdGen)
 -> Concrete
      (X (ADCnnMnistParametersShaped
            Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
-> (Concrete
      (X (ADCnnMnistParametersShaped
            Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
    StdGen)
-> Concrete
     (X (ADCnnMnistParametersShaped
           Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
forall a b. (a -> b) -> a -> b
$ forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue
                @(Concrete (X (MnistCnnShaped2.ADCnnMnistParametersShaped
                                 Concrete SizeMnistHeight SizeMnistWidth
                                 kh kw c_out n_hidden r)))
                Double
0.4 (Int -> StdGen
mkStdGen Int
44)
      name :: String
name = String
prefix String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": "
             String -> String -> String
forall a. [a] -> [a] -> [a]
++ [String] -> String
unwords [ Int -> String
forall a. Show a => a -> String
show Int
epochs, Int -> String
forall a. Show a => a -> String
show Int
maxBatches
                        , Int -> String
forall a. Show a => a -> String
show (SNat kh -> Int
forall (n :: Natural). SNat n -> Int
sNatValue SNat kh
kh), Int -> String
forall a. Show a => a -> String
show (SNat kw -> Int
forall (n :: Natural). SNat n -> Int
sNatValue SNat kw
kw)
                        , Int -> String
forall a. Show a => a -> String
show Int
c_outInt, Int -> String
forall a. Show a => a -> String
show Int
n_hiddenInt
                        , Int -> String
forall a. Show a => a -> String
show Int
miniBatchSizeInt
                        , Int -> String
forall a. Show a => a -> String
show (Int -> String) -> Int -> String
forall a b. (a -> b) -> a -> b
$ SingletonTK
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK
                          (SingletonTK
   (X (ADCnnMnistParametersShaped
         Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
 -> Int)
-> SingletonTK
     (X (ADCnnMnistParametersShaped
           Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams kh kw c_out n_hidden r)
                        , Int -> String
forall a. Show a => a -> String
show (SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':)
                    @Natural
                    n
                    ((':)
                       @Natural
                       1
                       ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':)
                    @Natural
                    n
                    ((':)
                       @Natural
                       n
                       ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                 (TKScalar r))
              (TKS2
                 ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
targetInit) ]
      ftest :: KnownNat batch_size
            => MnistDataBatchS batch_size r
            -> Concrete (XParams kh kw c_out n_hidden r) -> r
      ftest :: forall (batch_size :: Natural).
KnownNat batch_size =>
MnistDataBatchS batch_size r
-> Concrete
     (X (ADCnnMnistParametersShaped
           Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> r
ftest @batch_size MnistDataBatchS batch_size r
mnistData Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
pars =
        SNat kh
-> SNat kw
-> SNat n
-> SNat n
-> SNat batch_size
-> MnistDataBatchS batch_size r
-> ADCnnMnistParametersShaped
     Concrete SizeMnistHeight SizeMnistHeight kh kw n n r
-> r
forall (kh :: Natural) (kw :: Natural) (h :: Natural)
       (w :: Natural) (c_out :: Natural) (n_hidden :: Natural)
       (batch_size :: Natural) (target :: TK -> Type) r.
((h :: Natural) ~ (SizeMnistHeight :: Natural),
 (w :: Natural) ~ (SizeMnistHeight :: Natural), (<=) @Natural 1 kh,
 (<=) @Natural 1 kw,
 (target :: (TK -> Type)) ~ (Concrete :: (TK -> Type)),
 GoodScalar r, Differentiable r) =>
SNat kh
-> SNat kw
-> SNat c_out
-> SNat n_hidden
-> SNat batch_size
-> MnistDataBatchS batch_size r
-> ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
-> r
MnistCnnShaped2.convMnistTestS SNat kh
kh SNat kw
kw (forall (n :: Natural). KnownNat n => SNat n
SNat @c_out) (forall (n :: Natural). KnownNat n => SNat n
SNat @n_hidden)
          (forall (n :: Natural). KnownNat n => SNat n
SNat @batch_size) MnistDataBatchS batch_size r
mnistData (forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget @Concrete Concrete
  (X ((Concrete
         (TKS2
            ((':)
               @Natural
               n
               ((':)
                  @Natural
                  1
                  ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
            (TKScalar r)),
       Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
      (Concrete
         (TKS2
            ((':)
               @Natural
               n
               ((':)
                  @Natural
                  n
                  ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
            (TKScalar r)),
       Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
      (Concrete
         (TKS2
            ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
            (TKScalar r)),
       Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
      (Concrete
         (TKS2
            ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
            (TKScalar r)),
       Concrete
         (TKS2
            ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
pars)
  in String -> Assertion -> TestTree
testCase String
name (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ do
      Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
        String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: Epochs to run/max batches per epoch: %d/%d"
               String
prefix Int
epochs Int
maxBatches
      trainData <- (MnistData r -> MnistDataS r) -> [MnistData r] -> [MnistDataS r]
forall a b. (a -> b) -> [a] -> [b]
map MnistData r -> MnistDataS r
forall r. PrimElt r => MnistData r -> MnistDataS r
mkMnistDataS
                   ([MnistData r] -> [MnistDataS r])
-> IO [MnistData r] -> IO [MnistDataS r]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> String -> IO [MnistData r]
forall r.
(Storable r, Fractional r) =>
String -> String -> IO [MnistData r]
loadMnistData String
trainGlyphsPath String
trainLabelsPath
      testData <- map mkMnistDataS . take (totalBatchSize * maxBatches)
                  <$> loadMnistData testGlyphsPath testLabelsPath
      withSNat (totalBatchSize * maxBatches) $ \(SNat @lenTestData) -> do
       let testDataS :: MnistDataBatchS n r
testDataS = [MnistDataS r] -> MnistDataBatchS n r
forall (batch_size :: Natural) r.
(Elt r, KnownNat batch_size) =>
[MnistDataS r] -> MnistDataBatchS batch_size r
mkMnistDataBatchS [MnistDataS r]
testData
           dataInit :: (Concrete
   (TKS
      ((':)
         @Natural
         n
         ((':)
            @Natural
            SizeMnistHeight
            ((':) @Natural SizeMnistHeight ('[] @Natural))))
      r),
 Concrete
   (TKS
      ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r))
dataInit = case Int -> [MnistDataS r] -> [[MnistDataS r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
miniBatchSizeInt [MnistDataS r]
testData of
             [MnistDataS r]
d : [[MnistDataS r]]
_ -> let (Shaped
  ((':)
     @Natural
     n
     ((':)
        @Natural
        SizeMnistHeight
        ((':) @Natural SizeMnistHeight ('[] @Natural))))
  r
dglyph, Shaped
  ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r
dlabel) = [MnistDataS r]
-> (Shaped
      ((':)
         @Natural
         n
         ((':)
            @Natural
            SizeMnistHeight
            ((':) @Natural SizeMnistHeight ('[] @Natural))))
      r,
    Shaped
      ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r)
forall (batch_size :: Natural) r.
(Elt r, KnownNat batch_size) =>
[MnistDataS r] -> MnistDataBatchS batch_size r
mkMnistDataBatchS [MnistDataS r]
d
                      in (Shaped
  ((':)
     @Natural
     n
     ((':)
        @Natural
        SizeMnistHeight
        ((':) @Natural SizeMnistHeight ('[] @Natural))))
  r
-> Concrete
     (TKS
        ((':)
           @Natural
           n
           ((':)
              @Natural
              SizeMnistHeight
              ((':) @Natural SizeMnistHeight ('[] @Natural))))
        r)
forall r (target :: TK -> Type) (sh :: [Natural]).
(GoodScalar r, BaseTensor target) =>
Shaped sh r -> target (TKS sh r)
sconcrete Shaped
  ((':)
     @Natural
     n
     ((':)
        @Natural
        SizeMnistHeight
        ((':) @Natural SizeMnistHeight ('[] @Natural))))
  r
dglyph, Shaped
  ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r
-> Concrete
     (TKS
        ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r)
forall r (target :: TK -> Type) (sh :: [Natural]).
(GoodScalar r, BaseTensor target) =>
Shaped sh r -> target (TKS sh r)
sconcrete Shaped
  ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r
dlabel)
             [] -> String
-> (Concrete
      (TKS
         ((':)
            @Natural
            n
            ((':)
               @Natural
               SizeMnistHeight
               ((':) @Natural SizeMnistHeight ('[] @Natural))))
         r),
    Concrete
      (TKS
         ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r))
forall a. HasCallStack => String -> a
error String
"empty test data"
           f :: ( MnistCnnShaped2.ADCnnMnistParametersShaped
                    (AstTensor AstMethodLet FullSpan)
                    SizeMnistHeight SizeMnistWidth
                    kh kw c_out n_hidden r
                , ( AstTensor AstMethodLet FullSpan (TKS '[miniBatchSize, SizeMnistHeight, SizeMnistWidth] r)
                  , AstTensor AstMethodLet FullSpan (TKS '[miniBatchSize, SizeMnistLabel] r) ) )
             -> AstTensor AstMethodLet FullSpan (TKScalar r)
           f :: (ADCnnMnistParametersShaped
   (AstTensor AstMethodLet FullSpan)
   SizeMnistHeight
   SizeMnistHeight
   kh
   kw
   n
   n
   r,
 (AstTensor
    AstMethodLet
    FullSpan
    (TKS
       ((':)
          @Natural
          n
          ((':)
             @Natural
             SizeMnistHeight
             ((':) @Natural SizeMnistHeight ('[] @Natural))))
       r),
  AstTensor
    AstMethodLet
    FullSpan
    (TKS
       ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
       r)))
-> AstTensor AstMethodLet FullSpan (TKScalar r)
f = \ (ADCnnMnistParametersShaped
  (AstTensor AstMethodLet FullSpan)
  SizeMnistHeight
  SizeMnistHeight
  kh
  kw
  n
  n
  r
pars, (AstTensor
  AstMethodLet
  FullSpan
  (TKS
     ((':)
        @Natural
        n
        ((':)
           @Natural
           SizeMnistHeight
           ((':) @Natural SizeMnistHeight ('[] @Natural))))
     r)
glyphR, AstTensor
  AstMethodLet
  FullSpan
  (TKS
     ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r)
labelR)) ->
             SNat kh
-> SNat kw
-> SNat n
-> SNat n
-> SNat n
-> (PrimalOf
      (AstTensor AstMethodLet FullSpan)
      (TKS
         ((':)
            @Natural
            n
            ((':)
               @Natural
               SizeMnistHeight
               ((':) @Natural SizeMnistHeight ('[] @Natural))))
         r),
    PrimalOf
      (AstTensor AstMethodLet FullSpan)
      (TKS
         ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r))
-> ADCnnMnistParametersShaped
     (AstTensor AstMethodLet FullSpan)
     SizeMnistHeight
     SizeMnistHeight
     kh
     kw
     n
     n
     r
-> AstTensor AstMethodLet FullSpan (TKScalar r)
forall (kh :: Natural) (kw :: Natural) (h :: Natural)
       (w :: Natural) (c_out :: Natural) (n_hidden :: Natural)
       (batch_size :: Natural) (target :: TK -> Type) r.
((h :: Natural) ~ (SizeMnistHeight :: Natural),
 (w :: Natural) ~ (SizeMnistHeight :: Natural), (<=) @Natural 1 kh,
 (<=) @Natural 1 kw, ADReady target, ADReady (PrimalOf target),
 GoodScalar r, Differentiable r) =>
SNat kh
-> SNat kw
-> SNat c_out
-> SNat n_hidden
-> SNat batch_size
-> (PrimalOf
      target
      (TKS
         ((':)
            @Natural
            batch_size
            ((':) @Natural h ((':) @Natural w ('[] @Natural))))
         r),
    PrimalOf
      target
      (TKS
         ((':)
            @Natural batch_size ((':) @Natural SizeMnistLabel ('[] @Natural)))
         r))
-> ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
-> target (TKScalar r)
MnistCnnShaped2.convMnistLossFusedS
               SNat kh
kh SNat kw
kw (forall (n :: Natural). KnownNat n => SNat n
SNat @c_out) (forall (n :: Natural). KnownNat n => SNat n
SNat @n_hidden)
               SNat n
miniBatchSize (AstTensor
  AstMethodLet
  FullSpan
  (TKS
     ((':)
        @Natural
        n
        ((':)
           @Natural
           SizeMnistHeight
           ((':) @Natural SizeMnistHeight ('[] @Natural))))
     r)
-> PrimalOf
     (AstTensor AstMethodLet FullSpan)
     (TKS
        ((':)
           @Natural
           n
           ((':)
              @Natural
              SizeMnistHeight
              ((':) @Natural SizeMnistHeight ('[] @Natural))))
        r)
forall (target :: TK -> Type) (sh :: [Natural]) (x :: TK).
BaseTensor target =>
target (TKS2 sh x) -> PrimalOf target (TKS2 sh x)
sprimalPart AstTensor
  AstMethodLet
  FullSpan
  (TKS
     ((':)
        @Natural
        n
        ((':)
           @Natural
           SizeMnistHeight
           ((':) @Natural SizeMnistHeight ('[] @Natural))))
     r)
glyphR, AstTensor
  AstMethodLet
  FullSpan
  (TKS
     ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r)
-> PrimalOf
     (AstTensor AstMethodLet FullSpan)
     (TKS
        ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r)
forall (target :: TK -> Type) (sh :: [Natural]) (x :: TK).
BaseTensor target =>
target (TKS2 sh x) -> PrimalOf target (TKS2 sh x)
sprimalPart AstTensor
  AstMethodLet
  FullSpan
  (TKS
     ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r)
labelR) ADCnnMnistParametersShaped
  (AstTensor AstMethodLet FullSpan)
  SizeMnistHeight
  SizeMnistHeight
  kh
  kw
  n
  n
  r
pars
           artRaw :: AstArtifactRev
  (X (((AstTensor
          AstMethodLet
          FullSpan
          (TKS2
             ((':)
                @Natural
                n
                ((':)
                   @Natural
                   1
                   ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
             (TKScalar r)),
        AstTensor
          AstMethodLet
          FullSpan
          (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
       (AstTensor
          AstMethodLet
          FullSpan
          (TKS2
             ((':)
                @Natural
                n
                ((':)
                   @Natural
                   n
                   ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
             (TKScalar r)),
        AstTensor
          AstMethodLet
          FullSpan
          (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
       (AstTensor
          AstMethodLet
          FullSpan
          (TKS2
             ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
             (TKScalar r)),
        AstTensor
          AstMethodLet
          FullSpan
          (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
       (AstTensor
          AstMethodLet
          FullSpan
          (TKS2
             ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
             (TKScalar r)),
        AstTensor
          AstMethodLet
          FullSpan
          (TKS2
             ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))),
      (AstTensor
         AstMethodLet
         FullSpan
         (TKS
            ((':)
               @Natural
               n
               ((':)
                  @Natural
                  SizeMnistHeight
                  ((':) @Natural SizeMnistHeight ('[] @Natural))))
            r),
       AstTensor
         AstMethodLet
         FullSpan
         (TKS
            ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
            r))))
  (TKScalar r)
artRaw = ((((AstTensor
      AstMethodLet
      FullSpan
      (TKS2
         ((':)
            @Natural
            n
            ((':)
               @Natural
               1
               ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
         (TKScalar r)),
    AstTensor
      AstMethodLet
      FullSpan
      (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
   (AstTensor
      AstMethodLet
      FullSpan
      (TKS2
         ((':)
            @Natural
            n
            ((':)
               @Natural
               n
               ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
         (TKScalar r)),
    AstTensor
      AstMethodLet
      FullSpan
      (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
   (AstTensor
      AstMethodLet
      FullSpan
      (TKS2
         ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
         (TKScalar r)),
    AstTensor
      AstMethodLet
      FullSpan
      (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
   (AstTensor
      AstMethodLet
      FullSpan
      (TKS2
         ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
         (TKScalar r)),
    AstTensor
      AstMethodLet
      FullSpan
      (TKS2
         ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))),
  (AstTensor
     AstMethodLet
     FullSpan
     (TKS
        ((':)
           @Natural
           n
           ((':)
              @Natural
              SizeMnistHeight
              ((':) @Natural SizeMnistHeight ('[] @Natural))))
        r),
   AstTensor
     AstMethodLet
     FullSpan
     (TKS
        ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
        r)))
 -> AstTensor AstMethodLet FullSpan (TKScalar r))
-> Value
     (((AstTensor
          AstMethodLet
          FullSpan
          (TKS2
             ((':)
                @Natural
                n
                ((':)
                   @Natural
                   1
                   ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
             (TKScalar r)),
        AstTensor
          AstMethodLet
          FullSpan
          (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
       (AstTensor
          AstMethodLet
          FullSpan
          (TKS2
             ((':)
                @Natural
                n
                ((':)
                   @Natural
                   n
                   ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
             (TKScalar r)),
        AstTensor
          AstMethodLet
          FullSpan
          (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
       (AstTensor
          AstMethodLet
          FullSpan
          (TKS2
             ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
             (TKScalar r)),
        AstTensor
          AstMethodLet
          FullSpan
          (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
       (AstTensor
          AstMethodLet
          FullSpan
          (TKS2
             ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
             (TKScalar r)),
        AstTensor
          AstMethodLet
          FullSpan
          (TKS2
             ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))),
      (AstTensor
         AstMethodLet
         FullSpan
         (TKS
            ((':)
               @Natural
               n
               ((':)
                  @Natural
                  SizeMnistHeight
                  ((':) @Natural SizeMnistHeight ('[] @Natural))))
            r),
       AstTensor
         AstMethodLet
         FullSpan
         (TKS
            ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
            r)))
-> AstArtifactRev
     (X (((AstTensor
             AstMethodLet
             FullSpan
             (TKS2
                ((':)
                   @Natural
                   n
                   ((':)
                      @Natural
                      1
                      ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                (TKScalar r)),
           AstTensor
             AstMethodLet
             FullSpan
             (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
          (AstTensor
             AstMethodLet
             FullSpan
             (TKS2
                ((':)
                   @Natural
                   n
                   ((':)
                      @Natural
                      n
                      ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                (TKScalar r)),
           AstTensor
             AstMethodLet
             FullSpan
             (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
          (AstTensor
             AstMethodLet
             FullSpan
             (TKS2
                ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                (TKScalar r)),
           AstTensor
             AstMethodLet
             FullSpan
             (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
          (AstTensor
             AstMethodLet
             FullSpan
             (TKS2
                ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                (TKScalar r)),
           AstTensor
             AstMethodLet
             FullSpan
             (TKS2
                ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))),
         (AstTensor
            AstMethodLet
            FullSpan
            (TKS
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     SizeMnistHeight
                     ((':) @Natural SizeMnistHeight ('[] @Natural))))
               r),
          AstTensor
            AstMethodLet
            FullSpan
            (TKS
               ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
               r))))
     (TKScalar r)
forall src r tgt.
((X src :: TK) ~ (X (Value src) :: TK), KnownSTK (X src),
 AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
 AdaptableTarget Concrete (Value src),
 (tgt :: Type)
 ~ (AstTensor AstMethodLet FullSpan (TKScalar r) :: Type)) =>
(src -> tgt) -> Value src -> AstArtifactRev (X src) (TKScalar r)
gradArtifact (((AstTensor
     AstMethodLet
     FullSpan
     (TKS2
        ((':)
           @Natural
           n
           ((':)
              @Natural
              1
              ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
        (TKScalar r)),
   AstTensor
     AstMethodLet
     FullSpan
     (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
  (AstTensor
     AstMethodLet
     FullSpan
     (TKS2
        ((':)
           @Natural
           n
           ((':)
              @Natural
              n
              ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
        (TKScalar r)),
   AstTensor
     AstMethodLet
     FullSpan
     (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
  (AstTensor
     AstMethodLet
     FullSpan
     (TKS2
        ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
        (TKScalar r)),
   AstTensor
     AstMethodLet
     FullSpan
     (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
  (AstTensor
     AstMethodLet
     FullSpan
     (TKS2
        ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
        (TKScalar r)),
   AstTensor
     AstMethodLet
     FullSpan
     (TKS2
        ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))),
 (AstTensor
    AstMethodLet
    FullSpan
    (TKS
       ((':)
          @Natural
          n
          ((':)
             @Natural
             SizeMnistHeight
             ((':) @Natural SizeMnistHeight ('[] @Natural))))
       r),
  AstTensor
    AstMethodLet
    FullSpan
    (TKS
       ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
       r)))
-> AstTensor AstMethodLet FullSpan (TKScalar r)
(ADCnnMnistParametersShaped
   (AstTensor AstMethodLet FullSpan)
   SizeMnistHeight
   SizeMnistHeight
   kh
   kw
   n
   n
   r,
 (AstTensor
    AstMethodLet
    FullSpan
    (TKS
       ((':)
          @Natural
          n
          ((':)
             @Natural
             SizeMnistHeight
             ((':) @Natural SizeMnistHeight ('[] @Natural))))
       r),
  AstTensor
    AstMethodLet
    FullSpan
    (TKS
       ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
       r)))
-> AstTensor AstMethodLet FullSpan (TKScalar r)
f (Concrete
  (X ((Concrete
         (TKS2
            ((':)
               @Natural
               n
               ((':)
                  @Natural
                  1
                  ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
            (TKScalar r)),
       Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
      (Concrete
         (TKS2
            ((':)
               @Natural
               n
               ((':)
                  @Natural
                  n
                  ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
            (TKScalar r)),
       Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
      (Concrete
         (TKS2
            ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
            (TKScalar r)),
       Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
      (Concrete
         (TKS2
            ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
            (TKScalar r)),
       Concrete
         (TKS2
            ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> ((Concrete
       (TKS2
          ((':)
             @Natural
             n
             ((':)
                @Natural
                1
                ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
          (TKScalar r)),
     Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
    (Concrete
       (TKS2
          ((':)
             @Natural
             n
             ((':)
                @Natural
                n
                ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
          (TKScalar r)),
     Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
    (Concrete
       (TKS2
          ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
          (TKScalar r)),
     Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
    (Concrete
       (TKS2
          ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
          (TKScalar r)),
     Concrete
       (TKS2 ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))
forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget Concrete
  (X ((Concrete
         (TKS2
            ((':)
               @Natural
               n
               ((':)
                  @Natural
                  1
                  ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
            (TKScalar r)),
       Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
      (Concrete
         (TKS2
            ((':)
               @Natural
               n
               ((':)
                  @Natural
                  n
                  ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
            (TKScalar r)),
       Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
      (Concrete
         (TKS2
            ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
            (TKScalar r)),
       Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
      (Concrete
         (TKS2
            ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
            (TKScalar r)),
       Concrete
         (TKS2
            ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
targetInit, (Concrete
   (TKS
      ((':)
         @Natural
         n
         ((':)
            @Natural
            SizeMnistHeight
            ((':) @Natural SizeMnistHeight ('[] @Natural))))
      r),
 Concrete
   (TKS
      ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r))
dataInit)
           art :: AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':)
                    @Natural
                    n
                    ((':)
                       @Natural
                       1
                       ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':)
                    @Natural
                    n
                    ((':)
                       @Natural
                       n
                       ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                 (TKScalar r))
              (TKS2
                 ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
     (TKProduct
        (TKS
           ((':)
              @Natural
              n
              ((':)
                 @Natural
                 SizeMnistHeight
                 ((':) @Natural SizeMnistHeight ('[] @Natural))))
           r)
        (TKS
           ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
           r)))
  (TKScalar r)
art = AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':)
                    @Natural
                    n
                    ((':)
                       @Natural
                       1
                       ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':)
                    @Natural
                    n
                    ((':)
                       @Natural
                       n
                       ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                 (TKScalar r))
              (TKS2
                 ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
     (TKProduct
        (TKS
           ((':)
              @Natural
              n
              ((':)
                 @Natural
                 SizeMnistHeight
                 ((':) @Natural SizeMnistHeight ('[] @Natural))))
           r)
        (TKS
           ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
           r)))
  (TKScalar r)
-> AstArtifactRev
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          1
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          n
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                    (TKScalar r))
                 (TKS2
                    ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
        (TKProduct
           (TKS
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    SizeMnistHeight
                    ((':) @Natural SizeMnistHeight ('[] @Natural))))
              r)
           (TKS
              ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
              r)))
     (TKScalar r)
forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstArtifactRev x z
simplifyArtifactGradient AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':)
                    @Natural
                    n
                    ((':)
                       @Natural
                       1
                       ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':)
                    @Natural
                    n
                    ((':)
                       @Natural
                       n
                       ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                 (TKScalar r))
              (TKS2
                 ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
     (TKProduct
        (TKS
           ((':)
              @Natural
              n
              ((':)
                 @Natural
                 SizeMnistHeight
                 ((':) @Natural SizeMnistHeight ('[] @Natural))))
           r)
        (TKS
           ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
           r)))
  (TKScalar r)
AstArtifactRev
  (X (((AstTensor
          AstMethodLet
          FullSpan
          (TKS2
             ((':)
                @Natural
                n
                ((':)
                   @Natural
                   1
                   ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
             (TKScalar r)),
        AstTensor
          AstMethodLet
          FullSpan
          (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
       (AstTensor
          AstMethodLet
          FullSpan
          (TKS2
             ((':)
                @Natural
                n
                ((':)
                   @Natural
                   n
                   ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
             (TKScalar r)),
        AstTensor
          AstMethodLet
          FullSpan
          (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
       (AstTensor
          AstMethodLet
          FullSpan
          (TKS2
             ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
             (TKScalar r)),
        AstTensor
          AstMethodLet
          FullSpan
          (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
       (AstTensor
          AstMethodLet
          FullSpan
          (TKS2
             ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
             (TKScalar r)),
        AstTensor
          AstMethodLet
          FullSpan
          (TKS2
             ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))),
      (AstTensor
         AstMethodLet
         FullSpan
         (TKS
            ((':)
               @Natural
               n
               ((':)
                  @Natural
                  SizeMnistHeight
                  ((':) @Natural SizeMnistHeight ('[] @Natural))))
            r),
       AstTensor
         AstMethodLet
         FullSpan
         (TKS
            ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
            r))))
  (TKScalar r)
artRaw
           go :: [MnistDataBatchS miniBatchSize r]
              -> ( Concrete (XParams kh kw c_out n_hidden r)
                 , StateAdam (XParams kh kw c_out n_hidden r) )
              -> ( Concrete (XParams kh kw c_out n_hidden r)
                 , StateAdam (XParams kh kw c_out n_hidden r) )
           go :: [(Shaped
    ((':)
       @Natural
       n
       ((':)
          @Natural
          SizeMnistHeight
          ((':) @Natural SizeMnistHeight ('[] @Natural))))
    r,
  Shaped
    ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r)]
-> (Concrete
      (X (ADCnnMnistParametersShaped
            Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
    StateAdam
      (X (ADCnnMnistParametersShaped
            Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
-> (Concrete
      (X (ADCnnMnistParametersShaped
            Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
    StateAdam
      (X (ADCnnMnistParametersShaped
            Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
go [] (Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
parameters, StateAdam
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
stateAdam) = (Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
parameters, StateAdam
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
stateAdam)
           go ((Shaped
  ((':)
     @Natural
     n
     ((':)
        @Natural
        SizeMnistHeight
        ((':) @Natural SizeMnistHeight ('[] @Natural))))
  r
glyph, Shaped
  ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r
label) : [(Shaped
    ((':)
       @Natural
       n
       ((':)
          @Natural
          SizeMnistHeight
          ((':) @Natural SizeMnistHeight ('[] @Natural))))
    r,
  Shaped
    ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r)]
rest) (!Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
parameters, !StateAdam
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
stateAdam) =
             let parametersAndInput :: Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':)
                    @Natural
                    n
                    ((':)
                       @Natural
                       1
                       ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':)
                    @Natural
                    n
                    ((':)
                       @Natural
                       n
                       ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                 (TKScalar r))
              (TKS2
                 ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
     (TKProduct
        (TKS
           ((':)
              @Natural
              n
              ((':)
                 @Natural
                 SizeMnistHeight
                 ((':) @Natural SizeMnistHeight ('[] @Natural))))
           r)
        (TKS
           ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
           r)))
parametersAndInput =
                   Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> Concrete
     (TKProduct
        (TKS
           ((':)
              @Natural
              n
              ((':)
                 @Natural
                 SizeMnistHeight
                 ((':) @Natural SizeMnistHeight ('[] @Natural))))
           r)
        (TKS
           ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          1
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          n
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                    (TKScalar r))
                 (TKS2
                    ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
        (TKProduct
           (TKS
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    SizeMnistHeight
                    ((':) @Natural SizeMnistHeight ('[] @Natural))))
              r)
           (TKS
              ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
              r)))
forall (x :: TK) (z :: TK).
Concrete x -> Concrete z -> Concrete (TKProduct x z)
forall (target :: TK -> Type) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
parameters (Concrete
  (TKS
     ((':)
        @Natural
        n
        ((':)
           @Natural
           SizeMnistHeight
           ((':) @Natural SizeMnistHeight ('[] @Natural))))
     r)
-> Concrete
     (TKS
        ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r)
-> Concrete
     (TKProduct
        (TKS
           ((':)
              @Natural
              n
              ((':)
                 @Natural
                 SizeMnistHeight
                 ((':) @Natural SizeMnistHeight ('[] @Natural))))
           r)
        (TKS
           ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r))
forall (x :: TK) (z :: TK).
Concrete x -> Concrete z -> Concrete (TKProduct x z)
forall (target :: TK -> Type) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair (Shaped
  ((':)
     @Natural
     n
     ((':)
        @Natural
        SizeMnistHeight
        ((':) @Natural SizeMnistHeight ('[] @Natural))))
  r
-> Concrete
     (TKS
        ((':)
           @Natural
           n
           ((':)
              @Natural
              SizeMnistHeight
              ((':) @Natural SizeMnistHeight ('[] @Natural))))
        r)
forall r (target :: TK -> Type) (sh :: [Natural]).
(GoodScalar r, BaseTensor target) =>
Shaped sh r -> target (TKS sh r)
sconcrete Shaped
  ((':)
     @Natural
     n
     ((':)
        @Natural
        SizeMnistHeight
        ((':) @Natural SizeMnistHeight ('[] @Natural))))
  r
glyph) (Shaped
  ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r
-> Concrete
     (TKS
        ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r)
forall r (target :: TK -> Type) (sh :: [Natural]).
(GoodScalar r, BaseTensor target) =>
Shaped sh r -> target (TKS sh r)
sconcrete Shaped
  ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r
label))
                 gradient :: Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
gradient =
                   Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':)
                    @Natural
                    n
                    ((':)
                       @Natural
                       1
                       ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':)
                    @Natural
                    n
                    ((':)
                       @Natural
                       n
                       ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                 (TKScalar r))
              (TKS2
                 ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
     (TKProduct
        (TKS
           ((':)
              @Natural
              n
              ((':)
                 @Natural
                 SizeMnistHeight
                 ((':) @Natural SizeMnistHeight ('[] @Natural))))
           r)
        (TKS
           ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
           r)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':)
                    @Natural
                    n
                    ((':)
                       @Natural
                       1
                       ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':)
                    @Natural
                    n
                    ((':)
                       @Natural
                       n
                       ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                 (TKScalar r))
              (TKS2
                 ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
forall (x :: TK) (z :: TK). Concrete (TKProduct x z) -> Concrete x
forall (target :: TK -> Type) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct
               (TKS2
                  ((':)
                     @Natural
                     n
                     ((':)
                        @Natural
                        1
                        ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                  (TKScalar r))
               (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
            (TKProduct
               (TKS2
                  ((':)
                     @Natural
                     n
                     ((':)
                        @Natural
                        n
                        ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                  (TKScalar r))
               (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
         (TKProduct
            (TKProduct
               (TKS2
                  ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                  (TKScalar r))
               (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
            (TKProduct
               (TKS2
                  ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                  (TKScalar r))
               (TKS2
                  ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
      (TKProduct
         (TKS
            ((':)
               @Natural
               n
               ((':)
                  @Natural
                  SizeMnistHeight
                  ((':) @Natural SizeMnistHeight ('[] @Natural))))
            r)
         (TKS
            ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
            r)))
 -> Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (TKS2
                  ((':)
                     @Natural
                     n
                     ((':)
                        @Natural
                        1
                        ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                  (TKScalar r))
               (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
            (TKProduct
               (TKS2
                  ((':)
                     @Natural
                     n
                     ((':)
                        @Natural
                        n
                        ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                  (TKScalar r))
               (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
         (TKProduct
            (TKProduct
               (TKS2
                  ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                  (TKScalar r))
               (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
            (TKProduct
               (TKS2
                  ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                  (TKScalar r))
               (TKS2
                  ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          1
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          n
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                    (TKScalar r))
                 (TKS2
                    ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
        (TKProduct
           (TKS
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    SizeMnistHeight
                    ((':) @Natural SizeMnistHeight ('[] @Natural))))
              r)
           (TKS
              ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
              r)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':)
                    @Natural
                    n
                    ((':)
                       @Natural
                       1
                       ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':)
                    @Natural
                    n
                    ((':)
                       @Natural
                       n
                       ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                 (TKScalar r))
              (TKS2
                 ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
forall a b. (a -> b) -> a -> b
$ (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct
               (TKS2
                  ((':)
                     @Natural
                     n
                     ((':)
                        @Natural
                        1
                        ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                  (TKScalar r))
               (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
            (TKProduct
               (TKS2
                  ((':)
                     @Natural
                     n
                     ((':)
                        @Natural
                        n
                        ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                  (TKScalar r))
               (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
         (TKProduct
            (TKProduct
               (TKS2
                  ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                  (TKScalar r))
               (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
            (TKProduct
               (TKS2
                  ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                  (TKScalar r))
               (TKS2
                  ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
      (TKProduct
         (TKS
            ((':)
               @Natural
               n
               ((':)
                  @Natural
                  SizeMnistHeight
                  ((':) @Natural SizeMnistHeight ('[] @Natural))))
            r)
         (TKS
            ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
            r))),
 Concrete (TKScalar r))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          1
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          n
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                    (TKScalar r))
                 (TKS2
                    ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
        (TKProduct
           (TKS
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    SizeMnistHeight
                    ((':) @Natural SizeMnistHeight ('[] @Natural))))
              r)
           (TKS
              ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
              r)))
forall a b. (a, b) -> a
fst
                   ((Concrete
    (TKProduct
       (TKProduct
          (TKProduct
             (TKProduct
                (TKS2
                   ((':)
                      @Natural
                      n
                      ((':)
                         @Natural
                         1
                         ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                   (TKScalar r))
                (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
             (TKProduct
                (TKS2
                   ((':)
                      @Natural
                      n
                      ((':)
                         @Natural
                         n
                         ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                   (TKScalar r))
                (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
          (TKProduct
             (TKProduct
                (TKS2
                   ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                   (TKScalar r))
                (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
             (TKProduct
                (TKS2
                   ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                   (TKScalar r))
                (TKS2
                   ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
       (TKProduct
          (TKS
             ((':)
                @Natural
                n
                ((':)
                   @Natural
                   SizeMnistHeight
                   ((':) @Natural SizeMnistHeight ('[] @Natural))))
             r)
          (TKS
             ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
             r))),
  Concrete (TKScalar r))
 -> Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKS2
                     ((':)
                        @Natural
                        n
                        ((':)
                           @Natural
                           1
                           ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                     (TKScalar r))
                  (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
               (TKProduct
                  (TKS2
                     ((':)
                        @Natural
                        n
                        ((':)
                           @Natural
                           n
                           ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                     (TKScalar r))
                  (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
            (TKProduct
               (TKProduct
                  (TKS2
                     ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                     (TKScalar r))
                  (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
               (TKProduct
                  (TKS2
                     ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                     (TKScalar r))
                  (TKS2
                     ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
         (TKProduct
            (TKS
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     SizeMnistHeight
                     ((':) @Natural SizeMnistHeight ('[] @Natural))))
               r)
            (TKS
               ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
               r))))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKS2
                     ((':)
                        @Natural
                        n
                        ((':)
                           @Natural
                           1
                           ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                     (TKScalar r))
                  (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
               (TKProduct
                  (TKS2
                     ((':)
                        @Natural
                        n
                        ((':)
                           @Natural
                           n
                           ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                     (TKScalar r))
                  (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
            (TKProduct
               (TKProduct
                  (TKS2
                     ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                     (TKScalar r))
                  (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
               (TKProduct
                  (TKS2
                     ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                     (TKScalar r))
                  (TKS2
                     ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
         (TKProduct
            (TKS
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     SizeMnistHeight
                     ((':) @Natural SizeMnistHeight ('[] @Natural))))
               r)
            (TKS
               ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
               r))),
    Concrete (TKScalar r))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          1
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          n
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                    (TKScalar r))
                 (TKS2
                    ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
        (TKProduct
           (TKS
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    SizeMnistHeight
                    ((':) @Natural SizeMnistHeight ('[] @Natural))))
              r)
           (TKS
              ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
              r)))
forall a b. (a -> b) -> a -> b
$ AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':)
                    @Natural
                    n
                    ((':)
                       @Natural
                       1
                       ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':)
                    @Natural
                    n
                    ((':)
                       @Natural
                       n
                       ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                 (TKScalar r))
              (TKS2
                 ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
     (TKProduct
        (TKS
           ((':)
              @Natural
              n
              ((':)
                 @Natural
                 SizeMnistHeight
                 ((':) @Natural SizeMnistHeight ('[] @Natural))))
           r)
        (TKS
           ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
           r)))
  (TKScalar r)
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          1
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          n
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                    (TKScalar r))
                 (TKS2
                    ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
        (TKProduct
           (TKS
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    SizeMnistHeight
                    ((':) @Natural SizeMnistHeight ('[] @Natural))))
              r)
           (TKS
              ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
              r)))
-> Maybe (Concrete (ADTensorKind (TKScalar r)))
-> (Concrete
      (ADTensorKind
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKProduct
                     (TKS2
                        ((':)
                           @Natural
                           n
                           ((':)
                              @Natural
                              1
                              ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                        (TKScalar r))
                     (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
                  (TKProduct
                     (TKS2
                        ((':)
                           @Natural
                           n
                           ((':)
                              @Natural
                              n
                              ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                        (TKScalar r))
                     (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
               (TKProduct
                  (TKProduct
                     (TKS2
                        ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                        (TKScalar r))
                     (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
                  (TKProduct
                     (TKS2
                        ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                        (TKScalar r))
                     (TKS2
                        ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
            (TKProduct
               (TKS
                  ((':)
                     @Natural
                     n
                     ((':)
                        @Natural
                        SizeMnistHeight
                        ((':) @Natural SizeMnistHeight ('[] @Natural))))
                  r)
               (TKS
                  ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
                  r)))),
    Concrete (TKScalar r))
forall (x :: TK) (z :: TK).
AstArtifactRev x z
-> Concrete x
-> Maybe (Concrete (ADTensorKind z))
-> (Concrete (ADTensorKind x), Concrete z)
revInterpretArtifact AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':)
                    @Natural
                    n
                    ((':)
                       @Natural
                       1
                       ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':)
                    @Natural
                    n
                    ((':)
                       @Natural
                       n
                       ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                 (TKScalar r))
              (TKS2
                 ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
     (TKProduct
        (TKS
           ((':)
              @Natural
              n
              ((':)
                 @Natural
                 SizeMnistHeight
                 ((':) @Natural SizeMnistHeight ('[] @Natural))))
           r)
        (TKS
           ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
           r)))
  (TKScalar r)
art Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':)
                    @Natural
                    n
                    ((':)
                       @Natural
                       1
                       ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':)
                    @Natural
                    n
                    ((':)
                       @Natural
                       n
                       ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                 (TKScalar r))
              (TKS2
                 ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
     (TKProduct
        (TKS
           ((':)
              @Natural
              n
              ((':)
                 @Natural
                 SizeMnistHeight
                 ((':) @Natural SizeMnistHeight ('[] @Natural))))
           r)
        (TKS
           ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
           r)))
parametersAndInput Maybe (Concrete (ADTensorKind (TKScalar r)))
Maybe (Concrete (TKScalar r))
forall a. Maybe a
Nothing
             in [(Shaped
    ((':)
       @Natural
       n
       ((':)
          @Natural
          SizeMnistHeight
          ((':) @Natural SizeMnistHeight ('[] @Natural))))
    r,
  Shaped
    ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r)]
-> (Concrete
      (X (ADCnnMnistParametersShaped
            Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
    StateAdam
      (X (ADCnnMnistParametersShaped
            Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
-> (Concrete
      (X (ADCnnMnistParametersShaped
            Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
    StateAdam
      (X (ADCnnMnistParametersShaped
            Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
go [(Shaped
    ((':)
       @Natural
       n
       ((':)
          @Natural
          SizeMnistHeight
          ((':) @Natural SizeMnistHeight ('[] @Natural))))
    r,
  Shaped
    ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r)]
rest (forall (y :: TK).
ArgsAdam
-> StateAdam y
-> SingletonTK y
-> Concrete y
-> Concrete (ADTensorKind y)
-> (Concrete y, StateAdam y)
updateWithGradientAdam
                           @(XParams kh kw c_out n_hidden r)
                           ArgsAdam
defaultArgsAdam StateAdam
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
stateAdam SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
SingletonTK
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
parameters
                           Concrete
  (ADTensorKind
     (X (ADCnnMnistParametersShaped
           Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
gradient)
           runBatch :: ( Concrete (XParams kh kw c_out n_hidden r)
                       , StateAdam (XParams kh kw c_out n_hidden r) )
                    -> (Int, [MnistDataS r])
                    -> IO ( Concrete (XParams kh kw c_out n_hidden r)
                          , StateAdam (XParams kh kw c_out n_hidden r) )
           runBatch :: (Concrete
   (X (ADCnnMnistParametersShaped
         Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
 StateAdam
   (X (ADCnnMnistParametersShaped
         Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
-> (Int, [MnistDataS r])
-> IO
     (Concrete
        (X (ADCnnMnistParametersShaped
              Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
      StateAdam
        (X (ADCnnMnistParametersShaped
              Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
runBatch (!Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
parameters, !StateAdam
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
stateAdam) (Int
k, [MnistDataS r]
chunk) = do
             let chunkS :: [(Shaped
    ((':)
       @Natural
       n
       ((':)
          @Natural
          SizeMnistHeight
          ((':) @Natural SizeMnistHeight ('[] @Natural))))
    r,
  Shaped
    ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r)]
chunkS = ([MnistDataS r]
 -> (Shaped
       ((':)
          @Natural
          n
          ((':)
             @Natural
             SizeMnistHeight
             ((':) @Natural SizeMnistHeight ('[] @Natural))))
       r,
     Shaped
       ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r))
-> [[MnistDataS r]]
-> [(Shaped
       ((':)
          @Natural
          n
          ((':)
             @Natural
             SizeMnistHeight
             ((':) @Natural SizeMnistHeight ('[] @Natural))))
       r,
     Shaped
       ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r)]
forall a b. (a -> b) -> [a] -> [b]
map [MnistDataS r]
-> (Shaped
      ((':)
         @Natural
         n
         ((':)
            @Natural
            SizeMnistHeight
            ((':) @Natural SizeMnistHeight ('[] @Natural))))
      r,
    Shaped
      ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r)
forall (batch_size :: Natural) r.
(Elt r, KnownNat batch_size) =>
[MnistDataS r] -> MnistDataBatchS batch_size r
mkMnistDataBatchS
                          ([[MnistDataS r]]
 -> [(Shaped
        ((':)
           @Natural
           n
           ((':)
              @Natural
              SizeMnistHeight
              ((':) @Natural SizeMnistHeight ('[] @Natural))))
        r,
      Shaped
        ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
        r)])
-> [[MnistDataS r]]
-> [(Shaped
       ((':)
          @Natural
          n
          ((':)
             @Natural
             SizeMnistHeight
             ((':) @Natural SizeMnistHeight ('[] @Natural))))
       r,
     Shaped
       ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r)]
forall a b. (a -> b) -> a -> b
$ ([MnistDataS r] -> Bool) -> [[MnistDataS r]] -> [[MnistDataS r]]
forall a. (a -> Bool) -> [a] -> [a]
filter (\[MnistDataS r]
ch -> [MnistDataS r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataS r]
ch Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
miniBatchSizeInt)
                          ([[MnistDataS r]] -> [[MnistDataS r]])
-> [[MnistDataS r]] -> [[MnistDataS r]]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataS r] -> [[MnistDataS r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
miniBatchSizeInt [MnistDataS r]
chunk
                 res :: (Concrete
   (X (ADCnnMnistParametersShaped
         Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
 StateAdam
   (X (ADCnnMnistParametersShaped
         Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
res@(Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
parameters2, StateAdam
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
_) = [(Shaped
    ((':)
       @Natural
       n
       ((':)
          @Natural
          SizeMnistHeight
          ((':) @Natural SizeMnistHeight ('[] @Natural))))
    r,
  Shaped
    ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r)]
-> (Concrete
      (X (ADCnnMnistParametersShaped
            Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
    StateAdam
      (X (ADCnnMnistParametersShaped
            Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
-> (Concrete
      (X (ADCnnMnistParametersShaped
            Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
    StateAdam
      (X (ADCnnMnistParametersShaped
            Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
go [(Shaped
    ((':)
       @Natural
       n
       ((':)
          @Natural
          SizeMnistHeight
          ((':) @Natural SizeMnistHeight ('[] @Natural))))
    r,
  Shaped
    ((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r)]
chunkS (Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
parameters, StateAdam
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
stateAdam)
                 trainScore :: r
trainScore = Int -> (forall (n :: Natural). KnownNat n => SNat n -> r) -> r
forall r.
Int -> (forall (n :: Natural). KnownNat n => SNat n -> r) -> r
withSNat ([MnistDataS r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataS r]
chunk) ((forall (n :: Natural). KnownNat n => SNat n -> r) -> r)
-> (forall (n :: Natural). KnownNat n => SNat n -> r) -> r
forall a b. (a -> b) -> a -> b
$ \(SNat @len) ->
                   forall (batch_size :: Natural).
KnownNat batch_size =>
MnistDataBatchS batch_size r
-> Concrete
     (X (ADCnnMnistParametersShaped
           Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> r
ftest @len ([MnistDataS r] -> MnistDataBatchS n r
forall (batch_size :: Natural) r.
(Elt r, KnownNat batch_size) =>
[MnistDataS r] -> MnistDataBatchS batch_size r
mkMnistDataBatchS [MnistDataS r]
chunk) Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
parameters2
                 testScore :: r
testScore = forall (batch_size :: Natural).
KnownNat batch_size =>
MnistDataBatchS batch_size r
-> Concrete
     (X (ADCnnMnistParametersShaped
           Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> r
ftest @lenTestData MnistDataBatchS n r
testDataS Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
parameters2
                 lenChunk :: Int
lenChunk = [MnistDataS r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataS r]
chunk
             Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
n_hiddenInt Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$ do
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
                 String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: (Batch %d with %d points)"
                        String
prefix Int
k Int
lenChunk
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
                 String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Training error:   %.2f%%"
                        String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
trainScore) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
                 String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Validation error: %.2f%%"
                        String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
testScore ) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
             (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     1
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     n
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
      (TKProduct
         (TKProduct
            (TKS2
               ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
               (TKScalar r))
            (TKS2
               ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     1
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     n
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
      (TKProduct
         (TKProduct
            (TKS2
               ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
               (TKScalar r))
            (TKS2
               ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          1
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          n
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                    (TKScalar r))
                 (TKS2
                    ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
      StateAdam
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          1
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          n
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                    (TKScalar r))
                 (TKS2
                    ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     1
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     n
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
      (TKProduct
         (TKProduct
            (TKS2
               ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
               (TKScalar r))
            (TKS2
               ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     1
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     n
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
      (TKProduct
         (TKProduct
            (TKS2
               ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
               (TKScalar r))
            (TKS2
               ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
(Concrete
   (X (ADCnnMnistParametersShaped
         Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
 StateAdam
   (X (ADCnnMnistParametersShaped
         Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
res
       let runEpoch :: Int
                    -> (Concrete (XParams kh kw c_out n_hidden r), StateAdam (XParams kh kw c_out n_hidden r))
                    -> IO (Concrete (XParams kh kw c_out n_hidden r))
           runEpoch :: Int
-> (Concrete
      (X (ADCnnMnistParametersShaped
            Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
    StateAdam
      (X (ADCnnMnistParametersShaped
            Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
-> IO
     (Concrete
        (X (ADCnnMnistParametersShaped
              Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
runEpoch Int
n (Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
params2, StateAdam
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
_) | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
epochs = Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          1
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          n
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                    (TKScalar r))
                 (TKS2
                    ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
params2
           runEpoch Int
n paramsStateAdam :: (Concrete
   (X (ADCnnMnistParametersShaped
         Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
 StateAdam
   (X (ADCnnMnistParametersShaped
         Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
paramsStateAdam@(!Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
_, !StateAdam
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
_) = do
             Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
n_hiddenInt Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$ String -> String -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: [Epoch %d]" String
prefix Int
n
             let trainDataShuffled :: [MnistDataS r]
trainDataShuffled = StdGen -> [MnistDataS r] -> [MnistDataS r]
forall a. StdGen -> [a] -> [a]
shuffle (Int -> StdGen
mkStdGen (Int -> StdGen) -> Int -> StdGen
forall a b. (a -> b) -> a -> b
$ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
5) [MnistDataS r]
trainData
                 chunks :: [(Int, [MnistDataS r])]
chunks = Int -> [(Int, [MnistDataS r])] -> [(Int, [MnistDataS r])]
forall a. Int -> [a] -> [a]
take Int
maxBatches
                          ([(Int, [MnistDataS r])] -> [(Int, [MnistDataS r])])
-> [(Int, [MnistDataS r])] -> [(Int, [MnistDataS r])]
forall a b. (a -> b) -> a -> b
$ [Int] -> [[MnistDataS r]] -> [(Int, [MnistDataS r])]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
1 ..]
                          ([[MnistDataS r]] -> [(Int, [MnistDataS r])])
-> [[MnistDataS r]] -> [(Int, [MnistDataS r])]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataS r] -> [[MnistDataS r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
totalBatchSize [MnistDataS r]
trainDataShuffled
             res <- ((Concrete
    (TKProduct
       (TKProduct
          (TKProduct
             (TKS2
                ((':)
                   @Natural
                   n
                   ((':)
                      @Natural
                      1
                      ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                (TKScalar r))
             (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
          (TKProduct
             (TKS2
                ((':)
                   @Natural
                   n
                   ((':)
                      @Natural
                      n
                      ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                (TKScalar r))
             (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
       (TKProduct
          (TKProduct
             (TKS2
                ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                (TKScalar r))
             (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
          (TKProduct
             (TKS2
                ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                (TKScalar r))
             (TKS2
                ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
  StateAdam
    (TKProduct
       (TKProduct
          (TKProduct
             (TKS2
                ((':)
                   @Natural
                   n
                   ((':)
                      @Natural
                      1
                      ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                (TKScalar r))
             (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
          (TKProduct
             (TKS2
                ((':)
                   @Natural
                   n
                   ((':)
                      @Natural
                      n
                      ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                (TKScalar r))
             (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
       (TKProduct
          (TKProduct
             (TKS2
                ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                (TKScalar r))
             (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
          (TKProduct
             (TKS2
                ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                (TKScalar r))
             (TKS2
                ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
 -> (Int, [MnistDataS r])
 -> IO
      (Concrete
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKS2
                     ((':)
                        @Natural
                        n
                        ((':)
                           @Natural
                           1
                           ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                     (TKScalar r))
                  (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
               (TKProduct
                  (TKS2
                     ((':)
                        @Natural
                        n
                        ((':)
                           @Natural
                           n
                           ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                     (TKScalar r))
                  (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
            (TKProduct
               (TKProduct
                  (TKS2
                     ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                     (TKScalar r))
                  (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
               (TKProduct
                  (TKS2
                     ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                     (TKScalar r))
                  (TKS2
                     ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
       StateAdam
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKS2
                     ((':)
                        @Natural
                        n
                        ((':)
                           @Natural
                           1
                           ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                     (TKScalar r))
                  (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
               (TKProduct
                  (TKS2
                     ((':)
                        @Natural
                        n
                        ((':)
                           @Natural
                           n
                           ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                     (TKScalar r))
                  (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
            (TKProduct
               (TKProduct
                  (TKS2
                     ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                     (TKScalar r))
                  (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
               (TKProduct
                  (TKS2
                     ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                     (TKScalar r))
                  (TKS2
                     ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (TKS2
                  ((':)
                     @Natural
                     n
                     ((':)
                        @Natural
                        1
                        ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                  (TKScalar r))
               (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
            (TKProduct
               (TKS2
                  ((':)
                     @Natural
                     n
                     ((':)
                        @Natural
                        n
                        ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                  (TKScalar r))
               (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
         (TKProduct
            (TKProduct
               (TKS2
                  ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                  (TKScalar r))
               (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
            (TKProduct
               (TKS2
                  ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                  (TKScalar r))
               (TKS2
                  ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
    StateAdam
      (TKProduct
         (TKProduct
            (TKProduct
               (TKS2
                  ((':)
                     @Natural
                     n
                     ((':)
                        @Natural
                        1
                        ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                  (TKScalar r))
               (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
            (TKProduct
               (TKS2
                  ((':)
                     @Natural
                     n
                     ((':)
                        @Natural
                        n
                        ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                  (TKScalar r))
               (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
         (TKProduct
            (TKProduct
               (TKS2
                  ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                  (TKScalar r))
               (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
            (TKProduct
               (TKS2
                  ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                  (TKScalar r))
               (TKS2
                  ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
-> [(Int, [MnistDataS r])]
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          1
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          n
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                    (TKScalar r))
                 (TKS2
                    ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
      StateAdam
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          1
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          n
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                    (TKScalar r))
                 (TKS2
                    ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
forall (t :: Type -> Type) (m :: Type -> Type) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     1
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     n
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
      (TKProduct
         (TKProduct
            (TKS2
               ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
               (TKScalar r))
            (TKS2
               ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     1
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     n
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
      (TKProduct
         (TKProduct
            (TKS2
               ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
               (TKScalar r))
            (TKS2
               ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
-> (Int, [MnistDataS r])
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          1
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          n
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                    (TKScalar r))
                 (TKS2
                    ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
      StateAdam
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          1
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':)
                       @Natural
                       n
                       ((':)
                          @Natural
                          n
                          ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                    (TKScalar r))
                 (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                    (TKScalar r))
                 (TKS2
                    ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
(Concrete
   (X (ADCnnMnistParametersShaped
         Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
 StateAdam
   (X (ADCnnMnistParametersShaped
         Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
-> (Int, [MnistDataS r])
-> IO
     (Concrete
        (X (ADCnnMnistParametersShaped
              Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
      StateAdam
        (X (ADCnnMnistParametersShaped
              Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
runBatch (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     1
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     n
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
      (TKProduct
         (TKProduct
            (TKS2
               ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
               (TKScalar r))
            (TKS2
               ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     1
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':)
                  @Natural
                  n
                  ((':)
                     @Natural
                     n
                     ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
      (TKProduct
         (TKProduct
            (TKS2
               ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
               (TKScalar r))
            (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
               (TKScalar r))
            (TKS2
               ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
(Concrete
   (X (ADCnnMnistParametersShaped
         Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
 StateAdam
   (X (ADCnnMnistParametersShaped
         Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
paramsStateAdam [(Int, [MnistDataS r])]
chunks
             runEpoch (succ n) res
           ftk :: FullShapeTK
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
ftk = forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams kh kw c_out n_hidden r)) Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
targetInit
       res <- Int
-> (Concrete
      (X (ADCnnMnistParametersShaped
            Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
    StateAdam
      (X (ADCnnMnistParametersShaped
            Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
-> IO
     (Concrete
        (X (ADCnnMnistParametersShaped
              Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
runEpoch Int
1 (Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
targetInit, FullShapeTK
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> StateAdam
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':)
                    @Natural
                    n
                    ((':)
                       @Natural
                       1
                       ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':)
                    @Natural
                    n
                    ((':)
                       @Natural
                       n
                       ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
                 (TKScalar r))
              (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
                 (TKScalar r))
              (TKS2
                 ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
forall (y :: TK). FullShapeTK y -> StateAdam y
initialStateAdam FullShapeTK
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
ftk)
       let testErrorFinal =
             r
1 r -> r -> r
forall a. Num a => a -> a -> a
- MnistDataBatchS n r
-> Concrete
     (X (ADCnnMnistParametersShaped
           Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> r
forall (batch_size :: Natural).
KnownNat batch_size =>
MnistDataBatchS batch_size r
-> Concrete
     (X (ADCnnMnistParametersShaped
           Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> r
ftest MnistDataBatchS n r
testDataS Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    1
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Natural
                 n
                 ((':)
                    @Natural
                    n
                    ((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
              (TKScalar r))
           (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
              (TKScalar r))
           (TKS2
              ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
  (X (ADCnnMnistParametersShaped
        Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
res
       assertEqualUpToEpsilon 1e-1 expected testErrorFinal

tensorADValMnistTestsCNNSO :: TestTree
tensorADValMnistTestsCNNSO :: TestTree
tensorADValMnistTestsCNNSO = String -> [TestTree] -> TestTree
testGroup String
"CNNS Once MNIST tests"
  [ String
-> Int
-> Int
-> SNat 4
-> SNat 4
-> Int
-> Int
-> Int
-> Int
-> Double
-> TestTree
forall (kh :: Natural) (kw :: Natural) r.
((<=) @Natural 1 kh, (<=) @Natural 1 kw, Differentiable r,
 GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r,
 (ADTensorScalar r :: Type) ~ (r :: Type)) =>
String
-> Int
-> Int
-> SNat kh
-> SNat kw
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNSO String
"CNNSO 1 epoch, 1 batch"
                       Int
1 Int
1 (forall (n :: Natural). KnownNat n => SNat n
SNat @4) (forall (n :: Natural). KnownNat n => SNat n
SNat @4) Int
8 Int
16 Int
1 Int
1
                       (Double
1 :: Double)
  , String
-> Int
-> Int
-> SNat 2
-> SNat 3
-> Int
-> Int
-> Int
-> Int
-> Float
-> TestTree
forall (kh :: Natural) (kw :: Natural) r.
((<=) @Natural 1 kh, (<=) @Natural 1 kw, Differentiable r,
 GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r,
 (ADTensorScalar r :: Type) ~ (r :: Type)) =>
String
-> Int
-> Int
-> SNat kh
-> SNat kw
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNSO String
"CNNSO artificial 1 2 3 4 5"
                       Int
1 Int
1 (forall (n :: Natural). KnownNat n => SNat n
SNat @2) (forall (n :: Natural). KnownNat n => SNat n
SNat @3) Int
4 Int
5 Int
1 Int
10
                       (Float
1 :: Float)
  , String
-> Int
-> Int
-> SNat 3
-> SNat 2
-> Int
-> Int
-> Int
-> Int
-> Double
-> TestTree
forall (kh :: Natural) (kw :: Natural) r.
((<=) @Natural 1 kh, (<=) @Natural 1 kw, Differentiable r,
 GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r,
 (ADTensorScalar r :: Type) ~ (r :: Type)) =>
String
-> Int
-> Int
-> SNat kh
-> SNat kw
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNSO String
"CNNSO artificial 5 4 3 2 1"
                       Int
5 Int
4 (forall (n :: Natural). KnownNat n => SNat n
SNat @3) (forall (n :: Natural). KnownNat n => SNat n
SNat @2) Int
1 Int
1 Int
1 Int
1
                       (Double
1 :: Double)
  , String
-> Int
-> Int
-> SNat 4
-> SNat 4
-> Int
-> Int
-> Int
-> Int
-> Float
-> TestTree
forall (kh :: Natural) (kw :: Natural) r.
((<=) @Natural 1 kh, (<=) @Natural 1 kw, Differentiable r,
 GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r,
 (ADTensorScalar r :: Type) ~ (r :: Type)) =>
String
-> Int
-> Int
-> SNat kh
-> SNat kw
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNSO String
"CNNSO 1 epoch, 0 batch"
                       Int
1 Int
0 (forall (n :: Natural). KnownNat n => SNat n
SNat @4) (forall (n :: Natural). KnownNat n => SNat n
SNat @4) Int
16 Int
64 Int
16 Int
50
                       (Float
1.0 :: Float)
  ]