{-# LANGUAGE TemplateHaskell #-}
{-# OPTIONS_GHC -Wno-missing-export-lists #-}
-- | A set of benchmarks using fully connected MNIST neural networks.
module BenchMnistTools where

import Prelude

import Control.Arrow ((***))
import Control.DeepSeq (NFData (..))
import Criterion.Main
import Data.Default qualified as Default
import Data.Proxy (Proxy (Proxy))
import GHC.Exts (WithDict)
import GHC.TypeLits (KnownNat)
import System.Random
import Test.Inspection
import Type.Reflection (Typeable)

import HordeAd
import HordeAd.Core.Adaptor
import HordeAd.Core.OpsConcrete ()
import HordeAd.External.OptimizerTools

import Data.Array.Nested qualified as Nested
import Data.Array.Nested.Shaped.Shape

import MnistData
import MnistFcnnRanked1 qualified
import MnistFcnnRanked2 (XParams2)
import MnistFcnnRanked2 qualified

-- * Using lists of vectors, which is rank 1

type XParams widthHidden widthHidden2 r =
  X (MnistFcnnRanked1.ADFcnnMnist1Parameters
       Concrete widthHidden widthHidden2 r)

-- POPL differentiation, straight via the ADVal instance of RankedTensor,
-- which side-steps vectorization.
mnistTrainBench1VTA
  :: forall r. r ~ Double
  => String
  -> Int -> Int -> Double -> Int -> [MnistDataLinearR r]
  -> Benchmark
mnistTrainBench1VTA :: forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTrainBench1VTA String
prefix Int
widthHiddenInt Int
widthHidden2Int
                    Double
gamma Int
batchSize [MnistDataLinearR r]
xs =
  Int
-> (forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
widthHiddenInt ((forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
 -> Benchmark)
-> (forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark
forall a b. (a -> b) -> a -> b
$ \(SNat n
widthHiddenSNat :: SNat widthHidden) ->
  Int
-> (forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
widthHidden2Int ((forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
 -> Benchmark)
-> (forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark
forall a b. (a -> b) -> a -> b
$ \(SNat n
widthHidden2SNat :: SNat widthHidden2) ->
  SingletonTK (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
-> (KnownSTK (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)) =>
    Benchmark)
-> Benchmark
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK
    (SingletonTK (TKS ((':) @Nat 784 ('[] @Nat)) Double)
-> SNat n
-> SingletonTK (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
forall (t :: TK) (n :: Nat).
SingletonTK t -> SNat n -> SingletonTK (Tups n t)
stkOfListR (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(TKS '[SizeMnistGlyph] r)) (forall (n :: Nat). KnownNat n => SNat n
SNat @widthHidden)) ((KnownSTK (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)) =>
  Benchmark)
 -> Benchmark)
-> (KnownSTK (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)) =>
    Benchmark)
-> Benchmark
forall a b. (a -> b) -> a -> b
$
  SingletonTK (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
-> (KnownSTK (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)) =>
    Benchmark)
-> Benchmark
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK
    (SingletonTK (TKS ((':) @Nat n ('[] @Nat)) Float)
-> SNat n
-> SingletonTK (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
forall (t :: TK) (n :: Nat).
SingletonTK t -> SNat n -> SingletonTK (Tups n t)
stkOfListR (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(TKS '[widthHidden] Float)) (forall (n :: Nat). KnownNat n => SNat n
SNat @widthHidden2)) ((KnownSTK (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)) =>
  Benchmark)
 -> Benchmark)
-> (KnownSTK (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)) =>
    Benchmark)
-> Benchmark
forall a b. (a -> b) -> a -> b
$
  let valsInit :: MnistFcnnRanked1.ADFcnnMnist1Parameters
                    Concrete widthHidden widthHidden2 r
      valsInit :: ADFcnnMnist1Parameters Concrete n n r
valsInit = (ADFcnnMnist1Parameters Concrete n n r, StdGen)
-> ADFcnnMnist1Parameters Concrete n n r
forall a b. (a, b) -> a
fst ((ADFcnnMnist1Parameters Concrete n n r, StdGen)
 -> ADFcnnMnist1Parameters Concrete n n r)
-> (ADFcnnMnist1Parameters Concrete n n r, StdGen)
-> ADFcnnMnist1Parameters Concrete n n r
forall a b. (a -> b) -> a -> b
$ Double
-> StdGen
-> (((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) Double)),
      Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
     (ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
      Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
     (ListR
        SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
      Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))),
    StdGen)
forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue Double
1 (Int -> StdGen
mkStdGen Int
44)
      targetInit :: Concrete (XParams widthHidden widthHidden2 r)
      targetInit :: Concrete (XParams n n r)
targetInit = forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget @Concrete ((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) Double)),
  Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
 (ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
  Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
 (ListR
    SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
  Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
ADFcnnMnist1Parameters Concrete n n r
valsInit
  in do
    let f :: MnistDataLinearR Double
          -> ADVal Concrete (XParams widthHidden widthHidden2 Double)
          -> ADVal Concrete (TKScalar Double)
        f :: MnistDataLinearR Double
-> ADVal Concrete (XParams n n Double)
-> ADVal Concrete (TKScalar Double)
f (Ranked 1 Double
glyph, Ranked 1 Double
label) ADVal Concrete (XParams n n Double)
adinputs =
          SNat n
-> SNat n
-> (ADVal Concrete (TKR 1 Double), ADVal Concrete (TKR 1 Double))
-> ADFcnnMnist1Parameters (ADVal Concrete) n n Double
-> ADVal Concrete (TKScalar Double)
forall (target :: Target) r (widthHidden :: Nat)
       (widthHidden2 :: Nat).
(ADReady target, GoodScalar r, Differentiable r) =>
SNat widthHidden
-> SNat widthHidden2
-> (target (TKR 1 r), target (TKR 1 r))
-> ADFcnnMnist1Parameters target widthHidden widthHidden2 r
-> target (TKScalar r)
MnistFcnnRanked1.afcnnMnistLoss1
            SNat n
widthHiddenSNat SNat n
widthHidden2SNat
            (Ranked 1 Double -> ADVal Concrete (TKR 1 Double)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 1 Double
glyph, Ranked 1 Double -> ADVal Concrete (TKR 1 Double)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 1 Double
label) (ADVal
  Concrete
  (X ((ListR
         n (ADVal Concrete (TKS ((':) @Nat 784 ('[] @Nat)) Double)),
       ADVal Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
      (ListR n (ADVal Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
       ADVal Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
      (ListR
         SizeMnistLabel
         (ADVal Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
       ADVal
         Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))))
-> ((ListR
       n (ADVal Concrete (TKS ((':) @Nat 784 ('[] @Nat)) Double)),
     ADVal Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
    (ListR n (ADVal Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
     ADVal Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
    (ListR
       SizeMnistLabel
       (ADVal Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
     ADVal Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget ADVal
  Concrete
  (X ((ListR
         n (ADVal Concrete (TKS ((':) @Nat 784 ('[] @Nat)) Double)),
       ADVal Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
      (ListR n (ADVal Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
       ADVal Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
      (ListR
         SizeMnistLabel
         (ADVal Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
       ADVal
         Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))))
ADVal Concrete (XParams n n Double)
adinputs)
        chunk :: [MnistDataLinearR r]
chunk = Int -> [MnistDataLinearR r] -> [MnistDataLinearR r]
forall a. Int -> [a] -> [a]
take Int
batchSize [MnistDataLinearR r]
xs
        gradf :: [MnistDataLinearR Double]
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
              (TKS ((':) @Nat n ('[] @Nat)) Double))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) Double)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) Double)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) Double)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
gradf [MnistDataLinearR Double]
c = (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
            (TKS ((':) @Nat n ('[] @Nat)) Double))
         (TKProduct
            (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
            (TKS ((':) @Nat n ('[] @Nat)) Double)))
      (TKProduct
         (TKProduct
            (TKS ((':) @Nat n ('[] @Nat)) Double)
            (TKProduct
               (TKS ((':) @Nat n ('[] @Nat)) Double)
               (TKProduct
                  (TKS ((':) @Nat n ('[] @Nat)) Double)
                  (TKProduct
                     (TKS ((':) @Nat n ('[] @Nat)) Double)
                     (TKProduct
                        (TKS ((':) @Nat n ('[] @Nat)) Double)
                        (TKProduct
                           (TKS ((':) @Nat n ('[] @Nat)) Double)
                           (TKProduct
                              (TKS ((':) @Nat n ('[] @Nat)) Double)
                              (TKProduct
                                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                                 (TKProduct
                                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                                    (TKProduct
                                       (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
         (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))),
 Concrete (TKScalar Double))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
              (TKS ((':) @Nat n ('[] @Nat)) Double))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) Double)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) Double)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) Double)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
forall a b. (a, b) -> a
fst ((Concrete
    (TKProduct
       (TKProduct
          (TKProduct
             (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
             (TKS ((':) @Nat n ('[] @Nat)) Double))
          (TKProduct
             (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
             (TKS ((':) @Nat n ('[] @Nat)) Double)))
       (TKProduct
          (TKProduct
             (TKS ((':) @Nat n ('[] @Nat)) Double)
             (TKProduct
                (TKS ((':) @Nat n ('[] @Nat)) Double)
                (TKProduct
                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                   (TKProduct
                      (TKS ((':) @Nat n ('[] @Nat)) Double)
                      (TKProduct
                         (TKS ((':) @Nat n ('[] @Nat)) Double)
                         (TKProduct
                            (TKS ((':) @Nat n ('[] @Nat)) Double)
                            (TKProduct
                               (TKS ((':) @Nat n ('[] @Nat)) Double)
                               (TKProduct
                                  (TKS ((':) @Nat n ('[] @Nat)) Double)
                                  (TKProduct
                                     (TKS ((':) @Nat n ('[] @Nat)) Double)
                                     (TKProduct
                                        (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
          (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))),
  Concrete (TKScalar Double))
 -> Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
               (TKS ((':) @Nat n ('[] @Nat)) Double))
            (TKProduct
               (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
               (TKS ((':) @Nat n ('[] @Nat)) Double)))
         (TKProduct
            (TKProduct
               (TKS ((':) @Nat n ('[] @Nat)) Double)
               (TKProduct
                  (TKS ((':) @Nat n ('[] @Nat)) Double)
                  (TKProduct
                     (TKS ((':) @Nat n ('[] @Nat)) Double)
                     (TKProduct
                        (TKS ((':) @Nat n ('[] @Nat)) Double)
                        (TKProduct
                           (TKS ((':) @Nat n ('[] @Nat)) Double)
                           (TKProduct
                              (TKS ((':) @Nat n ('[] @Nat)) Double)
                              (TKProduct
                                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                                 (TKProduct
                                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                                    (TKProduct
                                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                                       (TKProduct
                                          (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
            (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
               (TKS ((':) @Nat n ('[] @Nat)) Double))
            (TKProduct
               (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
               (TKS ((':) @Nat n ('[] @Nat)) Double)))
         (TKProduct
            (TKProduct
               (TKS ((':) @Nat n ('[] @Nat)) Double)
               (TKProduct
                  (TKS ((':) @Nat n ('[] @Nat)) Double)
                  (TKProduct
                     (TKS ((':) @Nat n ('[] @Nat)) Double)
                     (TKProduct
                        (TKS ((':) @Nat n ('[] @Nat)) Double)
                        (TKProduct
                           (TKS ((':) @Nat n ('[] @Nat)) Double)
                           (TKProduct
                              (TKS ((':) @Nat n ('[] @Nat)) Double)
                              (TKProduct
                                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                                 (TKProduct
                                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                                    (TKProduct
                                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                                       (TKProduct
                                          (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
            (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))),
    Concrete (TKScalar Double))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
              (TKS ((':) @Nat n ('[] @Nat)) Double))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) Double)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) Double)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) Double)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
forall a b. (a -> b) -> a -> b
$ SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
           (TKS ((':) @Nat n ('[] @Nat)) Double))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) Double)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) Double)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) Double)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
-> Double
-> (MnistDataLinearR Double
    -> ADVal
         Concrete
         (TKProduct
            (TKProduct
               (TKProduct
                  (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
                  (TKS ((':) @Nat n ('[] @Nat)) Double))
               (TKProduct
                  (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
                  (TKS ((':) @Nat n ('[] @Nat)) Double)))
            (TKProduct
               (TKProduct
                  (TKS ((':) @Nat n ('[] @Nat)) Double)
                  (TKProduct
                     (TKS ((':) @Nat n ('[] @Nat)) Double)
                     (TKProduct
                        (TKS ((':) @Nat n ('[] @Nat)) Double)
                        (TKProduct
                           (TKS ((':) @Nat n ('[] @Nat)) Double)
                           (TKProduct
                              (TKS ((':) @Nat n ('[] @Nat)) Double)
                              (TKProduct
                                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                                 (TKProduct
                                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                                    (TKProduct
                                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                                       (TKProduct
                                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                                          (TKProduct
                                             (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
               (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
    -> ADVal Concrete (TKScalar Double))
-> [MnistDataLinearR Double]
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
              (TKS ((':) @Nat n ('[] @Nat)) Double))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) Double)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) Double)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) Double)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
               (TKS ((':) @Nat n ('[] @Nat)) Double))
            (TKProduct
               (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
               (TKS ((':) @Nat n ('[] @Nat)) Double)))
         (TKProduct
            (TKProduct
               (TKS ((':) @Nat n ('[] @Nat)) Double)
               (TKProduct
                  (TKS ((':) @Nat n ('[] @Nat)) Double)
                  (TKProduct
                     (TKS ((':) @Nat n ('[] @Nat)) Double)
                     (TKProduct
                        (TKS ((':) @Nat n ('[] @Nat)) Double)
                        (TKProduct
                           (TKS ((':) @Nat n ('[] @Nat)) Double)
                           (TKProduct
                              (TKS ((':) @Nat n ('[] @Nat)) Double)
                              (TKProduct
                                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                                 (TKProduct
                                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                                    (TKProduct
                                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                                       (TKProduct
                                          (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
            (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))),
    Concrete (TKScalar Double))
forall a (x :: TK) (z :: TK).
SingletonTK x
-> Double
-> (a -> ADVal Concrete x -> ADVal Concrete z)
-> [a]
-> Concrete x
-> (Concrete x, Concrete z)
sgdSTK SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
           (TKS ((':) @Nat n ('[] @Nat)) Double))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) Double)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) Double)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) Double)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Double
gamma MnistDataLinearR Double
-> ADVal
     Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
              (TKS ((':) @Nat n ('[] @Nat)) Double))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) Double)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) Double)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) Double)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
-> ADVal Concrete (TKScalar Double)
MnistDataLinearR Double
-> ADVal Concrete (XParams n n Double)
-> ADVal Concrete (TKScalar Double)
f [MnistDataLinearR Double]
c Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
           (TKS ((':) @Nat n ('[] @Nat)) Double))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) Double)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) Double)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) Double)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
Concrete (XParams n n r)
targetInit
        name :: String
name =
          String
prefix
          String -> String -> String
forall a. [a] -> [a] -> [a]
++ [String] -> String
unwords
               [ String
"v" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (SingletonTK (XParams n n r) -> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK
                              (SingletonTK (XParams n n r) -> Int)
-> SingletonTK (XParams n n r) -> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams widthHidden widthHidden2 r))
               , String
"m0" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" =" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
           (TKS ((':) @Nat n ('[] @Nat)) Double))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) Double)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) Double)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) Double)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
              (TKS ((':) @Nat n ('[] @Nat)) Double))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) Double)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) Double)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) Double)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
           (TKS ((':) @Nat n ('[] @Nat)) Double))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) Double)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) Double)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) Double)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
           (TKS ((':) @Nat n ('[] @Nat)) Double))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) Double)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) Double)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) Double)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
Concrete (XParams n n r)
targetInit) ]
    String -> Benchmarkable -> Benchmark
