{-# OPTIONS_GHC -Wno-missing-export-lists #-}
-- | Ranked tensor-based implementation of fully connected neutral network
-- for classification of MNIST digits. Sports 2 hidden layers. No mini-batches,
-- so the maximum rank of tensors being used is 2.
module MnistFcnnRanked2 where

import Prelude

import Data.Proxy (Proxy (Proxy))
import Data.Vector.Generic qualified as V
import GHC.Exts (inline)
import GHC.TypeLits (Nat)
import System.Random

import Data.Array.Nested qualified as Nested
import Data.Array.Nested.Ranked.Shape

import HordeAd
import HordeAd.Core.Adaptor
import HordeAd.Core.CarriersAst
import MnistData

-- | The differentiable type of all trainable parameters of this nn.
-- Shaped version, statically checking all dimension widths.
type ADFcnnMnist2ParametersShaped
       (target :: Target) (widthHidden :: Nat) (widthHidden2 :: Nat) r q =
  ( ( target (TKS '[widthHidden, SizeMnistGlyph] r)
    , target (TKS '[widthHidden] r) )
  , ( target (TKS '[widthHidden2, widthHidden] q)
    , target (TKS '[widthHidden2] r) )
  , ( target (TKS '[SizeMnistLabel, widthHidden2] r)
    , target (TKS '[SizeMnistLabel] r) )
  )

-- | The differentiable type of all trainable parameters of this nn.
type ADFcnnMnist2Parameters (target :: Target) r q =
  ( ( target (TKR 2 r)
    , target (TKR 1 r) )
  , ( target (TKR 2 q)
    , target (TKR 1 r) )
  , ( target (TKR 2 r)
    , target (TKR 1 r) )
  )

type XParams2 r q = X (MnistFcnnRanked2.ADFcnnMnist2Parameters Concrete r q)

-- | Fully connected neural network for the MNIST digit classification task.
-- There are two hidden layers and both use the same activation function.
-- The output layer uses a different activation function.
-- The widths of the two hidden layers are @widthHidden@ and @widthHidden2@,
-- respectively.
afcnnMnist2 :: ( ADReady target, GoodScalar r, Differentiable r
               , GoodScalar q, Differentiable q )
            => (target (TKR 1 r) -> target (TKR 1 r))
            -> (target (TKR 1 r) -> target (TKR 1 r))
            -> target (TKR 1 r)
            -> ADFcnnMnist2Parameters target r q
            -> target (TKR 1 r)
afcnnMnist2 :: forall (target :: Target) r q.
(ADReady target, GoodScalar r, Differentiable r, GoodScalar q,
 Differentiable q) =>
(target (TKR 1 r) -> target (TKR 1 r))
-> (target (TKR 1 r) -> target (TKR 1 r))
-> target (TKR 1 r)
-> ADFcnnMnist2Parameters target r q
-> target (TKR 1 r)
afcnnMnist2 target (TKR 1 r) -> target (TKR 1 r)
factivationHidden target (TKR 1 r) -> target (TKR 1 r)
factivationOutput
            target (TKR 1 r)
datum ((target (TKR 2 r)
hidden, target (TKR 1 r)
bias), (target (TKR 2 q)
hidden2, target (TKR 1 r)
bias2), (target (TKR 2 r)
readout, target (TKR 1 r)
biasr)) =
  let hiddenLayer1 :: target (TKR 1 r)
hiddenLayer1 = target (TKR 2 r) -> target (TKR 1 r) -> target (TKR 1 r)
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
target (TKR 2 r) -> target (TKR 1 r) -> target (TKR 1 r)
rmatvecmul target (TKR 2 r)
hidden target (TKR 1 r)
datum target (TKR 1 r) -> target (TKR 1 r) -> target (TKR 1 r)
forall a. Num a => a -> a -> a
+ target (TKR 1 r)
bias
      nonlinearLayer1 :: target (TKR 1 r)
nonlinearLayer1 = target (TKR 1 r) -> target (TKR 1 r)
factivationHidden target (TKR 1 r)
hiddenLayer1
      hiddenLayer2 :: target (TKR 1 r)
hiddenLayer2 = target (TKR 1 q) -> target (TKR 1 r)
forall r1 r2 (target :: Target) (n :: Nat).
(RealFrac r1, GoodScalar r1, RealFrac r2, GoodScalar r2,
 BaseTensor target) =>
target (TKR n r1) -> target (TKR n r2)
rcast (target (TKR 2 q) -> target (TKR 1 q) -> target (TKR 1 q)
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
target (TKR 2 r) -> target (TKR 1 r) -> target (TKR 1 r)
rmatvecmul target (TKR 2 q)
hidden2 (target (TKR 1 r) -> target (TKR 1 q)
forall r1 r2 (target :: Target) (n :: Nat).
(RealFrac r1, GoodScalar r1, RealFrac r2, GoodScalar r2,
 BaseTensor target) =>
target (TKR n r1) -> target (TKR n r2)
rcast target (TKR 1 r)
nonlinearLayer1)) target (TKR 1 r) -> target (TKR 1 r) -> target (TKR 1 r)
forall a. Num a => a -> a -> a
+ target (TKR 1 r)
bias2
      nonlinearLayer2 :: target (TKR 1 r)
nonlinearLayer2 = target (TKR 1 r) -> target (TKR 1 r)
factivationHidden target (TKR 1 r)
hiddenLayer2
      outputLayer :: target (TKR 1 r)
outputLayer = target (TKR 2 r) -> target (TKR 1 r) -> target (TKR 1 r)
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
target (TKR 2 r) -> target (TKR 1 r) -> target (TKR 1 r)
rmatvecmul target (TKR 2 r)
readout target (TKR 1 r)
nonlinearLayer2 target (TKR 1 r) -> target (TKR 1 r) -> target (TKR 1 r)
forall a. Num a => a -> a -> a
+ target (TKR 1 r)
biasr
  in target (TKR 1 r) -> target (TKR 1 r)
factivationOutput target (TKR 1 r)
outputLayer

-- | The neural network applied to concrete activation functions
-- and composed with the appropriate loss function.
afcnnMnistLoss2
  :: ( ADReady target, GoodScalar r, Differentiable r
     , GoodScalar q, Differentiable q )
  => (target (TKR 1 r), target (TKR 1 r)) -> ADFcnnMnist2Parameters target r q
  -> target (TKScalar r)
afcnnMnistLoss2 :: forall (target :: Target) r q.
(ADReady target, GoodScalar r, Differentiable r, GoodScalar q,
 Differentiable q) =>
(target (TKR 1 r), target (TKR 1 r))
-> ADFcnnMnist2Parameters target r q -> target (TKScalar r)
afcnnMnistLoss2 (target (TKR 1 r)
datum, target (TKR 1 r)
target) ADFcnnMnist2Parameters target r q
adparams =
  let result :: target (TKR 1 r)
result = ((target (TKR 1 r) -> target (TKR 1 r))
 -> (target (TKR 1 r) -> target (TKR 1 r))
 -> target (TKR 1 r)
 -> ADFcnnMnist2Parameters target r q
 -> target (TKR 1 r))
-> (target (TKR 1 r) -> target (TKR 1 r))
-> (target (TKR 1 r) -> target (TKR 1 r))
-> target (TKR 1 r)
-> ADFcnnMnist2Parameters target r q
-> target (TKR 1 r)
forall a. a -> a
inline (target (TKR 1 r) -> target (TKR 1 r))
-> (target (TKR 1 r) -> target (TKR 1 r))
-> target (TKR 1 r)
-> ADFcnnMnist2Parameters target r q
-> target (TKR 1 r)
forall (target :: Target) r q.
(ADReady target, GoodScalar r, Differentiable r, GoodScalar q,
 Differentiable q) =>
(target (TKR 1 r) -> target (TKR 1 r))
-> (target (TKR 1 r) -> target (TKR 1 r))
-> target (TKR 1 r)
-> ADFcnnMnist2Parameters target r q
-> target (TKR 1 r)
afcnnMnist2 target (TKR 1 r) -> target (TKR 1 r)
forall (target :: Target) r (n :: Nat).
(BaseTensor target, LetTensor target, BaseTensor (PrimalOf target),
 KnownNat n, GoodScalar r, Differentiable r) =>
target (TKR n r) -> target (TKR n r)
logistic target (TKR 1 r) -> target (TKR 1 r)
forall (target :: Target) (n :: Nat) r.
(BaseTensor target, LetTensor target, KnownNat n, GoodScalar r,
 Differentiable r) =>
target (TKR n r) -> target (TKR n r)
softMax1 target (TKR 1 r)
datum ADFcnnMnist2Parameters target r q
adparams
  in target (TKR 1 r) -> target (TKR 1 r) -> target (TKScalar r)
forall (target :: Target) (n :: Nat) r.
(BaseTensor target, ConvertTensor target, KnownNat n, GoodScalar r,
 Differentiable r) =>
target (TKR n r) -> target (TKR n r) -> target (TKScalar r)
lossCrossEntropyV target (TKR 1 r)
target target (TKR 1 r)
result
{-# SPECIALIZE afcnnMnistLoss2 :: (ADVal Concrete (TKR 1 Double), ADVal Concrete (TKR 1 Double)) -> ADFcnnMnist2Parameters (ADVal Concrete) Double Float -> ADVal Concrete (TKScalar Double) #-}
{-# SPECIALIZE afcnnMnistLoss2 :: (ADVal Concrete (TKR 1 Float), ADVal Concrete (TKR 1 Float)) -> ADFcnnMnist2Parameters (ADVal Concrete) Float Float -> ADVal Concrete (TKScalar Float) #-}
{-# SPECIALIZE afcnnMnistLoss2 :: (ADVal Concrete (TKR 1 Double), ADVal Concrete (TKR 1 Double)) -> ADFcnnMnist2Parameters (ADVal Concrete) Double Double -> ADVal Concrete (TKScalar Double) #-}

-- | A function testing the neural network given testing set of inputs
-- and the trained parameters.
afcnnMnistTest2
  :: forall target r q.
     ( target ~ Concrete, GoodScalar r, Differentiable r
     , GoodScalar q, Differentiable q )
  => [MnistDataLinearR r]
  -> ADFcnnMnist2Parameters target r q
  -> r
afcnnMnistTest2 :: forall (target :: Target) r q.
((target :: Target) ~ (Concrete :: Target), GoodScalar r,
 Differentiable r, GoodScalar q, Differentiable q) =>
[MnistDataLinearR r] -> ADFcnnMnist2Parameters target r q -> r
afcnnMnistTest2 [] ADFcnnMnist2Parameters target r q
_ = r
0
afcnnMnistTest2 [MnistDataLinearR r]
dataList ADFcnnMnist2Parameters target r q
testParams =
  let matchesLabels :: MnistDataLinearR r -> Bool
      matchesLabels :: MnistDataLinearR r -> Bool
matchesLabels (Ranked 1 r
glyph, Ranked 1 r
label) =
        let glyph1 :: target (TKR 1 r)
glyph1 = Ranked 1 r -> target (TKR 1 r)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 1 r
glyph
            nn :: ADFcnnMnist2Parameters target r q
               -> target (TKR 1 r)
            nn :: ADFcnnMnist2Parameters target r q -> target (TKR 1 r)
nn = ((target (TKR 1 r) -> target (TKR 1 r))
 -> (target (TKR 1 r) -> target (TKR 1 r))
 -> target (TKR 1 r)
 -> ADFcnnMnist2Parameters target r q
 -> target (TKR 1 r))
-> (target (TKR 1 r) -> target (TKR 1 r))
-> (target (TKR 1 r) -> target (TKR 1 r))
-> target (TKR 1 r)
-> ADFcnnMnist2Parameters target r q
-> target (TKR 1 r)
forall a. a -> a
inline (target (TKR 1 r) -> target (TKR 1 r))
-> (target (TKR 1 r) -> target (TKR 1 r))
-> target (TKR 1 r)
-> ADFcnnMnist2Parameters target r q
-> target (TKR 1 r)
forall (target :: Target) r q.
(ADReady target, GoodScalar r, Differentiable r, GoodScalar q,
 Differentiable q) =>
(target (TKR 1 r) -> target (TKR 1 r))
-> (target (TKR 1 r) -> target (TKR 1 r))
-> target (TKR 1 r)
-> ADFcnnMnist2Parameters target r q
-> target (TKR 1 r)
afcnnMnist2 target (TKR 1 r) -> target (TKR 1 r)
forall (target :: Target) r (n :: Nat).
(BaseTensor target, LetTensor target, BaseTensor (PrimalOf target),
 KnownNat n, GoodScalar r, Differentiable r) =>
target (TKR n r) -> target (TKR n r)
logistic target (TKR 1 r) -> target (TKR 1 r)
forall (target :: Target) (n :: Nat) r.
(BaseTensor target, LetTensor target, KnownNat n, GoodScalar r,
 Differentiable r) =>
target (TKR n r) -> target (TKR n r)
softMax1 target (TKR 1 r)
glyph1
            v :: Vector r
v = Concrete (TKR 1 r) -> Vector r
forall r (n :: Nat). GoodScalar r => Concrete (TKR n r) -> Vector r
rtoVector (Concrete (TKR 1 r) -> Vector r) -> Concrete (TKR 1 r) -> Vector r
forall a b. (a -> b) -> a -> b
$ ADFcnnMnist2Parameters target r q -> target (TKR 1 r)
nn ADFcnnMnist2Parameters target r q
testParams
        in Vector r -> Int
forall (v :: Type -> Type) a. (Vector v a, Ord a) => v a -> Int
V.maxIndex Vector r
v Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Vector r -> Int
forall (v :: Type -> Type) a. (Vector v a, Ord a) => v a -> Int
V.maxIndex (Ranked 1 r -> Vector r
forall a (n :: Nat). PrimElt a => Ranked n a -> Vector a
Nested.rtoVector Ranked 1 r
label)
  in Int -> r
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([MnistDataLinearR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length ((MnistDataLinearR r -> Bool)
-> [MnistDataLinearR r] -> [MnistDataLinearR r]
forall a. (a -> Bool) -> [a] -> [a]
filter MnistDataLinearR r -> Bool
matchesLabels [MnistDataLinearR r]
dataList))
     r -> r -> r
forall a. Fractional a => a -> a -> a
/ Int -> r
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([MnistDataLinearR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataLinearR r]
dataList)

-- | The loss function applied to randomly generated initial parameters
-- and wrapped in artifact generation. This is helpful to share code
-- between tests and benchmarks and to separate compile-time and run-time
-- for benchmarking (this part is considered compile-time).
mnistTrainBench2VTOGradient
  :: forall r q. ( GoodScalar r, Differentiable r
                 , GoodScalar q, Differentiable q )
  => Proxy q -> IncomingCotangentHandling -> Double -> StdGen -> Int -> Int
  -> ( Concrete (XParams2 r q)
     , AstArtifactRev
         (TKProduct
            (XParams2 r q)
            (TKProduct (TKR2 1 (TKScalar r))
                       (TKR2 1 (TKScalar r))))
         (TKScalar r) )
mnistTrainBench2VTOGradient :: forall r q.
(GoodScalar r, Differentiable r, GoodScalar q, Differentiable q) =>
Proxy @Type q
-> IncomingCotangentHandling
-> Double
-> StdGen
-> Int
-> Int
-> (Concrete (XParams2 r q),
    AstArtifactRev
      (TKProduct
         (XParams2 r q)
         (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKScalar r))
mnistTrainBench2VTOGradient Proxy @Type q
Proxy IncomingCotangentHandling
cotangentHandling Double
range StdGen
seed Int
widthHidden Int
widthHidden2 =
  Int
-> (forall (n :: Nat).
    KnownNat n =>
    SNat n
    -> (Concrete (XParams2 r q),
        AstArtifactRev
          (TKProduct
             (XParams2 r q)
             (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
          (TKScalar r)))
-> (Concrete (XParams2 r q),
    AstArtifactRev
      (TKProduct
         (XParams2 r q)
         (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKScalar r))
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
widthHidden ((forall (n :: Nat).
  KnownNat n =>
  SNat n
  -> (Concrete (XParams2 r q),
      AstArtifactRev
        (TKProduct
           (XParams2 r q)
           (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
        (TKScalar r)))
 -> (Concrete (XParams2 r q),
     AstArtifactRev
       (TKProduct
          (XParams2 r q)
          (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
       (TKScalar r)))
-> (forall (n :: Nat).
    KnownNat n =>
    SNat n
    -> (Concrete (XParams2 r q),
        AstArtifactRev
          (TKProduct
             (XParams2 r q)
             (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
          (TKScalar r)))
-> (Concrete (XParams2 r q),
    AstArtifactRev
      (TKProduct
         (XParams2 r q)
         (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKScalar r))
forall a b. (a -> b) -> a -> b
$ \(SNat @widthHidden) ->
  Int
-> (forall (n :: Nat).
    KnownNat n =>
    SNat n
    -> (Concrete (XParams2 r q),
        AstArtifactRev
          (TKProduct
             (XParams2 r q)
             (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
          (TKScalar r)))
-> (Concrete (XParams2 r q),
    AstArtifactRev
      (TKProduct
         (XParams2 r q)
         (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKScalar r))
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
widthHidden2 ((forall (n :: Nat).
  KnownNat n =>
  SNat n
  -> (Concrete (XParams2 r q),
      AstArtifactRev
        (TKProduct
           (XParams2 r q)
           (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
        (TKScalar r)))
 -> (Concrete (XParams2 r q),
     AstArtifactRev
       (TKProduct
          (XParams2 r q)
          (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
       (TKScalar r)))
-> (forall (n :: Nat).
    KnownNat n =>
    SNat n
    -> (Concrete (XParams2 r q),
        AstArtifactRev
          (TKProduct
             (XParams2 r q)
             (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
          (TKScalar r)))
-> (Concrete (XParams2 r q),
    AstArtifactRev
      (TKProduct
         (XParams2 r q)
         (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKScalar r))
forall a b. (a -> b) -> a -> b
$ \(SNat @widthHidden2) ->
  -- Initial parameter generation is counted as part of compilation time.
  let targetInit :: NoShape
  (Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar q))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
targetInit =
        Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
           (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar q))
           (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> NoShape
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar q))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
            (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
         (TKProduct
            (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar q))
            (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
      (TKProduct
         (TKS2
            ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
         (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
 -> NoShape
      (Concrete
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
                  (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
               (TKProduct
                  (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar q))
                  (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
            (TKProduct
               (TKS2
                  ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
               (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar q))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> NoShape
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar q))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
forall a b. (a -> b) -> a -> b
$ (Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r q)),
 StdGen)
-> Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r q))
forall a b. (a, b) -> a
fst
        ((Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r q)),
  StdGen)
 -> Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r q)))
-> (Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r q)),
    StdGen)
-> Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r q))
forall a b. (a -> b) -> a -> b
$ forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue @(Concrete (X (MnistFcnnRanked2.ADFcnnMnist2ParametersShaped
                                   Concrete widthHidden widthHidden2 r q)))
                      Double
range StdGen
seed
      ftk :: FullShapeTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar q)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
ftk = forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams2 r q)) Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar q)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
NoShape
  (Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar q))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
targetInit
      ftkData :: FullShapeTK
  (TKProduct (TKR2 (0 + 1) (TKScalar r)) (TKR2 (0 + 1) (TKScalar r)))
ftkData = FullShapeTK (TKR2 (0 + 1) (TKScalar r))
-> FullShapeTK (TKR2 (0 + 1) (TKScalar r))
-> FullShapeTK
     (TKProduct (TKR2 (0 + 1) (TKScalar r)) (TKR2 (0 + 1) (TKScalar r)))
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct (IShR (0 + 1)
-> FullShapeTK (TKScalar r)
-> FullShapeTK (TKR2 (0 + 1) (TKScalar r))
forall (n :: Nat) (x :: TK).
IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
FTKR (Int
sizeMnistGlyphInt Int -> ShR 0 Int -> IShR (0 + 1)
forall {n1 :: Nat} {i} (n :: Nat).
((n + 1 :: Nat) ~ (n1 :: Nat)) =>
i -> ShR n i -> ShR n1 i
:$: ShR 0 Int
forall (n :: Nat) i. ((n :: Nat) ~ (0 :: Nat)) => ShR n i
ZSR) FullShapeTK (TKScalar r)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar)
                           (IShR (0 + 1)
-> FullShapeTK (TKScalar r)
-> FullShapeTK (TKR2 (0 + 1) (TKScalar r))
forall (n :: Nat) (x :: TK).
IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
FTKR (Int
sizeMnistLabelInt Int -> ShR 0 Int -> IShR (0 + 1)
forall {n1 :: Nat} {i} (n :: Nat).
((n + 1 :: Nat) ~ (n1 :: Nat)) =>
i -> ShR n i -> ShR n1 i
:$: ShR 0 Int
forall (n :: Nat) i. ((n :: Nat) ~ (0 :: Nat)) => ShR n i
ZSR) FullShapeTK (TKScalar r)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar)
      f :: ( MnistFcnnRanked2.ADFcnnMnist2Parameters
               (AstTensor AstMethodLet FullSpan) r q
           , ( AstTensor AstMethodLet FullSpan (TKR 1 r)
             , AstTensor AstMethodLet FullSpan (TKR 1 r) ) )
        -> AstTensor AstMethodLet FullSpan (TKScalar r)
      f :: (ADFcnnMnist2Parameters (AstTensor AstMethodLet FullSpan) r q,
 (AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)),
  AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))))
-> AstTensor AstMethodLet FullSpan (TKScalar r)
f (ADFcnnMnist2Parameters (AstTensor AstMethodLet FullSpan) r q
pars, (AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))
glyphR, AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))
labelR)) =
        (AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)),
 AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)))
-> ADFcnnMnist2Parameters (AstTensor AstMethodLet FullSpan) r q
-> AstTensor AstMethodLet FullSpan (TKScalar r)
forall (target :: Target) r q.
(ADReady target, GoodScalar r, Differentiable r, GoodScalar q,
 Differentiable q) =>
(target (TKR 1 r), target (TKR 1 r))
-> ADFcnnMnist2Parameters target r q -> target (TKScalar r)
afcnnMnistLoss2 (AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))
glyphR, AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))
labelR) ADFcnnMnist2Parameters (AstTensor AstMethodLet FullSpan) r q
pars
      artRaw :: AstArtifactRev
  (X (ADFcnnMnist2Parameters (AstTensor AstMethodLet FullSpan) r q,
      (AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)),
       AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)))))
  (TKScalar r)
artRaw = IncomingCotangentHandling
-> ((ADFcnnMnist2Parameters (AstTensor AstMethodLet FullSpan) r q,
     (AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)),
      AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))))
    -> AstTensor AstMethodLet FullSpan (TKScalar r))
