{-# LANGUAGE OverloadedLists #-}
-- | Tests of "MnistRnnRanked2" recurrent neural networks using a few different
-- optimization pipelines.
--
-- Not LSTM.
-- Doesn't train without Adam, regardless of whether mini-batches used. It does
-- train with Adam, but only after very carefully tweaking initialization.
-- This is extremely sensitive to initial parameters, more than to anything
-- else. Probably, gradient is vanishing if parameters are initialized
-- with a probability distribution that doesn't have the right variance. See
-- https://stats.stackexchange.com/questions/301285/what-is-vanishing-gradient.
-- Regularization/normalization might help as well.
module TestMnistRNNR
  ( testTrees
  ) where

import Prelude

import Control.Monad (foldM, unless)
import System.IO (hPutStrLn, stderr)
import System.Random
import Test.Tasty
import Test.Tasty.HUnit hiding (assert)
import Text.Printf

import Data.Array.Nested.Ranked.Shape

import HordeAd
import HordeAd.Core.Adaptor
import HordeAd.Core.AstEnv
import HordeAd.Core.AstFreshId
import HordeAd.Core.AstInterpret

import EqEpsilon

import MnistData
import MnistRnnRanked2 (ADRnnMnistParameters, ADRnnMnistParametersShaped)
import MnistRnnRanked2 qualified

-- TODO: optimize enough that it can run for one full epoch in reasonable time
-- and then verify it trains down to ~20% validation error in a short enough
-- time to include such a training run in tests.

testTrees :: [TestTree]
testTrees :: [TestTree]
testTrees = [ Item [TestTree]
TestTree
tensorADValMnistTestsRNNRA
            , Item [TestTree]
TestTree
tensorADValMnistTestsRNNRI
            , Item [TestTree]
TestTree
tensorADValMnistTestsRNNRO
            ]

-- POPL differentiation, straight via the ADVal instance of RankedTensor,
-- which side-steps vectorization.
mnistTestCaseRNNRA
  :: forall r.
     (Differentiable r, GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r)
  => String
  -> Int -> Int -> Int -> Int -> Int -> r
  -> TestTree
mnistTestCaseRNNRA :: forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
String -> Int -> Int -> Int -> Int -> Int -> r -> TestTree
mnistTestCaseRNNRA String
prefix Int
epochs Int
maxBatches Int
width Int
miniBatchSize Int
totalBatchSize
                   r
expected =
  Int
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
width ((forall (n :: Nat). KnownNat n => SNat n -> TestTree) -> TestTree)
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat @width) ->
  let targetInit :: NoShape
  (Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKProduct
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
targetInit =
        Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> NoShape
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct
                    (TKS2
                       ((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
                    (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKProduct
                    (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
                    (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct
               (TKS2
                  ((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
               (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
         (TKProduct
            (TKProduct
               (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
               (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
      (TKProduct
         (TKS2
            ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
         (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
 -> NoShape
      (Concrete
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKProduct
                     (TKS2
                        ((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
                     (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
                  (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
               (TKProduct
                  (TKProduct
                     (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
                     (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
                  (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
            (TKProduct
               (TKS2
                  ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
               (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKProduct
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> NoShape
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct
                    (TKS2
                       ((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
                    (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKProduct
                    (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
                    (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
forall a b. (a -> b) -> a -> b
$ (Concrete (X (ADRnnMnistParametersShaped Concrete n r)), StdGen)
-> Concrete (X (ADRnnMnistParametersShaped Concrete n r))
forall a b. (a, b) -> a
fst
        ((Concrete (X (ADRnnMnistParametersShaped Concrete n r)), StdGen)
 -> Concrete (X (ADRnnMnistParametersShaped Concrete n r)))
-> (Concrete (X (ADRnnMnistParametersShaped Concrete n r)), StdGen)
-> Concrete (X (ADRnnMnistParametersShaped Concrete n r))
forall a b. (a -> b) -> a -> b
$ forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue @(Concrete (X (ADRnnMnistParametersShaped
                                       Concrete width r)))
                      Double
0.23 (Int -> StdGen
mkStdGen Int
44)
      name :: String
name = String
prefix String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": "
             String -> String -> String
forall a. [a] -> [a] -> [a]
++ [String] -> String
unwords [ Int -> String
forall a. Show a => a -> String
show Int
epochs, Int -> String
forall a. Show a => a -> String
show Int
maxBatches
                        , Int -> String
forall a. Show a => a -> String
show Int
width, Int -> String
forall a. Show a => a -> String
show Int
miniBatchSize
                        , Int -> String
forall a. Show a => a -> String
show (Int -> String) -> Int -> String
forall a b. (a -> b) -> a -> b
$ SingletonTK (X (ADRnnMnistParameters Concrete r)) -> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK
                          (SingletonTK (X (ADRnnMnistParameters Concrete r)) -> Int)
-> SingletonTK (X (ADRnnMnistParameters Concrete r)) -> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(X (ADRnnMnistParameters Concrete r))
                        , Int -> String
forall a. Show a => a -> String
show (SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r)))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
              (TKR2 1 (TKScalar r)))
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
              (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r)))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r)))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
NoShape
  (Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKProduct
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
targetInit) ]
      ftest :: Int -> MnistDataBatchR r
            -> Concrete (X (ADRnnMnistParameters Concrete r))
            -> r
      ftest :: Int
-> MnistDataBatchR r
-> Concrete (X (ADRnnMnistParameters Concrete r))
-> r
ftest Int
batch_size MnistDataBatchR r
mnistData Concrete (X (ADRnnMnistParameters Concrete r))
pars =
        Int -> MnistDataBatchR r -> ADRnnMnistParameters Concrete r -> r
forall (target :: TK -> Type) r.
((target :: (TK -> Type)) ~ (Concrete :: (TK -> Type)),
 GoodScalar r, Differentiable r) =>
Int -> MnistDataBatchR r -> ADRnnMnistParameters target r -> r
MnistRnnRanked2.rnnMnistTestR
          Int
batch_size MnistDataBatchR r
mnistData (forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget @Concrete Concrete (X (ADRnnMnistParameters Concrete r))
pars)
  in String -> Assertion -> TestTree
testCase String
name (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ do
       Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
         String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: Epochs to run/max batches per epoch: %d/%d"
                String
prefix Int
epochs Int
maxBatches
       trainData <- (MnistData r -> MnistDataR r) -> [MnistData r] -> [MnistDataR r]
forall a b. (a -> b) -> [a] -> [b]
map MnistData r -> MnistDataR r
forall r. PrimElt r => MnistData r -> MnistDataR r
mkMnistDataR
                    ([MnistData r] -> [MnistDataR r])
-> IO [MnistData r] -> IO [MnistDataR r]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> String -> IO [MnistData r]
forall r.
(Storable r, Fractional r) =>
String -> String -> IO [MnistData r]
loadMnistData String
trainGlyphsPath String
trainLabelsPath
       testData <- map mkMnistDataR . take (totalBatchSize * maxBatches)
                   <$> loadMnistData testGlyphsPath testLabelsPath
       let testDataR = [MnistDataR r] -> MnistDataBatchR r
forall r. Elt r => [MnistDataR r] -> MnistDataBatchR r
mkMnistDataBatchR [MnistDataR r]
testData
           f :: MnistDataBatchR r
             -> ADVal Concrete (X (ADRnnMnistParameters Concrete r))
             -> ADVal Concrete (TKScalar r)
           f (Ranked 3 r
glyphR, Ranked 2 r
labelR) ADVal Concrete (X (ADRnnMnistParameters Concrete r))
adinputs =
             Int
-> (PrimalOf (ADVal Concrete) (TKR 3 r),
    PrimalOf (ADVal Concrete) (TKR2 2 (TKScalar r)))
-> ADRnnMnistParameters (ADVal Concrete) r
-> ADVal Concrete (TKScalar r)
forall (target :: TK -> Type) r.
(ADReady target, ADReady (PrimalOf target), GoodScalar r,
 Differentiable r) =>
Int
-> (PrimalOf target (TKR 3 r), PrimalOf target (TKR 2 r))
-> ADRnnMnistParameters target r
-> target (TKScalar r)
MnistRnnRanked2.rnnMnistLossFusedR
               Int
miniBatchSize (Ranked 3 r -> Concrete (TKR 3 r)
forall r (target :: TK -> Type) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 3 r
glyphR, Ranked 2 r -> Concrete (TKR2 2 (TKScalar r))
forall r (target :: TK -> Type) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 2 r
labelR)
               (forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget @(ADVal Concrete) ADVal Concrete (X (ADRnnMnistParameters (ADVal Concrete) r))
ADVal Concrete (X (ADRnnMnistParameters Concrete r))
adinputs)
           runBatch :: ( Concrete (X (ADRnnMnistParameters Concrete r))
                       , StateAdam (X (ADRnnMnistParameters Concrete r)) )
                    -> (Int, [MnistDataR r])
                    -> IO ( Concrete (X (ADRnnMnistParameters Concrete r))
                          , StateAdam (X (ADRnnMnistParameters Concrete r)) )
           runBatch (!Concrete (X (ADRnnMnistParameters Concrete r))
parameters, !StateAdam (X (ADRnnMnistParameters Concrete r))
stateAdam) (Int
k, [MnistDataR r]
chunk) = do
             let chunkR :: [MnistDataBatchR r]
chunkR = ([MnistDataR r] -> MnistDataBatchR r)
-> [[MnistDataR r]] -> [MnistDataBatchR r]
forall a b. (a -> b) -> [a] -> [b]
map [MnistDataR r] -> MnistDataBatchR r
forall r. Elt r => [MnistDataR r] -> MnistDataBatchR r
mkMnistDataBatchR
                          ([[MnistDataR r]] -> [MnistDataBatchR r])
-> [[MnistDataR r]] -> [MnistDataBatchR r]
forall a b. (a -> b) -> a -> b
$ ([MnistDataR r] -> Bool) -> [[MnistDataR r]] -> [[MnistDataR r]]
forall a. (a -> Bool) -> [a] -> [a]
filter (\[MnistDataR r]
ch -> [MnistDataR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataR r]
ch Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
miniBatchSize)
                          ([[MnistDataR r]] -> [[MnistDataR r]])
-> [[MnistDataR r]] -> [[MnistDataR r]]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataR r] -> [[MnistDataR r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
miniBatchSize [MnistDataR r]
chunk
                 res :: (Concrete (X (ADRnnMnistParameters Concrete r)),
 StateAdam (X (ADRnnMnistParameters Concrete r)))
res@(Concrete (X (ADRnnMnistParameters Concrete r))
parameters2, StateAdam (X (ADRnnMnistParameters Concrete r))
_) =
                   forall a (x :: TK) (z :: TK).
KnownSTK x =>
(a -> ADVal Concrete x -> ADVal Concrete z)
-> [a] -> Concrete x -> StateAdam x -> (Concrete x, StateAdam x)
sgdAdam @(MnistDataBatchR r)
                               @(X (ADRnnMnistParameters Concrete r))
                               MnistDataBatchR r
-> ADVal Concrete (X (ADRnnMnistParameters Concrete r))
-> ADVal Concrete (TKScalar r)
f [MnistDataBatchR r]
chunkR Concrete (X (ADRnnMnistParameters Concrete r))
parameters StateAdam (X (ADRnnMnistParameters Concrete r))
stateAdam
                 trainScore :: r
trainScore =
                   Int
-> MnistDataBatchR r
-> Concrete (X (ADRnnMnistParameters Concrete r))
-> r
ftest ([MnistDataR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataR r]
chunk) ([MnistDataR r] -> MnistDataBatchR r
forall r. Elt r => [MnistDataR r] -> MnistDataBatchR r
mkMnistDataBatchR [MnistDataR r]
chunk) Concrete (X (ADRnnMnistParameters Concrete r))
parameters2
                 testScore :: r
testScore =
                   Int
-> MnistDataBatchR r
-> Concrete (X (ADRnnMnistParameters Concrete r))
-> r
ftest ((Int
totalBatchSize Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
maxBatches) Int -> Int -> Int
forall a. Ord a => a -> a -> a
`min` Int
10000)
                         MnistDataBatchR r
testDataR Concrete (X (ADRnnMnistParameters Concrete r))
parameters2
                 lenChunk :: Int
lenChunk = [MnistDataR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataR r]
chunk
             Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
width Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$ do
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
                 String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: (Batch %d with %d points)"
                        String
prefix Int
k Int
lenChunk
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
                 String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Training error:   %.2f%%"
                        String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
trainScore) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
                 String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Validation error: %.2f%%"
                        String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
testScore ) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
             (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r)))
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r))))
      (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r)))
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r))))
      (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r)))
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
      StateAdam
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r)))
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r)))
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r))))
      (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r)))
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r))))
      (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
(Concrete (X (ADRnnMnistParameters Concrete r)),
 StateAdam (X (ADRnnMnistParameters Concrete r)))
res
       let runEpoch :: Int
                    -> ( Concrete (X (ADRnnMnistParameters Concrete r))
                       , StateAdam (X (ADRnnMnistParameters Concrete r)) )
                    -> IO (Concrete (X (ADRnnMnistParameters Concrete r)))
           runEpoch Int
n (Concrete (X (ADRnnMnistParameters Concrete r))
params2, StateAdam (X (ADRnnMnistParameters Concrete r))
_) | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
epochs = Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r)))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r)))
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r)))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
Concrete (X (ADRnnMnistParameters Concrete r))
params2
           runEpoch Int