bench String
name (Benchmarkable -> Benchmark) -> Benchmarkable -> Benchmark
forall a b. (a -> b) -> a -> b
$ ([MnistDataLinearR Double]
 -> Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
               (TKS ((':) @Nat n ('[] @Nat)) Double))
            (TKProduct
               (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
               (TKS ((':) @Nat n ('[] @Nat)) Double)))
         (TKProduct
            (TKProduct
               (TKS ((':) @Nat n ('[] @Nat)) Double)
               (TKProduct
                  (TKS ((':) @Nat n ('[] @Nat)) Double)
                  (TKProduct
                     (TKS ((':) @Nat n ('[] @Nat)) Double)
                     (TKProduct
                        (TKS ((':) @Nat n ('[] @Nat)) Double)
                        (TKProduct
                           (TKS ((':) @Nat n ('[] @Nat)) Double)
                           (TKProduct
                              (TKS ((':) @Nat n ('[] @Nat)) Double)
                              (TKProduct
                                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                                 (TKProduct
                                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                                    (TKProduct
                                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                                       (TKProduct
                                          (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
            (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))))
-> [MnistDataLinearR Double] -> Benchmarkable
forall b a. NFData b => (a -> b) -> a -> Benchmarkable
nf [MnistDataLinearR Double]
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
              (TKS ((':) @Nat n ('[] @Nat)) Double))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) Double)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) Double)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) Double)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
gradf [MnistDataLinearR r]
[MnistDataLinearR Double]
chunk

mnistTestBench1VTA
  :: forall r. r ~ Double
  => String
  -> Int -> Int -> Double -> Int -> [MnistDataLinearR r]
  -> Benchmark
mnistTestBench1VTA :: forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTestBench1VTA String
prefix Int
widthHiddenInt Int
widthHidden2Int
                   Double
_gamma Int
batchSize [MnistDataLinearR r]
xs =
  Int
-> (forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
widthHiddenInt ((forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
 -> Benchmark)
-> (forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark
forall a b. (a -> b) -> a -> b
$ \(SNat n
widthHiddenSNat :: SNat widthHidden) ->
  Int
-> (forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
widthHidden2Int ((forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
 -> Benchmark)
-> (forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark
forall a b. (a -> b) -> a -> b
$ \(SNat n
widthHidden2SNat :: SNat widthHidden2) ->
  SingletonTK (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
-> (KnownSTK (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)) =>
    Benchmark)
-> Benchmark
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK
    (SingletonTK (TKS ((':) @Nat 784 ('[] @Nat)) Double)
-> SNat n
-> SingletonTK (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
forall (t :: TK) (n :: Nat).
SingletonTK t -> SNat n -> SingletonTK (Tups n t)
stkOfListR (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(TKS '[SizeMnistGlyph] r)) (forall (n :: Nat). KnownNat n => SNat n
SNat @widthHidden)) ((KnownSTK (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)) =>
  Benchmark)
 -> Benchmark)
-> (KnownSTK (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)) =>
    Benchmark)
-> Benchmark
forall a b. (a -> b) -> a -> b
$
  SingletonTK (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
-> (KnownSTK (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)) =>
    Benchmark)
-> Benchmark
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK
    (SingletonTK (TKS ((':) @Nat n ('[] @Nat)) Float)
-> SNat n
-> SingletonTK (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
forall (t :: TK) (n :: Nat).
SingletonTK t -> SNat n -> SingletonTK (Tups n t)
stkOfListR (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(TKS '[widthHidden] Float)) (forall (n :: Nat). KnownNat n => SNat n
SNat @widthHidden2)) ((KnownSTK (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)) =>
  Benchmark)
 -> Benchmark)
-> (KnownSTK (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)) =>
    Benchmark)
-> Benchmark
forall a b. (a -> b) -> a -> b
$
  let valsInit :: MnistFcnnRanked1.ADFcnnMnist1Parameters
                    Concrete widthHidden widthHidden2 r
      valsInit :: ADFcnnMnist1Parameters Concrete n n r
valsInit = (ADFcnnMnist1Parameters Concrete n n r, StdGen)
-> ADFcnnMnist1Parameters Concrete n n r
forall a b. (a, b) -> a
fst ((ADFcnnMnist1Parameters Concrete n n r, StdGen)
 -> ADFcnnMnist1Parameters Concrete n n r)
-> (ADFcnnMnist1Parameters Concrete n n r, StdGen)
-> ADFcnnMnist1Parameters Concrete n n r
forall a b. (a -> b) -> a -> b
$ Double
-> StdGen
-> (((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) Double)),
      Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
     (ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
      Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
     (ListR
        SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
      Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))),
    StdGen)
forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue Double
1 (Int -> StdGen
mkStdGen Int
44)
      targetInit :: Concrete (XParams widthHidden widthHidden2 r)
      targetInit :: Concrete (XParams n n r)
targetInit = forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget @Concrete ((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) Double)),
  Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
 (ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
  Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
 (ListR
    SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
  Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
ADFcnnMnist1Parameters Concrete n n r
valsInit
      ftest :: [MnistDataLinearR r]
            -> MnistFcnnRanked1.ADFcnnMnist1Parameters
                 Concrete widthHidden widthHidden2 r
            -> r
      ftest :: [MnistDataLinearR r] -> ADFcnnMnist1Parameters Concrete n n r -> r
ftest = SNat n
-> SNat n
-> [MnistDataLinearR r]
-> ADFcnnMnist1Parameters Concrete n n r
-> r
forall (target :: Target) (widthHidden :: Nat)
       (widthHidden2 :: Nat) r.
((target :: Target) ~ (Concrete :: Target), GoodScalar r,
 Differentiable r) =>
SNat widthHidden
-> SNat widthHidden2
-> [MnistDataLinearR r]
-> ADFcnnMnist1Parameters target widthHidden widthHidden2 r
-> r
MnistFcnnRanked1.afcnnMnistTest1 SNat n
widthHiddenSNat SNat n
widthHidden2SNat
  in do
    let chunk :: [MnistDataLinearR r]
chunk = Int -> [MnistDataLinearR r] -> [MnistDataLinearR r]
forall a. Int -> [a] -> [a]
take Int
batchSize [MnistDataLinearR r]
xs
        score :: [MnistDataLinearR r] -> r
score [MnistDataLinearR r]
c = [MnistDataLinearR r] -> ADFcnnMnist1Parameters Concrete n n r -> r
ftest [MnistDataLinearR r]
c ADFcnnMnist1Parameters Concrete n n r
valsInit
        name :: String
name =
          String
"test " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
prefix
          String -> String -> String
forall a. [a] -> [a] -> [a]
++ [String] -> String
unwords
               [ String
"v" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (SingletonTK (XParams n n r) -> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK
                              (SingletonTK (XParams n n r) -> Int)
-> SingletonTK (XParams n n r) -> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams widthHidden widthHidden2 r))
               , String
"m0" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" =" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
           (TKS ((':) @Nat n ('[] @Nat)) Double))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) Double)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) Double)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) Double)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
              (TKS ((':) @Nat n ('[] @Nat)) Double))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) Double)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) Double)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) Double)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
           (TKS ((':) @Nat n ('[] @Nat)) Double))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) Double)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) Double)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) Double)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
           (TKS ((':) @Nat n ('[] @Nat)) Double))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) Double)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) Double)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) Double)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
Concrete (XParams n n r)
targetInit) ]
    String -> Benchmarkable -> Benchmark
bench String
name (Benchmarkable -> Benchmark) -> Benchmarkable -> Benchmark
forall a b. (a -> b) -> a -> b
$ ([MnistDataLinearR r] -> r)
-> [MnistDataLinearR r] -> Benchmarkable
forall a b. (a -> b) -> a -> Benchmarkable
whnf [MnistDataLinearR r] -> r
score [MnistDataLinearR r]
chunk

mnistBGroup1VTA :: Int -> Benchmark
mnistBGroup1VTA :: Int -> Benchmark
mnistBGroup1VTA Int
chunkLength =
  IO [MnistDataLinearR Double]
-> ([MnistDataLinearR Double] -> Benchmark) -> Benchmark
forall env. NFData env => IO env -> (env -> Benchmark) -> Benchmark
env (do
    testData0 <- String -> String -> IO [MnistData Double]
forall r.
(Storable r, Fractional r) =>
String -> String -> IO [MnistData r]
loadMnistData String
testGlyphsPath String
testLabelsPath  -- 10k total
    let testData = StdGen -> [MnistData Double] -> [MnistData Double]
forall a. StdGen -> [a] -> [a]
shuffle (Int -> StdGen
mkStdGen Int
42) [MnistData Double]
testData0
    return $! map mkMnistDataLinearR $ take chunkLength testData) (([MnistDataLinearR Double] -> Benchmark) -> Benchmark)
-> ([MnistDataLinearR Double] -> Benchmark) -> Benchmark
forall a b. (a -> b) -> a -> b
$
  \ [MnistDataLinearR Double]
xs ->
  String -> [Benchmark] -> Benchmark
bgroup (String
"2-hidden-layer rank 1 VTA MNIST nn with samples: "
          String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
chunkLength)
       [ String
-> Int
-> Int
-> Double
-> Int
-> [MnistDataLinearR Double]
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTestBench1VTA String
"30|10 " Int
30 Int
10 Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs
       , String
-> Int
-> Int
-> Double
-> Int
-> [MnistDataLinearR Double]
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTrainBench1VTA String
"30|10 " Int
30 Int
10 Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs
           -- toy width
       , String
-> Int
-> Int
-> Double
-> Int
-> [MnistDataLinearR Double]
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTestBench1VTA String
"300|100 " Int
300 Int
100 Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs
       , String
-> Int
-> Int
-> Double
-> Int
-> [MnistDataLinearR Double]
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTrainBench1VTA String
"300|100 " Int
300 Int
100 Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs
           -- ordinary width
       , String
-> Int
-> Int
-> Double
-> Int
-> [MnistDataLinearR Double]
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTestBench1VTA String
"500|150 " Int
500 Int
150 Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs
       , String
-> Int
-> Int
-> Double
-> Int
-> [MnistDataLinearR Double]
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTrainBench1VTA String
"500|150 " Int
500 Int
150 Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs
           -- another common width
       , String
-> Int
-> Int
-> Double
-> Int
-> [MnistDataLinearR Double]
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTrainBench1VTA String
"1500|500 " Int
1500 Int
500 Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs
       ]

-- JAX differentiation, Ast term built and differentiated only once
-- and the result interpreted with different inputs in each gradient
-- descent iteration.
mnistTrainBench1VTO
  :: forall r. r ~ Double
  => String
  -> Int -> Int -> Double -> Int -> [MnistDataLinearR r]
  -> Benchmark
mnistTrainBench1VTO :: forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTrainBench1VTO String
prefix Int
widthHiddenInt Int
widthHidden2Int
                    Double
gamma Int
batchSize [MnistDataLinearR r]
xs =
  Int
