{-# LANGUAGE TemplateHaskell #-}
{-# OPTIONS_GHC -Wno-missing-export-lists #-}
module BenchMnistTools where
import Prelude
import Control.Arrow ((***))
import Control.DeepSeq (NFData (..))
import Criterion.Main
import Data.Default qualified as Default
import Data.Proxy (Proxy (Proxy))
import GHC.Exts (WithDict)
import GHC.TypeLits (KnownNat)
import System.Random
import Test.Inspection
import Type.Reflection (Typeable)
import HordeAd
import HordeAd.Core.Adaptor
import HordeAd.Core.OpsConcrete ()
import HordeAd.External.OptimizerTools
import Data.Array.Nested qualified as Nested
import Data.Array.Nested.Shaped.Shape
import MnistData
import MnistFcnnRanked1 qualified
import MnistFcnnRanked2 (XParams2)
import MnistFcnnRanked2 qualified
type XParams widthHidden widthHidden2 r =
X (MnistFcnnRanked1.ADFcnnMnist1Parameters
Concrete widthHidden widthHidden2 r)
mnistTrainBench1VTA
:: forall r. r ~ Double
=> String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r]
-> Benchmark
mnistTrainBench1VTA :: forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTrainBench1VTA String
prefix Int
widthHiddenInt Int
widthHidden2Int
Double
gamma Int
batchSize [MnistDataLinearR r]
xs =
Int
-> (forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
widthHiddenInt ((forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark)
-> (forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark
forall a b. (a -> b) -> a -> b
$ \(SNat n
widthHiddenSNat :: SNat widthHidden) ->
Int
-> (forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
widthHidden2Int ((forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark)
-> (forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark
forall a b. (a -> b) -> a -> b
$ \(SNat n
widthHidden2SNat :: SNat widthHidden2) ->
SingletonTK (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
-> (KnownSTK (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)) =>
Benchmark)
-> Benchmark
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK
(SingletonTK (TKS ((':) @Nat 784 ('[] @Nat)) Double)
-> SNat n
-> SingletonTK (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
forall (t :: TK) (n :: Nat).
SingletonTK t -> SNat n -> SingletonTK (Tups n t)
stkOfListR (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(TKS '[SizeMnistGlyph] r)) (forall (n :: Nat). KnownNat n => SNat n
SNat @widthHidden)) ((KnownSTK (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)) =>
Benchmark)
-> Benchmark)
-> (KnownSTK (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)) =>
Benchmark)
-> Benchmark
forall a b. (a -> b) -> a -> b
$
SingletonTK (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
-> (KnownSTK (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)) =>
Benchmark)
-> Benchmark
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK
(SingletonTK (TKS ((':) @Nat n ('[] @Nat)) Float)
-> SNat n
-> SingletonTK (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
forall (t :: TK) (n :: Nat).
SingletonTK t -> SNat n -> SingletonTK (Tups n t)
stkOfListR (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(TKS '[widthHidden] Float)) (forall (n :: Nat). KnownNat n => SNat n
SNat @widthHidden2)) ((KnownSTK (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)) =>
Benchmark)
-> Benchmark)
-> (KnownSTK (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)) =>
Benchmark)
-> Benchmark
forall a b. (a -> b) -> a -> b
$
let valsInit :: MnistFcnnRanked1.ADFcnnMnist1Parameters
Concrete widthHidden widthHidden2 r
valsInit :: ADFcnnMnist1Parameters Concrete n n r
valsInit = (ADFcnnMnist1Parameters Concrete n n r, StdGen)
-> ADFcnnMnist1Parameters Concrete n n r
forall a b. (a, b) -> a
fst ((ADFcnnMnist1Parameters Concrete n n r, StdGen)
-> ADFcnnMnist1Parameters Concrete n n r)
-> (ADFcnnMnist1Parameters Concrete n n r, StdGen)
-> ADFcnnMnist1Parameters Concrete n n r
forall a b. (a -> b) -> a -> b
$ Double
-> StdGen
-> (((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) Double)),
Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
(ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
(ListR
SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))),
StdGen)
forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue Double
1 (Int -> StdGen
mkStdGen Int
44)
targetInit :: Concrete (XParams widthHidden widthHidden2 r)
targetInit :: Concrete (XParams n n r)
targetInit = forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget @Concrete ((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) Double)),
Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
(ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
(ListR
SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
ADFcnnMnist1Parameters Concrete n n r
valsInit
in do
let f :: MnistDataLinearR Double
-> ADVal Concrete (XParams widthHidden widthHidden2 Double)
-> ADVal Concrete (TKScalar Double)
f :: MnistDataLinearR Double
-> ADVal Concrete (XParams n n Double)
-> ADVal Concrete (TKScalar Double)
f (Ranked 1 Double
glyph, Ranked 1 Double
label) ADVal Concrete (XParams n n Double)
adinputs =
SNat n
-> SNat n
-> (ADVal Concrete (TKR 1 Double), ADVal Concrete (TKR 1 Double))
-> ADFcnnMnist1Parameters (ADVal Concrete) n n Double
-> ADVal Concrete (TKScalar Double)
forall (target :: Target) r (widthHidden :: Nat)
(widthHidden2 :: Nat).
(ADReady target, GoodScalar r, Differentiable r) =>
SNat widthHidden
-> SNat widthHidden2
-> (target (TKR 1 r), target (TKR 1 r))
-> ADFcnnMnist1Parameters target widthHidden widthHidden2 r
-> target (TKScalar r)
MnistFcnnRanked1.afcnnMnistLoss1
SNat n
widthHiddenSNat SNat n
widthHidden2SNat
(Ranked 1 Double -> ADVal Concrete (TKR 1 Double)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 1 Double
glyph, Ranked 1 Double -> ADVal Concrete (TKR 1 Double)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 1 Double
label) (ADVal
Concrete
(X ((ListR
n (ADVal Concrete (TKS ((':) @Nat 784 ('[] @Nat)) Double)),
ADVal Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
(ListR n (ADVal Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
ADVal Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
(ListR
SizeMnistLabel
(ADVal Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
ADVal
Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))))
-> ((ListR
n (ADVal Concrete (TKS ((':) @Nat 784 ('[] @Nat)) Double)),
ADVal Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
(ListR n (ADVal Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
ADVal Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
(ListR
SizeMnistLabel
(ADVal Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
ADVal Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget ADVal
Concrete
(X ((ListR
n (ADVal Concrete (TKS ((':) @Nat 784 ('[] @Nat)) Double)),
ADVal Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
(ListR n (ADVal Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
ADVal Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
(ListR
SizeMnistLabel
(ADVal Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
ADVal
Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))))
ADVal Concrete (XParams n n Double)
adinputs)
chunk :: [MnistDataLinearR r]
chunk = Int -> [MnistDataLinearR r] -> [MnistDataLinearR r]
forall a. Int -> [a] -> [a]
take Int
batchSize [MnistDataLinearR r]
xs
gradf :: [MnistDataLinearR Double]
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
gradf [MnistDataLinearR Double]
c = (Concrete
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))),
Concrete (TKScalar Double))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
forall a b. (a, b) -> a
fst ((Concrete
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))),
Concrete (TKScalar Double))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))))
-> (Concrete
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))),
Concrete (TKScalar Double))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
forall a b. (a -> b) -> a -> b
$ SingletonTK
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
-> Double
-> (MnistDataLinearR Double
-> ADVal
Concrete
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
-> ADVal Concrete (TKScalar Double))
-> [MnistDataLinearR Double]
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
-> (Concrete
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))),
Concrete (TKScalar Double))
forall a (x :: TK) (z :: TK).
SingletonTK x
-> Double
-> (a -> ADVal Concrete x -> ADVal Concrete z)
-> [a]
-> Concrete x
-> (Concrete x, Concrete z)
sgdSTK SingletonTK
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Double
gamma MnistDataLinearR Double
-> ADVal
Concrete
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
-> ADVal Concrete (TKScalar Double)
MnistDataLinearR Double
-> ADVal Concrete (XParams n n Double)
-> ADVal Concrete (TKScalar Double)
f [MnistDataLinearR Double]
c Concrete
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
Concrete (XParams n n r)
targetInit
name :: String
name =
String
prefix
String -> String -> String
forall a. [a] -> [a] -> [a]
++ [String] -> String
unwords
[ String
"v" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (SingletonTK (XParams n n r) -> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK
(SingletonTK (XParams n n r) -> Int)
-> SingletonTK (XParams n n r) -> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams widthHidden widthHidden2 r))
, String
"m0" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" =" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (SingletonTK
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
Concrete (XParams n n r)
targetInit) ]
String -> Benchmarkable -> Benchmark
bench String
name (Benchmarkable -> Benchmark) -> Benchmarkable -> Benchmark
forall a b. (a -> b) -> a -> b
$ ([MnistDataLinearR Double]
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))))
-> [MnistDataLinearR Double] -> Benchmarkable
forall b a. NFData b => (a -> b) -> a -> Benchmarkable
nf [MnistDataLinearR Double]
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
gradf [MnistDataLinearR r]
[MnistDataLinearR Double]
chunk
mnistTestBench1VTA
:: forall r. r ~ Double
=> String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r]
-> Benchmark
mnistTestBench1VTA :: forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTestBench1VTA String
prefix Int
widthHiddenInt Int
widthHidden2Int
Double
_gamma Int
batchSize [MnistDataLinearR r]
xs =
Int
-> (forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
widthHiddenInt ((forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark)
-> (forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark
forall a b. (a -> b) -> a -> b
$ \(SNat n
widthHiddenSNat :: SNat widthHidden) ->
Int
-> (forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
widthHidden2Int ((forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark)
-> (forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark
forall a b. (a -> b) -> a -> b
$ \(SNat n
widthHidden2SNat :: SNat widthHidden2) ->
SingletonTK (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
-> (KnownSTK (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)) =>
Benchmark)
-> Benchmark
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK
(SingletonTK (TKS ((':) @Nat 784 ('[] @Nat)) Double)
-> SNat n
-> SingletonTK (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
forall (t :: TK) (n :: Nat).
SingletonTK t -> SNat n -> SingletonTK (Tups n t)
stkOfListR (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(TKS '[SizeMnistGlyph] r)) (forall (n :: Nat). KnownNat n => SNat n
SNat @widthHidden)) ((KnownSTK (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)) =>
Benchmark)
-> Benchmark)
-> (KnownSTK (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)) =>
Benchmark)
-> Benchmark
forall a b. (a -> b) -> a -> b
$
SingletonTK (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
-> (KnownSTK (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)) =>
Benchmark)
-> Benchmark
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK
(SingletonTK (TKS ((':) @Nat n ('[] @Nat)) Float)
-> SNat n
-> SingletonTK (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
forall (t :: TK) (n :: Nat).
SingletonTK t -> SNat n -> SingletonTK (Tups n t)
stkOfListR (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(TKS '[widthHidden] Float)) (forall (n :: Nat). KnownNat n => SNat n
SNat @widthHidden2)) ((KnownSTK (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)) =>
Benchmark)
-> Benchmark)
-> (KnownSTK (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)) =>
Benchmark)
-> Benchmark
forall a b. (a -> b) -> a -> b
$
let valsInit :: MnistFcnnRanked1.ADFcnnMnist1Parameters
Concrete widthHidden widthHidden2 r
valsInit :: ADFcnnMnist1Parameters Concrete n n r
valsInit = (ADFcnnMnist1Parameters Concrete n n r, StdGen)
-> ADFcnnMnist1Parameters Concrete n n r
forall a b. (a, b) -> a
fst ((ADFcnnMnist1Parameters Concrete n n r, StdGen)
-> ADFcnnMnist1Parameters Concrete n n r)
-> (ADFcnnMnist1Parameters Concrete n n r, StdGen)
-> ADFcnnMnist1Parameters Concrete n n r
forall a b. (a -> b) -> a -> b
$ Double
-> StdGen
-> (((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) Double)),
Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
(ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
(ListR
SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))),
StdGen)
forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue Double
1 (Int -> StdGen
mkStdGen Int
44)
targetInit :: Concrete (XParams widthHidden widthHidden2 r)
targetInit :: Concrete (XParams n n r)
targetInit = forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget @Concrete ((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) Double)),
Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
(ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
(ListR
SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
ADFcnnMnist1Parameters Concrete n n r
valsInit
ftest :: [MnistDataLinearR r]
-> MnistFcnnRanked1.ADFcnnMnist1Parameters
Concrete widthHidden widthHidden2 r
-> r
ftest :: [MnistDataLinearR r] -> ADFcnnMnist1Parameters Concrete n n r -> r
ftest = SNat n
-> SNat n
-> [MnistDataLinearR r]
-> ADFcnnMnist1Parameters Concrete n n r
-> r
forall (target :: Target) (widthHidden :: Nat)
(widthHidden2 :: Nat) r.
((target :: Target) ~ (Concrete :: Target), GoodScalar r,
Differentiable r) =>
SNat widthHidden
-> SNat widthHidden2
-> [MnistDataLinearR r]
-> ADFcnnMnist1Parameters target widthHidden widthHidden2 r
-> r
MnistFcnnRanked1.afcnnMnistTest1 SNat n
widthHiddenSNat SNat n
widthHidden2SNat
in do
let chunk :: [MnistDataLinearR r]
chunk = Int -> [MnistDataLinearR r] -> [MnistDataLinearR r]
forall a. Int -> [a] -> [a]
take Int
batchSize [MnistDataLinearR r]
xs
score :: [MnistDataLinearR r] -> r
score [MnistDataLinearR r]
c = [MnistDataLinearR r] -> ADFcnnMnist1Parameters Concrete n n r -> r
ftest [MnistDataLinearR r]
c ADFcnnMnist1Parameters Concrete n n r
valsInit
name :: String
name =
String
"test " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
prefix
String -> String -> String
forall a. [a] -> [a] -> [a]
++ [String] -> String
unwords
[ String
"v" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (SingletonTK (XParams n n r) -> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK
(SingletonTK (XParams n n r) -> Int)
-> SingletonTK (XParams n n r) -> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams widthHidden widthHidden2 r))
, String
"m0" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" =" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (SingletonTK
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
Concrete (XParams n n r)
targetInit) ]
String -> Benchmarkable -> Benchmark
bench String
name (Benchmarkable -> Benchmark) -> Benchmarkable -> Benchmark
forall a b. (a -> b) -> a -> b
$ ([MnistDataLinearR r] -> r)
-> [MnistDataLinearR r] -> Benchmarkable
forall a b. (a -> b) -> a -> Benchmarkable
whnf [MnistDataLinearR r] -> r
score [MnistDataLinearR r]
chunk
mnistBGroup1VTA :: Int -> Benchmark
mnistBGroup1VTA :: Int -> Benchmark
mnistBGroup1VTA Int
chunkLength =
IO [MnistDataLinearR Double]
-> ([MnistDataLinearR Double] -> Benchmark) -> Benchmark
forall env. NFData env => IO env -> (env -> Benchmark) -> Benchmark
env (do
testData0 <- String -> String -> IO [MnistData Double]
forall r.
(Storable r, Fractional r) =>
String -> String -> IO [MnistData r]
loadMnistData String
testGlyphsPath String
testLabelsPath
let testData = StdGen -> [MnistData Double] -> [MnistData Double]
forall a. StdGen -> [a] -> [a]
shuffle (Int -> StdGen
mkStdGen Int
42) [MnistData Double]
testData0
return $! map mkMnistDataLinearR $ take chunkLength testData) (([MnistDataLinearR Double] -> Benchmark) -> Benchmark)
-> ([MnistDataLinearR Double] -> Benchmark) -> Benchmark
forall a b. (a -> b) -> a -> b
$
\ [MnistDataLinearR Double]
xs ->
String -> [Benchmark] -> Benchmark
bgroup (String
"2-hidden-layer rank 1 VTA MNIST nn with samples: "
String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
chunkLength)
[ String
-> Int
-> Int
-> Double
-> Int
-> [MnistDataLinearR Double]
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTestBench1VTA String
"30|10 " Int
30 Int
10 Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs
, String
-> Int
-> Int
-> Double
-> Int
-> [MnistDataLinearR Double]
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTrainBench1VTA String
"30|10 " Int
30 Int
10 Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs
, String
-> Int
-> Int
-> Double
-> Int
-> [MnistDataLinearR Double]
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTestBench1VTA String
"300|100 " Int
300 Int
100 Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs
, String
-> Int
-> Int
-> Double
-> Int
-> [MnistDataLinearR Double]
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTrainBench1VTA String
"300|100 " Int
300 Int
100 Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs
, String
-> Int
-> Int
-> Double
-> Int
-> [MnistDataLinearR Double]
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTestBench1VTA String
"500|150 " Int
500 Int
150 Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs
, String
-> Int
-> Int
-> Double
-> Int
-> [MnistDataLinearR Double]
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTrainBench1VTA String
"500|150 " Int
500 Int
150 Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs
, String
-> Int
-> Int
-> Double
-> Int
-> [MnistDataLinearR Double]
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTrainBench1VTA String
"1500|500 " Int
1500 Int
500 Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs
]
mnistTrainBench1VTO
:: forall r. r ~ Double
=> String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r]
-> Benchmark
mnistTrainBench1VTO :: forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTrainBench1VTO String
prefix Int
widthHiddenInt Int
widthHidden2Int
Double
gamma Int
batchSize [MnistDataLinearR r]
xs =
Int
-> (forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
widthHiddenInt ((forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark)
-> (forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark
forall a b. (a -> b) -> a -> b
$ \(SNat n
widthHiddenSNat :: SNat widthHidden) ->
Int
-> (forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
widthHidden2Int ((forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark)
-> (forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark
forall a b. (a -> b) -> a -> b
$ \(SNat n
widthHidden2SNat :: SNat widthHidden2) ->
SingletonTK (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
-> (KnownSTK (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)) =>
Benchmark)
-> Benchmark
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK
(SingletonTK (TKS ((':) @Nat 784 ('[] @Nat)) Double)
-> SNat n
-> SingletonTK (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
forall (t :: TK) (n :: Nat).
SingletonTK t -> SNat n -> SingletonTK (Tups n t)
stkOfListR (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(TKS '[SizeMnistGlyph] r)) (forall (n :: Nat). KnownNat n => SNat n
SNat @widthHidden)) ((KnownSTK (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)) =>
Benchmark)
-> Benchmark)
-> (KnownSTK (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)) =>
Benchmark)
-> Benchmark
forall a b. (a -> b) -> a -> b
$
SingletonTK (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
-> (KnownSTK (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)) =>
Benchmark)
-> Benchmark
forall (y :: TK) r. SingletonTK y -> (KnownSTK y => r) -> r
withKnownSTK
(SingletonTK (TKS ((':) @Nat n ('[] @Nat)) Float)
-> SNat n
-> SingletonTK (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
forall (t :: TK) (n :: Nat).
SingletonTK t -> SNat n -> SingletonTK (Tups n t)
stkOfListR (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(TKS '[widthHidden] Float)) (forall (n :: Nat). KnownNat n => SNat n
SNat @widthHidden2)) ((KnownSTK (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)) =>
Benchmark)
-> Benchmark)
-> (KnownSTK (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)) =>
Benchmark)
-> Benchmark
forall a b. (a -> b) -> a -> b
$
let valsInit :: MnistFcnnRanked1.ADFcnnMnist1Parameters
Concrete widthHidden widthHidden2 r
valsInit :: ADFcnnMnist1Parameters Concrete n n r
valsInit = (ADFcnnMnist1Parameters Concrete n n r, StdGen)
-> ADFcnnMnist1Parameters Concrete n n r
forall a b. (a, b) -> a
fst ((ADFcnnMnist1Parameters Concrete n n r, StdGen)
-> ADFcnnMnist1Parameters Concrete n n r)
-> (ADFcnnMnist1Parameters Concrete n n r, StdGen)
-> ADFcnnMnist1Parameters Concrete n n r
forall a b. (a -> b) -> a -> b
$ Double
-> StdGen
-> (((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) Double)),
Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
(ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
(ListR
SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))),
StdGen)
forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue Double
1 (Int -> StdGen
mkStdGen Int
44)
targetInit :: Concrete (XParams widthHidden widthHidden2 r)
targetInit :: Concrete (XParams n n r)
targetInit = forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget @Concrete ((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) Double)),
Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
(ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
(ListR
SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
ADFcnnMnist1Parameters Concrete n n r
valsInit
in do
let dataInit :: (Concrete (TKR 1 Double), Concrete (TKR 1 Double))
dataInit = case [MnistDataLinearR r]
xs of
MnistDataLinearR r
d : [MnistDataLinearR r]
_ -> (Ranked 1 Double -> Concrete (TKR 1 Double)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete (Ranked 1 Double -> Concrete (TKR 1 Double))
-> (Ranked 1 Double -> Concrete (TKR 1 Double))
-> MnistDataLinearR Double
-> (Concrete (TKR 1 Double), Concrete (TKR 1 Double))
forall b c b' c'. (b -> c) -> (b' -> c') -> (b, b') -> (c, c')
forall (a :: Type -> Type -> Type) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
*** Ranked 1 Double -> Concrete (TKR 1 Double)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete) MnistDataLinearR r
d
[] -> String -> (Concrete (TKR 1 Double), Concrete (TKR 1 Double))
forall a. HasCallStack => String -> a
error String
"empty test data"
f :: ( MnistFcnnRanked1.ADFcnnMnist1Parameters
(AstTensor AstMethodLet FullSpan)
widthHidden widthHidden2 Double
, ( AstTensor AstMethodLet FullSpan (TKR 1 Double)
, AstTensor AstMethodLet FullSpan (TKR 1 Double) ) )
-> AstTensor AstMethodLet FullSpan (TKScalar Double)
f :: (ADFcnnMnist1Parameters
(AstTensor AstMethodLet FullSpan) n n Double,
(AstTensor AstMethodLet FullSpan (TKR 1 Double),
AstTensor AstMethodLet FullSpan (TKR 1 Double)))
-> AstTensor AstMethodLet FullSpan (TKScalar Double)
f = \ (ADFcnnMnist1Parameters (AstTensor AstMethodLet FullSpan) n n Double
pars, (AstTensor AstMethodLet FullSpan (TKR 1 Double)
glyphR, AstTensor AstMethodLet FullSpan (TKR 1 Double)
labelR)) ->
forall (target :: Target) r (widthHidden :: Nat)
(widthHidden2 :: Nat).
(ADReady target, GoodScalar r, Differentiable r) =>
SNat widthHidden
-> SNat widthHidden2
-> (target (TKR 1 r), target (TKR 1 r))
-> ADFcnnMnist1Parameters target widthHidden widthHidden2 r
-> target (TKScalar r)
MnistFcnnRanked1.afcnnMnistLoss1 @_ @Double
SNat n
widthHiddenSNat SNat n
widthHidden2SNat
(AstTensor AstMethodLet FullSpan (TKR 1 Double)
glyphR, AstTensor AstMethodLet FullSpan (TKR 1 Double)
labelR) ADFcnnMnist1Parameters (AstTensor AstMethodLet FullSpan) n n Double
pars
artRaw :: AstArtifactRev
(X (((ListR
n
(AstTensor
AstMethodLet FullSpan (TKS ((':) @Nat 784 ('[] @Nat)) Double)),
AstTensor
AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Double)),
(ListR
n
(AstTensor
AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Float)),
AstTensor
AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Double)),
(ListR
SizeMnistLabel
(AstTensor
AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Double)),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))),
(AstTensor AstMethodLet FullSpan (TKR 1 Double),
AstTensor AstMethodLet FullSpan (TKR 1 Double))))
(TKScalar Double)
artRaw = ((((ListR
n
(AstTensor
AstMethodLet FullSpan (TKS ((':) @Nat 784 ('[] @Nat)) Double)),
AstTensor
AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Double)),
(ListR
n
(AstTensor
AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Float)),
AstTensor
AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Double)),
(ListR
SizeMnistLabel
(AstTensor
AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Double)),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))),
(AstTensor AstMethodLet FullSpan (TKR 1 Double),
AstTensor AstMethodLet FullSpan (TKR 1 Double)))
-> AstTensor AstMethodLet FullSpan (TKScalar Double))
-> Value
(((ListR
n
(AstTensor
AstMethodLet FullSpan (TKS ((':) @Nat 784 ('[] @Nat)) Double)),
AstTensor
AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Double)),
(ListR
n
(AstTensor
AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Float)),
AstTensor
AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Double)),
(ListR
SizeMnistLabel
(AstTensor
AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Double)),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))),
(AstTensor AstMethodLet FullSpan (TKR 1 Double),
AstTensor AstMethodLet FullSpan (TKR 1 Double)))
-> AstArtifactRev
(X (((ListR
n
(AstTensor
AstMethodLet FullSpan (TKS ((':) @Nat 784 ('[] @Nat)) Double)),
AstTensor
AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Double)),
(ListR
n
(AstTensor
AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Float)),
AstTensor
AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Double)),
(ListR
SizeMnistLabel
(AstTensor
AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Double)),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))),
(AstTensor AstMethodLet FullSpan (TKR 1 Double),
AstTensor AstMethodLet FullSpan (TKR 1 Double))))
(TKScalar Double)
forall src r tgt.
((X src :: TK) ~ (X (Value src) :: TK), KnownSTK (X src),
AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
AdaptableTarget Concrete (Value src),
(tgt :: Type)
~ (AstTensor AstMethodLet FullSpan (TKScalar r) :: Type)) =>
(src -> tgt) -> Value src -> AstArtifactRev (X src) (TKScalar r)
gradArtifact (((ListR
n
(AstTensor
AstMethodLet FullSpan (TKS ((':) @Nat 784 ('[] @Nat)) Double)),
AstTensor
AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Double)),
(ListR
n
(AstTensor
AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Float)),
AstTensor
AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Double)),
(ListR
SizeMnistLabel
(AstTensor
AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Double)),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))),
(AstTensor AstMethodLet FullSpan (TKR 1 Double),
AstTensor AstMethodLet FullSpan (TKR 1 Double)))
-> AstTensor AstMethodLet FullSpan (TKScalar Double)
(ADFcnnMnist1Parameters
(AstTensor AstMethodLet FullSpan) n n Double,
(AstTensor AstMethodLet FullSpan (TKR 1 Double),
AstTensor AstMethodLet FullSpan (TKR 1 Double)))
-> AstTensor AstMethodLet FullSpan (TKScalar Double)
f (((ListR n (Concrete (TKS ((':) @Nat 784 ('[] @Nat)) Double)),
Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
(ListR n (Concrete (TKS ((':) @Nat n ('[] @Nat)) Float)),
Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
(ListR
SizeMnistLabel (Concrete (TKS ((':) @Nat n ('[] @Nat)) Double)),
Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
ADFcnnMnist1Parameters Concrete n n r
valsInit, (Concrete (TKR 1 Double), Concrete (TKR 1 Double))
dataInit)
art :: AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
(TKProduct (TKR 1 r) (TKR 1 r)))
(TKScalar Double)
art = AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
(TKProduct (TKR 1 r) (TKR 1 r)))
(TKScalar Double)
-> AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
(TKProduct (TKR 1 r) (TKR 1 r)))
(TKScalar Double)
forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstArtifactRev x z
simplifyArtifactGradient AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
(TKProduct (TKR 1 r) (TKR 1 r)))
(TKScalar Double)
AstArtifactRev
(X (((ListR
n
(AstTensor
AstMethodLet FullSpan (TKS ((':) @Nat 784 ('[] @Nat)) Double)),
AstTensor
AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Double)),
(ListR
n
(AstTensor
AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Float)),
AstTensor
AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Double)),
(ListR
SizeMnistLabel
(AstTensor
AstMethodLet FullSpan (TKS ((':) @Nat n ('[] @Nat)) Double)),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))),
(AstTensor AstMethodLet FullSpan (TKR 1 Double),
AstTensor AstMethodLet FullSpan (TKR 1 Double))))
(TKScalar Double)
artRaw
go :: [MnistDataLinearR r]
-> Concrete (XParams widthHidden widthHidden2 r)
-> Concrete (XParams widthHidden widthHidden2 r)
go :: [MnistDataLinearR r]
-> Concrete (XParams n n r) -> Concrete (XParams n n r)
go [] Concrete (XParams n n r)
parameters = Concrete (XParams n n r)
parameters
go ((Ranked 1 r
glyph, Ranked 1 r
label) : [MnistDataLinearR r]
rest) !Concrete (XParams n n r)
parameters =
let parametersAndInput :: Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
(TKProduct (TKR 1 r) (TKR 1 r)))
parametersAndInput =
Concrete
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
-> Concrete (TKProduct (TKR 1 r) (TKR 1 r))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
(TKProduct (TKR 1 r) (TKR 1 r)))
forall (x :: TK) (z :: TK).
Concrete x -> Concrete z -> Concrete (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair Concrete
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
Concrete (XParams n n r)
parameters (Concrete (TKR 1 r)
-> Concrete (TKR 1 r) -> Concrete (TKProduct (TKR 1 r) (TKR 1 r))
forall (x :: TK) (z :: TK).
Concrete x -> Concrete z -> Concrete (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair (Ranked 1 r -> Concrete (TKR 1 r)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 1 r
glyph) (Ranked 1 r -> Concrete (TKR 1 r)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 1 r
label))
gradient :: Concrete
(TKProduct
(TKProduct
(TKProduct
(ADTensorKind (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(ADTensorKind (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
gradient = Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(ADTensorKind (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(ADTensorKind (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double)))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(ADTensorKind (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(ADTensorKind (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
forall (x :: TK) (z :: TK). Concrete (TKProduct x z) -> Concrete x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(ADTensorKind (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(ADTensorKind (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double)))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(ADTensorKind (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(ADTensorKind (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(ADTensorKind (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(ADTensorKind (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double)))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(ADTensorKind (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(ADTensorKind (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
forall a b. (a -> b) -> a -> b
$ (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(ADTensorKind (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(ADTensorKind (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double))),
Concrete (TKScalar Double))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(ADTensorKind (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(ADTensorKind (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double)))
forall a b. (a, b) -> a
fst
((Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(ADTensorKind (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(ADTensorKind (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double))),
Concrete (TKScalar Double))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(ADTensorKind (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(ADTensorKind (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double))))
-> (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(ADTensorKind (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(ADTensorKind (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double))),
Concrete (TKScalar Double))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(ADTensorKind (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(ADTensorKind (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double)))
forall a b. (a -> b) -> a -> b
$ AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
(TKProduct (TKR 1 r) (TKR 1 r)))
(TKScalar Double)
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
(TKProduct (TKR 1 r) (TKR 1 r)))
-> Maybe (Concrete (ADTensorKind (TKScalar Double)))
-> (Concrete
(ADTensorKind
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
(TKProduct (TKR 1 r) (TKR 1 r)))),
Concrete (TKScalar Double))
forall (x :: TK) (z :: TK).
AstArtifactRev x z
-> Concrete x
-> Maybe (Concrete (ADTensorKind z))
-> (Concrete (ADTensorKind x), Concrete z)
revInterpretArtifact AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
(TKProduct (TKR 1 r) (TKR 1 r)))
(TKScalar Double)
art Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
(TKProduct (TKR 1 r) (TKR 1 r)))
parametersAndInput Maybe (Concrete (ADTensorKind (TKScalar Double)))
Maybe (Concrete (TKScalar Double))
forall a. Maybe a
Nothing
in [MnistDataLinearR r]
-> Concrete (XParams n n r) -> Concrete (XParams n n r)
go [MnistDataLinearR r]
rest (Double
-> SingletonTK
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
-> Concrete
(ADTensorKind
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
forall (y :: TK).
Double
-> SingletonTK y
-> Concrete y
-> Concrete (ADTensorKind y)
-> Concrete y
updateWithGradient Double
gamma SingletonTK
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
Concrete (XParams n n r)
parameters Concrete
(ADTensorKind
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))))
Concrete
(TKProduct
(TKProduct
(TKProduct
(ADTensorKind (Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double)))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(ADTensorKind (Tups n (TKS ((':) @Nat n ('[] @Nat)) Float)))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
gradient)
chunk :: [MnistDataLinearR r]
chunk = Int -> [MnistDataLinearR r] -> [MnistDataLinearR r]
forall a. Int -> [a] -> [a]
take Int
batchSize [MnistDataLinearR r]
xs
gradf :: [MnistDataLinearR r] -> Concrete (XParams n n r)
gradf [MnistDataLinearR r]
c = [MnistDataLinearR r]
-> Concrete (XParams n n r) -> Concrete (XParams n n r)
go [MnistDataLinearR r]
c Concrete (XParams n n r)
targetInit
name :: String
name =
String
prefix
String -> String -> String
forall a. [a] -> [a] -> [a]
++ [String] -> String
unwords
[ String
"v" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (SingletonTK (XParams n n r) -> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK
(SingletonTK (XParams n n r) -> Int)
-> SingletonTK (XParams n n r) -> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams widthHidden widthHidden2 r))
, String
"m0" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" =" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (SingletonTK
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct (TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
Concrete (XParams n n r)
targetInit) ]
String -> Benchmarkable -> Benchmark
bench String
name (Benchmarkable -> Benchmark) -> Benchmarkable -> Benchmark
forall a b. (a -> b) -> a -> b
$ ([MnistDataLinearR r]
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))))
-> [MnistDataLinearR r] -> Benchmarkable
forall b a. NFData b => (a -> b) -> a -> Benchmarkable
nf [MnistDataLinearR r]
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(Tups n (TKS ((':) @Nat 784 ('[] @Nat)) Double))
(TKS ((':) @Nat n ('[] @Nat)) Double))
(TKProduct
(Tups n (TKS ((':) @Nat n ('[] @Nat)) Float))
(TKS ((':) @Nat n ('[] @Nat)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double)
(TKProduct
(TKS ((':) @Nat n ('[] @Nat)) Double) TKUnit))))))))))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
[MnistDataLinearR r] -> Concrete (XParams n n r)
gradf [MnistDataLinearR r]
chunk
mnistTestBench1VTO
:: forall r. r ~ Double
=> String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r]
-> Benchmark
mnistTestBench1VTO :: forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTestBench1VTO = String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTestBench1VTA
mnistBGroup1VTO :: Int -> Benchmark
mnistBGroup1VTO :: Int -> Benchmark
mnistBGroup1VTO Int
chunkLength =
IO [MnistDataLinearR Double]
-> ([MnistDataLinearR Double] -> Benchmark) -> Benchmark
forall env. NFData env => IO env -> (env -> Benchmark) -> Benchmark
env (do
testData0 <- String -> String -> IO [MnistData Double]
forall r.
(Storable r, Fractional r) =>
String -> String -> IO [MnistData r]
loadMnistData String
testGlyphsPath String
testLabelsPath
let testData = StdGen -> [MnistData Double] -> [MnistData Double]
forall a. StdGen -> [a] -> [a]
shuffle (Int -> StdGen
mkStdGen Int
42) [MnistData Double]
testData0
return $! map mkMnistDataLinearR $ take chunkLength testData) (([MnistDataLinearR Double] -> Benchmark) -> Benchmark)
-> ([MnistDataLinearR Double] -> Benchmark) -> Benchmark
forall a b. (a -> b) -> a -> b
$
\ [MnistDataLinearR Double]
xs ->
String -> [Benchmark] -> Benchmark
bgroup (String
"2-hidden-layer rank 1 VTO MNIST nn with samples: "
String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
chunkLength)
[ String
-> Int
-> Int
-> Double
-> Int
-> [MnistDataLinearR Double]
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTestBench1VTO String
"30|10 " Int
30 Int
10 Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs
, String
-> Int
-> Int
-> Double
-> Int
-> [MnistDataLinearR Double]
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTrainBench1VTO String
"30|10 " Int
30 Int
10 Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs
, String
-> Int
-> Int
-> Double
-> Int
-> [MnistDataLinearR Double]
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTestBench1VTO String
"300|100 " Int
300 Int
100 Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs
, String
-> Int
-> Int
-> Double
-> Int
-> [MnistDataLinearR Double]
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTrainBench1VTO String
"300|100 " Int
300 Int
100 Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs
, String
-> Int
-> Int
-> Double
-> Int
-> [MnistDataLinearR Double]
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTestBench1VTO String
"500|150 " Int
500 Int
150 Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs
, String
-> Int
-> Int
-> Double
-> Int
-> [MnistDataLinearR Double]
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTrainBench1VTO String
"500|150 " Int
500 Int
150 Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs
, String
-> Int
-> Int
-> Double
-> Int
-> [MnistDataLinearR Double]
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTrainBench1VTO String
"1500|500 " Int
1500 Int
500 Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs
]
mnistTrainBench2VTA
:: forall r. r ~ Double
=> String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r]
-> Benchmark
mnistTrainBench2VTA :: forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTrainBench2VTA String
prefix Int
widthHidden Int
widthHidden2
Double
gamma Int
batchSize [MnistDataLinearR r]
xs =
Int
-> (forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
widthHidden ((forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark)
-> (forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark
forall a b. (a -> b) -> a -> b
$ \(SNat @widthHidden) ->
Int
-> (forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
widthHidden2 ((forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark)
-> (forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark
forall a b. (a -> b) -> a -> b
$ \(SNat @widthHidden2) ->
let targetInit :: NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar Double))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double)))
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat)))
(TKScalar Double))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))))
targetInit =
Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar Double))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double)))
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat)))
(TKScalar Double))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
-> NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar Double))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double)))
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat)))
(TKScalar Double))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))))
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar Double))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double)))
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat)))
(TKScalar Double))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
-> NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar Double))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double)))
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat)))
(TKScalar Double))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar Double))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double)))
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat)))
(TKScalar Double))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
-> NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar Double))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double)))
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat)))
(TKScalar Double))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))))
forall a b. (a -> b) -> a -> b
$ (Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r Float)),
StdGen)
-> Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r Float))
forall a b. (a, b) -> a
fst
((Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r Float)),
StdGen)
-> Concrete
(X (ADFcnnMnist2ParametersShaped Concrete n n r Float)))
-> (Concrete
(X (ADFcnnMnist2ParametersShaped Concrete n n r Float)),
StdGen)
-> Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r Float))
forall a b. (a -> b) -> a -> b
$ forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue
@(Concrete (X (MnistFcnnRanked2.ADFcnnMnist2ParametersShaped
Concrete widthHidden widthHidden2 r Float)))
Double
1 (Int -> StdGen
mkStdGen Int
44)
in do
let f :: MnistDataLinearR Double -> ADVal Concrete (XParams2 Double Float)
-> ADVal Concrete (TKScalar Double)
f :: MnistDataLinearR Double
-> ADVal Concrete (XParams2 Double Float)
-> ADVal Concrete (TKScalar Double)
f (Ranked 1 Double
glyph, Ranked 1 Double
label) ADVal Concrete (XParams2 Double Float)
adinputs =
(ADVal Concrete (TKR 1 Double), ADVal Concrete (TKR 1 Double))
-> ADFcnnMnist2Parameters (ADVal Concrete) Double Float
-> ADVal Concrete (TKScalar Double)
forall (target :: Target) r q.
(ADReady target, GoodScalar r, Differentiable r, GoodScalar q,
Differentiable q) =>
(target (TKR 1 r), target (TKR 1 r))
-> ADFcnnMnist2Parameters target r q -> target (TKScalar r)
MnistFcnnRanked2.afcnnMnistLoss2
(Ranked 1 Double -> ADVal Concrete (TKR 1 Double)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 1 Double
glyph, Ranked 1 Double -> ADVal Concrete (TKR 1 Double)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 1 Double
label) (ADVal
Concrete (X (ADFcnnMnist2Parameters (ADVal Concrete) Double Float))
-> ADFcnnMnist2Parameters (ADVal Concrete) Double Float
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget ADVal
Concrete (X (ADFcnnMnist2Parameters (ADVal Concrete) Double Float))
ADVal Concrete (XParams2 Double Float)
adinputs)
chunk :: [MnistDataLinearR r]
chunk = Int -> [MnistDataLinearR r] -> [MnistDataLinearR r]
forall a. Int -> [a] -> [a]
take Int
batchSize [MnistDataLinearR r]
xs
gradf :: [MnistDataLinearR Double]
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
gradf [MnistDataLinearR Double]
c = (Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))),
Concrete (TKScalar Double))
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
forall a b. (a, b) -> a
fst ((Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))),
Concrete (TKScalar Double))
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))))
-> (Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))),
Concrete (TKScalar Double))
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
forall a b. (a -> b) -> a -> b
$ Double
-> (MnistDataLinearR Double
-> ADVal
Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> ADVal Concrete (TKScalar Double))
-> [MnistDataLinearR Double]
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> (Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))),
Concrete (TKScalar Double))
forall a (x :: TK) (z :: TK).
KnownSTK x =>
Double
-> (a -> ADVal Concrete x -> ADVal Concrete z)
-> [a]
-> Concrete x
-> (Concrete x, Concrete z)
sgd Double
gamma MnistDataLinearR Double
-> ADVal
Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> ADVal Concrete (TKScalar Double)
MnistDataLinearR Double
-> ADVal Concrete (XParams2 Double Float)
-> ADVal Concrete (TKScalar Double)
f [MnistDataLinearR Double]
c Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar Double))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double)))
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat)))
(TKScalar Double))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))))
targetInit
name :: String
name =
String
prefix
String -> String -> String
forall a. [a] -> [a] -> [a]
++ [String] -> String
unwords
[ String
"v0 m" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (SingletonTK (XParams2 r Float) -> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK (SingletonTK (XParams2 r Float) -> Int)
-> SingletonTK (XParams2 r Float) -> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams2 r Float))
, String
"=" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (SingletonTK
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar Double))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double)))
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat)))
(TKScalar Double))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))))
targetInit) ]
String -> Benchmarkable -> Benchmark
bench String
name (Benchmarkable -> Benchmark) -> Benchmarkable -> Benchmark
forall a b. (a -> b) -> a -> b
$ ([MnistDataLinearR Double]
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))))
-> [MnistDataLinearR Double] -> Benchmarkable
forall b a. NFData b => (a -> b) -> a -> Benchmarkable
nf [MnistDataLinearR Double]
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
gradf [MnistDataLinearR r]
[MnistDataLinearR Double]
chunk
mnistTestBench2VTA
:: forall r. r ~ Double
=> String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r]
-> Benchmark
mnistTestBench2VTA :: forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTestBench2VTA String
prefix Int
widthHidden Int
widthHidden2
Double
_gamma Int
batchSize [MnistDataLinearR r]
xs =
Int
-> (forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
widthHidden ((forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark)
-> (forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark
forall a b. (a -> b) -> a -> b
$ \(SNat @widthHidden) ->
Int
-> (forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
widthHidden2 ((forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark)
-> (forall (n :: Nat). KnownNat n => SNat n -> Benchmark)
-> Benchmark
forall a b. (a -> b) -> a -> b
$ \(SNat @widthHidden2) ->
let targetInit :: NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar Double))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double)))
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat)))
(TKScalar Double))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))))
targetInit =
Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar Double))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double)))
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat)))
(TKScalar Double))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
-> NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar Double))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double)))
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat)))
(TKScalar Double))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))))
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar Double))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double)))
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat)))
(TKScalar Double))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
-> NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar Double))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double)))
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat)))
(TKScalar Double))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar Double))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double)))
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat)))
(TKScalar Double))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double)))
-> NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar Double))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double)))
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat)))
(TKScalar Double))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))))
forall a b. (a -> b) -> a -> b
$ (Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r Float)),
StdGen)
-> Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r Float))
forall a b. (a, b) -> a
fst
((Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r Float)),
StdGen)
-> Concrete
(X (ADFcnnMnist2ParametersShaped Concrete n n r Float)))
-> (Concrete
(X (ADFcnnMnist2ParametersShaped Concrete n n r Float)),
StdGen)
-> Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r Float))
forall a b. (a -> b) -> a -> b
$ forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue
@(Concrete (X (MnistFcnnRanked2.ADFcnnMnist2ParametersShaped
Concrete widthHidden widthHidden2 r Float)))
Double
1 (Int -> StdGen
mkStdGen Int
44)
in do
let chunk :: [MnistDataLinearR r]
chunk = Int -> [MnistDataLinearR r] -> [MnistDataLinearR r]
forall a. Int -> [a] -> [a]
take Int
batchSize [MnistDataLinearR r]
xs
score :: [MnistDataLinearR r] -> r
score [MnistDataLinearR r]
c = [MnistDataLinearR r]
-> ADFcnnMnist2Parameters Concrete r Float -> r
forall (target :: Target) r q.
((target :: Target) ~ (Concrete :: Target), GoodScalar r,
Differentiable r, GoodScalar q, Differentiable q) =>
[MnistDataLinearR r] -> ADFcnnMnist2Parameters target r q -> r
MnistFcnnRanked2.afcnnMnistTest2 [MnistDataLinearR r]
c (Concrete (XParams2 r Float)
-> ADFcnnMnist2Parameters Concrete r Float
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget Concrete (XParams2 r Float)
NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar Double))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double)))
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat)))
(TKScalar Double))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))))
targetInit)
name :: String
name =
String
"test " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
prefix
String -> String -> String
forall a. [a] -> [a] -> [a]
++ [String] -> String
unwords
[ String
"v0 m" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (SingletonTK (XParams2 r Float) -> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK (SingletonTK (XParams2 r Float) -> Int)
-> SingletonTK (XParams2 r Float) -> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams2 r Float))
, String
"=" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (SingletonTK
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar Double))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double)))
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar Float))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar Double))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat)))
(TKScalar Double))
(TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) Double))))
targetInit) ]
String -> Benchmarkable -> Benchmark
bench String
name (Benchmarkable -> Benchmark) -> Benchmarkable -> Benchmark
forall a b. (a -> b) -> a -> b
$ ([MnistDataLinearR r] -> r)
-> [MnistDataLinearR r] -> Benchmarkable
forall a b. (a -> b) -> a -> Benchmarkable
whnf [MnistDataLinearR r] -> r
score [MnistDataLinearR r]
chunk
mnistBGroup2VTA :: Int -> Benchmark
mnistBGroup2VTA :: Int -> Benchmark
mnistBGroup2VTA Int
chunkLength =
IO [MnistDataLinearR Double]
-> ([MnistDataLinearR Double] -> Benchmark) -> Benchmark
forall env. NFData env => IO env -> (env -> Benchmark) -> Benchmark
env (do
testData0 <- String -> String -> IO [MnistData Double]
forall r.
(Storable r, Fractional r) =>
String -> String -> IO [MnistData r]
loadMnistData String
testGlyphsPath String
testLabelsPath
let testData = StdGen -> [MnistData Double] -> [MnistData Double]
forall a. StdGen -> [a] -> [a]
shuffle (Int -> StdGen
mkStdGen Int
42) [MnistData Double]
testData0
return $! map mkMnistDataLinearR $ take chunkLength testData) (([MnistDataLinearR Double] -> Benchmark) -> Benchmark)
-> ([MnistDataLinearR Double] -> Benchmark) -> Benchmark
forall a b. (a -> b) -> a -> b
$
\ [MnistDataLinearR Double]
xs ->
String -> [Benchmark] -> Benchmark
bgroup (String
"2-hidden-layer rank 2 VTA MNIST nn with samples: "
String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
chunkLength)
[ String
-> Int
-> Int
-> Double
-> Int
-> [MnistDataLinearR Double]
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTestBench2VTA String
"30|10 " Int
30 Int
10 Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs
, String
-> Int
-> Int
-> Double
-> Int
-> [MnistDataLinearR Double]
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTrainBench2VTA String
"30|10 " Int
30 Int
10 Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs
, String
-> Int
-> Int
-> Double
-> Int
-> [MnistDataLinearR Double]
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTestBench2VTA String
"300|100 " Int
300 Int
100 Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs
, String
-> Int
-> Int
-> Double
-> Int
-> [MnistDataLinearR Double]
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTrainBench2VTA String
"300|100 " Int
300 Int
100 Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs
, String
-> Int
-> Int
-> Double
-> Int
-> [MnistDataLinearR Double]
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTestBench2VTA String
"500|150 " Int
500 Int
150 Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs
, String
-> Int
-> Int
-> Double
-> Int
-> [MnistDataLinearR Double]
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTrainBench2VTA String
"500|150 " Int
500 Int
150 Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs
, String
-> Int
-> Int
-> Double
-> Int
-> [MnistDataLinearR Double]
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTrainBench2VTA String
"1500|500 " Int
1500 Int
500 Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs
]
mnistTrainBench2VTC
:: String
-> Int -> Int
-> Benchmark
mnistTrainBench2VTC :: String -> Int -> Int -> Benchmark
mnistTrainBench2VTC String
prefix Int
widthHidden Int
widthHidden2 =
String -> Benchmarkable -> Benchmark
bench String
prefix
(Benchmarkable -> Benchmark) -> Benchmarkable -> Benchmark
forall a b. (a -> b) -> a -> b
$ (Int
-> AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar Double))
-> Int -> Benchmarkable
forall a b. (a -> b) -> a -> Benchmarkable
whnf (AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar Double)
-> AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar Double)
forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstArtifactRev x z
simplifyArtifactGradient (AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar Double)
-> AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar Double))
-> (Int
-> AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar Double))
-> Int
-> AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))),
AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar Double))
-> AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar Double)
forall a b. (a, b) -> b
snd
((Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))),
AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar Double))
-> AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar Double))
-> (Int
-> (Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))),
AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar Double)))
-> Int
-> AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r q.
(GoodScalar r, Differentiable r, GoodScalar q, Differentiable q) =>
Proxy @Type q
-> IncomingCotangentHandling
-> Double
-> StdGen
-> Int
-> Int
-> (Concrete (XParams2 r q),
AstArtifactRev
(TKProduct
(XParams2 r q)
(TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKScalar r))
MnistFcnnRanked2.mnistTrainBench2VTOGradient
@Double (forall t. Proxy @Type t
forall {k} (t :: k). Proxy @k t
Proxy @Float) IncomingCotangentHandling
IgnoreIncomingCotangent
Double
1 (Int -> StdGen
mkStdGen Int
44) Int
widthHidden)
Int
widthHidden2
mnistBGroup2VTC :: Int -> Benchmark
mnistBGroup2VTC :: Int -> Benchmark
mnistBGroup2VTC Int
chunkLength =
String -> [Benchmark] -> Benchmark
bgroup (String
"2-hidden-layer rank 2 VTC compilation MNIST nn with samples: "
String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
chunkLength)
[ String -> Int -> Int -> Benchmark
mnistTrainBench2VTC String
"30|10 " Int
30 Int
10
, String -> Int -> Int -> Benchmark
mnistTrainBench2VTC String
"300|100 " Int
300 Int
100
, String -> Int -> Int -> Benchmark
mnistTrainBench2VTC String
"500|150 " Int
500 Int
150
, String -> Int -> Int -> Benchmark
mnistTrainBench2VTC String
"1500|500 " Int
1500 Int
500
]
mnistTrainBench2VTOO
:: forall r. r ~ Double
=> String
-> Double -> Int -> [MnistDataLinearR r]
-> ( Concrete (XParams2 r Float)
, AstArtifactRev
(TKProduct
(XParams2 r Float)
(TKProduct (TKR2 1 (TKScalar Double))
(TKR2 1 (TKScalar Double))))
(TKScalar r) )
-> Benchmark
mnistTrainBench2VTOO :: forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Double
-> Int
-> [MnistDataLinearR r]
-> (Concrete (XParams2 r Float),
AstArtifactRev
(TKProduct
(XParams2 r Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar r))
-> Benchmark
mnistTrainBench2VTOO String
prefix Double
gamma Int
batchSize [MnistDataLinearR r]
xs (Concrete (XParams2 r Float)
targetInit, AstArtifactRev
(TKProduct
(XParams2 r Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar r)
art) = do
let go :: [MnistDataLinearR r] -> Concrete (XParams2 r Float)
-> Concrete (XParams2 r Float)
go :: [MnistDataLinearR r]
-> Concrete (XParams2 r Float) -> Concrete (XParams2 r Float)
go [] Concrete (XParams2 r Float)
parameters = Concrete (XParams2 r Float)
parameters
go ((Ranked 1 r
glyph, Ranked 1 r
label) : [MnistDataLinearR r]
rest) !Concrete (XParams2 r Float)
parameters =
let parametersAndInput :: Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 r) (TKR 1 r)))
parametersAndInput =
Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Concrete (TKProduct (TKR 1 r) (TKR 1 r))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 r) (TKR 1 r)))
forall (x :: TK) (z :: TK).
Concrete x -> Concrete z -> Concrete (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
Concrete (XParams2 r Float)
parameters (Concrete (TKR 1 r)
-> Concrete (TKR 1 r) -> Concrete (TKProduct (TKR 1 r) (TKR 1 r))
forall (x :: TK) (z :: TK).
Concrete x -> Concrete z -> Concrete (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair (Ranked 1 r -> Concrete (TKR 1 r)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 1 r
glyph) (Ranked 1 r -> Concrete (TKR 1 r)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 1 r
label))
gradient :: Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
gradient = Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double)))
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
forall (x :: TK) (z :: TK). Concrete (TKProduct x z) -> Concrete x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double)))
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double)))
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
forall a b. (a -> b) -> a -> b
$ (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double))),
Concrete (TKScalar r))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double)))
forall a b. (a, b) -> a
fst
((Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double))),
Concrete (TKScalar r))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double))))
-> (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double))),
Concrete (TKScalar r))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double)))
forall a b. (a -> b) -> a -> b
$ AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 r) (TKR 1 r)))
(TKScalar r)
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 r) (TKR 1 r)))
-> Maybe (Concrete (ADTensorKind (TKScalar r)))
-> (Concrete
(ADTensorKind
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 r) (TKR 1 r)))),
Concrete (TKScalar r))
forall (x :: TK) (z :: TK).
AstArtifactRev x z
-> Concrete x
-> Maybe (Concrete (ADTensorKind z))
-> (Concrete (ADTensorKind x), Concrete z)
revInterpretArtifact AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 r) (TKR 1 r)))
(TKScalar r)
AstArtifactRev
(TKProduct
(XParams2 r Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar r)
art Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 r) (TKR 1 r)))
parametersAndInput Maybe (Concrete (ADTensorKind (TKScalar r)))
Maybe (Concrete (TKScalar Double))
forall a. Maybe a
Nothing
in [MnistDataLinearR r]
-> Concrete (XParams2 r Float) -> Concrete (XParams2 r Float)
go [MnistDataLinearR r]
rest (Double
-> SingletonTK
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Concrete
(ADTensorKind
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
forall (y :: TK).
Double
-> SingletonTK y
-> Concrete y
-> Concrete (ADTensorKind y)
-> Concrete y
updateWithGradient Double
gamma SingletonTK
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
Concrete (XParams2 r Float)
parameters Concrete
(ADTensorKind
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))))
Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
gradient)
chunk :: [MnistDataLinearR r]
chunk = Int -> [MnistDataLinearR r] -> [MnistDataLinearR r]
forall a. Int -> [a] -> [a]
take Int
batchSize [MnistDataLinearR r]
xs
gradf :: [MnistDataLinearR r] -> Concrete (XParams2 r Float)
gradf [MnistDataLinearR r]
c = [MnistDataLinearR r]
-> Concrete (XParams2 r Float) -> Concrete (XParams2 r Float)
go [MnistDataLinearR r]
c Concrete (XParams2 r Float)
targetInit
name :: String
name =
String
prefix
String -> String -> String
forall a. [a] -> [a] -> [a]
++ [String] -> String
unwords
[ String
"v0 m" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (SingletonTK (XParams2 r Float) -> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK (SingletonTK (XParams2 r Float) -> Int)
-> SingletonTK (XParams2 r Float) -> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams2 r Float))
, String
"=" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (SingletonTK
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
Concrete (XParams2 r Float)
targetInit) ]
String -> Benchmarkable -> Benchmark
bench String
name (Benchmarkable -> Benchmark) -> Benchmarkable -> Benchmark
forall a b. (a -> b) -> a -> b
$ ([MnistDataLinearR r]
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))))
-> [MnistDataLinearR r] -> Benchmarkable
forall b a. NFData b => (a -> b) -> a -> Benchmarkable
nf [MnistDataLinearR r]
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
[MnistDataLinearR r] -> Concrete (XParams2 r Float)
gradf [MnistDataLinearR r]
chunk
mnistTrainBench2VTO
:: forall r. r ~ Double
=> String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r]
-> Benchmark
mnistTrainBench2VTO :: forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r] -> Benchmark
mnistTrainBench2VTO String
prefix Int
widthHidden Int
widthHidden2
Double
gamma Int
batchSize [MnistDataLinearR r]
xs =
let (!Concrete (XParams2 Double Float)
targetInit, !AstArtifactRev
(TKProduct
(XParams2 Double Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar Double)
artRaw) =
forall r q.
(GoodScalar r, Differentiable r, GoodScalar q, Differentiable q) =>
Proxy @Type q
-> IncomingCotangentHandling
-> Double
-> StdGen
-> Int
-> Int
-> (Concrete (XParams2 r q),
AstArtifactRev
(TKProduct
(XParams2 r q)
(TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKScalar r))
MnistFcnnRanked2.mnistTrainBench2VTOGradient
@Double (forall t. Proxy @Type t
forall {k} (t :: k). Proxy @k t
Proxy @Float) IncomingCotangentHandling
IgnoreIncomingCotangent
Double
1 (Int -> StdGen
mkStdGen Int
44) Int
widthHidden Int
widthHidden2
!art :: AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar Double)
art = AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar Double)
-> AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar Double)
forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstArtifactRev x z
simplifyArtifactGradient AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar Double)
AstArtifactRev
(TKProduct
(XParams2 Double Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar Double)
artRaw
in String
-> Double
-> Int
-> [MnistDataLinearR r]
-> (Concrete (XParams2 r Float),
AstArtifactRev
(TKProduct
(XParams2 r Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar r))
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Double
-> Int
-> [MnistDataLinearR r]
-> (Concrete (XParams2 r Float),
AstArtifactRev
(TKProduct
(XParams2 r Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar r))
-> Benchmark
mnistTrainBench2VTOO String
prefix Double
gamma Int
batchSize [MnistDataLinearR r]
xs (Concrete (XParams2 r Float)
Concrete (XParams2 Double Float)
targetInit, AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar Double)
AstArtifactRev
(TKProduct
(XParams2 r Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar r)
art)
mnistBGroup2VTO :: Int -> Benchmark
mnistBGroup2VTO :: Int -> Benchmark
mnistBGroup2VTO Int
chunkLength =
let (!Concrete (XParams2 Double Float)
targetInit, !AstArtifactRev
(TKProduct
(XParams2 Double Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar Double)
artRaw) =
forall r q.
(GoodScalar r, Differentiable r, GoodScalar q, Differentiable q) =>
Proxy @Type q
-> IncomingCotangentHandling
-> Double
-> StdGen
-> Int
-> Int
-> (Concrete (XParams2 r q),
AstArtifactRev
(TKProduct
(XParams2 r q)
(TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKScalar r))
MnistFcnnRanked2.mnistTrainBench2VTOGradient
@Double (forall t. Proxy @Type t
forall {k} (t :: k). Proxy @k t
Proxy @Float) IncomingCotangentHandling
IgnoreIncomingCotangent
Double
1 (Int -> StdGen
mkStdGen Int
44) Int
1500 Int
500
!art :: AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar Double)
art = AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar Double)
-> AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar Double)
forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstArtifactRev x z
simplifyArtifactGradient AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar Double)
artRaw
in IO [MnistDataLinearR Double]
-> ([MnistDataLinearR Double] -> Benchmark) -> Benchmark
forall env. NFData env => IO env -> (env -> Benchmark) -> Benchmark
env (do
testData0 <- String -> String -> IO [MnistData Double]
forall r.
(Storable r, Fractional r) =>
String -> String -> IO [MnistData r]
loadMnistData String
testGlyphsPath String
testLabelsPath
let testData = StdGen -> [MnistData Double] -> [MnistData Double]
forall a. StdGen -> [a] -> [a]
shuffle (Int -> StdGen
mkStdGen Int
42) [MnistData Double]
testData0
return $! map mkMnistDataLinearR $ take chunkLength testData) (([MnistDataLinearR Double] -> Benchmark) -> Benchmark)
-> ([MnistDataLinearR Double] -> Benchmark) -> Benchmark
forall a b. (a -> b) -> a -> b
$
\ [MnistDataLinearR Double]
xs ->
String -> [Benchmark] -> Benchmark
bgroup (String
"2-hidden-layer rank 2 VTO runtime MNIST nn with samples: "
String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
chunkLength)
[ String
-> Double
-> Int
-> [MnistDataLinearR Double]
-> (Concrete (XParams2 Double Float),
AstArtifactRev
(TKProduct
(XParams2 Double Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar Double))
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Double
-> Int
-> [MnistDataLinearR r]
-> (Concrete (XParams2 r Float),
AstArtifactRev
(TKProduct
(XParams2 r Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar r))
-> Benchmark
mnistTrainBench2VTOO String
"1500|500 " Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs (Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
Concrete (XParams2 Double Float)
targetInit, AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar Double)
AstArtifactRev
(TKProduct
(XParams2 Double Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar Double)
art)
]
mnistTrainBench2VTOZ
:: forall r. r ~ Double
=> String
-> Double -> Int -> [MnistDataLinearR r]
-> ( Concrete (XParams2 r Float)
, AstArtifactRev
(TKProduct
(XParams2 r Float)
(TKProduct (TKR2 1 (TKScalar Double))
(TKR2 1 (TKScalar Double))))
(TKScalar r) )
-> Benchmark
mnistTrainBench2VTOZ :: forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Double
-> Int
-> [MnistDataLinearR r]
-> (Concrete (XParams2 r Float),
AstArtifactRev
(TKProduct
(XParams2 r Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar r))
-> Benchmark
mnistTrainBench2VTOZ String
prefix Double
gamma Int
batchSize [MnistDataLinearR r]
xs (Concrete (XParams2 r Float)
targetInit, AstArtifactRev
(TKProduct
(XParams2 r Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar r)
art) = do
let go :: [MnistDataLinearR r] -> Concrete (XParams2 r Float)
-> Concrete (XParams2 r Float)
go :: [MnistDataLinearR r]
-> Concrete (XParams2 r Float) -> Concrete (XParams2 r Float)
go [] Concrete (XParams2 r Float)
parameters = Concrete (XParams2 r Float)
parameters
go ((Ranked 1 r
glyph, Ranked 1 r
label) : [MnistDataLinearR r]
rest) !Concrete (XParams2 r Float)
parameters =
let parametersAndInput :: Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 r) (TKR 1 r)))
parametersAndInput =
Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Concrete (TKProduct (TKR 1 r) (TKR 1 r))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 r) (TKR 1 r)))
forall (x :: TK) (z :: TK).
Concrete x -> Concrete z -> Concrete (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
Concrete (XParams2 r Float)
parameters (Concrete (TKR 1 r)
-> Concrete (TKR 1 r) -> Concrete (TKProduct (TKR 1 r) (TKR 1 r))
forall (x :: TK) (z :: TK).
Concrete x -> Concrete z -> Concrete (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair (Ranked 1 r -> Concrete (TKR 1 r)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 1 r
glyph) (Ranked 1 r -> Concrete (TKR 1 r)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 1 r
label))
gradient :: Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
gradient = Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double)))
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
forall (x :: TK) (z :: TK). Concrete (TKProduct x z) -> Concrete x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double)))
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double)))
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
forall a b. (a -> b) -> a -> b
$ (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double))),
Concrete (TKScalar r))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double)))
forall a b. (a, b) -> a
fst
((Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double))),
Concrete (TKScalar r))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double))))
-> (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double))),
Concrete (TKScalar r))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double)))
forall a b. (a -> b) -> a -> b
$ AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 r) (TKR 1 r)))
(TKScalar r)
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 r) (TKR 1 r)))
-> Maybe (Concrete (ADTensorKind (TKScalar r)))
-> (Concrete
(ADTensorKind
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 r) (TKR 1 r)))),
Concrete (TKScalar r))
forall (x :: TK) (z :: TK).
AstArtifactRev x z
-> Concrete x
-> Maybe (Concrete (ADTensorKind z))
-> (Concrete (ADTensorKind x), Concrete z)
revInterpretArtifact AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 r) (TKR 1 r)))
(TKScalar r)
AstArtifactRev
(TKProduct
(XParams2 r Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar r)
art Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 r) (TKR 1 r)))
parametersAndInput Maybe (Concrete (ADTensorKind (TKScalar r)))
Maybe (Concrete (TKScalar Double))
forall a. Maybe a
Nothing
in [MnistDataLinearR r]
-> Concrete (XParams2 r Float) -> Concrete (XParams2 r Float)
go [MnistDataLinearR r]
rest (Double
-> SingletonTK
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Concrete
(ADTensorKind
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
forall (y :: TK).
Double
-> SingletonTK y
-> Concrete y
-> Concrete (ADTensorKind y)
-> Concrete y
updateWithGradient Double
gamma SingletonTK
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
Concrete (XParams2 r Float)
parameters Concrete
(ADTensorKind
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))))
Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
gradient)
chunk :: [MnistDataLinearR r]
chunk = Int -> [MnistDataLinearR r] -> [MnistDataLinearR r]
forall a. Int -> [a] -> [a]
take Int
batchSize [MnistDataLinearR r]
xs
gradf :: [MnistDataLinearR r] -> Concrete (XParams2 r Float)
gradf [MnistDataLinearR r]
c = [MnistDataLinearR r]
-> Concrete (XParams2 r Float) -> Concrete (XParams2 r Float)
go [MnistDataLinearR r]
c Concrete (XParams2 r Float)
targetInit
name :: String
name =
String
prefix
String -> String -> String
forall a. [a] -> [a] -> [a]
++ [String] -> String
unwords
[ String
"v0 m" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (SingletonTK (XParams2 r Float) -> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK (SingletonTK (XParams2 r Float) -> Int)
-> SingletonTK (XParams2 r Float) -> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams2 r Float))
, String
"=" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (SingletonTK
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
Concrete (XParams2 r Float)
targetInit) ]
String -> Benchmarkable -> Benchmark
bench String
name (Benchmarkable -> Benchmark) -> Benchmarkable -> Benchmark
forall a b. (a -> b) -> a -> b
$ ([MnistDataLinearR r]
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))))
-> [MnistDataLinearR r] -> Benchmarkable
forall b a. NFData b => (a -> b) -> a -> Benchmarkable
nf [MnistDataLinearR r]
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
[MnistDataLinearR r] -> Concrete (XParams2 r Float)
gradf [MnistDataLinearR r]
chunk
mnistBGroup2VTOZ :: Int -> Benchmark
mnistBGroup2VTOZ :: Int -> Benchmark
mnistBGroup2VTOZ Int
chunkLength =
let (!Concrete (XParams2 Double Float)
targetInit, !AstArtifactRev
(TKProduct
(XParams2 Double Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar Double)
art) =
forall r q.
(GoodScalar r, Differentiable r, GoodScalar q, Differentiable q) =>
Proxy @Type q
-> IncomingCotangentHandling
-> Double
-> StdGen
-> Int
-> Int
-> (Concrete (XParams2 r q),
AstArtifactRev
(TKProduct
(XParams2 r q)
(TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKScalar r))
MnistFcnnRanked2.mnistTrainBench2VTOGradient
@Double (forall t. Proxy @Type t
forall {k} (t :: k). Proxy @k t
Proxy @Float) IncomingCotangentHandling
IgnoreIncomingCotangent
Double
1 (Int -> StdGen
mkStdGen Int
44) Int
1500 Int
500
in IO [MnistDataLinearR Double]
-> ([MnistDataLinearR Double] -> Benchmark) -> Benchmark
forall env. NFData env => IO env -> (env -> Benchmark) -> Benchmark
env (do
testData0 <- String -> String -> IO [MnistData Double]
forall r.
(Storable r, Fractional r) =>
String -> String -> IO [MnistData r]
loadMnistData String
testGlyphsPath String
testLabelsPath
let testData = StdGen -> [MnistData Double] -> [MnistData Double]
forall a. StdGen -> [a] -> [a]
shuffle (Int -> StdGen
mkStdGen Int
42) [MnistData Double]
testData0
return $! map mkMnistDataLinearR $ take chunkLength testData) (([MnistDataLinearR Double] -> Benchmark) -> Benchmark)
-> ([MnistDataLinearR Double] -> Benchmark) -> Benchmark
forall a b. (a -> b) -> a -> b
$
\ [MnistDataLinearR Double]
xs ->
String -> [Benchmark] -> Benchmark
bgroup (String
"2-hidden-layer rank 2 VTOZ runtime MNIST nn with samples: "
String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
chunkLength)
[ String
-> Double
-> Int
-> [MnistDataLinearR Double]
-> (Concrete (XParams2 Double Float),
AstArtifactRev
(TKProduct
(XParams2 Double Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar Double))
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Double
-> Int
-> [MnistDataLinearR r]
-> (Concrete (XParams2 r Float),
AstArtifactRev
(TKProduct
(XParams2 r Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar r))
-> Benchmark
mnistTrainBench2VTOO String
"1500|500 " Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs (Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
Concrete (XParams2 Double Float)
targetInit, AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar Double)
AstArtifactRev
(TKProduct
(XParams2 Double Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar Double)
art)
]
mnistTrainBench2VTOX
:: forall r. r ~ Double
=> String
-> Double -> Int -> [MnistDataLinearR r]
-> ( Concrete (XParams2 r Float)
, AstArtifactRev
(TKProduct
(XParams2 r Float)
(TKProduct (TKR2 1 (TKScalar Double))
(TKR2 1 (TKScalar Double))))
(TKScalar r) )
-> Benchmark
mnistTrainBench2VTOX :: forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Double
-> Int
-> [MnistDataLinearR r]
-> (Concrete (XParams2 r Float),
AstArtifactRev
(TKProduct
(XParams2 r Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar r))
-> Benchmark
mnistTrainBench2VTOX String
prefix Double
gamma Int
batchSize [MnistDataLinearR r]
xs (Concrete (XParams2 r Float)
targetInit, AstArtifactRev
(TKProduct
(XParams2 r Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar r)
art) = do
let go :: [MnistDataLinearR r] -> Concrete (XParams2 r Float)
-> Concrete (XParams2 r Float)
go :: [MnistDataLinearR r]
-> Concrete (XParams2 r Float) -> Concrete (XParams2 r Float)
go [] Concrete (XParams2 r Float)
parameters = Concrete (XParams2 r Float)
parameters
go ((Ranked 1 r
glyph, Ranked 1 r
label) : [MnistDataLinearR r]
rest) !Concrete (XParams2 r Float)
parameters =
let parametersAndInput :: Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 r) (TKR 1 r)))
parametersAndInput =
Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Concrete (TKProduct (TKR 1 r) (TKR 1 r))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 r) (TKR 1 r)))
forall (x :: TK) (z :: TK).
Concrete x -> Concrete z -> Concrete (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
Concrete (XParams2 r Float)
parameters (Concrete (TKR 1 r)
-> Concrete (TKR 1 r) -> Concrete (TKProduct (TKR 1 r) (TKR 1 r))
forall (x :: TK) (z :: TK).
Concrete x -> Concrete z -> Concrete (TKProduct x z)
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair (Ranked 1 r -> Concrete (TKR 1 r)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 1 r
glyph) (Ranked 1 r -> Concrete (TKR 1 r)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 1 r
label))
gradient :: Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
gradient = Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double)))
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
forall (x :: TK) (z :: TK). Concrete (TKProduct x z) -> Concrete x
forall (target :: Target) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double)))
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double)))
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
forall a b. (a -> b) -> a -> b
$ (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double))),
Concrete (TKScalar r))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double)))
forall a b. (a, b) -> a
fst
((Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double))),
Concrete (TKScalar r))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double))))
-> (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double))),
Concrete (TKScalar r))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double)))
forall a b. (a -> b) -> a -> b
$ AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 r) (TKR 1 r)))
(TKScalar r)
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 r) (TKR 1 r)))
-> Maybe (Concrete (ADTensorKind (TKScalar r)))
-> (Concrete
(ADTensorKind
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 r) (TKR 1 r)))),
Concrete (TKScalar r))
forall (x :: TK) (z :: TK).
AstArtifactRev x z
-> Concrete x
-> Maybe (Concrete (ADTensorKind z))
-> (Concrete (ADTensorKind x), Concrete z)
revInterpretArtifact AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 r) (TKR 1 r)))
(TKScalar r)
AstArtifactRev
(TKProduct
(XParams2 r Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar r)
art Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 r) (TKR 1 r)))
parametersAndInput Maybe (Concrete (ADTensorKind (TKScalar r)))
Maybe (Concrete (TKScalar Double))
forall a. Maybe a
Nothing
in [MnistDataLinearR r]
-> Concrete (XParams2 r Float) -> Concrete (XParams2 r Float)
go [MnistDataLinearR r]
rest (Double
-> SingletonTK
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Concrete
(ADTensorKind
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
forall (y :: TK).
Double
-> SingletonTK y
-> Concrete y
-> Concrete (ADTensorKind y)
-> Concrete y
updateWithGradient Double
gamma SingletonTK
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
Concrete (XParams2 r Float)
parameters Concrete
(ADTensorKind
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))))
Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
gradient)
chunk :: [MnistDataLinearR r]
chunk = Int -> [MnistDataLinearR r] -> [MnistDataLinearR r]
forall a. Int -> [a] -> [a]
take Int
batchSize [MnistDataLinearR r]
xs
gradf :: [MnistDataLinearR r] -> Concrete (XParams2 r Float)
gradf [MnistDataLinearR r]
c = [MnistDataLinearR r]
-> Concrete (XParams2 r Float) -> Concrete (XParams2 r Float)
go [MnistDataLinearR r]
c Concrete (XParams2 r Float)
targetInit
name :: String
name =
String
prefix
String -> String -> String
forall a. [a] -> [a] -> [a]
++ [String] -> String
unwords
[ String
"v0 m" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (SingletonTK (XParams2 r Float) -> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK (SingletonTK (XParams2 r Float) -> Int)
-> SingletonTK (XParams2 r Float) -> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams2 r Float))
, String
"=" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (SingletonTK
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
Concrete (XParams2 r Float)
targetInit) ]
String -> Benchmarkable -> Benchmark
bench String
name (Benchmarkable -> Benchmark) -> Benchmarkable -> Benchmark
forall a b. (a -> b) -> a -> b
$ ([MnistDataLinearR r]
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))))
-> [MnistDataLinearR r] -> Benchmarkable
forall b a. NFData b => (a -> b) -> a -> Benchmarkable
nf [MnistDataLinearR r]
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
[MnistDataLinearR r] -> Concrete (XParams2 r Float)
gradf [MnistDataLinearR r]
chunk
mnistBGroup2VTOX :: Int -> Benchmark
mnistBGroup2VTOX :: Int -> Benchmark
mnistBGroup2VTOX Int
chunkLength =
let (!Concrete (XParams2 Double Float)
targetInit, !AstArtifactRev
(TKProduct
(XParams2 Double Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar Double)
art) =
forall r q.
(GoodScalar r, Differentiable r, GoodScalar q, Differentiable q) =>
Proxy @Type q
-> IncomingCotangentHandling
-> Double
-> StdGen
-> Int
-> Int
-> (Concrete (XParams2 r q),
AstArtifactRev
(TKProduct
(XParams2 r q)
(TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKScalar r))
MnistFcnnRanked2.mnistTrainBench2VTOGradientX
@Double (forall t. Proxy @Type t
forall {k} (t :: k). Proxy @k t
Proxy @Float) IncomingCotangentHandling
IgnoreIncomingCotangent
Double
1 (Int -> StdGen
mkStdGen Int
44) Int
1500 Int
500
in IO [MnistDataLinearR Double]
-> ([MnistDataLinearR Double] -> Benchmark) -> Benchmark
forall env. NFData env => IO env -> (env -> Benchmark) -> Benchmark
env (do
testData0 <- String -> String -> IO [MnistData Double]
forall r.
(Storable r, Fractional r) =>
String -> String -> IO [MnistData r]
loadMnistData String
testGlyphsPath String
testLabelsPath
let testData = StdGen -> [MnistData Double] -> [MnistData Double]
forall a. StdGen -> [a] -> [a]
shuffle (Int -> StdGen
mkStdGen Int
42) [MnistData Double]
testData0
return $! map mkMnistDataLinearR $ take chunkLength testData) (([MnistDataLinearR Double] -> Benchmark) -> Benchmark)
-> ([MnistDataLinearR Double] -> Benchmark) -> Benchmark
forall a b. (a -> b) -> a -> b
$
\ [MnistDataLinearR Double]
xs ->
String -> [Benchmark] -> Benchmark
bgroup (String
"2-hidden-layer rank 2 VTOX runtime MNIST nn with samples: "
String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
chunkLength)
[ String
-> Double
-> Int
-> [MnistDataLinearR Double]
-> (Concrete (XParams2 Double Float),
AstArtifactRev
(TKProduct
(XParams2 Double Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar Double))
-> Benchmark
forall r.
((r :: Type) ~ (Double :: Type)) =>
String
-> Double
-> Int
-> [MnistDataLinearR r]
-> (Concrete (XParams2 r Float),
AstArtifactRev
(TKProduct
(XParams2 r Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar r))
-> Benchmark
mnistTrainBench2VTOO String
"1500|500 " Double
0.02 Int
chunkLength [MnistDataLinearR Double]
xs (Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
Concrete (XParams2 Double Float)
targetInit, AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double))
(TKProduct (TKR2 2 (TKScalar Float)) (TKR 1 Double)))
(TKProduct (TKR2 2 (TKScalar Double)) (TKR 1 Double)))
(TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar Double)
AstArtifactRev
(TKProduct
(XParams2 Double Float) (TKProduct (TKR 1 Double) (TKR 1 Double)))
(TKScalar Double)
art)
]