-- | Tests of "MnistRnnShaped2" recurrent neural networks using a few different
-- optimization pipelines.
--
-- Not LSTM.
-- Doesn't train without Adam, regardless of whether mini-batches used. It does
-- train with Adam, but only after very carefully tweaking initialization.
-- This is extremely sensitive to initial parameters, more than to anything
-- else. Probably, gradient is vanishing if parameters are initialized
-- with a probability distribution that doesn't have the right variance. See
-- https://stats.stackexchange.com/questions/301285/what-is-vanishing-gradient.
-- Regularization/normalization might help as well.
module TestMnistRNNS
  ( testTrees
  ) where

import Prelude

import Control.Monad (foldM, unless)
import Data.Proxy (Proxy (Proxy))
import Data.Type.Equality ((:~:) (Refl))
import GHC.TypeLits (KnownNat, sameNat)
import System.IO (hPutStrLn, stderr)
import System.Random
import Test.Tasty
import Test.Tasty.HUnit hiding (assert)
import Text.Printf

import Data.Array.Nested.Shaped.Shape

import HordeAd
import HordeAd.Core.Adaptor
import HordeAd.Core.AstEnv
import HordeAd.Core.AstFreshId
import HordeAd.Core.AstInterpret

import EqEpsilon

import MnistData
import MnistRnnShaped2 (ADRnnMnistParametersShaped)
import MnistRnnShaped2 qualified

-- TODO: optimize enough that it can run for one full epoch in reasonable time
-- and then verify it trains down to ~20% validation error in a short enough
-- time to include such a training run in tests.

type XParams out_width r =
 X (ADRnnMnistParametersShaped Concrete SizeMnistHeight out_width r)

testTrees :: [TestTree]
testTrees :: [TestTree]
testTrees = [ TestTree
tensorADValMnistTestsRNNSA
            , TestTree
tensorADValMnistTestsRNNSI
            , TestTree
tensorADValMnistTestsRNNSO
            ]

-- POPL differentiation, straight via the ADVal instance of RankedTensor,
-- which side-steps vectorization.
mnistTestCaseRNNSA
  :: forall width batch_size r.
     (Differentiable r, GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r)
  => String
  -> Int -> Int -> SNat width -> SNat batch_size -> Int -> r
  -> TestTree
mnistTestCaseRNNSA :: forall (width :: Nat) (batch_size :: Nat) r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> SNat width
-> SNat batch_size
-> Int
-> r
-> TestTree
mnistTestCaseRNNSA String
prefix Int
epochs Int
maxBatches width :: SNat width
width@SNat width
SNat batch_size :: SNat batch_size
batch_size@SNat batch_size
SNat
                   Int
totalBatchSize r
expected =
  let targetInit :: Concrete (XParams width r)
targetInit =
        (Concrete (XParams width r), StdGen) -> Concrete (XParams width r)
forall a b. (a, b) -> a
fst ((Concrete (XParams width r), StdGen)
 -> Concrete (XParams width r))
-> (Concrete (XParams width r), StdGen)
-> Concrete (XParams width r)
forall a b. (a -> b) -> a -> b
$ forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue @(Concrete (XParams width r)) Double
0.23 (Int -> StdGen
mkStdGen Int
44)
      miniBatchSize :: Int
miniBatchSize = SNat batch_size -> Int
forall (n :: Nat). SNat n -> Int
sNatValue SNat batch_size
batch_size
      name :: String
name = String
prefix String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": "
             String -> String -> String
forall a. [a] -> [a] -> [a]
++ [String] -> String
unwords [ Int -> String
forall a. Show a => a -> String
show Int
epochs, Int -> String
forall a. Show a => a -> String
show Int
maxBatches
                        , Int -> String
forall a. Show a => a -> String
show (SNat width -> Int
forall (n :: Nat). SNat n -> Int
sNatValue SNat width
width), Int -> String
forall a. Show a => a -> String
show Int
miniBatchSize
                        , Int -> String
forall a. Show a => a -> String
show (Int -> String) -> Int -> String
forall a b. (a -> b) -> a -> b
$ SingletonTK (XParams width r) -> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK
                          (SingletonTK (XParams width r) -> Int)
-> SingletonTK (XParams width r) -> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams width r)
                        , Int -> String