-> (forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
widthHiddenInt ((forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
 -> Benchmark)
-> (forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark
forall a b. (a -> b) -> a -> b
$ \(SNat n
widthHiddenSNat :: SNat widthHidden) ->
  Int
-> (forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
widthHidden2Int ((forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
 -> Benchmark)
-> (forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark
forall a b. (a -> b) -> a -> b
$ \(SNat n
widthHidden2SNat :: SNat widthHidden2) ->
  SingletonTK (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
-> (KnownSTK (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)) =>
    Benchmark)
-> Benchmark
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK
    (SingletonTK (TKS ((':) @Nat 784 ('[] @Nat)) Double)
-> SNat n
-> SingletonTK (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
forall (t :: TK) (n :: Nat).
SingletonTK t -> SNat n -> SingletonTK (Tups n t)
stkOfListR (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(TKS '[SizeMnistGlyph] r)) (forall (n :: Nat). KnownNat n => SNat n
SNat @widthHidden)) ((KnownSTK (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)) =>
  Benchmark)
 -> Benchmark)
-> (KnownSTK (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)) =>
    Benchmark)
-> Benchmark
forall a b. (a -> b) -> a -> b
$
  SingletonTK (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
-> (KnownSTK (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)) =>
    Benchmark)
-> Benchmark
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK
    (SingletonTK (TKS ((':) @Nat n ('[] @Nat)) Float)
-> SNat n
-> SingletonTK (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
forall (t :: TK) (n :: Nat).
SingletonTK t -> SNat n -> SingletonTK (Tups n t)
stkOfListR (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(TKS '[widthHidden] Float)) (forall (n :: Nat). KnownNat n => SNat n
SNat @widthHidden2)) ((KnownSTK (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)) =>
  Benchmark)
 -> Benchmark)
-> (KnownSTK (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)) =>
    Benchmark)
-> Benchmark
forall a b. (a -> b) -> a -> b
$
  let valsInit :: MnistFcnnRanked1.ADFcnnMnist1Parameters
                    Concrete widthHidden widthHidden2 r
      valsInit :: ADFcnnMnist1Parameters Concrete n n r
valsInit = (ADFcnnMnist1Parameters Concrete n n r, StdGen)
-> ADFcnnMnist1Parameters Concrete n n r
forall a b. (a, b) -> a
fst ((ADFcnnMnist1Parameters Concrete n n r, StdGen)
 -> ADFcnnMnist1Parameters Concrete n n r)
-> (ADFcnnMnist1Parameters Concrete n n r, StdGen)
-> ADFcnnMnist1Parameters Concrete n n r
forall a b. (a -> b) -> a -> b
$ Double
-> StdGen
-> (((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) Double)),
      Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
     (ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
      Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
     (ListR
        SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
      Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))),
    StdGen)
forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue Double
1 (Int -> StdGen
mkStdGen Int
44)
      targetInit :: Concrete (XParams widthHidden widthHidden2 r)
      targetInit :: Concrete (XParams n n r)
targetInit = forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget @Concrete ((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) Double)),
  Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
 (ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
  Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
 (ListR
    SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
  Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
ADFcnnMnist1Parameters Concrete n n r
valsInit
  in do
{-      -- g is not enough to specialize to Double instead of to r,
        -- despite the declaration of r ~ Double above
        f :: ( MnistFcnnRanked1.ADFcnnMnist1Parameters
                 (AstTensor AstMethodLet FullSpan)
                 widthHidden widthHidden2 r
             , ( AstTensor AstMethodLet FullSpan (TKR 1 r)
               , AstTensor AstMethodLet FullSpan (TKR 1 r) ) )
          -> AstTensor AstMethodLet FullSpan (TKScalar r)
        f = \ (pars, (glyphR, labelR)) ->
          MnistFcnnRanked1.afcnnMnistLoss1
            widthHiddenSNat widthHidden2SNat
            (glyphR, labelR) pars
        g :: ( MnistFcnnRanked1.ADFcnnMnist1Parameters (AstTensor AstMethodLet FullSpan) widthHidden widthHidden2 Double, ( AstTensor AstMethodLet FullSpan (TKR 1 Double), AstTensor AstMethodLet FullSpan (TKR 1 Double) ) ) -> AstTensor AstMethodLet FullSpan (TKScalar Double)
        g = f
-}
    let dataInit :: (Concrete (TKR 1 Double), Concrete (TKR 1 Double))
dataInit = case [MnistDataLinearR r]
xs of
          MnistDataLinearR r
d : [MnistDataLinearR r]
_ -> (Ranked 1 Double -> Concrete (TKR 1 Double)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete (Ranked 1 Double -> Concrete (TKR 1 Double))
-> (Ranked 1 Double -> Concrete (TKR 1 Double))
-> MnistDataLinearR Double
-> (Concrete (TKR 1 Double), Concrete (TKR 1 Double))
forall b c b' c'. (b -> c) -> (b' -> c') -> (b, b') -> (c, c')
forall (a :: Type -> Type -> Type) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
*** Ranked 1 Double -> Concrete (TKR 1 Double)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete) MnistDataLinearR r
d
          [] -> String -> (Concrete (TKR 1 Double), Concrete (TKR 1 Double))
forall a. HasCallStack => String -> a
error String
"empty test data"
        f :: ( MnistFcnnRanked1.ADFcnnMnist1Parameters
                 (AstTensor AstMethodLet FullSpan)
                 widthHidden widthHidden2 Double
             , ( AstTensor AstMethodLet FullSpan (TKR 1 Double)
               , AstTensor AstMethodLet FullSpan (TKR 1 Double) ) )
          -> AstTensor AstMethodLet FullSpan (TKScalar Double)
        f :: (ADFcnnMnist1Parameters
   (AstTensor AstMethodLet FullSpan) n n Double,
 (AstTensor AstMethodLet FullSpan (TKR 1 Double),
  AstTensor AstMethodLet FullSpan (TKR 1 Double)))
-> AstTensor AstMethodLet FullSpan (TKScalar Double)
f = \ (ADFcnnMnist1Parameters (AstTensor AstMethodLet FullSpan) n n Double
pars, (AstTensor AstMethodLet FullSpan (TKR 1 Double)
glyphR, AstTensor AstMethodLet FullSpan (TKR 1 Double)
labelR)) ->
          forall (target :: Target) r (widthHidden :: Nat)
       (widthHidden2 :: Nat).
(ADReady target, GoodScalar r, Differentiable r) =>
SNat widthHidden
-> SNat widthHidden2
-> (target (TKR 1 r), target (TKR 1 r))
-> ADFcnnMnist1Parameters target widthHidden widthHidden2 r
-> target (TKScalar r)
MnistFcnnRanked1.afcnnMnistLoss1 @_ @Double
            SNat n
widthHiddenSNat SNat n
widthHidden2SNat
            (AstTensor AstMethodLet FullSpan (TKR 1 Double)
glyphR, AstTensor AstMethodLet FullSpan (TKR 1 Double)
labelR) ADFcnnMnist1Parameters (AstTensor AstMethodLet FullSpan) n n Double
pars
        artRaw :: AstArtifactRev
  (X (((ListR
          n
          (AstTensor
             AstMethodLet FullSpan (TKS ((':) @Nat 784 ('[] @Nat)) Double)),
        AstTensor
          AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Double)),
       (ListR
          n
          (AstTensor
             AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Float)),
        AstTensor
          AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Double)),
       (ListR
          SizeMnistLabel
          (AstTensor
             AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Double)),
        AstTensor
          AstMethodLet
          FullSpan
          (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))),
      (AstTensor AstMethodLet FullSpan (TKR 1 Double),
       AstTensor AstMethodLet FullSpan (TKR 1 Double))))
  (TKScalar Double)
artRaw = ((((ListR
      n
      (AstTensor
         AstMethodLet FullSpan (TKS ((':) @Nat 784 ('[] @Nat)) Double)),
    AstTensor
      AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Double)),
   (ListR
      n
      (AstTensor
         AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Float)),
    AstTensor
      AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Double)),
   (ListR
      SizeMnistLabel
      (AstTensor
         AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Double)),
    AstTensor
      AstMethodLet
      FullSpan
      (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))),
  (AstTensor AstMethodLet FullSpan (TKR 1 Double),
   AstTensor AstMethodLet FullSpan (TKR 1 Double)))
 -> AstTensor AstMethodLet FullSpan (TKScalar Double))
-> Value
     (((ListR
          n
          (AstTensor
             AstMethodLet FullSpan (TKS ((':) @Nat 784 ('[] @Nat)) Double)),
        AstTensor
          AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Double)),
       (ListR
          n
          (AstTensor
             AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Float)),
        AstTensor
          AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Double)),
       (ListR
          SizeMnistLabel
          (AstTensor
             AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Double)),
        AstTensor
          AstMethodLet
          FullSpan
          (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))),
      (AstTensor AstMethodLet FullSpan (TKR 1 Double),
       AstTensor AstMethodLet FullSpan (TKR 1 Double)))
-> AstArtifactRev
     (X (((ListR
             n
             (AstTensor
                AstMethodLet FullSpan (TKS ((':) @Nat 784 ('[] @Nat)) Double)),
           AstTensor
             AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Double)),
          (ListR
             n
             (AstTensor
                AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Float)),
           AstTensor
             AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Double)),
          (ListR
             SizeMnistLabel
             (AstTensor
                AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Double)),
           AstTensor
             AstMethodLet
             FullSpan
             (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))),
         (AstTensor AstMethodLet FullSpan (TKR 1 Double),
          AstTensor AstMethodLet FullSpan (TKR 1 Double))))
     (TKScalar Double)
forall src r tgt.
((X src :: TK) ~ (X (Value src) :: TK), KnownSTK (X src),
 AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
 AdaptableTarget Concrete (Value src),
 (tgt :: Type)
 ~ (AstTensor AstMethodLet FullSpan (TKScalar r) :: Type)) =>
(src -> tgt) -> Value src -> AstArtifactRev (X src) (TKScalar r)
gradArtifact (((ListR
     n
     (AstTensor
        AstMethodLet FullSpan (TKS ((':) @Nat 784 ('[] @Nat)) Double)),
   AstTensor
     AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Double)),
  (ListR
     n
     (AstTensor
        AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Float)),
   AstTensor
     AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Double)),
  (ListR
     SizeMnistLabel
     (AstTensor
        AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Double)),
   AstTensor
     AstMethodLet
     FullSpan
     (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))),
 (AstTensor AstMethodLet FullSpan (TKR 1 Double),
  AstTensor AstMethodLet FullSpan (TKR 1 Double)))
-> AstTensor AstMethodLet FullSpan (TKScalar Double)
(ADFcnnMnist1Parameters
   (AstTensor AstMethodLet FullSpan) n n Double,
 (AstTensor AstMethodLet FullSpan (TKR 1 Double),
  AstTensor AstMethodLet FullSpan (TKR 1 Double)))
-> AstTensor AstMethodLet FullSpan (TKScalar Double)
f (((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) Double)),
  Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
 (ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
  Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
 (ListR
    SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
  Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
ADFcnnMnist1Parameters Concrete n n r
valsInit, (Concrete (TKR 1 Double), Concrete (TKR 1 Double))
dataInit)
        art :: AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
              (TKS ((':) @Nat n ('[] @Nat)) Double))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) Double)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) Double)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) Double)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
     (TKProduct (TKR 1 r) (TKR 1 r)))
  (TKScalar Double)
art = AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
              (TKS ((':) @Nat n ('[] @Nat)) Double))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) Double)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) Double)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) Double)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
     (TKProduct (TKR 1 r) (TKR 1 r)))
  (TKScalar Double)
-> AstArtifactRev
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
                 (TKS ((':) @Nat n ('[] @Nat)) Double))
              (TKProduct
                 (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
                 (TKS ((':) @Nat n ('[] @Nat)) Double)))
           (TKProduct
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) Double)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) Double)
                                         (TKProduct
                                            (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
              (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
        (TKProduct (TKR 1 r) (TKR 1 r)))
     (TKScalar Double)
forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstArtifactRev x z
simplifyArtifactGradient AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
              (TKS ((':) @Nat n ('[] @Nat)) Double))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) Double)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) Double)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) Double)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
     (TKProduct (TKR 1 r) (TKR 1 r)))
  (TKScalar Double)
AstArtifactRev
  (X (((ListR
          n
          (AstTensor
             AstMethodLet FullSpan (TKS ((':) @Nat 784 ('[] @Nat)) Double)),
        AstTensor
          AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Double)),
       (ListR
          n
          (AstTensor
             AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Float)),
        AstTensor
          AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Double)),
       (ListR
          SizeMnistLabel
          (AstTensor
             AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Double)),
        AstTensor
          AstMethodLet
          FullSpan
          (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))),
      (AstTensor AstMethodLet FullSpan (TKR 1 Double),
       AstTensor AstMethodLet FullSpan (TKR 1 Double))))
  (TKScalar Double)
artRaw
        go :: [MnistDataLinearR r]
           -> Concrete (XParams widthHidden widthHidden2 r)
           -> Concrete (XParams widthHidden widthHidden2 r)
        go :: [MnistDataLinearR r]
-> Concrete (XParams n n r) -> Concrete (XParams n n r)
go [] Concrete (XParams n n r)
parameters = Concrete (XParams n n r)
parameters
        go ((Ranked 1 r
glyph, Ranked 1 r
label) : [MnistDataLinearR r]
rest) !Concrete (XParams n n r)
parameters =
          let parametersAndInput :: Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
              (TKS ((':) @Nat n ('[] @Nat)) Double))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) Double)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) Double)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) Double)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
     (TKProduct (TKR 1 r) (TKR 1 r)))
parametersAndInput =
                Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
           (TKS ((':) @Nat n ('[] @Nat)) Double))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) Double)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) Double)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) Double)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
-> Concrete (TKProduct (TKR 1 r) (TKR 1 r))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
                 (TKS ((':) @Nat n ('[] @Nat)) Double))
              (TKProduct
                 (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
                 (TKS ((':) @Nat n ('[] @Nat)) Double)))
           (TKProduct
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) Double)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) Double)
                                         (TKProduct
                                            (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
              (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
        (TKProduct (TKR 1 r) (TKR 1 r)))
forall (x :: TK) (z :: TK).
Concrete x -> Concrete z -> Concrete (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
           (TKS ((':) @Nat n ('[] @Nat)) Double))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) Double)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) Double)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) Double)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
Concrete (XParams n n r)
parameters (Concrete (TKR 1 r)
-> Concrete (TKR 1 r) -> Concrete (TKProduct (TKR 1 r) (TKR 1 r))
forall (x :: TK) (z :: TK).
Concrete x -> Concrete z -> Concrete (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair (Ranked 1 r -> Concrete (TKR 1 r)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 1 r
glyph) (Ranked 1 r -> Concrete (TKR 1 r)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 1 r
label))
              gradient :: Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (ADTensorKind (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)))
           (TKS ((':) @Nat n ('[] @Nat)) Double))
        (TKProduct
           (ADTensorKind (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)))
           (TKS ((':) @Nat n ('[] @Nat)) Double)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) Double)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) Double)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
gradient = Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (ADTensorKind (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)))
              (TKS ((':) @Nat n ('[] @Nat)) Double))
           (TKProduct
              (ADTensorKind (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)))
              (TKS ((':) @Nat n ('[] @Nat)) Double)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) Double)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) Double)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
     (TKProduct (TKR 1 Double) (TKR 1 Double)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (ADTensorKind (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)))
              (TKS ((':) @Nat n ('[] @Nat)) Double))
           (TKProduct
              (ADTensorKind (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)))
              (TKS ((':) @Nat n ('[] @Nat)) Double)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) Double)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) Double)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