n paramsStateAdam :: (Concrete (X (ADRnnMnistParameters Concrete r)),
 StateAdam (X (ADRnnMnistParameters Concrete r)))
paramsStateAdam@(!Concrete (X (ADRnnMnistParameters Concrete r))
_, !StateAdam (X (ADRnnMnistParameters Concrete r))
_) = do
             Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
width Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$ String -> String -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: [Epoch %d]" String
prefix Int
n
             let trainDataShuffled :: [MnistDataR r]
trainDataShuffled = StdGen -> [MnistDataR r] -> [MnistDataR r]
forall a. StdGen -> [a] -> [a]
shuffle (Int -> StdGen
mkStdGen (Int -> StdGen) -> Int -> StdGen
forall a b. (a -> b) -> a -> b
$ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
5) [MnistDataR r]
trainData
                 chunks :: [(Int, [MnistDataR r])]
chunks = Int -> [(Int, [MnistDataR r])] -> [(Int, [MnistDataR r])]
forall a. Int -> [a] -> [a]
take Int
maxBatches
                          ([(Int, [MnistDataR r])] -> [(Int, [MnistDataR r])])
-> [(Int, [MnistDataR r])] -> [(Int, [MnistDataR r])]
forall a b. (a -> b) -> a -> b
$ [Int] -> [[MnistDataR r]] -> [(Int, [MnistDataR r])]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
Item [Int]
1 ..]
                          ([[MnistDataR r]] -> [(Int, [MnistDataR r])])
-> [[MnistDataR r]] -> [(Int, [MnistDataR r])]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataR r] -> [[MnistDataR r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
totalBatchSize [MnistDataR r]
trainDataShuffled
             res <- ((Concrete
    (TKProduct
       (TKProduct
          (TKProduct
             (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
             (TKR2 1 (TKScalar r)))
          (TKProduct
             (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
             (TKR2 1 (TKScalar r))))
       (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
  StateAdam
    (TKProduct
       (TKProduct
          (TKProduct
             (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
             (TKR2 1 (TKScalar r)))
          (TKProduct
             (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
             (TKR2 1 (TKScalar r))))
       (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
 -> (Int, [MnistDataR r])
 -> IO
      (Concrete
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                  (TKR2 1 (TKScalar r)))
               (TKProduct
                  (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                  (TKR2 1 (TKScalar r))))
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
       StateAdam
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                  (TKR2 1 (TKScalar r)))
               (TKProduct
                  (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                  (TKR2 1 (TKScalar r))))
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
               (TKR2 1 (TKScalar r)))
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
               (TKR2 1 (TKScalar r))))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
    StateAdam
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
               (TKR2 1 (TKScalar r)))
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
               (TKR2 1 (TKScalar r))))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> [(Int, [MnistDataR r])]
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r)))
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
      StateAdam
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r)))
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
forall (t :: Type -> Type) (m :: Type -> Type) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r)))
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r))))
      (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r)))
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r))))
      (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> (Int, [MnistDataR r])
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r)))
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
      StateAdam
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r)))
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
(Concrete (X (ADRnnMnistParameters Concrete r)),
 StateAdam (X (ADRnnMnistParameters Concrete r)))
-> (Int, [MnistDataR r])
-> IO
     (Concrete (X (ADRnnMnistParameters Concrete r)),
      StateAdam (X (ADRnnMnistParameters Concrete r)))
runBatch (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r)))
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r))))
      (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r)))
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r))))
      (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
(Concrete (X (ADRnnMnistParameters Concrete r)),
 StateAdam (X (ADRnnMnistParameters Concrete r)))
paramsStateAdam [(Int, [MnistDataR r])]
chunks
             runEpoch (succ n) res
           ftk = forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete
                      (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(X (ADRnnMnistParameters Concrete r)))
                      Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r)))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