forall a. Show a => a -> String
show (SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
           (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                    (TKScalar r))
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKProduct
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
              (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
           (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
           (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
Concrete (XParams width r)
targetInit) ]
      ftest :: forall batch_size2. KnownNat batch_size2
            => MnistDataBatchS batch_size2 r -> Concrete (XParams width r)
            -> r
      ftest :: forall (batch_size2 :: Nat).
KnownNat batch_size2 =>
MnistDataBatchS batch_size2 r -> Concrete (XParams width r) -> r
ftest MnistDataBatchS batch_size2 r
_ Concrete (XParams width r)
_ | Just (:~:) @Nat 0 batch_size2
Refl <- Proxy @Nat 0
-> Proxy @Nat batch_size2 -> Maybe ((:~:) @Nat 0 batch_size2)
forall (a :: Nat) (b :: Nat) (proxy1 :: Nat -> Type)
       (proxy2 :: Nat -> Type).
(KnownNat a, KnownNat b) =>
proxy1 a -> proxy2 b -> Maybe ((:~:) @Nat a b)
sameNat (forall (t :: Nat). Proxy @Nat t
forall {k} (t :: k). Proxy @k t
Proxy @0) (forall (t :: Nat). Proxy @Nat t
forall {k} (t :: k). Proxy @k t
Proxy @batch_size2) = r
0
      ftest MnistDataBatchS batch_size2 r
mnistData Concrete (XParams width r)
testParams =
        SNat width
-> SNat batch_size2
-> MnistDataBatchS batch_size2 r
-> ADRnnMnistParametersShaped Concrete SizeMnistHeight width r
-> r
forall (target :: TK -> Type) (h :: Nat) (w :: Nat)
       (out_width :: Nat) (batch_size :: Nat) r.
((h :: Nat) ~ (SizeMnistHeight :: Nat),
 (w :: Nat) ~ (SizeMnistHeight :: Nat),
 (target :: (TK -> Type)) ~ (Concrete :: (TK -> Type)),
 Differentiable r, GoodScalar r) =>
SNat out_width
-> SNat batch_size
-> MnistDataBatchS batch_size r
-> ADRnnMnistParametersShaped target h out_width r
-> r
MnistRnnShaped2.rnnMnistTestS
          SNat width
width (forall (n :: Nat). KnownNat n => SNat n
SNat @batch_size2) MnistDataBatchS batch_size2 r
mnistData (forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget @Concrete Concrete (XParams width r)
testParams)
  in String -> Assertion -> TestTree
testCase String
name (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ do
    Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
      String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: Epochs to run/max batches per epoch: %d/%d"
             String
prefix Int
epochs Int
maxBatches
    trainData <- (MnistData r -> MnistDataS r) -> [MnistData r] -> [MnistDataS r]
forall a b. (a -> b) -> [a] -> [b]
map MnistData r -> MnistDataS r
forall r. PrimElt r => MnistData r -> MnistDataS r
mkMnistDataS
                 ([MnistData r] -> [MnistDataS r])
-> IO [MnistData r] -> IO [MnistDataS r]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> String -> IO [MnistData r]
forall r.
(Storable r, Fractional r) =>
String -> String -> IO [MnistData r]
loadMnistData String
trainGlyphsPath String
trainLabelsPath
    testData <- map mkMnistDataS . take (totalBatchSize * maxBatches)
                <$> loadMnistData testGlyphsPath testLabelsPath
    withSNat ((totalBatchSize * maxBatches) `min` 10000)
     $ \(SNat @lenTestData) -> do
       let testDataS :: MnistDataBatchS n r
testDataS = forall (batch_size :: Nat) r.
(Elt r, KnownNat batch_size) =>
[MnistDataS r] -> MnistDataBatchS batch_size r
mkMnistDataBatchS @lenTestData [MnistDataS r]
testData
           f :: MnistDataBatchS batch_size r
             -> ADVal Concrete (XParams width r)
             -> ADVal Concrete (TKScalar r)
           f :: MnistDataBatchS batch_size r
-> ADVal Concrete (XParams width r) -> ADVal Concrete (TKScalar r)
f (Shaped
  ((':)
     @Nat
     batch_size
     ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
  r
glyphS, Shaped
  ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r
labelS) ADVal Concrete (XParams width r)
adinputs =
             SNat width
-> SNat batch_size
-> (PrimalOf
      (ADVal Concrete)
      (TKS
         ((':)
            @Nat
            batch_size
            ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
         r),
    PrimalOf
      (ADVal Concrete)
      (TKS
         ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r))
-> ADRnnMnistParametersShaped
     (ADVal Concrete) SizeMnistHeight width r
-> ADVal Concrete (TKScalar r)
forall (target :: TK -> Type) (h :: Nat) (w :: Nat)
       (out_width :: Nat) (batch_size :: Nat) r.
((h :: Nat) ~ (SizeMnistHeight :: Nat),
 (w :: Nat) ~ (SizeMnistHeight :: Nat), Differentiable r,
 ADReady target, ADReady (PrimalOf target), GoodScalar r) =>
SNat out_width
-> SNat batch_size
-> (PrimalOf
      target
      (TKS
         ((':) @Nat batch_size ((':) @Nat h ((':) @Nat w ('[] @Nat)))) r),
    PrimalOf
      target
      (TKS
         ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r))
-> ADRnnMnistParametersShaped target h out_width r
-> target (TKScalar r)
MnistRnnShaped2.rnnMnistLossFusedS
               SNat width
width SNat batch_size
batch_size (Shaped
  ((':)
     @Nat
     batch_size
     ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
  r
-> Concrete
     (TKS
        ((':)
           @Nat
           batch_size
           ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
        r)
forall r (target :: TK -> Type) (sh :: [Nat]).
(GoodScalar r, BaseTensor target) =>
Shaped sh r -> target (TKS sh r)
sconcrete Shaped
  ((':)
     @Nat
     batch_size
     ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
  r
glyphS, Shaped
  ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r
-> Concrete
     (TKS
        ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)
forall r (target :: TK -> Type) (sh :: [Nat]).
(GoodScalar r, BaseTensor target) =>
Shaped sh r -> target (TKS sh r)
sconcrete Shaped
  ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r
labelS)
               (forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget @(ADVal Concrete) ADVal
  Concrete
  (X (ADRnnMnistParametersShaped
        (ADVal Concrete) SizeMnistHeight width r))
ADVal Concrete (XParams width r)
adinputs)
           runBatch :: ( Concrete (XParams width r)
                       , StateAdam (XParams width r) )
                    -> (Int, [MnistDataS r])
                    -> IO ( Concrete (XParams width r)
                          , StateAdam (XParams width r) )
           runBatch :: (Concrete (XParams width r), StateAdam (XParams width r))
-> (Int, [MnistDataS r])
-> IO (Concrete (XParams width r), StateAdam (XParams width r))
runBatch (!Concrete (XParams width r)
parameters, !StateAdam (XParams width r)
stateAdam) (Int
k, [MnistDataS r]
chunk) = do
             let chunkS :: [MnistDataBatchS batch_size r]
chunkS = ([MnistDataS r] -> MnistDataBatchS batch_size r)
-> [[MnistDataS r]] -> [MnistDataBatchS batch_size r]
forall a b. (a -> b) -> [a] -> [b]
map [MnistDataS r] -> MnistDataBatchS batch_size r
forall (batch_size :: Nat) r.
(Elt r, KnownNat batch_size) =>
[MnistDataS r] -> MnistDataBatchS batch_size r
mkMnistDataBatchS
                          ([[MnistDataS r]] -> [MnistDataBatchS batch_size r])
-> [[MnistDataS r]] -> [MnistDataBatchS batch_size r]
forall a b. (a -> b) -> a -> b
$ ([MnistDataS r] -> Bool) -> [[MnistDataS r]] -> [[MnistDataS r]]
forall a. (a -> Bool) -> [a] -> [a]
filter (\[MnistDataS r]
ch -> [MnistDataS r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataS r]
ch Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
miniBatchSize)
                          ([[MnistDataS r]] -> [[MnistDataS r]])
-> [[MnistDataS r]] -> [[MnistDataS r]]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataS r] -> [[MnistDataS r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
miniBatchSize [MnistDataS r]
chunk
                 res :: (Concrete (XParams width r), StateAdam (XParams width r))
res@(Concrete (XParams width r)
parameters2, StateAdam (XParams width r)
_) =
                   forall a (x :: TK) (z :: TK).
KnownSTK x =>
(a -> ADVal Concrete x -> ADVal Concrete z)
-> [a] -> Concrete x -> StateAdam x -> (Concrete x, StateAdam x)
sgdAdam @(MnistDataBatchS batch_size r)
                               @(XParams width r)
                               MnistDataBatchS batch_size r
-> ADVal Concrete (XParams width r) -> ADVal Concrete (TKScalar r)
f [MnistDataBatchS batch_size r]
chunkS Concrete (XParams width r)
parameters StateAdam (XParams width r)
stateAdam
                 trainScore :: r
trainScore = Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat ([MnistDataS r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataS r]
chunk) ((forall (n :: Nat). KnownNat n => SNat n -> r) -> r)
-> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
forall a b. (a -> b) -> a -> b
$ \(SNat @len) ->
                   forall (batch_size2 :: Nat).
KnownNat batch_size2 =>
MnistDataBatchS batch_size2 r -> Concrete (XParams width r) -> r
ftest @len ([MnistDataS r] -> MnistDataBatchS n r
forall (batch_size :: Nat) r.
(Elt r, KnownNat batch_size) =>
[MnistDataS r] -> MnistDataBatchS batch_size r
mkMnistDataBatchS [MnistDataS r]
chunk) Concrete (XParams width r)
parameters2
                 testScore :: r
testScore = forall (batch_size2 :: Nat).
KnownNat batch_size2 =>
MnistDataBatchS batch_size2 r -> Concrete (XParams width r) -> r
ftest @lenTestData MnistDataBatchS n r
testDataS Concrete (XParams width r)
parameters2
                 lenChunk :: Int
lenChunk = [MnistDataS r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataS r]
chunk
             Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (SNat width -> Int
forall (n :: Nat). SNat n -> Int
sNatValue SNat width
width Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$ do
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
                 String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: (Batch %d with %d points)"
                        String
prefix Int
k Int
lenChunk
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
                 String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Training error:   %.2f%%"
                        String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
trainScore) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
                 String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Validation error: %.2f%%"
                        String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
testScore ) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
             (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct
               (TKS2
                  ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                  (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
         (TKProduct
            (TKProduct
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
      (TKProduct
         (TKS2
            ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
            (TKScalar r))
         (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct
               (TKS2
                  ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                  (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
         (TKProduct
            (TKProduct
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
      (TKProduct
         (TKS2
            ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
            (TKScalar r))
         (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct
                    (TKS2
                       ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                       (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKProduct
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
      StateAdam
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct
                    (TKS2
                       ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                       (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKProduct
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct
               (TKS2
                  ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                  (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
         (TKProduct
            (TKProduct
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
      (TKProduct
         (TKS2
            ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
            (TKScalar r))
         (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct
               (TKS2
                  ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                  (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
         (TKProduct
            (TKProduct
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
      (TKProduct
         (TKS2
            ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
            (TKScalar r))
         (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
(Concrete (XParams width r), StateAdam (XParams width r))
res
       let runEpoch :: Int
                    -> ( Concrete (XParams width r)
                       , StateAdam (XParams width r) )
                    -> IO (Concrete (XParams width r))
           runEpoch :: Int
-> (Concrete (XParams width r), StateAdam (XParams width r))
-> IO (Concrete (XParams width r))
runEpoch Int
n (Concrete (XParams width r)
params2, StateAdam (XParams width r)
_) | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
epochs = Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
           (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct
                    (TKS2
                       ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                       (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKProduct
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
           (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
Concrete (XParams width r)
params2
           runEpoch Int
n paramsStateAdam :: (Concrete (XParams width r), StateAdam (XParams width r))
paramsStateAdam@(!Concrete (XParams width r)
_, !StateAdam (XParams width r)
_) = do
             Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (SNat width -> Int
forall (n :: Nat). SNat n -> Int
sNatValue SNat width
width Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$ String -> String -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: [Epoch %d]" String
prefix Int
n
             let trainDataShuffled :: [MnistDataS r]
trainDataShuffled = StdGen -> [MnistDataS r] -> [MnistDataS r]
forall a. StdGen -> [a] -> [a]
shuffle (Int -> StdGen
mkStdGen (Int -> StdGen) -> Int -> StdGen
forall a b. (a -> b) -> a -> b
$ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
5) [MnistDataS r]
trainData
                 chunks :: [(Int, [MnistDataS r])]
chunks = Int -> [(Int, [MnistDataS r])] -> [(Int, [MnistDataS r])]
forall a. Int -> [a] -> [a]
take Int
maxBatches
                          ([(Int, [MnistDataS r])] -> [(Int, [MnistDataS r])])
-> [(Int, [MnistDataS r])] -> [(Int, [MnistDataS r])]
forall a b. (a -> b) -> a -> b
$ [Int] -> [[MnistDataS r]] -> [(Int, [MnistDataS r])]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
1 ..]
                          ([[MnistDataS r]] -> [(Int, [MnistDataS r])])
-> [[MnistDataS r]] -> [(Int, [MnistDataS r])]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataS r] -> [[MnistDataS r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
totalBatchSize [MnistDataS r]
trainDataShuffled
             res <- ((Concrete
    (TKProduct
       (TKProduct
          (TKProduct
             (TKProduct
                (TKS2
                   ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                   (TKScalar r))
                (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
             (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
          (TKProduct
             (TKProduct
                (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
             (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
       (TKProduct
          (TKS2
             ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
             (TKScalar r))
          (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
  StateAdam
    (TKProduct
       (TKProduct
          (TKProduct
             (TKProduct
                (TKS2
                   ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                   (TKScalar r))
                (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
             (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
          (TKProduct
             (TKProduct
                (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
             (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
       (TKProduct
          (TKS2
             ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
             (TKScalar r))
          (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
 -> (Int, [MnistDataS r])
 -> IO
      (Concrete
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKProduct
                     (TKS2
                        ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                        (TKScalar r))
                     (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                  (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
               (TKProduct
                  (TKProduct
                     (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                     (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                  (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
            (TKProduct
               (TKS2
                  ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
                  (TKScalar r))
               (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
       StateAdam
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKProduct
                     (TKS2
                        ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                        (TKScalar r))
                     (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                  (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
               (TKProduct
                  (TKProduct
                     (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                     (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                  (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
            (TKProduct
               (TKS2
                  ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
                  (TKScalar r))
               (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKS2
                     ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                     (TKScalar r))
                  (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
               (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
            (TKProduct
               (TKProduct
                  (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                  (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
               (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
         (TKProduct
            (TKS2
               ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
               (TKScalar r))
            (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
    StateAdam
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKS2
                     ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                     (TKScalar r))
                  (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
               (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
            (TKProduct
               (TKProduct
                  (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                  (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
               (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
         (TKProduct
            (TKS2
               ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
               (TKScalar r))
            (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
-> [(Int, [MnistDataS r])]
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct
                    (TKS2
                       ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                       (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKProduct
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
      StateAdam
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct
                    (TKS2
                       ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                       (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKProduct
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
forall (t :: Type -> Type) (m :: Type -> Type) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct
               (TKS2
                  ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                  (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
         (TKProduct
            (TKProduct
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
      (TKProduct
         (TKS2
            ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
            (TKScalar r))
         (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct
               (TKS2
                  ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                  (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
         (TKProduct
            (TKProduct
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
      (TKProduct
         (TKS2
            ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
            (TKScalar r))
         (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
-> (Int, [MnistDataS r])
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct
                    (TKS2
                       ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                       (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKProduct
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
      StateAdam
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct
                    (TKS2
                       ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                       (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKProduct
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
(Concrete (XParams width r), StateAdam (XParams width r))
-> (Int, [MnistDataS r])
-> IO (Concrete (XParams width r), StateAdam (XParams width r))
runBatch (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct
               (TKS2
                  ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                  (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
         (TKProduct
            (TKProduct
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
      (TKProduct
         (TKS2
            ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
            (TKScalar r))
         (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct
               (TKS2
                  ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                  (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
         (TKProduct
            (TKProduct
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
      (TKProduct
         (TKS2
            ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
            (TKScalar r))
         (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
(Concrete (XParams width r), StateAdam (XParams width r))
paramsStateAdam [(Int, [MnistDataS r])]
chunks
             runEpoch (succ n) res
           ftk :: FullShapeTK
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
           (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
ftk = forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams width r))
                      Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
           (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
Concrete (XParams width r)
targetInit
       res <- Int
-> (Concrete (XParams width r), StateAdam (XParams width r))
-> IO (Concrete (XParams width r))
runEpoch Int
1 (Concrete (XParams width r)
targetInit, FullShapeTK
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
           (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> StateAdam
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                    (TKScalar r))
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKProduct
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
              (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
forall (y :: TK). FullShapeTK y -> StateAdam y
initialStateAdam FullShapeTK
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
           (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
ftk)
       let testErrorFinal = r
1 r -> r -> r
forall a. Num a => a -> a -> a
- forall (batch_size2 :: Nat).
KnownNat batch_size2 =>
MnistDataBatchS batch_size2 r -> Concrete (XParams width r) -> r
ftest @lenTestData MnistDataBatchS n r
testDataS Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
           (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
Concrete (XParams width r)
res
       testErrorFinal @?~ expected

{-# SPECIALIZE mnistTestCaseRNNSA
  :: String
  -> Int -> Int -> SNat width -> SNat batch_size -> Int -> Double
  -> TestTree #-}

tensorADValMnistTestsRNNSA :: TestTree
tensorADValMnistTestsRNNSA :: TestTree
tensorADValMnistTestsRNNSA = String -> [TestTree] -> TestTree
testGroup String
"RNNS ADVal MNIST tests"
  [ String
-> Int -> Int -> SNat 128 -> SNat 150 -> Int -> Double -> TestTree
forall (width :: Nat) (batch_size :: Nat) r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> SNat width
-> SNat batch_size
-> Int
-> r
-> TestTree
mnistTestCaseRNNSA String
"RNNSA 1 epoch, 1 batch" Int
1 Int
1 (forall (n :: Nat). KnownNat n => SNat n
SNat @128) (forall (n :: Nat). KnownNat n => SNat n
SNat @150) Int
5000
                       (Double
0.6026 :: Double)
  , String
-> Int -> Int -> SNat 4 -> SNat 5 -> Int -> Float -> TestTree
forall (width :: Nat) (batch_size :: Nat) r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> SNat width
-> SNat batch_size
-> Int
-> r
-> TestTree
mnistTestCaseRNNSA String
"RNNSA artificial 1 2 3 4 5" Int
2 Int
3 (forall (n :: Nat). KnownNat n => SNat n
SNat @4) (forall (n :: Nat). KnownNat n => SNat n
SNat @5) Int
50
                       (Float
0.8933333 :: Float)
  , String
-> Int -> Int -> SNat 3 -> SNat 2 -> Int -> Double -> TestTree
forall (width :: Nat) (batch_size :: Nat) r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> SNat width
-> SNat batch_size
-> Int
-> r
-> TestTree
mnistTestCaseRNNSA String
"RNNSA artificial 5 4 3 2 1" Int
5 Int
4 (forall (n :: Nat). KnownNat n => SNat n
SNat @3) (forall (n :: Nat). KnownNat n => SNat n
SNat @2) Int
49
                       (Double
0.8622448979591837 :: Double)
  , String
-> Int -> Int -> SNat 128 -> SNat 150 -> Int -> Float -> TestTree
forall (width :: Nat) (batch_size :: Nat) r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> SNat width
-> SNat batch_size
-> Int
-> r
-> TestTree
mnistTestCaseRNNSA String
"RNNSA 1 epoch, 0 batch" Int
1 Int
0 (forall (n :: Nat). KnownNat n => SNat n
SNat @128) (forall (n :: Nat). KnownNat n => SNat n
SNat @150) Int
50
                       (Float
1.0 :: Float)
  ]

-- POPL differentiation, with Ast term defined and vectorized only once,
-- but differentiated anew in each gradient descent iteration.
mnistTestCaseRNNSI
  :: forall width batch_size r.
     (Differentiable r, GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r)
  => String
  -> Int -> Int -> SNat width -> SNat batch_size -> Int -> r
  -> TestTree
mnistTestCaseRNNSI :: forall (width :: Nat) (batch_size :: Nat) r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> SNat width
-> SNat batch_size
-> Int
-> r
-> TestTree
mnistTestCaseRNNSI String
prefix Int
epochs Int
maxBatches width :: SNat width
width@SNat width
SNat batch_size :: SNat batch_size
batch_size@SNat batch_size
SNat
                   Int
totalBatchSize r
expected =
  let targetInit :: Concrete (XParams width r)
targetInit =
        (Concrete (XParams width r), StdGen) -> Concrete (XParams width r)
forall a b. (a, b) -> a
fst ((Concrete (XParams width r), StdGen)
 -> Concrete (XParams width r))
-> (Concrete (XParams width r), StdGen)
-> Concrete (XParams width r)
forall a b. (a -> b) -> a -> b
$ forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue @(Concrete (XParams width r)) Double
0.23 (Int -> StdGen
mkStdGen Int
44)
      miniBatchSize :: Int
miniBatchSize = SNat batch_size -> Int
forall (n :: Nat). SNat n -> Int
sNatValue SNat batch_size
batch_size
      name :: String
name = String
prefix String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": "
             String -> String -> String
forall a. [a] -> [a] -> [a]
++ [String] -> String
unwords [ Int -> String
forall a. Show a => a -> String
show Int
epochs, Int -> String
forall a. Show a => a -> String
show Int
maxBatches
                        , Int -> String
forall a. Show a => a -> String
show (SNat width -> Int
forall (n :: Nat). SNat n -> Int
sNatValue SNat width
width), Int -> String
forall a. Show a => a -> String
show Int
miniBatchSize
                        , Int -> String
forall a. Show a => a -> String
show (Int -> String) -> Int -> String
forall a b. (a -> b) -> a -> b
$ SingletonTK (XParams width r) -> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK
                          (SingletonTK (XParams width r) -> Int)
-> SingletonTK (XParams width r) -> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams width r)
                        , Int -> String
forall a. Show a => a -> String
show (SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
           (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                    (TKScalar r))
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKProduct
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
              (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
           (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
           (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
Concrete (XParams width r)
targetInit) ]
      ftest :: forall batch_size2. KnownNat batch_size2
            => MnistDataBatchS batch_size2 r -> Concrete (XParams width r)
            -> r
      ftest :: forall (batch_size2 :: Nat).
KnownNat batch_size2 =>
MnistDataBatchS batch_size2 r -> Concrete (XParams width r) -> r
ftest MnistDataBatchS batch_size2 r
_ Concrete (XParams width r)
_ | Just (:~:) @Nat 0 batch_size2
Refl <- Proxy @Nat 0
-> Proxy @Nat batch_size2 -> Maybe ((:~:) @Nat 0 batch_size2)
forall (a :: Nat) (b :: Nat) (proxy1 :: Nat -> Type)
       (proxy2 :: Nat -> Type).
(KnownNat a, KnownNat b) =>
proxy1 a -> proxy2 b -> Maybe ((:~:) @Nat a b)
sameNat (forall (t :: Nat). Proxy @Nat t
forall {k} (t :: k). Proxy @k t
Proxy @0) (forall (t :: Nat). Proxy @Nat t
forall {k} (t :: k). Proxy @k t
Proxy @batch_size2) = r
0
      ftest MnistDataBatchS batch_size2 r
mnistData Concrete (XParams width r)
testParams =
        SNat width
-> SNat batch_size2
-> MnistDataBatchS batch_size2 r
-> ADRnnMnistParametersShaped Concrete SizeMnistHeight width r
-> r
forall (target :: TK -> Type) (h :: Nat) (w :: Nat)
       (out_width :: Nat) (batch_size :: Nat) r.
((h :: Nat) ~ (SizeMnistHeight :: Nat),
 (w :: Nat) ~ (SizeMnistHeight :: Nat),
 (target :: (TK -> Type)) ~ (Concrete :: (TK -> Type)),
 Differentiable r, GoodScalar r) =>
SNat out_width
-> SNat batch_size
-> MnistDataBatchS batch_size r
-> ADRnnMnistParametersShaped target h out_width r
-> r
MnistRnnShaped2.rnnMnistTestS
          SNat width
width (forall (n :: Nat). KnownNat n => SNat n
SNat @batch_size2) MnistDataBatchS batch_size2 r
mnistData (forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget @Concrete Concrete (XParams width r)
testParams)
  in String -> Assertion -> TestTree
testCase String
name (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ do
    Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
      String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: Epochs to run/max batches per epoch: %d/%d"
             String
prefix Int
epochs Int
maxBatches
    trainData <- (MnistData r -> MnistDataS r) -> [MnistData r] -> [MnistDataS r]
forall a b. (a -> b) -> [a] -> [b]
map MnistData r -> MnistDataS r
forall r. PrimElt r => MnistData r -> MnistDataS r
mkMnistDataS
                 ([MnistData r] -> [MnistDataS r])
-> IO [MnistData r] -> IO [MnistDataS r]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> String -> IO [MnistData r]
forall r.
(Storable r, Fractional r) =>
String -> String -> IO [MnistData r]
loadMnistData String
trainGlyphsPath String
trainLabelsPath
    testData <- map mkMnistDataS . take (totalBatchSize * maxBatches)
                <$> loadMnistData testGlyphsPath testLabelsPath
    withSNat ((totalBatchSize * maxBatches) `min` 10000)
     $ \(SNat @lenTestData) -> do
       let testDataS :: MnistDataBatchS n r
testDataS = forall (batch_size :: Nat) r.
(Elt r, KnownNat batch_size) =>
[MnistDataS r] -> MnistDataBatchS batch_size r
mkMnistDataBatchS @lenTestData [MnistDataS r]
testData
           ftk :: FullShapeTK
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
           (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
ftk = forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams width r)) Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
           (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
Concrete (XParams width r)
targetInit
       (_, _, var, varAst) <- FullShapeTK
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
           (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> IO
     (AstVarName
        PrimalSpan
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct
                    (TKS2
                       ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                       (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKProduct
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
      AstTensor
        AstMethodShare
        PrimalSpan
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct
                    (TKS2
                       ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                       (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKProduct
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
      AstVarName
        FullSpan
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct
                    (TKS2
                       ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                       (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKProduct
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
      AstTensor
        AstMethodLet
        FullSpan
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct
                    (TKS2
                       ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                       (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKProduct
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
forall (x :: TK).
FullShapeTK x
-> IO
     (AstVarName PrimalSpan x, AstTensor AstMethodShare PrimalSpan x,
      AstVarName FullSpan x, AstTensor AstMethodLet FullSpan x)
funToAstRevIO FullShapeTK
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
           (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
ftk
       (varGlyph, astGlyph) <- funToAstIO (FTKS knownShS FTKScalar) id
       (varLabel, astLabel) <- funToAstIO (FTKS knownShS FTKScalar) id
       let ast :: AstTensor AstMethodLet FullSpan (TKScalar r)
           ast = AstTensor AstMethodLet FullSpan (TKScalar r)
-> AstTensor AstMethodLet FullSpan (TKScalar r)
forall (z :: TK) (s :: AstSpanType).
AstSpan s =>
AstTensor AstMethodLet s z -> AstTensor AstMethodLet s z
simplifyInline
                 (AstTensor AstMethodLet FullSpan (TKScalar r)
 -> AstTensor AstMethodLet FullSpan (TKScalar r))
-> AstTensor AstMethodLet FullSpan (TKScalar r)
-> AstTensor AstMethodLet FullSpan (TKScalar r)
forall a b. (a -> b) -> a -> b
$ SNat width
-> SNat batch_size
-> (PrimalOf
      (AstTensor AstMethodLet FullSpan)
      (TKS2
         ((':)
            @Nat
            batch_size
            ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
         (TKScalar r)),
    PrimalOf
      (AstTensor AstMethodLet FullSpan)
      (TKS2
         ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat)))
         (TKScalar r)))
-> ADRnnMnistParametersShaped
     (AstTensor AstMethodLet FullSpan) SizeMnistHeight width r
-> AstTensor AstMethodLet FullSpan (TKScalar r)
forall (target :: TK -> Type) (h :: Nat) (w :: Nat)
       (out_width :: Nat) (batch_size :: Nat) r.
((h :: Nat) ~ (SizeMnistHeight :: Nat),
 (w :: Nat) ~ (SizeMnistHeight :: Nat), Differentiable r,
 ADReady target, ADReady (PrimalOf target), GoodScalar r) =>
SNat out_width
-> SNat batch_size
-> (PrimalOf
      target
      (TKS
         ((':) @Nat batch_size ((':) @Nat h ((':) @Nat w ('[] @Nat)))) r),
    PrimalOf
      target
      (TKS
         ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r))
-> ADRnnMnistParametersShaped target h out_width r
-> target (TKScalar r)
MnistRnnShaped2.rnnMnistLossFusedS
                     SNat width
width SNat batch_size
batch_size (AstTensor
  AstMethodLet
  PrimalSpan
  (TKS2
     ((':)
        @Nat
        batch_size
        ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
     (TKScalar r))
PrimalOf
  (AstTensor AstMethodLet FullSpan)
  (TKS2
     ((':)
        @Nat
        batch_size
        ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
     (TKScalar r))
astGlyph, AstTensor
  AstMethodLet
  PrimalSpan
  (TKS2
     ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat)))
     (TKScalar r))
PrimalOf
  (AstTensor AstMethodLet FullSpan)
  (TKS2
     ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat)))
     (TKScalar r))
astLabel)
                     (AstTensor
  AstMethodLet
  FullSpan
  (X (ADRnnMnistParametersShaped
        (AstTensor AstMethodLet FullSpan) SizeMnistHeight width r))
-> ADRnnMnistParametersShaped
     (AstTensor AstMethodLet FullSpan) SizeMnistHeight width r
forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget AstTensor
  AstMethodLet
  FullSpan
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
           (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
AstTensor
  AstMethodLet
  FullSpan
  (X (ADRnnMnistParametersShaped
        (AstTensor AstMethodLet FullSpan) SizeMnistHeight width r))
varAst)
           f :: MnistDataBatchS batch_size r
             -> ADVal Concrete (XParams width r)
             -> ADVal Concrete (TKScalar r)
           f (Shaped
  ((':)
     @Nat
     batch_size
     ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
  r
glyph, Shaped
  ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r
label) ADVal Concrete (XParams width r)
varInputs =
             let env :: AstEnv (ADVal Concrete)
env = AstVarName
  FullSpan
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
           (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> ADVal
     Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                    (TKScalar r))
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKProduct
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
              (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> AstEnv (ADVal Concrete)
-> AstEnv (ADVal Concrete)
forall (target :: TK -> Type) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName
  FullSpan
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
           (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
var ADVal
  Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
           (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
ADVal Concrete (XParams width r)
varInputs AstEnv (ADVal Concrete)
forall (target :: TK -> Type). AstEnv target
emptyEnv
                 envMnist :: AstEnv (ADVal Concrete)
envMnist = AstVarName
  PrimalSpan
  (TKS2
     ((':)
        @Nat
        batch_size
        ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
     (TKScalar r))
-> ADVal
     Concrete
     (TKS2
        ((':)
           @Nat
           batch_size
           ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
        (TKScalar r))
-> AstEnv (ADVal Concrete)
-> AstEnv (ADVal Concrete)
forall (target :: TK -> Type) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName
  PrimalSpan
  (TKS2
     ((':)
        @Nat
        batch_size
        ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
     (TKScalar r))
varGlyph (Shaped
  ((':)
     @Nat
     batch_size
     ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
  r
-> ADVal
     Concrete
     (TKS2
        ((':)
           @Nat
           batch_size
           ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
        (TKScalar r))
forall r (target :: TK -> Type) (sh :: [Nat]).
(GoodScalar r, BaseTensor target) =>
Shaped sh r -> target (TKS sh r)
sconcrete Shaped
  ((':)
     @Nat
     batch_size
     ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
  r
glyph)
                            (AstEnv (ADVal Concrete) -> AstEnv (ADVal Concrete))
-> AstEnv (ADVal Concrete) -> AstEnv (ADVal Concrete)
forall a b. (a -> b) -> a -> b
$ AstVarName
  PrimalSpan
  (TKS2
     ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat)))
     (TKScalar r))
-> ADVal
     Concrete
     (TKS2
        ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat)))
        (TKScalar r))
-> AstEnv (ADVal Concrete)
-> AstEnv (ADVal Concrete)
forall (target :: TK -> Type) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName
  PrimalSpan
  (TKS2
     ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat)))
     (TKScalar r))
varLabel (Shaped
  ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r
-> ADVal
     Concrete
     (TKS2
        ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat)))
        (TKScalar r))
forall r (target :: TK -> Type) (sh :: [Nat]).
(GoodScalar r, BaseTensor target) =>
Shaped sh r -> target (TKS sh r)
sconcrete Shaped
  ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r
label) AstEnv (ADVal Concrete)
env
             in AstEnv (ADVal Concrete)
-> AstTensor AstMethodLet FullSpan (TKScalar r)
-> ADVal Concrete (TKScalar r)
forall (target :: TK -> Type) (y :: TK).
ADReady target =>
AstEnv target -> AstTensor AstMethodLet FullSpan y -> target y
interpretAstFull AstEnv (ADVal Concrete)
envMnist AstTensor AstMethodLet FullSpan (TKScalar r)
ast
           runBatch :: ( Concrete (XParams width r)
                       , StateAdam (XParams width r) )
                    -> (Int, [MnistDataS r])
                    -> IO ( Concrete (XParams width r)
                          , StateAdam (XParams width r) )
           runBatch (!Concrete (XParams width r)
parameters, !StateAdam (XParams width r)
stateAdam) (Int
k, [MnistDataS r]
chunk) = do
             let chunkS :: [MnistDataBatchS batch_size r]
chunkS = ([MnistDataS r] -> MnistDataBatchS batch_size r)
-> [[MnistDataS r]] -> [MnistDataBatchS batch_size r]
forall a b. (a -> b) -> [a] -> [b]
map [MnistDataS r] -> MnistDataBatchS batch_size r
forall (batch_size :: Nat) r.
(Elt r, KnownNat batch_size) =>
[MnistDataS r] -> MnistDataBatchS batch_size r
mkMnistDataBatchS
                          ([[MnistDataS r]] -> [MnistDataBatchS batch_size r])
-> [[MnistDataS r]] -> [MnistDataBatchS batch_size r]
forall a b. (a -> b) -> a -> b
$ ([MnistDataS r] -> Bool) -> [[MnistDataS r]] -> [[MnistDataS r]]
forall a. (a -> Bool) -> [a] -> [a]
filter (\[MnistDataS r]
ch -> [MnistDataS r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataS r]
ch Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
miniBatchSize)
                          ([[MnistDataS r]] -> [[MnistDataS r]])
-> [[MnistDataS r]] -> [[MnistDataS r]]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataS r] -> [[MnistDataS r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
miniBatchSize [MnistDataS r]
chunk
                 res :: (Concrete (XParams width r), StateAdam (XParams width r))
res@(Concrete (XParams width r)
parameters2, StateAdam (XParams width r)
_) =
                   forall a (x :: TK) (z :: TK).
KnownSTK x =>
(a -> ADVal Concrete x -> ADVal Concrete z)
-> [a] -> Concrete x -> StateAdam x -> (Concrete x, StateAdam x)
sgdAdam @(MnistDataBatchS batch_size r)
                               @(XParams width r)
                               MnistDataBatchS batch_size r
-> ADVal Concrete (XParams width r) -> ADVal Concrete (TKScalar r)
f [MnistDataBatchS batch_size r]
chunkS Concrete (XParams width r)
parameters StateAdam (XParams width r)
stateAdam
                 trainScore :: r
trainScore = Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat ([MnistDataS r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataS r]
chunk) ((forall (n :: Nat). KnownNat n => SNat n -> r) -> r)
-> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
forall a b. (a -> b) -> a -> b
$ \(SNat @len) ->
                   forall (batch_size2 :: Nat).
KnownNat batch_size2 =>
MnistDataBatchS batch_size2 r -> Concrete (XParams width r) -> r
ftest @len ([MnistDataS r] -> MnistDataBatchS n r
forall (batch_size :: Nat) r.
(Elt r, KnownNat batch_size) =>
[MnistDataS r] -> MnistDataBatchS batch_size r
mkMnistDataBatchS [MnistDataS r]
chunk) Concrete (XParams width r)
parameters2
                 testScore :: r
testScore = forall (batch_size2 :: Nat).
KnownNat batch_size2 =>
MnistDataBatchS batch_size2 r -> Concrete (XParams width r) -> r
ftest @lenTestData MnistDataBatchS n r
testDataS Concrete (XParams width r)
parameters2
                 lenChunk :: Int
lenChunk = [MnistDataS r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataS r]
chunk
             Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (SNat width -> Int
forall (n :: Nat). SNat n -> Int
sNatValue SNat width
width Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$ do
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
                 String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: (Batch %d with %d points)"
                        String
prefix Int
k Int
lenChunk
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
                 String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Training error:   %.2f%%"
                        String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
trainScore) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
                 String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Validation error: %.2f%%"
                        String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
testScore ) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
             (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct
               (TKS2
                  ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                  (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
         (TKProduct
            (TKProduct
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
      (TKProduct
         (TKS2
            ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
            (TKScalar r))
         (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct
               (TKS2
                  ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                  (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
         (TKProduct
            (TKProduct
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
      (TKProduct
         (TKS2
            ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
            (TKScalar r))
         (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct
                    (TKS2
                       ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                       (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKProduct
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
      StateAdam
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct
                    (TKS2
                       ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                       (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKProduct
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct
               (TKS2
                  ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                  (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
         (TKProduct
            (TKProduct
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
      (TKProduct
         (TKS2
            ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
            (TKScalar r))
         (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct
               (TKS2
                  ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                  (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
         (TKProduct
            (TKProduct
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
      (TKProduct
         (TKS2
            ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
            (TKScalar r))
         (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
(Concrete (XParams width r), StateAdam (XParams width r))
res
       let runEpoch :: Int
                    -> ( Concrete (XParams width r)
                       , StateAdam (XParams width r) )
                    -> IO (Concrete (XParams width r))
           runEpoch Int
n (Concrete (XParams width r)
params2, StateAdam (XParams width r)
_) | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
epochs = Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
           (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct
                    (TKS2
                       ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                       (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKProduct
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
           (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
Concrete (XParams width r)
params2
           runEpoch Int
n paramsStateAdam :: (Concrete (XParams width r), StateAdam (XParams width r))
paramsStateAdam@(!Concrete (XParams width r)
_, !StateAdam (XParams width r)
_) = do
             Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (SNat width -> Int
forall (n :: Nat). SNat n -> Int
sNatValue SNat width
width Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$ String -> String -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: [Epoch %d]" String
prefix Int
n
             let trainDataShuffled :: [MnistDataS r]
trainDataShuffled = StdGen -> [MnistDataS r] -> [MnistDataS r]
forall a. StdGen -> [a] -> [a]
shuffle (Int -> StdGen
mkStdGen (Int -> StdGen) -> Int -> StdGen
forall a b. (a -> b) -> a -> b
$ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
5) [MnistDataS r]
trainData
                 chunks :: [(Int, [MnistDataS r])]
chunks = Int -> [(Int, [MnistDataS r])] -> [(Int, [MnistDataS r])]
forall a. Int -> [a] -> [a]
take Int
maxBatches
                          ([(Int, [MnistDataS r])] -> [(Int, [MnistDataS r])])
-> [(Int, [MnistDataS r])] -> [(Int, [MnistDataS r])]
forall a b. (a -> b) -> a -> b
$ [Int] -> [[MnistDataS r]] -> [(Int, [MnistDataS r])]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
1 ..]
                          ([[MnistDataS r]] -> [(Int, [MnistDataS r])])
-> [[MnistDataS r]] -> [(Int, [MnistDataS r])]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataS r] -> [[MnistDataS r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
totalBatchSize [MnistDataS r]
trainDataShuffled
             res <- ((Concrete
    (TKProduct
       (TKProduct
          (TKProduct
             (TKProduct
                (TKS2
                   ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                   (TKScalar r))
                (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
             (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
          (TKProduct
             (TKProduct
                (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
             (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
       (TKProduct
          (TKS2
             ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
             (TKScalar r))
          (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
  StateAdam
    (TKProduct
       (TKProduct
          (TKProduct
             (TKProduct
                (TKS2
                   ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                   (TKScalar r))
                (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
             (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
          (TKProduct
             (TKProduct
                (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
             (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
       (TKProduct
          (TKS2
             ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
             (TKScalar r))
          (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
 -> (Int, [MnistDataS r])
 -> IO
      (Concrete
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKProduct
                     (TKS2
                        ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                        (TKScalar r))
                     (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                  (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
               (TKProduct
                  (TKProduct
                     (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                     (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                  (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
            (TKProduct
               (TKS2
                  ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
                  (TKScalar r))
               (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
       StateAdam
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKProduct
                     (TKS2
                        ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                        (TKScalar r))
                     (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                  (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
               (TKProduct
                  (TKProduct
                     (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                     (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                  (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
            (TKProduct
               (TKS2
                  ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
                  (TKScalar r))
               (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKS2
                     ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                     (TKScalar r))
                  (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
               (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
            (TKProduct
               (TKProduct
                  (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                  (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
               (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
         (TKProduct
            (TKS2
               ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
               (TKScalar r))
            (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
    StateAdam
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKS2
                     ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                     (TKScalar r))
                  (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
               (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
            (TKProduct
               (TKProduct
                  (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                  (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
               (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
         (TKProduct
            (TKS2
               ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
               (TKScalar r))
            (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
-> [(Int, [MnistDataS r])]
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct
                    (TKS2
                       ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                       (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKProduct
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
      StateAdam
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct
                    (TKS2
                       ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                       (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKProduct
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
forall (t :: Type -> Type) (m :: Type -> Type) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct
               (TKS2
                  ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                  (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
         (TKProduct
            (TKProduct
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
      (TKProduct
         (TKS2
            ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
            (TKScalar r))
         (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct
               (TKS2
                  ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                  (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
         (TKProduct
            (TKProduct
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
      (TKProduct
         (TKS2
            ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
            (TKScalar r))
         (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
-> (Int, [MnistDataS r])
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct
                    (TKS2
                       ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                       (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKProduct
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
      StateAdam
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct
                    (TKS2
                       ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                       (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKProduct
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
(Concrete (XParams width r), StateAdam (XParams width r))
-> (Int, [MnistDataS r])
-> IO (Concrete (XParams width r), StateAdam (XParams width r))
runBatch (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct
               (TKS2
                  ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                  (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
         (TKProduct
            (TKProduct
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
      (TKProduct
         (TKS2
            ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
            (TKScalar r))
         (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct
               (TKS2
                  ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                  (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
         (TKProduct
            (TKProduct
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
      (TKProduct
         (TKS2
            ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
            (TKScalar r))
         (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
(Concrete (XParams width r), StateAdam (XParams width r))
paramsStateAdam [(Int, [MnistDataS r])]
chunks
             runEpoch (succ n) res
       res <- runEpoch 1 (targetInit, initialStateAdam ftk)
       let testErrorFinal = r
1 r -> r -> r
forall a. Num a => a -> a -> a
- forall (batch_size2 :: Nat).
KnownNat batch_size2 =>
MnistDataBatchS batch_size2 r -> Concrete (XParams width r) -> r
ftest @lenTestData MnistDataBatchS n r
testDataS Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
           (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
Concrete (XParams width r)
res
       testErrorFinal @?~ expected

{-# SPECIALIZE mnistTestCaseRNNSI
  :: String
  -> Int -> Int -> SNat width -> SNat batch_size -> Int -> Double
  -> TestTree #-}

tensorADValMnistTestsRNNSI :: TestTree
tensorADValMnistTestsRNNSI :: TestTree
tensorADValMnistTestsRNNSI = String -> [TestTree] -> TestTree
testGroup String
"RNNS Intermediate MNIST tests"
  [ String
-> Int -> Int -> SNat 128 -> SNat 150 -> Int -> Double -> TestTree
forall (width :: Nat) (batch_size :: Nat) r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> SNat width
-> SNat batch_size
-> Int
-> r
-> TestTree
mnistTestCaseRNNSI String
"RNNSI 1 epoch, 1 batch" Int
1 Int
1 (forall (n :: Nat). KnownNat n => SNat n
SNat @128) (forall (n :: Nat). KnownNat n => SNat n
SNat @150) Int
5000
                       (Double
0.6026 :: Double)
  , String
-> Int -> Int -> SNat 4 -> SNat 5 -> Int -> Float -> TestTree
forall (width :: Nat) (batch_size :: Nat) r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> SNat width
-> SNat batch_size
-> Int
-> r
-> TestTree
mnistTestCaseRNNSI String
"RNNSI artificial 1 2 3 4 5" Int
2 Int
3 (forall (n :: Nat). KnownNat n => SNat n
SNat @4) (forall (n :: Nat). KnownNat n => SNat n
SNat @5) Int
50
                       (Float
0.8933333 :: Float)
  , String
-> Int -> Int -> SNat 3 -> SNat 2 -> Int -> Double -> TestTree
forall (width :: Nat) (batch_size :: Nat) r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> SNat width
-> SNat batch_size
-> Int
-> r
-> TestTree
mnistTestCaseRNNSI String
"RNNSI artificial 5 4 3 2 1" Int
5 Int
4 (forall (n :: Nat). KnownNat n => SNat n
SNat @3) (forall (n :: Nat). KnownNat n => SNat n
SNat @2) Int
49
                       (Double
0.8622448979591837 :: Double)
  , String
-> Int -> Int -> SNat 128 -> SNat 150 -> Int -> Float -> TestTree
forall (width :: Nat) (batch_size :: Nat) r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> SNat width
-> SNat batch_size
-> Int
-> r
-> TestTree
mnistTestCaseRNNSI String
"RNNSI 1 epoch, 0 batch" Int
1 Int
0 (forall (n :: Nat). KnownNat n => SNat n
SNat @128) (forall (n :: Nat). KnownNat n => SNat n
SNat @150) Int
50
                       (Float
1.0 :: Float)
  ]

-- JAX differentiation, Ast term built and differentiated only once
-- and the result interpreted with different inputs in each gradient
-- descent iteration.
mnistTestCaseRNNSO
  :: forall width batch_size r.
     ( Differentiable r, GoodScalar r
     , PrintfArg r, AssertEqualUpToEpsilon r, ADTensorScalar r ~ r )
  => String
  -> Int -> Int -> SNat width -> SNat batch_size -> Int -> r
  -> TestTree
mnistTestCaseRNNSO :: forall (width :: Nat) (batch_size :: Nat) r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r,
 (ADTensorScalar r :: Type) ~ (r :: Type)) =>
String
-> Int
-> Int
-> SNat width
-> SNat batch_size
-> Int
-> r
-> TestTree
mnistTestCaseRNNSO String
prefix Int
epochs Int
maxBatches width :: SNat width
width@SNat width
SNat batch_size :: SNat batch_size
batch_size@SNat batch_size
SNat
                   Int
totalBatchSize r
expected =
  let targetInit :: Concrete (XParams width r)
targetInit =
        (Concrete (XParams width r), StdGen) -> Concrete (XParams width r)
forall a b. (a, b) -> a
fst ((Concrete (XParams width r), StdGen)
 -> Concrete (XParams width r))
-> (Concrete (XParams width r), StdGen)
-> Concrete (XParams width r)
forall a b. (a -> b) -> a -> b
$ forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue @(Concrete (XParams width r)) Double
0.23 (Int -> StdGen
mkStdGen Int
44)
      miniBatchSize :: Int
miniBatchSize = SNat batch_size -> Int
forall (n :: Nat). SNat n -> Int
sNatValue SNat batch_size
batch_size
      name :: String
name = String
prefix String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": "
             String -> String -> String
forall a. [a] -> [a] -> [a]
++ [String] -> String
unwords [ Int -> String
forall a. Show a => a -> String
show Int
epochs, Int -> String
forall a. Show a => a -> String
show Int
maxBatches
                        , Int -> String
forall a. Show a => a -> String
show (SNat width -> Int
forall (n :: Nat). SNat n -> Int
sNatValue SNat width
width), Int -> String
forall a. Show a => a -> String
show Int
miniBatchSize
                        , Int -> String
forall a. Show a => a -> String
show (Int -> String) -> Int -> String
forall a b. (a -> b) -> a -> b
$ SingletonTK (XParams width r) -> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK
                          (SingletonTK (XParams width r) -> Int)
-> SingletonTK (XParams width r) -> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams width r)
                        , Int -> String
forall a. Show a => a -> String
show (SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
           (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                    (TKScalar r))
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKProduct
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
              (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
           (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
           (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
Concrete (XParams width r)
targetInit) ]
      ftest :: forall batch_size2. KnownNat batch_size2
            => MnistDataBatchS batch_size2 r -> Concrete (XParams width r)
            -> r
      ftest :: forall (batch_size2 :: Nat).
KnownNat batch_size2 =>
MnistDataBatchS batch_size2 r -> Concrete (XParams width r) -> r
ftest MnistDataBatchS batch_size2 r
_ Concrete (XParams width r)
_ | Just (:~:) @Nat 0 batch_size2
Refl <- Proxy @Nat 0
-> Proxy @Nat batch_size2 -> Maybe ((:~:) @Nat 0 batch_size2)
forall (a :: Nat) (b :: Nat) (proxy1 :: Nat -> Type)
       (proxy2 :: Nat -> Type).
(KnownNat a, KnownNat b) =>
proxy1 a -> proxy2 b -> Maybe ((:~:) @Nat a b)
sameNat (forall (t :: Nat). Proxy @Nat t
forall {k} (t :: k). Proxy @k t
Proxy @0) (forall (t :: Nat). Proxy @Nat t
forall {k} (t :: k). Proxy @k t
Proxy @batch_size2) = r
0
      ftest MnistDataBatchS batch_size2 r
mnistData Concrete (XParams width r)
testParams =
        SNat width
-> SNat batch_size2
-> MnistDataBatchS batch_size2 r
-> ADRnnMnistParametersShaped Concrete SizeMnistHeight width r
-> r
forall (target :: TK -> Type) (h :: Nat) (w :: Nat)
       (out_width :: Nat) (batch_size :: Nat) r.
((h :: Nat) ~ (SizeMnistHeight :: Nat),
 (w :: Nat) ~ (SizeMnistHeight :: Nat),
 (target :: (TK -> Type)) ~ (Concrete :: (TK -> Type)),
 Differentiable r, GoodScalar r) =>
SNat out_width
-> SNat batch_size
-> MnistDataBatchS batch_size r
-> ADRnnMnistParametersShaped target h out_width r
-> r
MnistRnnShaped2.rnnMnistTestS
          SNat width
width (forall (n :: Nat). KnownNat n => SNat n
SNat @batch_size2) MnistDataBatchS batch_size2 r
mnistData (forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget @Concrete Concrete (XParams width r)
testParams)
  in String -> Assertion -> TestTree
testCase String
name (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ do
    Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
      String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: Epochs to run/max batches per epoch: %d/%d"
             String
prefix Int
epochs Int
maxBatches
    trainData <- (MnistData r -> MnistDataS r) -> [MnistData r] -> [MnistDataS r]
forall a b. (a -> b) -> [a] -> [b]
map MnistData r -> MnistDataS r
forall r. PrimElt r => MnistData r -> MnistDataS r
mkMnistDataS
                 ([MnistData r] -> [MnistDataS r])
-> IO [MnistData r] -> IO [MnistDataS r]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> String -> IO [MnistData r]
forall r.
(Storable r, Fractional r) =>
String -> String -> IO [MnistData r]
loadMnistData String
trainGlyphsPath String
trainLabelsPath
    testData <- map mkMnistDataS . take (totalBatchSize * maxBatches)
                <$> loadMnistData testGlyphsPath testLabelsPath
    withSNat ((totalBatchSize * maxBatches) `min` 10000)
     $ \(SNat @lenTestData) -> do
       let testDataS :: MnistDataBatchS n r
testDataS = forall (batch_size :: Nat) r.
(Elt r, KnownNat batch_size) =>
[MnistDataS r] -> MnistDataBatchS batch_size r
mkMnistDataBatchS @lenTestData [MnistDataS r]
testData
           ftk :: FullShapeTK
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
           (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
ftk = forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams width r)) Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
           (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
Concrete (XParams width r)
targetInit
           ftkData :: FullShapeTK
  (TKProduct
     (TKS
        ((':)
           @Nat
           batch_size
           ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
        r)
     (TKS
        ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r))
ftkData = FullShapeTK
  (TKS
     ((':)
        @Nat
        batch_size
        ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
     r)
-> FullShapeTK
     (TKS
        ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)
-> FullShapeTK
     (TKProduct
        (TKS
           ((':)
              @Nat
              batch_size
              ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
           r)
        (TKS
           ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r))
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct (ShS
  ((':)
     @Nat
     batch_size
     ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
-> FullShapeTK (TKScalar r)
-> FullShapeTK
     (TKS
        ((':)
           @Nat
           batch_size
           ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
        r)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS (SNat batch_size
batch_size
                                       SNat batch_size
-> ShS
     ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat)))
-> ShS
     ((':)
        @Nat
        batch_size
        ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
forall {sh1 :: [Nat]} (n :: Nat) (sh :: [Nat]).
(KnownNat n, ((':) @Nat n sh :: [Nat]) ~ (sh1 :: [Nat])) =>
SNat n -> ShS sh -> ShS sh1
:$$ SNat SizeMnistHeight
sizeMnistHeight
                                       SNat SizeMnistHeight
-> ShS ((':) @Nat SizeMnistHeight ('[] @Nat))
-> ShS
     ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat)))
forall {sh1 :: [Nat]} (n :: Nat) (sh :: [Nat]).
(KnownNat n, ((':) @Nat n sh :: [Nat]) ~ (sh1 :: [Nat])) =>
SNat n -> ShS sh -> ShS sh1
:$$ SNat SizeMnistHeight
sizeMnistWidth
                                       SNat SizeMnistHeight
-> ShS ('[] @Nat) -> ShS ((':) @Nat SizeMnistHeight ('[] @Nat))
forall {sh1 :: [Nat]} (n :: Nat) (sh :: [Nat]).
(KnownNat n, ((':) @Nat n sh :: [Nat]) ~ (sh1 :: [Nat])) =>
SNat n -> ShS sh -> ShS sh1
:$$ ShS ('[] @Nat)
forall (sh :: [Nat]).
((sh :: [Nat]) ~ ('[] @Nat :: [Nat])) =>
ShS sh
ZSS) FullShapeTK (TKScalar r)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar)
                                (ShS ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat)))
-> FullShapeTK (TKScalar r)
-> FullShapeTK
     (TKS
        ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS (SNat batch_size
batch_size
                                       SNat batch_size
-> ShS ((':) @Nat SizeMnistLabel ('[] @Nat))
-> ShS ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat)))
forall {sh1 :: [Nat]} (n :: Nat) (sh :: [Nat]).
(KnownNat n, ((':) @Nat n sh :: [Nat]) ~ (sh1 :: [Nat])) =>
SNat n -> ShS sh -> ShS sh1
:$$ SNat SizeMnistLabel
sizeMnistLabel
                                       SNat SizeMnistLabel
-> ShS ('[] @Nat) -> ShS ((':) @Nat SizeMnistLabel ('[] @Nat))
forall {sh1 :: [Nat]} (n :: Nat) (sh :: [Nat]).
(KnownNat n, ((':) @Nat n sh :: [Nat]) ~ (sh1 :: [Nat])) =>
SNat n -> ShS sh -> ShS sh1
:$$ ShS ('[] @Nat)
forall (sh :: [Nat]).
((sh :: [Nat]) ~ ('[] @Nat :: [Nat])) =>
ShS sh
ZSS) FullShapeTK (TKScalar r)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar)
           f :: ( ADRnnMnistParametersShaped (AstTensor AstMethodLet FullSpan)
                    SizeMnistHeight width r
                , ( AstTensor AstMethodLet FullSpan
                      (TKS '[batch_size, SizeMnistHeight, SizeMnistWidth] r)
                  , AstTensor AstMethodLet FullSpan
                      (TKS '[batch_size, SizeMnistLabel] r) ) )
             -> AstTensor AstMethodLet FullSpan (TKScalar r)
           f :: (ADRnnMnistParametersShaped
   (AstTensor AstMethodLet FullSpan) SizeMnistHeight width r,
 (AstTensor
    AstMethodLet
    FullSpan
    (TKS
       ((':)
          @Nat
          batch_size
          ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
       r),
  AstTensor
    AstMethodLet
    FullSpan
    (TKS
       ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)))
-> AstTensor AstMethodLet FullSpan (TKScalar r)
f = \ (ADRnnMnistParametersShaped
  (AstTensor AstMethodLet FullSpan) SizeMnistHeight width r
pars, (AstTensor
  AstMethodLet
  FullSpan
  (TKS
     ((':)
        @Nat
        batch_size
        ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
     r)
glyphS, AstTensor
  AstMethodLet
  FullSpan
  (TKS
     ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)
labelS)) ->
             SNat width
-> SNat batch_size
-> (PrimalOf
      (AstTensor AstMethodLet FullSpan)
      (TKS
         ((':)
            @Nat
            batch_size
            ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
         r),
    PrimalOf
      (AstTensor AstMethodLet FullSpan)
      (TKS
         ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r))
-> ADRnnMnistParametersShaped
     (AstTensor AstMethodLet FullSpan) SizeMnistHeight width r
-> AstTensor AstMethodLet FullSpan (TKScalar r)
forall (target :: TK -> Type) (h :: Nat) (w :: Nat)
       (out_width :: Nat) (batch_size :: Nat) r.
((h :: Nat) ~ (SizeMnistHeight :: Nat),
 (w :: Nat) ~ (SizeMnistHeight :: Nat), Differentiable r,
 ADReady target, ADReady (PrimalOf target), GoodScalar r) =>
SNat out_width
-> SNat batch_size
-> (PrimalOf
      target
      (TKS
         ((':) @Nat batch_size ((':) @Nat h ((':) @Nat w ('[] @Nat)))) r),
    PrimalOf
      target
      (TKS
         ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r))
-> ADRnnMnistParametersShaped target h out_width r
-> target (TKScalar r)
MnistRnnShaped2.rnnMnistLossFusedS
               SNat width
width SNat batch_size
batch_size (AstTensor
  AstMethodLet
  FullSpan
  (TKS
     ((':)
        @Nat
        batch_size
        ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
     r)
-> PrimalOf
     (AstTensor AstMethodLet FullSpan)
     (TKS
        ((':)
           @Nat
           batch_size
           ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
        r)
forall (target :: TK -> Type) (sh :: [Nat]) (x :: TK).
BaseTensor target =>
target (TKS2 sh x) -> PrimalOf target (TKS2 sh x)
sprimalPart AstTensor
  AstMethodLet
  FullSpan
  (TKS
     ((':)
        @Nat
        batch_size
        ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
     r)
glyphS, AstTensor
  AstMethodLet
  FullSpan
  (TKS
     ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)
-> PrimalOf
     (AstTensor AstMethodLet FullSpan)
     (TKS
        ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)
forall (target :: TK -> Type) (sh :: [Nat]) (x :: TK).
BaseTensor target =>
target (TKS2 sh x) -> PrimalOf target (TKS2 sh x)
sprimalPart AstTensor
  AstMethodLet
  FullSpan
  (TKS
     ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)
labelS) ADRnnMnistParametersShaped
  (AstTensor AstMethodLet FullSpan) SizeMnistHeight width r
pars
           artRaw :: AstArtifactRev
  (X (ADRnnMnistParametersShaped
        (AstTensor AstMethodLet FullSpan) SizeMnistHeight width r,
      (AstTensor
         AstMethodLet
         FullSpan
         (TKS
            ((':)
               @Nat
               batch_size
               ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
            r),
       AstTensor
         AstMethodLet
         FullSpan
         (TKS
            ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r))))
  (TKScalar r)
artRaw = IncomingCotangentHandling
-> ((ADRnnMnistParametersShaped
       (AstTensor AstMethodLet FullSpan) SizeMnistHeight width r,
     (AstTensor
        AstMethodLet
        FullSpan
        (TKS
           ((':)
              @Nat
              batch_size
              ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
           r),
      AstTensor
        AstMethodLet
        FullSpan
        (TKS
           ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)))
    -> AstTensor AstMethodLet FullSpan (TKScalar r))
-> FullShapeTK
     (X (ADRnnMnistParametersShaped
           (AstTensor AstMethodLet FullSpan) SizeMnistHeight width r,
         (AstTensor
            AstMethodLet
            FullSpan
            (TKS
               ((':)
                  @Nat
                  batch_size
                  ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
               r),
          AstTensor
            AstMethodLet
            FullSpan
            (TKS
               ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r))))
-> AstArtifactRev
     (X (ADRnnMnistParametersShaped
           (AstTensor AstMethodLet FullSpan) SizeMnistHeight width r,
         (AstTensor
            AstMethodLet
            FullSpan
            (TKS
               ((':)
                  @Nat
                  batch_size
                  ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
               r),
          AstTensor
            AstMethodLet
            FullSpan
            (TKS
               ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r))))
     (TKScalar r)
forall src (ztgt :: TK) tgt.
(AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
 (tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
IncomingCotangentHandling
-> (src -> tgt)
-> FullShapeTK (X src)
-> AstArtifactRev (X src) ztgt
revArtifactAdapt IncomingCotangentHandling
IgnoreIncomingCotangent
                                     (ADRnnMnistParametersShaped
   (AstTensor AstMethodLet FullSpan) SizeMnistHeight width r,
 (AstTensor
    AstMethodLet
    FullSpan
    (TKS
       ((':)
          @Nat
          batch_size
          ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
       r),
  AstTensor
    AstMethodLet
    FullSpan
    (TKS
       ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)))
-> AstTensor AstMethodLet FullSpan (TKScalar r)
f (FullShapeTK
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
           (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> FullShapeTK
     (TKProduct
        (TKS
           ((':)
              @Nat
              batch_size
              ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
           r)
        (TKS
           ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r))
-> FullShapeTK
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct
                    (TKS2
                       ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                       (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKProduct
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS
              ((':)
                 @Nat
                 batch_size
                 ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
              r)
           (TKS
              ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)))
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct FullShapeTK
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
           (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
ftk FullShapeTK
  (TKProduct
     (TKS
        ((':)
           @Nat
           batch_size
           ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
        r)
     (TKS
        ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r))
ftkData)
           art :: AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                    (TKScalar r))
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKProduct
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
              (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS
           ((':)
              @Nat
              batch_size
              ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
           r)
        (TKS
           ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)))
  (TKScalar r)
art = AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                    (TKScalar r))
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKProduct
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
              (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS
           ((':)
              @Nat
              batch_size
              ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
           r)
        (TKS
           ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)))
  (TKScalar r)
-> AstArtifactRev
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct
                    (TKS2
                       ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                       (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKProduct
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS
              ((':)
                 @Nat
                 batch_size
                 ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
              r)
           (TKS
              ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)))
     (TKScalar r)
forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstArtifactRev x z
simplifyArtifactGradient AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                    (TKScalar r))
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKProduct
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
              (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS
           ((':)
              @Nat
              batch_size
              ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
           r)
        (TKS
           ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)))
  (TKScalar r)
AstArtifactRev
  (X (ADRnnMnistParametersShaped
        (AstTensor AstMethodLet FullSpan) SizeMnistHeight width r,
      (AstTensor
         AstMethodLet
         FullSpan
         (TKS
            ((':)
               @Nat
               batch_size
               ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
            r),
       AstTensor
         AstMethodLet
         FullSpan
         (TKS
            ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r))))
  (TKScalar r)
artRaw
           go :: [MnistDataBatchS batch_size r]
              -> ( Concrete (XParams width r)
                 , StateAdam (XParams width r) )
              -> ( Concrete (XParams width r)
                 , StateAdam (XParams width r) )
           go :: [MnistDataBatchS batch_size r]
-> (Concrete (XParams width r), StateAdam (XParams width r))
-> (Concrete (XParams width r), StateAdam (XParams width r))
go [] (Concrete (XParams width r)
parameters, StateAdam (XParams width r)
stateAdam) = (Concrete (XParams width r)
parameters, StateAdam (XParams width r)
stateAdam)
           go ((Shaped
  ((':)
     @Nat
     batch_size
     ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
  r
glyph, Shaped
  ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r
label) : [MnistDataBatchS batch_size r]
rest) (!Concrete (XParams width r)
parameters, !StateAdam (XParams width r)
stateAdam) =
             let parametersAndInput :: Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                    (TKScalar r))
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKProduct
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
              (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS
           ((':)
              @Nat
              batch_size
              ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
           r)
        (TKS
           ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)))
parametersAndInput =
                   Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
           (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> Concrete
     (TKProduct
        (TKS
           ((':)
              @Nat
              batch_size
              ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
           r)
        (TKS
           ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct
                    (TKS2
                       ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                       (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKProduct
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS
              ((':)
                 @Nat
                 batch_size
                 ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
              r)
           (TKS
              ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)))
forall (x :: TK) (z :: TK).
Concrete x -> Concrete z -> Concrete (TKProduct x z)
forall (target :: TK -> Type) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
           (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
Concrete (XParams width r)
parameters (Concrete
  (TKS
     ((':)
        @Nat
        batch_size
        ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
     r)
-> Concrete
     (TKS
        ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)
-> Concrete
     (TKProduct
        (TKS
           ((':)
              @Nat
              batch_size
              ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
           r)
        (TKS
           ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r))
forall (x :: TK) (z :: TK).
Concrete x -> Concrete z -> Concrete (TKProduct x z)
forall (target :: TK -> Type) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair (Shaped
  ((':)
     @Nat
     batch_size
     ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
  r
-> Concrete
     (TKS
        ((':)
           @Nat
           batch_size
           ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
        r)
forall r (target :: TK -> Type) (sh :: [Nat]).
(GoodScalar r, BaseTensor target) =>
Shaped sh r -> target (TKS sh r)
sconcrete Shaped
  ((':)
     @Nat
     batch_size
     ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
  r
glyph) (Shaped
  ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r
-> Concrete
     (TKS
        ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)
forall r (target :: TK -> Type) (sh :: [Nat]).
(GoodScalar r, BaseTensor target) =>
Shaped sh r -> target (TKS sh r)
sconcrete Shaped
  ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r
label))
                 gradient :: Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
           (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
gradient = Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                    (TKScalar r))
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKProduct
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
              (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS
           ((':)
              @Nat
              batch_size
              ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
           r)
        (TKS
           ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                    (TKScalar r))
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKProduct
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
              (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
forall (x :: TK) (z :: TK). Concrete (TKProduct x z) -> Concrete x
forall (target :: TK -> Type) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKS2
                     ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                     (TKScalar r))
                  (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
               (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
            (TKProduct
               (TKProduct
                  (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                  (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
               (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
         (TKProduct
            (TKS2
               ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
               (TKScalar r))
            (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
      (TKProduct
         (TKS
            ((':)
               @Nat
               batch_size
               ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
            r)
         (TKS
            ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)))
 -> Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKS2
                     ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                     (TKScalar r))
                  (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
               (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
            (TKProduct
               (TKProduct
                  (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                  (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
               (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
         (TKProduct
            (TKS2
               ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
               (TKScalar r))
            (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct
                    (TKS2
                       ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                       (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKProduct
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS
              ((':)
                 @Nat
                 batch_size
                 ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
              r)
           (TKS
              ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                    (TKScalar r))
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKProduct
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
              (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
forall a b. (a -> b) -> a -> b
$ (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKS2
                     ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                     (TKScalar r))
                  (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
               (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
            (TKProduct
               (TKProduct
                  (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                  (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
               (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
         (TKProduct
            (TKS2
               ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
               (TKScalar r))
            (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
      (TKProduct
         (TKS
            ((':)
               @Nat
               batch_size
               ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
            r)
         (TKS
            ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r))),
 Concrete (TKScalar r))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct
                    (TKS2
                       ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                       (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKProduct
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS
              ((':)
                 @Nat
                 batch_size
                 ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
              r)
           (TKS
              ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)))
forall a b. (a, b) -> a
fst
                            ((Concrete
    (TKProduct
       (TKProduct
          (TKProduct
             (TKProduct
                (TKProduct
                   (TKS2
                      ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                      (TKScalar r))
                   (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
             (TKProduct
                (TKProduct
                   (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                   (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
          (TKProduct
             (TKS2
                ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
                (TKScalar r))
             (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
       (TKProduct
          (TKS
             ((':)
                @Nat
                batch_size
                ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
             r)
          (TKS
             ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r))),
  Concrete (TKScalar r))
 -> Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKProduct
                     (TKS2
                        ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                        (TKScalar r))
                     (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                  (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
               (TKProduct
                  (TKProduct
                     (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                     (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                  (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
            (TKProduct
               (TKS2
                  ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
                  (TKScalar r))
               (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
         (TKProduct
            (TKS
               ((':)
                  @Nat
                  batch_size
                  ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
               r)
            (TKS
               ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r))))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKProduct
                     (TKS2
                        ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                        (TKScalar r))
                     (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                  (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
               (TKProduct
                  (TKProduct
                     (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                     (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                  (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
            (TKProduct
               (TKS2
                  ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
                  (TKScalar r))
               (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
         (TKProduct
            (TKS
               ((':)
                  @Nat
                  batch_size
                  ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
               r)
            (TKS
               ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r))),
    Concrete (TKScalar r))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct
                    (TKS2
                       ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                       (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKProduct
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS
              ((':)
                 @Nat
                 batch_size
                 ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
              r)
           (TKS
              ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)))
forall a b. (a -> b) -> a -> b
$ AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                    (TKScalar r))
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKProduct
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
              (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS
           ((':)
              @Nat
              batch_size
              ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
           r)
        (TKS
           ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)))
  (TKScalar r)
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct
                    (TKS2
                       ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                       (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKProduct
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS
              ((':)
                 @Nat
                 batch_size
                 ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
              r)
           (TKS
              ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)))
-> Maybe (Concrete (ADTensorKind (TKScalar r)))
-> (Concrete
      (ADTensorKind
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKProduct
                     (TKProduct
                        (TKS2
                           ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                           (TKScalar r))
                        (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                     (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
                  (TKProduct
                     (TKProduct
                        (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                        (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                     (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
               (TKProduct
                  (TKS2
                     ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
                     (TKScalar r))
                  (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
            (TKProduct
               (TKS
                  ((':)
                     @Nat
                     batch_size
                     ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
                  r)
               (TKS
                  ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)))),
    Concrete (TKScalar r))
forall (x :: TK) (z :: TK).
AstArtifactRev x z
-> Concrete x
-> Maybe (Concrete (ADTensorKind z))
-> (Concrete (ADTensorKind x), Concrete z)
revInterpretArtifact
                                AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                    (TKScalar r))
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKProduct
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
              (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS
           ((':)
              @Nat
              batch_size
              ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
           r)
        (TKS
           ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)))
  (TKScalar r)
art Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                    (TKScalar r))
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKProduct
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
              (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS
           ((':)
              @Nat
              batch_size
              ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
           r)
        (TKS
           ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)))
parametersAndInput Maybe (Concrete (ADTensorKind (TKScalar r)))
Maybe (Concrete (TKScalar r))
forall a. Maybe a
Nothing
             in [MnistDataBatchS batch_size r]
-> (Concrete (XParams width r), StateAdam (XParams width r))
-> (Concrete (XParams width r), StateAdam (XParams width r))
go [MnistDataBatchS batch_size r]
rest (forall (y :: TK).
ArgsAdam
-> StateAdam y
-> SingletonTK y
-> Concrete y
-> Concrete (ADTensorKind y)
-> (Concrete y, StateAdam y)
updateWithGradientAdam
                           @(XParams width r)
                           ArgsAdam
defaultArgsAdam StateAdam (XParams width r)
stateAdam SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
           (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
SingletonTK (XParams width r)
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete (XParams width r)
parameters
                           Concrete (ADTensorKind (XParams width r))
Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
           (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
gradient)
           runBatch :: ( Concrete (XParams width r)
                       , StateAdam (XParams width r) )
                    -> (Int, [MnistDataS r])
                    -> IO ( Concrete (XParams width r)
                          , StateAdam (XParams width r) )
           runBatch :: (Concrete (XParams width r), StateAdam (XParams width r))
-> (Int, [MnistDataS r])
-> IO (Concrete (XParams width r), StateAdam (XParams width r))
runBatch (!Concrete (XParams width r)
parameters, !StateAdam (XParams width r)
stateAdam) (Int
k, [MnistDataS r]
chunk) = do
             let chunkS :: [MnistDataBatchS batch_size r]
chunkS = ([MnistDataS r] -> MnistDataBatchS batch_size r)
-> [[MnistDataS r]] -> [MnistDataBatchS batch_size r]
forall a b. (a -> b) -> [a] -> [b]
map [MnistDataS r] -> MnistDataBatchS batch_size r
forall (batch_size :: Nat) r.
(Elt r, KnownNat batch_size) =>
[MnistDataS r] -> MnistDataBatchS batch_size r
mkMnistDataBatchS
                          ([[MnistDataS r]] -> [MnistDataBatchS batch_size r])
-> [[MnistDataS r]] -> [MnistDataBatchS batch_size r]
forall a b. (a -> b) -> a -> b
$ ([MnistDataS r] -> Bool) -> [[MnistDataS r]] -> [[MnistDataS r]]
forall a. (a -> Bool) -> [a] -> [a]
filter (\[MnistDataS r]
ch -> [MnistDataS r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataS r]
ch Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
miniBatchSize)
                          ([[MnistDataS r]] -> [[MnistDataS r]])
-> [[MnistDataS r]] -> [[MnistDataS r]]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataS r] -> [[MnistDataS r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
miniBatchSize [MnistDataS r]
chunk
                 res :: (Concrete (XParams width r), StateAdam (XParams width r))
res@(Concrete (XParams width r)
parameters2, StateAdam (XParams width r)
_) = [MnistDataBatchS batch_size r]
-> (Concrete (XParams width r), StateAdam (XParams width r))
-> (Concrete (XParams width r), StateAdam (XParams width r))
go [MnistDataBatchS batch_size r]
chunkS (Concrete (XParams width r)
parameters, StateAdam (XParams width r)
stateAdam)
                 trainScore :: r
trainScore = Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat ([MnistDataS r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataS r]
chunk) ((forall (n :: Nat). KnownNat n => SNat n -> r) -> r)
-> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
forall a b. (a -> b) -> a -> b
$ \(SNat @len) ->
                   forall (batch_size2 :: Nat).
KnownNat batch_size2 =>
MnistDataBatchS batch_size2 r -> Concrete (XParams width r) -> r
ftest @len ([MnistDataS r] -> MnistDataBatchS n r
forall (batch_size :: Nat) r.
(Elt r, KnownNat batch_size) =>
[MnistDataS r] -> MnistDataBatchS batch_size r
mkMnistDataBatchS [MnistDataS r]
chunk) Concrete (XParams width r)
parameters2
                 testScore :: r
testScore = forall (batch_size2 :: Nat).
KnownNat batch_size2 =>
MnistDataBatchS batch_size2 r -> Concrete (XParams width r) -> r
ftest @lenTestData MnistDataBatchS n r
testDataS Concrete (XParams width r)
parameters2
                 lenChunk :: Int
lenChunk = [MnistDataS r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataS r]
chunk
             Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (SNat width -> Int
forall (n :: Nat). SNat n -> Int
sNatValue SNat width
width Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$ do
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
                 String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: (Batch %d with %d points)"
                        String
prefix Int
k Int
lenChunk
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
                 String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Training error:   %.2f%%"
                        String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
trainScore) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
                 String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Validation error: %.2f%%"
                        String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
testScore ) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
             (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct
               (TKS2
                  ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                  (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
         (TKProduct
            (TKProduct
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
      (TKProduct
         (TKS2
            ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
            (TKScalar r))
         (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct
               (TKS2
                  ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                  (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
         (TKProduct
            (TKProduct
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
      (TKProduct
         (TKS2
            ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
            (TKScalar r))
         (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct
                    (TKS2
                       ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                       (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKProduct
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
      StateAdam
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct
                    (TKS2
                       ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                       (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKProduct
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct
               (TKS2
                  ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                  (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
         (TKProduct
            (TKProduct
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
      (TKProduct
         (TKS2
            ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
            (TKScalar r))
         (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct
               (TKS2
                  ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                  (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
         (TKProduct
            (TKProduct
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
      (TKProduct
         (TKS2
            ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
            (TKScalar r))
         (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
(Concrete (XParams width r), StateAdam (XParams width r))
res
       let runEpoch :: Int
                    -> ( Concrete (XParams width r)
                       , StateAdam (XParams width r) )
                    -> IO (Concrete (XParams width r))
           runEpoch :: Int
-> (Concrete (XParams width r), StateAdam (XParams width r))
-> IO (Concrete (XParams width r))
runEpoch Int
n (Concrete (XParams width r)
params2, StateAdam (XParams width r)
_) | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
epochs = Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
           (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct
                    (TKS2
                       ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                       (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKProduct
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
           (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
Concrete (XParams width r)
params2
           runEpoch Int
n paramsStateAdam :: (Concrete (XParams width r), StateAdam (XParams width r))
paramsStateAdam@(!Concrete (XParams width r)
_, !StateAdam (XParams width r)
_) = do
             Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (SNat width -> Int
forall (n :: Nat). SNat n -> Int
sNatValue SNat width
width Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$ String -> String -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: [Epoch %d]" String
prefix Int
n
             let trainDataShuffled :: [MnistDataS r]
trainDataShuffled = StdGen -> [MnistDataS r] -> [MnistDataS r]
forall a. StdGen -> [a] -> [a]
shuffle (Int -> StdGen
mkStdGen (Int -> StdGen) -> Int -> StdGen
forall a b. (a -> b) -> a -> b
$ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
5) [MnistDataS r]
trainData
                 chunks :: [(Int, [MnistDataS r])]
chunks = Int -> [(Int, [MnistDataS r])] -> [(Int, [MnistDataS r])]
forall a. Int -> [a] -> [a]
take Int
maxBatches
                          ([(Int, [MnistDataS r])] -> [(Int, [MnistDataS r])])
-> [(Int, [MnistDataS r])] -> [(Int, [MnistDataS r])]
forall a b. (a -> b) -> a -> b
$ [Int] -> [[MnistDataS r]] -> [(Int, [MnistDataS r])]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
1 ..]
                          ([[MnistDataS r]] -> [(Int, [MnistDataS r])])
-> [[MnistDataS r]] -> [(Int, [MnistDataS r])]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataS r] -> [[MnistDataS r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
totalBatchSize [MnistDataS r]
trainDataShuffled
             res <- ((Concrete
    (TKProduct
       (TKProduct
          (TKProduct
             (TKProduct
                (TKS2
                   ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                   (TKScalar r))
                (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
             (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
          (TKProduct
             (TKProduct
                (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
             (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
       (TKProduct
          (TKS2
             ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
             (TKScalar r))
          (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
  StateAdam
    (TKProduct
       (TKProduct
          (TKProduct
             (TKProduct
                (TKS2
                   ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                   (TKScalar r))
                (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
             (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
          (TKProduct
             (TKProduct
                (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
             (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
       (TKProduct
          (TKS2
             ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
             (TKScalar r))
          (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
 -> (Int, [MnistDataS r])
 -> IO
      (Concrete
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKProduct
                     (TKS2
                        ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                        (TKScalar r))
                     (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                  (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
               (TKProduct
                  (TKProduct
                     (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                     (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                  (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
            (TKProduct
               (TKS2
                  ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
                  (TKScalar r))
               (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
       StateAdam
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKProduct
                     (TKS2
                        ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                        (TKScalar r))
                     (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                  (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
               (TKProduct
                  (TKProduct
                     (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                     (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                  (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
            (TKProduct
               (TKS2
                  ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
                  (TKScalar r))
               (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKS2
                     ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                     (TKScalar r))
                  (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
               (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
            (TKProduct
               (TKProduct
                  (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                  (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
               (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
         (TKProduct
            (TKS2
               ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
               (TKScalar r))
            (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
    StateAdam
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKS2
                     ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                     (TKScalar r))
                  (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
               (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
            (TKProduct
               (TKProduct
                  (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                  (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
               (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
         (TKProduct
            (TKS2
               ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
               (TKScalar r))
            (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
-> [(Int, [MnistDataS r])]
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct
                    (TKS2
                       ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                       (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKProduct
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
      StateAdam
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct
                    (TKS2
                       ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                       (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKProduct
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
forall (t :: Type -> Type) (m :: Type -> Type) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct
               (TKS2
                  ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                  (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
         (TKProduct
            (TKProduct
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
      (TKProduct
         (TKS2
            ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
            (TKScalar r))
         (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct
               (TKS2
                  ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                  (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
         (TKProduct
            (TKProduct
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
      (TKProduct
         (TKS2
            ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
            (TKScalar r))
         (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
-> (Int, [MnistDataS r])
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct
                    (TKS2
                       ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                       (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKProduct
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
      StateAdam
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct
                    (TKS2
                       ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                       (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKProduct
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                    (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
(Concrete (XParams width r), StateAdam (XParams width r))
-> (Int, [MnistDataS r])
-> IO (Concrete (XParams width r), StateAdam (XParams width r))
runBatch (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct
               (TKS2
                  ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                  (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
         (TKProduct
            (TKProduct
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
      (TKProduct
         (TKS2
            ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
            (TKScalar r))
         (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct
               (TKS2
                  ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                  (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
         (TKProduct
            (TKProduct
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
               (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
      (TKProduct
         (TKS2
            ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
            (TKScalar r))
         (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
(Concrete (XParams width r), StateAdam (XParams width r))
paramsStateAdam [(Int, [MnistDataS r])]
chunks
             runEpoch (succ n) res
       res <- Int
-> (Concrete (XParams width r), StateAdam (XParams width r))
-> IO (Concrete (XParams width r))
runEpoch Int
1 (Concrete (XParams width r)
targetInit, FullShapeTK
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
           (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> StateAdam
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                    (TKScalar r))
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKProduct
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
              (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
forall (y :: TK). FullShapeTK y -> StateAdam y
initialStateAdam FullShapeTK
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
           (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
ftk)
       let testErrorFinal = r
1 r -> r -> r
forall a. Num a => a -> a -> a
- forall (batch_size2 :: Nat).
KnownNat batch_size2 =>
MnistDataBatchS batch_size2 r -> Concrete (XParams width r) -> r
ftest @lenTestData MnistDataBatchS n r
testDataS Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
                 (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
           (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
Concrete (XParams width r)
res
       assertEqualUpToEpsilon 1e-1 expected testErrorFinal

{-# SPECIALIZE mnistTestCaseRNNSO
  :: String
  -> Int -> Int -> SNat width -> SNat batch_size -> Int -> Double
  -> TestTree #-}

tensorADValMnistTestsRNNSO :: TestTree
tensorADValMnistTestsRNNSO :: TestTree
tensorADValMnistTestsRNNSO = String -> [TestTree] -> TestTree
testGroup String
"RNNS Once MNIST tests"
  [ String
-> Int -> Int -> SNat 128 -> SNat 150 -> Int -> Double -> TestTree
forall (width :: Nat) (batch_size :: Nat) r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r,
 (ADTensorScalar r :: Type) ~ (r :: Type)) =>
String
-> Int
-> Int
-> SNat width
-> SNat batch_size
-> Int
-> r
-> TestTree
mnistTestCaseRNNSO String
"RNNSO 1 epoch, 1 batch" Int
1 Int
1 (forall (n :: Nat). KnownNat n => SNat n
SNat @128) (forall (n :: Nat). KnownNat n => SNat n
SNat @150) Int
5000
                       (Double
0.6026 :: Double)
  , String
-> Int -> Int -> SNat 4 -> SNat 5 -> Int -> Float -> TestTree
forall (width :: Nat) (batch_size :: Nat) r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r,
 (ADTensorScalar r :: Type) ~ (r :: Type)) =>
String
-> Int
-> Int
-> SNat width
-> SNat batch_size
-> Int
-> r
-> TestTree
mnistTestCaseRNNSO String
"RNNSO artificial 1 2 3 4 5" Int
2 Int
3 (forall (n :: Nat). KnownNat n => SNat n
SNat @4) (forall (n :: Nat). KnownNat n => SNat n
SNat @5) Int
50
                       (Float
0.8933333 :: Float)
  , String
-> Int -> Int -> SNat 3 -> SNat 2 -> Int -> Double -> TestTree
forall (width :: Nat) (batch_size :: Nat) r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r,
 (ADTensorScalar r :: Type) ~ (r :: Type)) =>
String
-> Int
-> Int
-> SNat width
-> SNat batch_size
-> Int
-> r
-> TestTree
mnistTestCaseRNNSO String
"RNNSO artificial 5 4 3 2 1" Int
5 Int
4 (forall (n :: Nat). KnownNat n => SNat n
SNat @3) (forall (n :: Nat). KnownNat n => SNat n
SNat @2) Int
49
                       (Double
0.9336734693877551 :: Double)
  , String
-> Int -> Int -> SNat 128 -> SNat 150 -> Int -> Float -> TestTree
forall (width :: Nat) (batch_size :: Nat) r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r,
 (ADTensorScalar r :: Type) ~ (r :: Type)) =>
String
-> Int
-> Int
-> SNat width
-> SNat batch_size
-> Int
-> r
-> TestTree
mnistTestCaseRNNSO String
"RNNSO 1 epoch, 0 batch" Int
1 Int
0 (forall (n :: Nat). KnownNat n => SNat n
SNat @128) (forall (n :: Nat). KnownNat n => SNat n
SNat @150) Int
50
                       (Float
1.0 :: Float)
  ]