forall (x :: TK) (z :: TK). Concrete (TKProduct x z) -> Concrete x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct
               (ADTensorKind (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)))
               (TKS ((':) @Nat n ('[] @Nat)) Double))
            (TKProduct
               (ADTensorKind (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)))
               (TKS ((':) @Nat n ('[] @Nat)) Double)))
         (TKProduct
            (TKProduct
               (TKS ((':) @Nat n ('[] @Nat)) Double)
               (TKProduct
                  (TKS ((':) @Nat n ('[] @Nat)) Double)
                  (TKProduct
                     (TKS ((':) @Nat n ('[] @Nat)) Double)
                     (TKProduct
                        (TKS ((':) @Nat n ('[] @Nat)) Double)
                        (TKProduct
                           (TKS ((':) @Nat n ('[] @Nat)) Double)
                           (TKProduct
                              (TKS ((':) @Nat n ('[] @Nat)) Double)
                              (TKProduct
                                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                                 (TKProduct
                                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                                    (TKProduct
                                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                                       (TKProduct
                                          (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
            (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
      (TKProduct (TKR 1 Double) (TKR 1 Double)))
 -> Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (ADTensorKind (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)))
               (TKS ((':) @Nat n ('[] @Nat)) Double))
            (TKProduct
               (ADTensorKind (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)))
               (TKS ((':) @Nat n ('[] @Nat)) Double)))
         (TKProduct
            (TKProduct
               (TKS ((':) @Nat n ('[] @Nat)) Double)
               (TKProduct
                  (TKS ((':) @Nat n ('[] @Nat)) Double)
                  (TKProduct
                     (TKS ((':) @Nat n ('[] @Nat)) Double)
                     (TKProduct
                        (TKS ((':) @Nat n ('[] @Nat)) Double)
                        (TKProduct
                           (TKS ((':) @Nat n ('[] @Nat)) Double)
                           (TKProduct
                              (TKS ((':) @Nat n ('[] @Nat)) Double)
                              (TKProduct
                                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                                 (TKProduct
                                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                                    (TKProduct
                                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                                       (TKProduct
                                          (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
            (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (ADTensorKind (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)))
                 (TKS ((':) @Nat n ('[] @Nat)) Double))
              (TKProduct
                 (ADTensorKind (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)))
                 (TKS ((':) @Nat n ('[] @Nat)) Double)))
           (TKProduct
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) Double)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) Double)
                                         (TKProduct
                                            (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
              (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
        (TKProduct (TKR 1 Double) (TKR 1 Double)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (ADTensorKind (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)))
              (TKS ((':) @Nat n ('[] @Nat)) Double))
           (TKProduct
              (ADTensorKind (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)))
              (TKS ((':) @Nat n ('[] @Nat)) Double)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) Double)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) Double)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
forall a b. (a -> b) -> a -> b
$ (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct
               (ADTensorKind (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)))
               (TKS ((':) @Nat n ('[] @Nat)) Double))
            (TKProduct
               (ADTensorKind (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)))
               (TKS ((':) @Nat n ('[] @Nat)) Double)))
         (TKProduct
            (TKProduct
               (TKS ((':) @Nat n ('[] @Nat)) Double)
               (TKProduct
                  (TKS ((':) @Nat n ('[] @Nat)) Double)
                  (TKProduct
                     (TKS ((':) @Nat n ('[] @Nat)) Double)
                     (TKProduct
                        (TKS ((':) @Nat n ('[] @Nat)) Double)
                        (TKProduct
                           (TKS ((':) @Nat n ('[] @Nat)) Double)
                           (TKProduct
                              (TKS ((':) @Nat n ('[] @Nat)) Double)
                              (TKProduct
                                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                                 (TKProduct
                                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                                    (TKProduct
                                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                                       (TKProduct
                                          (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
            (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
      (TKProduct (TKR 1 Double) (TKR 1 Double))),
 Concrete (TKScalar Double))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (ADTensorKind (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)))
                 (TKS ((':) @Nat n ('[] @Nat)) Double))
              (TKProduct
                 (ADTensorKind (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)))
                 (TKS ((':) @Nat n ('[] @Nat)) Double)))
           (TKProduct
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) Double)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) Double)
                                         (TKProduct
                                            (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
              (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
        (TKProduct (TKR 1 Double) (TKR 1 Double)))
forall a b. (a, b) -> a
fst
                         ((Concrete
    (TKProduct
       (TKProduct
          (TKProduct
             (TKProduct
                (ADTensorKind (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)))
                (TKS ((':) @Nat n ('[] @Nat)) Double))
             (TKProduct
                (ADTensorKind (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)))
                (TKS ((':) @Nat n ('[] @Nat)) Double)))
          (TKProduct
             (TKProduct
                (TKS ((':) @Nat n ('[] @Nat)) Double)
                (TKProduct
                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                   (TKProduct
                      (TKS ((':) @Nat n ('[] @Nat)) Double)
                      (TKProduct
                         (TKS ((':) @Nat n ('[] @Nat)) Double)
                         (TKProduct
                            (TKS ((':) @Nat n ('[] @Nat)) Double)
                            (TKProduct
                               (TKS ((':) @Nat n ('[] @Nat)) Double)
                               (TKProduct
                                  (TKS ((':) @Nat n ('[] @Nat)) Double)
                                  (TKProduct
                                     (TKS ((':) @Nat n ('[] @Nat)) Double)
                                     (TKProduct
                                        (TKS ((':) @Nat n ('[] @Nat)) Double)
                                        (TKProduct
                                           (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
             (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
       (TKProduct (TKR 1 Double) (TKR 1 Double))),
  Concrete (TKScalar Double))
 -> Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct
                  (ADTensorKind (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)))
                  (TKS ((':) @Nat n ('[] @Nat)) Double))
               (TKProduct
                  (ADTensorKind (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)))
                  (TKS ((':) @Nat n ('[] @Nat)) Double)))
            (TKProduct
               (TKProduct
                  (TKS ((':) @Nat n ('[] @Nat)) Double)
                  (TKProduct
                     (TKS ((':) @Nat n ('[] @Nat)) Double)
                     (TKProduct
                        (TKS ((':) @Nat n ('[] @Nat)) Double)
                        (TKProduct
                           (TKS ((':) @Nat n ('[] @Nat)) Double)
                           (TKProduct
                              (TKS ((':) @Nat n ('[] @Nat)) Double)
                              (TKProduct
                                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                                 (TKProduct
                                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                                    (TKProduct
                                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                                       (TKProduct
                                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                                          (TKProduct
                                             (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
               (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
         (TKProduct (TKR 1 Double) (TKR 1 Double))))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct
                  (ADTensorKind (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)))
                  (TKS ((':) @Nat n ('[] @Nat)) Double))
               (TKProduct
                  (ADTensorKind (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)))
                  (TKS ((':) @Nat n ('[] @Nat)) Double)))
            (TKProduct
               (TKProduct
                  (TKS ((':) @Nat n ('[] @Nat)) Double)
                  (TKProduct
                     (TKS ((':) @Nat n ('[] @Nat)) Double)
                     (TKProduct
                        (TKS ((':) @Nat n ('[] @Nat)) Double)
                        (TKProduct
                           (TKS ((':) @Nat n ('[] @Nat)) Double)
                           (TKProduct
                              (TKS ((':) @Nat n ('[] @Nat)) Double)
                              (TKProduct
                                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                                 (TKProduct
                                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                                    (TKProduct
                                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                                       (TKProduct
                                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                                          (TKProduct
                                             (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
               (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
         (TKProduct (TKR 1 Double) (TKR 1 Double))),
    Concrete (TKScalar Double))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (ADTensorKind (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)))
                 (TKS ((':) @Nat n ('[] @Nat)) Double))
              (TKProduct
                 (ADTensorKind (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)))
                 (TKS ((':) @Nat n ('[] @Nat)) Double)))
           (TKProduct
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) Double)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) Double)
                                         (TKProduct
                                            (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
              (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
        (TKProduct (TKR 1 Double) (TKR 1 Double)))
forall a b. (a -> b) -> a -> b
$ AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
              (TKS ((':) @Nat n ('[] @Nat)) Double))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) Double)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) Double)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) Double)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
     (TKProduct (TKR 1 r) (TKR 1 r)))
  (TKScalar Double)
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct
                 (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
                 (TKS ((':) @Nat n ('[] @Nat)) Double))
              (TKProduct
                 (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
                 (TKS ((':) @Nat n ('[] @Nat)) Double)))
           (TKProduct
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) Double)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) Double)
                                         (TKProduct
                                            (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
              (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
        (TKProduct (TKR 1 r) (TKR 1 r)))
-> Maybe (Concrete (ADTensorKind (TKScalar Double)))
-> (Concrete
      (ADTensorKind
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKProduct
                     (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
                     (TKS ((':) @Nat n ('[] @Nat)) Double))
                  (TKProduct
                     (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
                     (TKS ((':) @Nat n ('[] @Nat)) Double)))
               (TKProduct
                  (TKProduct
                     (TKS ((':) @Nat n ('[] @Nat)) Double)
                     (TKProduct
                        (TKS ((':) @Nat n ('[] @Nat)) Double)
                        (TKProduct
                           (TKS ((':) @Nat n ('[] @Nat)) Double)
                           (TKProduct
                              (TKS ((':) @Nat n ('[] @Nat)) Double)
                              (TKProduct
                                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                                 (TKProduct
                                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                                    (TKProduct
                                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                                       (TKProduct
                                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                                          (TKProduct
                                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                                             (TKProduct
                                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                                TKUnit))))))))))
                  (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
            (TKProduct (TKR 1 r) (TKR 1 r)))),
    Concrete (TKScalar Double))
forall (x :: TK) (z :: TK).
AstArtifactRev x z
-> Concrete x
-> Maybe (Concrete (ADTensorKind z))
-> (Concrete (ADTensorKind x), Concrete z)
revInterpretArtifact AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
              (TKS ((':) @Nat n ('[] @Nat)) Double))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) Double)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) Double)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) Double)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
     (TKProduct (TKR 1 r) (TKR 1 r)))
  (TKScalar Double)
art Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
              (TKS ((':) @Nat n ('[] @Nat)) Double))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) Double)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) Double)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) Double)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
     (TKProduct (TKR 1 r) (TKR 1 r)))
parametersAndInput Maybe (Concrete (ADTensorKind (TKScalar Double)))
Maybe (Concrete (TKScalar Double))
forall a. Maybe a
Nothing
          in [MnistDataLinearR r]
-> Concrete (XParams n n r) -> Concrete (XParams n n r)
go [MnistDataLinearR r]
rest (Double
-> SingletonTK
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
              (TKS ((':) @Nat n ('[] @Nat)) Double))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) Double)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) Double)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) Double)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
              (TKS ((':) @Nat n ('[] @Nat)) Double))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) Double)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) Double)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) Double)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
-> Concrete
     (ADTensorKind
        (TKProduct
           (TKProduct
              (TKProduct
                 (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
                 (TKS ((':) @Nat n ('[] @Nat)) Double))
              (TKProduct
                 (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
                 (TKS ((':) @Nat n ('[] @Nat)) Double)))
           (TKProduct
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) Double)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) Double)
                                         (TKProduct
                                            (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
              (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
              (TKS ((':) @Nat n ('[] @Nat)) Double))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) Double)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) Double)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) Double)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
forall (y :: TK).
Double
-> SingletonTK y
-> Concrete y
-> Concrete (ADTensorKind y)
-> Concrete y
updateWithGradient Double
gamma SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
           (TKS ((':) @Nat n ('[] @Nat)) Double))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) Double)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) Double)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) Double)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
           (TKS ((':) @Nat n ('[] @Nat)) Double))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) Double)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) Double)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) Double)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
Concrete (XParams n n r)
parameters Concrete
  (ADTensorKind
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
              (TKS ((':) @Nat n ('[] @Nat)) Double))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) Double)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) Double)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) Double)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))))
Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (ADTensorKind (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)))
           (TKS ((':) @Nat n ('[] @Nat)) Double))
        (TKProduct
           (ADTensorKind (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)))
           (TKS ((':) @Nat n ('[] @Nat)) Double)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) Double)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) Double)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
gradient)
        chunk :: [MnistDataLinearR r]
chunk = Int -> [MnistDataLinearR r] -> [MnistDataLinearR r]
forall a. Int -> [a] -> [a]
take Int
batchSize [MnistDataLinearR r]
xs
        gradf :: [MnistDataLinearR r] -> Concrete (XParams n n r)
gradf [MnistDataLinearR r]
c = [MnistDataLinearR r]
-> Concrete (XParams n n r) -> Concrete (XParams n n r)
go [MnistDataLinearR r]
c Concrete (XParams n n r)
targetInit
        name :: String
name =
          String
prefix
          String -> String -> String
forall a. [a] -> [a] -> [a]
++ [String] -> String
unwords
               [ String
"v" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (SingletonTK (XParams n n r) -> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK
                              (SingletonTK (XParams n n r) -> Int)
-> SingletonTK (XParams n n r) -> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams widthHidden widthHidden2 r))
               , String
"m0" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" =" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
           (TKS ((':) @Nat n ('[] @Nat)) Double))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) Double)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) Double)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) Double)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
              (TKS ((':) @Nat n ('[] @Nat)) Double))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) Double)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) Double)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) Double)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
           (TKS ((':) @Nat n ('[] @Nat)) Double))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) Double)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) Double)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) Double)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
           (TKS ((':) @Nat n ('[] @Nat)) Double))
        (TKProduct
           (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
           (TKS ((':) @Nat n ('[] @Nat)) Double)))
     (TKProduct
        (TKProduct
           (TKS ((':) @Nat n ('[] @Nat)) Double)
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) Double)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
Concrete (XParams n n r)
targetInit) ]
    String -> Benchmarkable -> Benchmark