-> FullShapeTK
     (X (ADFcnnMnist2Parameters (AstTensor AstMethodLet FullSpan) r q,
         (AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)),
          AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)))))
-> AstArtifactRev
     (X (ADFcnnMnist2Parameters (AstTensor AstMethodLet FullSpan) r q,
         (AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)),
          AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)))))
     (TKScalar r)
forall src (ztgt :: TK) tgt.
(AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
 (tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
IncomingCotangentHandling
-> (src -> tgt)
-> FullShapeTK (X src)
-> AstArtifactRev (X src) ztgt
revArtifactAdapt IncomingCotangentHandling
cotangentHandling (ADFcnnMnist2Parameters (AstTensor AstMethodLet FullSpan) r q,
 (AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)),
  AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))))
-> AstTensor AstMethodLet FullSpan (TKScalar r)
f (FullShapeTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar q)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> FullShapeTK
     (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r)))
-> FullShapeTK
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar q)) (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct FullShapeTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar q)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
ftk FullShapeTK (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r)))
ftkData)
  in (Concrete (XParams2 r q)
NoShape
  (Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar q))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
targetInit, AstArtifactRev
  (TKProduct
     (XParams2 r q)
     (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
  (TKScalar r)
AstArtifactRev
  (X (ADFcnnMnist2Parameters (AstTensor AstMethodLet FullSpan) r q,
      (AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)),
       AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)))))
  (TKScalar r)