NoShape
  (Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKProduct
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
targetInit
       res <- runEpoch 1 (targetInit, initialStateAdam ftk)
       let testErrorFinal =
             r
1 r -> r -> r
forall a. Num a => a -> a -> a
- Int
-> MnistDataBatchR r
-> Concrete (X (ADRnnMnistParameters Concrete r))
-> r
ftest ((Int
totalBatchSize Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
maxBatches) Int -> Int -> Int
forall a. Ord a => a -> a -> a
`min` Int
10000)
                       MnistDataBatchR r
testDataR Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r)))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
Concrete (X (ADRnnMnistParameters Concrete r))
res
       testErrorFinal @?~ expected

{-# SPECIALIZE mnistTestCaseRNNRA
  :: String
  -> Int -> Int -> Int -> Int -> Int -> Double
  -> TestTree #-}

tensorADValMnistTestsRNNRA :: TestTree
tensorADValMnistTestsRNNRA :: TestTree
tensorADValMnistTestsRNNRA = String -> [TestTree] -> TestTree
testGroup String
"RNNR ADVal MNIST tests"
  [ String -> Int -> Int -> Int -> Int -> Int -> Double -> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
String -> Int -> Int -> Int -> Int -> Int -> r -> TestTree
mnistTestCaseRNNRA String
"RNNRA 1 epoch, 1 batch" Int
1 Int
1 Int
128 Int
150 Int
5000
                       (Double
0.6026 :: Double)
  , String -> Int -> Int -> Int -> Int -> Int -> Float -> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
String -> Int -> Int -> Int -> Int -> Int -> r -> TestTree
mnistTestCaseRNNRA String
"RNNRA artificial 1 2 3 4 5" Int
2 Int
3 Int
4 Int
5 Int
50
                       (Float
0.8933333 :: Float)
  , String -> Int -> Int -> Int -> Int -> Int -> Double -> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
String -> Int -> Int -> Int -> Int -> Int -> r -> TestTree
mnistTestCaseRNNRA String
"RNNRA artificial 5 4 3 2 1" Int
5 Int
4 Int
3 Int
2 Int
49
                       (Double
0.8622448979591837 :: Double)
  , String -> Int -> Int -> Int -> Int -> Int -> Float -> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
String -> Int -> Int -> Int -> Int -> Int -> r -> TestTree
mnistTestCaseRNNRA String
"RNNRA 1 epoch, 0 batch" Int
1 Int
0 Int
128 Int
150 Int
50
                       (Float
1.0 :: Float)
  ]

-- POPL differentiation, with Ast term defined and vectorized only once,
-- but differentiated anew in each gradient descent iteration.
mnistTestCaseRNNRI
  :: forall r.
     (Differentiable r, GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r)
  => String
  -> Int -> Int -> Int -> Int -> Int -> r
  -> TestTree
mnistTestCaseRNNRI :: forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
String -> Int -> Int -> Int -> Int -> Int -> r -> TestTree
mnistTestCaseRNNRI String
prefix Int
epochs Int
maxBatches Int
width Int
miniBatchSize Int
totalBatchSize
                   r
expected =
  Int
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
width ((forall (n :: Nat). KnownNat n => SNat n -> TestTree) -> TestTree)
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat @width) ->
  let targetInit :: NoShape
  (Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKProduct
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
targetInit =
        Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> NoShape
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct
                    (TKS2
                       ((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
                    (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKProduct
                    (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
                    (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct
               (TKS2
                  ((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
               (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
         (TKProduct
            (TKProduct
               (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
               (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
      (TKProduct
         (TKS2
            ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
         (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
 -> NoShape
      (Concrete
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKProduct
                     (TKS2
                        ((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
                     (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
                  (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
               (TKProduct
                  (TKProduct
                     (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
                     (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
                  (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
            (TKProduct
               (TKS2
                  ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
               (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKProduct
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> NoShape
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct
                    (TKS2
                       ((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
                    (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKProduct
                    (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
                    (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
forall a b. (a -> b) -> a -> b
$ (Concrete (X (ADRnnMnistParametersShaped Concrete n r)), StdGen)
-> Concrete (X (ADRnnMnistParametersShaped Concrete n r))
forall a b. (a, b) -> a
fst
        ((Concrete (X (ADRnnMnistParametersShaped Concrete n r)), StdGen)
 -> Concrete (X (ADRnnMnistParametersShaped Concrete n r)))
-> (Concrete (X (ADRnnMnistParametersShaped Concrete n r)), StdGen)
-> Concrete (X (ADRnnMnistParametersShaped Concrete n r))
forall a b. (a -> b) -> a -> b
$ forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue @(Concrete (X (ADRnnMnistParametersShaped
                                       Concrete width r)))
                      Double
0.23 (Int -> StdGen
mkStdGen Int
44)
      name :: String
name = String
prefix String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": "
             String -> String -> String
forall a. [a] -> [a] -> [a]
++ [String] -> String
unwords [ Int -> String
forall a. Show a => a -> String
show Int
epochs, Int -> String
forall a. Show a => a -> String
show Int
maxBatches
                        , Int -> String
forall a. Show a => a -> String
show Int
width, Int -> String
forall a. Show a => a -> String
show Int
miniBatchSize
                        , Int -> String
forall a. Show a => a -> String
show (Int -> String) -> Int -> String
forall a b. (a -> b) -> a -> b
$ SingletonTK (X (ADRnnMnistParameters Concrete r)) -> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK
                          (SingletonTK (X (ADRnnMnistParameters Concrete r)) -> Int)
-> SingletonTK (X (ADRnnMnistParameters Concrete r)) -> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(X (ADRnnMnistParameters Concrete r))
                        , Int -> String
forall a. Show a => a -> String
show (SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r)))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
              (TKR2 1 (TKScalar r)))
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
              (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r)))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r)))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
NoShape
  (Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKProduct
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
targetInit) ]
      ftest :: Int -> MnistDataBatchR r
            -> Concrete (X (ADRnnMnistParameters Concrete r))
            -> r
      ftest :: Int
-> MnistDataBatchR r
-> Concrete (X (ADRnnMnistParameters Concrete r))
-> r
ftest Int
batch_size MnistDataBatchR r
mnistData Concrete (X (ADRnnMnistParameters Concrete r))
pars =
        Int -> MnistDataBatchR r -> ADRnnMnistParameters Concrete r -> r
forall (target :: TK -> Type) r.
((target :: (TK -> Type)) ~ (Concrete :: (TK -> Type)),
 GoodScalar r, Differentiable r) =>
Int -> MnistDataBatchR r -> ADRnnMnistParameters target r -> r
MnistRnnRanked2.rnnMnistTestR
          Int
batch_size MnistDataBatchR r
mnistData (forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget @Concrete Concrete (X (ADRnnMnistParameters Concrete r))
pars)
  in String -> Assertion -> TestTree
testCase String
name (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ do
       Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
         String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: Epochs to run/max batches per epoch: %d/%d"
                String
prefix Int
epochs Int
maxBatches
       trainData <- (MnistData r -> MnistDataR r) -> [MnistData r] -> [MnistDataR r]
forall a b. (a -> b) -> [a] -> [b]
map MnistData r -> MnistDataR r
forall r. PrimElt r => MnistData r -> MnistDataR r
mkMnistDataR
                    ([MnistData r] -> [MnistDataR r])
-> IO [MnistData r] -> IO [MnistDataR r]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> String -> IO [MnistData r]
forall r.
(Storable r, Fractional r) =>
String -> String -> IO [MnistData r]
loadMnistData String
trainGlyphsPath String
trainLabelsPath
       testData <- map mkMnistDataR . take (totalBatchSize * maxBatches)
                   <$> loadMnistData testGlyphsPath testLabelsPath
       let testDataR = [MnistDataR r] -> MnistDataBatchR r
forall r. Elt r => [MnistDataR r] -> MnistDataBatchR r
mkMnistDataBatchR [MnistDataR r]
testData
           ftk = forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete
                      (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(X (ADRnnMnistParameters Concrete r)))
                      Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r)))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
NoShape
  (Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKProduct
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
targetInit
       (_, _, var, varAst) <- funToAstRevIO ftk
       (varGlyph, astGlyph) <-
         funToAstIO (FTKR (miniBatchSize
                           :$: sizeMnistHeightInt
                           :$: sizeMnistWidthInt
                           :$: ZSR) FTKScalar) id
       (varLabel, astLabel) <-
         funToAstIO (FTKR (miniBatchSize
                           :$: sizeMnistLabelInt
                           :$: ZSR) FTKScalar) id
       let ast :: AstTensor AstMethodLet FullSpan (TKScalar r)
           ast = AstTensor AstMethodLet FullSpan (TKScalar r)
-> AstTensor AstMethodLet FullSpan (TKScalar r)
forall (z :: TK) (s :: AstSpanType).
AstSpan s =>
AstTensor AstMethodLet s z -> AstTensor AstMethodLet s z
simplifyInline
                 (AstTensor AstMethodLet FullSpan (TKScalar r)
 -> AstTensor AstMethodLet FullSpan (TKScalar r))
-> AstTensor AstMethodLet FullSpan (TKScalar r)
-> AstTensor AstMethodLet FullSpan (TKScalar r)
forall a b. (a -> b) -> a -> b
$ Int
-> (PrimalOf
      (AstTensor AstMethodLet FullSpan) (TKR2 3 (TKScalar r)),
    PrimalOf (AstTensor AstMethodLet FullSpan) (TKR2 2 (TKScalar r)))
-> ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) r
-> AstTensor AstMethodLet FullSpan (TKScalar r)
forall (target :: TK -> Type) r.
(ADReady target, ADReady (PrimalOf target), GoodScalar r,
 Differentiable r) =>
Int
-> (PrimalOf target (TKR 3 r), PrimalOf target (TKR 2 r))
-> ADRnnMnistParameters target r
-> target (TKScalar r)
MnistRnnRanked2.rnnMnistLossFusedR
                     Int
miniBatchSize (AstTensor AstMethodLet PrimalSpan (TKR2 3 (TKScalar r))
PrimalOf (AstTensor AstMethodLet FullSpan) (TKR2 3 (TKScalar r))
astGlyph, AstTensor AstMethodLet PrimalSpan (TKR2 2 (TKScalar r))
PrimalOf (AstTensor AstMethodLet FullSpan) (TKR2 2 (TKScalar r))
astLabel)
                     (AstTensor
  AstMethodLet
  FullSpan
  (X (ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) r))
-> ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) r
forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget AstTensor
  AstMethodLet
  FullSpan
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r)))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
AstTensor
  AstMethodLet
  FullSpan
  (X (ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) r))
varAst)
           f :: MnistDataBatchR r
              -> ADVal Concrete (X (ADRnnMnistParameters Concrete r))
              -> ADVal Concrete (TKScalar r)
           f (Ranked 3 r
glyph, Ranked 2 r
label) ADVal Concrete (X (ADRnnMnistParameters Concrete r))
varInputs =
             let env :: AstEnv (ADVal Concrete)
env = AstVarName
  FullSpan
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r)))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> ADVal
     Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
              (TKR2 1 (TKScalar r)))
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
              (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> AstEnv (ADVal Concrete)
-> AstEnv (ADVal Concrete)
forall (target :: TK -> Type) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName
  FullSpan
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r)))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
var ADVal
  Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r)))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
ADVal Concrete (X (ADRnnMnistParameters Concrete r))
varInputs AstEnv (ADVal Concrete)
forall (target :: TK -> Type). AstEnv target
emptyEnv
                 envMnist :: AstEnv (ADVal Concrete)
envMnist = AstVarName PrimalSpan (TKR2 3 (TKScalar r))
-> ADVal Concrete (TKR2 3 (TKScalar r))
-> AstEnv (ADVal Concrete)
-> AstEnv (ADVal Concrete)
forall (target :: TK -> Type) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName PrimalSpan (TKR2 3 (TKScalar r))
varGlyph (Ranked 3 r -> ADVal Concrete (TKR2 3 (TKScalar r))
forall r (target :: TK -> Type) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 3 r
glyph)
                            (AstEnv (ADVal Concrete) -> AstEnv (ADVal Concrete))
-> AstEnv (ADVal Concrete) -> AstEnv (ADVal Concrete)
forall a b. (a -> b) -> a -> b
$ AstVarName PrimalSpan (TKR2 2 (TKScalar r))
-> ADVal Concrete (TKR2 2 (TKScalar r))
-> AstEnv (ADVal Concrete)
-> AstEnv (ADVal Concrete)
forall (target :: TK -> Type) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName PrimalSpan (TKR2 2 (TKScalar r))
varLabel (Ranked 2 r -> ADVal Concrete (TKR2 2 (TKScalar r))
forall r (target :: TK -> Type) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 2 r
label) AstEnv (ADVal Concrete)
env
             in AstEnv (ADVal Concrete)
-> AstTensor AstMethodLet FullSpan (TKScalar r)
-> ADVal Concrete (TKScalar r)
forall (target :: TK -> Type) (y :: TK).
ADReady target =>
AstEnv target -> AstTensor AstMethodLet FullSpan y -> target y
interpretAstFull AstEnv (ADVal Concrete)
envMnist AstTensor AstMethodLet FullSpan (TKScalar r)
ast
           runBatch :: ( Concrete (X (ADRnnMnistParameters Concrete r))
                       , StateAdam (X (ADRnnMnistParameters Concrete r)) )
                    -> (Int, [MnistDataR r])
                    -> IO ( Concrete (X (ADRnnMnistParameters Concrete r))
                          , StateAdam (X (ADRnnMnistParameters Concrete r)) )
           runBatch (!Concrete (X (ADRnnMnistParameters Concrete r))
parameters, !StateAdam (X (ADRnnMnistParameters Concrete r))
stateAdam) (Int
k, [MnistDataR r]
chunk) = do
             let chunkR :: [MnistDataBatchR r]
chunkR = ([MnistDataR r] -> MnistDataBatchR r)
-> [[MnistDataR r]] -> [MnistDataBatchR r]
forall a b. (a -> b) -> [a] -> [b]
map [MnistDataR r] -> MnistDataBatchR r
forall r. Elt r => [MnistDataR r] -> MnistDataBatchR r
mkMnistDataBatchR
                          ([[MnistDataR r]] -> [MnistDataBatchR r])
-> [[MnistDataR r]] -> [MnistDataBatchR r]
forall a b. (a -> b) -> a -> b
$ ([MnistDataR r] -> Bool) -> [[MnistDataR r]] -> [[MnistDataR r]]
forall a. (a -> Bool) -> [a] -> [a]
filter (\[MnistDataR r]
ch -> [MnistDataR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataR r]
ch Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
miniBatchSize)
                          ([[MnistDataR r]] -> [[MnistDataR r]])
-> [[MnistDataR r]] -> [[MnistDataR r]]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataR r] -> [[MnistDataR r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
miniBatchSize [MnistDataR r]
chunk
                 res :: (Concrete (X (ADRnnMnistParameters Concrete r)),
 StateAdam (X (ADRnnMnistParameters Concrete r)))
res@(Concrete (X (ADRnnMnistParameters Concrete r))
parameters2, StateAdam (X (ADRnnMnistParameters Concrete r))
_) =
                   forall a (x :: TK) (z :: TK).
KnownSTK x =>
(a -> ADVal Concrete x -> ADVal Concrete z)
-> [a] -> Concrete x -> StateAdam x -> (Concrete x, StateAdam x)
sgdAdam @(MnistDataBatchR r)
                               @(X (ADRnnMnistParameters Concrete r))
                               MnistDataBatchR r
-> ADVal Concrete (X (ADRnnMnistParameters Concrete r))
-> ADVal Concrete (TKScalar r)
f [MnistDataBatchR r]
chunkR Concrete (X (ADRnnMnistParameters Concrete r))
parameters StateAdam (X (ADRnnMnistParameters Concrete r))
stateAdam
                 trainScore :: r
trainScore =
                   Int
-> MnistDataBatchR r
-> Concrete (X (ADRnnMnistParameters Concrete r))
-> r
ftest ([MnistDataR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataR r]
chunk) ([MnistDataR r] -> MnistDataBatchR r
forall r. Elt r => [MnistDataR r] -> MnistDataBatchR r
mkMnistDataBatchR [MnistDataR r]
chunk) Concrete (X (ADRnnMnistParameters Concrete r))
parameters2
                 testScore :: r
testScore =
                   Int
-> MnistDataBatchR r
-> Concrete (X (ADRnnMnistParameters Concrete r))
-> r
ftest ((Int
totalBatchSize Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
maxBatches) Int -> Int -> Int
forall a. Ord a => a -> a -> a
`min` Int
10000)
                         MnistDataBatchR r
testDataR Concrete (X (ADRnnMnistParameters Concrete r))
parameters2
                 lenChunk :: Int
lenChunk = [MnistDataR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataR r]
chunk
             Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
width Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$ do
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
                 String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: (Batch %d with %d points)"
                        String
prefix Int
k Int
lenChunk
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
                 String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Training error:   %.2f%%"
                        String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
trainScore) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
                 String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Validation error: %.2f%%"
                        String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
testScore ) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
             (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r)))
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r))))
      (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r)))
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r))))
      (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r)))
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
      StateAdam
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r)))
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r)))
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r))))
      (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r)))
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r))))
      (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
(Concrete (X (ADRnnMnistParameters Concrete r)),
 StateAdam (X (ADRnnMnistParameters Concrete r)))
res
       let runEpoch :: Int
                    -> ( Concrete (X (ADRnnMnistParameters Concrete r))
                       , StateAdam (X (ADRnnMnistParameters Concrete r)) )
                    -> IO (Concrete (X (ADRnnMnistParameters Concrete r)))
           runEpoch Int
n (Concrete (X (ADRnnMnistParameters Concrete r))
params2, StateAdam (X (ADRnnMnistParameters Concrete r))
_) | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
epochs = Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r)))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r)))
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r)))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
Concrete (X (ADRnnMnistParameters Concrete r))
params2
           runEpoch Int
