{-# LANGUAGE OverloadedLists #-}
-- | Tests of "MnistFcnnRanked1" and "MnistFcnnRanked2" dense neural networks
-- using a few different optimization pipelines.
module TestMnistFCNNR
  ( testTrees
  ) where

import Prelude

import Control.Arrow ((***))
import Control.Monad (foldM, unless)
import Data.Bifunctor (first)
import Data.Proxy (Proxy (Proxy))
import System.IO (hPutStrLn, stderr)
import System.Random
import Test.Tasty
import Test.Tasty.HUnit hiding (assert)
import Test.Tasty.QuickCheck hiding (label, shuffle)
import Text.Printf

import Data.Array.Nested.Ranked.Shape

import HordeAd
import HordeAd.Core.Adaptor
import HordeAd.Core.AstEnv
import HordeAd.Core.AstFreshId
import HordeAd.Core.AstInterpret
import HordeAd.Core.Ops (tconcrete)

import CrossTesting
import EqEpsilon

import MnistData
import MnistFcnnRanked1 qualified
import MnistFcnnRanked2 (XParams2)
import MnistFcnnRanked2 qualified

testTrees :: [TestTree]
testTrees :: [TestTree]
testTrees = [ TestName -> Assertion -> TestTree
testCase TestName
"2VTOrev" Assertion
mnistTestCase2VTOrev
            , Item [TestTree]
TestTree
tensorADValMnistTests
            , Item [TestTree]
TestTree
tensorIntermediateMnistTests
            , Item [TestTree]
TestTree
tensorADOnceMnistTests
            , Item [TestTree]
TestTree
tensorADValMnistTests2
            , Item [TestTree]
TestTree
tensorIntermediateMnistTests2
            , Item [TestTree]
TestTree
tensorADOnceMnistTests2
            ]


-- * Running rev' on the gradient of afcnnMnistLoss2

mnistTestCase2VTOrev :: Assertion
mnistTestCase2VTOrev :: Assertion
mnistTestCase2VTOrev =
  let (!Concrete (XParams2 Double Float)
targetInit, !AstArtifactRev
  (TKProduct
     (XParams2 Double Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
  (TKScalar Double)
art) =
        forall r q.
(GoodScalar r, Differentiable r, GoodScalar q, Differentiable q) =>
Proxy @Type q
-> IncomingCotangentHandling
-> Double
-> StdGen
-> Int
-> Int
-> (Concrete (XParams2 r q),
    AstArtifactRev
      (TKProduct
         (XParams2 r q)
         (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKScalar r))
MnistFcnnRanked2.mnistTrainBench2VTOGradientX
          @Double (forall t. Proxy @Type t
forall {k} (t :: k). Proxy @k t
Proxy @Float) IncomingCotangentHandling
IgnoreIncomingCotangent
          Double
1 (Int -> StdGen
mkStdGen Int
44) Int
1500 Int
500
      blackGlyph :: Concrete (TKR2 (1 + 0) (TKScalar Double))
blackGlyph = Int
-> Concrete (TKR2 0 (TKScalar Double))
-> Concrete (TKR2 (1 + 0) (TKScalar Double))
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownNat n, KnownSTK x, BaseTensor target) =>
Int -> target (TKR2 n x) -> target (TKR2 (1 + n) x)
rreplicate Int
sizeMnistGlyphInt (Concrete (TKR2 0 (TKScalar Double))
 -> Concrete (TKR2 (1 + 0) (TKScalar Double)))
-> Concrete (TKR2 0 (TKScalar Double))
-> Concrete (TKR2 (1 + 0) (TKScalar Double))
forall a b. (a -> b) -> a -> b
$ Double -> Concrete (TKR2 0 (TKScalar Double))
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
r -> target (TKR 0 r)
rscalar Double
7
      ftk :: FullShapeTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
ftk = forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams2 Double Float))
                 Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
targetInit
      f :: forall target r. (ADReady target, r ~ Double)
        => target (TKR 1 r) -> target (TKR 1 r)
      f :: forall (target :: Target) r.
(ADReady target, (r :: Type) ~ (Double :: Type)) =>
target (TKR 1 r) -> target (TKR 1 r)
f target (TKR 1 r)
label =
        let val :: target
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 Double) (TKR 1 r)))
val = target
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> target (TKProduct (TKR 1 Double) (TKR 1 r))
-> target
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR 1 Double) (TKR 1 r)))
forall (x :: TK) (z :: TK).
target x -> target z -> target (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair (FullShapeTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> target
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
forall (y :: TK). FullShapeTK y -> Concrete y -> target y
forall (target :: Target) (y :: TK).
BaseTensor target =>
FullShapeTK y -> Concrete y -> target y
tconcrete FullShapeTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
ftk Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
targetInit)
                        (target (TKR 1 Double)
-> target (TKR 1 r) -> target (TKProduct (TKR 1 Double) (TKR 1 r))
forall (x :: TK) (z :: TK).
target x -> target z -> target (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair (Ranked 1 Double -> target (TKR 1 Double)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete (Ranked 1 Double -> target (TKR 1 Double))
-> Ranked 1 Double -> target (TKR 1 Double)
forall a b. (a -> b) -> a -> b
$ Concrete (TKR 1 Double) -> RepConcrete (TKR 1 Double)
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete (TKR 1 Double)
blackGlyph) target (TKR 1 r)
label)
            env :: AstEnv target
env = AstVarName
  PrimalSpan
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 Double) (TKR 1 Double)))
-> target
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR 1 Double) (TKR 1 Double)))
-> AstEnv target
-> AstEnv target
forall (target :: Target) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv (AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 Double) (TKR 1 Double)))
  (TKScalar Double)
-> AstVarName
     PrimalSpan
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR 1 Double) (TKR 1 Double)))
forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstVarName PrimalSpan x
artVarDomainRev AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 Double) (TKR 1 Double)))
  (TKScalar Double)
art) target
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 Double) (TKR 1 r)))
target
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 Double) (TKR 1 Double)))
val AstEnv target
forall (target :: Target). AstEnv target
emptyEnv
        in target (TKProduct (TKR 1 r) (TKR 1 Double)) -> target (TKR 1 r)
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 (target (TKProduct (TKR 1 r) (TKR 1 Double)) -> target (TKR 1 r))
-> target (TKProduct (TKR 1 r) (TKR 1 Double)) -> target (TKR 1 r)
forall a b. (a -> b) -> a -> b
$ target
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 r) (TKR 1 Double)))
-> target (TKProduct (TKR 1 r) (TKR 1 Double))
forall (x :: TK) (z :: TK). target (TKProduct x z) -> target z
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target z
tproject2
           (target
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
            (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
         (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
      (TKProduct (TKR 1 r) (TKR 1 Double)))
 -> target (TKProduct (TKR 1 r) (TKR 1 Double)))
-> target
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR 1 r) (TKR 1 Double)))
-> target (TKProduct (TKR 1 r) (TKR 1 Double))
forall a b. (a -> b) -> a -> b
$ forall (target :: Target) (s :: AstSpanType) (y :: TK).
(ADReady target, AstSpan s) =>
AstEnv target -> AstTensor AstMethodLet s y -> target y
interpretAst @target AstEnv target
env (AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 Double) (TKR 1 Double)))
  (TKScalar Double)
-> AstTensor
     AstMethodLet
     PrimalSpan
     (ADTensorKind
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
                 (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
           (TKProduct (TKR 1 Double) (TKR 1 Double))))
forall (x :: TK) (z :: TK).
AstArtifactRev x z
-> AstTensor AstMethodLet PrimalSpan (ADTensorKind x)
artDerivativeRev AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 Double) (TKR 1 Double)))
  (TKScalar Double)
art)
  in Rational
-> Concrete (TKR 1 Double)
-> ((Concrete (TKR 1 Double), Concrete (TKR 1 Double),
     Concrete (TKR 1 Double), Concrete (TKR 1 Double),
     Concrete (TKR 1 Double), Concrete (TKR 1 Double),
     Concrete (TKR 1 Double), Concrete (TKR 1 Double),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     AstTensor AstMethodLet PrimalSpan (TKR 1 Double),
     AstTensor AstMethodLet PrimalSpan (TKR 1 Double),
     Concrete (TKR 1 Double), Concrete (TKR 1 Double),
     Concrete (TKR 1 Double), Concrete (TKR 1 Double),
     Concrete (TKR 1 Double), Concrete (TKR 1 Double),
     Concrete (TKR 1 Double), Concrete (TKR 1 Double),
     Concrete (TKR 1 Double), Concrete (TKR 1 Double),
     Concrete (TKR 1 Double), Concrete (TKR 1 Double),
     Concrete (TKR 1 Double), Concrete (TKR 1 Double),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR 1 Double),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double)))),
    (Concrete (TKR 1 Double), Concrete (TKR 1 Double),
     Concrete (TKR 1 Double), Concrete (TKR 1 Double),
     Concrete (TKR 1 Double), Concrete (TKR 1 Double),
     Concrete (TKR 1 Double), Concrete (TKR 1 Double),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     AstTensor AstMethodLet PrimalSpan (TKR 1 Double),
     AstTensor AstMethodLet PrimalSpan (TKR 1 Double),
     Concrete (TKR 1 Double), Concrete (TKR 1 Double),
     Concrete (TKR 1 Double), Concrete (TKR 1 Double),
     Concrete (TKR 1 Double), Concrete (TKR 1 Double),
     Concrete (TKR 1 Double), Concrete (TKR 1 Double),
     Concrete (TKR 1 Double), Concrete (TKR 1 Double),
     Concrete (TKR 1 Double), Concrete (TKR 1 Double),
     Concrete (TKR 1 Double), Concrete (TKR 1 Double),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR 1 Double),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double)))))
-> Assertion
forall (n :: Nat) (m :: Nat) v r w a.
(KnownNat n, KnownNat m,
 (v :: Type) ~ (Concrete (TKR m r) :: Type),
 (w :: Type) ~ (Concrete (ADTensorKind (TKR m r)) :: Type),
 (a :: Type) ~ (Concrete (ADTensorKind (TKR n r)) :: Type),
 AssertEqualUpToEpsilon a, AssertEqualUpToEpsilon v,
 AssertEqualUpToEpsilon (ADTensorScalar r), GoodScalar r,
 GoodScalar (ADTensorScalar r), HasCallStack) =>
Rational
-> Concrete (TKR n r)
-> ((v, v, v, v, v, v, v, v, a, a, a, a, a, a, a, a, a, a, a, a,
     AstTensor AstMethodLet PrimalSpan (TKR m r),
     AstTensor AstMethodLet PrimalSpan (TKR m r), v, v, v, v, v, v, v,
     v, v, v, v, v, v, v, a, a, a, a, a, a, a, a, a, a, a, a, a, a,
     Concrete (TKR n r), w, w, w, a, a, a, a, a),
    (v, v, v, v, v, v, v, v, a, a, a, a, a, a, a, a, a, a, a, a,
     AstTensor AstMethodLet PrimalSpan (TKR m r),
     AstTensor AstMethodLet PrimalSpan (TKR m r), v, v, v, v, v, v, v,
     v, v, v, v, v, v, v, a, a, a, a, a, a, a, a, a, a, a, a, a, a,
     Concrete (TKR n r), w, w, w, a, a, a, a, a))
-> Assertion
assertEqualUpToEpsilon' Rational
1e-10
       (IShR 1 -> [Double] -> Concrete (TKR 1 Double)
forall (n :: Nat) r (target :: Target).
(GoodScalar r, BaseTensor target) =>
IShR n -> [r] -> target (TKR n r)
ringestData [Int
Item (IShR 1)
10] [Double
Item [Double]
6.922657834114052e-2,-Double
3.2210167235305924e-5,Double
Item [Double]
0.12334696753032606,-Double
4.892729845753193e-3,Double
Item [Double]
3.010762414514606e-2,Double
Item [Double]
2.0344986964700877e-2,-Double
3.78339785604896e-2,Double
Item [Double]
5.77360835535866e-2,Double
Item [Double]
0.10761507003315526,-Double
7.909016076299641e-2])
       ((forall (f :: Target).
 ADReady f =>
 f (TKR 1 Double) -> f (TKR 1 Double))
-> Concrete (TKR 1 Double)
-> ((Concrete (TKR 1 Double), Concrete (TKR 1 Double),
     Concrete (TKR 1 Double), Concrete (TKR 1 Double),
     Concrete (TKR 1 Double), Concrete (TKR 1 Double),
     Concrete (TKR 1 Double), Concrete (TKR 1 Double),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     AstTensor AstMethodLet PrimalSpan (TKR 1 Double),
     AstTensor AstMethodLet PrimalSpan (TKR 1 Double),
     Concrete (TKR 1 Double), Concrete (TKR 1 Double),
     Concrete (TKR 1 Double), Concrete (TKR 1 Double),
     Concrete (TKR 1 Double), Concrete (TKR 1 Double),
     Concrete (TKR 1 Double), Concrete (TKR 1 Double),
     Concrete (TKR 1 Double), Concrete (TKR 1 Double),
     Concrete (TKR 1 Double), Concrete (TKR 1 Double),
     Concrete (TKR 1 Double), Concrete (TKR 1 Double),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR 1 Double),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double)))),
    (Concrete (TKR 1 Double), Concrete (TKR 1 Double),
     Concrete (TKR 1 Double), Concrete (TKR 1 Double),
     Concrete (TKR 1 Double), Concrete (TKR 1 Double),
     Concrete (TKR 1 Double), Concrete (TKR 1 Double),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     AstTensor AstMethodLet PrimalSpan (TKR 1 Double),
     AstTensor AstMethodLet PrimalSpan (TKR 1 Double),
     Concrete (TKR 1 Double), Concrete (TKR 1 Double),
     Concrete (TKR 1 Double), Concrete (TKR 1 Double),
     Concrete (TKR 1 Double), Concrete (TKR 1 Double),
     Concrete (TKR 1 Double), Concrete (TKR 1 Double),
     Concrete (TKR 1 Double), Concrete (TKR 1 Double),
     Concrete (TKR 1 Double), Concrete (TKR 1 Double),
     Concrete (TKR 1 Double), Concrete (TKR 1 Double),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR 1 Double),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double))),
     Concrete (TKR2 1 (TKScalar (ADTensorScalar Double)))))
forall r (m :: Nat) (n :: Nat) v a w.
(KnownNat m, KnownNat n, GoodScalar r,
 (v :: Type) ~ (Concrete (TKR m r) :: Type),
 (w :: Type) ~ (Concrete (ADTensorKind (TKR m r)) :: Type),
 (a :: Type) ~ (Concrete (ADTensorKind (TKR n r)) :: Type)) =>
(forall (f :: Target). ADReady f => f (TKR n r) -> f (TKR m r))
-> Concrete (TKR n r)
-> ((v, v, v, v, v, v, v, v, a, a, a, a, a, a, a, a, a, a, a, a,
     AstTensor AstMethodLet PrimalSpan (TKR m r),
     AstTensor AstMethodLet PrimalSpan (TKR m r), v, v, v, v, v, v, v,
     v, v, v, v, v, v, v, a, a, a, a, a, a, a, a, a, a, a, a, a, a,
     Concrete (TKR n r), w, w, w, a, a, a, a, a),
    (v, v, v, v, v, v, v, v, a, a, a, a, a, a, a, a, a, a, a, a,
     AstTensor AstMethodLet PrimalSpan (TKR m r),
     AstTensor AstMethodLet PrimalSpan (TKR m r), v, v, v, v, v, v, v,
     v, v, v, v, v, v, v, a, a, a, a, a, a, a, a, a, a, a, a, a, a,
     Concrete (TKR n r), w, w, w, a, a, a, a, a))
rev' f (TKR 1 Double) -> f (TKR 1 Double)
forall (f :: Target).
ADReady f =>
f (TKR 1 Double) -> f (TKR 1 Double)
forall (target :: Target) r.
(ADReady target, (r :: Type) ~ (Double :: Type)) =>
target (TKR 1 r) -> target (TKR 1 r)
f (Int
-> Concrete (TKR2 0 (TKScalar Double))
-> Concrete (TKR2 (1 + 0) (TKScalar Double))
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownNat n, KnownSTK x, BaseTensor target) =>
Int -> target (TKR2 n x) -> target (TKR2 (1 + n) x)
rreplicate Int
sizeMnistLabelInt (Concrete (TKR2 0 (TKScalar Double))
 -> Concrete (TKR2 (1 + 0) (TKScalar Double)))
-> Concrete (TKR2 0 (TKScalar Double))
-> Concrete (TKR2 (1 + 0) (TKScalar Double))
forall a b. (a -> b) -> a -> b
$ Double -> Concrete (TKR2 0 (TKScalar Double))
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
r -> target (TKR 0 r)
rscalar Double
8))


-- * Using lists of vectors, which is rank 1

type XParams widthHidden widthHidden2 r =
  X (MnistFcnnRanked1.ADFcnnMnist1Parameters
       Concrete widthHidden widthHidden2 r)

-- POPL differentiation, straight via the ADVal instance of RankedTensor,
-- which side-steps vectorization.
mnistTestCase1VTA
  :: forall r.
     ( Differentiable r, GoodScalar r
     , PrintfArg r, AssertEqualUpToEpsilon r )
  => String
  -> Int -> Int -> Int -> Int -> Double -> Int -> r
  -> TestTree
mnistTestCase1VTA :: forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> r -> TestTree
mnistTestCase1VTA TestName
prefix Int
epochs Int
maxBatches Int
widthHiddenInt Int
widthHidden2Int
                  Double
gamma Int
batchSize r
expected =
  Int
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
widthHiddenInt ((forall (n :: Nat). KnownNat n => SNat n -> TestTree) -> TestTree)
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
widthHiddenSNat :: SNat widthHidden) ->
  Int
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
widthHidden2Int ((forall (n :: Nat). KnownNat n => SNat n -> TestTree) -> TestTree)
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
widthHidden2SNat :: SNat widthHidden2) ->
  SingletonTK (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
-> (KnownSTK (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r)) =>
    TestTree)
-> TestTree
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK
    (SingletonTK (TKS ((':) @Nat 784 ('[] @Nat)) r)
-> SNat n
-> SingletonTK (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
forall (t :: TK) (n :: Nat).
SingletonTK t -> SNat n -> SingletonTK (Tups n t)
stkOfListR (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(TKS '[SizeMnistGlyph] r)) (forall (n :: Nat). KnownNat n => SNat n
SNat @widthHidden)) ((KnownSTK (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r)) => TestTree)
 -> TestTree)
-> (KnownSTK (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r)) =>
    TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$
  SingletonTK (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
-> (KnownSTK (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)) =>
    TestTree)
-> TestTree
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK
    (SingletonTK (TKS ((':) @Nat n ('[] @Nat)) Float)
-> SNat n
-> SingletonTK (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
forall (t :: TK) (n :: Nat).
SingletonTK t -> SNat n -> SingletonTK (Tups n t)
stkOfListR (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(TKS '[widthHidden] Float)) (forall (n :: Nat). KnownNat n => SNat n
SNat @widthHidden2)) ((KnownSTK (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)) =>
  TestTree)
 -> TestTree)
-> (KnownSTK (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)) =>
    TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$
  let valsInit :: MnistFcnnRanked1.ADFcnnMnist1Parameters
                    Concrete widthHidden widthHidden2 r
      valsInit :: ADFcnnMnist1Parameters Concrete n n r
valsInit = (ADFcnnMnist1Parameters Concrete n n r, StdGen)
-> ADFcnnMnist1Parameters Concrete n n r
forall a b. (a, b) -> a
fst ((ADFcnnMnist1Parameters Concrete n n r, StdGen)
 -> ADFcnnMnist1Parameters Concrete n n r)
-> (ADFcnnMnist1Parameters Concrete n n r, StdGen)
-> ADFcnnMnist1Parameters Concrete n n r
forall a b. (a -> b) -> a -> b
$ Double
-> StdGen
-> (((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) r)),
      Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
     (ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
      Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
     (ListR SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
      Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))),
    StdGen)
forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue Double
1 (Int -> StdGen
mkStdGen Int
44)
      targetInit :: Concrete (XParams widthHidden widthHidden2 r)
      targetInit :: Concrete (XParams n n r)
targetInit = forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget @Concrete ((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) r)),
  Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
 (ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
  Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
 (ListR SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
  Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
ADFcnnMnist1Parameters Concrete n n r
valsInit
      name :: TestName
name = TestName
prefix TestName -> TestName -> TestName
forall a. [a] -> [a] -> [a]
++ TestName
": "
             TestName -> TestName -> TestName
forall a. [a] -> [a] -> [a]
++ [TestName] -> TestName
unwords [ Int -> TestName
forall a. Show a => a -> TestName
show Int
epochs, Int -> TestName
forall a. Show a => a -> TestName
show Int
maxBatches
                        , Int -> TestName
forall a. Show a => a -> TestName
show Int
widthHiddenInt, Int -> TestName
forall a. Show a => a -> TestName
show Int
widthHidden2Int
                        , Int -> TestName
forall a. Show a => a -> TestName
show (Int -> TestName) -> Int -> TestName
forall a b. (a -> b) -> a -> b
$ SingletonTK (XParams n n r) -> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK
                          (SingletonTK (XParams n n r) -> Int)
-> SingletonTK (XParams n n r) -> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams widthHidden widthHidden2 r)
                        , Int -> TestName
forall a. Show a => a -> TestName
show (SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
              (TKS ((':) @Nat n ('[] @Nat)) r))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) r)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
Concrete (XParams n n r)
targetInit)
                        , Double -> TestName
forall a. Show a => a -> TestName
show Double
gamma ]
      ftest :: [MnistDataLinearR r]
            -> MnistFcnnRanked1.ADFcnnMnist1Parameters
                 Concrete widthHidden widthHidden2 r
            -> r
      ftest :: [MnistDataLinearR r] -> ADFcnnMnist1Parameters Concrete n n r -> r
ftest = SNat n
-> SNat n
-> [MnistDataLinearR r]
-> ADFcnnMnist1Parameters Concrete n n r
-> r
forall (target :: Target) (widthHidden :: Nat)
       (widthHidden2 :: Nat) r.
((target :: Target) ~ (Concrete :: Target), GoodScalar r,
 Differentiable r) =>
SNat widthHidden
-> SNat widthHidden2
-> [MnistDataLinearR r]
-> ADFcnnMnist1Parameters target widthHidden widthHidden2 r
-> r
MnistFcnnRanked1.afcnnMnistTest1 SNat n
widthHiddenSNat SNat n
widthHidden2SNat
  in TestName -> Assertion -> TestTree
