{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-- | Tests of "MnistCnnRanked2" convolutional neural network
-- using a few different optimization pipelines.
--
-- With the current CPU backend it's slow enough that it's hard to see
-- if it trains.
module TestMnistCNNR
  ( testTrees
  ) where

import Prelude

import Control.Monad (foldM, unless)
import System.IO (hPutStrLn, stderr)
import System.Random
import Test.Tasty
import Test.Tasty.HUnit hiding (assert)
import Text.Printf

import Data.Array.Nested.Ranked.Shape

import HordeAd
import HordeAd.Core.Adaptor
import HordeAd.Core.AstEnv
import HordeAd.Core.AstFreshId
import HordeAd.Core.AstInterpret

import EqEpsilon

import MnistCnnRanked2 qualified
import MnistData

-- TODO: optimize enough that it can run for one full epoch in reasonable time
-- and then verify it trains down to ~20% validation error in a short enough
-- time to include such a training run in tests.

testTrees :: [TestTree]
testTrees :: [TestTree]
testTrees = [ TestTree
tensorADValMnistTestsCNNRA
            , TestTree
tensorADValMnistTestsCNNRI
            , TestTree
tensorADValMnistTestsCNNRO
            ]

type XParams r = X (MnistCnnRanked2.ADCnnMnistParameters Concrete r)

-- POPL differentiation, straight via the ADVal instance of RankedTensor,
-- which side-steps vectorization.
mnistTestCaseCNNRA
  :: forall r.
     (Differentiable r, GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r)
  => String
  -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> r
  -> TestTree
mnistTestCaseCNNRA :: forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNRA String
prefix Int
epochs Int
maxBatches Int
khInt Int
kwInt Int
c_outInt Int
n_hiddenInt
                   Int
miniBatchSize Int
totalBatchSize r
expected =
  Int
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
khInt ((forall (n :: Nat). KnownNat n => SNat n -> TestTree) -> TestTree)
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
_khSNat :: SNat kh) ->
  Int
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
kwInt ((forall (n :: Nat). KnownNat n => SNat n -> TestTree) -> TestTree)
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
_kwSNat :: SNat kw) ->
  Int
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
c_outInt ((forall (n :: Nat). KnownNat n => SNat n -> TestTree) -> TestTree)
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
_c_outSNat :: SNat c_out) ->
  Int
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
n_hiddenInt ((forall (n :: Nat). KnownNat n => SNat n -> TestTree) -> TestTree)
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
_n_hiddenSNat :: SNat n_hidden) ->
  let targetInit :: NoShape
  (Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':)
                    @Nat
                    n
                    ((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
                 (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':)
                    @Nat
                    n
                    ((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
                 (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
targetInit =
        Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Nat
                 n
                 ((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
              (TKScalar r))
           (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Nat
                 n
                 ((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
              (TKScalar r))
           (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
           (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
-> NoShape
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':)
                       @Nat
                       n
                       ((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
                    (TKScalar r))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':)
                       @Nat
                       n
                       ((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
                    (TKScalar r))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKS2
               ((':)
                  @Nat
                  n
                  ((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
               (TKScalar r))
            (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':)
                  @Nat
                  n
                  ((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
               (TKScalar r))
            (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
      (TKProduct
         (TKProduct
            (TKS2
               ((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
            (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
            (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
 -> NoShape
      (Concrete
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKS2
                     ((':)
                        @Nat
                        n
                        ((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
                     (TKScalar r))
                  (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
               (TKProduct
                  (TKS2
                     ((':)
                        @Nat
                        n
                        ((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
                     (TKScalar r))
                  (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
            (TKProduct
               (TKProduct
                  (TKS2
                     ((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
                  (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
               (TKProduct
                  (TKS2
                     ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
                  (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':)
                    @Nat
                    n
                    ((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
                 (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':)
                    @Nat
                    n
                    ((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
                 (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
-> NoShape
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':)
                       @Nat
                       n
                       ((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
                    (TKScalar r))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':)
                       @Nat
                       n
                       ((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
                    (TKScalar r))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
forall a b. (a -> b) -> a -> b
$ (Concrete
   (X (ADCnnMnistParametersShaped
         Concrete SizeMnistHeight SizeMnistHeight n n n n r)),
 StdGen)
-> Concrete
     (X (ADCnnMnistParametersShaped
           Concrete SizeMnistHeight SizeMnistHeight n n n n r))
forall a b. (a, b) -> a
fst
        ((Concrete
    (X (ADCnnMnistParametersShaped
          Concrete SizeMnistHeight SizeMnistHeight n n n n r)),
  StdGen)
 -> Concrete
      (X (ADCnnMnistParametersShaped
            Concrete SizeMnistHeight SizeMnistHeight n n n n r)))
-> (Concrete
      (X (ADCnnMnistParametersShaped
            Concrete SizeMnistHeight SizeMnistHeight n n n n r)),
    StdGen)
-> Concrete
     (X (ADCnnMnistParametersShaped
           Concrete SizeMnistHeight SizeMnistHeight n n n n r))
forall a b. (a -> b) -> a -> b
$ forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue
            @(Concrete (X (MnistCnnRanked2.ADCnnMnistParametersShaped
                             Concrete SizeMnistHeight SizeMnistWidth
                             kh kw c_out n_hidden r)))
            Double
0.4 (Int -> StdGen
mkStdGen Int
44)
      name :: String
name = String
prefix String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": "
             String -> String -> String
forall a. [a] -> [a] -> [a]
++ [String] -> String
unwords [ Int -> String
forall a. Show a => a -> String
show Int
epochs, Int -> String
forall a. Show a => a -> String
show Int
maxBatches
                        , Int -> String
forall a. Show a => a -> String
show Int
khInt, Int -> String
forall a. Show a => a -> String
show Int
kwInt
                        , Int -> String
forall a. Show a => a -> String
show Int
c_outInt, Int -> String
forall a. Show a => a -> String
show Int
n_hiddenInt
                        , Int -> String
forall a. Show a => a -> String
show Int
miniBatchSize
                        , Int -> String
forall a. Show a => a -> String
show (Int -> String) -> Int -> String
forall a b. (a -> b) -> a -> b
$ SingletonTK (XParams r) -> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK (SingletonTK (XParams r) -> Int) -> SingletonTK (XParams r) -> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams r)
                        , Int -> String
forall a. Show a => a -> String
show (SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
NoShape
  (Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':)
                    @Nat
                    n
                    ((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
                 (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':)
                    @Nat
                    n
                    ((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
                 (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
targetInit) ]
      ftest :: Int -> MnistDataBatchR r -> Concrete (XParams r) -> r
      ftest :: Int -> MnistDataBatchR r -> Concrete (XParams r) -> r
ftest Int
batch_size MnistDataBatchR r
mnistData Concrete (XParams r)
pars =
        Int -> MnistDataBatchR r -> ADCnnMnistParameters Concrete r -> r
forall (target :: TK -> Type) r.
((target :: (TK -> Type)) ~ (Concrete :: (TK -> Type)),
 GoodScalar r, Differentiable r) =>
Int -> MnistDataBatchR r -> ADCnnMnistParameters Concrete r -> r
MnistCnnRanked2.convMnistTestR
          Int
batch_size MnistDataBatchR r
mnistData (forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget @Concrete Concrete (XParams r)
pars)
  in String -> Assertion -> TestTree
testCase String
name (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ do
       Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
         String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: Epochs to run/max batches per epoch: %d/%d"
                String
prefix Int
epochs Int
maxBatches
       trainData <- (MnistData r -> MnistDataR r) -> [MnistData r] -> [MnistDataR r]
forall a b. (a -> b) -> [a] -> [b]
map MnistData r -> MnistDataR r
forall r. PrimElt r => MnistData r -> MnistDataR r
mkMnistDataR
                    ([MnistData r] -> [MnistDataR r])
-> IO [MnistData r] -> IO [MnistDataR r]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> String -> IO [MnistData r]
forall r.
(Storable r, Fractional r) =>
String -> String -> IO [MnistData r]
loadMnistData String
trainGlyphsPath String
trainLabelsPath
       testData <- map mkMnistDataR . take (totalBatchSize * maxBatches)
                   <$> loadMnistData testGlyphsPath testLabelsPath
       let testDataR = [MnistDataR r] -> MnistDataBatchR r
forall r. Elt r => [MnistDataR r] -> MnistDataBatchR r
mkMnistDataBatchR [MnistDataR r]
testData
           f :: MnistDataBatchR r -> ADVal Concrete (XParams r)
             -> ADVal Concrete (TKScalar r)
           f (Ranked 3 r
glyphR, Ranked 2 r
labelR) ADVal Concrete (XParams r)
adinputs =
             Int
-> (PrimalOf (ADVal Concrete) (TKR 3 r),
    PrimalOf (ADVal Concrete) (TKR2 2 (TKScalar r)))
-> ADCnnMnistParameters (ADVal Concrete) r
-> ADVal Concrete (TKScalar r)
forall (target :: TK -> Type) r.
(ADReady target, ADReady (PrimalOf target), GoodScalar r,
 Differentiable r) =>
Int
-> (PrimalOf target (TKR 3 r), PrimalOf target (TKR 2 r))
-> ADCnnMnistParameters target r
-> target (TKScalar r)
MnistCnnRanked2.convMnistLossFusedR
               Int
miniBatchSize (Ranked 3 r -> Concrete (TKR 3 r)
forall r (target :: TK -> Type) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 3 r
glyphR, Ranked 2 r -> Concrete (TKR2 2 (TKScalar r))
forall r (target :: TK -> Type) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 2 r
labelR)
               (ADVal Concrete (X (ADCnnMnistParameters (ADVal Concrete) r))
-> ADCnnMnistParameters (ADVal Concrete) r
forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget ADVal Concrete (X (ADCnnMnistParameters (ADVal Concrete) r))
ADVal Concrete (XParams r)
adinputs)
           runBatch :: (Concrete (XParams r), StateAdam (XParams r))
                    -> (Int, [MnistDataR r])
                    -> IO (Concrete (XParams r), StateAdam (XParams r))
           runBatch (!Concrete (XParams r)
parameters, !StateAdam (XParams r)
stateAdam) (Int
k, [MnistDataR r]
chunk) = do
             let chunkR :: [MnistDataBatchR r]
chunkR = ([MnistDataR r] -> MnistDataBatchR r)
-> [[MnistDataR r]] -> [MnistDataBatchR r]
forall a b. (a -> b) -> [a] -> [b]
map [MnistDataR r] -> MnistDataBatchR r
forall r. Elt r => [MnistDataR r] -> MnistDataBatchR r
mkMnistDataBatchR
                          ([[MnistDataR r]] -> [MnistDataBatchR r])
-> [[MnistDataR r]] -> [MnistDataBatchR r]
forall a b. (a -> b) -> a -> b
$ ([MnistDataR r] -> Bool) -> [[MnistDataR r]] -> [[MnistDataR r]]
forall a. (a -> Bool) -> [a] -> [a]
filter (\[MnistDataR r]
ch -> [MnistDataR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataR r]
ch Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
miniBatchSize)
                          ([[MnistDataR r]] -> [[MnistDataR r]])
-> [[MnistDataR r]] -> [[MnistDataR r]]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataR r] -> [[MnistDataR r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
miniBatchSize [MnistDataR r]
chunk
                 res :: (Concrete
   (TKProduct
      (TKProduct
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKProduct
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKProduct
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
res@(Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
parameters2, StateAdam
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
_) =
                   (MnistDataBatchR r
 -> ADVal
      Concrete
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
            (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
 -> ADVal Concrete (TKScalar r))
-> [MnistDataBatchR r]
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> StateAdam
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
            (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
    StateAdam
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
            (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
forall a (x :: TK) (z :: TK).
KnownSTK x =>
(a -> ADVal Concrete x -> ADVal Concrete z)
-> [a] -> Concrete x -> StateAdam x -> (Concrete x, StateAdam x)
sgdAdam MnistDataBatchR r
-> ADVal
     Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> ADVal Concrete (TKScalar r)
MnistDataBatchR r
-> ADVal Concrete (XParams r) -> ADVal Concrete (TKScalar r)
f [MnistDataBatchR r]
chunkR Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
Concrete (XParams r)
parameters StateAdam
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
StateAdam (XParams r)
stateAdam
                 trainScore :: r
trainScore =
                   Int -> MnistDataBatchR r -> Concrete (XParams r) -> r
ftest ([MnistDataR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataR r]
chunk) ([MnistDataR r] -> MnistDataBatchR r
forall r. Elt r => [MnistDataR r] -> MnistDataBatchR r
mkMnistDataBatchR [MnistDataR r]
chunk) Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
Concrete (XParams r)
parameters2
                 testScore :: r
testScore =
                   Int -> MnistDataBatchR r -> Concrete (XParams r) -> r
ftest (Int
totalBatchSize Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
maxBatches) MnistDataBatchR r
testDataR Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
Concrete (XParams r)
parameters2
                 lenChunk :: Int
lenChunk = [MnistDataR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataR r]
chunk
             Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
n_hiddenInt Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$ do
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
                 String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: (Batch %d with %d points)"
                        String
prefix Int
k Int
lenChunk
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
                 String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Training error:   %.2f%%"
                        String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
trainScore) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
                 String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Validation error: %.2f%%"
                        String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
testScore ) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
             (Concrete
   (TKProduct
      (TKProduct
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKProduct
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKProduct
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
      StateAdam
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Concrete
   (TKProduct
      (TKProduct
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKProduct
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKProduct
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
res
       let runEpoch :: Int
                    -> (Concrete (XParams r), StateAdam (XParams r))
                    -> IO (Concrete (XParams r))
           runEpoch Int
n (Concrete (XParams r)
params2, StateAdam (XParams r)
_) | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
epochs = Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
Concrete (XParams r)
params2
           runEpoch Int
n paramsStateAdam :: (Concrete (XParams r), StateAdam (XParams r))
paramsStateAdam@(!Concrete (XParams r)
_, !StateAdam (XParams r)
_) = do
             Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
n_hiddenInt Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$ String -> String -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: [Epoch %d]" String
prefix Int
n
             let trainDataShuffled :: [MnistDataR r]
trainDataShuffled = StdGen -> [MnistDataR r] -> [MnistDataR r]
forall a. StdGen -> [a] -> [a]
shuffle (Int -> StdGen
mkStdGen (Int -> StdGen) -> Int -> StdGen
forall a b. (a -> b) -> a -> b
$ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
5) [MnistDataR r]
trainData
                 chunks :: [(Int, [MnistDataR r])]
chunks = Int -> [(Int, [MnistDataR r])] -> [(Int, [MnistDataR r])]
forall a. Int -> [a] -> [a]
take Int
maxBatches
                          ([(Int, [MnistDataR r])] -> [(Int, [MnistDataR r])])
-> [(Int, [MnistDataR r])] -> [(Int, [MnistDataR r])]
forall a b. (a -> b) -> a -> b
$ [Int] -> [[MnistDataR r]] -> [(Int, [MnistDataR r])]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
1 ..]
                          ([[MnistDataR r]] -> [(Int, [MnistDataR r])])
-> [[MnistDataR r]] -> [(Int, [MnistDataR r])]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataR r] -> [[MnistDataR r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
totalBatchSize [MnistDataR r]
trainDataShuffled
             res <- ((Concrete
    (TKProduct
       (TKProduct
          (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
          (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
       (TKProduct
          (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
          (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
  StateAdam
    (TKProduct
       (TKProduct
          (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
          (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
       (TKProduct
          (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
          (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
 -> (Int, [MnistDataR r])
 -> IO
      (Concrete
         (TKProduct
            (TKProduct
               (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
               (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
               (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
       StateAdam
         (TKProduct
            (TKProduct
               (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
               (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
               (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
            (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
    StateAdam
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
            (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
-> [(Int, [MnistDataR r])]
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
      StateAdam
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
forall (t :: Type -> Type) (m :: Type -> Type) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Concrete
   (TKProduct
      (TKProduct
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKProduct
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKProduct
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
-> (Int, [MnistDataR r])
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
      StateAdam
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
(Concrete (XParams r), StateAdam (XParams r))
-> (Int, [MnistDataR r])
-> IO (Concrete (XParams r), StateAdam (XParams r))
runBatch (Concrete
   (TKProduct
      (TKProduct
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKProduct
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKProduct
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
(Concrete (XParams r), StateAdam (XParams r))
paramsStateAdam [(Int, [MnistDataR r])]
chunks
             runEpoch (succ n) res
           ftk = forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams r)) Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
NoShape
  (Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':)
                    @Nat
                    n
                    ((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
                 (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':)
                    @Nat
                    n
                    ((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
                 (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
targetInit
       res <- runEpoch 1 (targetInit, initialStateAdam ftk)
       let testErrorFinal =
             r
1 r -> r -> r
forall a. Num a => a -> a -> a
- Int -> MnistDataBatchR r -> Concrete (XParams r) -> r
ftest (Int
totalBatchSize Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
maxBatches) MnistDataBatchR r
testDataR Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
Concrete (XParams r)
res
       testErrorFinal @?~ expected

{-# SPECIALIZE mnistTestCaseCNNRA
  :: String
  -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Double
  -> TestTree #-}

tensorADValMnistTestsCNNRA :: TestTree
tensorADValMnistTestsCNNRA :: TestTree
tensorADValMnistTestsCNNRA = String -> [TestTree] -> TestTree
testGroup String
"CNNR ADVal MNIST tests"
  [ String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Double
-> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNRA String
"CNNRA 1 epoch, 1 batch" Int
1 Int
1 Int
4 Int
4 Int
8 Int
16 Int
1 Int
1
                       (Double
1 :: Double)
  , String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Float
-> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNRA String
"CNNRA artificial 1 2 3 4 5" Int
1 Int
1 Int
2 Int
3 Int
4 Int
5 Int
1 Int
10
                       (Float
1 :: Float)
  , String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Double
-> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNRA String
"CNNRA artificial 5 4 3 2 1" Int
5 Int
4 Int
3 Int
2 Int
1 Int
1 Int
1 Int
1
                       (Double
1 :: Double)
  , String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Float
-> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNRA String
"CNNRA 1 epoch, 0 batch" Int
1 Int
0 Int
4 Int
4 Int
16 Int
64 Int
16 Int
50
                       (Float
1.0 :: Float)
  ]

-- POPL differentiation, with Ast term defined and vectorized only once,
-- but differentiated anew in each gradient descent iteration.
mnistTestCaseCNNRI
  :: forall r.
     (Differentiable r, GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r)
  => String
  -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> r
  -> TestTree
mnistTestCaseCNNRI :: forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNRI String
prefix Int
epochs Int
maxBatches Int
khInt Int
kwInt Int
c_outInt Int
n_hiddenInt
                   Int
miniBatchSize Int
totalBatchSize r
expected =
  Int
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
khInt ((forall (n :: Nat). KnownNat n => SNat n -> TestTree) -> TestTree)
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
_khSNat :: SNat kh) ->
  Int
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
kwInt ((forall (n :: Nat). KnownNat n => SNat n -> TestTree) -> TestTree)
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
_kwSNat :: SNat kw) ->
  Int
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
c_outInt ((forall (n :: Nat). KnownNat n => SNat n -> TestTree) -> TestTree)
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
_c_outSNat :: SNat c_out) ->
  Int
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
n_hiddenInt ((forall (n :: Nat). KnownNat n => SNat n -> TestTree) -> TestTree)
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
_n_hiddenSNat :: SNat n_hidden) ->
  let targetInit :: NoShape
  (Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':)
                    @Nat
                    n
                    ((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
                 (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':)
                    @Nat
                    n
                    ((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
                 (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
targetInit =
        Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Nat
                 n
                 ((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
              (TKScalar r))
           (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Nat
                 n
                 ((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
              (TKScalar r))
           (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
           (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
-> NoShape
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':)
                       @Nat
                       n
                       ((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
                    (TKScalar r))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':)
                       @Nat
                       n
                       ((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
                    (TKScalar r))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKS2
               ((':)
                  @Nat
                  n
                  ((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
               (TKScalar r))
            (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':)
                  @Nat
                  n
                  ((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
               (TKScalar r))
            (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
      (TKProduct
         (TKProduct
            (TKS2
               ((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
            (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
            (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
 -> NoShape
      (Concrete
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKS2
                     ((':)
                        @Nat
                        n
                        ((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
                     (TKScalar r))
                  (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
               (TKProduct
                  (TKS2
                     ((':)
                        @Nat
                        n
                        ((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
                     (TKScalar r))
                  (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
            (TKProduct
               (TKProduct
                  (TKS2
                     ((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
                  (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
               (TKProduct
                  (TKS2
                     ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
                  (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':)
                    @Nat
                    n
                    ((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
                 (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':)
                    @Nat
                    n
                    ((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
                 (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
-> NoShape
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':)
                       @Nat
                       n
                       ((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
                    (TKScalar r))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':)
                       @Nat
                       n
                       ((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
                    (TKScalar r))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
forall a b. (a -> b) -> a -> b
$ (Concrete
   (X (ADCnnMnistParametersShaped
         Concrete SizeMnistHeight SizeMnistHeight n n n n r)),
 StdGen)
-> Concrete
     (X (ADCnnMnistParametersShaped
           Concrete SizeMnistHeight SizeMnistHeight n n n n r))
forall a b. (a, b) -> a
fst
        ((Concrete
    (X (ADCnnMnistParametersShaped
          Concrete SizeMnistHeight SizeMnistHeight n n n n r)),
  StdGen)
 -> Concrete
      (X (ADCnnMnistParametersShaped
            Concrete SizeMnistHeight SizeMnistHeight n n n n r)))
-> (Concrete
      (X (ADCnnMnistParametersShaped
            Concrete SizeMnistHeight SizeMnistHeight n n n n r)),
    StdGen)
-> Concrete
     (X (ADCnnMnistParametersShaped
           Concrete SizeMnistHeight SizeMnistHeight n n n n r))
forall a b. (a -> b) -> a -> b
$ forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue
            @(Concrete (X (MnistCnnRanked2.ADCnnMnistParametersShaped
                             Concrete SizeMnistHeight SizeMnistWidth
                             kh kw c_out n_hidden r)))
            Double
0.4 (Int -> StdGen
mkStdGen Int
44)
      name :: String
name = String
prefix String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": "
             String -> String -> String
forall a. [a] -> [a] -> [a]
++ [String] -> String
unwords [ Int -> String
forall a. Show a => a -> String
show Int
epochs, Int -> String
forall a. Show a => a -> String
show Int
maxBatches
                        , Int -> String
forall a. Show a => a -> String
show Int
khInt, Int -> String
forall a. Show a => a -> String
show Int
kwInt
                        , Int -> String
forall a. Show a => a -> String
show Int
c_outInt, Int -> String
forall a. Show a => a -> String
show Int
n_hiddenInt
                        , Int -> String
forall a. Show a => a -> String
show Int
miniBatchSize
                        , Int -> String
forall a. Show a => a -> String
show (Int -> String) -> Int -> String
forall a b. (a -> b) -> a -> b
$ SingletonTK (XParams r) -> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK (SingletonTK (XParams r) -> Int) -> SingletonTK (XParams r) -> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams r)
                        , Int -> String
forall a. Show a => a -> String
show (SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
NoShape
  (Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':)
                    @Nat
                    n
                    ((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
                 (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':)
                    @Nat
                    n
                    ((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
                 (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
targetInit) ]
      ftest :: Int -> MnistDataBatchR r -> Concrete (XParams r) -> r
      ftest :: Int -> MnistDataBatchR r -> Concrete (XParams r) -> r
ftest Int
batch_size MnistDataBatchR r
mnistData Concrete (XParams r)
pars =
        Int -> MnistDataBatchR r -> ADCnnMnistParameters Concrete r -> r
forall (target :: TK -> Type) r.
((target :: (TK -> Type)) ~ (Concrete :: (TK -> Type)),
 GoodScalar r, Differentiable r) =>
Int -> MnistDataBatchR r -> ADCnnMnistParameters Concrete r -> r
MnistCnnRanked2.convMnistTestR
          Int
batch_size MnistDataBatchR r
mnistData (forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget @Concrete Concrete (XParams r)
pars)
  in String -> Assertion -> TestTree
testCase String
name (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ do
       Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
         String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: Epochs to run/max batches per epoch: %d/%d"
                String
prefix Int
epochs Int
maxBatches
       trainData <- (MnistData r -> MnistDataR r) -> [MnistData r] -> [MnistDataR r]
forall a b. (a -> b) -> [a] -> [b]
map MnistData r -> MnistDataR r
forall r. PrimElt r => MnistData r -> MnistDataR r
mkMnistDataR
                    ([MnistData r] -> [MnistDataR r])
-> IO [MnistData r] -> IO [MnistDataR r]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> String -> IO [MnistData r]
forall r.
(Storable r, Fractional r) =>
String -> String -> IO [MnistData r]
loadMnistData String
trainGlyphsPath String
trainLabelsPath
       testData <- map mkMnistDataR . take (totalBatchSize * maxBatches)
                   <$> loadMnistData testGlyphsPath testLabelsPath
       let testDataR = [MnistDataR r] -> MnistDataBatchR r
forall r. Elt r => [MnistDataR r] -> MnistDataBatchR r
mkMnistDataBatchR [MnistDataR r]
testData
           ftk = forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams r)) Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
NoShape
  (Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':)
                    @Nat
                    n
                    ((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
                 (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':)
                    @Nat
                    n
                    ((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
                 (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
targetInit
       (_, _, var, varAst2) <- funToAstRevIO ftk
       (varGlyph, astGlyph) <-
         funToAstIO (FTKR (miniBatchSize
                           :$: sizeMnistHeightInt
                           :$: sizeMnistWidthInt
                           :$: ZSR) FTKScalar) id
       (varLabel, astLabel) <-
         funToAstIO (FTKR (miniBatchSize
                           :$: sizeMnistLabelInt
                           :$: ZSR) FTKScalar) id
       let ast :: AstTensor AstMethodLet FullSpan (TKScalar r)
           ast = AstTensor AstMethodLet FullSpan (TKScalar r)
-> AstTensor AstMethodLet FullSpan (TKScalar r)
forall (z :: TK) (s :: AstSpanType).
AstSpan s =>
AstTensor AstMethodLet s z -> AstTensor AstMethodLet s z
simplifyInline
                 (AstTensor AstMethodLet FullSpan (TKScalar r)
 -> AstTensor AstMethodLet FullSpan (TKScalar r))
-> AstTensor AstMethodLet FullSpan (TKScalar r)
-> AstTensor AstMethodLet FullSpan (TKScalar r)
forall a b. (a -> b) -> a -> b
$ Int
-> (PrimalOf
      (AstTensor AstMethodLet FullSpan) (TKR2 3 (TKScalar r)),
    PrimalOf (AstTensor AstMethodLet FullSpan) (TKR2 2 (TKScalar r)))
-> ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) r
-> AstTensor AstMethodLet FullSpan (TKScalar r)
forall (target :: TK -> Type) r.
(ADReady target, ADReady (PrimalOf target), GoodScalar r,
 Differentiable r) =>
Int
-> (PrimalOf target (TKR 3 r), PrimalOf target (TKR 2 r))
-> ADCnnMnistParameters target r
-> target (TKScalar r)
MnistCnnRanked2.convMnistLossFusedR
                     Int
miniBatchSize (AstTensor AstMethodLet PrimalSpan (TKR2 3 (TKScalar r))
PrimalOf (AstTensor AstMethodLet FullSpan) (TKR2 3 (TKScalar r))
astGlyph, AstTensor AstMethodLet PrimalSpan (TKR2 2 (TKScalar r))
PrimalOf (AstTensor AstMethodLet FullSpan) (TKR2 2 (TKScalar r))
astLabel)
                     (AstTensor
  AstMethodLet
  FullSpan
  (X (ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) r))
-> ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) r
forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget AstTensor
  AstMethodLet
  FullSpan
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
AstTensor
  AstMethodLet
  FullSpan
  (X (ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) r))
varAst2)
           f :: MnistDataBatchR r -> ADVal Concrete (XParams r)
             -> ADVal Concrete (TKScalar r)
           f (Ranked 3 r
glyph, Ranked 2 r
label) ADVal Concrete (XParams r)
varInputs =
             let env :: AstEnv (ADVal Concrete)
env = AstVarName
  FullSpan
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> ADVal
     Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> AstEnv (ADVal Concrete)
-> AstEnv (ADVal Concrete)
forall (target :: TK -> Type) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName
  FullSpan
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
var ADVal
  Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
ADVal Concrete (XParams r)
varInputs AstEnv (ADVal Concrete)
forall (target :: TK -> Type). AstEnv target
emptyEnv
                 envMnist :: AstEnv (ADVal Concrete)
envMnist = AstVarName PrimalSpan (TKR2 3 (TKScalar r))
-> ADVal Concrete (TKR2 3 (TKScalar r))
-> AstEnv (ADVal Concrete)
-> AstEnv (ADVal Concrete)
forall (target :: TK -> Type) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName PrimalSpan (TKR2 3 (TKScalar r))
varGlyph (Ranked 3 r -> ADVal Concrete (TKR2 3 (TKScalar r))
forall r (target :: TK -> Type) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 3 r
glyph)
                            (AstEnv (ADVal Concrete) -> AstEnv (ADVal Concrete))
-> AstEnv (ADVal Concrete) -> AstEnv (ADVal Concrete)
forall a b. (a -> b) -> a -> b
$ AstVarName PrimalSpan (TKR2 2 (TKScalar r))
-> ADVal Concrete (TKR2 2 (TKScalar r))
-> AstEnv (ADVal Concrete)
-> AstEnv (ADVal Concrete)
forall (target :: TK -> Type) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName PrimalSpan (TKR2 2 (TKScalar r))
varLabel (Ranked 2 r -> ADVal Concrete (TKR2 2 (TKScalar r))
forall r (target :: TK -> Type) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 2 r
label) AstEnv (ADVal Concrete)
env
             in AstEnv (ADVal Concrete)
-> AstTensor AstMethodLet FullSpan (TKScalar r)
-> ADVal Concrete (TKScalar r)
forall (target :: TK -> Type) (y :: TK).
ADReady target =>
AstEnv target -> AstTensor AstMethodLet FullSpan y -> target y
interpretAstFull AstEnv (ADVal Concrete)
envMnist AstTensor AstMethodLet FullSpan (TKScalar r)
ast
           runBatch :: (Concrete (XParams r), StateAdam (XParams r))
                    -> (Int, [MnistDataR r])
                    -> IO (Concrete (XParams r), StateAdam (XParams r))
           runBatch (!Concrete (XParams r)
parameters, !StateAdam (XParams r)
stateAdam) (Int
k, [MnistDataR r]
chunk) = do
             let chunkR :: [MnistDataBatchR r]
chunkR = ([MnistDataR r] -> MnistDataBatchR r)
-> [[MnistDataR r]] -> [MnistDataBatchR r]
forall a b. (a -> b) -> [a] -> [b]
map [MnistDataR r] -> MnistDataBatchR r
forall r. Elt r => [MnistDataR r] -> MnistDataBatchR r
mkMnistDataBatchR
                          ([[MnistDataR r]] -> [MnistDataBatchR r])
-> [[MnistDataR r]] -> [MnistDataBatchR r]
forall a b. (a -> b) -> a -> b
$ ([MnistDataR r] -> Bool) -> [[MnistDataR r]] -> [[MnistDataR r]]
forall a. (a -> Bool) -> [a] -> [a]
filter (\[MnistDataR r]
ch -> [MnistDataR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataR r]
ch Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
miniBatchSize)
                          ([[MnistDataR r]] -> [[MnistDataR r]])
-> [[MnistDataR r]] -> [[MnistDataR r]]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataR r] -> [[MnistDataR r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
miniBatchSize [MnistDataR r]
chunk
                 res :: (Concrete
   (TKProduct
      (TKProduct
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKProduct
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKProduct
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
res@(Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
parameters2, StateAdam
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
_) =
                   (MnistDataBatchR r
 -> ADVal
      Concrete
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
            (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
 -> ADVal Concrete (TKScalar r))
-> [MnistDataBatchR r]
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> StateAdam
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
            (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
    StateAdam
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
            (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
forall a (x :: TK) (z :: TK).
KnownSTK x =>
(a -> ADVal Concrete x -> ADVal Concrete z)
-> [a] -> Concrete x -> StateAdam x -> (Concrete x, StateAdam x)
sgdAdam MnistDataBatchR r
-> ADVal
     Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> ADVal Concrete (TKScalar r)
MnistDataBatchR r
-> ADVal Concrete (XParams r) -> ADVal Concrete (TKScalar r)
f [MnistDataBatchR r]
chunkR Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
Concrete (XParams r)
parameters StateAdam
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
StateAdam (XParams r)
stateAdam
                 !trainScore :: r
trainScore =
                   Int -> MnistDataBatchR r -> Concrete (XParams r) -> r
ftest ([MnistDataR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataR r]
chunk) ([MnistDataR r] -> MnistDataBatchR r
forall r. Elt r => [MnistDataR r] -> MnistDataBatchR r
mkMnistDataBatchR [MnistDataR r]
chunk) Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
Concrete (XParams r)
parameters2
                 !testScore :: r
testScore =
                   Int -> MnistDataBatchR r -> Concrete (XParams r) -> r
ftest (Int
totalBatchSize Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
maxBatches) MnistDataBatchR r
testDataR Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
Concrete (XParams r)
parameters2
                 !lenChunk :: Int
lenChunk = [MnistDataR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataR r]
chunk
             Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
n_hiddenInt Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$ do
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
                 String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: (Batch %d with %d points)"
                        String
prefix Int
k Int
lenChunk
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
                 String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Training error:   %.2f%%"
                        String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
trainScore) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
                 String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Validation error: %.2f%%"
                        String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
testScore ) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
             (Concrete
   (TKProduct
      (TKProduct
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKProduct
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKProduct
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
      StateAdam
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Concrete
   (TKProduct
      (TKProduct
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKProduct
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKProduct
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
res
       let runEpoch :: Int
                    -> (Concrete (XParams r), StateAdam (XParams r))
                    -> IO (Concrete (XParams r))
           runEpoch Int
n (Concrete (XParams r)
params2, StateAdam (XParams r)
_) | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
epochs = Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
Concrete (XParams r)
params2
           runEpoch Int
n paramsStateAdam :: (Concrete (XParams r), StateAdam (XParams r))
paramsStateAdam@(!Concrete (XParams r)
_, !StateAdam (XParams r)
_) = do
             Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
n_hiddenInt Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$ String -> String -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: [Epoch %d]" String
prefix Int
n
             let trainDataShuffled :: [MnistDataR r]
trainDataShuffled = StdGen -> [MnistDataR r] -> [MnistDataR r]
forall a. StdGen -> [a] -> [a]
shuffle (Int -> StdGen
mkStdGen (Int -> StdGen) -> Int -> StdGen
forall a b. (a -> b) -> a -> b
$ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
5) [MnistDataR r]
trainData
                 chunks :: [(Int, [MnistDataR r])]
chunks = Int -> [(Int, [MnistDataR r])] -> [(Int, [MnistDataR r])]
forall a. Int -> [a] -> [a]
take Int
maxBatches
                          ([(Int, [MnistDataR r])] -> [(Int, [MnistDataR r])])
-> [(Int, [MnistDataR r])] -> [(Int, [MnistDataR r])]
forall a b. (a -> b) -> a -> b
$ [Int] -> [[MnistDataR r]] -> [(Int, [MnistDataR r])]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
1 ..]
                          ([[MnistDataR r]] -> [(Int, [MnistDataR r])])
-> [[MnistDataR r]] -> [(Int, [MnistDataR r])]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataR r] -> [[MnistDataR r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
totalBatchSize [MnistDataR r]
trainDataShuffled
             res <- ((Concrete
    (TKProduct
       (TKProduct
          (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
          (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
       (TKProduct
          (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
          (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
  StateAdam
    (TKProduct
       (TKProduct
          (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
          (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
       (TKProduct
          (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
          (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
 -> (Int, [MnistDataR r])
 -> IO
      (Concrete
         (TKProduct
            (TKProduct
               (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
               (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
               (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
       StateAdam
         (TKProduct
            (TKProduct
               (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
               (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
               (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
            (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
    StateAdam
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
            (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
-> [(Int, [MnistDataR r])]
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
      StateAdam
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
forall (t :: Type -> Type) (m :: Type -> Type) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Concrete
   (TKProduct
      (TKProduct
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKProduct
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKProduct
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
-> (Int, [MnistDataR r])
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
      StateAdam
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
(Concrete (XParams r), StateAdam (XParams r))
-> (Int, [MnistDataR r])
-> IO (Concrete (XParams r), StateAdam (XParams r))
runBatch (Concrete
   (TKProduct
      (TKProduct
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKProduct
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKProduct
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
(Concrete (XParams r), StateAdam (XParams r))
paramsStateAdam [(Int, [MnistDataR r])]
chunks
             runEpoch (succ n) res
       res <- runEpoch 1 (targetInit, initialStateAdam ftk)
       let testErrorFinal =
             r
1 r -> r -> r
forall a. Num a => a -> a -> a
- Int -> MnistDataBatchR r -> Concrete (XParams r) -> r
ftest (Int
totalBatchSize Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
maxBatches) MnistDataBatchR r
testDataR Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
Concrete (XParams r)
res
       testErrorFinal @?~ expected

{-# SPECIALIZE mnistTestCaseCNNRI
  :: String
  -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Double
  -> TestTree #-}

tensorADValMnistTestsCNNRI :: TestTree
tensorADValMnistTestsCNNRI :: TestTree
tensorADValMnistTestsCNNRI = String -> [TestTree] -> TestTree
testGroup String
"CNNR Intermediate MNIST tests"
  [ String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Double
-> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNRI String
"CNNRI 1 epoch, 1 batch" Int
1 Int
1 Int
4 Int
4 Int
8 Int
16 Int
1 Int
1
                       (Double
1 :: Double)
  , String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Float
-> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNRI String
"CNNRI artificial 1 2 3 4 5" Int
1 Int
1 Int
2 Int
3 Int
4 Int
5 Int
1 Int
10
                       (Float
1 :: Float)
  , String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Double
-> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNRI String
"CNNRI artificial 5 4 3 2 1" Int
5 Int
4 Int
3 Int
2 Int
1 Int
1 Int
1 Int
1
                       (Double
1 :: Double)
  , String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Float
-> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNRI String
"CNNRI 1 epoch, 0 batch" Int
1 Int
0 Int
4 Int
4 Int
16 Int
64 Int
16 Int
50
                       (Float
1.0 :: Float)
  ]

-- JAX differentiation, Ast term built and differentiated only once
-- and the result interpreted with different inputs in each gradient
-- descent iteration.
mnistTestCaseCNNRO
  :: forall r.
     ( Differentiable r, GoodScalar r
     , PrintfArg r, AssertEqualUpToEpsilon r, ADTensorScalar r ~ r )
  => String
  -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> r
  -> TestTree
mnistTestCaseCNNRO :: forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r,
 (ADTensorScalar r :: Type) ~ (r :: Type)) =>
String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNRO String
prefix Int
epochs Int
maxBatches Int
khInt Int
kwInt Int
c_outInt Int
n_hiddenInt
                   Int
miniBatchSize Int
totalBatchSize r
expected =
  Int
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
khInt ((forall (n :: Nat). KnownNat n => SNat n -> TestTree) -> TestTree)
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
_khSNat :: SNat kh) ->
  Int
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
kwInt ((forall (n :: Nat). KnownNat n => SNat n -> TestTree) -> TestTree)
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
_kwSNat :: SNat kw) ->
  Int
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
c_outInt ((forall (n :: Nat). KnownNat n => SNat n -> TestTree) -> TestTree)
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
_c_outSNat :: SNat c_out) ->
  Int
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
n_hiddenInt ((forall (n :: Nat). KnownNat n => SNat n -> TestTree) -> TestTree)
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
_n_hiddenSNat :: SNat n_hidden) ->
  let targetInit :: NoShape
  (Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':)
                    @Nat
                    n
                    ((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
                 (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':)
                    @Nat
                    n
                    ((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
                 (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
targetInit =
        Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2
              ((':)
                 @Nat
                 n
                 ((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
              (TKScalar r))
           (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':)
                 @Nat
                 n
                 ((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
              (TKScalar r))
           (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKProduct
           (TKS2
              ((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
           (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
-> NoShape
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':)
                       @Nat
                       n
                       ((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
                    (TKScalar r))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':)
                       @Nat
                       n
                       ((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
                    (TKScalar r))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKS2
               ((':)
                  @Nat
                  n
                  ((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
               (TKScalar r))
            (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':)
                  @Nat
                  n
                  ((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
               (TKScalar r))
            (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
      (TKProduct
         (TKProduct
            (TKS2
               ((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
            (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
         (TKProduct
            (TKS2
               ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
            (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
 -> NoShape
      (Concrete
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKS2
                     ((':)
                        @Nat
                        n
                        ((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
                     (TKScalar r))
                  (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
               (TKProduct
                  (TKS2
                     ((':)
                        @Nat
                        n
                        ((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
                     (TKScalar r))
                  (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
            (TKProduct
               (TKProduct
                  (TKS2
                     ((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
                  (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
               (TKProduct
                  (TKS2
                     ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
                  (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':)
                    @Nat
                    n
                    ((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
                 (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':)
                    @Nat
                    n
                    ((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
                 (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
-> NoShape
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':)
                       @Nat
                       n
                       ((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
                    (TKScalar r))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':)
                       @Nat
                       n
                       ((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
                    (TKScalar r))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKProduct
                 (TKS2
                    ((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKS2
                    ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
forall a b. (a -> b) -> a -> b
$ (Concrete
   (X (ADCnnMnistParametersShaped
         Concrete SizeMnistHeight SizeMnistHeight n n n n r)),
 StdGen)
-> Concrete
     (X (ADCnnMnistParametersShaped
           Concrete SizeMnistHeight SizeMnistHeight n n n n r))
forall a b. (a, b) -> a
fst
        ((Concrete
    (X (ADCnnMnistParametersShaped
          Concrete SizeMnistHeight SizeMnistHeight n n n n r)),
  StdGen)
 -> Concrete
      (X (ADCnnMnistParametersShaped
            Concrete SizeMnistHeight SizeMnistHeight n n n n r)))
-> (Concrete
      (X (ADCnnMnistParametersShaped
            Concrete SizeMnistHeight SizeMnistHeight n n n n r)),
    StdGen)
-> Concrete
     (X (ADCnnMnistParametersShaped
           Concrete SizeMnistHeight SizeMnistHeight n n n n r))
forall a b. (a -> b) -> a -> b
$ forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue
            @(Concrete (X (MnistCnnRanked2.ADCnnMnistParametersShaped
                             Concrete SizeMnistHeight SizeMnistWidth
                             kh kw c_out n_hidden r)))
            Double
0.4 (Int -> StdGen
mkStdGen Int
44)
      name :: String
name = String
prefix String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": "
             String -> String -> String
forall a. [a] -> [a] -> [a]
++ [String] -> String
unwords [ Int -> String
forall a. Show a => a -> String
show Int
epochs, Int -> String
forall a. Show a => a -> String
show Int
maxBatches
                        , Int -> String
forall a. Show a => a -> String
show Int
khInt, Int -> String
forall a. Show a => a -> String
show Int
kwInt
                        , Int -> String
forall a. Show a => a -> String
show Int
c_outInt, Int -> String
forall a. Show a => a -> String
show Int
n_hiddenInt
                        , Int -> String
forall a. Show a => a -> String
show Int
miniBatchSize
                        , Int -> String
forall a. Show a => a -> String
show (Int -> String) -> Int -> String
forall a b. (a -> b) -> a -> b
$ SingletonTK (XParams r) -> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK (SingletonTK (XParams r) -> Int) -> SingletonTK (XParams r) -> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams r)
                        , Int -> String
forall a. Show a => a -> String
show (SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
NoShape
  (Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':)
                    @Nat
                    n
                    ((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
                 (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':)
                    @Nat
                    n
                    ((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
                 (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
targetInit) ]
      ftest :: Int -> MnistDataBatchR r -> Concrete (XParams r) -> r
      ftest :: Int -> MnistDataBatchR r -> Concrete (XParams r) -> r
ftest Int
batch_size MnistDataBatchR r
mnistData Concrete (XParams r)
pars =
        Int -> MnistDataBatchR r -> ADCnnMnistParameters Concrete r -> r
forall (target :: TK -> Type) r.
((target :: (TK -> Type)) ~ (Concrete :: (TK -> Type)),
 GoodScalar r, Differentiable r) =>
Int -> MnistDataBatchR r -> ADCnnMnistParameters Concrete r -> r
MnistCnnRanked2.convMnistTestR
          Int
batch_size MnistDataBatchR r
mnistData (forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget @Concrete Concrete (XParams r)
pars)
  in String -> Assertion -> TestTree
testCase String
name (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ do
       Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
         String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: Epochs to run/max batches per epoch: %d/%d"
                String
prefix Int
epochs Int
maxBatches
       trainData <- (MnistData r -> MnistDataR r) -> [MnistData r] -> [MnistDataR r]
forall a b. (a -> b) -> [a] -> [b]
map MnistData r -> MnistDataR r
forall r. PrimElt r => MnistData r -> MnistDataR r
mkMnistDataR
                    ([MnistData r] -> [MnistDataR r])
-> IO [MnistData r] -> IO [MnistDataR r]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> String -> IO [MnistData r]
forall r.
(Storable r, Fractional r) =>
String -> String -> IO [MnistData r]
loadMnistData String
trainGlyphsPath String
trainLabelsPath
       testData <- map mkMnistDataR . take (totalBatchSize * maxBatches)
                   <$> loadMnistData testGlyphsPath testLabelsPath
       let testDataR = [MnistDataR r] -> MnistDataBatchR r
forall r. Elt r => [MnistDataR r] -> MnistDataBatchR r
mkMnistDataBatchR [MnistDataR r]
testData
           dataInit = case Int -> [MnistDataR r] -> [[MnistDataR r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
miniBatchSize [MnistDataR r]
testData of
             [MnistDataR r]
d : [[MnistDataR r]]
_ -> let (Ranked 3 r
dglyph, Ranked 2 r
dlabel) = [MnistDataR r] -> MnistDataBatchR r
forall r. Elt r => [MnistDataR r] -> MnistDataBatchR r
mkMnistDataBatchR [MnistDataR r]
d
                      in (Ranked 3 r -> Concrete (TKR 3 r)
forall r (target :: TK -> Type) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 3 r
dglyph, Ranked 2 r -> Concrete (TKR2 2 (TKScalar r))
forall r (target :: TK -> Type) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 2 r
dlabel)
             [] -> String -> (Concrete (TKR 3 r), Concrete (TKR2 2 (TKScalar r)))
forall a. HasCallStack => String -> a
error String
"empty test data"
           f :: ( MnistCnnRanked2.ADCnnMnistParameters
                    (AstTensor AstMethodLet FullSpan) r
                , ( AstTensor AstMethodLet FullSpan (TKR 3 r)
                  , AstTensor AstMethodLet FullSpan (TKR 2 r) ) )
             -> AstTensor AstMethodLet FullSpan (TKScalar r)
           f = \ (ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) r
pars, (AstTensor AstMethodLet FullSpan (TKR 3 r)
glyphR, AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar r))
labelR)) ->
             Int
-> (PrimalOf (AstTensor AstMethodLet FullSpan) (TKR 3 r),
    PrimalOf (AstTensor AstMethodLet FullSpan) (TKR2 2 (TKScalar r)))
-> ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) r
-> AstTensor AstMethodLet FullSpan (TKScalar r)
forall (target :: TK -> Type) r.
(ADReady target, ADReady (PrimalOf target), GoodScalar r,
 Differentiable r) =>
Int
-> (PrimalOf target (TKR 3 r), PrimalOf target (TKR 2 r))
-> ADCnnMnistParameters target r
-> target (TKScalar r)
MnistCnnRanked2.convMnistLossFusedR
               Int
miniBatchSize (AstTensor AstMethodLet FullSpan (TKR 3 r)
-> PrimalOf (AstTensor AstMethodLet FullSpan) (TKR 3 r)
forall (target :: TK -> Type) (n :: Nat) (x :: TK).
BaseTensor target =>
target (TKR2 n x) -> PrimalOf target (TKR2 n x)
rprimalPart AstTensor AstMethodLet FullSpan (TKR 3 r)
glyphR, AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar r))
-> PrimalOf (AstTensor AstMethodLet FullSpan) (TKR2 2 (TKScalar r))
forall (target :: TK -> Type) (n :: Nat) (x :: TK).
BaseTensor target =>
target (TKR2 n x) -> PrimalOf target (TKR2 n x)
rprimalPart AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar r))
labelR) ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) r
pars
           artRaw = ((ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) r,
  (AstTensor AstMethodLet FullSpan (TKR 3 r),
   AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar r))))
 -> AstTensor AstMethodLet FullSpan (TKScalar r))
-> Value
     (ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) r,
      (AstTensor AstMethodLet FullSpan (TKR 3 r),
       AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar r))))
-> AstArtifactRev
     (X (ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) r,
         (AstTensor AstMethodLet FullSpan (TKR 3 r),
          AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar r)))))
     (TKScalar r)
forall src r tgt.
((X src :: TK) ~ (X (Value src) :: TK), KnownSTK (X src),
 AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
 AdaptableTarget Concrete (Value src),
 (tgt :: Type)
 ~ (AstTensor AstMethodLet FullSpan (TKScalar r) :: Type)) =>
(src -> tgt) -> Value src -> AstArtifactRev (X src) (TKScalar r)
gradArtifact (ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) r,
 (AstTensor AstMethodLet FullSpan (TKR 3 r),
  AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar r))))
-> AstTensor AstMethodLet FullSpan (TKScalar r)
f (Concrete (XParams r) -> ADCnnMnistParameters Concrete r
forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget Concrete (XParams r)
NoShape
  (Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':)
                    @Nat
                    n
                    ((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
                 (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':)
                    @Nat
                    n
                    ((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
                 (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
targetInit, (Concrete (TKR 3 r), Concrete (TKR2 2 (TKScalar r)))
dataInit)
           art = AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
     (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
  (TKScalar r)
-> AstArtifactRev
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
        (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
     (TKScalar r)
forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstArtifactRev x z
simplifyArtifactGradient AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
     (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
  (TKScalar r)
AstArtifactRev
  (X (ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) r,
      (AstTensor AstMethodLet FullSpan (TKR 3 r),
       AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar r)))))
  (TKScalar r)
artRaw
           go :: [MnistDataBatchR r]
              -> (Concrete (XParams r), StateAdam (XParams r))
              -> (Concrete (XParams r), StateAdam (XParams r))
           go [] (Concrete (XParams r)
parameters, StateAdam (XParams r)
stateAdam) = (Concrete (XParams r)
parameters, StateAdam (XParams r)
stateAdam)
           go ((Ranked 3 r
glyph, Ranked 2 r
label) : [MnistDataBatchR r]
rest) (!Concrete (XParams r)
parameters, !StateAdam (XParams r)
stateAdam) =
             let parametersAndInput :: Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
     (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
parametersAndInput =
                   Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> Concrete (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
        (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
forall (x :: TK) (z :: TK).
Concrete x -> Concrete z -> Concrete (TKProduct x z)
forall (target :: TK -> Type) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
Concrete (XParams r)
parameters (Concrete (TKR 3 r)
-> Concrete (TKR2 2 (TKScalar r))
-> Concrete (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r)))
forall (x :: TK) (z :: TK).
Concrete x -> Concrete z -> Concrete (TKProduct x z)
forall (target :: TK -> Type) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair (Ranked 3 r -> Concrete (TKR 3 r)
forall r (target :: TK -> Type) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 3 r
glyph) (Ranked 2 r -> Concrete (TKR2 2 (TKScalar r))
forall r (target :: TK -> Type) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 2 r
label))
                 gradient :: Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
gradient =
                   Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
     (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
forall (x :: TK) (z :: TK). Concrete (TKProduct x z) -> Concrete x
forall (target :: TK -> Type) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
            (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
      (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
 -> Concrete
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
            (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
        (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
forall a b. (a -> b) -> a -> b
$ (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
            (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
      (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r)))),
 Concrete (TKScalar r))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
        (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
forall a b. (a, b) -> a
fst
                   ((Concrete
    (TKProduct
       (TKProduct
          (TKProduct
             (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
             (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
          (TKProduct
             (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
             (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
       (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r)))),
  Concrete (TKScalar r))
 -> Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
               (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
               (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
         (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r)))))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
               (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
               (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
         (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r)))),
    Concrete (TKScalar r))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
        (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
forall a b. (a -> b) -> a -> b
$ AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
     (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
  (TKScalar r)
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
        (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
-> Maybe (Concrete (ADTensorKind (TKScalar r)))
-> (Concrete
      (ADTensorKind
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
                  (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
               (TKProduct
                  (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
                  (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
            (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))),
    Concrete (TKScalar r))
forall (x :: TK) (z :: TK).
AstArtifactRev x z
-> Concrete x
-> Maybe (Concrete (ADTensorKind z))
-> (Concrete (ADTensorKind x), Concrete z)
revInterpretArtifact AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
     (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
  (TKScalar r)
art Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
     (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
parametersAndInput Maybe (Concrete (ADTensorKind (TKScalar r)))
Maybe (Concrete (TKScalar r))
forall a. Maybe a
Nothing
             in [MnistDataBatchR r]
-> (Concrete (XParams r), StateAdam (XParams r))
-> (Concrete (XParams r), StateAdam (XParams r))
go [MnistDataBatchR r]
rest (forall (y :: TK).
ArgsAdam
-> StateAdam y
-> SingletonTK y
-> Concrete y
-> Concrete (ADTensorKind y)
-> (Concrete y, StateAdam y)
updateWithGradientAdam
                           @(XParams r)
                           ArgsAdam
defaultArgsAdam StateAdam (XParams r)
stateAdam SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
SingletonTK (XParams r)
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete (XParams r)
parameters
                           Concrete (ADTensorKind (XParams r))
Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
gradient)
           runBatch :: (Concrete (XParams r), StateAdam (XParams r))
                    -> (Int, [MnistDataR r])
                    -> IO (Concrete (XParams r), StateAdam (XParams r))
           runBatch (!Concrete (XParams r)
parameters, !StateAdam (XParams r)
stateAdam) (Int
k, [MnistDataR r]
chunk) = do
             let chunkR :: [MnistDataBatchR r]
chunkR = ([MnistDataR r] -> MnistDataBatchR r)
-> [[MnistDataR r]] -> [MnistDataBatchR r]
forall a b. (a -> b) -> [a] -> [b]
map [MnistDataR r] -> MnistDataBatchR r
forall r. Elt r => [MnistDataR r] -> MnistDataBatchR r
mkMnistDataBatchR
                          ([[MnistDataR r]] -> [MnistDataBatchR r])
-> [[MnistDataR r]] -> [MnistDataBatchR r]
forall a b. (a -> b) -> a -> b
$ ([MnistDataR r] -> Bool) -> [[MnistDataR r]] -> [[MnistDataR r]]
forall a. (a -> Bool) -> [a] -> [a]
filter (\[MnistDataR r]
ch -> [MnistDataR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataR r]
ch Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
miniBatchSize)
                          ([[MnistDataR r]] -> [[MnistDataR r]])
-> [[MnistDataR r]] -> [[MnistDataR r]]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataR r] -> [[MnistDataR r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
miniBatchSize [MnistDataR r]
chunk
                 res :: (Concrete (XParams r), StateAdam (XParams r))
res@(Concrete (XParams r)
parameters2, StateAdam (XParams r)
_) = [MnistDataBatchR r]
-> (Concrete (XParams r), StateAdam (XParams r))
-> (Concrete (XParams r), StateAdam (XParams r))
go [MnistDataBatchR r]
chunkR (Concrete (XParams r)
parameters, StateAdam (XParams r)
stateAdam)
                 trainScore :: r
trainScore =
                   Int -> MnistDataBatchR r -> Concrete (XParams r) -> r
ftest ([MnistDataR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataR r]
chunk) ([MnistDataR r] -> MnistDataBatchR r
forall r. Elt r => [MnistDataR r] -> MnistDataBatchR r
mkMnistDataBatchR [MnistDataR r]
chunk) Concrete (XParams r)
parameters2
                 testScore :: r
testScore =
                   Int -> MnistDataBatchR r -> Concrete (XParams r) -> r
ftest (Int
totalBatchSize Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
maxBatches) MnistDataBatchR r
testDataR Concrete (XParams r)
parameters2
                 lenChunk :: Int
lenChunk = [MnistDataR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataR r]
chunk
             Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
n_hiddenInt Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$ do
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
                 String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: (Batch %d with %d points)"
                        String
prefix Int
k Int
lenChunk
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
                 String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Training error:   %.2f%%"
                        String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
trainScore) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
                 String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Validation error: %.2f%%"
                        String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
testScore ) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
             (Concrete
   (TKProduct
      (TKProduct
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKProduct
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKProduct
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
      StateAdam
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Concrete
   (TKProduct
      (TKProduct
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKProduct
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKProduct
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
(Concrete (XParams r), StateAdam (XParams r))
res
       let runEpoch :: Int
                    -> (Concrete (XParams r), StateAdam (XParams r))
                    -> IO (Concrete (XParams r))
           runEpoch Int
n (Concrete (XParams r)
params2, StateAdam (XParams r)
_) | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
epochs = Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
Concrete (XParams r)
params2
           runEpoch Int
n paramsStateAdam :: (Concrete (XParams r), StateAdam (XParams r))
paramsStateAdam@(!Concrete (XParams r)
_, !StateAdam (XParams r)
_) = do
             Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
n_hiddenInt Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$
               Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$ String -> String -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: [Epoch %d]" String
prefix Int
n
             let trainDataShuffled :: [MnistDataR r]
trainDataShuffled = StdGen -> [MnistDataR r] -> [MnistDataR r]
forall a. StdGen -> [a] -> [a]
shuffle (Int -> StdGen
mkStdGen (Int -> StdGen) -> Int -> StdGen
forall a b. (a -> b) -> a -> b
$ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
5) [MnistDataR r]
trainData
                 chunks :: [(Int, [MnistDataR r])]
chunks = Int -> [(Int, [MnistDataR r])] -> [(Int, [MnistDataR r])]
forall a. Int -> [a] -> [a]
take Int
maxBatches
                          ([(Int, [MnistDataR r])] -> [(Int, [MnistDataR r])])
-> [(Int, [MnistDataR r])] -> [(Int, [MnistDataR r])]
forall a b. (a -> b) -> a -> b
$ [Int] -> [[MnistDataR r]] -> [(Int, [MnistDataR r])]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
1 ..]
                          ([[MnistDataR r]] -> [(Int, [MnistDataR r])])
-> [[MnistDataR r]] -> [(Int, [MnistDataR r])]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataR r] -> [[MnistDataR r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
totalBatchSize [MnistDataR r]
trainDataShuffled
             res <- ((Concrete
    (TKProduct
       (TKProduct
          (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
          (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
       (TKProduct
          (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
          (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
  StateAdam
    (TKProduct
       (TKProduct
          (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
          (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
       (TKProduct
          (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
          (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
 -> (Int, [MnistDataR r])
 -> IO
      (Concrete
         (TKProduct
            (TKProduct
               (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
               (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
               (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
       StateAdam
         (TKProduct
            (TKProduct
               (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
               (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
               (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
            (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
    StateAdam
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
            (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
-> [(Int, [MnistDataR r])]
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
      StateAdam
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
forall (t :: Type -> Type) (m :: Type -> Type) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Concrete
   (TKProduct
      (TKProduct
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKProduct
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKProduct
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
-> (Int, [MnistDataR r])
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
      StateAdam
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
(Concrete (XParams r), StateAdam (XParams r))
-> (Int, [MnistDataR r])
-> IO (Concrete (XParams r), StateAdam (XParams r))
runBatch (Concrete
   (TKProduct
      (TKProduct
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKProduct
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
 StateAdam
   (TKProduct
      (TKProduct
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKProduct
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
(Concrete (XParams r), StateAdam (XParams r))
paramsStateAdam [(Int, [MnistDataR r])]
chunks
             runEpoch (succ n) res
           ftk = forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams r)) Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
NoShape
  (Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2
                 ((':)
                    @Nat
                    n
                    ((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
                 (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':)
                    @Nat
                    n
                    ((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
                 (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKProduct
              (TKS2
                 ((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
targetInit
       res <- runEpoch 1 (targetInit, initialStateAdam ftk)
       let testErrorFinal =
             r
1 r -> r -> r
forall a. Num a => a -> a -> a
- Int -> MnistDataBatchR r -> Concrete (XParams r) -> r
ftest (Int
totalBatchSize Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
maxBatches) MnistDataBatchR r
testDataR Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
Concrete (XParams r)
res
       assertEqualUpToEpsilon 1e-1 expected testErrorFinal

{-# SPECIALIZE mnistTestCaseCNNRO
  :: String
  -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Double
  -> TestTree #-}

tensorADValMnistTestsCNNRO :: TestTree
tensorADValMnistTestsCNNRO :: TestTree
tensorADValMnistTestsCNNRO = String -> [TestTree] -> TestTree
testGroup String
"CNNR Once MNIST tests"
  [ String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Double
-> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r,
 (ADTensorScalar r :: Type) ~ (r :: Type)) =>
String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNRO String
"CNNRO 1 epoch, 1 batch" Int
1 Int
1 Int
4 Int
4 Int
8 Int
16 Int
1 Int
1
                       (Double
1 :: Double)
  , String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Float
-> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r,
 (ADTensorScalar r :: Type) ~ (r :: Type)) =>
String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNRO String
"CNNRO artificial 1 2 3 4 5" Int
1 Int
1 Int
2 Int
3 Int
4 Int
5 Int
1 Int
10
                       (Float
1 :: Float)
  , String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Double
-> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r,
 (ADTensorScalar r :: Type) ~ (r :: Type)) =>
String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNRO String
"CNNRO artificial 5 4 3 2 1" Int
5 Int
4 Int
3 Int
2 Int
1 Int
1 Int
1 Int
1
                       (Double
1 :: Double)
  , String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Float
-> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r,
 (ADTensorScalar r :: Type) ~ (r :: Type)) =>
String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNRO String
"CNNRO 1 epoch, 0 batch" Int
1 Int
0 Int
4 Int
4 Int
16 Int
64 Int
16 Int
50
                       (Float
1.0 :: Float)
  ]