{-# OPTIONS_GHC -Wno-missing-export-lists #-}
module MnistFcnnRanked2 where
import Prelude
import Data.Proxy (Proxy (Proxy))
import Data.Vector.Generic qualified as V
import GHC.Exts (inline)
import GHC.TypeLits (Nat)
import System.Random
import Data.Array.Nested qualified as Nested
import Data.Array.Nested.Ranked.Shape
import HordeAd
import HordeAd.Core.Adaptor
import HordeAd.Core.CarriersAst
import MnistData
type
(target :: Target) (widthHidden :: Nat) (widthHidden2 :: Nat) r q =
( ( target (TKS '[widthHidden, SizeMnistGlyph] r)
, target (TKS '[widthHidden] r) )
, ( target (TKS '[widthHidden2, widthHidden] q)
, target (TKS '[widthHidden2] r) )
, ( target (TKS '[SizeMnistLabel, widthHidden2] r)
, target (TKS '[SizeMnistLabel] r) )
)
type ADFcnnMnist2Parameters (target :: Target) r q =
( ( target (TKR 2 r)
, target (TKR 1 r) )
, ( target (TKR 2 q)
, target (TKR 1 r) )
, ( target (TKR 2 r)
, target (TKR 1 r) )
)
type XParams2 r q = X (MnistFcnnRanked2.ADFcnnMnist2Parameters Concrete r q)
afcnnMnist2 :: ( ADReady target, GoodScalar r, Differentiable r
, GoodScalar q, Differentiable q )
=> (target (TKR 1 r) -> target (TKR 1 r))
-> (target (TKR 1 r) -> target (TKR 1 r))
-> target (TKR 1 r)
-> ADFcnnMnist2Parameters target r q
-> target (TKR 1 r)
afcnnMnist2 :: forall (target :: Target) r q.
(ADReady target, GoodScalar r, Differentiable r, GoodScalar q,
Differentiable q) =>
(target (TKR 1 r) -> target (TKR 1 r))
-> (target (TKR 1 r) -> target (TKR 1 r))
-> target (TKR 1 r)
-> ADFcnnMnist2Parameters target r q
-> target (TKR 1 r)
afcnnMnist2 target (TKR 1 r) -> target (TKR 1 r)
factivationHidden target (TKR 1 r) -> target (TKR 1 r)
factivationOutput
target (TKR 1 r)
datum ((target (TKR 2 r)
hidden, target (TKR 1 r)
bias), (target (TKR 2 q)
hidden2, target (TKR 1 r)
bias2), (target (TKR 2 r)
readout, target (TKR 1 r)
biasr)) =
let hiddenLayer1 :: target (TKR 1 r)
hiddenLayer1 = target (TKR 2 r) -> target (TKR 1 r) -> target (TKR 1 r)
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
target (TKR 2 r) -> target (TKR 1 r) -> target (TKR 1 r)
rmatvecmul target (TKR 2 r)
hidden target (TKR 1 r)
datum target (TKR 1 r) -> target (TKR 1 r) -> target (TKR 1 r)
forall a. Num a => a -> a -> a
+ target (TKR 1 r)
bias
nonlinearLayer1 :: target (TKR 1 r)
nonlinearLayer1 = target (TKR 1 r) -> target (TKR 1 r)
factivationHidden target (TKR 1 r)
hiddenLayer1
hiddenLayer2 :: target (TKR 1 r)
hiddenLayer2 = target (TKR 1 q) -> target (TKR 1 r)
forall r1 r2 (target :: Target) (n :: Nat).
(RealFrac r1, GoodScalar r1, RealFrac r2, GoodScalar r2,
BaseTensor target) =>
target (TKR n r1) -> target (TKR n r2)
rcast (target (TKR 2 q) -> target (TKR 1 q) -> target (TKR 1 q)
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
target (TKR 2 r) -> target (TKR 1 r) -> target (TKR 1 r)
rmatvecmul target (TKR 2 q)
hidden2 (target (TKR 1 r) -> target (TKR 1 q)
forall r1 r2 (target :: Target) (n :: Nat).
(RealFrac r1, GoodScalar r1, RealFrac r2, GoodScalar r2,
BaseTensor target) =>
target (TKR n r1) -> target (TKR n r2)
rcast target (TKR 1 r)
nonlinearLayer1)) target (TKR 1 r) -> target (TKR 1 r) -> target (TKR 1 r)
forall a. Num a => a -> a -> a
+ target (TKR 1 r)
bias2
nonlinearLayer2 :: target (TKR 1 r)
nonlinearLayer2 = target (TKR 1 r) -> target (TKR 1 r)
factivationHidden target (TKR 1 r)
hiddenLayer2
outputLayer :: target (TKR 1 r)
outputLayer = target (TKR 2 r) -> target (TKR 1 r) -> target (TKR 1 r)
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
target (TKR 2 r) -> target (TKR 1 r) -> target (TKR 1 r)
rmatvecmul target (TKR 2 r)
readout target (TKR 1 r)
nonlinearLayer2 target (TKR 1 r) -> target (TKR 1 r) -> target (TKR 1 r)
forall a. Num a => a -> a -> a
+ target (TKR 1 r)
biasr
in target (TKR 1 r) -> target (TKR 1 r)
factivationOutput target (TKR 1 r)
outputLayer
afcnnMnistLoss2
:: ( ADReady target, GoodScalar r, Differentiable r
, GoodScalar q, Differentiable q )
=> (target (TKR 1 r), target (TKR 1 r)) -> ADFcnnMnist2Parameters target r q
-> target (TKScalar r)
afcnnMnistLoss2 :: forall (target :: Target) r q.
(ADReady target, GoodScalar r, Differentiable r, GoodScalar q,
Differentiable q) =>
(target (TKR 1 r), target (TKR 1 r))
-> ADFcnnMnist2Parameters target r q -> target (TKScalar r)
afcnnMnistLoss2 (target (TKR 1 r)
datum, target (TKR 1 r)
target) ADFcnnMnist2Parameters target r q
adparams =
let result :: target (TKR 1 r)
result = ((target (TKR 1 r) -> target (TKR 1 r))
-> (target (TKR 1 r) -> target (TKR 1 r))
-> target (TKR 1 r)
-> ADFcnnMnist2Parameters target r q
-> target (TKR 1 r))
-> (target (TKR 1 r) -> target (TKR 1 r))
-> (target (TKR 1 r) -> target (TKR 1 r))
-> target (TKR 1 r)
-> ADFcnnMnist2Parameters target r q
-> target (TKR 1 r)
forall a. a -> a
inline (target (TKR 1 r) -> target (TKR 1 r))
-> (target (TKR 1 r) -> target (TKR 1 r))
-> target (TKR 1 r)
-> ADFcnnMnist2Parameters target r q
-> target (TKR 1 r)
forall (target :: Target) r q.
(ADReady target, GoodScalar r, Differentiable r, GoodScalar q,
Differentiable q) =>
(target (TKR 1 r) -> target (TKR 1 r))
-> (target (TKR 1 r) -> target (TKR 1 r))
-> target (TKR 1 r)
-> ADFcnnMnist2Parameters target r q
-> target (TKR 1 r)
afcnnMnist2 target (TKR 1 r) -> target (TKR 1 r)
forall (target :: Target) r (n :: Nat).
(BaseTensor target, LetTensor target, BaseTensor (PrimalOf target),
KnownNat n, GoodScalar r, Differentiable r) =>
target (TKR n r) -> target (TKR n r)
logistic target (TKR 1 r) -> target (TKR 1 r)
forall (target :: Target) (n :: Nat) r.
(BaseTensor target, LetTensor target, KnownNat n, GoodScalar r,
Differentiable r) =>
target (TKR n r) -> target (TKR n r)
softMax1 target (TKR 1 r)
datum ADFcnnMnist2Parameters target r q
adparams
in target (TKR 1 r) -> target (TKR 1 r) -> target (TKScalar r)
forall (target :: Target) (n :: Nat) r.
(BaseTensor target, ConvertTensor target, KnownNat n, GoodScalar r,
Differentiable r) =>
target (TKR n r) -> target (TKR n r) -> target (TKScalar r)
lossCrossEntropyV target (TKR 1 r)
target target (TKR 1 r)
result
{-# SPECIALIZE afcnnMnistLoss2 :: (ADVal Concrete (TKR 1 Double), ADVal Concrete (TKR 1 Double)) -> ADFcnnMnist2Parameters (ADVal Concrete) Double Float -> ADVal Concrete (TKScalar Double) #-}
{-# SPECIALIZE afcnnMnistLoss2 :: (ADVal Concrete (TKR 1 Float), ADVal Concrete (TKR 1 Float)) -> ADFcnnMnist2Parameters (ADVal Concrete) Float Float -> ADVal Concrete (TKScalar Float) #-}
{-# SPECIALIZE afcnnMnistLoss2 :: (ADVal Concrete (TKR 1 Double), ADVal Concrete (TKR 1 Double)) -> ADFcnnMnist2Parameters (ADVal Concrete) Double Double -> ADVal Concrete (TKScalar Double) #-}
afcnnMnistTest2
:: forall target r q.
( target ~ Concrete, GoodScalar r, Differentiable r
, GoodScalar q, Differentiable q )
=> [MnistDataLinearR r]
-> ADFcnnMnist2Parameters target r q
-> r
afcnnMnistTest2 :: forall (target :: Target) r q.
((target :: Target) ~ (Concrete :: Target), GoodScalar r,
Differentiable r, GoodScalar q, Differentiable q) =>
[MnistDataLinearR r] -> ADFcnnMnist2Parameters target r q -> r
afcnnMnistTest2 [] ADFcnnMnist2Parameters target r q
_ = r
0
afcnnMnistTest2 [MnistDataLinearR r]
dataList ADFcnnMnist2Parameters target r q
testParams =
let matchesLabels :: MnistDataLinearR r -> Bool
matchesLabels :: MnistDataLinearR r -> Bool
matchesLabels (Ranked 1 r
glyph, Ranked 1 r
label) =
let glyph1 :: target (TKR 1 r)
glyph1 = Ranked 1 r -> target (TKR 1 r)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 1 r
glyph
nn :: ADFcnnMnist2Parameters target r q
-> target (TKR 1 r)
nn :: ADFcnnMnist2Parameters target r q -> target (TKR 1 r)
nn = ((target (TKR 1 r) -> target (TKR 1 r))
-> (target (TKR 1 r) -> target (TKR 1 r))
-> target (TKR 1 r)
-> ADFcnnMnist2Parameters target r q
-> target (TKR 1 r))
-> (target (TKR 1 r) -> target (TKR 1 r))
-> (target (TKR 1 r) -> target (TKR 1 r))
-> target (TKR 1 r)
-> ADFcnnMnist2Parameters target r q
-> target (TKR 1 r)
forall a. a -> a
inline (target (TKR 1 r) -> target (TKR 1 r))
-> (target (TKR 1 r) -> target (TKR 1 r))
-> target (TKR 1 r)
-> ADFcnnMnist2Parameters target r q
-> target (TKR 1 r)
forall (target :: Target) r q.
(ADReady target, GoodScalar r, Differentiable r, GoodScalar q,
Differentiable q) =>
(target (TKR 1 r) -> target (TKR 1 r))
-> (target (TKR 1 r) -> target (TKR 1 r))
-> target (TKR 1 r)
-> ADFcnnMnist2Parameters target r q
-> target (TKR 1 r)
afcnnMnist2 target (TKR 1 r) -> target (TKR 1 r)
forall (target :: Target) r (n :: Nat).
(BaseTensor target, LetTensor target, BaseTensor (PrimalOf target),
KnownNat n, GoodScalar r, Differentiable r) =>
target (TKR n r) -> target (TKR n r)
logistic target (TKR 1 r) -> target (TKR 1 r)
forall (target :: Target) (n :: Nat) r.
(BaseTensor target, LetTensor target, KnownNat n, GoodScalar r,
Differentiable r) =>
target (TKR n r) -> target (TKR n r)
softMax1 target (TKR 1 r)
glyph1
v :: Vector r
v = Concrete (TKR 1 r) -> Vector r
forall r (n :: Nat). GoodScalar r => Concrete (TKR n r) -> Vector r
rtoVector (Concrete (TKR 1 r) -> Vector r) -> Concrete (TKR 1 r) -> Vector r
forall a b. (a -> b) -> a -> b
$ ADFcnnMnist2Parameters target r q -> target (TKR 1 r)
nn ADFcnnMnist2Parameters target r q
testParams
in Vector r -> Int
forall (v :: Type -> Type) a. (Vector v a, Ord a) => v a -> Int
V.maxIndex Vector r
v Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Vector r -> Int
forall (v :: Type -> Type) a. (Vector v a, Ord a) => v a -> Int
V.maxIndex (Ranked 1 r -> Vector r
forall a (n :: Nat). PrimElt a => Ranked n a -> Vector a
Nested.rtoVector Ranked 1 r
label)
in Int -> r
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([MnistDataLinearR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length ((MnistDataLinearR r -> Bool)
-> [MnistDataLinearR r] -> [MnistDataLinearR r]
forall a. (a -> Bool) -> [a] -> [a]
filter MnistDataLinearR r -> Bool
matchesLabels [MnistDataLinearR r]
dataList))
r -> r -> r
forall a. Fractional a => a -> a -> a
/ Int -> r
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([MnistDataLinearR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataLinearR r]
dataList)
mnistTrainBench2VTOGradient
:: forall r q. ( GoodScalar r, Differentiable r
, GoodScalar q, Differentiable q )
=> Proxy q -> IncomingCotangentHandling -> Double -> StdGen -> Int -> Int
-> ( Concrete (XParams2 r q)
, AstArtifactRev
(TKProduct
(XParams2 r q)
(TKProduct (TKR2 1 (TKScalar r))
(TKR2 1 (TKScalar r))))
(TKScalar r) )
mnistTrainBench2VTOGradient :: forall r q.
(GoodScalar r, Differentiable r, GoodScalar q, Differentiable q) =>
Proxy @Type q
-> IncomingCotangentHandling
-> Double
-> StdGen
-> Int
-> Int
-> (Concrete (XParams2 r q),
AstArtifactRev
(TKProduct
(XParams2 r q)
(TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKScalar r))
mnistTrainBench2VTOGradient Proxy @Type q
Proxy IncomingCotangentHandling
cotangentHandling Double
range StdGen
seed Int
widthHidden Int
widthHidden2 =
Int
-> (forall (n :: Nat).
KnownNat n =>
SNat n
-> (Concrete (XParams2 r q),
AstArtifactRev
(TKProduct
(XParams2 r q)
(TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKScalar r)))
-> (Concrete (XParams2 r q),
AstArtifactRev
(TKProduct
(XParams2 r q)
(TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKScalar r))
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
widthHidden ((forall (n :: Nat).
KnownNat n =>
SNat n
-> (Concrete (XParams2 r q),
AstArtifactRev
(TKProduct
(XParams2 r q)
(TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKScalar r)))
-> (Concrete (XParams2 r q),
AstArtifactRev
(TKProduct
(XParams2 r q)
(TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKScalar r)))
-> (forall (n :: Nat).
KnownNat n =>
SNat n
-> (Concrete (XParams2 r q),
AstArtifactRev
(TKProduct
(XParams2 r q)
(TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKScalar r)))
-> (Concrete (XParams2 r q),
AstArtifactRev
(TKProduct
(XParams2 r q)
(TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKScalar r))
forall a b. (a -> b) -> a -> b
$ \(SNat @widthHidden) ->
Int
-> (forall (n :: Nat).
KnownNat n =>
SNat n
-> (Concrete (XParams2 r q),
AstArtifactRev
(TKProduct
(XParams2 r q)
(TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKScalar r)))
-> (Concrete (XParams2 r q),
AstArtifactRev
(TKProduct
(XParams2 r q)
(TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKScalar r))
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
widthHidden2 ((forall (n :: Nat).
KnownNat n =>
SNat n
-> (Concrete (XParams2 r q),
AstArtifactRev
(TKProduct
(XParams2 r q)
(TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKScalar r)))
-> (Concrete (XParams2 r q),
AstArtifactRev
(TKProduct
(XParams2 r q)
(TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKScalar r)))
-> (forall (n :: Nat).
KnownNat n =>
SNat n
-> (Concrete (XParams2 r q),
AstArtifactRev
(TKProduct
(XParams2 r q)
(TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKScalar r)))
-> (Concrete (XParams2 r q),
AstArtifactRev
(TKProduct
(XParams2 r q)
(TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKScalar r))
forall a b. (a -> b) -> a -> b
$ \(SNat @widthHidden2) ->
let targetInit :: NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar q))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
targetInit =
Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar q))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar q))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar q))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar q))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar q))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar q))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
forall a b. (a -> b) -> a -> b
$ (Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r q)),
StdGen)
-> Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r q))
forall a b. (a, b) -> a
fst
((Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r q)),
StdGen)
-> Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r q)))
-> (Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r q)),
StdGen)
-> Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r q))
forall a b. (a -> b) -> a -> b
$ forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue @(Concrete (X (MnistFcnnRanked2.ADFcnnMnist2ParametersShaped
Concrete widthHidden widthHidden2 r q)))
Double
range StdGen
seed
ftk :: FullShapeTK
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar q)) (TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
ftk = forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams2 r q)) Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar q)) (TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar q))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
targetInit
ftkData :: FullShapeTK
(TKProduct (TKR2 (0 + 1) (TKScalar r)) (TKR2 (0 + 1) (TKScalar r)))
ftkData = FullShapeTK (TKR2 (0 + 1) (TKScalar r))
-> FullShapeTK (TKR2 (0 + 1) (TKScalar r))
-> FullShapeTK
(TKProduct (TKR2 (0 + 1) (TKScalar r)) (TKR2 (0 + 1) (TKScalar r)))
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct (IShR (0 + 1)
-> FullShapeTK (TKScalar r)
-> FullShapeTK (TKR2 (0 + 1) (TKScalar r))
forall (n :: Nat) (x :: TK).
IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
FTKR (Int
sizeMnistGlyphInt Int -> ShR 0 Int -> IShR (0 + 1)
forall {n1 :: Nat} {i} (n :: Nat).
((n + 1 :: Nat) ~ (n1 :: Nat)) =>
i -> ShR n i -> ShR n1 i
:$: ShR 0 Int
forall (n :: Nat) i. ((n :: Nat) ~ (0 :: Nat)) => ShR n i
ZSR) FullShapeTK (TKScalar r)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar)
(IShR (0 + 1)
-> FullShapeTK (TKScalar r)
-> FullShapeTK (TKR2 (0 + 1) (TKScalar r))
forall (n :: Nat) (x :: TK).
IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
FTKR (Int
sizeMnistLabelInt Int -> ShR 0 Int -> IShR (0 + 1)
forall {n1 :: Nat} {i} (n :: Nat).
((n + 1 :: Nat) ~ (n1 :: Nat)) =>
i -> ShR n i -> ShR n1 i
:$: ShR 0 Int
forall (n :: Nat) i. ((n :: Nat) ~ (0 :: Nat)) => ShR n i
ZSR) FullShapeTK (TKScalar r)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar)
f :: ( MnistFcnnRanked2.ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) r q
, ( AstTensor AstMethodLet FullSpan (TKR 1 r)
, AstTensor AstMethodLet FullSpan (TKR 1 r) ) )
-> AstTensor AstMethodLet FullSpan (TKScalar r)
f :: (ADFcnnMnist2Parameters (AstTensor AstMethodLet FullSpan) r q,
(AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)),
AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))))
-> AstTensor AstMethodLet FullSpan (TKScalar r)
f (ADFcnnMnist2Parameters (AstTensor AstMethodLet FullSpan) r q
pars, (AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))
glyphR, AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))
labelR)) =
(AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)),
AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)))
-> ADFcnnMnist2Parameters (AstTensor AstMethodLet FullSpan) r q
-> AstTensor AstMethodLet FullSpan (TKScalar r)
forall (target :: Target) r q.
(ADReady target, GoodScalar r, Differentiable r, GoodScalar q,
Differentiable q) =>
(target (TKR 1 r), target (TKR 1 r))
-> ADFcnnMnist2Parameters target r q -> target (TKScalar r)
afcnnMnistLoss2 (AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))
glyphR, AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))
labelR) ADFcnnMnist2Parameters (AstTensor AstMethodLet FullSpan) r q
pars
artRaw :: AstArtifactRev
(X (ADFcnnMnist2Parameters (AstTensor AstMethodLet FullSpan) r q,
(AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)),
AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)))))
(TKScalar r)
artRaw = IncomingCotangentHandling
-> ((ADFcnnMnist2Parameters (AstTensor AstMethodLet FullSpan) r q,
(AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)),
AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))))
-> AstTensor AstMethodLet FullSpan (TKScalar r))
-> FullShapeTK
(X (ADFcnnMnist2Parameters (AstTensor AstMethodLet FullSpan) r q,
(AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)),
AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)))))
-> AstArtifactRev
(X (ADFcnnMnist2Parameters (AstTensor AstMethodLet FullSpan) r q,
(AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)),
AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)))))
(TKScalar r)
forall src (ztgt :: TK) tgt.
(AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
(tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
IncomingCotangentHandling
-> (src -> tgt)
-> FullShapeTK (X src)
-> AstArtifactRev (X src) ztgt
revArtifactAdapt IncomingCotangentHandling
cotangentHandling (ADFcnnMnist2Parameters (AstTensor AstMethodLet FullSpan) r q,
(AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)),
AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))))
-> AstTensor AstMethodLet FullSpan (TKScalar r)
f (FullShapeTK
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar q)) (TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> FullShapeTK
(TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r)))
-> FullShapeTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar q)) (TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct FullShapeTK
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar q)) (TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
ftk FullShapeTK (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r)))
ftkData)
in (Concrete (XParams2 r q)
NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar q))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
targetInit, AstArtifactRev
(TKProduct
(XParams2 r q)
(TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKScalar r)
AstArtifactRev
(X (ADFcnnMnist2Parameters (AstTensor AstMethodLet FullSpan) r q,
(AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)),
AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)))))
(TKScalar r)
artRaw)
{-# SPECIALIZE mnistTrainBench2VTOGradient :: Proxy Float -> IncomingCotangentHandling -> Double -> StdGen -> Int -> Int -> ( Concrete (XParams2 Double Float), AstArtifactRev (TKProduct (XParams2 Double Float) (TKProduct (TKR2 1 (TKScalar Double)) (TKR2 1 (TKScalar Double)))) (TKScalar Double) ) #-}
{-# SPECIALIZE mnistTrainBench2VTOGradient :: Proxy Float -> IncomingCotangentHandling -> Double -> StdGen -> Int -> Int -> ( Concrete (XParams2 Float Float), AstArtifactRev (TKProduct (XParams2 Float Float) (TKProduct (TKR2 1 (TKScalar Float)) (TKR2 1 (TKScalar Float)))) (TKScalar Float) ) #-}
{-# SPECIALIZE mnistTrainBench2VTOGradient :: Proxy Double -> IncomingCotangentHandling -> Double -> StdGen -> Int -> Int -> ( Concrete (XParams2 Double Double), AstArtifactRev (TKProduct (XParams2 Double Double) (TKProduct (TKR2 1 (TKScalar Double)) (TKR2 1 (TKScalar Double)))) (TKScalar Double) ) #-}
mnistTrainBench2VTOGradientX
:: forall r q. ( GoodScalar r, Differentiable r
, GoodScalar q, Differentiable q )
=> Proxy q -> IncomingCotangentHandling -> Double -> StdGen -> Int -> Int
-> ( Concrete (XParams2 r q)
, AstArtifactRev
(TKProduct
(XParams2 r q)
(TKProduct (TKR2 1 (TKScalar r))
(TKR2 1 (TKScalar r))))
(TKScalar r) )
mnistTrainBench2VTOGradientX :: forall r q.
(GoodScalar r, Differentiable r, GoodScalar q, Differentiable q) =>
Proxy @Type q
-> IncomingCotangentHandling
-> Double
-> StdGen
-> Int
-> Int
-> (Concrete (XParams2 r q),
AstArtifactRev
(TKProduct
(XParams2 r q)
(TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKScalar r))
mnistTrainBench2VTOGradientX Proxy @Type q
Proxy IncomingCotangentHandling
cotangentHandling Double
range StdGen
seed Int
widthHidden Int
widthHidden2 =
Int
-> (forall (n :: Nat).
KnownNat n =>
SNat n
-> (Concrete (XParams2 r q),
AstArtifactRev
(TKProduct
(XParams2 r q)
(TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKScalar r)))
-> (Concrete (XParams2 r q),
AstArtifactRev
(TKProduct
(XParams2 r q)
(TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKScalar r))
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
widthHidden ((forall (n :: Nat).
KnownNat n =>
SNat n
-> (Concrete (XParams2 r q),
AstArtifactRev
(TKProduct
(XParams2 r q)
(TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKScalar r)))
-> (Concrete (XParams2 r q),
AstArtifactRev
(TKProduct
(XParams2 r q)
(TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKScalar r)))
-> (forall (n :: Nat).
KnownNat n =>
SNat n
-> (Concrete (XParams2 r q),
AstArtifactRev
(TKProduct
(XParams2 r q)
(TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKScalar r)))
-> (Concrete (XParams2 r q),
AstArtifactRev
(TKProduct
(XParams2 r q)
(TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKScalar r))
forall a b. (a -> b) -> a -> b
$ \(SNat @widthHidden) ->
Int
-> (forall (n :: Nat).
KnownNat n =>
SNat n
-> (Concrete (XParams2 r q),
AstArtifactRev
(TKProduct
(XParams2 r q)
(TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKScalar r)))
-> (Concrete (XParams2 r q),
AstArtifactRev
(TKProduct
(XParams2 r q)
(TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKScalar r))
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
widthHidden2 ((forall (n :: Nat).
KnownNat n =>
SNat n
-> (Concrete (XParams2 r q),
AstArtifactRev
(TKProduct
(XParams2 r q)
(TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKScalar r)))
-> (Concrete (XParams2 r q),
AstArtifactRev
(TKProduct
(XParams2 r q)
(TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKScalar r)))
-> (forall (n :: Nat).
KnownNat n =>
SNat n
-> (Concrete (XParams2 r q),
AstArtifactRev
(TKProduct
(XParams2 r q)
(TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKScalar r)))
-> (Concrete (XParams2 r q),
AstArtifactRev
(TKProduct
(XParams2 r q)
(TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKScalar r))
forall a b. (a -> b) -> a -> b
$ \(SNat @widthHidden2) ->
let targetInit :: NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar q))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
targetInit =
Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar q))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar q))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar q))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar q))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar q))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar q))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
forall a b. (a -> b) -> a -> b
$ (Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r q)),
StdGen)
-> Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r q))
forall a b. (a, b) -> a
fst
((Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r q)),
StdGen)
-> Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r q)))
-> (Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r q)),
StdGen)
-> Concrete (X (ADFcnnMnist2ParametersShaped Concrete n n r q))
forall a b. (a -> b) -> a -> b
$ forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue @(Concrete (X (MnistFcnnRanked2.ADFcnnMnist2ParametersShaped
Concrete widthHidden widthHidden2 r q)))
Double
range StdGen
seed
ftk :: FullShapeTK
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar q)) (TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
ftk = forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams2 r q)) Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar q)) (TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar q))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
targetInit
ftkData :: FullShapeTK
(TKProduct (TKR2 (0 + 1) (TKScalar r)) (TKR2 (0 + 1) (TKScalar r)))
ftkData = FullShapeTK (TKR2 (0 + 1) (TKScalar r))
-> FullShapeTK (TKR2 (0 + 1) (TKScalar r))
-> FullShapeTK
(TKProduct (TKR2 (0 + 1) (TKScalar r)) (TKR2 (0 + 1) (TKScalar r)))
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct (IShR (0 + 1)
-> FullShapeTK (TKScalar r)
-> FullShapeTK (TKR2 (0 + 1) (TKScalar r))
forall (n :: Nat) (x :: TK).
IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
FTKR (Int
sizeMnistGlyphInt Int -> ShR 0 Int -> IShR (0 + 1)
forall {n1 :: Nat} {i} (n :: Nat).
((n + 1 :: Nat) ~ (n1 :: Nat)) =>
i -> ShR n i -> ShR n1 i
:$: ShR 0 Int
forall (n :: Nat) i. ((n :: Nat) ~ (0 :: Nat)) => ShR n i
ZSR) FullShapeTK (TKScalar r)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar)
(IShR (0 + 1)
-> FullShapeTK (TKScalar r)
-> FullShapeTK (TKR2 (0 + 1) (TKScalar r))
forall (n :: Nat) (x :: TK).
IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
FTKR (Int
sizeMnistLabelInt Int -> ShR 0 Int -> IShR (0 + 1)
forall {n1 :: Nat} {i} (n :: Nat).
((n + 1 :: Nat) ~ (n1 :: Nat)) =>
i -> ShR n i -> ShR n1 i
:$: ShR 0 Int
forall (n :: Nat) i. ((n :: Nat) ~ (0 :: Nat)) => ShR n i
ZSR) FullShapeTK (TKScalar r)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar)
f :: ( MnistFcnnRanked2.ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) r q
, ( AstTensor AstMethodLet FullSpan (TKR 1 r)
, AstTensor AstMethodLet FullSpan (TKR 1 r) ) )
-> AstTensor AstMethodLet FullSpan (TKScalar r)
f :: (ADFcnnMnist2Parameters (AstTensor AstMethodLet FullSpan) r q,
(AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)),
AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))))
-> AstTensor AstMethodLet FullSpan (TKScalar r)
f (((AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar r))
hidden, AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))
bias), (AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar q))
hidden2, AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))
bias2), (AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar r))
readout, AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))
biasr)), (AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))
glyphR, AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))
labelR)) =
AstNoSimplify FullSpan (TKScalar r)
-> AstTensor AstMethodLet FullSpan (TKScalar r)
forall (s :: AstSpanType) (y :: TK).
AstNoSimplify s y -> AstTensor AstMethodLet s y
unAstNoSimplify
(AstNoSimplify FullSpan (TKScalar r)
-> AstTensor AstMethodLet FullSpan (TKScalar r))
-> AstNoSimplify FullSpan (TKScalar r)
-> AstTensor AstMethodLet FullSpan (TKScalar r)
forall a b. (a -> b) -> a -> b
$ (AstNoSimplify FullSpan (TKR2 1 (TKScalar r)),
AstNoSimplify FullSpan (TKR2 1 (TKScalar r)))
-> ADFcnnMnist2Parameters (AstNoSimplify FullSpan) r q
-> AstNoSimplify FullSpan (TKScalar r)
forall (target :: Target) r q.
(ADReady target, GoodScalar r, Differentiable r, GoodScalar q,
Differentiable q) =>
(target (TKR 1 r), target (TKR 1 r))
-> ADFcnnMnist2Parameters target r q -> target (TKScalar r)
afcnnMnistLoss2 (AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))
-> AstNoSimplify FullSpan (TKR2 1 (TKScalar r))
forall (s :: AstSpanType) (y :: TK).
AstTensor AstMethodLet s y -> AstNoSimplify s y
AstNoSimplify AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))
glyphR, AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))
-> AstNoSimplify FullSpan (TKR2 1 (TKScalar r))
forall (s :: AstSpanType) (y :: TK).
AstTensor AstMethodLet s y -> AstNoSimplify s y
AstNoSimplify AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))
labelR)
( (AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar r))
-> AstNoSimplify FullSpan (TKR2 2 (TKScalar r))
forall (s :: AstSpanType) (y :: TK).
AstTensor AstMethodLet s y -> AstNoSimplify s y
AstNoSimplify AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar r))
hidden, AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))
-> AstNoSimplify FullSpan (TKR2 1 (TKScalar r))
forall (s :: AstSpanType) (y :: TK).
AstTensor AstMethodLet s y -> AstNoSimplify s y
AstNoSimplify AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))
bias)
, (AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar q))
-> AstNoSimplify FullSpan (TKR2 2 (TKScalar q))
forall (s :: AstSpanType) (y :: TK).
AstTensor AstMethodLet s y -> AstNoSimplify s y
AstNoSimplify AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar q))
hidden2, AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))
-> AstNoSimplify FullSpan (TKR2 1 (TKScalar r))
forall (s :: AstSpanType) (y :: TK).
AstTensor AstMethodLet s y -> AstNoSimplify s y
AstNoSimplify AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))
bias2)
, (AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar r))
-> AstNoSimplify FullSpan (TKR2 2 (TKScalar r))
forall (s :: AstSpanType) (y :: TK).
AstTensor AstMethodLet s y -> AstNoSimplify s y
AstNoSimplify AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar r))
readout, AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))
-> AstNoSimplify FullSpan (TKR2 1 (TKScalar r))
forall (s :: AstSpanType) (y :: TK).
AstTensor AstMethodLet s y -> AstNoSimplify s y
AstNoSimplify AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))
biasr) )
artRaw :: AstArtifactRev
(X (ADFcnnMnist2Parameters (AstTensor AstMethodLet FullSpan) r q,
(AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)),
AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)))))
(TKScalar r)
artRaw = IncomingCotangentHandling
-> ((ADFcnnMnist2Parameters (AstTensor AstMethodLet FullSpan) r q,
(AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)),
AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))))
-> AstTensor AstMethodLet FullSpan (TKScalar r))
-> FullShapeTK
(X (ADFcnnMnist2Parameters (AstTensor AstMethodLet FullSpan) r q,
(AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)),
AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)))))
-> AstArtifactRev
(X (ADFcnnMnist2Parameters (AstTensor AstMethodLet FullSpan) r q,
(AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)),
AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)))))
(TKScalar r)
forall src (ztgt :: TK) tgt.
(AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
(tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
IncomingCotangentHandling
-> (src -> tgt)
-> FullShapeTK (X src)
-> AstArtifactRev (X src) ztgt
revArtifactAdapt IncomingCotangentHandling
cotangentHandling (ADFcnnMnist2Parameters (AstTensor AstMethodLet FullSpan) r q,
(AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)),
AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r))))
-> AstTensor AstMethodLet FullSpan (TKScalar r)
f (FullShapeTK
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar q)) (TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> FullShapeTK
(TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r)))
-> FullShapeTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar q)) (TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct FullShapeTK
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar q)) (TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
ftk FullShapeTK (TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r)))
ftkData)
in (Concrete (XParams2 r q)
NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat 784 ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar q))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
targetInit, AstArtifactRev
(TKProduct
(XParams2 r q)
(TKProduct (TKR2 1 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKScalar r)
AstArtifactRev
(X (ADFcnnMnist2Parameters (AstTensor AstMethodLet FullSpan) r q,
(AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)),
AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar r)))))
(TKScalar r)
artRaw)
{-# SPECIALIZE mnistTrainBench2VTOGradientX :: Proxy Float -> IncomingCotangentHandling -> Double -> StdGen -> Int -> Int -> ( Concrete (XParams2 Double Float), AstArtifactRev (TKProduct (XParams2 Double Float) (TKProduct (TKR2 1 (TKScalar Double)) (TKR2 1 (TKScalar Double)))) (TKScalar Double) ) #-}
{-# SPECIALIZE mnistTrainBench2VTOGradientX :: Proxy Float -> IncomingCotangentHandling -> Double -> StdGen -> Int -> Int -> ( Concrete (XParams2 Float Float), AstArtifactRev (TKProduct (XParams2 Float Float) (TKProduct (TKR2 1 (TKScalar Float)) (TKR2 1 (TKScalar Float)))) (TKScalar Float) ) #-}
{-# SPECIALIZE mnistTrainBench2VTOGradientX :: Proxy Double -> IncomingCotangentHandling -> Double -> StdGen -> Int -> Int -> ( Concrete (XParams2 Double Double), AstArtifactRev (TKProduct (XParams2 Double Double) (TKProduct (TKR2 1 (TKScalar Double)) (TKR2 1 (TKScalar Double)))) (TKScalar Double) ) #-}