artRaw)
{-# SPECIALIZE mnistTrainBench2VTOGradient :: Proxy Float -> IncomingCotangentHandling -> Double -> StdGen -> Int -> Int -> ( Concrete (XParams2 Double Float), AstArtifactRev (TKProduct (XParams2 Double Float) (TKProduct (TKR2 1 (TKScalar Double)) (TKR2 1 (TKScalar Double)))) (TKScalar Double) ) #-}
{-# SPECIALIZE mnistTrainBench2VTOGradient :: Proxy Float -> IncomingCotangentHandling -> Double -> StdGen -> Int -> Int -> ( Concrete (XParams2 Float Float), AstArtifactRev (TKProduct (XParams2 Float Float) (TKProduct (TKR2 1 (TKScalar Float)) (TKR2 1 (TKScalar Float)))) (TKScalar Float) ) #-}
{-# SPECIALIZE mnistTrainBench2VTOGradient :: Proxy Double -> IncomingCotangentHandling -> Double -> StdGen -> Int -> Int -> ( Concrete (XParams2 Double Double), AstArtifactRev (TKProduct (XParams2 Double Double) (TKProduct (TKR2 1 (TKScalar Double)) (TKR2 1 (TKScalar Double)))) (TKScalar Double) ) #-}

-- | A version of 'mnistTrainBench2VTOGradient' without any simplification,
-- even the AST smart constructors. Intended for benchmarking.
mnistTrainBench2VTOGradientX
  :: forall r q. ( GoodScalar r, Differentiable r
                 , GoodScalar q, Differentiable q )
  => Proxy q -> IncomingCotangentHandling -> Double -> StdGen -> Int -> Int
  -> ( Concrete (XParams2 r q)
     , AstArtifactRev
         (TKProduct
            (XParams2 r q)
            (TKProduct (TKR2 1 (TKScalar r))
                       (TKR2 1 (TKScalar r))))
         (TKScalar r) )
mnistTrainBench2VTOGradientX :: forall r q.
(GoodScalar r, Differentiable r, GoodScalar q, Differentiable q) =>
Proxy @Type q
-> IncomingCotangentHandling
-> Double
-> StdGen
-> Int
-> Int
-> (Concrete (XParams2 r q),
    AstArtifactRev
      (TKProduct
         (XParams2 r q)
         (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKScalar r))
mnistTrainBench2VTOGradientX Proxy @Type q
Proxy IncomingCotangentHandling
cotangentHandling Double
range StdGen
seed Int
widthHidden Int
widthHidden2 =
  Int
-> (forall (n :: Nat).
    KnownNat n =>
    SNat n
    -> (Concrete (XParams2 r q),
        AstArtifactRev
          (TKProduct
             (XParams2 r q)
             (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
          (TKScalar r)))
-> (Concrete (XParams2 r q),
    AstArtifactRev
      (TKProduct
         (XParams2 r q)
         (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKScalar r))
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
widthHidden ((forall (n :: Nat).
  KnownNat n =>
  SNat n
  -> (Concrete (XParams2 r q),
      AstArtifactRev
        (TKProduct
           (XParams2 r q)
           (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
        (TKScalar r)))
 -> (Concrete (XParams2 r q),
     AstArtifactRev
       (TKProduct
          (XParams2 r q)
          (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
       (TKScalar r)))
-> (forall (n :: Nat).
    KnownNat n =>
    SNat n
    -> (Concrete (XParams2 r q),
        AstArtifactRev
          (TKProduct
             (XParams2 r q)
             (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
          (TKScalar r)))
-> (Concrete (XParams2 r q),
    AstArtifactRev
      (TKProduct
         (XParams2 r q)
         (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKScalar r))
forall a b. (a -> b) -> a -> b
$ \(SNat @widthHidden) ->
  Int
-> (forall (n :: Nat).
    KnownNat n =>
    SNat n
    -> (Concrete (XParams2 r q),
        AstArtifactRev
          (TKProduct
             (XParams2 r q)
             (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
          (TKScalar r)))
-> (Concrete (XParams2 r q),
    AstArtifactRev
      (TKProduct
         (XParams2 r q)
         (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKScalar r))
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
widthHidden2 ((forall (n :: Nat).
  KnownNat n =>
  SNat n
  -> (Concrete (XParams2 r q),
      AstArtifactRev
        (TKProduct
           (XParams2 r q)
           (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
        (TKScalar r)))
 -> (Concrete (XParams2 r q),
     AstArtifactRev
       (TKProduct
          (XParams2 r q)
          (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
       (TKScalar r)))
-> (forall (n :: Nat).
    KnownNat n =>
    SNat n
    -> (Concrete (XParams2 r q),
        AstArtifactRev
          (TKProduct
             (XParams2 r q)
             (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
          (TKScalar r)))
-> (Concrete (XParams2 r q),
    AstArtifactRev
      (TKProduct
         (XParams2 r q)
         (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
      (TKScalar r))
forall a b. (a -> b) -> a -> b
$ \(SNat @widthHidden2) ->
  -- Initial parameter generation is counted as part of compilation time.
  let targetInit :: NoShape
  (Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar q))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
targetInit =
        Concrete
  (TKProduct
     (TKProduct
        (TKProduct
           (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
           (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
        (TKProduct
           (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar q))
           (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
     (TKProduct
        (TKS2
           ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
        (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> NoShape
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar q))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape (Concrete
   (TKProduct
      (TKProduct
         (TKProduct
            (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
            (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
         (TKProduct
            (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar q))
            (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
      (TKProduct
         (TKS2
            ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
         (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
 -> NoShape
      (Concrete
         (TKProduct
            (TKProduct
               (TKProduct
                  (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
                  (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
               (TKProduct
                  (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar q))
                  (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
            (TKProduct
               (TKS2
                  ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
               (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
-> Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar q))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> NoShape
     (Concrete
        (TKProduct
           (TKProduct
              (TKProduct
                 (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
              (TKProduct
                 (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar q))
                 (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
           (TKProduct
              (TKS2
                 ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
forall a b. (a -> b) -> a -> b
$ (Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r q)),
 StdGen)
-> Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r q))
forall a b. (a, b) -> a
fst
        ((Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r q)),
  StdGen)
 -> Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r q)))
-> (Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r q)),
    StdGen)
-> Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r q))
forall a b. (a -> b) -> a -> b
$ forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue @(Concrete (X (MnistFcnnRanked2.ADFcnnMnist2ParametersShaped
                                   Concrete widthHidden widthHidden2 r q)))
                      Double
range StdGen
seed
      ftk :: FullShapeTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar q)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
ftk = forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams2 r q)) Concrete
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar q)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
NoShape
  (Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar q))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
targetInit
      ftkData :: FullShapeTK
  (TKProduct (TKR2 (0 + 1) (TKScalar r)) (TKR2 (0 + 1) (TKScalar r)))
ftkData = FullShapeTK (TKR2 (0 + 1) (TKScalar r))
-> FullShapeTK (TKR2 (0 + 1) (TKScalar r))
-> FullShapeTK
     (TKProduct (TKR2 (0 + 1) (TKScalar r)) (TKR2 (0 + 1) (TKScalar r)))
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct (IShR (0 + 1)
-> FullShapeTK (TKScalar r)
-> FullShapeTK (TKR2 (0 + 1) (TKScalar r))
forall (n :: Nat) (x :: TK).
IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
FTKR (Int
sizeMnistGlyphInt Int -> ShR 0 Int -> IShR (0 + 1)
forall {n1 :: Nat} {i} (n :: Nat).
((n + 1 :: Nat) ~ (n1 :: Nat)) =>
i -> ShR n i -> ShR n1 i
:$: ShR 0 Int
forall (n :: Nat) i. ((n :: Nat) ~ (0 :: Nat)) => ShR n i
ZSR) FullShapeTK (TKScalar r)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar)
                           (IShR (0 + 1)
-> FullShapeTK (TKScalar r)
-> FullShapeTK (TKR2 (0 + 1) (TKScalar r))
forall (n :: Nat) (x :: TK).
IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
FTKR (Int
sizeMnistLabelInt Int -> ShR 0 Int -> IShR (0 + 1)
forall {n1 :: Nat} {i} (n :: Nat).
((n + 1 :: Nat) ~ (n1 :: Nat)) =>
i -> ShR n i -> ShR n1 i
:$: ShR 0 Int
forall (n :: Nat) i. ((n :: Nat) ~ (0 :: Nat)) => ShR n i
ZSR) FullShapeTK (TKScalar r)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar)
      f :: ( MnistFcnnRanked2.ADFcnnMnist2Parameters
               (AstTensor AstMethodLet FullSpan) r q
           , ( AstTensor AstMethodLet FullSpan (TKR 1 r)
             , AstTensor AstMethodLet FullSpan (TKR 1 r) ) )
        -> AstTensor AstMethodLet FullSpan (TKScalar r)
      f :: (ADFcnnMnist2Parameters (AstTensor AstMethodLet FullSpan) r q,
 (AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)),
  AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))))
-> AstTensor AstMethodLet FullSpan (TKScalar r)
f (((AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar r))
hidden, AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))
bias), (AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar q))
hidden2, AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))
bias2), (AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar r))
readout, AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))
biasr)), (AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))
glyphR, AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))
labelR)) =
        AstNoSimplify FullSpan (TKScalar r)
-> AstTensor AstMethodLet FullSpan (TKScalar r)
forall (s :: AstSpanType) (y :: TK).
AstNoSimplify s y -> AstTensor AstMethodLet s y
unAstNoSimplify
        (AstNoSimplify FullSpan (TKScalar r)
 -> AstTensor AstMethodLet FullSpan (TKScalar r))
-> AstNoSimplify FullSpan (TKScalar r)
-> AstTensor AstMethodLet FullSpan (TKScalar r)
forall a b. (a -> b) -> a -> b
$ (AstNoSimplify FullSpan (TKR2 1 (TKScalar r)),
 AstNoSimplify FullSpan (TKR2 1 (TKScalar r)))
-> ADFcnnMnist2Parameters (AstNoSimplify FullSpan) r q
-> AstNoSimplify FullSpan (TKScalar r)
forall (target :: Target) r q.
(ADReady target, GoodScalar r, Differentiable r, GoodScalar q,
 Differentiable q) =>
(target (TKR 1 r), target (TKR 1 r))
-> ADFcnnMnist2Parameters target r q -> target (TKScalar r)
afcnnMnistLoss2 (AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))
-> AstNoSimplify FullSpan (TKR2 1 (TKScalar r))
forall (s :: AstSpanType) (y :: TK).
AstTensor AstMethodLet s y -> AstNoSimplify s y
AstNoSimplify AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))
glyphR, AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))
-> AstNoSimplify FullSpan (TKR2 1 (TKScalar r))
forall (s :: AstSpanType) (y :: TK).
AstTensor AstMethodLet s y -> AstNoSimplify s y
AstNoSimplify AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))
labelR)
                          ( (AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar r))
-> AstNoSimplify FullSpan (TKR2 2 (TKScalar r))
forall (s :: AstSpanType) (y :: TK).
AstTensor AstMethodLet s y -> AstNoSimplify s y
AstNoSimplify AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar r))
hidden, AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))
-> AstNoSimplify FullSpan (TKR2 1 (TKScalar r))
forall (s :: AstSpanType) (y :: TK).
AstTensor AstMethodLet s y -> AstNoSimplify s y
AstNoSimplify AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))
bias)
                          , (AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar q))
-> AstNoSimplify FullSpan (TKR2 2 (TKScalar q))
forall (s :: AstSpanType) (y :: TK).
AstTensor AstMethodLet s y -> AstNoSimplify s y
AstNoSimplify AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar q))
hidden2, AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))
-> AstNoSimplify FullSpan (TKR2 1 (TKScalar r))
forall (s :: AstSpanType) (y :: TK).
AstTensor AstMethodLet s y -> AstNoSimplify s y
AstNoSimplify AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))
bias2)
                          , (AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar r))