bench String
name (Benchmarkable -> Benchmark) -> Benchmarkable -> Benchmark
forall a b. (a -> b) -> a -> b
$ ([MnistDataLinearR r]
 -> Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
               (TKS ((':) @Nat n ('[] @Nat)) Double))
            (TKProduct
               (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
               (TKS ((':) @Nat n ('[] @Nat)) Double)))
         (TKProduct
            (TKProduct
               (TKS ((':) @Nat n ('[] @Nat)) Double)
               (TKProduct
                  (TKS ((':) @Nat n ('[] @Nat)) Double)
                  (TKProduct
                     (TKS ((':) @Nat n ('[] @Nat)) Double)
                     (TKProduct
                        (TKS ((':) @Nat n ('[] @Nat)) Double)
                        (TKProduct
                           (TKS ((':) @Nat n ('[] @Nat)) Double)
                           (TKProduct
                              (TKS ((':) @Nat n ('[] @Nat)) Double)
                              (TKProduct
                                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                                 (TKProduct
                                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                                    (TKProduct
                                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                                       (TKProduct
                                          (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
            (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))))
-> [MnistDataLinearR r] -> Benchmarkable
forall b a. NFData b => (a -> b) -> a -> Benchmarkable
nf [MnistDataLinearR r]
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
              (TKS ((':) @Nat n ('[] @Nat)) Double))
           (TKProduct
              (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
              (TKS ((':) @Nat n ('[] @Nat)) Double)))
        (TKProduct
           (TKProduct
              (TKS ((':) @Nat n ('[] @Nat)) Double)
              (TKProduct
                 (TKS ((':) @Nat n ('[] @Nat)) Double)
                 (TKProduct
                    (TKS ((':) @Nat n ('[] @Nat)) Double)
                    (TKProduct
                       (TKS ((':) @Nat n ('[] @Nat)) Double)
                       (TKProduct
                          (TKS ((':) @Nat n ('[] @Nat)) Double)
                          (TKProduct
                             (TKS ((':) @Nat n ('[] @Nat)) Double)
                             (TKProduct
                                (TKS ((':) @Nat n ('[] @Nat)) Double)
                                (TKProduct
                                   (TKS ((':) @Nat n ('[] @Nat)) Double)
                                   (TKProduct
                                      (TKS ((':) @Nat n ('[] @Nat)) Double)
                                      (TKProduct
                                         (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
[MnistDataLinearR r] -> Concrete (XParams n n r)
gradf [MnistDataLinearR r]
chunk

mnistTestBench1VTO
  :: forall r. r ~ Double
  => String
  -> Int -> Int -> Double -> Int -> [MnistDataLinearR r]
  -> Benchmark
mnistTestBench1VTO :: forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTestBench1VTO = String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTestBench1VTA

mnistBGroup1VTO :: Int -> Benchmark
mnistBGroup1VTO :: Int -> Benchmark
mnistBGroup1VTO Int
chunkLength =
  IO [MnistDataLinearR Double]
-> ([MnistDataLinearR Double] -> Benchmark) -> Benchmark
forall env. NFData env => IO env -> (env -> Benchmark) -> Benchmark
env (do
    testData0 <- String -> String -> IO [MnistData Double]
forall r.
(Storable r, Fractional r) =>
String -> String -> IO [MnistData r]
loadMnistData String
testGlyphsPath String
testLabelsPath  -- 10k total
    let testData = StdGen -> [MnistData Double] -> [MnistData Double]
forall a. StdGen -> [a] -> [a]
shuffle (Int -> StdGen
mkStdGen Int
42) [MnistData Double]
testData0
    return $! map mkMnistDataLinearR $ take chunkLength testData) (([MnistDataLinearR Double] -> Benchmark) -> Benchmark)
-> ([MnistDataLinearR Double] -> Benchmark) -> Benchmark
forall a b. (a -> b) -> a -> b
$
  \ [MnistDataLinearR Double]
xs ->
  String -> [Benchmark] -> Benchmark
bgroup (String
"2-hidden-layer rank 1 VTO MNIST nn with samples: "
          String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
chunkLength)
       [ String
-> Int
-> Int
-> Double
-> Int
-> [MnistDataLinearR Double]
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTestBench1VTO String
"30|10 " Int
30 Int
10 Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs
       , String
-> Int
-> Int
-> Double
-> Int
-> [MnistDataLinearR Double]
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTrainBench1VTO String
"30|10 " Int
30 Int
10 Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs
       , String
-> Int
-> Int
-> Double
-> Int
-> [MnistDataLinearR Double]
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTestBench1VTO String
"300|100 " Int
300 Int
100 Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs
       , String
-> Int
-> Int
-> Double
-> Int
-> [MnistDataLinearR Double]
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTrainBench1VTO String
"300|100 " Int
300 Int
100 Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs
       , String
-> Int
-> Int
-> Double
-> Int
-> [MnistDataLinearR Double]
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTestBench1VTO String
"500|150 " Int
500 Int
150 Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs
       , String
-> Int
-> Int
-> Double
-> Int
-> [MnistDataLinearR Double]
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTrainBench1VTO String
"500|150 " Int
500 Int
150 Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs
       , String
-> Int
-> Int
-> Double
-> Int
-> [MnistDataLinearR Double]
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTrainBench1VTO String
"1500|500 " Int
1500 Int
500 Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs
       ]


-- * Using matrices, which is rank 2

-- POPL differentiation, straight via the ADVal instance of RankedTensor,
-- which side-steps vectorization.
mnistTrainBench2VTA
  :: forall r. r ~ Double
  => String
  -> Int -> Int -> Double -> Int -> [MnistDataLinearR r]
  -> Benchmark
mnistTrainBench2VTA :: forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTrainBench2VTA String
prefix Int
widthHidden Int
widthHidden2
                    Double
gamma Int
batchSize [MnistDataLinearR r]
xs =
  Int
-> (forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
widthHidden ((forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
 -> Benchmark)
-> (forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark
forall a b. (a -> b) -> a -> b
$ \(SNat @widthHidden) ->
  Int
-> (forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
widthHidden2 ((forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
 -> Benchmark)
-> (forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark
forall a b. (a -> b) -> a -> b
$ \(SNat @widthHidden2) ->
  let targetInit :: NoShape
  (Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar Double))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double)))
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat)))
              (TKScalar Double))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))))
targetInit =
        Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar Double))
           (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double)))
        (TKProduct
           (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
           (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat)))
           (TKScalar Double))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
-> NoShape
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar Double))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double)))
              (TKProduct
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat)))
                 (TKScalar Double))
              (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))))
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar Double))
            (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double)))
         (TKProduct
            (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
            (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double))))
      (TKProduct
         (TKS2
            ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat)))
            (TKScalar Double))
         (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
 -> NoShape
      (Concrete
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar Double))
                  (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double)))
               (TKProduct
                  (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
                  (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double))))
            (TKProduct
               (TKS2
                  ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat)))
                  (TKScalar Double))
               (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar Double))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double)))
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat)))
              (TKScalar Double))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
-> NoShape
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar Double))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double)))
              (TKProduct
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat)))
                 (TKScalar Double))
              (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))))
forall a b. (a -> b) -> a -> b
$ (Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r Float)),
 StdGen)
-> Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r Float))
forall a b. (a, b) -> a
fst
        ((Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r Float)),
  StdGen)
 -> Concrete
      (X (ADFcnnMnist2ParametersShaped Concrete n n r Float)))
-> (Concrete
      (X (ADFcnnMnist2ParametersShaped Concrete n n r Float)),
    StdGen)
-> Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r Float))
forall a b. (a -> b) -> a -> b
$ forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue
            @(Concrete (X (MnistFcnnRanked2.ADFcnnMnist2ParametersShaped
                             Concrete widthHidden widthHidden2 r Float)))
            Double
1 (Int -> StdGen
mkStdGen Int
44)
  in do
{-    let f :: MnistDataLinearR r -> ADVal Concrete (XParams2 r Float)
          -> ADVal Concrete (TKScalar r)
        f (glyph, label) adinputs =
          MnistFcnnRanked2.afcnnMnistLoss2
            (rconcrete glyph, rconcrete label) (fromTarget adinputs) -}
    let f :: MnistDataLinearR Double -> ADVal Concrete (XParams2 Double Float)
          -> ADVal Concrete (TKScalar Double)
        f :: MnistDataLinearR Double
-> ADVal Concrete (XParams2 Double Float)
-> ADVal Concrete (TKScalar Double)
f (Ranked 1 Double
glyph, Ranked 1 Double
label) ADVal Concrete (XParams2 Double Float)
adinputs =
          (ADVal Concrete (TKR 1 Double), ADVal Concrete (TKR 1 Double))
-> ADFcnnMnist2Parameters (ADVal Concrete) Double Float
-> ADVal Concrete (TKScalar Double)
forall (target :: Target) r q.
(ADReady target, GoodScalar r, Differentiable r, GoodScalar q,
 Differentiable q) =>
(target (TKR 1 r), target (TKR 1 r))
-> ADFcnnMnist2Parameters target r q -> target (TKScalar r)
MnistFcnnRanked2.afcnnMnistLoss2
            (Ranked 1 Double -> ADVal Concrete (TKR 1 Double)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 1 Double
glyph, Ranked 1 Double -> ADVal Concrete (TKR 1 Double)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 1 Double
label) (ADVal
  Concrete (X (ADFcnnMnist2Parameters (ADVal Concrete) Double Float))
-> ADFcnnMnist2Parameters (ADVal Concrete) Double Float
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget ADVal
  Concrete (X (ADFcnnMnist2Parameters (ADVal Concrete) Double Float))
ADVal Concrete (XParams2 Double Float)
adinputs)
        chunk :: [MnistDataLinearR r]
chunk = Int -> [MnistDataLinearR r] -> [MnistDataLinearR r]
forall a. Int -> [a] -> [a]
take Int
batchSize [MnistDataLinearR r]
xs
        gradf :: [MnistDataLinearR Double]
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
gradf [MnistDataLinearR Double]
c = (Concrete
   (TKProduct
      (TKProduct
         (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
         (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
      (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))),
 Concrete (TKScalar Double))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
forall a b. (a, b) -> a
fst ((Concrete
    (TKProduct
       (TKProduct
          (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
          (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
       (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))),
  Concrete (TKScalar Double))
 -> Concrete
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
            (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
         (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
            (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
         (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))),
    Concrete (TKScalar Double))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
forall a b. (a -> b) -> a -> b
$ Double
-> (MnistDataLinearR Double
    -> ADVal
         Concrete
         (TKProduct
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
               (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
    -> ADVal Concrete (TKScalar Double))
-> [MnistDataLinearR Double]
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
            (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
         (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))),
    Concrete (TKScalar Double))
forall a (x :: TK) (z :: TK).
KnownSTK x =>
Double
-> (a -> ADVal Concrete x -> ADVal Concrete z)
-> [a]
-> Concrete x
-> (Concrete x, Concrete z)
sgd Double
gamma MnistDataLinearR Double
-> ADVal
     Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> ADVal Concrete (TKScalar Double)
MnistDataLinearR Double
-> ADVal Concrete (XParams2 Double Float)
-> ADVal Concrete (TKScalar Double)
f [MnistDataLinearR Double]
c Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
NoShape
  (Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar Double))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double)))
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat)))
              (TKScalar Double))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))))
targetInit
        name :: String
name =
          String
prefix
          String -> String -> String
forall a. [a] -> [a] -> [a]
++ [String] -> String
unwords
               [ String
"v0 m" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (SingletonTK (XParams2 r Float) -> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK (SingletonTK (XParams2 r Float) -> Int)
-> SingletonTK (XParams2 r Float) -> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams2 r Float))
               , String
"=" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
NoShape
  (Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar Double))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double)))
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat)))
              (TKScalar Double))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))))
targetInit) ]
    String -> Benchmarkable -> Benchmark
bench String
name (Benchmarkable -> Benchmark) -> Benchmarkable -> Benchmark
forall a b. (a -> b) -> a -> b
$ ([MnistDataLinearR Double]
 -> Concrete
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
            (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
         (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))))
-> [MnistDataLinearR Double] -> Benchmarkable
forall b a. NFData b => (a -> b) -> a -> Benchmarkable
nf [MnistDataLinearR Double]
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
gradf [MnistDataLinearR r]
[MnistDataLinearR Double]
chunk

mnistTestBench2VTA
  :: forall r. r ~ Double
  => String
  -> Int -> Int -> Double -> Int -> [MnistDataLinearR r]
  -> Benchmark
mnistTestBench2VTA :: forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTestBench2VTA String
prefix Int
widthHidden Int
widthHidden2
                   Double
_gamma Int
batchSize [MnistDataLinearR r]
xs =
  Int
-> (forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
widthHidden ((forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
 -> Benchmark)
-> (forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark
forall a b. (a -> b) -> a -> b
$ \(SNat @widthHidden) ->
  Int
-> (forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
widthHidden2 ((forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
 -> Benchmark)
-> (forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark
forall a b. (a -> b) -> a -> b
$ \(SNat @widthHidden2) ->
  let targetInit :: NoShape
  (Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar Double))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double)))
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat)))
              (TKScalar Double))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))))
targetInit =
        Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar Double))
           (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double)))
        (TKProduct
           (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
           (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat)))
           (TKScalar Double))
        (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
-> NoShape
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar Double))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double)))
              (TKProduct
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat)))
                 (TKScalar Double))
              (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))))
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar Double))
            (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double)))
         (TKProduct
            (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
            (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double))))
      (TKProduct
         (TKS2
            ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat)))
            (TKScalar Double))
         (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
 -> NoShape
      (Concrete
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar Double))
                  (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double)))
               (TKProduct
                  (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
                  (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double))))
            (TKProduct
               (TKS2
                  ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat)))
                  (TKScalar Double))
               (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar Double))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double)))
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat)))
              (TKScalar Double))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
-> NoShape
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar Double))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double)))
              (TKProduct
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat)))
                 (TKScalar Double))
              (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))))
forall a b. (a -> b) -> a -> b
$ (Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r Float)),
 StdGen)
-> Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r Float))
forall a b. (a, b) -> a
fst
        ((Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r Float)),
  StdGen)
 -> Concrete
      (X (ADFcnnMnist2ParametersShaped Concrete n n r Float)))
-> (Concrete
      (X (ADFcnnMnist2ParametersShaped Concrete n n r Float)),
    StdGen)
-> Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r Float))
forall a b. (a -> b) -> a -> b
$ forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue
            @(Concrete (X (MnistFcnnRanked2.ADFcnnMnist2ParametersShaped
                             Concrete widthHidden widthHidden2 r Float)))
            Double
1 (Int -> StdGen
mkStdGen Int
44)
  in do
    let chunk :: [MnistDataLinearR r]
chunk = Int -> [MnistDataLinearR r] -> [MnistDataLinearR r]
forall a. Int -> [a] -> [a]
take Int
batchSize [MnistDataLinearR r]
xs
        score :: [MnistDataLinearR r] -> r
score [MnistDataLinearR r]
c = [MnistDataLinearR r]
-> ADFcnnMnist2Parameters Concrete r Float -> r
forall (target :: Target) r q.
((target :: Target) ~ (Concrete :: Target), GoodScalar r,
 Differentiable r, GoodScalar q, Differentiable q) =>
[MnistDataLinearR r] -> ADFcnnMnist2Parameters target r q -> r
MnistFcnnRanked2.afcnnMnistTest2 [MnistDataLinearR r]
c (Concrete (XParams2 r Float)
-> ADFcnnMnist2Parameters Concrete r Float
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget Concrete (XParams2 r Float)
NoShape
  (Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar Double))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double)))
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat)))
              (TKScalar Double))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))))