n paramsStateAdam :: (Concrete (X (ADRnnMnistParameters Concrete r)),
 StateAdam (X (ADRnnMnistParameters Concrete r)))
paramsStateAdam@(!Concrete (X (ADRnnMnistParameters Concrete r))
_, !StateAdam (X (ADRnnMnistParameters Concrete r))
_) = do
             Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
width Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$ String -> String -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: [Epoch %d]" String
prefix Int
n
             let trainDataShuffled :: [MnistDataR r]
trainDataShuffled = StdGen -> [MnistDataR r] -> [MnistDataR r]
forall a. StdGen -> [a] -> [a]
shuffle (Int -> StdGen
mkStdGen (Int -> StdGen) -> Int -> StdGen
forall a b. (a -> b) -> a -> b
$ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
5) [MnistDataR r]
trainData
                 chunks :: [(Int, [MnistDataR r])]
chunks = Int -> [(Int, [MnistDataR r])] -> [(Int, [MnistDataR r])]
forall a. Int -> [a] -> [a]
take Int
maxBatches
                          ([(Int, [MnistDataR r])] -> [(Int, [MnistDataR r])])
-> [(Int, [MnistDataR r])] -> [(Int, [MnistDataR r])]
forall a b. (a -> b) -> a -> b
$ [Int] -> [[MnistDataR r]] -> [(Int, [MnistDataR r])]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
Item [Int]
1 ..]
                          ([[MnistDataR r]] -> [(Int, [MnistDataR r])])