-> AstNoSimplify FullSpan (TKR2 2 (TKScalar r))
forall (s :: AstSpanType) (y :: TK).
AstTensor AstMethodLet s y -> AstNoSimplify s y
AstNoSimplify AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar r))
readout, AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))
-> AstNoSimplify FullSpan (TKR2 1 (TKScalar r))
forall (s :: AstSpanType) (y :: TK).
AstTensor AstMethodLet s y -> AstNoSimplify s y
AstNoSimplify AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))
biasr) )
      artRaw :: AstArtifactRev
  (X (ADFcnnMnist2Parameters (AstTensor AstMethodLet FullSpan) r q,
      (AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)),
       AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)))))
  (TKScalar r)
artRaw = IncomingCotangentHandling
-> ((ADFcnnMnist2Parameters (AstTensor AstMethodLet FullSpan) r q,
     (AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)),
      AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))))
    -> AstTensor AstMethodLet FullSpan (TKScalar r))
-> FullShapeTK
     (X (ADFcnnMnist2Parameters (AstTensor AstMethodLet FullSpan) r q,
         (AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)),
          AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)))))
-> AstArtifactRev
     (X (ADFcnnMnist2Parameters (AstTensor AstMethodLet FullSpan) r q,
         (AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)),
          AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)))))
     (TKScalar r)