targetInit)
        name :: String
name =
          String
"test " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
prefix
          String -> String -> String
forall a. [a] -> [a] -> [a]
++ [String] -> String
unwords
               [ String
"v0 m" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (SingletonTK (XParams2 r Float) -> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK (SingletonTK (XParams2 r Float) -> Int)
-> SingletonTK (XParams2 r Float) -> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams2 r Float))
               , String
"=" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
NoShape
  (Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar Double))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double)))
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat)))
              (TKScalar Double))
           (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))))
targetInit) ]
    String -> Benchmarkable -> Benchmark
bench String
name (Benchmarkable -> Benchmark) -> Benchmarkable -> Benchmark
forall a b. (a -> b) -> a -> b
$ ([MnistDataLinearR r] -> r)
-> [MnistDataLinearR r] -> Benchmarkable
forall a b. (a -> b) -> a -> Benchmarkable
whnf [MnistDataLinearR r] -> r
score [MnistDataLinearR r]
chunk

mnistBGroup2VTA :: Int -> Benchmark
mnistBGroup2VTA :: Int -> Benchmark
mnistBGroup2VTA Int
chunkLength =
  IO [MnistDataLinearR Double]
-> ([MnistDataLinearR Double] -> Benchmark) -> Benchmark
forall env. NFData env => IO env -> (env -> Benchmark) -> Benchmark
env (do
    testData0 <- String -> String -> IO [MnistData Double]
forall r.
(Storable r, Fractional r) =>
String -> String -> IO [MnistData r]
loadMnistData String
testGlyphsPath String
testLabelsPath  -- 10k total
    let testData = StdGen -> [MnistData Double] -> [MnistData Double]
forall a. StdGen -> [a] -> [a]
shuffle (Int -> StdGen
mkStdGen Int
42) [MnistData Double]
testData0
    return $! map mkMnistDataLinearR $ take chunkLength testData) (([MnistDataLinearR Double] -> Benchmark) -> Benchmark)
-> ([MnistDataLinearR Double] -> Benchmark) -> Benchmark
forall a b. (a -> b) -> a -> b
$
  \ [MnistDataLinearR Double]
xs ->
  String -> [Benchmark] -> Benchmark
bgroup (String
"2-hidden-layer rank 2 VTA MNIST nn with samples: "
          String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
chunkLength)
       [ String
-> Int
-> Int
-> Double
-> Int
-> [MnistDataLinearR Double]
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTestBench2VTA String
"30|10 " Int
30 Int
10 Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs
       , String
-> Int
-> Int
-> Double
-> Int
-> [MnistDataLinearR Double]
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTrainBench2VTA String
"30|10 " Int
30 Int
10 Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs
       , String
-> Int
-> Int
-> Double
-> Int
-> [MnistDataLinearR Double]
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTestBench2VTA String
"300|100 " Int
300 Int
100 Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs
       , String
-> Int
-> Int
-> Double
-> Int
-> [MnistDataLinearR Double]
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTrainBench2VTA String
"300|100 " Int
300 Int
100 Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs
       , String
-> Int
-> Int
-> Double
-> Int
-> [MnistDataLinearR Double]
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTestBench2VTA String
"500|150 " Int
500 Int
150 Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs
       , String
-> Int
-> Int
-> Double
-> Int
-> [MnistDataLinearR Double]
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTrainBench2VTA String
"500|150 " Int
500 Int
150 Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs
       , String
-> Int
-> Int
-> Double
-> Int
-> [MnistDataLinearR Double]
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTrainBench2VTA String
"1500|500 " Int
1500 Int
500 Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs
       ]

-- JAX differentiation, Ast term built and differentiated only once
-- and the result interpreted with different inputs in each gradient
-- descent iteration.

-- Only compilation time.
mnistTrainBench2VTC
  :: String
  -> Int -> Int
  -> Benchmark
mnistTrainBench2VTC :: String -> Int -> Int -> Benchmark
mnistTrainBench2VTC String
prefix Int
widthHidden Int
widthHidden2 =
  String -> Benchmarkable -> Benchmark
bench String
prefix
  (Benchmarkable -> Benchmark) -> Benchmarkable -> Benchmark
forall a b. (a -> b) -> a -> b
$ (Int
 -> AstArtifactRev
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
               (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
         (TKProduct (TKR 1 Double) (TKR 1 Double)))
      (TKScalar Double))
-> Int -> Benchmarkable
forall a b. (a -> b) -> a -> Benchmarkable
whnf (AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 Double) (TKR 1 Double)))
  (TKScalar Double)
-> AstArtifactRev
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR 1 Double) (TKR 1 Double)))
     (TKScalar Double)
forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstArtifactRev x z
simplifyArtifactGradient (AstArtifactRev
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
            (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
         (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
      (TKProduct (TKR 1 Double) (TKR 1 Double)))
   (TKScalar Double)
 -> AstArtifactRev
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
               (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
         (TKProduct (TKR 1 Double) (TKR 1 Double)))
      (TKScalar Double))
-> (Int
    -> AstArtifactRev
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
                  (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
               (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
            (TKProduct (TKR 1 Double) (TKR 1 Double)))
         (TKScalar Double))
-> Int
-> AstArtifactRev
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR 1 Double) (TKR 1 Double)))
     (TKScalar Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Concrete
   (TKProduct
      (TKProduct
         (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
         (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
      (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))),
 AstArtifactRev
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
            (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
         (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
      (TKProduct (TKR 1 Double) (TKR 1 Double)))
   (TKScalar Double))
-> AstArtifactRev
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR 1 Double) (TKR 1 Double)))
     (TKScalar Double)
forall a b. (a, b) -> b
snd
          ((Concrete
    (TKProduct
       (TKProduct
          (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
          (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
       (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))),
  AstArtifactRev
    (TKProduct
       (TKProduct
          (TKProduct
             (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
             (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
          (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
       (TKProduct (TKR 1 Double) (TKR 1 Double)))
    (TKScalar Double))
 -> AstArtifactRev
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
               (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
         (TKProduct (TKR 1 Double) (TKR 1 Double)))
      (TKScalar Double))
-> (Int
    -> (Concrete
          (TKProduct
             (TKProduct
                (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
                (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
             (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))),
        AstArtifactRev
          (TKProduct
             (TKProduct
                (TKProduct
                   (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
                   (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
                (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
             (TKProduct (TKR 1 Double) (TKR 1 Double)))
          (TKScalar Double)))
-> Int
-> AstArtifactRev
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR 1 Double) (TKR 1 Double)))
     (TKScalar Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r q.
(GoodScalar r, Differentiable r, GoodScalar q, Differentiable q) =>
Proxy @Type q
-> IncomingCotangentHandling
-> Double
-> StdGen
-> Int
-> Int
-> (Concrete (XParams2 r q),
    AstArtifactRev
      (TKProduct
         (XParams2 r q)
         (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKScalar r))
MnistFcnnRanked2.mnistTrainBench2VTOGradient
              @Double (forall t. Proxy @Type t
forall {k} (t :: k). Proxy @k t
Proxy @Float) IncomingCotangentHandling
IgnoreIncomingCotangent
              Double
1 (Int -> StdGen
mkStdGen Int
44) Int
widthHidden)
         Int
widthHidden2

mnistBGroup2VTC :: Int -> Benchmark
mnistBGroup2VTC :: Int -> Benchmark
mnistBGroup2VTC Int
chunkLength =
  String -> [Benchmark] -> Benchmark
bgroup (String
"2-hidden-layer rank 2 VTC compilation MNIST nn with samples: "
          String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
chunkLength)
       [ String -> Int -> Int -> Benchmark
mnistTrainBench2VTC String
"30|10 " Int
30 Int
10
       , String -> Int -> Int -> Benchmark
mnistTrainBench2VTC String
"300|100 " Int
300 Int
100
       , String -> Int -> Int -> Benchmark
mnistTrainBench2VTC String
"500|150 " Int
500 Int
150
       , String -> Int -> Int -> Benchmark
mnistTrainBench2VTC String
"1500|500 " Int
1500 Int
500
       ]

-- The same as above, but only runtime.
mnistTrainBench2VTOO
  :: forall r. r ~ Double
  => String
  -> Double -> Int -> [MnistDataLinearR r]
  -> ( Concrete (XParams2 r Float)
     , AstArtifactRev
         (TKProduct
            (XParams2 r Float)
            (TKProduct (TKR2 1 (TKScalar Double))
                       (TKR2 1 (TKScalar Double))))
         (TKScalar r) )
  -> Benchmark
mnistTrainBench2VTOO :: forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Double
-> Int
-> [MnistDataLinearR r]
-> (Concrete (XParams2 r Float),
    AstArtifactRev
      (TKProduct
         (XParams2 r Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
      (TKScalar r))
-> Benchmark
mnistTrainBench2VTOO String
prefix Double
gamma Int
batchSize [MnistDataLinearR r]
xs (Concrete (XParams2 r Float)
targetInit, AstArtifactRev
  (TKProduct
     (XParams2 r Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
  (TKScalar r)
art) = do
    let go :: [MnistDataLinearR r] -> Concrete (XParams2 r Float)
           -> Concrete (XParams2 r Float)
        go :: [MnistDataLinearR r]
-> Concrete (XParams2 r Float) -> Concrete (XParams2 r Float)
go [] Concrete (XParams2 r Float)
parameters = Concrete (XParams2 r Float)
parameters
        go ((Ranked 1 r
glyph, Ranked 1 r
label) : [MnistDataLinearR r]
rest) !Concrete (XParams2 r Float)
parameters =
          let parametersAndInput :: Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 r) (TKR 1 r)))
parametersAndInput =
                Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Concrete (TKProduct (TKR 1 r) (TKR 1 r))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR 1 r) (TKR 1 r)))
forall (x :: TK) (z :: TK).
Concrete x -> Concrete z -> Concrete (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
Concrete (XParams2 r Float)
parameters (Concrete (TKR 1 r)
-> Concrete (TKR 1 r) -> Concrete (TKProduct (TKR 1 r) (TKR 1 r))
forall (x :: TK) (z :: TK).
Concrete x -> Concrete z -> Concrete (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair (Ranked 1 r -> Concrete (TKR 1 r)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 1 r
glyph) (Ranked 1 r -> Concrete (TKR 1 r)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 1 r
label))
              gradient :: Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
gradient = Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 Double) (TKR 1 Double)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
forall (x :: TK) (z :: TK). Concrete (TKProduct x z) -> Concrete x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
            (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
         (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
      (TKProduct (TKR 1 Double) (TKR 1 Double)))
 -> Concrete
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
            (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
         (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR 1 Double) (TKR 1 Double)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
forall a b. (a -> b) -> a -> b
$ (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
            (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
         (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
      (TKProduct (TKR 1 Double) (TKR 1 Double))),
 Concrete (TKScalar r))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR 1 Double) (TKR 1 Double)))
forall a b. (a, b) -> a
fst
                         ((Concrete
    (TKProduct
       (TKProduct
          (TKProduct
             (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
             (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
          (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
       (TKProduct (TKR 1 Double) (TKR 1 Double))),
  Concrete (TKScalar r))
 -> Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
               (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
         (TKProduct (TKR 1 Double) (TKR 1 Double))))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
               (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
         (TKProduct (TKR 1 Double) (TKR 1 Double))),
    Concrete (TKScalar r))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR 1 Double) (TKR 1 Double)))
forall a b. (a -> b) -> a -> b
$ AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 r) (TKR 1 r)))
  (TKScalar r)
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR 1 r) (TKR 1 r)))
-> Maybe (Concrete (ADTensorKind (TKScalar r)))
-> (Concrete
      (ADTensorKind
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
                  (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
               (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
            (TKProduct (TKR 1 r) (TKR 1 r)))),
    Concrete (TKScalar r))
forall (x :: TK) (z :: TK).
AstArtifactRev x z
-> Concrete x
-> Maybe (Concrete (ADTensorKind z))
-> (Concrete (ADTensorKind x), Concrete z)
revInterpretArtifact AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 r) (TKR 1 r)))
  (TKScalar r)
AstArtifactRev
  (TKProduct
     (XParams2 r Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
  (TKScalar r)
art Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 r) (TKR 1 r)))
parametersAndInput Maybe (Concrete (ADTensorKind (TKScalar r)))
Maybe (Concrete (TKScalar Double))
forall a. Maybe a
Nothing
          in [MnistDataLinearR r]
-> Concrete (XParams2 r Float) -> Concrete (XParams2 r Float)
go [MnistDataLinearR r]
rest (Double
-> SingletonTK
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Concrete
     (ADTensorKind
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
forall (y :: TK).
Double
-> SingletonTK y
-> Concrete y
-> Concrete (ADTensorKind y)
-> Concrete y
updateWithGradient Double
gamma SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
Concrete (XParams2 r Float)
parameters Concrete
  (ADTensorKind
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))))
Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
gradient)
        chunk :: [MnistDataLinearR r]
chunk = Int -> [MnistDataLinearR r] -> [MnistDataLinearR r]
forall a. Int -> [a] -> [a]
take Int
batchSize [MnistDataLinearR r]
xs
        gradf :: [MnistDataLinearR r] -> Concrete (XParams2 r Float)
gradf [MnistDataLinearR r]
c = [MnistDataLinearR r]
-> Concrete (XParams2 r Float) -> Concrete (XParams2 r Float)
go [MnistDataLinearR r]
c Concrete (XParams2 r Float)
targetInit
        name :: String
name =
          String
prefix
          String -> String -> String
forall a. [a] -> [a] -> [a]
++ [String] -> String
unwords
               [ String
"v0 m" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (SingletonTK (XParams2 r Float) -> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK (SingletonTK (XParams2 r Float) -> Int)
-> SingletonTK (XParams2 r Float) -> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams2 r Float))
               , String
"=" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
Concrete (XParams2 r Float)
targetInit) ]
    String -> Benchmarkable -> Benchmark
bench String
name (Benchmarkable -> Benchmark) -> Benchmarkable -> Benchmark
forall a b. (a -> b) -> a -> b
$ ([MnistDataLinearR r]
 -> Concrete
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
            (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
         (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))))