-> [[MnistDataR r]] -> [(Int, [MnistDataR r])]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataR r] -> [[MnistDataR r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
totalBatchSize [MnistDataR r]
trainDataShuffled
             res <- ((Concrete
    (TKProduct
       (TKProduct
          (TKProduct
             (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
             (TKR2 1 (TKScalar r)))
          (TKProduct
             (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
             (TKR2 1 (TKScalar r))))
       (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
  StateAdam
    (TKProduct
       (TKProduct
          (TKProduct
             (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
             (TKR2 1 (TKScalar r)))
          (TKProduct
             (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
             (TKR2 1 (TKScalar r))))
       (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
 -> (Int, [MnistDataR r])
 -> IO
      (Concrete
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                  (TKR2 1 (TKScalar r)))
               (TKProduct
                  (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                  (TKR2 1 (TKScalar r))))
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
       StateAdam
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                  (TKR2 1 (TKScalar r)))
               (TKProduct
                  (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                  (TKR2 1 (TKScalar r))))
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
               (TKR2 1 (TKScalar r)))
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
               (TKR2 1 (TKScalar r))))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
    StateAdam
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
               (TKR2 1 (TKScalar r)))
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
               (TKR2 1 (TKScalar r))))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> [(Int, [MnistDataR r])]
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r)))
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
      StateAdam
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r)))
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
forall (t :: Type -> Type) (m :: Type -> Type) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r)))
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r))))
      (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r)))
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r))))
      (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> (Int, [MnistDataR r])
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r)))
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
      StateAdam
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r)))
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
(Concrete (X (ADRnnMnistParameters Concrete r)),
 StateAdam (X (ADRnnMnistParameters Concrete r)))
-> (Int, [MnistDataR r])
-> IO
     (Concrete (X (ADRnnMnistParameters Concrete r)),
      StateAdam (X (ADRnnMnistParameters Concrete r)))
runBatch (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r)))
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r))))
      (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r)))
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r))))
      (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
(Concrete (X (ADRnnMnistParameters Concrete r)),
 StateAdam (X (ADRnnMnistParameters Concrete r)))
paramsStateAdam [(Int, [MnistDataR r])]
chunks
             runEpoch (succ n) res
       res <- runEpoch 1 (targetInit, initialStateAdam ftk)
       let testErrorFinal =
             r
1 r -> r -> r
forall a. Num a => a -> a -> a
- Int
-> MnistDataBatchR r
-> Concrete (X (ADRnnMnistParameters Concrete r))
-> r
ftest ((Int
totalBatchSize Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
maxBatches) Int -> Int -> Int
forall a. Ord a => a -> a -> a
`min` Int
10000)
                       MnistDataBatchR r
testDataR Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r)))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
Concrete (X (ADRnnMnistParameters Concrete r))
res
       testErrorFinal @?~ expected

{-# SPECIALIZE mnistTestCaseRNNRI
  :: String
  -> Int -> Int -> Int -> Int -> Int -> Double
  -> TestTree #-}

tensorADValMnistTestsRNNRI :: TestTree
tensorADValMnistTestsRNNRI :: TestTree
tensorADValMnistTestsRNNRI = String -> [TestTree] -> TestTree
testGroup String
"RNNR Intermediate MNIST tests"
  [ String -> Int -> Int -> Int -> Int -> Int -> Double -> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
String -> Int -> Int -> Int -> Int -> Int -> r -> TestTree
mnistTestCaseRNNRI String
"RNNRI 1 epoch, 1 batch" Int
1 Int
1 Int
128 Int
150 Int
5000
                       (Double
0.6026 :: Double)
  , String -> Int -> Int -> Int -> Int -> Int -> Float -> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
String -> Int -> Int -> Int -> Int -> Int -> r -> TestTree
mnistTestCaseRNNRI String
"RNNRI artificial 1 2 3 4 5" Int
2 Int
3 Int
4 Int
5 Int
50
                       (Float
0.8933333 :: Float)
  , String -> Int -> Int -> Int -> Int -> Int -> Double -> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
String -> Int -> Int -> Int -> Int -> Int -> r -> TestTree
mnistTestCaseRNNRI String
"RNNRI artificial 5 4 3 2 1" Int
5 Int
4 Int
3 Int
2 Int
49
                       (Double
0.8622448979591837 :: Double)
  , String -> Int -> Int -> Int -> Int -> Int -> Float -> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
String -> Int -> Int -> Int -> Int -> Int -> r -> TestTree
mnistTestCaseRNNRI String
"RNNRI 1 epoch, 0 batch" Int
1 Int
0 Int
128 Int
150 Int
50
                       (Float
1.0 :: Float)
  ]

-- JAX differentiation, Ast term built and differentiated only once
-- and the result interpreted with different inputs in each gradient
-- descent iteration.
mnistTestCaseRNNRO
  :: forall r.
     ( Differentiable r, GoodScalar r
     , PrintfArg r, AssertEqualUpToEpsilon r, ADTensorScalar r ~ r )
  => String
  -> Int -> Int -> Int -> Int -> Int -> r
  -> TestTree
mnistTestCaseRNNRO :: forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r,
 (ADTensorScalar r :: Type) ~ (r :: Type)) =>
String -> Int -> Int -> Int -> Int -> Int -> r -> TestTree
mnistTestCaseRNNRO String
prefix Int
epochs Int
maxBatches Int
width Int
miniBatchSize Int
totalBatchSize
                   r