forall src (ztgt :: TK) tgt.
(AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
 (tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
IncomingCotangentHandling
-> (src -> tgt)
-> FullShapeTK (X src)
-> AstArtifactRev (X src) ztgt
revArtifactAdapt IncomingCotangentHandling
cotangentHandling (ADFcnnMnist2Parameters (AstTensor AstMethodLet FullSpan) r q,
 (AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)),
  AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))))
-> AstTensor AstMethodLet FullSpan (TKScalar r)
f (FullShapeTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar q)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> FullShapeTK
     (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r)))
-> FullShapeTK
     (TKProduct
        (TKProduct
           (TKProduct
              (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
              (TKProduct (TKR2 2 (TKScalar q)) (TKR2 1 (TKScalar r))))
           (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
        (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct FullShapeTK
  (TKProduct
     (TKProduct
        (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
        (TKProduct (TKR2 2 (TKScalar q)) (TKR2 1 (TKScalar r))))
     (TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
ftk FullShapeTK (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r)))
ftkData)
  in (Concrete (XParams2 r q)
NoShape
  (Concrete
     (TKProduct
        (TKProduct
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
           (TKProduct
              (TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar q))
              (TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
        (TKProduct
           (TKS2
              ((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
           (TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
targetInit, AstArtifactRev
  (TKProduct
     (XParams2 r q)
     (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
  (TKScalar r)
AstArtifactRev
  (X (ADFcnnMnist2Parameters (AstTensor AstMethodLet FullSpan) r q,
      (AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)),
       AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)))))
  (TKScalar r)
artRaw)
{-# SPECIALIZE mnistTrainBench2VTOGradientX :: Proxy Float -> IncomingCotangentHandling -> Double -> StdGen -> Int -> Int -> ( Concrete (XParams2 Double Float), AstArtifactRev (TKProduct (XParams2 Double Float) (TKProduct (TKR2 1 (TKScalar Double)) (TKR2 1 (TKScalar Double)))) (TKScalar Double) ) #-}
{-# SPECIALIZE mnistTrainBench2VTOGradientX :: Proxy Float -> IncomingCotangentHandling -> Double -> StdGen -> Int -> Int -> ( Concrete (XParams2 Float Float), AstArtifactRev (TKProduct (XParams2 Float Float) (TKProduct (TKR2 1 (TKScalar Float)) (TKR2 1 (TKScalar Float)))) (TKScalar Float) ) #-}
{-# SPECIALIZE mnistTrainBench2VTOGradientX :: Proxy Double -> IncomingCotangentHandling -> Double -> StdGen -> Int -> Int -> ( Concrete (XParams2 Double Double), AstArtifactRev (TKProduct (XParams2 Double Double) (TKProduct (TKR2 1 (TKScalar Double)) (TKR2 1 (TKScalar Double)))) (TKScalar Double) ) #-}