testCase TestName
name (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ do
    Handle -> TestName -> Assertion
hPutStrLn Handle
stderr (TestName -> Assertion) -> TestName -> Assertion
forall a b. (a -> b) -> a -> b
$
      TestName -> TestName -> Int -> Int -> TestName
forall r. PrintfType r => TestName -> r
printf TestName
"\n%s: Epochs to run/max batches per epoch: %d/%d"
             TestName
prefix Int
epochs Int
maxBatches
    trainData <- TestName -> TestName -> IO [MnistData r]
forall r.
(Storable r, Fractional r) =>
TestName -> TestName -> IO [MnistData r]
loadMnistData TestName
trainGlyphsPath TestName
trainLabelsPath
    testData <- map mkMnistDataLinearR . take (batchSize * maxBatches)
                <$> loadMnistData testGlyphsPath testLabelsPath
    let f :: MnistDataLinearR r
          -> ADVal Concrete (XParams widthHidden widthHidden2 r)
          -> ADVal Concrete (TKScalar r)
        f (Ranked 1 r
glyph, Ranked 1 r
label) ADVal Concrete (XParams n n r)
adinputs =
          SNat n
-> SNat n
-> (ADVal Concrete (TKR 1 r), ADVal Concrete (TKR 1 r))
-> ADFcnnMnist1Parameters (ADVal Concrete) n n r
-> ADVal Concrete (TKScalar r)
forall (target :: Target) r (widthHidden :: Nat)
       (widthHidden2 :: Nat).
(ADReady target, GoodScalar r, Differentiable r) =>
SNat widthHidden
-> SNat widthHidden2
-> (target (TKR 1 r), target (TKR 1 r))
-> ADFcnnMnist1Parameters target widthHidden widthHidden2 r
-> target (TKScalar r)
MnistFcnnRanked1.afcnnMnistLoss1
            SNat n
widthHiddenSNat SNat n
widthHidden2SNat
            (Ranked 1 r -> ADVal Concrete (TKR 1 r)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 1 r
glyph, Ranked 1 r -> ADVal Concrete (TKR 1 r)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 1 r
label) (ADVal
  Concrete
  (X ((ListR n (ADVal Concrete (TKS ((':) @Nat 784 ('[] @Nat)) r)),
       ADVal Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
      (ListR n (ADVal Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
       ADVal Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
      (ListR
         SizeMnistLabel (ADVal Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
       ADVal Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))))
-> ((ListR n (ADVal Concrete (TKS ((':) @Nat 784 ('[] @Nat)) r)),
     ADVal Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
    (ListR n (ADVal Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
     ADVal Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
    (ListR
       SizeMnistLabel (ADVal Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
     ADVal Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget ADVal
  Concrete
  (X ((ListR n (ADVal Concrete (TKS ((':) @Nat 784 ('[] @Nat)) r)),
       ADVal Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
      (ListR n (ADVal Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
       ADVal Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
      (ListR
         SizeMnistLabel (ADVal Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
       ADVal Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))))
ADVal Concrete (XParams n n r)
adinputs)
    -- Mimic how backprop tests and display it, even though tests
    -- should not print, in principle.
    let runBatch :: Concrete (XParams widthHidden widthHidden2 r)
                 -> (Int, [MnistDataLinearR r])
                 -> IO (Concrete (XParams widthHidden widthHidden2 r))
        runBatch !Concrete (XParams n n r)
params (Int
k, [MnistDataLinearR r]
chunk) = do
          let res :: Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
res = (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
            (TKS ((':) @Nat n ('[] @Nat)) r))
         (TKProduct
            (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
            (TKS ((':) @Nat n ('[] @Nat)) r)))
      (TKProduct
         (TKProduct
            (TKS ((':) @Nat n ('[] @Nat)) r)
            (TKProduct
               (TKS ((':) @Nat n ('[] @Nat)) r)
               (TKProduct
                  (TKS ((':) @Nat n ('[] @Nat)) r)
                  (TKProduct
                     (TKS ((':) @Nat n ('[] @Nat)) r)
                     (TKProduct
                        (TKS ((':) @Nat n ('[] @Nat)) r)
                        (TKProduct
                           (TKS ((':) @Nat n ('[] @Nat)) r)
                           (TKProduct
                              (TKS ((':) @Nat n ('[] @Nat)) r)
                              (TKProduct
                                 (TKS ((':) @Nat n ('[] @Nat)) r)
                                 (TKProduct
                                    (TKS ((':) @Nat n ('[] @Nat)) r)
                                    (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
         (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))),
 Concrete (TKScalar r))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
              (TKS ((':) @Nat n ('[] @Nat)) r))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) r)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
forall a b. (a, b) -> a
fst ((Concrete
    (TKProduct
       (TKProduct
          (TKProduct
             (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
             (TKS ((':) @Nat n ('[] @Nat)) r))
          (TKProduct
             (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
             (TKS ((':) @Nat n ('[] @Nat)) r)))
       (TKProduct
          (TKProduct
             (TKS ((':) @Nat n ('[] @Nat)) r)
             (TKProduct
                (TKS ((':) @Nat n ('[] @Nat)) r)
                (TKProduct
                   (TKS ((':) @Nat n ('[] @Nat)) r)
                   (TKProduct
                      (TKS ((':) @Nat n ('[] @Nat)) r)
                      (TKProduct
                         (TKS ((':) @Nat n ('[] @Nat)) r)
                         (TKProduct
                            (TKS ((':) @Nat n ('[] @Nat)) r)
                            (TKProduct
                               (TKS ((':) @Nat n ('[] @Nat)) r)
                               (TKProduct
                                  (TKS ((':) @Nat n ('[] @Nat)) r)
                                  (TKProduct
                                     (TKS ((':) @Nat n ('[] @Nat)) r)
                                     (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
          (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))),
  Concrete (TKScalar r))
 -> Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
               (TKS ((':) @Nat n ('[] @Nat)) r))
            (TKProduct
               (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
               (TKS ((':) @Nat n ('[] @Nat)) r)))
         (TKProduct
            (TKProduct
               (TKS ((':) @Nat n ('[] @Nat)) r)
               (TKProduct
                  (TKS ((':) @Nat n ('[] @Nat)) r)
                  (TKProduct
                     (TKS ((':) @Nat n ('[] @Nat)) r)
                     (TKProduct
                        (TKS ((':) @Nat n ('[] @Nat)) r)
                        (TKProduct
                           (TKS ((':) @Nat n ('[] @Nat)) r)
                           (TKProduct
                              (TKS ((':) @Nat n ('[] @Nat)) r)
                              (TKProduct
                                 (TKS ((':) @Nat n ('[] @Nat)) r)
                                 (TKProduct
                                    (TKS ((':) @Nat n ('[] @Nat)) r)
                                    (TKProduct
                                       (TKS ((':) @Nat n ('[] @Nat)) r)
                                       (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
            (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
               (TKS ((':) @Nat n ('[] @Nat)) r))
            (TKProduct
               (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
               (TKS ((':) @Nat n ('[] @Nat)) r)))
         (TKProduct
            (TKProduct
               (TKS ((':) @Nat n ('[] @Nat)) r)
               (TKProduct
                  (TKS ((':) @Nat n ('[] @Nat)) r)
                  (TKProduct
                     (TKS ((':) @Nat n ('[] @Nat)) r)
                     (TKProduct
                        (TKS ((':) @Nat n ('[] @Nat)) r)
                        (TKProduct
                           (TKS ((':) @Nat n ('[] @Nat)) r)
                           (TKProduct
                              (TKS ((':) @Nat n ('[] @Nat)) r)
                              (TKProduct
                                 (TKS ((':) @Nat n ('[] @Nat)) r)
                                 (TKProduct
                                    (TKS ((':) @Nat n ('[] @Nat)) r)
                                    (TKProduct
                                       (TKS ((':) @Nat n ('[] @Nat)) r)
                                       (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
            (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))),
    Concrete (TKScalar r))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
              (TKS ((':) @Nat n ('[] @Nat)) r))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) r)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
forall a b. (a -> b) -> a -> b
$ Double
-> (MnistDataLinearR r
    -> ADVal
         Concrete
         (TKProduct
            (TKProduct
               (TKProduct
                  (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
                  (TKS ((':) @Nat n ('[] @Nat)) r))
               (TKProduct
                  (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
                  (TKS ((':) @Nat n ('[] @Nat)) r)))
            (TKProduct
               (TKProduct
                  (TKS ((':) @Nat n ('[] @Nat)) r)
                  (TKProduct
                     (TKS ((':) @Nat n ('[] @Nat)) r)
                     (TKProduct
                        (TKS ((':) @Nat n ('[] @Nat)) r)
                        (TKProduct
                           (TKS ((':) @Nat n ('[] @Nat)) r)
                           (TKProduct
                              (TKS ((':) @Nat n ('[] @Nat)) r)
                              (TKProduct
                                 (TKS ((':) @Nat n ('[] @Nat)) r)
                                 (TKProduct
                                    (TKS ((':) @Nat n ('[] @Nat)) r)
                                    (TKProduct
                                       (TKS ((':) @Nat n ('[] @Nat)) r)
                                       (TKProduct
                                          (TKS ((':) @Nat n ('[] @Nat)) r)
                                          (TKProduct
                                             (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
               (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
    -> ADVal Concrete (TKScalar r))
-> [MnistDataLinearR r]
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
              (TKS ((':) @Nat n ('[] @Nat)) r))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) r)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
               (TKS ((':) @Nat n ('[] @Nat)) r))
            (TKProduct
               (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
               (TKS ((':) @Nat n ('[] @Nat)) r)))
         (TKProduct
            (TKProduct
               (TKS ((':) @Nat n ('[] @Nat)) r)
               (TKProduct
                  (TKS ((':) @Nat n ('[] @Nat)) r)
                  (TKProduct
                     (TKS ((':) @Nat n ('[] @Nat)) r)
                     (TKProduct
                        (TKS ((':) @Nat n ('[] @Nat)) r)
                        (TKProduct
                           (TKS ((':) @Nat n ('[] @Nat)) r)
                           (TKProduct
                              (TKS ((':) @Nat n ('[] @Nat)) r)
                              (TKProduct
                                 (TKS ((':) @Nat n ('[] @Nat)) r)
                                 (TKProduct
                                    (TKS ((':) @Nat n ('[] @Nat)) r)
                                    (TKProduct
                                       (TKS ((':) @Nat n ('[] @Nat)) r)
                                       (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
            (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))),
    Concrete (TKScalar r))
forall a (x :: TK) (z :: TK).
KnownSTK x =>
Double
-> (a -> ADVal Concrete x -> ADVal Concrete z)
-> [a]
-> Concrete x
-> (Concrete x, Concrete z)
sgd Double
gamma MnistDataLinearR r
-> ADVal
     Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
              (TKS ((':) @Nat n ('[] @Nat)) r))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) r)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
-> ADVal Concrete (TKScalar r)
MnistDataLinearR r
-> ADVal Concrete (XParams n n r) -> ADVal Concrete (TKScalar r)
f [MnistDataLinearR r]
chunk Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
Concrete (XParams n n r)
params
              trainScore :: r
trainScore = [MnistDataLinearR r] -> ADFcnnMnist1Parameters Concrete n n r -> r
ftest [MnistDataLinearR r]
chunk (Concrete
  (X ((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) r)),
       Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
      (ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
       Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
      (ListR SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
       Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))))
-> ((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) r)),
     Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
    (ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
     Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
    (ListR SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
     Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
Concrete
  (X ((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) r)),
       Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
      (ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
       Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
      (ListR SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
       Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))))
res)
              testScore :: r
testScore = [MnistDataLinearR r] -> ADFcnnMnist1Parameters Concrete n n r -> r
ftest [MnistDataLinearR r]
testData (Concrete
  (X ((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) r)),
       Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
      (ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
       Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
      (ListR SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
       Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))))
-> ((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) r)),
     Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
    (ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
     Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
    (ListR SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
     Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
Concrete
  (X ((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) r)),
       Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
      (ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
       Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
      (ListR SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
       Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))))
res)
              lenChunk :: Int
lenChunk = [MnistDataLinearR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataLinearR r]
chunk
          Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
widthHiddenInt Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$ do
            Handle -> TestName -> Assertion
hPutStrLn Handle
stderr (TestName -> Assertion) -> TestName -> Assertion
forall a b. (a -> b) -> a -> b
$
              TestName -> TestName -> Int -> Int -> TestName
forall r. PrintfType r => TestName -> r
printf TestName
"\n%s: (Batch %d with %d points)"
                     TestName
prefix Int
k Int
lenChunk
            Handle -> TestName -> Assertion
hPutStrLn Handle
stderr (TestName -> Assertion) -> TestName -> Assertion
forall a b. (a -> b) -> a -> b
$
              TestName -> TestName -> r -> TestName
forall r. PrintfType r => TestName -> r
printf TestName
"%s: Training error:   %.2f%%"
                     TestName
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
trainScore) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
            Handle -> TestName -> Assertion
hPutStrLn Handle
stderr (TestName -> Assertion) -> TestName -> Assertion
forall a b. (a -> b) -> a -> b
$
              TestName -> TestName -> r -> TestName
forall r. PrintfType r => TestName -> r
printf TestName
"%s: Validation error: %.2f%%"
                     TestName
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
testScore ) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
          Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
                 (TKS ((':) @Nat n ('[] @Nat)) r))
              (TKProduct
                 (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
                 (TKS ((':) @Nat n ('[] @Nat)) r)))
           (TKProduct
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) r)
                                         (TKProduct
                                            (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
              (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
res
    let runEpoch :: Int
                 -> Concrete (XParams widthHidden widthHidden2 r)
                 -> IO (Concrete (XParams widthHidden widthHidden2 r))
        runEpoch Int
n Concrete (XParams n n r)
params | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
epochs = Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
                 (TKS ((':) @Nat n ('[] @Nat)) r))
              (TKProduct
                 (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
                 (TKS ((':) @Nat n ('[] @Nat)) r)))
           (TKProduct
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) r)
                                         (TKProduct
                                            (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
              (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
Concrete (XParams n n r)
params
        runEpoch Int
n !Concrete (XParams n n r)
params = do
          Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
widthHiddenInt Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$
            Handle -> TestName -> Assertion
hPutStrLn Handle
stderr (TestName -> Assertion) -> TestName -> Assertion
forall a b. (a -> b) -> a -> b
$ TestName -> TestName -> Int -> TestName
forall r. PrintfType r => TestName -> r
printf TestName
"\n%s: [Epoch %d]" TestName
prefix Int
n
          let trainDataShuffled :: [MnistData r]
trainDataShuffled = StdGen -> [MnistData r] -> [MnistData r]
forall a. StdGen -> [a] -> [a]
shuffle (Int -> StdGen
mkStdGen (Int -> StdGen) -> Int -> StdGen
forall a b. (a -> b) -> a -> b
$ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
5) [MnistData r]
trainData
              chunks :: [(Int, [MnistDataLinearR r])]
chunks = Int
-> [(Int, [MnistDataLinearR r])] -> [(Int, [MnistDataLinearR r])]
forall a. Int -> [a] -> [a]
take Int
maxBatches
                       ([(Int, [MnistDataLinearR r])] -> [(Int, [MnistDataLinearR r])])
-> [(Int, [MnistDataLinearR r])] -> [(Int, [MnistDataLinearR r])]
forall a b. (a -> b) -> a -> b
$ [Int] -> [[MnistDataLinearR r]] -> [(Int, [MnistDataLinearR r])]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
Item [Int]
1 ..] ([[MnistDataLinearR r]] -> [(Int, [MnistDataLinearR r])])
-> [[MnistDataLinearR r]] -> [(Int, [MnistDataLinearR r])]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataLinearR r] -> [[MnistDataLinearR r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
batchSize
                       ([MnistDataLinearR r] -> [[MnistDataLinearR r]])
-> [MnistDataLinearR r] -> [[MnistDataLinearR r]]
forall a b. (a -> b) -> a -> b
$ (MnistData r -> MnistDataLinearR r)
-> [MnistData r] -> [MnistDataLinearR r]
forall a b. (a -> b) -> [a] -> [b]
map MnistData r -> MnistDataLinearR r
forall r. PrimElt r => MnistData r -> MnistDataLinearR r
mkMnistDataLinearR [MnistData r]
trainDataShuffled
          res <- (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
            (TKS ((':) @Nat n ('[] @Nat)) r))
         (TKProduct
            (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
            (TKS ((':) @Nat n ('[] @Nat)) r)))
      (TKProduct
         (TKProduct
            (TKS ((':) @Nat n ('[] @Nat)) r)
            (TKProduct
               (TKS ((':) @Nat n ('[] @Nat)) r)
               (TKProduct
                  (TKS ((':) @Nat n ('[] @Nat)) r)
                  (TKProduct
                     (TKS ((':) @Nat n ('[] @Nat)) r)
                     (TKProduct
                        (TKS ((':) @Nat n ('[] @Nat)) r)
                        (TKProduct
                           (TKS ((':) @Nat n ('[] @Nat)) r)
                           (TKProduct
                              (TKS ((':) @Nat n ('[] @Nat)) r)
                              (TKProduct
                                 (TKS ((':) @Nat n ('[] @Nat)) r)
                                 (TKProduct
                                    (TKS ((':) @Nat n ('[] @Nat)) r)
                                    (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
         (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
 -> (Int, [MnistDataLinearR r])
 -> IO
      (Concrete
         (TKProduct
            (TKProduct
               (TKProduct
                  (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
                  (TKS ((':) @Nat n ('[] @Nat)) r))
               (TKProduct
                  (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
                  (TKS ((':) @Nat n ('[] @Nat)) r)))
            (TKProduct
               (TKProduct
                  (TKS ((':) @Nat n ('[] @Nat)) r)
                  (TKProduct
                     (TKS ((':) @Nat n ('[] @Nat)) r)
                     (TKProduct
                        (TKS ((':) @Nat n ('[] @Nat)) r)
                        (TKProduct
                           (TKS ((':) @Nat n ('[] @Nat)) r)
                           (TKProduct
                              (TKS ((':) @Nat n ('[] @Nat)) r)
                              (TKProduct
                                 (TKS ((':) @Nat n ('[] @Nat)) r)
                                 (TKProduct
                                    (TKS ((':) @Nat n ('[] @Nat)) r)
                                    (TKProduct
                                       (TKS ((':) @Nat n ('[] @Nat)) r)
                                       (TKProduct
                                          (TKS ((':) @Nat n ('[] @Nat)) r)
                                          (TKProduct
                                             (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
               (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
              (TKS ((':) @Nat n ('[] @Nat)) r))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) r)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
-> [(Int, [MnistDataLinearR r])]
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
                 (TKS ((':) @Nat n ('[] @Nat)) r))
              (TKProduct
                 (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
                 (TKS ((':) @Nat n ('[] @Nat)) r)))
           (TKProduct
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) r)
                                         (TKProduct
                                            (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
              (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))))
forall (t :: Type -> Type) (m :: Type -> Type) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
-> (Int, [MnistDataLinearR r])
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
                 (TKS ((':) @Nat n ('[] @Nat)) r))
              (TKProduct
                 (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
                 (TKS ((':) @Nat n ('[] @Nat)) r)))
           (TKProduct
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) r)
                                         (TKProduct
                                            (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
              (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))))
Concrete (XParams n n r)
-> (Int, [MnistDataLinearR r]) -> IO (Concrete (XParams n n r))
runBatch Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
Concrete (XParams n n r)
params [(Int, [MnistDataLinearR r])]
chunks
          runEpoch (succ n) res
    res <- runEpoch 1 targetInit
    let testErrorFinal = r
1 r -> r -> r
forall a. Num a => a -> a -> a
- [MnistDataLinearR r] -> ADFcnnMnist1Parameters Concrete n n r -> r
ftest [MnistDataLinearR r]
testData (Concrete
  (X ((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) r)),
       Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
      (ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
       Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
      (ListR SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
       Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))))
-> ((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) r)),
     Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
    (ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
     Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
    (ListR SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
     Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
Concrete
  (X ((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) r)),
       Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
      (ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
       Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
      (ListR SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
       Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))))
res)
    testErrorFinal @?~ expected

{-# SPECIALIZE mnistTestCase1VTA
  :: String
  -> Int -> Int -> Int -> Int -> Double -> Int -> Double
  -> TestTree #-}

tensorADValMnistTests :: TestTree
tensorADValMnistTests :: TestTree
tensorADValMnistTests = TestName -> [TestTree] -> TestTree
testGroup TestName
"Ranked ADVal MNIST tests"
  [ TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> Double -> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> r -> TestTree
mnistTestCase1VTA TestName
"VTA1 1 epoch, 1 batch" Int
1 Int
1 Int
300 Int
100 Double
0.02 Int
5000
                      (Double
0.2146 :: Double)
  , TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> Float -> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> r -> TestTree
mnistTestCase1VTA TestName
"VTA1 artificial 1 2 3 4 5" Int
1 Int
2 Int
3 Int
4 Double
5 Int
5000
                      (Float
0.8972 :: Float)
  , TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> Float -> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> r -> TestTree
mnistTestCase1VTA TestName
"VTA1 1 epoch, 0 batch" Int
1 Int
0 Int
300 Int
100 Double
0.02 Int
5000
                      (Float
1 :: Float)
  ]

-- POPL differentiation, with Ast term defined and vectorized only once,
-- but differentiated anew in each gradient descent iteration.
mnistTestCase1VTI
  :: forall r.
     ( Differentiable r, GoodScalar r
     , PrintfArg r, AssertEqualUpToEpsilon r )
  => String
  -> Int -> Int -> Int -> Int -> Double -> Int -> r
  -> TestTree
mnistTestCase1VTI :: forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> r -> TestTree
mnistTestCase1VTI TestName
prefix Int
epochs Int
maxBatches Int
widthHiddenInt Int
widthHidden2Int
                  Double
gamma Int
batchSize r
expected =
  Int
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
widthHiddenInt ((forall (n :: Nat). KnownNat n => SNat n -> TestTree) -> TestTree)
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
widthHiddenSNat :: SNat widthHidden) ->
  Int
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
widthHidden2Int ((forall (n :: Nat). KnownNat n => SNat n -> TestTree) -> TestTree)
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
widthHidden2SNat :: SNat widthHidden2) ->
  SingletonTK (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
-> (KnownSTK (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r)) =>
    TestTree)
-> TestTree
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK
    (SingletonTK (TKS ((':) @Nat 784 ('[] @Nat)) r)
-> SNat n
-> SingletonTK (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
forall (t :: TK) (n :: Nat).
SingletonTK t -> SNat n -> SingletonTK (Tups n t)
stkOfListR (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(TKS '[SizeMnistGlyph] r)) (forall (n :: Nat). KnownNat n => SNat n
SNat @widthHidden)) ((KnownSTK (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r)) => TestTree)
 -> TestTree)
-> (KnownSTK (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r)) =>
    TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$
  SingletonTK (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
-> (KnownSTK (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)) =>
    TestTree)
-> TestTree
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK
    (SingletonTK (TKS ((':) @Nat n ('[] @Nat)) Float)
-> SNat n
-> SingletonTK (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
forall (t :: TK) (n :: Nat).
SingletonTK t -> SNat n -> SingletonTK (Tups n t)
stkOfListR (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(TKS '[widthHidden] Float)) (forall (n :: Nat). KnownNat n => SNat n
SNat @widthHidden2)) ((KnownSTK (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)) =>
  TestTree)
 -> TestTree)
-> (KnownSTK (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)) =>
    TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$
  let valsInit :: MnistFcnnRanked1.ADFcnnMnist1Parameters
                    Concrete widthHidden widthHidden2 r
      valsInit :: ADFcnnMnist1Parameters Concrete n n r
valsInit = (ADFcnnMnist1Parameters Concrete n n r, StdGen)
-> ADFcnnMnist1Parameters Concrete n n r
forall a b. (a, b) -> a
fst ((ADFcnnMnist1Parameters Concrete n n r, StdGen)
 -> ADFcnnMnist1Parameters Concrete n n r)
-> (ADFcnnMnist1Parameters Concrete n n r, StdGen)
-> ADFcnnMnist1Parameters Concrete n n r
forall a b. (a -> b) -> a -> b
$ Double
-> StdGen
-> (((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) r)),
      Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
     (ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
      Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
     (ListR SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
      Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))),
    StdGen)
forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue Double
1 (Int -> StdGen
mkStdGen Int
44)
      targetInit :: Concrete (XParams widthHidden widthHidden2 r)
      targetInit :: Concrete (XParams n n r)
targetInit = forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget @Concrete ((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) r)),
  Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
 (ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
  Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
 (ListR SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
  Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
ADFcnnMnist1Parameters Concrete n n r
valsInit
      ftk :: FullShapeTK
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
ftk = forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams widthHidden widthHidden2 r))
                 Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
Concrete (XParams n n r)
targetInit
      name :: TestName
name = TestName
prefix TestName -> TestName -> TestName
forall a. [a] -> [a] -> [a]
++ TestName
": "
             TestName -> TestName -> TestName
forall a. [a] -> [a] -> [a]
++ [TestName] -> TestName
unwords [ Int -> TestName
forall a. Show a => a -> TestName
show Int
epochs, Int -> TestName
forall a. Show a => a -> TestName
show Int
maxBatches
                        , Int -> TestName
forall a. Show a => a -> TestName
show Int
widthHiddenInt, Int -> TestName
forall a. Show a => a -> TestName
show Int
widthHidden2Int
                        , Int -> TestName
forall a. Show a => a -> TestName
show (Int -> TestName) -> Int -> TestName
forall a b. (a -> b) -> a -> b
$ SingletonTK (XParams n n r) -> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK
                          (SingletonTK (XParams n n r) -> Int)
-> SingletonTK (XParams n n r) -> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams widthHidden widthHidden2 r)
                        , Int -> TestName
forall a. Show a => a -> TestName
show (SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
              (TKS ((':) @Nat n ('[] @Nat)) r))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) r)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
Concrete (XParams n n r)
targetInit)
                        , Double -> TestName
forall a. Show a => a -> TestName
show Double
gamma ]
      ftest :: [MnistDataLinearR r]
            -> MnistFcnnRanked1.ADFcnnMnist1Parameters
                 Concrete widthHidden widthHidden2 r
            -> r
      ftest :: [MnistDataLinearR r] -> ADFcnnMnist1Parameters Concrete n n r -> r
ftest = SNat n
-> SNat n
-> [MnistDataLinearR r]
-> ADFcnnMnist1Parameters Concrete n n r
-> r
forall (target :: Target) (widthHidden :: Nat)
       (widthHidden2 :: Nat) r.
((target :: Target) ~ (Concrete :: Target), GoodScalar r,
 Differentiable r) =>
SNat widthHidden
-> SNat widthHidden2
-> [MnistDataLinearR r]
-> ADFcnnMnist1Parameters target widthHidden widthHidden2 r
-> r
MnistFcnnRanked1.afcnnMnistTest1 SNat n
widthHiddenSNat SNat n
widthHidden2SNat
  in TestName -> Assertion -> TestTree
testCase TestName
name (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ do
    Handle -> TestName -> Assertion
hPutStrLn Handle
stderr (TestName -> Assertion) -> TestName -> Assertion
forall a b. (a -> b) -> a -> b
$
      TestName -> TestName -> Int -> Int -> TestName
forall r. PrintfType r => TestName -> r
printf TestName
"\n%s: Epochs to run/max batches per epoch: %d/%d"
             TestName
prefix Int
epochs Int
maxBatches
    trainData <- TestName -> TestName -> IO [MnistData r]
forall r.
(Storable r, Fractional r) =>
TestName -> TestName -> IO [MnistData r]
loadMnistData TestName
trainGlyphsPath TestName
trainLabelsPath
    testData <- map mkMnistDataLinearR . take (batchSize * maxBatches)
                <$> loadMnistData testGlyphsPath testLabelsPath
    (_, _, var, varAst) <- funToAstRevIO ftk
    (varGlyph, astGlyph) <-
      funToAstIO (FTKR (sizeMnistGlyphInt :$: ZSR) FTKScalar) id
    (varLabel, astLabel) <-
      funToAstIO (FTKR (sizeMnistLabelInt :$: ZSR) FTKScalar) id
    let ast :: AstTensor AstMethodLet FullSpan (TKScalar r)
        ast = AstTensor AstMethodLet FullSpan (TKScalar r)
-> AstTensor AstMethodLet FullSpan (TKScalar r)
forall (z :: TK) (s :: AstSpanType).
AstSpan s =>
AstTensor AstMethodLet s z -> AstTensor AstMethodLet s z
simplifyInline
              (AstTensor AstMethodLet FullSpan (TKScalar r)
 -> AstTensor AstMethodLet FullSpan (TKScalar r))
-> AstTensor AstMethodLet FullSpan (TKScalar r)
-> AstTensor AstMethodLet FullSpan (TKScalar r)
forall a b. (a -> b) -> a -> b
$ SNat n
-> SNat n
-> (AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)),
    AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)))
-> ADFcnnMnist1Parameters (AstTensor AstMethodLet FullSpan) n n r
-> AstTensor AstMethodLet FullSpan (TKScalar r)
forall (target :: Target) r (widthHidden :: Nat)
       (widthHidden2 :: Nat).
(ADReady target, GoodScalar r, Differentiable r) =>
SNat widthHidden
-> SNat widthHidden2
-> (target (TKR 1 r), target (TKR 1 r))
-> ADFcnnMnist1Parameters target widthHidden widthHidden2 r
-> target (TKScalar r)
MnistFcnnRanked1.afcnnMnistLoss1
                  SNat n
widthHiddenSNat SNat n
widthHidden2SNat (AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))
astGlyph, AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))
astLabel)
                  (AstTensor
  AstMethodLet
  FullSpan
  (X ((ListR
         n
         (AstTensor
            AstMethodLet FullSpan (TKS ((':) @Nat 784 ('[] @Nat)) r)),
       AstTensor AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) r)),
      (ListR
         n
         (AstTensor
            AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Float)),
       AstTensor AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) r)),
      (ListR
         SizeMnistLabel
         (AstTensor AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) r)),
       AstTensor
         AstMethodLet
         FullSpan
         (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))))
-> ((ListR
       n
       (AstTensor
          AstMethodLet FullSpan (TKS ((':) @Nat 784 ('[] @Nat)) r)),
     AstTensor AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) r)),
    (ListR
       n
       (AstTensor
          AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Float)),
     AstTensor AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) r)),
    (ListR
       SizeMnistLabel
       (AstTensor AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) r)),
     AstTensor
       AstMethodLet
       FullSpan
       (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget AstTensor
  AstMethodLet
  FullSpan
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
AstTensor
  AstMethodLet
  FullSpan
  (X ((ListR
         n
         (AstTensor
            AstMethodLet FullSpan (TKS ((':) @Nat 784 ('[] @Nat)) r)),
       AstTensor AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) r)),
      (ListR
         n
         (AstTensor
            AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Float)),
       AstTensor AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) r)),
      (ListR
         SizeMnistLabel
         (AstTensor AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) r)),
       AstTensor
         AstMethodLet
         FullSpan
         (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))))
varAst)
        f :: MnistDataLinearR r
          -> ADVal Concrete (XParams widthHidden widthHidden2 r)
          -> ADVal Concrete (TKScalar r)
        f (Ranked 1 r
glyph, Ranked 1 r
label) ADVal Concrete (XParams n n r)
varInputs =
          let env :: AstEnv (ADVal Concrete)
env = AstVarName
  FullSpan
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
-> ADVal
     Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
              (TKS ((':) @Nat n ('[] @Nat)) r))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) r)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
-> AstEnv (ADVal Concrete)
-> AstEnv (ADVal Concrete)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName
  FullSpan
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
var ADVal
  Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
ADVal Concrete (XParams n n r)
varInputs AstEnv (ADVal Concrete)
forall (target :: Target). AstEnv target
emptyEnv
              envMnist :: AstEnv (ADVal Concrete)
envMnist = AstVarName FullSpan (TKR2 1 (TKScalar r))
-> ADVal Concrete (TKR2 1 (TKScalar r))
-> AstEnv (ADVal Concrete)
-> AstEnv (ADVal Concrete)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName FullSpan (TKR2 1 (TKScalar r))
varGlyph (Ranked 1 r -> ADVal Concrete (TKR2 1 (TKScalar r))
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 1 r
glyph)
                         (AstEnv (ADVal Concrete) -> AstEnv (ADVal Concrete))
-> AstEnv (ADVal Concrete) -> AstEnv (ADVal Concrete)
forall a b. (a -> b) -> a -> b
$ AstVarName FullSpan (TKR2 1 (TKScalar r))
-> ADVal Concrete (TKR2 1 (TKScalar r))
-> AstEnv (ADVal Concrete)
-> AstEnv (ADVal Concrete)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName FullSpan (TKR2 1 (TKScalar r))
varLabel (Ranked 1 r -> ADVal Concrete (TKR2 1 (TKScalar r))
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 1 r
label) AstEnv (ADVal Concrete)
env
          in AstEnv (ADVal Concrete)
-> AstTensor AstMethodLet FullSpan (TKScalar r)
-> ADVal Concrete (TKScalar r)
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target -> AstTensor AstMethodLet FullSpan y -> target y
interpretAstFull AstEnv (ADVal Concrete)
envMnist AstTensor AstMethodLet FullSpan (TKScalar r)
ast
    let runBatch :: Concrete (XParams widthHidden widthHidden2 r)
                 -> (Int, [MnistDataLinearR r])
                 -> IO (Concrete (XParams widthHidden widthHidden2 r))
        runBatch !Concrete (XParams n n r)
params (Int
k, [MnistDataLinearR r]
chunk) = do
          let res :: Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
res = (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
            (TKS ((':) @Nat n ('[] @Nat)) r))
         (TKProduct
            (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
            (TKS ((':) @Nat n ('[] @Nat)) r)))
      (TKProduct
         (TKProduct
            (TKS ((':) @Nat n ('[] @Nat)) r)
            (TKProduct
               (TKS ((':) @Nat n ('[] @Nat)) r)
               (TKProduct
                  (TKS ((':) @Nat n ('[] @Nat)) r)
                  (TKProduct
                     (TKS ((':) @Nat n ('[] @Nat)) r)
                     (TKProduct
                        (TKS ((':) @Nat n ('[] @Nat)) r)
                        (TKProduct
                           (TKS ((':) @Nat n ('[] @Nat)) r)
                           (TKProduct
                              (TKS ((':) @Nat n ('[] @Nat)) r)
                              (TKProduct
                                 (TKS ((':) @Nat n ('[] @Nat)) r)
                                 (TKProduct
                                    (TKS ((':) @Nat n ('[] @Nat)) r)
                                    (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
         (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))),
 Concrete (TKScalar r))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
              (TKS ((':) @Nat n ('[] @Nat)) r))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) r)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
forall a b. (a, b) -> a
fst ((Concrete
    (TKProduct
       (TKProduct
          (TKProduct
             (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
             (TKS ((':) @Nat n ('[] @Nat)) r))
          (TKProduct
             (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
             (TKS ((':) @Nat n ('[] @Nat)) r)))
       (TKProduct
          (TKProduct
             (TKS ((':) @Nat n ('[] @Nat)) r)
             (TKProduct
                (TKS ((':) @Nat n ('[] @Nat)) r)
                (TKProduct
                   (TKS ((':) @Nat n ('[] @Nat)) r)
                   (TKProduct
                      (TKS ((':) @Nat n ('[] @Nat)) r)
                      (TKProduct
                         (TKS ((':) @Nat n ('[] @Nat)) r)
                         (TKProduct
                            (TKS ((':) @Nat n ('[] @Nat)) r)
                            (TKProduct
                               (TKS ((':) @Nat n ('[] @Nat)) r)
                               (TKProduct
                                  (TKS ((':) @Nat n ('[] @Nat)) r)
                                  (TKProduct
                                     (TKS ((':) @Nat n ('[] @Nat)) r)
                                     (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
          (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))),
  Concrete (TKScalar r))
 -> Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
               (TKS ((':) @Nat n ('[] @Nat)) r))
            (TKProduct
               (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
               (TKS ((':) @Nat n ('[] @Nat)) r)))
         (TKProduct
            (TKProduct
               (TKS ((':) @Nat n ('[] @Nat)) r)
               (TKProduct
                  (TKS ((':) @Nat n ('[] @Nat)) r)
                  (TKProduct
                     (TKS ((':) @Nat n ('[] @Nat)) r)
                     (TKProduct
                        (TKS ((':) @Nat n ('[] @Nat)) r)
                        (TKProduct
                           (TKS ((':) @Nat n ('[] @Nat)) r)
                           (TKProduct
                              (TKS ((':) @Nat n ('[] @Nat)) r)
                              (TKProduct
                                 (TKS ((':) @Nat n ('[] @Nat)) r)
                                 (TKProduct
                                    (TKS ((':) @Nat n ('[] @Nat)) r)
                                    (TKProduct
                                       (TKS ((':) @Nat n ('[] @Nat)) r)
                                       (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
            (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
               (TKS ((':) @Nat n ('[] @Nat)) r))
            (TKProduct
               (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
               (TKS ((':) @Nat n ('[] @Nat)) r)))
         (TKProduct
            (TKProduct
               (TKS ((':) @Nat n ('[] @Nat)) r)
               (TKProduct
                  (TKS ((':) @Nat n ('[] @Nat)) r)
                  (TKProduct
                     (TKS ((':) @Nat n ('[] @Nat)) r)
                     (TKProduct
                        (TKS ((':) @Nat n ('[] @Nat)) r)
                        (TKProduct
                           (TKS ((':) @Nat n ('[] @Nat)) r)
                           (TKProduct
                              (TKS ((':) @Nat n ('[] @Nat)) r)
                              (TKProduct
                                 (TKS ((':) @Nat n ('[] @Nat)) r)
                                 (TKProduct
                                    (TKS ((':) @Nat n ('[] @Nat)) r)
                                    (TKProduct
                                       (TKS ((':) @Nat n ('[] @Nat)) r)
                                       (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
            (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))),
    Concrete (TKScalar r))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
              (TKS ((':) @Nat n ('[] @Nat)) r))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) r)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
forall a b. (a -> b) -> a -> b
$ Double
-> (MnistDataLinearR r
    -> ADVal
         Concrete
         (TKProduct
            (TKProduct
               (TKProduct
                  (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
                  (TKS ((':) @Nat n ('[] @Nat)) r))
               (TKProduct
                  (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
                  (TKS ((':) @Nat n ('[] @Nat)) r)))
            (TKProduct
               (TKProduct
                  (TKS ((':) @Nat n ('[] @Nat)) r)
                  (TKProduct
                     (TKS ((':) @Nat n ('[] @Nat)) r)
                     (TKProduct
                        (TKS ((':) @Nat n ('[] @Nat)) r)
                        (TKProduct
                           (TKS ((':) @Nat n ('[] @Nat)) r)
                           (TKProduct
                              (TKS ((':) @Nat n ('[] @Nat)) r)
                              (TKProduct
                                 (TKS ((':) @Nat n ('[] @Nat)) r)
                                 (TKProduct
                                    (TKS ((':) @Nat n ('[] @Nat)) r)
                                    (TKProduct
                                       (TKS ((':) @Nat n ('[] @Nat)) r)
                                       (TKProduct
                                          (TKS ((':) @Nat n ('[] @Nat)) r)
                                          (TKProduct
                                             (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
               (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
    -> ADVal Concrete (TKScalar r))
-> [MnistDataLinearR r]
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
              (TKS ((':) @Nat n ('[] @Nat)) r))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) r)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
               (TKS ((':) @Nat n ('[] @Nat)) r))
            (TKProduct
               (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
               (TKS ((':) @Nat n ('[] @Nat)) r)))
         (TKProduct
            (TKProduct
               (TKS ((':) @Nat n ('[] @Nat)) r)
               (TKProduct
                  (TKS ((':) @Nat n ('[] @Nat)) r)
                  (TKProduct
                     (TKS ((':) @Nat n ('[] @Nat)) r)
                     (TKProduct
                        (TKS ((':) @Nat n ('[] @Nat)) r)
                        (TKProduct
                           (TKS ((':) @Nat n ('[] @Nat)) r)
                           (TKProduct
                              (TKS ((':) @Nat n ('[] @Nat)) r)
                              (TKProduct
                                 (TKS ((':) @Nat n ('[] @Nat)) r)
                                 (TKProduct
                                    (TKS ((':) @Nat n ('[] @Nat)) r)
                                    (TKProduct
                                       (TKS ((':) @Nat n ('[] @Nat)) r)
                                       (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
            (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))),
    Concrete (TKScalar r))
forall a (x :: TK) (z :: TK).
KnownSTK x =>
Double
-> (a -> ADVal Concrete x -> ADVal Concrete z)
-> [a]
-> Concrete x
-> (Concrete x, Concrete z)
sgd Double
gamma MnistDataLinearR r
-> ADVal
     Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
              (TKS ((':) @Nat n ('[] @Nat)) r))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) r)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
-> ADVal Concrete (TKScalar r)
MnistDataLinearR r
-> ADVal Concrete (XParams n n r) -> ADVal Concrete (TKScalar r)
f [MnistDataLinearR r]
chunk Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
Concrete (XParams n n r)
params
              trainScore :: r
trainScore = [MnistDataLinearR r] -> ADFcnnMnist1Parameters Concrete n n r -> r
ftest [MnistDataLinearR r]
chunk (Concrete
  (X ((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) r)),
       Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
      (ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
       Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
      (ListR SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
       Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))))
-> ((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) r)),
     Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
    (ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
     Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
    (ListR SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
     Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
Concrete
  (X ((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) r)),
       Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
      (ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
       Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
      (ListR SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
       Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))))
res)
              testScore :: r
testScore = [MnistDataLinearR r] -> ADFcnnMnist1Parameters Concrete n n r -> r
ftest [MnistDataLinearR r]
testData (Concrete
  (X ((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) r)),
       Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
      (ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
       Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
      (ListR SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
       Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))))
-> ((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) r)),
     Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
    (ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
     Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
    (ListR SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
     Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
Concrete
  (X ((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) r)),
       Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
      (ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
       Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
      (ListR SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
       Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))))
res)
              lenChunk :: Int
lenChunk = [MnistDataLinearR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataLinearR r]
chunk
          Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
widthHiddenInt Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$ do
            Handle -> TestName -> Assertion
hPutStrLn Handle
stderr (TestName -> Assertion) -> TestName -> Assertion
forall a b. (a -> b) -> a -> b
$
              TestName -> TestName -> Int -> Int -> TestName
forall r. PrintfType r => TestName -> r
printf TestName
"\n%s: (Batch %d with %d points)"
                     TestName
prefix Int
k Int
lenChunk
            Handle -> TestName -> Assertion
hPutStrLn Handle
stderr (TestName -> Assertion) -> TestName -> Assertion
forall a b. (a -> b) -> a -> b
$
              TestName -> TestName -> r -> TestName
forall r. PrintfType r => TestName -> r
printf TestName
"%s: Training error:   %.2f%%"
                     TestName
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
trainScore) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
            Handle -> TestName -> Assertion
hPutStrLn Handle
stderr (TestName -> Assertion) -> TestName -> Assertion
forall a b. (a -> b) -> a -> b
$
              TestName -> TestName -> r -> TestName
forall r. PrintfType r => TestName -> r
printf TestName
"%s: Validation error: %.2f%%"
                     TestName
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
testScore ) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
          Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
                 (TKS ((':) @Nat n ('[] @Nat)) r))
              (TKProduct
                 (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
                 (TKS ((':) @Nat n ('[] @Nat)) r)))
           (TKProduct
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) r)
                                         (TKProduct
                                            (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
              (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
res
    let runEpoch :: Int
                 -> Concrete (XParams widthHidden widthHidden2 r)
                 -> IO (Concrete (XParams widthHidden widthHidden2 r))
        runEpoch Int
n Concrete (XParams n n r)
params | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
epochs = Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
                 (TKS ((':) @Nat n ('[] @Nat)) r))
              (TKProduct
                 (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
                 (TKS ((':) @Nat n ('[] @Nat)) r)))
           (TKProduct
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) r)
                                         (TKProduct
                                            (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
              (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
Concrete (XParams n n r)
params
        runEpoch Int
n !Concrete (XParams n n r)
params = do
          Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
widthHiddenInt Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$
            Handle -> TestName -> Assertion
hPutStrLn Handle
stderr (TestName -> Assertion) -> TestName -> Assertion
forall a b. (a -> b) -> a -> b
$ TestName -> TestName -> Int -> TestName
forall r. PrintfType r => TestName -> r
printf TestName
"\n%s: [Epoch %d]" TestName
prefix Int
n
          let trainDataShuffled :: [MnistData r]
trainDataShuffled = StdGen -> [MnistData r] -> [MnistData r]
forall a. StdGen -> [a] -> [a]
shuffle (Int -> StdGen
mkStdGen (Int -> StdGen) -> Int -> StdGen
forall a b. (a -> b) -> a -> b
$ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [MnistData r]
trainData
              chunks :: [(Int, [MnistDataLinearR r])]
chunks = Int
-> [(Int, [MnistDataLinearR r])] -> [(Int, [MnistDataLinearR r])]
forall a. Int -> [a] -> [a]
take Int
maxBatches
                       ([(Int, [MnistDataLinearR r])] -> [(Int, [MnistDataLinearR r])])
-> [(Int, [MnistDataLinearR r])] -> [(Int, [MnistDataLinearR r])]
forall a b. (a -> b) -> a -> b
$ [Int] -> [[MnistDataLinearR r]] -> [(Int, [MnistDataLinearR r])]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
Item [Int]
1 ..] ([[MnistDataLinearR r]] -> [(Int, [MnistDataLinearR r])])
-> [[MnistDataLinearR r]] -> [(Int, [MnistDataLinearR r])]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataLinearR r] -> [[MnistDataLinearR r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
batchSize
                       ([MnistDataLinearR r] -> [[MnistDataLinearR r]])
-> [MnistDataLinearR r] -> [[MnistDataLinearR r]]
forall a b. (a -> b) -> a -> b
$ (MnistData r -> MnistDataLinearR r)
-> [MnistData r] -> [MnistDataLinearR r]
forall a b. (a -> b) -> [a] -> [b]
map MnistData r -> MnistDataLinearR r
forall r. PrimElt r => MnistData r -> MnistDataLinearR r
mkMnistDataLinearR [MnistData r]
trainDataShuffled
          res <- (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
            (TKS ((':) @Nat n ('[] @Nat)) r))
         (TKProduct
            (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
            (TKS ((':) @Nat n ('[] @Nat)) r)))
      (TKProduct
         (TKProduct
            (TKS ((':) @Nat n ('[] @Nat)) r)
            (TKProduct
               (TKS ((':) @Nat n ('[] @Nat)) r)
               (TKProduct
                  (TKS ((':) @Nat n ('[] @Nat)) r)
                  (TKProduct
                     (TKS ((':) @Nat n ('[] @Nat)) r)
                     (TKProduct
                        (TKS ((':) @Nat n ('[] @Nat)) r)
                        (TKProduct
                           (TKS ((':) @Nat n ('[] @Nat)) r)
                           (TKProduct
                              (TKS ((':) @Nat n ('[] @Nat)) r)
                              (TKProduct
                                 (TKS ((':) @Nat n ('[] @Nat)) r)
                                 (TKProduct
                                    (TKS ((':) @Nat n ('[] @Nat)) r)
                                    (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
         (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
 -> (Int, [MnistDataLinearR r])
 -> IO
      (Concrete
         (TKProduct
            (TKProduct
               (TKProduct
                  (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
                  (TKS ((':) @Nat n ('[] @Nat)) r))
               (TKProduct
                  (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
                  (TKS ((':) @Nat n ('[] @Nat)) r)))
            (TKProduct
               (TKProduct
                  (TKS ((':) @Nat n ('[] @Nat)) r)
                  (TKProduct
                     (TKS ((':) @Nat n ('[] @Nat)) r)
                     (TKProduct
                        (TKS ((':) @Nat n ('[] @Nat)) r)
                        (TKProduct
                           (TKS ((':) @Nat n ('[] @Nat)) r)
                           (TKProduct
                              (TKS ((':) @Nat n ('[] @Nat)) r)
                              (TKProduct
                                 (TKS ((':) @Nat n ('[] @Nat)) r)
                                 (TKProduct
                                    (TKS ((':) @Nat n ('[] @Nat)) r)
                                    (TKProduct
                                       (TKS ((':) @Nat n ('[] @Nat)) r)
                                       (TKProduct
                                          (TKS ((':) @Nat n ('[] @Nat)) r)
                                          (TKProduct
                                             (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
               (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
              (TKS ((':) @Nat n ('[] @Nat)) r))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) r)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
-> [(Int, [MnistDataLinearR r])]
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
                 (TKS ((':) @Nat n ('[] @Nat)) r))
              (TKProduct
                 (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
                 (TKS ((':) @Nat n ('[] @Nat)) r)))
           (TKProduct
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) r)
                                         (TKProduct
                                            (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
              (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))))
forall (t :: Type -> Type) (m :: Type -> Type) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
-> (Int, [MnistDataLinearR r])
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
                 (TKS ((':) @Nat n ('[] @Nat)) r))
              (TKProduct
                 (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
                 (TKS ((':) @Nat n ('[] @Nat)) r)))
           (TKProduct
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) r)
                                         (TKProduct
                                            (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
              (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))))
Concrete (XParams n n r)
-> (Int, [MnistDataLinearR r]) -> IO (Concrete (XParams n n r))
runBatch Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
Concrete (XParams n n r)
params [(Int, [MnistDataLinearR r])]
chunks
          runEpoch (succ n) res
    res <- runEpoch 1 targetInit
    let testErrorFinal = r
1 r -> r -> r
forall a. Num a => a -> a -> a
- [MnistDataLinearR r] -> ADFcnnMnist1Parameters Concrete n n r -> r
ftest [MnistDataLinearR r]
testData (Concrete
  (X ((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) r)),
       Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
      (ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
       Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
      (ListR SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
       Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))))
-> ((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) r)),
     Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
    (ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
     Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
    (ListR SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
     Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
Concrete
  (X ((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) r)),
       Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
      (ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
       Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
      (ListR SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
       Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))))
res)
    testErrorFinal @?~ expected

{-# SPECIALIZE mnistTestCase1VTI
  :: String
  -> Int -> Int -> Int -> Int -> Double -> Int -> Double
  -> TestTree #-}

tensorIntermediateMnistTests :: TestTree
tensorIntermediateMnistTests :: TestTree
tensorIntermediateMnistTests = TestName -> [TestTree] -> TestTree
testGroup TestName
"Ranked Intermediate MNIST tests"
  [ TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> Double -> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> r -> TestTree
mnistTestCase1VTI TestName
"VTI1 1 epoch, 1 batch" Int
1 Int
1 Int
300 Int
100 Double
0.02 Int
5000
                      (Double
0.2116 :: Double)
  , TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> Float -> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> r -> TestTree
mnistTestCase1VTI TestName
"VTI1 artificial 1 2 3 4 5" Int
1 Int
2 Int
3 Int
4 Double
5 Int
5000
                      (Float
0.9108 :: Float)
  , TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> Float -> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> r -> TestTree
mnistTestCase1VTI TestName
"VTI1 1 epoch, 0 batch" Int
1 Int
0 Int
300 Int
100 Double
0.02 Int
5000
                      (Float
1 :: Float)
  ]

-- JAX differentiation, Ast term built and differentiated only once
-- and the result interpreted with different inputs in each gradient
-- descent iteration.
mnistTestCase1VTO
  :: forall r.
     ( Differentiable r, GoodScalar r, ADTensorScalar r ~ r
     , PrintfArg r, AssertEqualUpToEpsilon r)
  => String
  -> Int -> Int -> Int -> Int -> Double -> Int -> r
  -> TestTree
mnistTestCase1VTO :: forall r.
(Differentiable r, GoodScalar r,
 (ADTensorScalar r :: Type) ~ (r :: Type), PrintfArg r,
 AssertEqualUpToEpsilon r) =>
TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> r -> TestTree
mnistTestCase1VTO TestName
prefix Int
epochs Int
maxBatches Int
widthHiddenInt Int
widthHidden2Int
                  Double
gamma Int
batchSize r
expected =
  Int
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
widthHiddenInt ((forall (n :: Nat). KnownNat n => SNat n -> TestTree) -> TestTree)
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
widthHiddenSNat :: SNat widthHidden) ->
  Int
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
widthHidden2Int ((forall (n :: Nat). KnownNat n => SNat n -> TestTree) -> TestTree)
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
widthHidden2SNat :: SNat widthHidden2) ->
  SingletonTK (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
-> (KnownSTK (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r)) =>
    TestTree)
-> TestTree
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK
    (SingletonTK (TKS ((':) @Nat 784 ('[] @Nat)) r)
-> SNat n
-> SingletonTK (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
forall (t :: TK) (n :: Nat).
SingletonTK t -> SNat n -> SingletonTK (Tups n t)
stkOfListR (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(TKS '[SizeMnistGlyph] r)) (forall (n :: Nat). KnownNat n => SNat n
SNat @widthHidden)) ((KnownSTK (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r)) => TestTree)
 -> TestTree)
-> (KnownSTK (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r)) =>
    TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$
  SingletonTK (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
-> (KnownSTK (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)) =>
    TestTree)
-> TestTree
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK
    (SingletonTK (TKS ((':) @Nat n ('[] @Nat)) Float)
-> SNat n
-> SingletonTK (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
forall (t :: TK) (n :: Nat).
SingletonTK t -> SNat n -> SingletonTK (Tups n t)
stkOfListR (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(TKS '[widthHidden] Float)) (forall (n :: Nat). KnownNat n => SNat n
SNat @widthHidden2)) ((KnownSTK (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)) =>
  TestTree)
 -> TestTree)
-> (KnownSTK (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)) =>
    TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$
  let valsInit :: MnistFcnnRanked1.ADFcnnMnist1Parameters
                    Concrete widthHidden widthHidden2 r
      valsInit :: ADFcnnMnist1Parameters Concrete n n r
valsInit = (ADFcnnMnist1Parameters Concrete n n r, StdGen)
-> ADFcnnMnist1Parameters Concrete n n r
forall a b. (a, b) -> a
fst ((ADFcnnMnist1Parameters Concrete n n r, StdGen)
 -> ADFcnnMnist1Parameters Concrete n n r)
-> (ADFcnnMnist1Parameters Concrete n n r, StdGen)
-> ADFcnnMnist1Parameters Concrete n n r
forall a b. (a -> b) -> a -> b
$ Double
-> StdGen
-> (((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) r)),
      Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
     (ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
      Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
     (ListR SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
      Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))),
    StdGen)
forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue Double
1 (Int -> StdGen
mkStdGen Int
44)
      targetInit :: Concrete (XParams widthHidden widthHidden2 r)
      targetInit :: Concrete (XParams n n r)
targetInit = forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget @Concrete ((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) r)),
  Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
 (ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
  Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
 (ListR SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
  Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
ADFcnnMnist1Parameters Concrete n n r
valsInit
      name :: TestName
name = TestName
prefix TestName -> TestName -> TestName
forall a. [a] -> [a] -> [a]
++ TestName
": "
             TestName -> TestName -> TestName
forall a. [a] -> [a] -> [a]
++ [TestName] -> TestName
unwords [ Int -> TestName
forall a. Show a => a -> TestName
show Int
epochs, Int -> TestName
forall a. Show a => a -> TestName
show Int
maxBatches
                        , Int -> TestName
forall a. Show a => a -> TestName
show Int
widthHiddenInt, Int -> TestName
forall a. Show a => a -> TestName
show Int
widthHidden2Int
                        , Int -> TestName
forall a. Show a => a -> TestName
show (Int -> TestName) -> Int -> TestName
forall a b. (a -> b) -> a -> b
$ SingletonTK (XParams n n r) -> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK
                          (SingletonTK (XParams n n r) -> Int)
-> SingletonTK (XParams n n r) -> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams widthHidden widthHidden2 r)
                        , Int -> TestName
forall a. Show a => a -> TestName
show (SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
              (TKS ((':) @Nat n ('[] @Nat)) r))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) r)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
Concrete (XParams n n r)
targetInit)
                        , Double -> TestName
forall a. Show a => a -> TestName
show Double
gamma ]
      ftest :: [MnistDataLinearR r]
            -> MnistFcnnRanked1.ADFcnnMnist1Parameters
                 Concrete widthHidden widthHidden2 r
            -> r
      ftest :: [MnistDataLinearR r] -> ADFcnnMnist1Parameters Concrete n n r -> r
ftest = SNat n
-> SNat n
-> [MnistDataLinearR r]
-> ADFcnnMnist1Parameters Concrete n n r
-> r
forall (target :: Target) (widthHidden :: Nat)
       (widthHidden2 :: Nat) r.
((target :: Target) ~ (Concrete :: Target), GoodScalar r,
 Differentiable r) =>
SNat widthHidden
-> SNat widthHidden2
-> [MnistDataLinearR r]
-> ADFcnnMnist1Parameters target widthHidden widthHidden2 r
-> r
MnistFcnnRanked1.afcnnMnistTest1 SNat n
widthHiddenSNat SNat n
widthHidden2SNat
  in TestName -> Assertion -> TestTree
testCase TestName
name (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ do
    Handle -> TestName -> Assertion
hPutStrLn Handle
stderr (TestName -> Assertion) -> TestName -> Assertion
forall a b. (a -> b) -> a -> b
$
      TestName -> TestName -> Int -> Int -> TestName
forall r. PrintfType r => TestName -> r
printf TestName
"\n%s: Epochs to run/max batches per epoch: %d/%d"
             TestName
prefix Int
epochs Int
maxBatches
    trainData <- TestName -> TestName -> IO [MnistData r]
forall r.
(Storable r, Fractional r) =>
TestName -> TestName -> IO [MnistData r]
loadMnistData TestName
trainGlyphsPath TestName
trainLabelsPath
    testData <- map mkMnistDataLinearR . take (batchSize * maxBatches)
                <$> loadMnistData testGlyphsPath testLabelsPath
    let dataInit = case [MnistDataLinearR r]
testData of
          MnistDataLinearR r
d : [MnistDataLinearR r]
_ -> (Ranked 1 r -> Concrete (TKR 1 r)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete (Ranked 1 r -> Concrete (TKR 1 r))
-> (Ranked 1 r -> Concrete (TKR 1 r))
-> MnistDataLinearR r
-> (Concrete (TKR 1 r), Concrete (TKR 1 r))
forall b c b' c'. (b -> c) -> (b' -> c') -> (b, b') -> (c, c')
forall (a :: Type -> Type -> Type) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
*** Ranked 1 r -> Concrete (TKR 1 r)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete) MnistDataLinearR r
d
          [] -> TestName -> (Concrete (TKR 1 r), Concrete (TKR 1 r))
forall a. HasCallStack => TestName -> a
error TestName
"empty test data"
        f :: ( MnistFcnnRanked1.ADFcnnMnist1Parameters
                 (AstTensor AstMethodLet FullSpan)
                 widthHidden widthHidden2 r
             , ( AstTensor AstMethodLet FullSpan (TKR 1 r)
               , AstTensor AstMethodLet FullSpan (TKR 1 r) ) )
          -> AstTensor AstMethodLet FullSpan (TKScalar r)
        f = \ (ADFcnnMnist1Parameters (AstTensor AstMethodLet FullSpan) n n r
pars, (AstTensor AstMethodLet FullSpan (TKR 1 r)
glyphR, AstTensor AstMethodLet FullSpan (TKR 1 r)
labelR)) ->
          SNat n
-> SNat n
-> (AstTensor AstMethodLet FullSpan (TKR 1 r),
    AstTensor AstMethodLet FullSpan (TKR 1 r))
-> ADFcnnMnist1Parameters (AstTensor AstMethodLet FullSpan) n n r
-> AstTensor AstMethodLet FullSpan (TKScalar r)
forall (target :: Target) r (widthHidden :: Nat)
       (widthHidden2 :: Nat).
(ADReady target, GoodScalar r, Differentiable r) =>
SNat widthHidden
-> SNat widthHidden2
-> (target (TKR 1 r), target (TKR 1 r))
-> ADFcnnMnist1Parameters target widthHidden widthHidden2 r
-> target (TKScalar r)
MnistFcnnRanked1.afcnnMnistLoss1
            SNat n
widthHiddenSNat SNat n
widthHidden2SNat
            (AstTensor AstMethodLet FullSpan (TKR 1 r)
glyphR, AstTensor AstMethodLet FullSpan (TKR 1 r)
labelR) ADFcnnMnist1Parameters (AstTensor AstMethodLet FullSpan) n n r
pars
        artRaw = ((((ListR
      n
      (AstTensor
         AstMethodLet FullSpan (TKS ((':) @Nat 784 ('[] @Nat)) r)),
    AstTensor AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) r)),
   (ListR
      n
      (AstTensor
         AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Float)),
    AstTensor AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) r)),
   (ListR
      SizeMnistLabel
      (AstTensor AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) r)),
    AstTensor
      AstMethodLet
      FullSpan
      (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))),
  (AstTensor AstMethodLet FullSpan (TKR 1 r),
   AstTensor AstMethodLet FullSpan (TKR 1 r)))
 -> AstTensor AstMethodLet FullSpan (TKScalar r))
-> Value
     (((ListR
          n
          (AstTensor
             AstMethodLet FullSpan (TKS ((':) @Nat 784 ('[] @Nat)) r)),
        AstTensor AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) r)),
       (ListR
          n
          (AstTensor
             AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Float)),
        AstTensor AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) r)),
       (ListR
          SizeMnistLabel
          (AstTensor AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) r)),
        AstTensor
          AstMethodLet
          FullSpan
          (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))),
      (AstTensor AstMethodLet FullSpan (TKR 1 r),
       AstTensor AstMethodLet FullSpan (TKR 1 r)))
-> AstArtifactRev
     (X (((ListR
             n
             (AstTensor
                AstMethodLet FullSpan (TKS ((':) @Nat 784 ('[] @Nat)) r)),
           AstTensor AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) r)),
          (ListR
             n
             (AstTensor
                AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Float)),
           AstTensor AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) r)),
          (ListR
             SizeMnistLabel
             (AstTensor AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) r)),
           AstTensor
             AstMethodLet
             FullSpan
             (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))),
         (AstTensor AstMethodLet FullSpan (TKR 1 r),
          AstTensor AstMethodLet FullSpan (TKR 1 r))))
     (TKScalar r)
forall src r tgt.
((X src :: TK) ~ (X (Value src) :: TK), KnownSTK (X src),
 AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
 AdaptableTarget Concrete (Value src),
 (tgt :: Type)
 ~ (AstTensor AstMethodLet FullSpan (TKScalar r) :: Type)) =>
(src -> tgt) -> Value src -> AstArtifactRev (X src) (TKScalar r)
gradArtifact (((ListR
     n
     (AstTensor
        AstMethodLet FullSpan (TKS ((':) @Nat 784 ('[] @Nat)) r)),
   AstTensor AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) r)),
  (ListR
     n
     (AstTensor
        AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Float)),
   AstTensor AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) r)),
  (ListR
     SizeMnistLabel
     (AstTensor AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) r)),
   AstTensor
     AstMethodLet
     FullSpan
     (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))),
 (AstTensor AstMethodLet FullSpan (TKR 1 r),
  AstTensor AstMethodLet FullSpan (TKR 1 r)))
-> AstTensor AstMethodLet FullSpan (TKScalar r)
(ADFcnnMnist1Parameters (AstTensor AstMethodLet FullSpan) n n r,
 (AstTensor AstMethodLet FullSpan (TKR 1 r),
  AstTensor AstMethodLet FullSpan (TKR 1 r)))
-> AstTensor AstMethodLet FullSpan (TKScalar r)
f (((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) r)),
  Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
 (ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
  Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
 (ListR SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
  Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
ADFcnnMnist1Parameters Concrete n n r
valsInit, (Concrete (TKR 1 r), Concrete (TKR 1 r))
dataInit)
        art = AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
              (TKS ((':) @Nat n ('[] @Nat)) r))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) r)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
     (TKProduct (TKR 1 r) (TKR 1 r)))
  (TKScalar r)
-> AstArtifactRev
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
                 (TKS ((':) @Nat n ('[] @Nat)) r))
              (TKProduct
                 (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
                 (TKS ((':) @Nat n ('[] @Nat)) r)))
           (TKProduct
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) r)
                                         (TKProduct
                                            (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
              (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
        (TKProduct (TKR 1 r) (TKR 1 r)))
     (TKScalar r)
forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstArtifactRev x z
simplifyArtifactGradient AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
              (TKS ((':) @Nat n ('[] @Nat)) r))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) r)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
     (TKProduct (TKR 1 r) (TKR 1 r)))
  (TKScalar r)
AstArtifactRev
  (X (((ListR
          n
          (AstTensor
             AstMethodLet FullSpan (TKS ((':) @Nat 784 ('[] @Nat)) r)),
        AstTensor AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) r)),
       (ListR
          n
          (AstTensor
             AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Float)),
        AstTensor AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) r)),
       (ListR
          SizeMnistLabel
          (AstTensor AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) r)),
        AstTensor
          AstMethodLet
          FullSpan
          (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))),
      (AstTensor AstMethodLet FullSpan (TKR 1 r),
       AstTensor AstMethodLet FullSpan (TKR 1 r))))
  (TKScalar r)
artRaw
        go :: [MnistDataLinearR r]
           -> Concrete (XParams widthHidden widthHidden2 r)
           -> Concrete (XParams widthHidden widthHidden2 r)
        go [] Concrete (XParams n n r)
parameters = Concrete (XParams n n r)
parameters
        go ((Ranked 1 r
glyph, Ranked 1 r
label) : [MnistDataLinearR r]
rest) !Concrete (XParams n n r)
parameters =
          let parametersAndInput :: Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
              (TKS ((':) @Nat n ('[] @Nat)) r))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) r)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
     (TKProduct (TKR 1 r) (TKR 1 r)))
parametersAndInput =
                Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
-> Concrete (TKProduct (TKR 1 r) (TKR 1 r))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
                 (TKS ((':) @Nat n ('[] @Nat)) r))
              (TKProduct
                 (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
                 (TKS ((':) @Nat n ('[] @Nat)) r)))
           (TKProduct
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) r)
                                         (TKProduct
                                            (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
              (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
        (TKProduct (TKR 1 r) (TKR 1 r)))
forall (x :: TK) (z :: TK).
Concrete x -> Concrete z -> Concrete (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
Concrete (XParams n n r)
parameters (Concrete (TKR 1 r)
-> Concrete (TKR 1 r) -> Concrete (TKProduct (TKR 1 r) (TKR 1 r))
forall (x :: TK) (z :: TK).
Concrete x -> Concrete z -> Concrete (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair (Ranked 1 r -> Concrete (TKR 1 r)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 1 r
glyph) (Ranked 1 r -> Concrete (TKR 1 r)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 1 r
label))
              gradient :: Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (ADTensorKind (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r)))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (ADTensorKind (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
gradient = Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (ADTensorKind (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r)))
              (TKS ((':) @Nat n ('[] @Nat)) r))
           (TKProduct
              (ADTensorKind (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)))
              (TKS ((':) @Nat n ('[] @Nat)) r)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
     (TKProduct (TKR 1 r) (TKR 1 r)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (ADTensorKind (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r)))
              (TKS ((':) @Nat n ('[] @Nat)) r))
           (TKProduct
              (ADTensorKind (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)))
              (TKS ((':) @Nat n ('[] @Nat)) r)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
forall (x :: TK) (z :: TK). Concrete (TKProduct x z) -> Concrete x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct
               (ADTensorKind (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r)))
               (TKS ((':) @Nat n ('[] @Nat)) r))
            (TKProduct
               (ADTensorKind (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)))
               (TKS ((':) @Nat n ('[] @Nat)) r)))
         (TKProduct
            (TKProduct
               (TKS ((':) @Nat n ('[] @Nat)) r)
               (TKProduct
                  (TKS ((':) @Nat n ('[] @Nat)) r)
                  (TKProduct
                     (TKS ((':) @Nat n ('[] @Nat)) r)
                     (TKProduct
                        (TKS ((':) @Nat n ('[] @Nat)) r)
                        (TKProduct
                           (TKS ((':) @Nat n ('[] @Nat)) r)
                           (TKProduct
                              (TKS ((':) @Nat n ('[] @Nat)) r)
                              (TKProduct
                                 (TKS ((':) @Nat n ('[] @Nat)) r)
                                 (TKProduct
                                    (TKS ((':) @Nat n ('[] @Nat)) r)
                                    (TKProduct
                                       (TKS ((':) @Nat n ('[] @Nat)) r)
                                       (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
            (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
      (TKProduct (TKR 1 r) (TKR 1 r)))
 -> Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (ADTensorKind (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r)))
               (TKS ((':) @Nat n ('[] @Nat)) r))
            (TKProduct
               (ADTensorKind (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)))
               (TKS ((':) @Nat n ('[] @Nat)) r)))
         (TKProduct
            (TKProduct
               (TKS ((':) @Nat n ('[] @Nat)) r)
               (TKProduct
                  (TKS ((':) @Nat n ('[] @Nat)) r)
                  (TKProduct
                     (TKS ((':) @Nat n ('[] @Nat)) r)
                     (TKProduct
                        (TKS ((':) @Nat n ('[] @Nat)) r)
                        (TKProduct
                           (TKS ((':) @Nat n ('[] @Nat)) r)
                           (TKProduct
                              (TKS ((':) @Nat n ('[] @Nat)) r)
                              (TKProduct
                                 (TKS ((':) @Nat n ('[] @Nat)) r)
                                 (TKProduct
                                    (TKS ((':) @Nat n ('[] @Nat)) r)
                                    (TKProduct
                                       (TKS ((':) @Nat n ('[] @Nat)) r)
                                       (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
            (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (ADTensorKind (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r)))
                 (TKS ((':) @Nat n ('[] @Nat)) r))
              (TKProduct
                 (ADTensorKind (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)))
                 (TKS ((':) @Nat n ('[] @Nat)) r)))
           (TKProduct
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) r)
                                         (TKProduct
                                            (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
              (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
        (TKProduct (TKR 1 r) (TKR 1 r)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (ADTensorKind (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r)))
              (TKS ((':) @Nat n ('[] @Nat)) r))
           (TKProduct
              (ADTensorKind (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)))
              (TKS ((':) @Nat n ('[] @Nat)) r)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
forall a b. (a -> b) -> a -> b
$ (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct
               (ADTensorKind (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r)))
               (TKS ((':) @Nat n ('[] @Nat)) r))
            (TKProduct
               (ADTensorKind (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)))
               (TKS ((':) @Nat n ('[] @Nat)) r)))
         (TKProduct
            (TKProduct
               (TKS ((':) @Nat n ('[] @Nat)) r)
               (TKProduct
                  (TKS ((':) @Nat n ('[] @Nat)) r)
                  (TKProduct
                     (TKS ((':) @Nat n ('[] @Nat)) r)
                     (TKProduct
                        (TKS ((':) @Nat n ('[] @Nat)) r)
                        (TKProduct
                           (TKS ((':) @Nat n ('[] @Nat)) r)
                           (TKProduct
                              (TKS ((':) @Nat n ('[] @Nat)) r)
                              (TKProduct
                                 (TKS ((':) @Nat n ('[] @Nat)) r)
                                 (TKProduct
                                    (TKS ((':) @Nat n ('[] @Nat)) r)
                                    (TKProduct
                                       (TKS ((':) @Nat n ('[] @Nat)) r)
                                       (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
            (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
      (TKProduct (TKR 1 r) (TKR 1 r))),
 Concrete (TKScalar r))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (ADTensorKind (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r)))
                 (TKS ((':) @Nat n ('[] @Nat)) r))
              (TKProduct
                 (ADTensorKind (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)))
                 (TKS ((':) @Nat n ('[] @Nat)) r)))
           (TKProduct
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) r)
                                         (TKProduct
                                            (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
              (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
        (TKProduct (TKR 1 r) (TKR 1 r)))
forall a b. (a, b) -> a
fst
                         ((Concrete
    (TKProduct
       (TKProduct
          (TKProduct
             (TKProduct
                (ADTensorKind (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r)))
                (TKS ((':) @Nat n ('[] @Nat)) r))
             (TKProduct
                (ADTensorKind (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)))
                (TKS ((':) @Nat n ('[] @Nat)) r)))
          (TKProduct
             (TKProduct
                (TKS ((':) @Nat n ('[] @Nat)) r)
                (TKProduct
                   (TKS ((':) @Nat n ('[] @Nat)) r)
                   (TKProduct
                      (TKS ((':) @Nat n ('[] @Nat)) r)
                      (TKProduct
                         (TKS ((':) @Nat n ('[] @Nat)) r)
                         (TKProduct
                            (TKS ((':) @Nat n ('[] @Nat)) r)
                            (TKProduct
                               (TKS ((':) @Nat n ('[] @Nat)) r)
                               (TKProduct
                                  (TKS ((':) @Nat n ('[] @Nat)) r)
                                  (TKProduct
                                     (TKS ((':) @Nat n ('[] @Nat)) r)
                                     (TKProduct
                                        (TKS ((':) @Nat n ('[] @Nat)) r)
                                        (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
             (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
       (TKProduct (TKR 1 r) (TKR 1 r))),
  Concrete (TKScalar r))
 -> Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct
                  (ADTensorKind (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r)))
                  (TKS ((':) @Nat n ('[] @Nat)) r))
               (TKProduct
                  (ADTensorKind (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)))
                  (TKS ((':) @Nat n ('[] @Nat)) r)))
            (TKProduct
               (TKProduct
                  (TKS ((':) @Nat n ('[] @Nat)) r)
                  (TKProduct
                     (TKS ((':) @Nat n ('[] @Nat)) r)
                     (TKProduct
                        (TKS ((':) @Nat n ('[] @Nat)) r)
                        (TKProduct
                           (TKS ((':) @Nat n ('[] @Nat)) r)
                           (TKProduct
                              (TKS ((':) @Nat n ('[] @Nat)) r)
                              (TKProduct
                                 (TKS ((':) @Nat n ('[] @Nat)) r)
                                 (TKProduct
                                    (TKS ((':) @Nat n ('[] @Nat)) r)
                                    (TKProduct
                                       (TKS ((':) @Nat n ('[] @Nat)) r)
                                       (TKProduct
                                          (TKS ((':) @Nat n ('[] @Nat)) r)
                                          (TKProduct
                                             (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
               (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
         (TKProduct (TKR 1 r) (TKR 1 r))))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct
                  (ADTensorKind (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r)))
                  (TKS ((':) @Nat n ('[] @Nat)) r))
               (TKProduct
                  (ADTensorKind (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)))
                  (TKS ((':) @Nat n ('[] @Nat)) r)))
            (TKProduct
               (TKProduct
                  (TKS ((':) @Nat n ('[] @Nat)) r)
                  (TKProduct
                     (TKS ((':) @Nat n ('[] @Nat)) r)
                     (TKProduct
                        (TKS ((':) @Nat n ('[] @Nat)) r)
                        (TKProduct
                           (TKS ((':) @Nat n ('[] @Nat)) r)
                           (TKProduct
                              (TKS ((':) @Nat n ('[] @Nat)) r)
                              (TKProduct
                                 (TKS ((':) @Nat n ('[] @Nat)) r)
                                 (TKProduct
                                    (TKS ((':) @Nat n ('[] @Nat)) r)
                                    (TKProduct
                                       (TKS ((':) @Nat n ('[] @Nat)) r)
                                       (TKProduct
                                          (TKS ((':) @Nat n ('[] @Nat)) r)
                                          (TKProduct
                                             (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
               (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
         (TKProduct (TKR 1 r) (TKR 1 r))),
    Concrete (TKScalar r))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (ADTensorKind (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r)))
                 (TKS ((':) @Nat n ('[] @Nat)) r))
              (TKProduct
                 (ADTensorKind (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)))
                 (TKS ((':) @Nat n ('[] @Nat)) r)))
           (TKProduct
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) r)
                                         (TKProduct
                                            (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
              (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
        (TKProduct (TKR 1 r) (TKR 1 r)))
forall a b. (a -> b) -> a -> b
$ AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
              (TKS ((':) @Nat n ('[] @Nat)) r))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) r)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
     (TKProduct (TKR 1 r) (TKR 1 r)))
  (TKScalar r)
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
                 (TKS ((':) @Nat n ('[] @Nat)) r))
              (TKProduct
                 (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
                 (TKS ((':) @Nat n ('[] @Nat)) r)))
           (TKProduct
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) r)
                                         (TKProduct
                                            (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
              (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
        (TKProduct (TKR 1 r) (TKR 1 r)))
-> Maybe (Concrete (ADTensorKind (TKScalar r)))
-> (Concrete
      (ADTensorKind
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKProduct
                     (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
                     (TKS ((':) @Nat n ('[] @Nat)) r))
                  (TKProduct
                     (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
                     (TKS ((':) @Nat n ('[] @Nat)) r)))
               (TKProduct
                  (TKProduct
                     (TKS ((':) @Nat n ('[] @Nat)) r)
                     (TKProduct
                        (TKS ((':) @Nat n ('[] @Nat)) r)
                        (TKProduct
                           (TKS ((':) @Nat n ('[] @Nat)) r)
                           (TKProduct
                              (TKS ((':) @Nat n ('[] @Nat)) r)
                              (TKProduct
                                 (TKS ((':) @Nat n ('[] @Nat)) r)
                                 (TKProduct
                                    (TKS ((':) @Nat n ('[] @Nat)) r)
                                    (TKProduct
                                       (TKS ((':) @Nat n ('[] @Nat)) r)
                                       (TKProduct
                                          (TKS ((':) @Nat n ('[] @Nat)) r)
                                          (TKProduct
                                             (TKS ((':) @Nat n ('[] @Nat)) r)
                                             (TKProduct
                                                (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
                  (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
            (TKProduct (TKR 1 r) (TKR 1 r)))),
    Concrete (TKScalar r))
forall (x :: TK) (z :: TK).
AstArtifactRev x z
-> Concrete x
-> Maybe (Concrete (ADTensorKind z))
-> (Concrete (ADTensorKind x), Concrete z)
revInterpretArtifact AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
              (TKS ((':) @Nat n ('[] @Nat)) r))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) r)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
     (TKProduct (TKR 1 r) (TKR 1 r)))
  (TKScalar r)
art Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
              (TKS ((':) @Nat n ('[] @Nat)) r))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) r)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
     (TKProduct (TKR 1 r) (TKR 1 r)))
parametersAndInput Maybe (Concrete (ADTensorKind (TKScalar r)))
Maybe (Concrete (TKScalar r))
forall a. Maybe a
Nothing
          in [MnistDataLinearR r]
-> Concrete (XParams n n r) -> Concrete (XParams n n r)
go [MnistDataLinearR r]
rest (Double
-> SingletonTK
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
              (TKS ((':) @Nat n ('[] @Nat)) r))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) r)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
              (TKS ((':) @Nat n ('[] @Nat)) r))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) r)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
-> Concrete
     (ADTensorKind
        (TKProduct
           (TKProduct
              (TKProduct
                 (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
                 (TKS ((':) @Nat n ('[] @Nat)) r))
              (TKProduct
                 (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
                 (TKS ((':) @Nat n ('[] @Nat)) r)))
           (TKProduct
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) r)
                                         (TKProduct
                                            (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
              (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
              (TKS ((':) @Nat n ('[] @Nat)) r))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) r)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
forall (y :: TK).
Double
-> SingletonTK y
-> Concrete y
-> Concrete (ADTensorKind y)
-> Concrete y
updateWithGradient Double
gamma SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
Concrete (XParams n n r)
parameters Concrete
  (ADTensorKind
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
              (TKS ((':) @Nat n ('[] @Nat)) r))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) r)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))))
Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (ADTensorKind (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r)))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (ADTensorKind (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
gradient)
    let runBatch :: Concrete (XParams widthHidden widthHidden2 r)
                 -> (Int, [MnistDataLinearR r])
                 -> IO (Concrete (XParams widthHidden widthHidden2 r))
        runBatch !Concrete (XParams n n r)
params (Int
k, [MnistDataLinearR r]
chunk) = do
          let res :: Concrete (XParams n n r)
res = [MnistDataLinearR r]
-> Concrete (XParams n n r) -> Concrete (XParams n n r)
go [MnistDataLinearR r]
chunk Concrete (XParams n n r)
params
              trainScore :: r
trainScore = [MnistDataLinearR r] -> ADFcnnMnist1Parameters Concrete n n r -> r
ftest [MnistDataLinearR r]
chunk (Concrete
  (X ((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) r)),
       Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
      (ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
       Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
      (ListR SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
       Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))))
-> ((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) r)),
     Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
    (ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
     Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
    (ListR SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
     Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget Concrete
  (X ((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) r)),
       Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
      (ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
       Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
      (ListR SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
       Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))))
Concrete (XParams n n r)
res)
              testScore :: r
testScore = [MnistDataLinearR r] -> ADFcnnMnist1Parameters Concrete n n r -> r
ftest [MnistDataLinearR r]
testData (Concrete
  (X ((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) r)),
       Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
      (ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
       Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
      (ListR SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
       Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))))
-> ((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) r)),
     Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
    (ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
     Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
    (ListR SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
     Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget Concrete
  (X ((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) r)),
       Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
      (ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
       Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
      (ListR SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
       Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))))
Concrete (XParams n n r)
res)
              lenChunk :: Int
lenChunk = [MnistDataLinearR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataLinearR r]
chunk
          Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
widthHiddenInt Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$ do
            Handle -> TestName -> Assertion
hPutStrLn Handle
stderr (TestName -> Assertion) -> TestName -> Assertion
forall a b. (a -> b) -> a -> b
$
              TestName -> TestName -> Int -> Int -> TestName
forall r. PrintfType r => TestName -> r
printf TestName
"\n%s: (Batch %d with %d points)"
                     TestName
prefix Int
k Int
lenChunk
            Handle -> TestName -> Assertion
hPutStrLn Handle
stderr (TestName -> Assertion) -> TestName -> Assertion
forall a b. (a -> b) -> a -> b
$
              TestName -> TestName -> r -> TestName
forall r. PrintfType r => TestName -> r
printf TestName
"%s: Training error:   %.2f%%"
                     TestName
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
trainScore) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
            Handle -> TestName -> Assertion
hPutStrLn Handle
stderr (TestName -> Assertion) -> TestName -> Assertion
forall a b. (a -> b) -> a -> b
$
              TestName -> TestName -> r -> TestName
forall r. PrintfType r => TestName -> r
printf TestName
"%s: Validation error: %.2f%%"
                     TestName
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
testScore ) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
          Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
                 (TKS ((':) @Nat n ('[] @Nat)) r))
              (TKProduct
                 (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
                 (TKS ((':) @Nat n ('[] @Nat)) r)))
           (TKProduct
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) r)
                                         (TKProduct
                                            (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
              (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
Concrete (XParams n n r)
res
    let runEpoch :: Int
                 -> Concrete (XParams widthHidden widthHidden2 r)
                 -> IO (Concrete (XParams widthHidden widthHidden2 r))
        runEpoch Int
n Concrete (XParams n n r)
params | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
epochs = Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
                 (TKS ((':) @Nat n ('[] @Nat)) r))
              (TKProduct
                 (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
                 (TKS ((':) @Nat n ('[] @Nat)) r)))
           (TKProduct
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) r)
                                         (TKProduct
                                            (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
              (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
Concrete (XParams n n r)
params
        runEpoch Int
n !Concrete (XParams n n r)
params = do
          Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
widthHiddenInt Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$
            Handle -> TestName -> Assertion
hPutStrLn Handle
stderr (TestName -> Assertion) -> TestName -> Assertion
forall a b. (a -> b) -> a -> b
$ TestName -> TestName -> Int -> TestName
forall r. PrintfType r => TestName -> r
printf TestName
"\n%s: [Epoch %d]" TestName
prefix Int
n
          let trainDataShuffled :: [MnistData r]
trainDataShuffled = StdGen -> [MnistData r] -> [MnistData r]
forall a. StdGen -> [a] -> [a]
shuffle (Int -> StdGen
mkStdGen (Int -> StdGen) -> Int -> StdGen
forall a b. (a -> b) -> a -> b
$ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [MnistData r]
trainData
              chunks :: [(Int, [MnistDataLinearR r])]
chunks = Int
-> [(Int, [MnistDataLinearR r])] -> [(Int, [MnistDataLinearR r])]
forall a. Int -> [a] -> [a]
take Int
maxBatches
                       ([(Int, [MnistDataLinearR r])] -> [(Int, [MnistDataLinearR r])])
-> [(Int, [MnistDataLinearR r])] -> [(Int, [MnistDataLinearR r])]
forall a b. (a -> b) -> a -> b
$ [Int] -> [[MnistDataLinearR r]] -> [(Int, [MnistDataLinearR r])]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
Item [Int]
1 ..] ([[MnistDataLinearR r]] -> [(Int, [MnistDataLinearR r])])
-> [[MnistDataLinearR r]] -> [(Int, [MnistDataLinearR r])]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataLinearR r] -> [[MnistDataLinearR r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
batchSize
                       ([MnistDataLinearR r] -> [[MnistDataLinearR r]])
-> [MnistDataLinearR r] -> [[MnistDataLinearR r]]
forall a b. (a -> b) -> a -> b
$ (MnistData r -> MnistDataLinearR r)
-> [MnistData r] -> [MnistDataLinearR r]
forall a b. (a -> b) -> [a] -> [b]
map MnistData r -> MnistDataLinearR r
forall r. PrimElt r => MnistData r -> MnistDataLinearR r
mkMnistDataLinearR [MnistData r]
trainDataShuffled
          res <- (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
            (TKS ((':) @Nat n ('[] @Nat)) r))
         (TKProduct
            (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
            (TKS ((':) @Nat n ('[] @Nat)) r)))
      (TKProduct
         (TKProduct
            (TKS ((':) @Nat n ('[] @Nat)) r)
            (TKProduct
               (TKS ((':) @Nat n ('[] @Nat)) r)
               (TKProduct
                  (TKS ((':) @Nat n ('[] @Nat)) r)
                  (TKProduct
                     (TKS ((':) @Nat n ('[] @Nat)) r)
                     (TKProduct
                        (TKS ((':) @Nat n ('[] @Nat)) r)
                        (TKProduct
                           (TKS ((':) @Nat n ('[] @Nat)) r)
                           (TKProduct
                              (TKS ((':) @Nat n ('[] @Nat)) r)
                              (TKProduct
                                 (TKS ((':) @Nat n ('[] @Nat)) r)
                                 (TKProduct
                                    (TKS ((':) @Nat n ('[] @Nat)) r)
                                    (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
         (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
 -> (Int, [MnistDataLinearR r])
 -> IO
      (Concrete
         (TKProduct
            (TKProduct
               (TKProduct
                  (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
                  (TKS ((':) @Nat n ('[] @Nat)) r))
               (TKProduct
                  (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
                  (TKS ((':) @Nat n ('[] @Nat)) r)))
            (TKProduct
               (TKProduct
                  (TKS ((':) @Nat n ('[] @Nat)) r)
                  (TKProduct
                     (TKS ((':) @Nat n ('[] @Nat)) r)
                     (TKProduct
                        (TKS ((':) @Nat n ('[] @Nat)) r)
                        (TKProduct
                           (TKS ((':) @Nat n ('[] @Nat)) r)
                           (TKProduct
                              (TKS ((':) @Nat n ('[] @Nat)) r)
                              (TKProduct
                                 (TKS ((':) @Nat n ('[] @Nat)) r)
                                 (TKProduct
                                    (TKS ((':) @Nat n ('[] @Nat)) r)
                                    (TKProduct
                                       (TKS ((':) @Nat n ('[] @Nat)) r)
                                       (TKProduct
                                          (TKS ((':) @Nat n ('[] @Nat)) r)
                                          (TKProduct
                                             (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
               (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
              (TKS ((':) @Nat n ('[] @Nat)) r))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) r)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
-> [(Int, [MnistDataLinearR r])]
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
                 (TKS ((':) @Nat n ('[] @Nat)) r))
              (TKProduct
                 (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
                 (TKS ((':) @Nat n ('[] @Nat)) r)))
           (TKProduct
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) r)
                                         (TKProduct
                                            (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
              (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))))
forall (t :: Type -> Type) (m :: Type -> Type) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
-> (Int, [MnistDataLinearR r])
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
                 (TKS ((':) @Nat n ('[] @Nat)) r))
              (TKProduct
                 (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
                 (TKS ((':) @Nat n ('[] @Nat)) r)))
           (TKProduct
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) r)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) r)
                                         (TKProduct
                                            (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
              (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))))
Concrete (XParams n n r)
-> (Int, [MnistDataLinearR r]) -> IO (Concrete (XParams n n r))
runBatch Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
Concrete (XParams n n r)
params [(Int, [MnistDataLinearR r])]
chunks
          runEpoch (succ n) res
    res <- runEpoch 1 targetInit
    let testErrorFinal = r
1 r -> r -> r
forall a. Num a => a -> a -> a
- [MnistDataLinearR r] -> ADFcnnMnist1Parameters Concrete n n r -> r
ftest [MnistDataLinearR r]
testData (Concrete
  (X ((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) r)),
       Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
      (ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
       Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
      (ListR SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
       Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))))
-> ((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) r)),
     Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
    (ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
     Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
    (ListR SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
     Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) r))
           (TKS ((':) @Nat n ('[] @Nat)) r))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) r)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) r)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) r)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) r)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) r)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) r)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) r)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) r)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) r)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) r)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) r) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)))
Concrete
  (X ((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) r)),
       Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
      (ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
       Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
      (ListR SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) r)),
       Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r))))
res)
    testErrorFinal @?~ expected

{-# SPECIALIZE mnistTestCase1VTO
  :: String
  -> Int -> Int -> Int -> Int -> Double -> Int -> Double
  -> TestTree #-}

tensorADOnceMnistTests :: TestTree
tensorADOnceMnistTests :: TestTree
tensorADOnceMnistTests = TestName -> [TestTree] -> TestTree
testGroup TestName
"Ranked Once MNIST tests"
  [ TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> Double -> TestTree
forall r.
(Differentiable r, GoodScalar r,
 (ADTensorScalar r :: Type) ~ (r :: Type), PrintfArg r,
 AssertEqualUpToEpsilon r) =>
TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> r -> TestTree
mnistTestCase1VTO TestName
"VTO1 1 epoch, 1 batch" Int
1 Int
1 Int
300 Int
100 Double
0.02 Int
5000
                      (Double
0.2116 :: Double)
  , TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> Float -> TestTree
forall r.
(Differentiable r, GoodScalar r,
 (ADTensorScalar r :: Type) ~ (r :: Type), PrintfArg r,
 AssertEqualUpToEpsilon r) =>
TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> r -> TestTree
mnistTestCase1VTO TestName
"VTO1 artificial 1 2 3 4 5" Int
1 Int
2 Int
3 Int
4 Double
5 Int
5000
                      (Float
0.9108 :: Float)
  , TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> Float -> TestTree
forall r.
(Differentiable r, GoodScalar r,
 (ADTensorScalar r :: Type) ~ (r :: Type), PrintfArg r,
 AssertEqualUpToEpsilon r) =>
TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> r -> TestTree
mnistTestCase1VTO TestName
"VTO1 1 epoch, 0 batch" Int
1 Int
0 Int
300 Int
100 Double
0.02 Int
5000
                      (Float
1 :: Float)
  ]


-- * Using matrices, which is rank 2

-- POPL differentiation, straight via the ADVal instance of RankedTensor,
-- which side-steps vectorization.
mnistTestCase2VTA
  :: forall r.
     ( Differentiable r, GoodScalar r
     , PrintfArg r, AssertEqualUpToEpsilon r )
  => String
  -> Int -> Int -> Int -> Int -> Double -> Int -> r
  -> TestTree
mnistTestCase2VTA :: forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> r -> TestTree
mnistTestCase2VTA TestName
prefix Int
epochs Int
maxBatches Int
widthHidden Int
widthHidden2
                  Double
gamma Int
batchSize r
expected =
  Int
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
widthHidden ((forall (n :: Nat). KnownNat n => SNat n -> TestTree) -> TestTree)
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat @widthHidden) ->
  Int
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
widthHidden2 ((forall (n :: Nat). KnownNat n => SNat n -> TestTree) -> TestTree)
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat @widthHidden2) ->
  let targetInit :: NoShape
  (Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
targetInit =
        Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
           (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
           (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> NoShape
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
            (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
         (TKProduct
            (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
            (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
      (TKProduct
         (TKS2
            ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
         (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
 -> NoShape
      (Concrete
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
                  (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
               (TKProduct
                  (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
                  (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
            (TKProduct
               (TKS2
                  ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
               (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> NoShape
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
forall a b. (a -> b) -> a -> b
$ (Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r Float)),
 StdGen)
-> Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r Float))
forall a b. (a, b) -> a
fst
        ((Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r Float)),
  StdGen)
 -> Concrete
      (X (ADFcnnMnist2ParametersShaped Concrete n n r Float)))
-> (Concrete
      (X (ADFcnnMnist2ParametersShaped Concrete n n r Float)),
    StdGen)
-> Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r Float))
forall a b. (a -> b) -> a -> b
$ forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue
            @(Concrete (X (MnistFcnnRanked2.ADFcnnMnist2ParametersShaped
                             Concrete widthHidden widthHidden2 r Float)))
            Double
1 (Int -> StdGen
mkStdGen Int
44)
      name :: TestName
name = TestName
prefix TestName -> TestName -> TestName
forall a. [a] -> [a] -> [a]
++ TestName
": "
             TestName -> TestName -> TestName
forall a. [a] -> [a] -> [a]
++ [TestName] -> TestName
unwords [ Int -> TestName
forall a. Show a => a -> TestName
show Int
epochs, Int -> TestName
forall a. Show a => a -> TestName
show Int
maxBatches
                        , Int -> TestName
forall a. Show a => a -> TestName
show Int
widthHidden, Int -> TestName
forall a. Show a => a -> TestName
show Int
widthHidden2
                        , Int -> TestName
forall a. Show a => a -> TestName
show (Int -> TestName) -> Int -> TestName
forall a b. (a -> b) -> a -> b
$ SingletonTK (XParams2 r Float) -> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK (SingletonTK (XParams2 r Float) -> Int)
-> SingletonTK (XParams2 r Float) -> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams2 r Float)
                        , Int -> TestName
forall a. Show a => a -> TestName
show (SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
NoShape
  (Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
targetInit)
                        , Double -> TestName
forall a. Show a => a -> TestName
show Double
gamma ]
  in TestName -> Assertion -> TestTree
testCase TestName
name (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ do
    Handle -> TestName -> Assertion
hPutStrLn Handle
stderr (TestName -> Assertion) -> TestName -> Assertion
forall a b. (a -> b) -> a -> b
$
      TestName -> TestName -> Int -> Int -> TestName
forall r. PrintfType r => TestName -> r
printf TestName
"\n%s: Epochs to run/max batches per epoch: %d/%d"
             TestName
prefix Int
epochs Int
maxBatches
    trainData <- TestName -> TestName -> IO [MnistData r]
forall r.
(Storable r, Fractional r) =>
TestName -> TestName -> IO [MnistData r]
loadMnistData TestName
trainGlyphsPath TestName
trainLabelsPath
    testData <- map mkMnistDataLinearR . take (batchSize * maxBatches)
                <$> loadMnistData testGlyphsPath testLabelsPath
    let f :: MnistDataLinearR r -> ADVal Concrete (XParams2 r Float)
          -> ADVal Concrete (TKScalar r)
        f (Ranked 1 r
glyph, Ranked 1 r
label) ADVal Concrete (XParams2 r Float)
adinputs =
          (ADVal Concrete (TKR2 1 (TKScalar r)),
 ADVal Concrete (TKR2 1 (TKScalar r)))
-> ADFcnnMnist2Parameters (ADVal Concrete) r Float
-> ADVal Concrete (TKScalar r)
forall (target :: Target) r q.
(ADReady target, GoodScalar r, Differentiable r, GoodScalar q,
 Differentiable q) =>
(target (TKR 1 r), target (TKR 1 r))
-> ADFcnnMnist2Parameters target r q -> target (TKScalar r)
MnistFcnnRanked2.afcnnMnistLoss2
            (Ranked 1 r -> ADVal Concrete (TKR2 1 (TKScalar r))
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 1 r
glyph, Ranked 1 r -> ADVal Concrete (TKR2 1 (TKScalar r))
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 1 r
label) (ADVal
  Concrete (X (ADFcnnMnist2Parameters (ADVal Concrete) r Float))
-> ADFcnnMnist2Parameters (ADVal Concrete) r Float
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget ADVal
  Concrete (X (ADFcnnMnist2Parameters (ADVal Concrete) r Float))
ADVal Concrete (XParams2 r Float)
adinputs)
    let runBatch :: Concrete (XParams2 r Float) -> (Int, [MnistDataLinearR r])
                 -> IO (Concrete (XParams2 r Float))
        runBatch !Concrete (XParams2 r Float)
params (Int
k, [MnistDataLinearR r]
chunk) = do
          let res :: Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
res = (Concrete
   (TKProduct
      (TKProduct
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
      (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
 Concrete (TKScalar r))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
forall a b. (a, b) -> a
fst ((Concrete
    (TKProduct
       (TKProduct
          (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
          (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
       (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
  Concrete (TKScalar r))
 -> Concrete
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
            (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
            (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
    Concrete (TKScalar r))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
forall a b. (a -> b) -> a -> b
$ Double
-> (MnistDataLinearR r
    -> ADVal
         Concrete
         (TKProduct
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
               (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
    -> ADVal Concrete (TKScalar r))
-> [MnistDataLinearR r]
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
            (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
    Concrete (TKScalar r))
forall a (x :: TK) (z :: TK).
KnownSTK x =>
Double
-> (a -> ADVal Concrete x -> ADVal Concrete z)
-> [a]
-> Concrete x
-> (Concrete x, Concrete z)
sgd Double
gamma MnistDataLinearR r
-> ADVal
     Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> ADVal Concrete (TKScalar r)
MnistDataLinearR r
-> ADVal Concrete (XParams2 r Float) -> ADVal Concrete (TKScalar r)
f [MnistDataLinearR r]
chunk Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
Concrete (XParams2 r Float)
params
              trainScore :: r
trainScore =
                [MnistDataLinearR r]
-> ADFcnnMnist2Parameters Concrete r Float -> r
forall (target :: Target) r q.
((target :: Target) ~ (Concrete :: Target), GoodScalar r,
 Differentiable r, GoodScalar q, Differentiable q) =>
[MnistDataLinearR r] -> ADFcnnMnist2Parameters target r q -> r
MnistFcnnRanked2.afcnnMnistTest2 [MnistDataLinearR r]
chunk (Concrete (XParams2 r Float)
-> ADFcnnMnist2Parameters Concrete r Float
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
Concrete (XParams2 r Float)
res)
              testScore :: r
testScore =
                [MnistDataLinearR r]
-> ADFcnnMnist2Parameters Concrete r Float -> r
forall (target :: Target) r q.
((target :: Target) ~ (Concrete :: Target), GoodScalar r,
 Differentiable r, GoodScalar q, Differentiable q) =>
[MnistDataLinearR r] -> ADFcnnMnist2Parameters target r q -> r
MnistFcnnRanked2.afcnnMnistTest2 [MnistDataLinearR r]
testData (Concrete (XParams2 r Float)
-> ADFcnnMnist2Parameters Concrete r Float
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
Concrete (XParams2 r Float)
res)
              lenChunk :: Int
lenChunk = [MnistDataLinearR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataLinearR r]
chunk
          Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
widthHidden Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$ do
            Handle -> TestName -> Assertion
hPutStrLn Handle
stderr (TestName -> Assertion) -> TestName -> Assertion
forall a b. (a -> b) -> a -> b
$
              TestName -> TestName -> Int -> Int -> TestName
forall r. PrintfType r => TestName -> r
printf TestName
"\n%s: (Batch %d with %d points)"
                     TestName
prefix Int
k Int
lenChunk
            Handle -> TestName -> Assertion
hPutStrLn Handle
stderr (TestName -> Assertion) -> TestName -> Assertion
forall a b. (a -> b) -> a -> b
$
              TestName -> TestName -> r -> TestName
forall r. PrintfType r => TestName -> r
printf TestName
"%s: Training error:   %.2f%%"
                     TestName
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
trainScore) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
            Handle -> TestName -> Assertion
hPutStrLn Handle
stderr (TestName -> Assertion) -> TestName -> Assertion
forall a b. (a -> b) -> a -> b
$
              TestName -> TestName -> r -> TestName
forall r. PrintfType r => TestName -> r
printf TestName
"%s: Validation error: %.2f%%"
                     TestName
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
testScore ) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
          Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
res
    let runEpoch :: Int -> Concrete (XParams2 r Float)
                 -> IO (Concrete (XParams2 r Float))
        runEpoch Int
n Concrete (XParams2 r Float)
params | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
epochs = Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
Concrete (XParams2 r Float)
params
        runEpoch Int
n !Concrete (XParams2 r Float)
params = do
          Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
widthHidden Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$
            Handle -> TestName -> Assertion
hPutStrLn Handle
stderr (TestName -> Assertion) -> TestName -> Assertion
forall a b. (a -> b) -> a -> b
$ TestName -> TestName -> Int -> TestName
forall r. PrintfType r => TestName -> r
printf TestName
"\n%s: [Epoch %d]" TestName
prefix Int
n
          let trainDataShuffled :: [MnistData r]
trainDataShuffled = StdGen -> [MnistData r] -> [MnistData r]
forall a. StdGen -> [a] -> [a]
shuffle (Int -> StdGen
mkStdGen (Int -> StdGen) -> Int -> StdGen
forall a b. (a -> b) -> a -> b
$ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
5) [MnistData r]
trainData
              chunks :: [(Int, [MnistDataLinearR r])]
chunks = Int
-> [(Int, [MnistDataLinearR r])] -> [(Int, [MnistDataLinearR r])]
forall a. Int -> [a] -> [a]
take Int
maxBatches
                       ([(Int, [MnistDataLinearR r])] -> [(Int, [MnistDataLinearR r])])
-> [(Int, [MnistDataLinearR r])] -> [(Int, [MnistDataLinearR r])]
forall a b. (a -> b) -> a -> b
$ [Int] -> [[MnistDataLinearR r]] -> [(Int, [MnistDataLinearR r])]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
Item [Int]
1 ..] ([[MnistDataLinearR r]] -> [(Int, [MnistDataLinearR r])])
-> [[MnistDataLinearR r]] -> [(Int, [MnistDataLinearR r])]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataLinearR r] -> [[MnistDataLinearR r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
batchSize
                       ([MnistDataLinearR r] -> [[MnistDataLinearR r]])
-> [MnistDataLinearR r] -> [[MnistDataLinearR r]]
forall a b. (a -> b) -> a -> b
$ (MnistData r -> MnistDataLinearR r)
-> [MnistData r] -> [MnistDataLinearR r]
forall a b. (a -> b) -> [a] -> [b]
map MnistData r -> MnistDataLinearR r
forall r. PrimElt r => MnistData r -> MnistDataLinearR r
mkMnistDataLinearR [MnistData r]
trainDataShuffled
          res <- (Concrete
   (TKProduct
      (TKProduct
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
      (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
 -> (Int, [MnistDataLinearR r])
 -> IO
      (Concrete
         (TKProduct
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
               (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> [(Int, [MnistDataLinearR r])]
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
forall (t :: Type -> Type) (m :: Type -> Type) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> (Int, [MnistDataLinearR r])
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
Concrete (XParams2 r Float)
-> (Int, [MnistDataLinearR r]) -> IO (Concrete (XParams2 r Float))
runBatch Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
Concrete (XParams2 r Float)
params [(Int, [MnistDataLinearR r])]
chunks
          runEpoch (succ n) res
    res <- runEpoch 1 targetInit
    let testErrorFinal =
          r
1 r -> r -> r
forall a. Num a => a -> a -> a
- [MnistDataLinearR r]
-> ADFcnnMnist2Parameters Concrete r Float -> r
forall (target :: Target) r q.
((target :: Target) ~ (Concrete :: Target), GoodScalar r,
 Differentiable r, GoodScalar q, Differentiable q) =>
[MnistDataLinearR r] -> ADFcnnMnist2Parameters target r q -> r
MnistFcnnRanked2.afcnnMnistTest2 [MnistDataLinearR r]
testData (Concrete (XParams2 r Float)
-> ADFcnnMnist2Parameters Concrete r Float
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
Concrete (XParams2 r Float)
res)
    testErrorFinal @?~ expected

{-# SPECIALIZE mnistTestCase2VTA
  :: String
  -> Int -> Int -> Int -> Int -> Double -> Int -> Double
  -> TestTree #-}

tensorADValMnistTests2 :: TestTree
tensorADValMnistTests2 :: TestTree
tensorADValMnistTests2 = TestName -> [TestTree] -> TestTree
testGroup TestName
"Ranked2 ADVal MNIST tests"
  [ TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> Double -> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> r -> TestTree
mnistTestCase2VTA TestName
"VTA2 1 epoch, 1 batch" Int
1 Int
1 Int
300 Int
100 Double
0.02 Int
5000
                       (Double
0.21299999999999997 :: Double)
  , TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> Float -> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> r -> TestTree
mnistTestCase2VTA TestName
"VTA2 artificial 1 2 3 4 5" Int
1 Int
2 Int
3 Int
4 Double
5 Int
5000
                       (Float
0.8972 :: Float)
  , TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> Double -> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> r -> TestTree
mnistTestCase2VTA TestName
"VTA2 artificial 5 4 3 2 1" Int
5 Int
4 Int
3 Int
2 Double
1 Int
5000
                       (Double
0.6805:: Double)
  , TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> Float -> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> r -> TestTree
mnistTestCase2VTA TestName
"VTA2 1 epoch, 0 batch" Int
1 Int
0 Int
300 Int
100 Double
0.02 Int
5000
                       (Float
1 :: Float)
  ]

-- POPL differentiation, with Ast term defined and vectorized only once,
-- but differentiated anew in each gradient descent iteration.
mnistTestCase2VTI
  :: forall r.
     ( Differentiable r, GoodScalar r
     , PrintfArg r, AssertEqualUpToEpsilon r )
  => String
  -> Int -> Int -> Int -> Int -> Double -> Int -> r
  -> TestTree
mnistTestCase2VTI :: forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> r -> TestTree
mnistTestCase2VTI TestName
prefix Int
epochs Int
maxBatches Int
widthHidden Int
widthHidden2
                  Double
gamma Int
batchSize r
expected =
  Int
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
widthHidden ((forall (n :: Nat). KnownNat n => SNat n -> TestTree) -> TestTree)
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat @widthHidden) ->
  Int
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
widthHidden2 ((forall (n :: Nat). KnownNat n => SNat n -> TestTree) -> TestTree)
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat @widthHidden2) ->
  let targetInit :: NoShape
  (Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
targetInit =
        Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
           (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
           (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> NoShape
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
            (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
         (TKProduct
            (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
            (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
      (TKProduct
         (TKS2
            ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
         (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
 -> NoShape
      (Concrete
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
                  (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
               (TKProduct
                  (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
                  (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
            (TKProduct
               (TKS2
                  ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
               (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> NoShape
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
forall a b. (a -> b) -> a -> b
$ (Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r Float)),
 StdGen)
-> Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r Float))
forall a b. (a, b) -> a
fst
        ((Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r Float)),
  StdGen)
 -> Concrete
      (X (ADFcnnMnist2ParametersShaped Concrete n n r Float)))
-> (Concrete
      (X (ADFcnnMnist2ParametersShaped Concrete n n r Float)),
    StdGen)
-> Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r Float))
forall a b. (a -> b) -> a -> b
$ forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue
            @(Concrete (X (MnistFcnnRanked2.ADFcnnMnist2ParametersShaped
                             Concrete widthHidden widthHidden2 r Float)))
            Double
1 (Int -> StdGen
mkStdGen Int
44)
      name :: TestName
name = TestName
prefix TestName -> TestName -> TestName
forall a. [a] -> [a] -> [a]
++ TestName
": "
             TestName -> TestName -> TestName
forall a. [a] -> [a] -> [a]
++ [TestName] -> TestName
unwords [ Int -> TestName
forall a. Show a => a -> TestName
show Int
epochs, Int -> TestName
forall a. Show a => a -> TestName
show Int
maxBatches
                        , Int -> TestName
forall a. Show a => a -> TestName
show Int
widthHidden, Int -> TestName
forall a. Show a => a -> TestName
show Int
widthHidden2
                        , Int -> TestName
forall a. Show a => a -> TestName
show (Int -> TestName) -> Int -> TestName
forall a b. (a -> b) -> a -> b
$ SingletonTK (XParams2 r Float) -> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK (SingletonTK (XParams2 r Float) -> Int)
-> SingletonTK (XParams2 r Float) -> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams2 r Float)
                        , Int -> TestName
forall a. Show a => a -> TestName
show (SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
NoShape
  (Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
targetInit)
                        , Double -> TestName
forall a. Show a => a -> TestName
show Double
gamma ]
  in TestName -> Assertion -> TestTree
testCase TestName
name (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ do
    Handle -> TestName -> Assertion
hPutStrLn Handle
stderr (TestName -> Assertion) -> TestName -> Assertion
forall a b. (a -> b) -> a -> b
$
      TestName -> TestName -> Int -> Int -> TestName
forall r. PrintfType r => TestName -> r
printf TestName
"\n%s: Epochs to run/max batches per epoch: %d/%d"
             TestName
prefix Int
epochs Int
maxBatches
    trainData <- TestName -> TestName -> IO [MnistData r]
forall r.
(Storable r, Fractional r) =>
TestName -> TestName -> IO [MnistData r]
loadMnistData TestName
trainGlyphsPath TestName
trainLabelsPath
    testData <- map mkMnistDataLinearR . take (batchSize * maxBatches)
                <$> loadMnistData testGlyphsPath testLabelsPath
    let ftk = forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams2 r Float)) Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
NoShape
  (Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
targetInit
    (_, _, var, varAst) <- funToAstRevIO ftk
    (varGlyph, astGlyph) <-
      funToAstIO (FTKR (sizeMnistGlyphInt :$: ZSR) FTKScalar) id
    (varLabel, astLabel) <-
      funToAstIO (FTKR (sizeMnistLabelInt :$: ZSR) FTKScalar) id
    let ast :: AstTensor AstMethodLet FullSpan (TKScalar r)
        ast = AstTensor AstMethodLet FullSpan (TKScalar r)
-> AstTensor AstMethodLet FullSpan (TKScalar r)
forall (z :: TK) (s :: AstSpanType).
AstSpan s =>
AstTensor AstMethodLet s z -> AstTensor AstMethodLet s z
simplifyInline
              (AstTensor AstMethodLet FullSpan (TKScalar r)
 -> AstTensor AstMethodLet FullSpan (TKScalar r))
-> AstTensor AstMethodLet FullSpan (TKScalar r)
-> AstTensor AstMethodLet FullSpan (TKScalar r)
forall a b. (a -> b) -> a -> b
$ (AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)),
 AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)))
-> ADFcnnMnist2Parameters (AstTensor AstMethodLet FullSpan) r Float
-> AstTensor AstMethodLet FullSpan (TKScalar r)
forall (target :: Target) r q.
(ADReady target, GoodScalar r, Differentiable r, GoodScalar q,
 Differentiable q) =>
(target (TKR 1 r), target (TKR 1 r))
-> ADFcnnMnist2Parameters target r q -> target (TKScalar r)
MnistFcnnRanked2.afcnnMnistLoss2
                  (AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))
astGlyph, AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))
astLabel)
                  (AstTensor
  AstMethodLet
  FullSpan
  (X (ADFcnnMnist2Parameters
        (AstTensor AstMethodLet FullSpan) r Float))
-> ADFcnnMnist2Parameters (AstTensor AstMethodLet FullSpan) r Float
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget AstTensor
  AstMethodLet
  FullSpan
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
AstTensor
  AstMethodLet
  FullSpan
  (X (ADFcnnMnist2Parameters
        (AstTensor AstMethodLet FullSpan) r Float))
varAst)
        f :: MnistDataLinearR r -> ADVal Concrete (XParams2 r Float)
          -> ADVal Concrete (TKScalar r)
        f (Ranked 1 r
glyph, Ranked 1 r
label) ADVal Concrete (XParams2 r Float)
varInputs =
          let env :: AstEnv (ADVal Concrete)
env = AstVarName
  FullSpan
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> ADVal
     Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> AstEnv (ADVal Concrete)
-> AstEnv (ADVal Concrete)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName
  FullSpan
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
var ADVal
  Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
ADVal Concrete (XParams2 r Float)
varInputs AstEnv (ADVal Concrete)
forall (target :: Target). AstEnv target
emptyEnv
              envMnist :: AstEnv (ADVal Concrete)
envMnist = AstVarName FullSpan (TKR2 1 (TKScalar r))
-> ADVal Concrete (TKR2 1 (TKScalar r))
-> AstEnv (ADVal Concrete)
-> AstEnv (ADVal Concrete)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName FullSpan (TKR2 1 (TKScalar r))
varGlyph (Ranked 1 r -> ADVal Concrete (TKR2 1 (TKScalar r))
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 1 r
glyph)
                         (AstEnv (ADVal Concrete) -> AstEnv (ADVal Concrete))
-> AstEnv (ADVal Concrete) -> AstEnv (ADVal Concrete)
forall a b. (a -> b) -> a -> b
$ AstVarName FullSpan (TKR2 1 (TKScalar r))
-> ADVal Concrete (TKR2 1 (TKScalar r))
-> AstEnv (ADVal Concrete)
-> AstEnv (ADVal Concrete)
forall (target :: Target) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName FullSpan (TKR2 1 (TKScalar r))
varLabel (Ranked 1 r -> ADVal Concrete (TKR2 1 (TKScalar r))
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 1 r
label) AstEnv (ADVal Concrete)
env
          in AstEnv (ADVal Concrete)
-> AstTensor AstMethodLet FullSpan (TKScalar r)
-> ADVal Concrete (TKScalar r)
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target -> AstTensor AstMethodLet FullSpan y -> target y
interpretAstFull AstEnv (ADVal Concrete)
envMnist AstTensor AstMethodLet FullSpan (TKScalar r)
ast
    let runBatch :: Concrete (XParams2 r Float) -> (Int, [MnistDataLinearR r])
                 -> IO (Concrete (XParams2 r Float))
        runBatch !Concrete (XParams2 r Float)
params (Int
k, [MnistDataLinearR r]
chunk) = do
          let res :: Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
res = (Concrete
   (TKProduct
      (TKProduct
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
      (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
 Concrete (TKScalar r))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
forall a b. (a, b) -> a
fst ((Concrete
    (TKProduct
       (TKProduct
          (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
          (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
       (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
  Concrete (TKScalar r))
 -> Concrete
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
            (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
            (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
    Concrete (TKScalar r))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
forall a b. (a -> b) -> a -> b
$ Double
-> (MnistDataLinearR r
    -> ADVal
         Concrete
         (TKProduct
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
               (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
    -> ADVal Concrete (TKScalar r))
-> [MnistDataLinearR r]
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
            (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
    Concrete (TKScalar r))
forall a (x :: TK) (z :: TK).
KnownSTK x =>
Double
-> (a -> ADVal Concrete x -> ADVal Concrete z)
-> [a]
-> Concrete x
-> (Concrete x, Concrete z)
sgd Double
gamma MnistDataLinearR r
-> ADVal
     Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> ADVal Concrete (TKScalar r)
MnistDataLinearR r
-> ADVal Concrete (XParams2 r Float) -> ADVal Concrete (TKScalar r)
f [MnistDataLinearR r]
chunk Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
Concrete (XParams2 r Float)
params
              trainScore :: r
trainScore =
                [MnistDataLinearR r]
-> ADFcnnMnist2Parameters Concrete r Float -> r
forall (target :: Target) r q.
((target :: Target) ~ (Concrete :: Target), GoodScalar r,
 Differentiable r, GoodScalar q, Differentiable q) =>
[MnistDataLinearR r] -> ADFcnnMnist2Parameters target r q -> r
MnistFcnnRanked2.afcnnMnistTest2 [MnistDataLinearR r]
chunk (Concrete (XParams2 r Float)
-> ADFcnnMnist2Parameters Concrete r Float
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
Concrete (XParams2 r Float)
res)
              testScore :: r
testScore =
                [MnistDataLinearR r]
-> ADFcnnMnist2Parameters Concrete r Float -> r
forall (target :: Target) r q.
((target :: Target) ~ (Concrete :: Target), GoodScalar r,
 Differentiable r, GoodScalar q, Differentiable q) =>
[MnistDataLinearR r] -> ADFcnnMnist2Parameters target r q -> r
MnistFcnnRanked2.afcnnMnistTest2 [MnistDataLinearR r]
testData (Concrete (XParams2 r Float)
-> ADFcnnMnist2Parameters Concrete r Float
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
Concrete (XParams2 r Float)
res)
              lenChunk :: Int
lenChunk = [MnistDataLinearR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataLinearR r]
chunk
          Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
widthHidden Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$ do
            Handle -> TestName -> Assertion
hPutStrLn Handle
stderr (TestName -> Assertion) -> TestName -> Assertion
forall a b. (a -> b) -> a -> b
$
              TestName -> TestName -> Int -> Int -> TestName
forall r. PrintfType r => TestName -> r
printf TestName
"\n%s: (Batch %d with %d points)"
                     TestName
prefix Int
k Int
lenChunk
            Handle -> TestName -> Assertion
hPutStrLn Handle
stderr (TestName -> Assertion) -> TestName -> Assertion
forall a b. (a -> b) -> a -> b
$
              TestName -> TestName -> r -> TestName
forall r. PrintfType r => TestName -> r
printf TestName
"%s: Training error:   %.2f%%"
                     TestName
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
trainScore) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
            Handle -> TestName -> Assertion
hPutStrLn Handle
stderr (TestName -> Assertion) -> TestName -> Assertion
forall a b. (a -> b) -> a -> b
$
              TestName -> TestName -> r -> TestName
forall r. PrintfType r => TestName -> r
printf TestName
"%s: Validation error: %.2f%%"
                     TestName
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
testScore ) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
          Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
res
    let runEpoch :: Int -> Concrete (XParams2 r Float)
                 -> IO (Concrete (XParams2 r Float))
        runEpoch Int
n Concrete (XParams2 r Float)
params | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
epochs = Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
Concrete (XParams2 r Float)
params
        runEpoch Int
n !Concrete (XParams2 r Float)
params = do
          Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
widthHidden Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$
            Handle -> TestName -> Assertion
hPutStrLn Handle
stderr (TestName -> Assertion) -> TestName -> Assertion
forall a b. (a -> b) -> a -> b
$ TestName -> TestName -> Int -> TestName
forall r. PrintfType r => TestName -> r
printf TestName
"\n%s: [Epoch %d]" TestName
prefix Int
n
          let trainDataShuffled :: [MnistData r]
trainDataShuffled = StdGen -> [MnistData r] -> [MnistData r]
forall a. StdGen -> [a] -> [a]
shuffle (Int -> StdGen
mkStdGen (Int -> StdGen) -> Int -> StdGen
forall a b. (a -> b) -> a -> b
$ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [MnistData r]
trainData
              chunks :: [(Int, [MnistDataLinearR r])]
chunks = Int
-> [(Int, [MnistDataLinearR r])] -> [(Int, [MnistDataLinearR r])]
forall a. Int -> [a] -> [a]
take Int
maxBatches
                       ([(Int, [MnistDataLinearR r])] -> [(Int, [MnistDataLinearR r])])
-> [(Int, [MnistDataLinearR r])] -> [(Int, [MnistDataLinearR r])]
forall a b. (a -> b) -> a -> b
$ [Int] -> [[MnistDataLinearR r]] -> [(Int, [MnistDataLinearR r])]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
Item [Int]
1 ..] ([[MnistDataLinearR r]] -> [(Int, [MnistDataLinearR r])])
-> [[MnistDataLinearR r]] -> [(Int, [MnistDataLinearR r])]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataLinearR r] -> [[MnistDataLinearR r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
batchSize
                       ([MnistDataLinearR r] -> [[MnistDataLinearR r]])
-> [MnistDataLinearR r] -> [[MnistDataLinearR r]]
forall a b. (a -> b) -> a -> b
$ (MnistData r -> MnistDataLinearR r)
-> [MnistData r] -> [MnistDataLinearR r]
forall a b. (a -> b) -> [a] -> [b]
map MnistData r -> MnistDataLinearR r
forall r. PrimElt r => MnistData r -> MnistDataLinearR r
mkMnistDataLinearR [MnistData r]
trainDataShuffled
          res <- (Concrete
   (TKProduct
      (TKProduct
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
      (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
 -> (Int, [MnistDataLinearR r])
 -> IO
      (Concrete
         (TKProduct
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
               (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> [(Int, [MnistDataLinearR r])]
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
forall (t :: Type -> Type) (m :: Type -> Type) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> (Int, [MnistDataLinearR r])
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
Concrete (XParams2 r Float)
-> (Int, [MnistDataLinearR r]) -> IO (Concrete (XParams2 r Float))
runBatch Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
Concrete (XParams2 r Float)
params [(Int, [MnistDataLinearR r])]
chunks
          runEpoch (succ n) res
    res <- runEpoch 1 targetInit
    let testErrorFinal =
          r
1 r -> r -> r
forall a. Num a => a -> a -> a
- [MnistDataLinearR r]
-> ADFcnnMnist2Parameters Concrete r Float -> r
forall (target :: Target) r q.
((target :: Target) ~ (Concrete :: Target), GoodScalar r,
 Differentiable r, GoodScalar q, Differentiable q) =>
[MnistDataLinearR r] -> ADFcnnMnist2Parameters target r q -> r
MnistFcnnRanked2.afcnnMnistTest2 [MnistDataLinearR r]
testData (Concrete (XParams2 r Float)
-> ADFcnnMnist2Parameters Concrete r Float
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
Concrete (XParams2 r Float)
res)
    testErrorFinal @?~ expected

{-# SPECIALIZE mnistTestCase2VTI
  :: String
  -> Int -> Int -> Int -> Int -> Double -> Int -> Double
  -> TestTree #-}

tensorIntermediateMnistTests2 :: TestTree
tensorIntermediateMnistTests2 :: TestTree
tensorIntermediateMnistTests2 = TestName -> [TestTree] -> TestTree
testGroup TestName
"Ranked2 Intermediate MNIST tests"
  [ TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> Double -> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> r -> TestTree
mnistTestCase2VTI TestName
"VTI2 1 epoch, 1 batch" Int
1 Int
1 Int
300 Int
100 Double
0.02 Int
5000
                       (Double
0.20779999999999998 :: Double)
  , TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> Float -> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> r -> TestTree
mnistTestCase2VTI TestName
"VTI2 artificial 1 2 3 4 5" Int
1 Int
2 Int
3 Int
4 Double
5 Int
5000
                       (Float
0.9108 :: Float)
  , TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> Double -> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> r -> TestTree
mnistTestCase2VTI TestName
"VTI2 artificial 5 4 3 2 1" Int
5 Int
4 Int
3 Int
2 Double
1 Int
5000
                       (Double
0.8129 :: Double)
  , TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> Float -> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r) =>
TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> r -> TestTree
mnistTestCase2VTI TestName
"VTI2 1 epoch, 0 batch" Int
1 Int
0 Int
300 Int
100 Double
0.02 Int
5000
                       (Float
1 :: Float)
  ]

-- JAX differentiation, Ast term built and differentiated only once
-- and the result interpreted with different inputs in each gradient
-- descent iteration.
mnistTestCase2VTO
  :: forall r.
     ( Differentiable r, GoodScalar r
     , PrintfArg r, AssertEqualUpToEpsilon r, ADTensorScalar r ~ r )
  => String
  -> Int -> Int -> Int -> Int -> Double -> Int -> r
  -> TestTree
mnistTestCase2VTO :: forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r,
 (ADTensorScalar r :: Type) ~ (r :: Type)) =>
TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> r -> TestTree
mnistTestCase2VTO TestName
prefix Int
epochs Int
maxBatches Int
widthHidden Int
widthHidden2
                  Double
gamma Int
batchSize r
expected =
  let (!Concrete (XParams2 r Float)
targetInit, !AstArtifactRev
  (TKProduct
     (XParams2 r Float)
     (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
  (TKScalar r)
artRaw) =
        forall r q.
(GoodScalar r, Differentiable r, GoodScalar q, Differentiable q) =>
Proxy @Type q
-> IncomingCotangentHandling
-> Double
-> StdGen
-> Int
-> Int
-> (Concrete (XParams2 r q),
    AstArtifactRev
      (TKProduct
         (XParams2 r q)
         (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKScalar r))
MnistFcnnRanked2.mnistTrainBench2VTOGradient
          @r (forall t. Proxy @Type t
forall {k} (t :: k). Proxy @k t
Proxy @Float) IncomingCotangentHandling
IgnoreIncomingCotangent
          Double
1 (Int -> StdGen
mkStdGen Int
44) Int
widthHidden Int
widthHidden2
      !art :: AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
  (TKScalar r)
art = AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
  (TKScalar r)
-> AstArtifactRev
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKScalar r)
forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstArtifactRev x z
simplifyArtifactGradient AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
  (TKScalar r)
AstArtifactRev
  (TKProduct
     (XParams2 r Float)
     (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
  (TKScalar r)
artRaw
      name :: TestName
name = TestName
prefix TestName -> TestName -> TestName
forall a. [a] -> [a] -> [a]
++ TestName
": "
             TestName -> TestName -> TestName
forall a. [a] -> [a] -> [a]
++ [TestName] -> TestName
unwords [ Int -> TestName
forall a. Show a => a -> TestName
show Int
epochs, Int -> TestName
forall a. Show a => a -> TestName
show Int
maxBatches
                        , Int -> TestName
forall a. Show a => a -> TestName
show Int
widthHidden, Int -> TestName
forall a. Show a => a -> TestName
show Int
widthHidden2
                        , Int -> TestName
forall a. Show a => a -> TestName
show (Int -> TestName) -> Int -> TestName
forall a b. (a -> b) -> a -> b
$ SingletonTK (XParams2 r Float) -> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK (SingletonTK (XParams2 r Float) -> Int)
-> SingletonTK (XParams2 r Float) -> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams2 r Float)
                        , Int -> TestName
forall a. Show a => a -> TestName
show (SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
Concrete (XParams2 r Float)
targetInit)
                        , Double -> TestName
forall a. Show a => a -> TestName
show Double
gamma ]
  in TestName -> Assertion -> TestTree
testCase TestName
name (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ do
    Handle -> TestName -> Assertion
hPutStrLn Handle
stderr (TestName -> Assertion) -> TestName -> Assertion
forall a b. (a -> b) -> a -> b
$
      TestName -> TestName -> Int -> Int -> TestName
forall r. PrintfType r => TestName -> r
printf TestName
"\n%s: Epochs to run/max batches per epoch: %d/%d"
             TestName
prefix Int
epochs Int
maxBatches
    trainData <- TestName -> TestName -> IO [MnistData r]
forall r.
(Storable r, Fractional r) =>
TestName -> TestName -> IO [MnistData r]
loadMnistData TestName
trainGlyphsPath TestName
trainLabelsPath
    testData <- map mkMnistDataLinearR . take (batchSize * maxBatches)
                <$> loadMnistData testGlyphsPath testLabelsPath
    let go :: [MnistDataLinearR r] -> Concrete (XParams2 r Float)
           -> Concrete (XParams2 r Float)
        go [] Concrete (XParams2 r Float)
parameters = Concrete (XParams2 r Float)
parameters
        go ((Ranked 1 r
glyph, Ranked 1 r
label) : [MnistDataLinearR r]
rest) !Concrete (XParams2 r Float)
parameters =
          let parametersAndInput :: Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
parametersAndInput =
                Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> Concrete (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
forall (x :: TK) (z :: TK).
Concrete x -> Concrete z -> Concrete (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
Concrete (XParams2 r Float)
parameters (Concrete (TKR2 1 (TKScalar r))
-> Concrete (TKR2 1 (TKScalar r))
-> Concrete (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r)))
forall (x :: TK) (z :: TK).
Concrete x -> Concrete z -> Concrete (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair (Ranked 1 r -> Concrete (TKR2 1 (TKScalar r))
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 1 r
glyph) (Ranked 1 r -> Concrete (TKR2 1 (TKScalar r))
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 1 r
label))
              gradient :: Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
gradient = Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
forall (x :: TK) (z :: TK). Concrete (TKProduct x z) -> Concrete x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
            (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
 -> Concrete
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
            (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
forall a b. (a -> b) -> a -> b
$ (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
            (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r)))),
 Concrete (TKScalar r))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
forall a b. (a, b) -> a
fst
                         ((Concrete
    (TKProduct
       (TKProduct
          (TKProduct
             (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
             (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
          (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
       (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r)))),
  Concrete (TKScalar r))
 -> Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
               (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
         (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
               (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
         (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r)))),
    Concrete (TKScalar r))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
forall a b. (a -> b) -> a -> b
$ AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
  (TKScalar r)
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> Maybe (Concrete (ADTensorKind (TKScalar r)))
-> (Concrete
      (ADTensorKind
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
                  (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
               (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
            (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))),
    Concrete (TKScalar r))
forall (x :: TK) (z :: TK).
AstArtifactRev x z
-> Concrete x
-> Maybe (Concrete (ADTensorKind z))
-> (Concrete (ADTensorKind x), Concrete z)
revInterpretArtifact AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
  (TKScalar r)
art Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
parametersAndInput Maybe (Concrete (ADTensorKind (TKScalar r)))
Maybe (Concrete (TKScalar r))
forall a. Maybe a
Nothing
          in [MnistDataLinearR r]
-> Concrete (XParams2 r Float) -> Concrete (XParams2 r Float)
go [MnistDataLinearR r]
rest (Double
-> SingletonTK
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> Concrete
     (ADTensorKind
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
forall (y :: TK).
Double
-> SingletonTK y
-> Concrete y
-> Concrete (ADTensorKind y)
-> Concrete y
updateWithGradient Double
gamma SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
Concrete (XParams2 r Float)
parameters Concrete
  (ADTensorKind
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
gradient)
    let runBatch :: Concrete (XParams2 r Float) -> (Int, [MnistDataLinearR r])
                 -> IO (Concrete (XParams2 r Float))
        runBatch !Concrete (XParams2 r Float)
params (Int
k, [MnistDataLinearR r]
chunk) = do
          let res :: Concrete (XParams2 r Float)
res = [MnistDataLinearR r]
-> Concrete (XParams2 r Float) -> Concrete (XParams2 r Float)
go [MnistDataLinearR r]
chunk Concrete (XParams2 r Float)
params
              trainScore :: r
trainScore =
                [MnistDataLinearR r]
-> ADFcnnMnist2Parameters Concrete r Float -> r
forall (target :: Target) r q.
((target :: Target) ~ (Concrete :: Target), GoodScalar r,
 Differentiable r, GoodScalar q, Differentiable q) =>
[MnistDataLinearR r] -> ADFcnnMnist2Parameters target r q -> r
MnistFcnnRanked2.afcnnMnistTest2 [MnistDataLinearR r]
chunk (Concrete (XParams2 r Float)
-> ADFcnnMnist2Parameters Concrete r Float
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget Concrete (XParams2 r Float)
res)
              testScore :: r
testScore =
                [MnistDataLinearR r]
-> ADFcnnMnist2Parameters Concrete r Float -> r
forall (target :: Target) r q.
((target :: Target) ~ (Concrete :: Target), GoodScalar r,
 Differentiable r, GoodScalar q, Differentiable q) =>
[MnistDataLinearR r] -> ADFcnnMnist2Parameters target r q -> r
MnistFcnnRanked2.afcnnMnistTest2 [MnistDataLinearR r]
testData (Concrete (XParams2 r Float)
-> ADFcnnMnist2Parameters Concrete r Float
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget Concrete (XParams2 r Float)
res)
              lenChunk :: Int
lenChunk = [MnistDataLinearR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataLinearR r]
chunk
          Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
widthHidden Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$ do
            Handle -> TestName -> Assertion
hPutStrLn Handle
stderr (TestName -> Assertion) -> TestName -> Assertion
forall a b. (a -> b) -> a -> b
$
              TestName -> TestName -> Int -> Int -> TestName
forall r. PrintfType r => TestName -> r
printf TestName
"\n%s: (Batch %d with %d points)"
                     TestName
prefix Int
k Int
lenChunk
            Handle -> TestName -> Assertion
hPutStrLn Handle
stderr (TestName -> Assertion) -> TestName -> Assertion
forall a b. (a -> b) -> a -> b
$
              TestName -> TestName -> r -> TestName
forall r. PrintfType r => TestName -> r
printf TestName
"%s: Training error:   %.2f%%"
                     TestName
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
trainScore) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
            Handle -> TestName -> Assertion
hPutStrLn Handle
stderr (TestName -> Assertion) -> TestName -> Assertion
forall a b. (a -> b) -> a -> b
$
              TestName -> TestName -> r -> TestName
forall r. PrintfType r => TestName -> r
printf TestName
"%s: Validation error: %.2f%%"
                     TestName
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
testScore ) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
          Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
Concrete (XParams2 r Float)
res
    let runEpoch :: Int -> Concrete (XParams2 r Float)
                 -> IO (Concrete (XParams2 r Float))
        runEpoch Int
n Concrete (XParams2 r Float)
params | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
epochs = Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
Concrete (XParams2 r Float)
params
        runEpoch Int
n !Concrete (XParams2 r Float)
params = do
          Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
widthHidden Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$
            Handle -> TestName -> Assertion
hPutStrLn Handle
stderr (TestName -> Assertion) -> TestName -> Assertion
forall a b. (a -> b) -> a -> b
$ TestName -> TestName -> Int -> TestName
forall r. PrintfType r => TestName -> r
printf TestName
"\n%s: [Epoch %d]" TestName
prefix Int
n
          let trainDataShuffled :: [MnistData r]
trainDataShuffled = StdGen -> [MnistData r] -> [MnistData r]
forall a. StdGen -> [a] -> [a]
shuffle (Int -> StdGen
mkStdGen (Int -> StdGen) -> Int -> StdGen
forall a b. (a -> b) -> a -> b
$ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [MnistData r]
trainData
              chunks :: [(Int, [MnistDataLinearR r])]
chunks = Int
-> [(Int, [MnistDataLinearR r])] -> [(Int, [MnistDataLinearR r])]
forall a. Int -> [a] -> [a]
take Int
maxBatches
                       ([(Int, [MnistDataLinearR r])] -> [(Int, [MnistDataLinearR r])])
-> [(Int, [MnistDataLinearR r])] -> [(Int, [MnistDataLinearR r])]
forall a b. (a -> b) -> a -> b
$ [Int] -> [[MnistDataLinearR r]] -> [(Int, [MnistDataLinearR r])]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
Item [Int]
1 ..] ([[MnistDataLinearR r]] -> [(Int, [MnistDataLinearR r])])
-> [[MnistDataLinearR r]] -> [(Int, [MnistDataLinearR r])]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataLinearR r] -> [[MnistDataLinearR r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
batchSize
                       ([MnistDataLinearR r] -> [[MnistDataLinearR r]])
-> [MnistDataLinearR r] -> [[MnistDataLinearR r]]
forall a b. (a -> b) -> a -> b
$ (MnistData r -> MnistDataLinearR r)
-> [MnistData r] -> [MnistDataLinearR r]
forall a b. (a -> b) -> [a] -> [b]
map MnistData r -> MnistDataLinearR r
forall r. PrimElt r => MnistData r -> MnistDataLinearR r
mkMnistDataLinearR [MnistData r]
trainDataShuffled
          res <- (Concrete
   (TKProduct
      (TKProduct
         (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
         (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
      (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
 -> (Int, [MnistDataLinearR r])
 -> IO
      (Concrete
         (TKProduct
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
               (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
            (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> [(Int, [MnistDataLinearR r])]
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
forall (t :: Type -> Type) (m :: Type -> Type) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> (Int, [MnistDataLinearR r])
-> IO
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
Concrete (XParams2 r Float)
-> (Int, [MnistDataLinearR r]) -> IO (Concrete (XParams2 r Float))
runBatch Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
Concrete (XParams2 r Float)
params [(Int, [MnistDataLinearR r])]
chunks
          runEpoch (succ n) res
    res <- runEpoch 1 targetInit
    let testErrorFinal =
          r
1 r -> r -> r
forall a. Num a => a -> a -> a
- [MnistDataLinearR r]
-> ADFcnnMnist2Parameters Concrete r Float -> r
forall (target :: Target) r q.
((target :: Target) ~ (Concrete :: Target), GoodScalar r,
 Differentiable r, GoodScalar q, Differentiable q) =>
[MnistDataLinearR r] -> ADFcnnMnist2Parameters target r q -> r
MnistFcnnRanked2.afcnnMnistTest2 [MnistDataLinearR r]
testData (Concrete (XParams2 r Float)
-> ADFcnnMnist2Parameters Concrete r Float
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
Concrete (XParams2 r Float)
res)
    testErrorFinal @?~ expected

{-# SPECIALIZE mnistTestCase2VTO
  :: String
  -> Int -> Int -> Int -> Int -> Double -> Int -> Double
  -> TestTree #-}

tensorADOnceMnistTests2 :: TestTree
tensorADOnceMnistTests2 :: TestTree
tensorADOnceMnistTests2 = TestName -> [TestTree] -> TestTree
testGroup TestName
"Ranked2 Once MNIST tests"
  [ TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> Double -> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r,
 (ADTensorScalar r :: Type) ~ (r :: Type)) =>
TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> r -> TestTree
mnistTestCase2VTO TestName
"VTO2 1 epoch, 1 batch" Int
1 Int
1 Int
300 Int
100 Double
0.02 Int
5000
                       (Double
0.20779999999999998 :: Double)
  , TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> Float -> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r,
 (ADTensorScalar r :: Type) ~ (r :: Type)) =>
TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> r -> TestTree
mnistTestCase2VTO TestName
"VTO2 artificial 1 2 3 4 5" Int
1 Int
2 Int
3 Int
4 Double
5 Int
5000
                       (Float
0.9108 :: Float)
  , TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> Double -> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r,
 (ADTensorScalar r :: Type) ~ (r :: Type)) =>
TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> r -> TestTree
mnistTestCase2VTO TestName
"VTO2 artificial 5 4 3 2 1" Int
5 Int
4 Int
3 Int
2 Double
1 Int
5000
                       (Double
0.8129 :: Double)
  , TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> Float -> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
 AssertEqualUpToEpsilon r,
 (ADTensorScalar r :: Type) ~ (r :: Type)) =>
TestName
-> Int -> Int -> Int -> Int -> Double -> Int -> r -> TestTree
mnistTestCase2VTO TestName
"VTO2 1 epoch, 0 batch" Int
1 Int
0 Int
300 Int
100 Double
0.02 Int
5000
                       (Float
1 :: Float)
  , TestName -> (Int -> Property) -> TestTree
forall a. Testable a => TestName -> a -> TestTree
testProperty TestName
"VTO2 grad vs fwd" ((Int -> Property) -> TestTree) -> (Int -> Property) -> TestTree
forall a b. (a -> b) -> a -> b
$
    \Int
seed0 ->
    Gen Int -> (Int -> [Int]) -> (Int -> Property) -> Property
forall a prop.
(Show a, Testable prop) =>
Gen a -> (a -> [a]) -> (a -> prop) -> Property
forAllShrink ((Int, Int) -> Gen Int
chooseInt (Int
0, Int
600)) Int -> [Int]
forall a. Integral a => a -> [a]
shrinkIntegral ((Int -> Property) -> Property) -> (Int -> Property) -> Property
forall a b. (a -> b) -> a -> b
$ \Int
width1Hidden ->
    Gen Int -> (Int -> [Int]) -> (Int -> Property) -> Property
forall a prop.
(Show a, Testable prop) =>
Gen a -> (a -> [a]) -> (a -> prop) -> Property
forAllShrink ((Int, Int) -> Gen Int
chooseInt (Int
0, Int
200)) Int -> [Int]
forall a. Integral a => a -> [a]
shrinkIntegral ((Int -> Property) -> Property) -> (Int -> Property) -> Property
forall a b. (a -> b) -> a -> b
$ \Int
width1Hidden2 ->
    Gen Int -> (Int -> [Int]) -> (Int -> Property) -> Property
forall a prop.
(Show a, Testable prop) =>
Gen a -> (a -> [a]) -> (a -> prop) -> Property
forAllShrink ((Int, Int) -> Gen Int
chooseInt (Int
0, Int
5)) Int -> [Int]
forall a. Integral a => a -> [a]
shrinkIntegral ((Int -> Property) -> Property) -> (Int -> Property) -> Property
forall a b. (a -> b) -> a -> b
$ \Int
simp ->
    Gen Double -> (Double -> Property) -> Property
forall a prop.
(Show a, Testable prop) =>
Gen a -> (a -> prop) -> Property
forAll ((Double, Double) -> Gen Double
forall a. Random a => (a, a) -> Gen a
choose (Double
0.01, Double
1)) ((Double -> Property) -> Property)
-> (Double -> Property) -> Property
forall a b. (a -> b) -> a -> b
$ \Double
range ->
    Gen Double -> (Double -> Property) -> Property
forall a prop.
(Show a, Testable prop) =>
Gen a -> (a -> prop) -> Property
forAll ((Double, Double) -> Gen Double
forall a. Random a => (a, a) -> Gen a
choose (Double
0.01, Double
1)) ((Double -> Property) -> Property)
-> (Double -> Property) -> Property
forall a b. (a -> b) -> a -> b
$ \Double
range2 ->
    Gen Double -> (Double -> Property) -> Property
forall a prop.
(Show a, Testable prop) =>
Gen a -> (a -> prop) -> Property
forAll ((Double, Double) -> Gen Double
forall a. Random a => (a, a) -> Gen a
choose (Double
0.5, Double
1.5)) ((Double -> Property) -> Property)
-> (Double -> Property) -> Property
forall a b. (a -> b) -> a -> b
$ \Double
dt ->
    Gen Double -> (Double -> Property) -> Property
forall a prop.
(Show a, Testable prop) =>
Gen a -> (a -> prop) -> Property
forAll ((Double, Double) -> Gen Double
forall a. Random a => (a, a) -> Gen a
choose (Double
0, Double
1e-7)) ((Double -> Property) -> Property)
-> (Double -> Property) -> Property
forall a b. (a -> b) -> a -> b
$ \(Double
perturbation :: Double) ->
    Int
-> (forall (n :: Nat). KnownNat n => SNat n -> Property)
-> Property
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
width1Hidden) ((forall (n :: Nat). KnownNat n => SNat n -> Property) -> Property)
-> (forall (n :: Nat). KnownNat n => SNat n -> Property)
-> Property
forall a b. (a -> b) -> a -> b
$ \(SNat @widthHidden) ->
    Int
-> (forall (n :: Nat). KnownNat n => SNat n -> Property)
-> Property
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
width1Hidden2) ((forall (n :: Nat). KnownNat n => SNat n -> Property) -> Property)
-> (forall (n :: Nat). KnownNat n => SNat n -> Property)
-> Property
forall a b. (a -> b) -> a -> b
$ \(SNat @widthHidden2) ->
    let (Concrete (TKS ((':) @Nat SizeMnistGlyph ('[] @Nat)) Double)
glyph0, StdGen
seed2) = forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue @(Concrete (TKS '[SizeMnistGlyph] Double))
                                      Double
0.5 (Int -> StdGen
mkStdGen Int
seed0)
        (Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)
label0, StdGen
seed3) = forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue @(Concrete (TKS '[SizeMnistLabel] Double))
                                      Double
5 StdGen
seed2
        (Concrete (TKR 1 Double)
glyph, Concrete (TKR 1 Double)
label) = ( (Concrete (TKR2 0 (TKScalar Double))
 -> Concrete (TKR2 0 (TKScalar Double)))
-> Concrete (TKR2 (1 + 0) (TKScalar Double))
-> Concrete (TKR2 (1 + 0) (TKScalar Double))
forall (n :: Nat) (x :: TK) (x2 :: TK) (target :: Target).
(KnownNat n, KnownSTK x, KnownSTK x2, BaseTensor target) =>
(target (TKR2 n x) -> target (TKR2 n x2))
-> target (TKR2 (1 + n) x) -> target (TKR2 (1 + n) x2)
rmap1 (Double -> Concrete (TKR2 0 (TKScalar Double))
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
r -> target (TKR 0 r)
rscalar Double
0.5 Concrete (TKR2 0 (TKScalar Double))
-> Concrete (TKR2 0 (TKScalar Double))
-> Concrete (TKR2 0 (TKScalar Double))
forall a. Num a => a -> a -> a
+) (Concrete (TKR2 (1 + 0) (TKScalar Double))
 -> Concrete (TKR2 (1 + 0) (TKScalar Double)))
-> Concrete (TKR2 (1 + 0) (TKScalar Double))
-> Concrete (TKR2 (1 + 0) (TKScalar Double))
forall a b. (a -> b) -> a -> b
$ Concrete (TKS ((':) @Nat 784 ('[] @Nat)) Double)
-> NoShape (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) Double))
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape Concrete (TKS ((':) @Nat 784 ('[] @Nat)) Double)
Concrete (TKS ((':) @Nat SizeMnistGlyph ('[] @Nat)) Double)
glyph0
                         , (Concrete (TKR2 0 (TKScalar Double))
 -> Concrete (TKR2 0 (TKScalar Double)))
-> Concrete (TKR2 (1 + 0) (TKScalar Double))
-> Concrete (TKR2 (1 + 0) (TKScalar Double))
forall (n :: Nat) (x :: TK) (x2 :: TK) (target :: Target).
(KnownNat n, KnownSTK x, KnownSTK x2, BaseTensor target) =>
(target (TKR2 n x) -> target (TKR2 n x2))
-> target (TKR2 (1 + n) x) -> target (TKR2 (1 + n) x2)
rmap1 (Double -> Concrete (TKR2 0 (TKScalar Double))
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
r -> target (TKR 0 r)
rscalar Double
5 Concrete (TKR2 0 (TKScalar Double))
-> Concrete (TKR2 0 (TKScalar Double))
-> Concrete (TKR2 0 (TKScalar Double))
forall a. Num a => a -> a -> a
+ ) (Concrete (TKR2 (1 + 0) (TKScalar Double))
 -> Concrete (TKR2 (1 + 0) (TKScalar Double)))
-> Concrete (TKR2 (1 + 0) (TKScalar Double))
-> Concrete (TKR2 (1 + 0) (TKScalar Double))
forall a b. (a -> b) -> a -> b
$ Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)
-> NoShape
     (Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)
label0 )
        ds :: Concrete (XParams2 Double Double)
        (Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
Concrete (XParams2 Double Double)
ds, StdGen
seed4) = (Concrete
   (X (ADFcnnMnist2ParametersShaped Concrete n n Double Double))
 -> Concrete
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
         (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))))
-> (Concrete
      (X (ADFcnnMnist2ParametersShaped Concrete n n Double Double)),
    StdGen)
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
         (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))),
    StdGen)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: Type -> Type -> Type) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar Double))
           (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double)))
        (TKProduct
           (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Double))
           (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat)))
           (TKScalar Double))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
-> NoShape
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar Double))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double)))
              (TKProduct
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Double))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat)))
                 (TKScalar Double))
              (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))))
Concrete
  (X (ADFcnnMnist2ParametersShaped Concrete n n Double Double))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape ((Concrete
    (X (ADFcnnMnist2ParametersShaped Concrete n n Double Double)),
  StdGen)
 -> (Concrete
       (TKProduct
          (TKProduct
             (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
             (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
          (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))),
     StdGen))
-> (Concrete
      (X (ADFcnnMnist2ParametersShaped Concrete n n Double Double)),
    StdGen)
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
         (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))),
    StdGen)
forall a b. (a -> b) -> a -> b
$
          forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue
            @(Concrete (X (MnistFcnnRanked2.ADFcnnMnist2ParametersShaped
                             Concrete widthHidden widthHidden2 Double Double)))
            Double
range StdGen
seed3
        (Concrete (XParams2 Double Double)
targetInit, AstArtifactRev
  (TKProduct
     (XParams2 Double Double) (TKProduct (TKR 1 Double) (TKR 1 Double)))
  (TKScalar Double)
artRaw) =
          forall r q.
(GoodScalar r, Differentiable r, GoodScalar q, Differentiable q) =>
Proxy @Type q
-> IncomingCotangentHandling
-> Double
-> StdGen
-> Int
-> Int
-> (Concrete (XParams2 r q),
    AstArtifactRev
      (TKProduct
         (XParams2 r q)
         (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKScalar r))
MnistFcnnRanked2.mnistTrainBench2VTOGradient
            @Double (forall t. Proxy @Type t
forall {k} (t :: k). Proxy @k t
Proxy @Double) IncomingCotangentHandling
UseIncomingCotangent
            Double
range2 StdGen
seed4 (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
width1Hidden) (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
width1Hidden2)
        art :: AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 Double) (TKR 1 Double)))
  (TKScalar Double)
art = (AstArtifactRev
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
         (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
      (TKProduct (TKR 1 Double) (TKR 1 Double)))
   (TKScalar Double)
 -> AstArtifactRev
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
               (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
         (TKProduct (TKR 1 Double) (TKR 1 Double)))
      (TKScalar Double))
-> AstArtifactRev
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR 1 Double) (TKR 1 Double)))
     (TKScalar Double)
-> [AstArtifactRev
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
               (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
         (TKProduct (TKR 1 Double) (TKR 1 Double)))
      (TKScalar Double)]
forall a. (a -> a) -> a -> [a]
iterate AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 Double) (TKR 1 Double)))
  (TKScalar Double)
-> AstArtifactRev
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR 1 Double) (TKR 1 Double)))
     (TKScalar Double)
forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstArtifactRev x z
simplifyArtifactGradient AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 Double) (TKR 1 Double)))
  (TKScalar Double)
AstArtifactRev
  (TKProduct
     (XParams2 Double Double) (TKProduct (TKR 1 Double) (TKR 1 Double)))
  (TKScalar Double)
artRaw [AstArtifactRev
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
         (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
      (TKProduct (TKR 1 Double) (TKR 1 Double)))
   (TKScalar Double)]
-> Int
-> AstArtifactRev
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR 1 Double) (TKR 1 Double)))
     (TKScalar Double)
forall a. HasCallStack => [a] -> Int -> a
!! Int
simp
        stk :: SingletonTK (XParams2 Double Double)
stk = forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams2 Double Double)
        ftk :: FullShapeTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
ftk = forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
stk Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
Concrete (XParams2 Double Double)
targetInit
        parametersAndInput :: Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 Double) (TKR 1 Double)))
parametersAndInput = Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Concrete (TKProduct (TKR 1 Double) (TKR 1 Double))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR 1 Double) (TKR 1 Double)))
forall (x :: TK) (z :: TK).
Concrete x -> Concrete z -> Concrete (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
Concrete (XParams2 Double Double)
targetInit (Concrete (TKR 1 Double)
-> Concrete (TKR 1 Double)
-> Concrete (TKProduct (TKR 1 Double) (TKR 1 Double))
forall (x :: TK) (z :: TK).
Concrete x -> Concrete z -> Concrete (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair Concrete (TKR 1 Double)
glyph Concrete (TKR 1 Double)
label)
        (Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
_gradient0, Concrete (TKScalar Double)
value0) = (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
         (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
      (TKProduct (TKR 1 Double) (TKR 1 Double)))
 -> Concrete
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
         (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
               (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
         (TKProduct (TKR 1 Double) (TKR 1 Double))),
    Concrete (TKScalar Double))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
         (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))),
    Concrete (TKScalar Double))
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: Type -> Type -> Type) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 Double) (TKR 1 Double)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
forall (x :: TK) (z :: TK). Concrete (TKProduct x z) -> Concrete x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 ((Concrete
    (TKProduct
       (TKProduct
          (TKProduct
             (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
             (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
          (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
       (TKProduct (TKR 1 Double) (TKR 1 Double))),
  Concrete (TKScalar Double))
 -> (Concrete
       (TKProduct
          (TKProduct
             (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
             (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
          (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))),
     Concrete (TKScalar Double)))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
               (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
         (TKProduct (TKR 1 Double) (TKR 1 Double))),
    Concrete (TKScalar Double))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
         (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))),
    Concrete (TKScalar Double))
forall a b. (a -> b) -> a -> b
$
          AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 Double) (TKR 1 Double)))
  (TKScalar Double)
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR 1 Double) (TKR 1 Double)))
-> Maybe (Concrete (ADTensorKind (TKScalar Double)))
-> (Concrete
      (ADTensorKind
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
                  (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
               (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
            (TKProduct (TKR 1 Double) (TKR 1 Double)))),
    Concrete (TKScalar Double))
forall (x :: TK) (z :: TK).
AstArtifactRev x z
-> Concrete x
-> Maybe (Concrete (ADTensorKind z))
-> (Concrete (ADTensorKind x), Concrete z)
revInterpretArtifact AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 Double) (TKR 1 Double)))
  (TKScalar Double)
art Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 Double) (TKR 1 Double)))
parametersAndInput Maybe (Concrete (ADTensorKind (TKScalar Double)))
Maybe (Concrete (TKScalar Double))
forall a. Maybe a
Nothing
        (Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
gradient1, Concrete (TKScalar Double)
value1) = (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
         (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
      (TKProduct (TKR 1 Double) (TKR 1 Double)))
 -> Concrete
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
         (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
               (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
         (TKProduct (TKR 1 Double) (TKR 1 Double))),
    Concrete (TKScalar Double))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
         (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))),
    Concrete (TKScalar Double))
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: Type -> Type -> Type) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 Double) (TKR 1 Double)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
forall (x :: TK) (z :: TK). Concrete (TKProduct x z) -> Concrete x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 ((Concrete
    (TKProduct
       (TKProduct
          (TKProduct
             (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
             (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
          (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
       (TKProduct (TKR 1 Double) (TKR 1 Double))),
  Concrete (TKScalar Double))
 -> (Concrete
       (TKProduct
          (TKProduct
             (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
             (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
          (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))),
     Concrete (TKScalar Double)))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
               (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
         (TKProduct (TKR 1 Double) (TKR 1 Double))),
    Concrete (TKScalar Double))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
         (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))),
    Concrete (TKScalar Double))
forall a b. (a -> b) -> a -> b
$
          AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 Double) (TKR 1 Double)))
  (TKScalar Double)
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR 1 Double) (TKR 1 Double)))
-> Maybe (Concrete (ADTensorKind (TKScalar Double)))
-> (Concrete
      (ADTensorKind
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
                  (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
               (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
            (TKProduct (TKR 1 Double) (TKR 1 Double)))),
    Concrete (TKScalar Double))
forall (x :: TK) (z :: TK).
AstArtifactRev x z
-> Concrete x
-> Maybe (Concrete (ADTensorKind z))
-> (Concrete (ADTensorKind x), Concrete z)
revInterpretArtifact AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 Double) (TKR 1 Double)))
  (TKScalar Double)
art Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 Double) (TKR 1 Double)))
parametersAndInput (Concrete (ADTensorKind (TKScalar Double))
-> Maybe (Concrete (ADTensorKind (TKScalar Double)))
forall a. a -> Maybe a
Just (Concrete (ADTensorKind (TKScalar Double))
 -> Maybe (Concrete (ADTensorKind (TKScalar Double))))
-> Concrete (ADTensorKind (TKScalar Double))
-> Maybe (Concrete (ADTensorKind (TKScalar Double)))
forall a b. (a -> b) -> a -> b
$ Double -> Concrete (TKScalar Double)
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
r -> target (TKScalar r)
kconcrete Double
dt)
        f :: ADVal Concrete (XParams2 Double Double)
          -> ADVal Concrete (TKScalar Double)
        f :: ADVal Concrete (XParams2 Double Double)
-> ADVal Concrete (TKScalar Double)
f ADVal Concrete (XParams2 Double Double)
adinputs =
          (ADVal Concrete (TKR 1 Double), ADVal Concrete (TKR 1 Double))
-> ADFcnnMnist2Parameters (ADVal Concrete) Double Double
-> ADVal Concrete (TKScalar Double)
forall (target :: Target) r q.
(ADReady target, GoodScalar r, Differentiable r, GoodScalar q,
 Differentiable q) =>
(target (TKR 1 r), target (TKR 1 r))
-> ADFcnnMnist2Parameters target r q -> target (TKScalar r)
MnistFcnnRanked2.afcnnMnistLoss2
            (PrimalOf (ADVal Concrete) (TKR 1 Double)
-> ADVal Concrete (TKR 1 Double)
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownNat n, KnownSTK x) =>
PrimalOf target (TKR2 n x) -> target (TKR2 n x)
rfromPrimal Concrete (TKR 1 Double)
PrimalOf (ADVal Concrete) (TKR 1 Double)
glyph, PrimalOf (ADVal Concrete) (TKR 1 Double)
-> ADVal Concrete (TKR 1 Double)
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownNat n, KnownSTK x) =>
PrimalOf target (TKR2 n x) -> target (TKR2 n x)
rfromPrimal Concrete (TKR 1 Double)
PrimalOf (ADVal Concrete) (TKR 1 Double)
label) (ADVal
  Concrete
  (X (ADFcnnMnist2Parameters (ADVal Concrete) Double Double))
-> ADFcnnMnist2Parameters (ADVal Concrete) Double Double
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget ADVal
  Concrete
  (X (ADFcnnMnist2Parameters (ADVal Concrete) Double Double))
ADVal Concrete (XParams2 Double Double)
adinputs)
        (Concrete (ADTensorKind (TKScalar Double))
derivative2, Concrete (TKScalar Double)
value2) = (ADVal
   Concrete
   (TKProduct
      (TKProduct
         (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
         (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
      (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
 -> ADVal Concrete (TKScalar Double))
-> DValue
     (ADVal
        Concrete
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))))
-> DValue
     (ADVal
        Concrete
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))))
-> (Concrete (ADTensorKind (TKScalar Double)),
    Concrete (TKScalar Double))
forall src (ztgt :: TK) tgt.
((X src :: TK) ~ (X (DValue src) :: TK), KnownSTK (X src),
 AdaptableTarget (ADVal Concrete) src,
 AdaptableTarget Concrete (DValue src),
 (tgt :: Type) ~ (ADVal Concrete ztgt :: Type)) =>
(src -> tgt)
-> DValue src
-> DValue src
-> (Concrete (ADTensorKind ztgt), Concrete ztgt)
cfwdBoth ADVal
  Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> ADVal Concrete (TKScalar Double)
ADVal Concrete (XParams2 Double Double)
-> ADVal Concrete (TKScalar Double)
f Concrete (XParams2 Double Double)
DValue
  (ADVal
     Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))))
targetInit Concrete (XParams2 Double Double)
DValue
  (ADVal
     Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))))
ds
--        goodDt :: forall r. GoodScalar r => r
--        goodDt = ifDifferentiable @r (realToFrac dt) 0
--        targetDt :: Concrete (XParams2 Double Double)
--        targetDt = replTarget goodDt ftk
        goodPerturbation :: forall r. GoodScalar r => r
        goodPerturbation :: forall r. GoodScalar r => r
goodPerturbation = forall r a. IfDifferentiable r => (Differentiable r => a) -> a -> a
ifDifferentiable @r (Double -> r
forall a b. (Real a, Fractional b) => a -> b
realToFrac Double
perturbation) r
0
        targetPerturbed :: Concrete (XParams2 Double Double)
        targetPerturbed :: Concrete (XParams2 Double Double)
targetPerturbed = (forall r. GoodScalar r => r)
-> FullShapeTK
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
forall (y :: TK).
(forall r. GoodScalar r => r) -> FullShapeTK y -> Concrete y
forall (target :: Target) (y :: TK).
BaseTensor target =>
(forall r. GoodScalar r => r) -> FullShapeTK y -> target y
treplTarget r
forall r. GoodScalar r => r
goodPerturbation FullShapeTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
ftk
        targetInitPerturbed :: Concrete (XParams2 Double Double)
        targetInitPerturbed :: Concrete (XParams2 Double Double)
targetInitPerturbed = SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
forall (y :: TK).
SingletonTK y -> Concrete y -> Concrete y -> Concrete y
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> target y -> target y
taddTarget SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
stk Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
Concrete (XParams2 Double Double)
targetInit Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
Concrete (XParams2 Double Double)
targetPerturbed
        (Concrete (ADTensorKind (TKScalar Double))
derivative3, Concrete (TKScalar Double)
value3) = (ADVal
   Concrete
   (TKProduct
      (TKProduct
         (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
         (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
      (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
 -> ADVal Concrete (TKScalar Double))
-> DValue
     (ADVal
        Concrete
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))))
-> DValue
     (ADVal
        Concrete
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))))
-> (Concrete (ADTensorKind (TKScalar Double)),
    Concrete (TKScalar Double))
forall src (ztgt :: TK) tgt.
((X src :: TK) ~ (X (DValue src) :: TK), KnownSTK (X src),
 AdaptableTarget (ADVal Concrete) src,
 AdaptableTarget Concrete (DValue src),
 (tgt :: Type) ~ (ADVal Concrete ztgt :: Type)) =>
(src -> tgt)
-> DValue src
-> DValue src
-> (Concrete (ADTensorKind ztgt), Concrete ztgt)
cfwdBoth ADVal
  Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> ADVal Concrete (TKScalar Double)
ADVal Concrete (XParams2 Double Double)
-> ADVal Concrete (TKScalar Double)
f Concrete (XParams2 Double Double)
DValue
  (ADVal
     Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))))
targetInit Concrete (XParams2 Double Double)
DValue
  (ADVal
     Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))))
targetPerturbed
        value4 :: Concrete (TKScalar Double)
        value4 :: Concrete (TKScalar Double)
value4 = (Concrete (TKR 1 Double), Concrete (TKR 1 Double))
-> ADFcnnMnist2Parameters Concrete Double Double
-> Concrete (TKScalar Double)
forall (target :: Target) r q.
(ADReady target, GoodScalar r, Differentiable r, GoodScalar q,
 Differentiable q) =>
(target (TKR 1 r), target (TKR 1 r))
-> ADFcnnMnist2Parameters target r q -> target (TKScalar r)
MnistFcnnRanked2.afcnnMnistLoss2
                   (PrimalOf Concrete (TKR 1 Double) -> Concrete (TKR 1 Double)
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownNat n, KnownSTK x) =>
PrimalOf target (TKR2 n x) -> target (TKR2 n x)
rfromPrimal Concrete (TKR 1 Double)
PrimalOf Concrete (TKR 1 Double)
glyph, PrimalOf Concrete (TKR 1 Double) -> Concrete (TKR 1 Double)
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownNat n, KnownSTK x) =>
PrimalOf target (TKR2 n x) -> target (TKR2 n x)
rfromPrimal Concrete (TKR 1 Double)
PrimalOf Concrete (TKR 1 Double)
label)
                   (Concrete (XParams2 Double Double)
-> ADFcnnMnist2Parameters Concrete Double Double
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget Concrete (XParams2 Double Double)
targetInitPerturbed)
    in
      [Property] -> Property
forall prop. Testable prop => [prop] -> Property
conjoin
        [ TestName -> Bool -> Property
forall prop. Testable prop => TestName -> prop -> Property
counterexample
            (TestName
"Objective function value from grad and jvp matches: "
             TestName -> TestName -> TestName
forall a. [a] -> [a] -> [a]
++ (Concrete (TKScalar Double), Concrete (TKScalar Double),
 Concrete (TKScalar Double))
-> TestName
forall a. Show a => a -> TestName
show (Concrete (TKScalar Double)
value1, Concrete (TKScalar Double)
value2, Concrete (TKScalar Double)
value1 Concrete (TKScalar Double)
-> Concrete (TKScalar Double) -> Concrete (TKScalar Double)
forall a. Num a => a -> a -> a
- Concrete (TKScalar Double)
value2))
            (Concrete (TKScalar Double) -> Concrete (TKScalar Double)
forall a. Num a => a -> a
abs (Concrete (TKScalar Double)
value1 Concrete (TKScalar Double)
-> Concrete (TKScalar Double) -> Concrete (TKScalar Double)
forall a. Num a => a -> a -> a
- Concrete (TKScalar Double)
value2) Concrete (TKScalar Double) -> Concrete (TKScalar Double) -> Bool
forall a. Ord a => a -> a -> Bool
< Concrete (TKScalar Double)
1e-10)
        , TestName -> Bool -> Property
forall prop. Testable prop => TestName -> prop -> Property
counterexample
            (TestName
"Gradient and derivative agrees: "
             TestName -> TestName -> TestName
forall a. [a] -> [a] -> [a]
++ (Double, Concrete (TKScalar Double), Concrete (TKScalar Double),
 Concrete (TKScalar Double))
-> TestName
forall a. Show a => a -> TestName
show ( Double
dt, Concrete (ADTensorKind (TKScalar Double))
Concrete (TKScalar Double)
derivative2, FullShapeTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Concrete (TKScalar Double)
forall (y :: TK).
FullShapeTK y
-> Concrete y -> Concrete y -> Concrete (TKScalar Double)
forall (target :: Target) (y :: TK).
BaseTensor target =>
FullShapeTK y -> target y -> target y -> target (TKScalar Double)
tdot0Target FullShapeTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
ftk Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
gradient1 Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
Concrete (XParams2 Double Double)
ds
                     , FullShapeTK (TKScalar Double)
-> Concrete (TKScalar Double)
-> Concrete (TKScalar Double)
-> Concrete (TKScalar Double)
forall (y :: TK).
FullShapeTK y
-> Concrete y -> Concrete y -> Concrete (TKScalar Double)
forall (target :: Target) (y :: TK).
BaseTensor target =>
FullShapeTK y -> target y -> target y -> target (TKScalar Double)
tdot0Target FullShapeTK (TKScalar Double)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar (Double -> Concrete (TKScalar Double)
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
r -> target (TKScalar r)
kconcrete Double
dt) Concrete (ADTensorKind (TKScalar Double))
Concrete (TKScalar Double)
derivative2
                       Concrete (TKScalar Double)
-> Concrete (TKScalar Double) -> Concrete (TKScalar Double)
forall a. Num a => a -> a -> a
- FullShapeTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Concrete (TKScalar Double)
forall (y :: TK).
FullShapeTK y
-> Concrete y -> Concrete y -> Concrete (TKScalar Double)
forall (target :: Target) (y :: TK).
BaseTensor target =>
FullShapeTK y -> target y -> target y -> target (TKScalar Double)
tdot0Target FullShapeTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
ftk Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
gradient1 Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
Concrete (XParams2 Double Double)
ds ))
            (Concrete (TKScalar Double) -> Concrete (TKScalar Double)
forall a. Num a => a -> a
abs (FullShapeTK (TKScalar Double)
-> Concrete (TKScalar Double)
-> Concrete (TKScalar Double)
-> Concrete (TKScalar Double)
forall (y :: TK).
FullShapeTK y
-> Concrete y -> Concrete y -> Concrete (TKScalar Double)
forall (target :: Target) (y :: TK).
BaseTensor target =>
FullShapeTK y -> target y -> target y -> target (TKScalar Double)
tdot0Target FullShapeTK (TKScalar Double)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar (Double -> Concrete (TKScalar Double)
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
r -> target (TKScalar r)
kconcrete Double
dt) Concrete (ADTensorKind (TKScalar Double))
Concrete (TKScalar Double)
derivative2
                  Concrete (TKScalar Double)
-> Concrete (TKScalar Double) -> Concrete (TKScalar Double)
forall a. Num a => a -> a -> a
- FullShapeTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Concrete (TKScalar Double)
forall (y :: TK).
FullShapeTK y
-> Concrete y -> Concrete y -> Concrete (TKScalar Double)
forall (target :: Target) (y :: TK).
BaseTensor target =>
FullShapeTK y -> target y -> target y -> target (TKScalar Double)
tdot0Target FullShapeTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
ftk Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
gradient1 Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
Concrete (XParams2 Double Double)
ds) Concrete (TKScalar Double) -> Concrete (TKScalar Double) -> Bool
forall a. Ord a => a -> a -> Bool
< Concrete (TKScalar Double)
1e-10)
--        , counterexample  -- this is implied by the other clauses
--            "Gradient is a linear function"
--            (gradient1 === tmultTarget stk targetDt gradient0)
        , TestName -> Property -> Property
forall prop. Testable prop => TestName -> prop -> Property
counterexample
            TestName
"Objective function value unaffected by incoming cotangent"
            (Concrete (TKScalar Double)
value0 Concrete (TKScalar Double)
-> Concrete (TKScalar Double) -> Property
forall a. (Eq a, Show a) => a -> a -> Property
=== Concrete (TKScalar Double)
value1)
        , TestName -> Property -> Property
forall prop. Testable prop => TestName -> prop -> Property
counterexample
            TestName
"Objective function value unaffected by derivative perturbation"
            (Concrete (TKScalar Double)
value2 Concrete (TKScalar Double)
-> Concrete (TKScalar Double) -> Property
forall a. (Eq a, Show a) => a -> a -> Property
=== Concrete (TKScalar Double)
value3)
        , TestName -> Bool -> Property
forall prop. Testable prop => TestName -> prop -> Property
counterexample
            (TestName
"Derivative approximates the perturbation of value: "
             TestName -> TestName -> TestName
forall a. [a] -> [a] -> [a]
++ (Concrete (TKScalar Double), Concrete (TKScalar Double),
 Concrete (TKScalar Double), Concrete (TKScalar Double))
-> TestName
forall a. Show a => a -> TestName
show ( Concrete (TKScalar Double)
value2, Concrete (ADTensorKind (TKScalar Double))
Concrete (TKScalar Double)
derivative3, Concrete (TKScalar Double)
value4
                     , (Concrete (TKScalar Double)
value3 Concrete (TKScalar Double)
-> Concrete (TKScalar Double) -> Concrete (TKScalar Double)
forall a. Num a => a -> a -> a
+ Concrete (ADTensorKind (TKScalar Double))
Concrete (TKScalar Double)
derivative3) Concrete (TKScalar Double)
-> Concrete (TKScalar Double) -> Concrete (TKScalar Double)
forall a. Num a => a -> a -> a
- Concrete (TKScalar Double)
value4) )
            (Concrete (TKScalar Double) -> Concrete (TKScalar Double)
forall a. Num a => a -> a
abs ((Concrete (TKScalar Double)
value3 Concrete (TKScalar Double)
-> Concrete (TKScalar Double) -> Concrete (TKScalar Double)
forall a. Num a => a -> a -> a
+ Concrete (ADTensorKind (TKScalar Double))
Concrete (TKScalar Double)
derivative3) Concrete (TKScalar Double)
-> Concrete (TKScalar Double) -> Concrete (TKScalar Double)
forall a. Num a => a -> a -> a
- Concrete (TKScalar Double)
value4) Concrete (TKScalar Double) -> Concrete (TKScalar Double) -> Bool
forall a. Ord a => a -> a -> Bool
< Concrete (TKScalar Double)
1e-6)
        ]
  ]