expected =
  Int
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
width ((forall (n :: Nat). KnownNat n => SNat n -> TestTree) -> TestTree)
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat @width) ->
  let targetInit :: NoShape
  (Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKProduct
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
targetInit =
        Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
           (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> NoShape
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct
                    (TKS2
                       ((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
                    (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKProduct
                    (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
                    (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct
               (TKS2
                  ((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
               (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
         (TKProduct
            (TKProduct
               (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
               (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
            (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
      (TKProduct
         (TKS2
            ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
         (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
 -> NoShape
      (Concrete
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKProduct
                     (TKS2
                        ((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
                     (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
                  (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
               (TKProduct
                  (TKProduct
                     (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
                     (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
                  (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
            (TKProduct
               (TKS2
                  ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
               (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKProduct
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> NoShape
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct
                    (TKS2
                       ((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
                    (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKProduct
                    (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
                    (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
forall a b. (a -> b) -> a -> b
$ (Concrete (X (ADRnnMnistParametersShaped Concrete n r)), StdGen)
-> Concrete (X (ADRnnMnistParametersShaped Concrete n r))
forall a b. (a, b) -> a
fst
        ((Concrete (X (ADRnnMnistParametersShaped Concrete n r)), StdGen)
 -> Concrete (X (ADRnnMnistParametersShaped Concrete n r)))
-> (Concrete (X (ADRnnMnistParametersShaped Concrete n r)), StdGen)
-> Concrete (X (ADRnnMnistParametersShaped Concrete n r))
forall a b. (a -> b) -> a -> b
$ forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue @(Concrete (X (ADRnnMnistParametersShaped
                                       Concrete width r)))
                      Double
0.23 (Int -> StdGen
mkStdGen Int
44)
      name :: String
name = String
prefix String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": "
             String -> String -> String
forall a. [a] -> [a] -> [a]
++ [String] -> String
unwords [ Int -> String
forall a. Show a => a -> String
show Int
epochs, Int -> String
forall a. Show a => a -> String
show Int
maxBatches
                        , Int -> String
forall a. Show a => a -> String
show Int
width, Int -> String
forall a. Show a => a -> String
show Int
miniBatchSize
                        , Int -> String
forall a. Show a => a -> String
show (Int -> String) -> Int -> String
forall a b. (a -> b) -> a -> b
$ SingletonTK (X (ADRnnMnistParameters Concrete r)) -> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK
                          (SingletonTK (X (ADRnnMnistParameters Concrete r)) -> Int)
-> SingletonTK (X (ADRnnMnistParameters Concrete r)) -> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(X (ADRnnMnistParameters Concrete r))
                        , Int -> String
forall a. Show a => a -> String
show (SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r)))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
              (TKR2 1 (TKScalar r)))
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
              (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r)))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r)))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
NoShape
  (Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKProduct
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
targetInit) ]
      ftest :: Int -> MnistDataBatchR r
            -> Concrete (X (ADRnnMnistParameters Concrete r))
            -> r
      ftest :: Int
-> MnistDataBatchR r
-> Concrete (X (ADRnnMnistParameters Concrete r))
-> r
ftest Int
batch_size MnistDataBatchR r
mnistData Concrete (X (ADRnnMnistParameters Concrete r))
pars =
        Int -> MnistDataBatchR r -> ADRnnMnistParameters Concrete r -> r
forall (target :: TK -> Type) r.
((target :: (TK -> Type)) ~ (Concrete :: (TK -> Type)),
 GoodScalar r, Differentiable r) =>
Int -> MnistDataBatchR r -> ADRnnMnistParameters target r -> r
MnistRnnRanked2.rnnMnistTestR
          Int
batch_size MnistDataBatchR r
mnistData (forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget @Concrete Concrete (X (ADRnnMnistParameters Concrete r))
pars)
  in String -> Assertion -> TestTree
testCase String
name (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ do
       Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
         String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: Epochs to run/max batches per epoch: %d/%d"
                String
prefix Int
epochs Int
maxBatches
       trainData <- (MnistData r -> MnistDataR r) -> [MnistData r] -> [MnistDataR r]
forall a b. (a -> b) -> [a] -> [b]
map MnistData r -> MnistDataR r
forall r. PrimElt r => MnistData r -> MnistDataR r
mkMnistDataR
                    ([MnistData r] -> [MnistDataR r])
-> IO [MnistData r] -> IO [MnistDataR r]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> String -> IO [MnistData r]
forall r.
(Storable r, Fractional r) =>
String -> String -> IO [MnistData r]
loadMnistData String
trainGlyphsPath String
trainLabelsPath
       testData <- map mkMnistDataR . take (totalBatchSize * maxBatches)
                   <$> loadMnistData testGlyphsPath testLabelsPath
       let testDataR = [MnistDataR r] -> MnistDataBatchR r
forall r. Elt r => [MnistDataR r] -> MnistDataBatchR r
mkMnistDataBatchR [MnistDataR r]
testData
           dataInit = case Int -> [MnistDataR r] -> [[MnistDataR r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
miniBatchSize [MnistDataR r]
testData of
             [MnistDataR r]
d : [[MnistDataR r]]
_ -> let (Ranked 3 r
dglyph, Ranked 2 r
dlabel) = [MnistDataR r] -> MnistDataBatchR r
forall r. Elt r => [MnistDataR r] -> MnistDataBatchR r
mkMnistDataBatchR [MnistDataR r]
d
                      in (Ranked 3 r -> Concrete (TKR 3 r)
forall r (target :: TK -> Type) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 3 r
dglyph, Ranked 2 r -> Concrete (TKR2 2 (TKScalar r))
forall r (target :: TK -> Type) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 2 r
dlabel)
             [] -> String -> (Concrete (TKR 3 r), Concrete (TKR2 2 (TKScalar r)))
forall a. HasCallStack => String -> a
error String
"empty test data"
           f :: ( ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) r
                , ( AstTensor AstMethodLet FullSpan (TKR 3 r)
                  , AstTensor AstMethodLet FullSpan (TKR 2 r) ) )
             -> AstTensor AstMethodLet FullSpan (TKScalar r)
           f = \ (ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) r
pars, (AstTensor AstMethodLet FullSpan (TKR 3 r)
glyphR, AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar r))
labelR)) ->
             Int
-> (PrimalOf (AstTensor AstMethodLet FullSpan) (TKR 3 r),
    PrimalOf (AstTensor AstMethodLet FullSpan) (TKR2 2 (TKScalar r)))
-> ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) r
-> AstTensor AstMethodLet FullSpan (TKScalar r)
forall (target :: TK -> Type) r.
(ADReady target, ADReady (PrimalOf target), GoodScalar r,
 Differentiable r) =>
Int
-> (PrimalOf target (TKR 3 r), PrimalOf target (TKR 2 r))
-> ADRnnMnistParameters target r
-> target (TKScalar r)
MnistRnnRanked2.rnnMnistLossFusedR
               Int
miniBatchSize (AstTensor AstMethodLet FullSpan (TKR 3 r)
-> PrimalOf (AstTensor AstMethodLet FullSpan) (TKR 3 r)
forall (target :: TK -> Type) (n :: Nat) (x :: TK).
BaseTensor target =>
target (TKR2 n x) -> PrimalOf target (TKR2 n x)
rprimalPart AstTensor AstMethodLet FullSpan (TKR 3 r)
glyphR, AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar r))
-> PrimalOf (AstTensor AstMethodLet FullSpan) (TKR2 2 (TKScalar r))
forall (target :: TK -> Type) (n :: Nat) (x :: TK).
BaseTensor target =>
target (TKR2 n x) -> PrimalOf target (TKR2 n x)
rprimalPart AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar r))
labelR) ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) r
pars
           artRaw = ((ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) r,
  (AstTensor AstMethodLet FullSpan (TKR 3 r),
   AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar r))))
 -> AstTensor AstMethodLet FullSpan (TKScalar r))
-> Value
     (ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) r,
      (AstTensor AstMethodLet FullSpan (TKR 3 r),
       AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar r))))
-> AstArtifactRev
     (X (ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) r,
         (AstTensor AstMethodLet FullSpan (TKR 3 r),
          AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar r)))))
     (TKScalar r)
forall src r tgt.
((X src :: TK) ~ (X (Value src) :: TK), KnownSTK (X src),
 AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
 AdaptableTarget Concrete (Value src),
 (tgt :: Type)
 ~ (AstTensor AstMethodLet FullSpan (TKScalar r) :: Type)) =>
(src -> tgt) -> Value src -> AstArtifactRev (X src) (TKScalar r)
gradArtifact (ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) r,
 (AstTensor AstMethodLet FullSpan (TKR 3 r),
  AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar r))))
-> AstTensor AstMethodLet FullSpan (TKScalar r)
f (Concrete (X (ADRnnMnistParameters Concrete r))
-> ADRnnMnistParameters Concrete r
forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget Concrete (X (ADRnnMnistParameters Concrete r))
NoShape
  (Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKProduct
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
targetInit, (Concrete (TKR 3 r), Concrete (TKR2 2 (TKScalar r)))
dataInit)
           art = AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
              (TKR2 1 (TKScalar r)))
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
              (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
  (TKScalar r)
-> AstArtifactRev
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r)))
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
        (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
     (TKScalar r)
forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstArtifactRev x z
simplifyArtifactGradient AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
              (TKR2 1 (TKScalar r)))
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
              (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
  (TKScalar r)
AstArtifactRev
  (X (ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) r,
      (AstTensor AstMethodLet FullSpan (TKR 3 r),
       AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar r)))))
  (TKScalar r)
artRaw
           go :: [MnistDataBatchR r]
              -> ( Concrete (X (ADRnnMnistParameters Concrete r))
                 , StateAdam (X (ADRnnMnistParameters Concrete r)) )
              -> ( Concrete (X (ADRnnMnistParameters Concrete r))
                 , StateAdam (X (ADRnnMnistParameters Concrete r)) )
           go [] (Concrete (X (ADRnnMnistParameters Concrete r))
parameters, StateAdam (X (ADRnnMnistParameters Concrete r))
stateAdam) = (Concrete (X (ADRnnMnistParameters Concrete r))
parameters, StateAdam (X (ADRnnMnistParameters Concrete r))
stateAdam)
           go ((Ranked 3 r
glyph, Ranked 2 r
label) : [MnistDataBatchR r]
rest) (!Concrete (X (ADRnnMnistParameters Concrete r))
parameters, !StateAdam (X (ADRnnMnistParameters Concrete r))
stateAdam) =
             let parametersAndInput :: Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
              (TKR2 1 (TKScalar r)))
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
              (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
parametersAndInput =
                   Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r)))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> Concrete (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r)))
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
        (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
forall (x :: TK) (z :: TK).
Concrete x -> Concrete z -> Concrete (TKProduct x z)
forall (target :: TK -> Type) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r)))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
Concrete (X (ADRnnMnistParameters Concrete r))
parameters (Concrete (TKR 3 r)
-> Concrete (TKR2 2 (TKScalar r))
-> Concrete (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r)))
forall (x :: TK) (z :: TK).
Concrete x -> Concrete z -> Concrete (TKProduct x z)
forall (target :: TK -> Type) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair (Ranked 3 r -> Concrete (TKR 3 r)
forall r (target :: TK -> Type) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 3 r
glyph) (Ranked 2 r -> Concrete (TKR2 2 (TKScalar r))
forall r (target :: TK -> Type) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 2 r
label))
                 gradient :: Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r)))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