-> [MnistDataLinearR r] -> Benchmarkable
forall b a. NFData b => (a -> b) -> a -> Benchmarkable
nf [MnistDataLinearR r]
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
[MnistDataLinearR r] -> Concrete (XParams2 r Float)
gradf [MnistDataLinearR r]
chunk

-- The same as above, but both compilation time and only runtime.
mnistTrainBench2VTO
  :: forall r. r ~ Double
  => String
  -> Int -> Int -> Double -> Int -> [MnistDataLinearR r]
  -> Benchmark
mnistTrainBench2VTO :: forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTrainBench2VTO String
prefix Int
widthHidden Int
widthHidden2
                    Double
gamma Int
batchSize [MnistDataLinearR r]
xs =
  let (!Concrete (XParams2 Double Float)
targetInit, !AstArtifactRev
  (TKProduct
     (XParams2 Double Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
  (TKScalar Double)
artRaw) =
        forall r q.
(GoodScalar r, Differentiable r, GoodScalar q, Differentiable q) =>
Proxy @Type q
-> IncomingCotangentHandling
-> Double
-> StdGen
-> Int
-> Int
-> (Concrete (XParams2 r q),
    AstArtifactRev
      (TKProduct
         (XParams2 r q)
         (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKScalar r))
MnistFcnnRanked2.mnistTrainBench2VTOGradient
          @Double (forall t. Proxy @Type t
forall {k} (t :: k). Proxy @k t
Proxy @Float) IncomingCotangentHandling
IgnoreIncomingCotangent
          Double
1 (Int -> StdGen
mkStdGen Int
44) Int
widthHidden Int
widthHidden2
      !art :: AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 Double) (TKR 1 Double)))
  (TKScalar Double)
art = AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 Double) (TKR 1 Double)))
  (TKScalar Double)
-> AstArtifactRev
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR 1 Double) (TKR 1 Double)))
     (TKScalar Double)
forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstArtifactRev x z
simplifyArtifactGradient AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 Double) (TKR 1 Double)))
  (TKScalar Double)
AstArtifactRev
  (TKProduct
     (XParams2 Double Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
  (TKScalar Double)
artRaw
  in String
-> Double
-> Int
-> [MnistDataLinearR r]
-> (Concrete (XParams2 r Float),
    AstArtifactRev
      (TKProduct
         (XParams2 r Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
      (TKScalar r))
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Double
-> Int
-> [MnistDataLinearR r]
-> (Concrete (XParams2 r Float),
    AstArtifactRev
      (TKProduct
         (XParams2 r Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
      (TKScalar r))
-> Benchmark
mnistTrainBench2VTOO String
prefix Double
gamma Int
batchSize [MnistDataLinearR r]
xs (Concrete (XParams2 r Float)
Concrete (XParams2 Double Float)
targetInit, AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 Double) (TKR 1 Double)))
  (TKScalar Double)
AstArtifactRev
  (TKProduct
     (XParams2 r Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
  (TKScalar r)
art)

mnistBGroup2VTO :: Int -> Benchmark
mnistBGroup2VTO :: Int -> Benchmark
mnistBGroup2VTO Int
chunkLength =
  let (!Concrete (XParams2 Double Float)
targetInit, !AstArtifactRev
  (TKProduct
     (XParams2 Double Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
  (TKScalar Double)
artRaw) =
        forall r q.
(GoodScalar r, Differentiable r, GoodScalar q, Differentiable q) =>
Proxy @Type q
-> IncomingCotangentHandling
-> Double
-> StdGen
-> Int
-> Int
-> (Concrete (XParams2 r q),
    AstArtifactRev
      (TKProduct
         (XParams2 r q)
         (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKScalar r))
MnistFcnnRanked2.mnistTrainBench2VTOGradient
          @Double (forall t. Proxy @Type t
forall {k} (t :: k). Proxy @k t
Proxy @Float) IncomingCotangentHandling
IgnoreIncomingCotangent
          Double
1 (Int -> StdGen
mkStdGen Int
44) Int
1500 Int
500
      !art :: AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 Double) (TKR 1 Double)))
  (TKScalar Double)
art = AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 Double) (TKR 1 Double)))
  (TKScalar Double)
-> AstArtifactRev
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR 1 Double) (TKR 1 Double)))
     (TKScalar Double)
forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstArtifactRev x z
simplifyArtifactGradient AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 Double) (TKR 1 Double)))
  (TKScalar Double)
artRaw  -- no NFData for AST
  in IO [MnistDataLinearR Double]
-> ([MnistDataLinearR Double] -> Benchmark) -> Benchmark
forall env. NFData env => IO env -> (env -> Benchmark) -> Benchmark
env (do
    testData0 <- String -> String -> IO [MnistData Double]
forall r.
(Storable r, Fractional r) =>
String -> String -> IO [MnistData r]
loadMnistData String
testGlyphsPath String
testLabelsPath  -- 10k total
    let testData = StdGen -> [MnistData Double] -> [MnistData Double]
forall a. StdGen -> [a] -> [a]
shuffle (Int -> StdGen
mkStdGen Int
42) [MnistData Double]
testData0
    return $! map mkMnistDataLinearR $ take chunkLength testData) (([MnistDataLinearR Double] -> Benchmark) -> Benchmark)
-> ([MnistDataLinearR Double] -> Benchmark) -> Benchmark
forall a b. (a -> b) -> a -> b
$
  \ [MnistDataLinearR Double]
xs ->
   String -> [Benchmark] -> Benchmark
bgroup (String
"2-hidden-layer rank 2 VTO runtime MNIST nn with samples: "
           String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
chunkLength)
     [ String
-> Double
-> Int
-> [MnistDataLinearR Double]
-> (Concrete (XParams2 Double Float),
    AstArtifactRev
      (TKProduct
         (XParams2 Double Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
      (TKScalar Double))
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Double
-> Int
-> [MnistDataLinearR r]
-> (Concrete (XParams2 r Float),
    AstArtifactRev
      (TKProduct
         (XParams2 r Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
      (TKScalar r))
-> Benchmark
mnistTrainBench2VTOO String
"1500|500 " Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs (Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
Concrete (XParams2 Double Float)
targetInit, AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 Double) (TKR 1 Double)))
  (TKScalar Double)
AstArtifactRev
  (TKProduct
     (XParams2 Double Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
  (TKScalar Double)
art)
     ]

-- The same as above, but without simplifying the gradient.
mnistTrainBench2VTOZ
  :: forall r. r ~ Double
  => String
  -> Double -> Int -> [MnistDataLinearR r]
  -> ( Concrete (XParams2 r Float)
     , AstArtifactRev
         (TKProduct
            (XParams2 r Float)
            (TKProduct (TKR2 1 (TKScalar Double))
                       (TKR2 1 (TKScalar Double))))
         (TKScalar r) )
  -> Benchmark
mnistTrainBench2VTOZ :: forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Double
-> Int
-> [MnistDataLinearR r]
-> (Concrete (XParams2 r Float),
    AstArtifactRev
      (TKProduct
         (XParams2 r Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
      (TKScalar r))
-> Benchmark
mnistTrainBench2VTOZ String
prefix Double
gamma Int
batchSize [MnistDataLinearR r]
xs (Concrete (XParams2 r Float)
targetInit, AstArtifactRev
  (TKProduct
     (XParams2 r Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
  (TKScalar r)
art) = do
    let go :: [MnistDataLinearR r] -> Concrete (XParams2 r Float)
           -> Concrete (XParams2 r Float)
        go :: [MnistDataLinearR r]
-> Concrete (XParams2 r Float) -> Concrete (XParams2 r Float)
go [] Concrete (XParams2 r Float)
parameters = Concrete (XParams2 r Float)
parameters
        go ((Ranked 1 r
glyph, Ranked 1 r
label) : [MnistDataLinearR r]
rest) !Concrete (XParams2 r Float)
parameters =
          let parametersAndInput :: Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 r) (TKR 1 r)))
parametersAndInput =
                Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Concrete (TKProduct (TKR 1 r) (TKR 1 r))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR 1 r) (TKR 1 r)))
forall (x :: TK) (z :: TK).
Concrete x -> Concrete z -> Concrete (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
Concrete (XParams2 r Float)
parameters (Concrete (TKR 1 r)
-> Concrete (TKR 1 r) -> Concrete (TKProduct (TKR 1 r) (TKR 1 r))
forall (x :: TK) (z :: TK).
Concrete x -> Concrete z -> Concrete (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair (Ranked 1 r -> Concrete (TKR 1 r)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 1 r
glyph) (Ranked 1 r -> Concrete (TKR 1 r)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 1 r
label))
              gradient :: Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
gradient = Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 Double) (TKR 1 Double)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
forall (x :: TK) (z :: TK). Concrete (TKProduct x z) -> Concrete x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
            (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
         (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
      (TKProduct (TKR 1 Double) (TKR 1 Double)))
 -> Concrete
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
            (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
         (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR 1 Double) (TKR 1 Double)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
forall a b. (a -> b) -> a -> b
$ (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
            (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
         (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
      (TKProduct (TKR 1 Double) (TKR 1 Double))),
 Concrete (TKScalar r))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR 1 Double) (TKR 1 Double)))
forall a b. (a, b) -> a
fst
                         ((Concrete
    (TKProduct
       (TKProduct
          (TKProduct
             (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
             (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
          (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
       (TKProduct (TKR 1 Double) (TKR 1 Double))),
  Concrete (TKScalar r))
 -> Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
               (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
         (TKProduct (TKR 1 Double) (TKR 1 Double))))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
               (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
         (TKProduct (TKR 1 Double) (TKR 1 Double))),
    Concrete (TKScalar r))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR 1 Double) (TKR 1 Double)))
forall a b. (a -> b) -> a -> b
$ AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 r) (TKR 1 r)))
  (TKScalar r)
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR 1 r) (TKR 1 r)))
-> Maybe (Concrete (ADTensorKind (TKScalar r)))
-> (Concrete
      (ADTensorKind
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
                  (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
               (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
            (TKProduct (TKR 1 r) (TKR 1 r)))),
    Concrete (TKScalar r))
forall (x :: TK) (z :: TK).
AstArtifactRev x z
-> Concrete x
-> Maybe (Concrete (ADTensorKind z))
-> (Concrete (ADTensorKind x), Concrete z)
revInterpretArtifact AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 r) (TKR 1 r)))
  (TKScalar r)
AstArtifactRev
  (TKProduct
     (XParams2 r Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
  (TKScalar r)
art Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 r) (TKR 1 r)))
parametersAndInput Maybe (Concrete (ADTensorKind (TKScalar r)))
Maybe (Concrete (TKScalar Double))
forall a. Maybe a
Nothing
          in [MnistDataLinearR r]
-> Concrete (XParams2 r Float) -> Concrete (XParams2 r Float)
go [MnistDataLinearR r]
rest (Double
-> SingletonTK
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Concrete
     (ADTensorKind
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
forall (y :: TK).
Double
-> SingletonTK y
-> Concrete y
-> Concrete (ADTensorKind y)
-> Concrete y
updateWithGradient Double
gamma SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
Concrete (XParams2 r Float)
parameters Concrete
  (ADTensorKind
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))))
Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
gradient)
        chunk :: [MnistDataLinearR r]
chunk = Int -> [MnistDataLinearR r] -> [MnistDataLinearR r]
forall a. Int -> [a] -> [a]
take Int
batchSize [MnistDataLinearR r]
xs
        gradf :: [MnistDataLinearR r] -> Concrete (XParams2 r Float)
gradf [MnistDataLinearR r]
c = [MnistDataLinearR r]
-> Concrete (XParams2 r Float) -> Concrete (XParams2 r Float)
go [MnistDataLinearR r]
c Concrete (XParams2 r Float)
targetInit
        name :: String
name =
          String
prefix
          String -> String -> String
forall a. [a] -> [a] -> [a]
++ [String] -> String
unwords
               [ String
"v0 m" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (SingletonTK (XParams2 r Float) -> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK (SingletonTK (XParams2 r Float) -> Int)
-> SingletonTK (XParams2 r Float) -> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams2 r Float))
               , String
"=" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
Concrete (XParams2 r Float)
targetInit) ]
    String -> Benchmarkable -> Benchmark
bench String
name (Benchmarkable -> Benchmark) -> Benchmarkable -> Benchmark
forall a b. (a -> b) -> a -> b
$ ([MnistDataLinearR r]
 -> Concrete
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
            (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
         (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))))
-> [MnistDataLinearR r] -> Benchmarkable
forall b a. NFData b => (a -> b) -> a -> Benchmarkable
nf [MnistDataLinearR r]
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
[MnistDataLinearR r] -> Concrete (XParams2 r Float)
gradf [MnistDataLinearR r]
chunk

mnistBGroup2VTOZ :: Int -> Benchmark
mnistBGroup2VTOZ :: Int -> Benchmark
mnistBGroup2VTOZ Int
chunkLength =
  let (!Concrete (XParams2 Double Float)
targetInit, !AstArtifactRev
  (TKProduct
     (XParams2 Double Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
  (TKScalar Double)
art) =
        forall r q.
(GoodScalar r, Differentiable r, GoodScalar q, Differentiable q) =>
Proxy @Type q
-> IncomingCotangentHandling
-> Double
-> StdGen
-> Int
-> Int
-> (Concrete (XParams2 r q),
    AstArtifactRev
      (TKProduct
         (XParams2 r q)
         (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKScalar r))
MnistFcnnRanked2.mnistTrainBench2VTOGradient
          @Double (forall t. Proxy @Type t
forall {k} (t :: k). Proxy @k t
Proxy @Float) IncomingCotangentHandling
IgnoreIncomingCotangent
          Double
1 (Int -> StdGen
mkStdGen Int
44) Int
1500 Int
500
  in IO [MnistDataLinearR Double]
-> ([MnistDataLinearR Double] -> Benchmark) -> Benchmark
forall env. NFData env => IO env -> (env -> Benchmark) -> Benchmark
env (do
    testData0 <- String -> String -> IO [MnistData Double]
forall r.
(Storable r, Fractional r) =>
String -> String -> IO [MnistData r]
loadMnistData String
testGlyphsPath String
testLabelsPath  -- 10k total
    let testData = StdGen -> [MnistData Double] -> [MnistData Double]
forall a. StdGen -> [a] -> [a]
shuffle (Int -> StdGen
mkStdGen Int
42) [MnistData Double]
testData0
    return $! map mkMnistDataLinearR $ take chunkLength testData) (([MnistDataLinearR Double] -> Benchmark) -> Benchmark)
-> ([MnistDataLinearR Double] -> Benchmark) -> Benchmark
forall a b. (a -> b) -> a -> b
$
  \ [MnistDataLinearR Double]
xs ->
   String -> [Benchmark] -> Benchmark
bgroup (String
"2-hidden-layer rank 2 VTOZ runtime MNIST nn with samples: "
           String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
chunkLength)
     [ String
-> Double
-> Int
-> [MnistDataLinearR Double]
-> (Concrete (XParams2 Double Float),
    AstArtifactRev
      (TKProduct
         (XParams2 Double Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
      (TKScalar Double))
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Double
-> Int
-> [MnistDataLinearR r]
-> (Concrete (XParams2 r Float),
    AstArtifactRev
      (TKProduct
         (XParams2 r Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
      (TKScalar r))
-> Benchmark
mnistTrainBench2VTOO String
"1500|500 " Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs (Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
Concrete (XParams2 Double Float)
targetInit, AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 Double) (TKR 1 Double)))
  (TKScalar Double)
AstArtifactRev
  (TKProduct
     (XParams2 Double Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
  (TKScalar Double)
art)
     ]

-- The same as above, but without any simplification, even the smart
-- constructors.
mnistTrainBench2VTOX
  :: forall r. r ~ Double
  => String
  -> Double -> Int -> [MnistDataLinearR r]
  -> ( Concrete (XParams2 r Float)
     , AstArtifactRev
         (TKProduct
            (XParams2 r Float)
            (TKProduct (TKR2 1 (TKScalar Double))
                       (TKR2 1 (TKScalar Double))))
         (TKScalar r) )
  -> Benchmark
mnistTrainBench2VTOX :: forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Double
-> Int
-> [MnistDataLinearR r]
-> (Concrete (XParams2 r Float),
    AstArtifactRev
      (TKProduct
         (XParams2 r Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
      (TKScalar r))
-> Benchmark
mnistTrainBench2VTOX String
prefix Double
gamma Int
batchSize [MnistDataLinearR r]
xs (Concrete (XParams2 r Float)
targetInit, AstArtifactRev
  (TKProduct
     (XParams2 r Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
  (TKScalar r)
art) = do
    let go :: [MnistDataLinearR r] -> Concrete (XParams2 r Float)
           -> Concrete (XParams2 r Float)
        go :: [MnistDataLinearR r]
-> Concrete (XParams2 r Float) -> Concrete (XParams2 r Float)
go [] Concrete (XParams2 r Float)
parameters = Concrete (XParams2 r Float)
parameters
        go ((Ranked 1 r
glyph, Ranked 1 r
label) : [MnistDataLinearR r]
rest) !Concrete (XParams2 r Float)
parameters =
          let parametersAndInput :: Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 r) (TKR 1 r)))
parametersAndInput =
                Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Concrete (TKProduct (TKR 1 r) (TKR 1 r))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR 1 r) (TKR 1 r)))
forall (x :: TK) (z :: TK).
Concrete x -> Concrete z -> Concrete (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
Concrete (XParams2 r Float)
parameters (Concrete (TKR 1 r)
-> Concrete (TKR 1 r) -> Concrete (TKProduct (TKR 1 r) (TKR 1 r))
forall (x :: TK) (z :: TK).
Concrete x -> Concrete z -> Concrete (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair (Ranked 1 r -> Concrete (TKR 1 r)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 1 r
glyph) (Ranked 1 r -> Concrete (TKR 1 r)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 1 r
label))
              gradient :: Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
gradient = Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 Double) (TKR 1 Double)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
forall (x :: TK) (z :: TK). Concrete (TKProduct x z) -> Concrete x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
            (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
         (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
      (TKProduct (TKR 1 Double) (TKR 1 Double)))
 -> Concrete
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
            (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
         (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR 1 Double) (TKR 1 Double)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
forall a b. (a -> b) -> a -> b
$ (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
            (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
         (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
      (TKProduct (TKR 1 Double) (TKR 1 Double))),
 Concrete (TKScalar r))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR 1 Double) (TKR 1 Double)))
forall a b. (a, b) -> a
fst
                         ((Concrete
    (TKProduct
       (TKProduct
          (TKProduct
             (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
             (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
          (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
       (TKProduct (TKR 1 Double) (TKR 1 Double))),
  Concrete (TKScalar r))
 -> Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
               (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
         (TKProduct (TKR 1 Double) (TKR 1 Double))))
-> (Concrete
      (TKProduct
         (TKProduct
            (TKProduct
               (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
               (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
         (TKProduct (TKR 1 Double) (TKR 1 Double))),
    Concrete (TKScalar r))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR 1 Double) (TKR 1 Double)))
forall a b. (a -> b) -> a -> b
$ AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 r) (TKR 1 r)))
  (TKScalar r)
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
        (TKProduct (TKR 1 r) (TKR 1 r)))
-> Maybe (Concrete (ADTensorKind (TKScalar r)))
-> (Concrete
      (ADTensorKind
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
                  (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
               (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
            (TKProduct (TKR 1 r) (TKR 1 r)))),
    Concrete (TKScalar r))
forall (x :: TK) (z :: TK).
AstArtifactRev x z
-> Concrete x
-> Maybe (Concrete (ADTensorKind z))
-> (Concrete (ADTensorKind x), Concrete z)
revInterpretArtifact AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 r) (TKR 1 r)))
  (TKScalar r)
AstArtifactRev
  (TKProduct
     (XParams2 r Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
  (TKScalar r)
art Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 r) (TKR 1 r)))
parametersAndInput Maybe (Concrete (ADTensorKind (TKScalar r)))
Maybe (Concrete (TKScalar Double))
forall a. Maybe a
Nothing
          in [MnistDataLinearR r]
-> Concrete (XParams2 r Float) -> Concrete (XParams2 r Float)
go [MnistDataLinearR r]
rest (Double
-> SingletonTK
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Concrete
     (ADTensorKind
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
              (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
forall (y :: TK).
Double
-> SingletonTK y
-> Concrete y
-> Concrete (ADTensorKind y)
-> Concrete y
updateWithGradient Double
gamma SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
Concrete (XParams2 r Float)
parameters Concrete
  (ADTensorKind
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))))
Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
gradient)
        chunk :: [MnistDataLinearR r]
chunk = Int -> [MnistDataLinearR r] -> [MnistDataLinearR r]
forall a. Int -> [a] -> [a]
take Int
batchSize [MnistDataLinearR r]
xs
        gradf :: [MnistDataLinearR r] -> Concrete (XParams2 r Float)
gradf [MnistDataLinearR r]
c = [MnistDataLinearR r]
-> Concrete (XParams2 r Float) -> Concrete (XParams2 r Float)
go [MnistDataLinearR r]
c Concrete (XParams2 r Float)
targetInit
        name :: String
name =
          String
prefix
          String -> String -> String
forall a. [a] -> [a] -> [a]
++ [String] -> String
unwords
               [ String
"v0 m" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (SingletonTK (XParams2 r Float) -> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK (SingletonTK (XParams2 r Float) -> Int)
-> SingletonTK (XParams2 r Float) -> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams2 r Float))
               , String
"=" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
Concrete (XParams2 r Float)
targetInit) ]
    String -> Benchmarkable -> Benchmark
bench String
name (Benchmarkable -> Benchmark) -> Benchmarkable -> Benchmark
forall a b. (a -> b) -> a -> b
$ ([MnistDataLinearR r]
 -> Concrete
      (TKProduct
         (TKProduct
            (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
            (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
         (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))))
-> [MnistDataLinearR r] -> Benchmarkable
forall b a. NFData b => (a -> b) -> a -> Benchmarkable
nf [MnistDataLinearR r]
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
[MnistDataLinearR r] -> Concrete (XParams2 r Float)
gradf [MnistDataLinearR r]
chunk

mnistBGroup2VTOX :: Int -> Benchmark
mnistBGroup2VTOX :: Int -> Benchmark
mnistBGroup2VTOX Int
chunkLength =
  let (!Concrete (XParams2 Double Float)
targetInit, !AstArtifactRev
  (TKProduct
     (XParams2 Double Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
  (TKScalar Double)
art) =
        forall r q.
(GoodScalar r, Differentiable r, GoodScalar q, Differentiable q) =>
Proxy @Type q
-> IncomingCotangentHandling
-> Double
-> StdGen
-> Int
-> Int
-> (Concrete (XParams2 r q),
    AstArtifactRev
      (TKProduct
         (XParams2 r q)
         (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKScalar r))
MnistFcnnRanked2.mnistTrainBench2VTOGradientX
          @Double (forall t. Proxy @Type t
forall {k} (t :: k). Proxy @k t
Proxy @Float) IncomingCotangentHandling
IgnoreIncomingCotangent
          Double
1 (Int -> StdGen
mkStdGen Int
44) Int
1500 Int
500
  in IO [MnistDataLinearR Double]
-> ([MnistDataLinearR Double] -> Benchmark) -> Benchmark
forall env. NFData env => IO env -> (env -> Benchmark) -> Benchmark
env (do
    testData0 <- String -> String -> IO [MnistData Double]
forall r.
(Storable r, Fractional r) =>
String -> String -> IO [MnistData r]
loadMnistData String
testGlyphsPath String
testLabelsPath  -- 10k total
    let testData = StdGen -> [MnistData Double] -> [MnistData Double]
forall a. StdGen -> [a] -> [a]
shuffle (Int -> StdGen
mkStdGen Int
42) [MnistData Double]
testData0
    return $! map mkMnistDataLinearR $ take chunkLength testData) (([MnistDataLinearR Double] -> Benchmark) -> Benchmark)
-> ([MnistDataLinearR Double] -> Benchmark) -> Benchmark
forall a b. (a -> b) -> a -> b
$
  \ [MnistDataLinearR Double]
xs ->
   String -> [Benchmark] -> Benchmark
bgroup (String
"2-hidden-layer rank 2 VTOX runtime MNIST nn with samples: "
           String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
chunkLength)
     [ String
-> Double
-> Int
-> [MnistDataLinearR Double]
-> (Concrete (XParams2 Double Float),
    AstArtifactRev
      (TKProduct
         (XParams2 Double Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
      (TKScalar Double))
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Double
-> Int
-> [MnistDataLinearR r]
-> (Concrete (XParams2 r Float),
    AstArtifactRev
      (TKProduct
         (XParams2 r Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
      (TKScalar r))
-> Benchmark
mnistTrainBench2VTOO String
"1500|500 " Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs (Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
        (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
     (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
Concrete (XParams2 Double Float)
targetInit, AstArtifactRev
  (TKProduct
     (TKProduct
        (TKProduct
           (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
           (TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
        (TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
     (TKProduct (TKR 1 Double) (TKR 1 Double)))
  (TKScalar Double)
AstArtifactRev
  (TKProduct
     (XParams2 Double Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
  (TKScalar Double)
art)
     ]

{- TODO: re-enable once -fpolymorphic-specialisation works

-- This is expected to fail with -O0 and to pass with -O1
-- and -fpolymorphic-specialisation.
-- This prevents running benchmarks without optimization, which is a good thing.
--
-- The `Storable` is only needed for overloaded profiling, e.g., with
-- cabal bench longMnistBench -ftest_seq -w /home/mikolaj/r/ghc.HEAD/ghc/_build/stage1/bin/ghc --allow-newer --enable-optimization --enable-profiling --profiling-detail=none --ghc-options="-fprof-late-overloaded -fpolymorphic-specialisation" --benchmark-options='-n1 -m pattern "1 VTA MNIST nn with samples: 400/500|150 v" +RTS -pj'
inspect $ hasNoTypeClassesExcept 'mnistTrainBench1VTA [''(~), ''KnownNat, ''WithDict, ''KnownShS, ''AdaptableTarget, ''RandomValue, ''KnownSTK, ''GoodScalar, ''Num, ''Show, ''Ord, ''Eq, ''Nested.PrimElt, ''Nested.KnownElt, ''Nested.NumElt, ''Typeable, ''IfDifferentiable, ''NFData, ''Default.Default, ''Nested.Storable]
inspect $ hasNoTypeClassesExcept 'mnistTrainBench1VTO [''(~), ''KnownNat, ''WithDict, ''KnownShS, ''AdaptableTarget, ''RandomValue, ''KnownSTK, ''GoodScalar, ''Num, ''Show, ''Ord, ''Eq, ''Nested.PrimElt, ''Nested.KnownElt, ''Nested.NumElt, ''Typeable, ''IfDifferentiable, ''NFData, ''Default.Default, ''Nested.Storable,      ''AstSpan, ''RealFloatH, ''Nested.FloatElt, ''Fractional, ''Floating, ''IntegralH, ''RealFrac, ''Real, ''Nested.Elt, ''Integral]
inspect $ hasNoTypeClassesExcept 'mnistTrainBench2VTA [''(~), ''KnownNat, ''WithDict, ''KnownShS, ''AdaptableTarget, ''RandomValue, ''KnownSTK, ''GoodScalar, ''Num, ''Show, ''Ord, ''Eq, ''Nested.PrimElt, ''Nested.KnownElt, ''Nested.NumElt, ''Typeable, ''IfDifferentiable, ''NFData, ''Default.Default, ''Nested.Storable]
inspect $ hasNoTypeClassesExcept 'mnistTrainBench2VTC [''(~), ''KnownNat, ''WithDict, ''KnownShS, ''AdaptableTarget, ''RandomValue, ''KnownSTK, ''GoodScalar, ''Num, ''Show, ''Ord, ''Eq, ''Nested.PrimElt, ''Nested.KnownElt, ''Nested.NumElt, ''Typeable, ''IfDifferentiable, ''NFData, ''Default.Default]
inspect $ hasNoTypeClassesExcept 'mnistTrainBench2VTO [''(~), ''GoodScalar, ''Show, ''Num, ''Ord, ''Eq, ''Nested.PrimElt, ''Nested.KnownElt, ''Nested.NumElt, ''Typeable, ''IfDifferentiable, ''NFData, ''Default.Default,       ''AstSpan, ''RealFloatH, ''Nested.FloatElt, ''Fractional, ''Floating, ''IntegralH, ''RealFrac, ''Real, ''Nested.Storable, ''WithDict, ''KnownShS, ''KnownSTK, ''KnownNat, ''Nested.Elt, ''Integral]
-- inspect $ coreOf 'mnistTrainBench1VTA
-}