gradient = Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
              (TKR2 1 (TKScalar r)))
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
              (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
              (TKR2 1 (TKScalar r)))
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
              (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
forall (x :: TK) (z :: TK). Concrete (TKProduct x z) -> Concrete x
forall (target :: TK -> Type) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
               (TKR2 1 (TKScalar r)))
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
               (TKR2 1 (TKScalar r))))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
 -> Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
               (TKR2 1 (TKScalar r)))
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
               (TKR2 1 (TKScalar r))))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r)))
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
        (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
              (TKR2 1 (TKScalar r)))
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
              (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
forall a b. (a -> b) -> a -> b
$ (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
               (TKR2 1 (TKScalar r)))
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
               (TKR2 1 (TKScalar r))))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r)))),
 Concrete (TKScalar r))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r)))
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
        (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
forall a b. (a, b) -> a
fst
                            ((Concrete
    (TKProduct
       (TKProduct
          (TKProduct
             (TKProduct
                (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                (TKR2 1 (TKScalar r)))
             (TKProduct
                (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                (TKR2 1 (TKScalar r))))
          (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
       (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r)))),
  Concrete (TKScalar r))
 -> Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                  (TKR2 1 (TKScalar r)))
               (TKProduct
                  (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                  (TKR2 1 (TKScalar r))))
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
         (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r)))))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                  (TKR2 1 (TKScalar r)))
               (TKProduct
                  (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                  (TKR2 1 (TKScalar r))))
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
         (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r)))),
    Concrete (TKScalar r))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r)))
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
        (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
forall a b. (a -> b) -> a -> b
$ AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
              (TKR2 1 (TKScalar r)))
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
              (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
  (TKScalar r)
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r)))
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
        (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
-> Maybe (Concrete (ADTensorKind (TKScalar r)))
-> (Concrete
      (ADTensorKind
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKProduct
                     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                     (TKR2 1 (TKScalar r)))
                  (TKProduct
                     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                     (TKR2 1 (TKScalar r))))
               (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
            (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))),
    Concrete (TKScalar r))
forall (x :: TK) (z :: TK).
AstArtifactRev x z
-> Concrete x
-> Maybe (Concrete (ADTensorKind z))
-> (Concrete (ADTensorKind x), Concrete z)
revInterpretArtifact
                                AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
              (TKR2 1 (TKScalar r)))
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
              (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
  (TKScalar r)
art Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
              (TKR2 1 (TKScalar r)))
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
              (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
parametersAndInput Maybe (Concrete (ADTensorKind (TKScalar r)))
Maybe (Concrete (TKScalar r))
forall a. Maybe a
Nothing
             in [MnistDataBatchR r]
-> (Concrete (X (ADRnnMnistParameters Concrete r)),
    StateAdam (X (ADRnnMnistParameters Concrete r)))
-> (Concrete (X (ADRnnMnistParameters Concrete r)),
    StateAdam (X (ADRnnMnistParameters Concrete r)))
go [MnistDataBatchR r]
rest (forall (y :: TK).
ArgsAdam
-> StateAdam y
-> SingletonTK y
-> Concrete y
-> Concrete (ADTensorKind y)
-> (Concrete y, StateAdam y)
updateWithGradientAdam
                           @(X (ADRnnMnistParameters Concrete r))
                           ArgsAdam
defaultArgsAdam StateAdam (X (ADRnnMnistParameters Concrete r))
stateAdam SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r)))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
SingletonTK (X (ADRnnMnistParameters Concrete r))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete (X (ADRnnMnistParameters Concrete r))
parameters
                           Concrete (ADTensorKind (X (ADRnnMnistParameters Concrete r)))
Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r)))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
gradient)
           runBatch :: ( Concrete (X (ADRnnMnistParameters Concrete r))
                       , StateAdam (X (ADRnnMnistParameters Concrete r)) )
                    -> (Int, [MnistDataR r])
                    -> IO ( Concrete (X (ADRnnMnistParameters Concrete r))
                          , StateAdam (X (ADRnnMnistParameters Concrete r)) )
           runBatch (!Concrete (X (ADRnnMnistParameters Concrete r))
parameters, !StateAdam (X (ADRnnMnistParameters Concrete r))
stateAdam) (Int
k, [MnistDataR r]
chunk) = do
             let chunkR :: [MnistDataBatchR r]
chunkR = ([MnistDataR r] -> MnistDataBatchR r)
-> [[MnistDataR r]] -> [MnistDataBatchR r]
forall a b. (a -> b) -> [a] -> [b]
map [MnistDataR r] -> MnistDataBatchR r
forall r. Elt r => [MnistDataR r] -> MnistDataBatchR r
mkMnistDataBatchR
                          ([[MnistDataR r]] -> [MnistDataBatchR r])
-> [[MnistDataR r]] -> [MnistDataBatchR r]
forall a b. (a -> b) -> a -> b
$ ([MnistDataR r] -> Bool) -> [[MnistDataR r]] -> [[MnistDataR r]]
forall a. (a -> Bool) -> [a] -> [a]
filter (\[MnistDataR r]
ch -> [MnistDataR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataR r]
ch Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
miniBatchSize)
                          ([[MnistDataR r]] -> [[MnistDataR r]])
-> [[MnistDataR r]] -> [[MnistDataR r]]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataR r] -> [[MnistDataR r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
miniBatchSize [MnistDataR r]
chunk
                 res :: (Concrete (X (ADRnnMnistParameters Concrete r)),
 StateAdam (X (ADRnnMnistParameters Concrete r)))
res@(Concrete (X (ADRnnMnistParameters Concrete r))
parameters2, StateAdam (X (ADRnnMnistParameters Concrete r))
_) = [MnistDataBatchR r]
-> (Concrete (X (ADRnnMnistParameters Concrete r)),
    StateAdam (X (ADRnnMnistParameters Concrete r)))
-> (Concrete (X (ADRnnMnistParameters Concrete r)),
    StateAdam (X (ADRnnMnistParameters Concrete r)))
go [MnistDataBatchR r]
chunkR (Concrete (X (ADRnnMnistParameters Concrete r))
parameters, StateAdam (X (ADRnnMnistParameters Concrete r))
stateAdam)
                 trainScore :: r
trainScore =
                   Int
-> MnistDataBatchR r
-> Concrete (X (ADRnnMnistParameters Concrete r))
-> r
ftest ([MnistDataR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataR r]
chunk) ([MnistDataR r] -> MnistDataBatchR r
forall r. Elt r => [MnistDataR r] -> MnistDataBatchR r
mkMnistDataBatchR [MnistDataR r]
chunk) Concrete (X (ADRnnMnistParameters Concrete r))
parameters2
                 testScore :: r
testScore =
                   Int
-> MnistDataBatchR r
-> Concrete (X (ADRnnMnistParameters Concrete r))
-> r
ftest ((Int
totalBatchSize Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
maxBatches) Int -> Int -> Int
forall a. Ord a => a -> a -> a
`min` Int
10000)
                         MnistDataBatchR r
testDataR Concrete (X (ADRnnMnistParameters Concrete r))
parameters2
                 lenChunk :: Int
lenChunk = [MnistDataR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataR r]
chunk
             Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
width Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$ do
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
                 String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: (Batch %d with %d points)"
                        String
prefix Int
k Int
lenChunk
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
                 String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Training error:   %.2f%%"
                        String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
trainScore) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
                 String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Validation error: %.2f%%"
                        String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
testScore ) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
             (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r)))
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r))))
      (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r)))
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r))))
      (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r)))
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
      StateAdam
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r)))
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r)))
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r))))
      (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r)))
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r))))
      (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
(Concrete (X (ADRnnMnistParameters Concrete r)),
 StateAdam (X (ADRnnMnistParameters Concrete r)))
res
       let runEpoch :: Int
                    -> ( Concrete (X (ADRnnMnistParameters Concrete r))
                       , StateAdam (X (ADRnnMnistParameters Concrete r)) )
                    -> IO (Concrete (X (ADRnnMnistParameters Concrete r)))
           runEpoch Int
n (Concrete (X (ADRnnMnistParameters Concrete r))
params2, StateAdam (X (ADRnnMnistParameters Concrete r))
_) | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
epochs = Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r)))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r)))
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r)))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
Concrete (X (ADRnnMnistParameters Concrete r))
params2
           runEpoch Int
n paramsStateAdam :: (Concrete (X (ADRnnMnistParameters Concrete r)),
 StateAdam (X (ADRnnMnistParameters Concrete r)))
paramsStateAdam@(!Concrete (X (ADRnnMnistParameters Concrete r))
_, !StateAdam (X (ADRnnMnistParameters Concrete r))
_) = do
             Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
width Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$ String -> String -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: [Epoch %d]" String
prefix Int
n
             let trainDataShuffled :: [MnistDataR r]
trainDataShuffled = StdGen -> [MnistDataR r] -> [MnistDataR r]
forall a. StdGen -> [a] -> [a]
shuffle (Int -> StdGen
mkStdGen (Int -> StdGen) -> Int -> StdGen
forall a b. (a -> b) -> a -> b
$ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
5) [MnistDataR r]
trainData
                 chunks :: [(Int, [MnistDataR r])]
chunks = Int -> [(Int, [MnistDataR r])] -> [(Int, [MnistDataR r])]
forall a. Int -> [a] -> [a]
take Int
maxBatches
                          ([(Int, [MnistDataR r])] -> [(Int, [MnistDataR r])])
-> [(Int, [MnistDataR r])] -> [(Int, [MnistDataR r])]
forall a b. (a -> b) -> a -> b
$ [Int] -> [[MnistDataR r]] -> [(Int, [MnistDataR r])]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
Item [Int]
1 ..]
                          ([[MnistDataR r]] -> [(Int, [MnistDataR r])])
-> [[MnistDataR r]] -> [(Int, [MnistDataR r])]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataR r] -> [[MnistDataR r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
totalBatchSize [MnistDataR r]
trainDataShuffled
             res <- ((Concrete
    (TKProduct
       (TKProduct
          (TKProduct
             (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
             (TKR2 1 (TKScalar r)))
          (TKProduct
             (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
             (TKR2 1 (TKScalar r))))
       (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
  StateAdam
    (TKProduct
       (TKProduct
          (TKProduct
             (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
             (TKR2 1 (TKScalar r)))
          (TKProduct
             (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
             (TKR2 1 (TKScalar r))))
       (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
 -> (Int, [MnistDataR r])
 -> IO
      (Concrete
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                  (TKR2 1 (TKScalar r)))
               (TKProduct
                  (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                  (TKR2 1 (TKScalar r))))
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
       StateAdam
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                  (TKR2 1 (TKScalar r)))
               (TKProduct
                  (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                  (TKR2 1 (TKScalar r))))
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
               (TKR2 1 (TKScalar r)))
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
               (TKR2 1 (TKScalar r))))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
    StateAdam
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
               (TKR2 1 (TKScalar r)))
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
               (TKR2 1 (TKScalar r))))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> [(Int, [MnistDataR r])]
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r)))
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
      StateAdam
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r)))
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
forall (t :: Type -> Type) (m :: Type -> Type) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r)))
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r))))
      (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r)))
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r))))
      (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> (Int, [MnistDataR r])
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r)))
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
      StateAdam
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r)))
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
                 (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
(Concrete (X (ADRnnMnistParameters Concrete r)),
 StateAdam (X (ADRnnMnistParameters Concrete r)))
-> (Int, [MnistDataR r])
-> IO
     (Concrete (X (ADRnnMnistParameters Concrete r)),
      StateAdam (X (ADRnnMnistParameters Concrete r)))
runBatch (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r)))
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r))))
      (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r)))
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
            (TKR2 1 (TKScalar r))))
      (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
(Concrete (X (ADRnnMnistParameters Concrete r)),
 StateAdam (X (ADRnnMnistParameters Concrete r)))
paramsStateAdam [(Int, [MnistDataR r])]
chunks
             runEpoch (succ n) res
           ftk = forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(X (ADRnnMnistParameters
                                                 Concrete r)))
                      Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r)))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
NoShape
  (Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKProduct
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
targetInit
       res <- runEpoch 1 (targetInit, initialStateAdam ftk)
       let testErrorFinal =
             r
1 r -> r -> r
forall a. Num a => a -> a -> a
- Int
-> MnistDataBatchR r
-> Concrete (X (ADRnnMnistParameters Concrete r))
-> r
ftest ((Int
totalBatchSize Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
maxBatches) Int -> Int -> Int
forall a. Ord a => a -> a -> a
`min` Int
10000)
                       MnistDataBatchR r
testDataR Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r)))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
           (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
Concrete (X (ADRnnMnistParameters Concrete r))
res
       assertEqualUpToEpsilon 1e-1 expected testErrorFinal

{-# SPECIALIZE mnistTestCaseRNNRO
  :: String
  -> Int -> Int -> Int -> Int -> Int -> Double
  -> TestTree #-}

tensorADValMnistTestsRNNRO :: TestTree
tensorADValMnistTestsRNNRO :: TestTree
tensorADValMnistTestsRNNRO = String -> [TestTree] -> TestTree
testGroup String
"RNNR Once MNIST tests"
  [ String -> Int -> Int -> Int -> Int -> Int -> Double -> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r,
 (ADTensorScalar r :: Type) ~ (r :: Type)) =>
String -> Int -> Int -> Int -> Int -> Int -> r -> TestTree
mnistTestCaseRNNRO String
"RNNRO 1 epoch, 1 batch" Int
1 Int
1 Int
128 Int
150 Int
5000
                       (Double
0.6026 :: Double)
  , String -> Int -> Int -> Int -> Int -> Int -> Float -> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r,
 (ADTensorScalar r :: Type) ~ (r :: Type)) =>
String -> Int -> Int -> Int -> Int -> Int -> r -> TestTree
mnistTestCaseRNNRO String
"RNNRO artificial 1 2 3 4 5" Int
2 Int
3 Int
4 Int
5 Int
50
                       (Float
0.8933333 :: Float)
  , String -> Int -> Int -> Int -> Int -> Int -> Double -> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r,
 (ADTensorScalar r :: Type) ~ (r :: Type)) =>
String -> Int -> Int -> Int -> Int -> Int -> r -> TestTree
mnistTestCaseRNNRO String
"RNNRO artificial 5 4 3 2 1" Int
5 Int
4 Int
3 Int
2 Int
49
                       (Double
0.8928571428571429 :: Double)
  , String -> Int -> Int -> Int -> Int -> Int -> Float -> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r,
 (ADTensorScalar r :: Type) ~ (r :: Type)) =>
String -> Int -> Int -> Int -> Int -> Int -> r -> TestTree
mnistTestCaseRNNRO String
"RNNRO 1 epoch, 0 batch" Int
1 Int
0 Int
128 Int
150 Int
50
                       (Float
1.0 :: Float)
  ]