{-# LANGUAGE OverloadedLists #-}
module TestMnistPP
( testTrees
) where
import Prelude
import GHC.Exts (IsList (..))
import System.Random
import Test.Tasty
import Test.Tasty.HUnit hiding (assert)
import Data.Array.Nested qualified as Nested
import Data.Array.Nested.Ranked.Shape
import HordeAd
import HordeAd.Core.Adaptor
import HordeAd.Core.AstEnv
import HordeAd.Core.AstFreshId
import HordeAd.Core.AstInterpret
import HordeAd.Core.Ops (treplicate)
import MnistCnnRanked2 qualified
import MnistCnnShaped2 qualified
import MnistData
import MnistFcnnRanked1 qualified
import MnistFcnnRanked2 (XParams2)
import MnistFcnnRanked2 qualified
import MnistRnnRanked2 (ADRnnMnistParameters)
import MnistRnnRanked2 qualified
testTrees :: [TestTree]
testTrees :: [TestTree]
testTrees = [ Item [TestTree]
TestTree
tensorMnistPPFCNNR
, Item [TestTree]
TestTree
tensorMnistPPRNNR
, Item [TestTree]
TestTree
tensorMnistCNNRPP
]
type XParams widthHidden widthHidden2 r =
X (MnistFcnnRanked1.ADFcnnMnist1Parameters
Concrete widthHidden widthHidden2 r)
tensorMnistPPFCNNR :: TestTree
tensorMnistPPFCNNR :: TestTree
tensorMnistPPFCNNR = String -> [TestTree] -> TestTree
testGroup String
"PP and Ast tests for Short Ranked MNIST"
[ String -> Assertion -> TestTree
testCase String
"VTO1 PP Lin" Assertion
testVTOPP
, String -> Assertion -> TestTree
testCase String
"VTO1 Ast Lin" Assertion
testVTOAst
, String -> Assertion -> TestTree
testCase String
"VTO1 PP NonLin" Assertion
testVTOPPNonLin
, String -> Assertion -> TestTree
testCase String
"VTO1 Ast NonLin" Assertion
testVTOAstNonLin
, String -> Assertion -> TestTree
testCase String
"VTO2 PP Lin" Assertion
testVT2OPP
, String -> Assertion -> TestTree
testCase String
"VTO2 Ast Lin" Assertion
testVT2OAst
, String -> Assertion -> TestTree
testCase String
"VTO2 PP NonLin" Assertion
testVT2OPPNonLin
, String -> Assertion -> TestTree
testCase String
"VTO2 PP NonLin2" Assertion
testVT2OPPNonLin2
, String -> Assertion -> TestTree
testCase String
"VTO2 Ast NonLin2" Assertion
testVT2OAstNonLin2
, String -> Assertion -> TestTree
testCase String
"VTO2 PP NonLin3" Assertion
testVT2OPPNonLin3
, String -> Assertion -> TestTree
testCase String
"VTO2 Ast NonLin3" Assertion
testVT2OAstNonLin3
]
valsInitVTOPP :: (Num r, Enum r, Nested.PrimElt r)
=> MnistFcnnRanked1.ADFcnnMnist1Parameters Concrete 3 4 r
valsInitVTOPP :: forall r.
(Num r, Enum r, PrimElt r) =>
ADFcnnMnist1Parameters Concrete 3 4 r
valsInitVTOPP =
( ( [Item
(ListR 3 (Concrete (TKS ((':) @Natural 784 ('[] @Natural)) r)))]
-> ListR 3 (Concrete (TKS ((':) @Natural 784 ('[] @Natural)) r))
forall l. IsList l => [Item l] -> l
fromList (Int
-> Concrete (TKS ((':) @Natural 784 ('[] @Natural)) r)
-> [Concrete (TKS ((':) @Natural 784 ('[] @Natural)) r)]
forall a. Int -> a -> [a]
replicate Int
3 (RepConcrete (TKS ((':) @Natural 784 ('[] @Natural)) r)
-> Concrete (TKS ((':) @Natural 784 ('[] @Natural)) r)
forall (y :: TK). RepConcrete y -> Concrete y
Concrete
(RepConcrete (TKS ((':) @Natural 784 ('[] @Natural)) r)
-> Concrete (TKS ((':) @Natural 784 ('[] @Natural)) r))
-> RepConcrete (TKS ((':) @Natural 784 ('[] @Natural)) r)
-> Concrete (TKS ((':) @Natural 784 ('[] @Natural)) r)
forall a b. (a -> b) -> a -> b
$ SNat 784 -> [r] -> Shaped ((':) @Natural 784 ('[] @Natural)) r
forall (n :: Natural) a.
PrimElt a =>
SNat n -> [a] -> Shaped ((':) @Natural n ('[] @Natural)) a
Nested.sfromListPrim
(forall (n :: Natural). KnownNat n => SNat n
SNat @SizeMnistGlyph)
[r
Item [r]
1 .. Int -> r
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
sizeMnistGlyphInt]))
, RepConcrete (TKS ((':) @Natural 3 ('[] @Natural)) r)
-> Concrete (TKS ((':) @Natural 3 ('[] @Natural)) r)
forall (y :: TK). RepConcrete y -> Concrete y
Concrete (RepConcrete (TKS ((':) @Natural 3 ('[] @Natural)) r)
-> Concrete (TKS ((':) @Natural 3 ('[] @Natural)) r))
-> RepConcrete (TKS ((':) @Natural 3 ('[] @Natural)) r)
-> Concrete (TKS ((':) @Natural 3 ('[] @Natural)) r)
forall a b. (a -> b) -> a -> b
$ SNat 3 -> [r] -> Shaped ((':) @Natural 3 ('[] @Natural)) r
forall (n :: Natural) a.
PrimElt a =>
SNat n -> [a] -> Shaped ((':) @Natural n ('[] @Natural)) a
Nested.sfromListPrim (forall (n :: Natural). KnownNat n => SNat n
SNat @3) [r
Item [r]
1, r
Item [r]
2, r
Item [r]
3] )
, ( [Item
(ListR 4 (Concrete (TKS ((':) @Natural 3 ('[] @Natural)) Float)))]
-> ListR 4 (Concrete (TKS ((':) @Natural 3 ('[] @Natural)) Float))
forall l. IsList l => [Item l] -> l
fromList (Int
-> Concrete (TKS ((':) @Natural 3 ('[] @Natural)) Float)
-> [Concrete (TKS ((':) @Natural 3 ('[] @Natural)) Float)]
forall a. Int -> a -> [a]
replicate Int
4 (RepConcrete (TKS ((':) @Natural 3 ('[] @Natural)) Float)
-> Concrete (TKS ((':) @Natural 3 ('[] @Natural)) Float)
forall (y :: TK). RepConcrete y -> Concrete y
Concrete (RepConcrete (TKS ((':) @Natural 3 ('[] @Natural)) Float)
-> Concrete (TKS ((':) @Natural 3 ('[] @Natural)) Float))
-> RepConcrete (TKS ((':) @Natural 3 ('[] @Natural)) Float)
-> Concrete (TKS ((':) @Natural 3 ('[] @Natural)) Float)
forall a b. (a -> b) -> a -> b
$ SNat 3 -> [Float] -> Shaped ((':) @Natural 3 ('[] @Natural)) Float
forall (n :: Natural) a.
PrimElt a =>
SNat n -> [a] -> Shaped ((':) @Natural n ('[] @Natural)) a
Nested.sfromListPrim
(forall (n :: Natural). KnownNat n => SNat n
SNat @3) [Float
Item [Float]
1, Float
Item [Float]
2, Float
Item [Float]
3]))
, RepConcrete (TKS ((':) @Natural 4 ('[] @Natural)) r)
-> Concrete (TKS ((':) @Natural 4 ('[] @Natural)) r)
forall (y :: TK). RepConcrete y -> Concrete y
Concrete (RepConcrete (TKS ((':) @Natural 4 ('[] @Natural)) r)
-> Concrete (TKS ((':) @Natural 4 ('[] @Natural)) r))
-> RepConcrete (TKS ((':) @Natural 4 ('[] @Natural)) r)
-> Concrete (TKS ((':) @Natural 4 ('[] @Natural)) r)
forall a b. (a -> b) -> a -> b
$ SNat 4 -> [r] -> Shaped ((':) @Natural 4 ('[] @Natural)) r
forall (n :: Natural) a.
PrimElt a =>
SNat n -> [a] -> Shaped ((':) @Natural n ('[] @Natural)) a
Nested.sfromListPrim (forall (n :: Natural). KnownNat n => SNat n
SNat @4) [r
Item [r]
1, r
Item [r]
2, r
Item [r]
3, r
Item [r]
4] )
, ( [Item
(ListR
SizeMnistLabel
(Concrete (TKS ((':) @Natural 4 ('[] @Natural)) r)))]
-> ListR
SizeMnistLabel (Concrete (TKS ((':) @Natural 4 ('[] @Natural)) r))
forall l. IsList l => [Item l] -> l
fromList (Int
-> Concrete (TKS ((':) @Natural 4 ('[] @Natural)) r)
-> [Concrete (TKS ((':) @Natural 4 ('[] @Natural)) r)]
forall a. Int -> a -> [a]
replicate Int
sizeMnistLabelInt
(RepConcrete (TKS ((':) @Natural 4 ('[] @Natural)) r)
-> Concrete (TKS ((':) @Natural 4 ('[] @Natural)) r)
forall (y :: TK). RepConcrete y -> Concrete y
Concrete (RepConcrete (TKS ((':) @Natural 4 ('[] @Natural)) r)
-> Concrete (TKS ((':) @Natural 4 ('[] @Natural)) r))
-> RepConcrete (TKS ((':) @Natural 4 ('[] @Natural)) r)
-> Concrete (TKS ((':) @Natural 4 ('[] @Natural)) r)
forall a b. (a -> b) -> a -> b
$ SNat 4 -> [r] -> Shaped ((':) @Natural 4 ('[] @Natural)) r
forall (n :: Natural) a.
PrimElt a =>
SNat n -> [a] -> Shaped ((':) @Natural n ('[] @Natural)) a
Nested.sfromListPrim
(forall (n :: Natural). KnownNat n => SNat n
SNat @4) [r
Item [r]
1, r
Item [r]
2, r
Item [r]
3, r
Item [r]
4]))
, RepConcrete (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r)
-> Concrete (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r)
forall (y :: TK). RepConcrete y -> Concrete y
Concrete (RepConcrete (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r)
-> Concrete (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r))
-> RepConcrete
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r)
-> Concrete (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r)
forall a b. (a -> b) -> a -> b
$ SNat SizeMnistLabel
-> [r] -> Shaped ((':) @Natural SizeMnistLabel ('[] @Natural)) r
forall (n :: Natural) a.
PrimElt a =>
SNat n -> [a] -> Shaped ((':) @Natural n ('[] @Natural)) a
Nested.sfromListPrim (forall (n :: Natural). KnownNat n => SNat n
SNat @SizeMnistLabel)
[r
Item [r]
1, r
Item [r]
2, r
Item [r]
3, r
Item [r]
4, r
Item [r]
5, r
Item [r]
6, r
Item [r]
7, r
Item [r]
8, r
Item [r]
9, r
Item [r]
10] ) )
testVTOPP :: Assertion
testVTOPP :: Assertion
testVTOPP = do
Assertion
resetVarCounter
let blackGlyph :: AstTensor AstMethodLet FullSpan (BuildTensorKind 784 (TKR 0 Float))
blackGlyph = SNat 784
-> SingletonTK (TKR 0 Float)
-> AstTensor AstMethodLet FullSpan (TKR 0 Float)
-> AstTensor
AstMethodLet FullSpan (BuildTensorKind 784 (TKR 0 Float))
forall (z :: TK) (k :: Natural).
ConvertTensor (AstTensor AstMethodLet FullSpan) =>
SNat k
-> SingletonTK z
-> AstTensor AstMethodLet FullSpan z
-> AstTensor AstMethodLet FullSpan (BuildTensorKind k z)
forall (target :: Target) (z :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
treplicate (forall (n :: Natural). KnownNat n => SNat n
SNat @SizeMnistGlyph) SingletonTK (TKR 0 Float)
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK
(AstTensor AstMethodLet FullSpan (TKR 0 Float)
-> AstTensor
AstMethodLet FullSpan (BuildTensorKind 784 (TKR 0 Float)))
-> AstTensor AstMethodLet FullSpan (TKR 0 Float)
-> AstTensor
AstMethodLet FullSpan (BuildTensorKind 784 (TKR 0 Float))
forall a b. (a -> b) -> a -> b
$ AstTensor AstMethodLet PrimalSpan (TKR 0 Float)
-> AstTensor AstMethodLet FullSpan (TKR 0 Float)
forall (ms :: AstMethodOfSharing) (y :: TK).
AstTensor ms PrimalSpan y -> AstTensor ms FullSpan y
forall (s :: AstSpanType) (ms :: AstMethodOfSharing) (y :: TK).
AstSpan s =>
AstTensor ms PrimalSpan y -> AstTensor ms s y
fromPrimal (AstTensor AstMethodLet PrimalSpan (TKR 0 Float)
-> AstTensor AstMethodLet FullSpan (TKR 0 Float))
-> AstTensor AstMethodLet PrimalSpan (TKR 0 Float)
-> AstTensor AstMethodLet FullSpan (TKR 0 Float)
forall a b. (a -> b) -> a -> b
$ Ranked 0 Float -> AstTensor AstMethodLet PrimalSpan (TKR 0 Float)
forall r (target :: Target) (n :: Natural).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete (Ranked 0 Float -> AstTensor AstMethodLet PrimalSpan (TKR 0 Float))
-> Ranked 0 Float
-> AstTensor AstMethodLet PrimalSpan (TKR 0 Float)
forall a b. (a -> b) -> a -> b
$ Float -> Ranked 0 Float
forall a. Elt a => a -> Ranked 0 a
Nested.rscalar Float
7
afcnn2T :: MnistFcnnRanked1.ADFcnnMnist1Parameters
(AstTensor AstMethodLet FullSpan) 3 4 Float
-> AstTensor AstMethodLet FullSpan (TKR 1 Float)
afcnn2T :: ADFcnnMnist1Parameters (AstTensor AstMethodLet FullSpan) 3 4 Float
-> AstTensor AstMethodLet FullSpan (TKR 1 Float)
afcnn2T =
(forall (n :: Natural).
KnownNat n =>
AstTensor
AstMethodLet FullSpan (TKS ((':) @Natural n ('[] @Natural)) Float)
-> AstTensor
AstMethodLet FullSpan (TKS ((':) @Natural n ('[] @Natural)) Float))
-> (AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float)
-> AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float))
-> SNat 3
-> SNat 4
-> AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural SizeMnistGlyph ('[] @Natural)) Float)
-> ADFcnnMnist1Parameters
(AstTensor AstMethodLet FullSpan) 3 4 Float
-> AstTensor AstMethodLet FullSpan (TKR 1 Float)
forall (target :: Target) r (widthHidden :: Natural)
(widthHidden2 :: Natural).
(ADReady target, GoodScalar r, Differentiable r) =>
(forall (n :: Natural).
KnownNat n =>
target (TKS ((':) @Natural n ('[] @Natural)) r)
-> target (TKS ((':) @Natural n ('[] @Natural)) r))
-> (target (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r)
-> target (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r))
-> SNat widthHidden
-> SNat widthHidden2
-> target (TKS ((':) @Natural SizeMnistGlyph ('[] @Natural)) r)
-> ADFcnnMnist1Parameters target widthHidden widthHidden2 r
-> target (TKR 1 r)
MnistFcnnRanked1.afcnnMnist1 AstTensor
AstMethodLet FullSpan (TKS ((':) @Natural n ('[] @Natural)) Float)
-> AstTensor
AstMethodLet FullSpan (TKS ((':) @Natural n ('[] @Natural)) Float)
forall (n :: Natural).
KnownNat n =>
AstTensor
AstMethodLet FullSpan (TKS ((':) @Natural n ('[] @Natural)) Float)
-> AstTensor
AstMethodLet FullSpan (TKS ((':) @Natural n ('[] @Natural)) Float)
forall a. a -> a
id AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float)
-> AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float)
forall a. a -> a
id
(forall (n :: Natural). KnownNat n => SNat n
SNat @3) (forall (n :: Natural). KnownNat n => SNat n
SNat @4) (AstTensor
AstMethodLet
FullSpan
(TKR2
(Rank @Natural ((':) @Natural 784 ('[] @Natural)))
(TKScalar Float))
-> AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
forall (sh :: [Natural]) (x :: TK).
(KnownShS sh, KnownSTK x) =>
AstTensor AstMethodLet FullSpan (TKR2 (Rank @Natural sh) x)
-> AstTensor AstMethodLet FullSpan (TKS2 sh x)
forall (target :: Target) (sh :: [Natural]) (x :: TK).
(ConvertTensor target, KnownShS sh, KnownSTK x) =>
target (TKR2 (Rank @Natural sh) x) -> target (TKS2 sh x)
sfromR AstTensor AstMethodLet FullSpan (TKR 1 Float)
AstTensor
AstMethodLet
FullSpan
(TKR2
(Rank @Natural ((':) @Natural 784 ('[] @Natural)))
(TKScalar Float))
blackGlyph)
ftk :: FullShapeTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
TKUnit)))
(TKS ((':) @Natural 3 ('[] @Natural)) Float))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))))
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float)))
ftk = forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams 3 4 Float))
(forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget @Concrete ((ListR
3
(Concrete
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))),
Concrete (TKS ((':) @Natural 3 ('[] @Natural)) Float)),
(ListR 4 (Concrete (TKS ((':) @Natural 3 ('[] @Natural)) Float)),
Concrete (TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))),
(ListR
SizeMnistLabel
(Concrete
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))),
Concrete
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float)))
ADFcnnMnist1Parameters Concrete 3 4 Float
forall r.
(Num r, Enum r, PrimElt r) =>
ADFcnnMnist1Parameters Concrete 3 4 r
valsInitVTOPP)
artifactRev :: AstArtifactRev
(X ((ListR
3
(AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 3 ('[] @Natural)) Float)),
(ListR
4
(AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 3 ('[] @Natural)) Float)),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))),
(ListR
SizeMnistLabel
(AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float))))
(TKR 1 Float)
artifactRev = IncomingCotangentHandling
-> (((ListR
3
(AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 3 ('[] @Natural)) Float)),
(ListR
4
(AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 3 ('[] @Natural)) Float)),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))),
(ListR
SizeMnistLabel
(AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float)))
-> AstTensor AstMethodLet FullSpan (TKR 1 Float))
-> FullShapeTK
(X ((ListR
3
(AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 3 ('[] @Natural)) Float)),
(ListR
4
(AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 3 ('[] @Natural)) Float)),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))),
(ListR
SizeMnistLabel
(AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float))))
-> AstArtifactRev
(X ((ListR
3
(AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 3 ('[] @Natural)) Float)),
(ListR
4
(AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 3 ('[] @Natural)) Float)),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))),
(ListR
SizeMnistLabel
(AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float))))
(TKR 1 Float)
forall src (ztgt :: TK) tgt.
(AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
(tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
IncomingCotangentHandling
-> (src -> tgt)
-> FullShapeTK (X src)
-> AstArtifactRev (X src) ztgt
revArtifactAdapt IncomingCotangentHandling
UseIncomingCotangent ((ListR
3
(AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 3 ('[] @Natural)) Float)),
(ListR
4
(AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 3 ('[] @Natural)) Float)),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))),
(ListR
SizeMnistLabel
(AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float)))
-> AstTensor AstMethodLet FullSpan (TKR 1 Float)
ADFcnnMnist1Parameters (AstTensor AstMethodLet FullSpan) 3 4 Float
-> AstTensor AstMethodLet FullSpan (TKR 1 Float)
afcnn2T FullShapeTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
TKUnit)))
(TKS ((':) @Natural 3 ('[] @Natural)) Float))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))))
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float)))
FullShapeTK
(X ((ListR
3
(AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 3 ('[] @Natural)) Float)),
(ListR
4
(AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 3 ('[] @Natural)) Float)),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))),
(ListR
SizeMnistLabel
(AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float))))
ftk
AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
TKUnit)))
(TKS ((':) @Natural 3 ('[] @Natural)) Float))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))))
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float)))
(TKR 1 Float)
-> String
forall (x :: TK) (z :: TK). AstArtifactRev x z -> String
printArtifactPrimalPretty (AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
TKUnit)))
(TKS ((':) @Natural 3 ('[] @Natural)) Float))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))))
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float)))
(TKR 1 Float)
-> AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
TKUnit)))
(TKS ((':) @Natural 3 ('[] @Natural)) Float))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))))
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float)))
(TKR 1 Float)
forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstArtifactRev x z
simplifyArtifact AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
TKUnit)))
(TKS ((':) @Natural 3 ('[] @Natural)) Float))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))))
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float)))
(TKR 1 Float)
artifactRev)
String -> String -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= String
"\\v1 -> rfromS (let v4 = sfromVector (fromList [sdot0 (sconcrete (sreplicate [784] 7.0)) (tproject1 (tproject1 (tproject1 (tproject1 v1)))), sdot0 (sconcrete (sreplicate [784] 7.0)) (tproject1 (tproject2 (tproject1 (tproject1 (tproject1 v1))))), sdot0 (sconcrete (sreplicate [784] 7.0)) (tproject1 (tproject2 (tproject2 (tproject1 (tproject1 (tproject1 v1))))))]) + tproject2 (tproject1 (tproject1 v1)) ; v5 = sfromVector (fromList [sdot0 (tproject1 (tproject1 (tproject2 (tproject1 v1)))) v4, sdot0 (tproject1 (tproject2 (tproject1 (tproject2 (tproject1 v1))))) v4, sdot0 (tproject1 (tproject2 (tproject2 (tproject1 (tproject2 (tproject1 v1)))))) v4, sdot0 (tproject1 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 (tproject1 v1))))))) v4]) + tproject2 (tproject2 (tproject1 v1)) in sfromVector (fromList [sdot0 (tproject1 (tproject1 (tproject2 v1))) v5, sdot0 (tproject1 (tproject2 (tproject1 (tproject2 v1)))) v5, sdot0 (tproject1 (tproject2 (tproject2 (tproject1 (tproject2 v1))))) v5, sdot0 (tproject1 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1)))))) v5, sdot0 (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1))))))) v5, sdot0 (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1)))))))) v5, sdot0 (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1))))))))) v5, sdot0 (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1)))))))))) v5, sdot0 (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1))))))))))) v5, sdot0 (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1)))))))))))) v5]) + tproject2 (tproject2 v1))"
AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
TKUnit)))
(TKS ((':) @Natural 3 ('[] @Natural)) Float))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))))
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float)))
(TKR 1 Float)
-> String
forall (x :: TK) (z :: TK). AstArtifactRev x z -> String
printArtifactPrimalPretty AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
TKUnit)))
(TKS ((':) @Natural 3 ('[] @Natural)) Float))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))))
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float)))
(TKR 1 Float)
artifactRev
String -> String -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= String
"\\v1 -> let v4 = sfromVector (fromList [ssum @784 (sconcrete (sreplicate [784] 7.0) * tproject1 (tproject1 (tproject1 (tproject1 v1)))), ssum @784 (sconcrete (sreplicate [784] 7.0) * tproject1 (tproject2 (tproject1 (tproject1 (tproject1 v1))))), ssum @784 (sconcrete (sreplicate [784] 7.0) * tproject1 (tproject2 (tproject2 (tproject1 (tproject1 (tproject1 v1))))))]) + tproject2 (tproject1 (tproject1 v1)) ; v5 = sfromVector (fromList [ssum @3 (tproject1 (tproject1 (tproject2 (tproject1 v1))) * v4), ssum @3 (tproject1 (tproject2 (tproject1 (tproject2 (tproject1 v1)))) * v4), ssum @3 (tproject1 (tproject2 (tproject2 (tproject1 (tproject2 (tproject1 v1))))) * v4), ssum @3 (tproject1 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 (tproject1 v1)))))) * v4)]) + tproject2 (tproject2 (tproject1 v1)) in rfromS (sfromVector (fromList [ssum @4 (tproject1 (tproject1 (tproject2 v1)) * v5), ssum @4 (tproject1 (tproject2 (tproject1 (tproject2 v1))) * v5), ssum @4 (tproject1 (tproject2 (tproject2 (tproject1 (tproject2 v1)))) * v5), ssum @4 (tproject1 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1))))) * v5), ssum @4 (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1)))))) * v5), ssum @4 (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1))))))) * v5), ssum @4 (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1)))))))) * v5), ssum @4 (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1))))))))) * v5), ssum @4 (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1)))))))))) * v5), ssum @4 (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1))))))))))) * v5)]) + tproject2 (tproject2 v1))"
AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
TKUnit)))
(TKS ((':) @Natural 3 ('[] @Natural)) Float))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))))
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float)))
(TKR 1 Float)
-> String
forall (x :: TK) (z :: TK). AstArtifactRev x z -> String
printArtifactPretty AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
TKUnit)))
(TKS ((':) @Natural 3 ('[] @Natural)) Float))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))))
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float)))
(TKR 1 Float)
artifactRev
String -> String -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= String
"\\dret v1 -> let v4 = sfromVector (fromList [ssum @784 (sconcrete (sreplicate [784] 7.0) * tproject1 (tproject1 (tproject1 (tproject1 v1)))), ssum @784 (sconcrete (sreplicate [784] 7.0) * tproject1 (tproject2 (tproject1 (tproject1 (tproject1 v1))))), ssum @784 (sconcrete (sreplicate [784] 7.0) * tproject1 (tproject2 (tproject2 (tproject1 (tproject1 (tproject1 v1))))))]) + tproject2 (tproject1 (tproject1 v1)) ; v5 = sfromVector (fromList [ssum @3 (tproject1 (tproject1 (tproject2 (tproject1 v1))) * v4), ssum @3 (tproject1 (tproject2 (tproject1 (tproject2 (tproject1 v1)))) * v4), ssum @3 (tproject1 (tproject2 (tproject2 (tproject1 (tproject2 (tproject1 v1))))) * v4), ssum @3 (tproject1 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 (tproject1 v1)))))) * v4)]) + tproject2 (tproject2 (tproject1 v1)) ; v7 = sreplicate @4 (sfromR dret !$ [9]) ; v8 = sreplicate @4 (sfromR dret !$ [8]) ; v9 = sreplicate @4 (sfromR dret !$ [7]) ; v10 = sreplicate @4 (sfromR dret !$ [6]) ; v11 = sreplicate @4 (sfromR dret !$ [5]) ; v12 = sreplicate @4 (sfromR dret !$ [4]) ; v13 = sreplicate @4 (sfromR dret !$ [3]) ; v14 = sreplicate @4 (sfromR dret !$ [2]) ; v15 = sreplicate @4 (sfromR dret !$ [1]) ; v16 = sreplicate @4 (sfromR dret !$ [0]) ; v17 = tproject1 (tproject1 (tproject2 v1)) * v16 + (tproject1 (tproject2 (tproject1 (tproject2 v1))) * v15 + (tproject1 (tproject2 (tproject2 (tproject1 (tproject2 v1)))) * v14 + (tproject1 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1))))) * v13 + (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1)))))) * v12 + (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1))))))) * v11 + (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1)))))))) * v10 + (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1))))))))) * v9 + (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1)))))))))) * v8 + tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1))))))))))) * v7)))))))) ; v18 = sreplicate @3 (v17 !$ [3]) ; v19 = sreplicate @3 (v17 !$ [2]) ; v20 = sreplicate @3 (v17 !$ [1]) ; v21 = sreplicate @3 (v17 !$ [0]) ; v22 = tproject1 (tproject1 (tproject2 (tproject1 v1))) * v21 + (tproject1 (tproject2 (tproject1 (tproject2 (tproject1 v1)))) * v20 + (tproject1 (tproject2 (tproject2 (tproject1 (tproject2 (tproject1 v1))))) * v19 + tproject1 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 (tproject1 v1)))))) * v18)) in tpair (tpair (tpair (tpair (sconcrete (sreplicate [784] 7.0) * sreplicate @784 (v22 !$ [0])) (tpair (sconcrete (sreplicate [784] 7.0) * sreplicate @784 (v22 !$ [1])) (tpair (sconcrete (sreplicate [784] 7.0) * sreplicate @784 (v22 !$ [2])) Z1))) v22) (tpair (tpair (v4 * v21) (tpair (v4 * v20) (tpair (v4 * v19) (tpair (v4 * v18) Z1)))) v17)) (tpair (tpair (v5 * v16) (tpair (v5 * v15) (tpair (v5 * v14) (tpair (v5 * v13) (tpair (v5 * v12) (tpair (v5 * v11) (tpair (v5 * v10) (tpair (v5 * v9) (tpair (v5 * v8) (tpair (v5 * v7) Z1)))))))))) (sfromR dret))"
AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
TKUnit)))
(TKS ((':) @Natural 3 ('[] @Natural)) Float))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))))
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float)))
(TKR 1 Float)
-> String
forall (x :: TK) (z :: TK). AstArtifactRev x z -> String
printArtifactPretty (AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
TKUnit)))
(TKS ((':) @Natural 3 ('[] @Natural)) Float))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))))
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float)))
(TKR 1 Float)
-> AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
TKUnit)))
(TKS ((':) @Natural 3 ('[] @Natural)) Float))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))))
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float)))
(TKR 1 Float)
forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstArtifactRev x z
simplifyArtifact AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
TKUnit)))
(TKS ((':) @Natural 3 ('[] @Natural)) Float))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))))
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float)))
(TKR 1 Float)
artifactRev)
String -> String -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= String
"\\dret v1 -> let v4 = sfromVector (fromList [sdot0 (sconcrete (sreplicate [784] 7.0)) (tproject1 (tproject1 (tproject1 (tproject1 v1)))), sdot0 (sconcrete (sreplicate [784] 7.0)) (tproject1 (tproject2 (tproject1 (tproject1 (tproject1 v1))))), sdot0 (sconcrete (sreplicate [784] 7.0)) (tproject1 (tproject2 (tproject2 (tproject1 (tproject1 (tproject1 v1))))))]) + tproject2 (tproject1 (tproject1 v1)) ; v5 = sfromVector (fromList [sdot0 (tproject1 (tproject1 (tproject2 (tproject1 v1)))) v4, sdot0 (tproject1 (tproject2 (tproject1 (tproject2 (tproject1 v1))))) v4, sdot0 (tproject1 (tproject2 (tproject2 (tproject1 (tproject2 (tproject1 v1)))))) v4, sdot0 (tproject1 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 (tproject1 v1))))))) v4]) + tproject2 (tproject2 (tproject1 v1)) ; x7 = sfromR dret !$ [9] ; x8 = sfromR dret !$ [8] ; x9 = sfromR dret !$ [7] ; x10 = sfromR dret !$ [6] ; x11 = sfromR dret !$ [5] ; x12 = sfromR dret !$ [4] ; x13 = sfromR dret !$ [3] ; x14 = sfromR dret !$ [2] ; x15 = sfromR dret !$ [1] ; x16 = sfromR dret !$ [0] ; v17 = tproject1 (tproject1 (tproject2 v1)) * sreplicate @4 x16 + (tproject1 (tproject2 (tproject1 (tproject2 v1))) * sreplicate @4 x15 + (tproject1 (tproject2 (tproject2 (tproject1 (tproject2 v1)))) * sreplicate @4 x14 + (tproject1 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1))))) * sreplicate @4 x13 + (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1)))))) * sreplicate @4 x12 + (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1))))))) * sreplicate @4 x11 + (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1)))))))) * sreplicate @4 x10 + (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1))))))))) * sreplicate @4 x9 + (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1)))))))))) * sreplicate @4 x8 + tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1))))))))))) * sreplicate @4 x7)))))))) ; x18 = v17 !$ [3] ; x19 = v17 !$ [2] ; x20 = v17 !$ [1] ; x21 = v17 !$ [0] ; v22 = tproject1 (tproject1 (tproject2 (tproject1 v1))) * sreplicate @3 x21 + (tproject1 (tproject2 (tproject1 (tproject2 (tproject1 v1)))) * sreplicate @3 x20 + (tproject1 (tproject2 (tproject2 (tproject1 (tproject2 (tproject1 v1))))) * sreplicate @3 x19 + tproject1 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 (tproject1 v1)))))) * sreplicate @3 x18)) in tpair (tpair (tpair (tpair (sconcrete (sreplicate [784] 7.0) * sreplicate @784 (v22 !$ [0])) (tpair (sconcrete (sreplicate [784] 7.0) * sreplicate @784 (v22 !$ [1])) (tpair (sconcrete (sreplicate [784] 7.0) * sreplicate @784 (v22 !$ [2])) Z1))) v22) (tpair (tpair (v4 * sreplicate @3 x21) (tpair (v4 * sreplicate @3 x20) (tpair (v4 * sreplicate @3 x19) (tpair (v4 * sreplicate @3 x18) Z1)))) v17)) (tpair (tpair (v5 * sreplicate @4 x16) (tpair (v5 * sreplicate @4 x15) (tpair (v5 * sreplicate @4 x14) (tpair (v5 * sreplicate @4 x13) (tpair (v5 * sreplicate @4 x12) (tpair (v5 * sreplicate @4 x11) (tpair (v5 * sreplicate @4 x10) (tpair (v5 * sreplicate @4 x9) (tpair (v5 * sreplicate @4 x8) (tpair (v5 * sreplicate @4 x7) Z1)))))))))) (sfromR dret))"
testVTOAst :: Assertion
testVTOAst :: Assertion
testVTOAst = do
let ftk :: FullShapeTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
TKUnit)))
(TKS ((':) @Natural 3 ('[] @Natural)) Float))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))))
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float)))
ftk = forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams 3 4 Float))
(forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget @Concrete ((ListR
3
(Concrete
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))),
Concrete (TKS ((':) @Natural 3 ('[] @Natural)) Float)),
(ListR 4 (Concrete (TKS ((':) @Natural 3 ('[] @Natural)) Float)),
Concrete (TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))),
(ListR
SizeMnistLabel
(Concrete
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))),
Concrete
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float)))
ADFcnnMnist1Parameters Concrete 3 4 Float
forall r.
(Num r, Enum r, PrimElt r) =>
ADFcnnMnist1Parameters Concrete 3 4 r
valsInitVTOPP)
varName :: AstVarName
s
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
TKUnit)))
(TKS ((':) @Natural 3 ('[] @Natural)) Float))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))))
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float)))
varName = FullShapeTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
TKUnit)))
(TKS ((':) @Natural 3 ('[] @Natural)) Float))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))))
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float)))
-> Maybe (Int64, Int64)
-> AstVarId
-> AstVarName
s
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
TKUnit)))
(TKS ((':) @Natural 3 ('[] @Natural)) Float))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))))
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float)))
forall (s :: AstSpanType) (y :: TK).
FullShapeTK y -> Maybe (Int64, Int64) -> AstVarId -> AstVarName s y
mkAstVarName FullShapeTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
TKUnit)))
(TKS ((':) @Natural 3 ('[] @Natural)) Float))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))))
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float)))
ftk Maybe (Int64, Int64)
forall a. Maybe a
Nothing (AstVarId
-> AstVarName
s
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
TKUnit)))
(TKS ((':) @Natural 3 ('[] @Natural)) Float))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))))
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float))))
-> (Int -> AstVarId)
-> Int
-> AstVarName
s
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
TKUnit)))
(TKS ((':) @Natural 3 ('[] @Natural)) Float))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))))
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float)))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> AstVarId
intToAstVarId (Int
-> AstVarName
s
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
TKUnit)))
(TKS ((':) @Natural 3 ('[] @Natural)) Float))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))))
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float))))
-> Int
-> AstVarName
s
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
TKUnit)))
(TKS ((':) @Natural 3 ('[] @Natural)) Float))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))))
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float)))
forall a b. (a -> b) -> a -> b
$ Int
100000000
var :: AstTensor AstMethodLet FullSpan (XParams 3 4 Float)
var :: AstTensor AstMethodLet FullSpan (XParams 3 4 Float)
var = AstVarName
FullSpan
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
TKUnit)))
(TKS ((':) @Natural 3 ('[] @Natural)) Float))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))))
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float)))
-> AstTensor
AstMethodLet
FullSpan
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
TKUnit)))
(TKS ((':) @Natural 3 ('[] @Natural)) Float))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))))
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float)))
forall (b :: AstSpanType) (c :: TK) (a :: AstMethodOfSharing).
AstVarName b c -> AstTensor a b c
AstVar AstVarName
FullSpan
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
TKUnit)))
(TKS ((':) @Natural 3 ('[] @Natural)) Float))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))))
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float)))
forall {s :: AstSpanType}.
AstVarName
s
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
TKUnit)))
(TKS ((':) @Natural 3 ('[] @Natural)) Float))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))))
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float)))
varName
vals :: Concrete
(X ((ListR
3
(Concrete
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))),
Concrete (TKS ((':) @Natural 3 ('[] @Natural)) Float)),
(ListR 4 (Concrete (TKS ((':) @Natural 3 ('[] @Natural)) Float)),
Concrete (TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))),
(ListR
SizeMnistLabel
(Concrete
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))),
Concrete
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float))))
vals = forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget @Concrete ((ListR
3
(Concrete
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))),
Concrete (TKS ((':) @Natural 3 ('[] @Natural)) Float)),
(ListR 4 (Concrete (TKS ((':) @Natural 3 ('[] @Natural)) Float)),
Concrete (TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))),
(ListR
SizeMnistLabel
(Concrete
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))),
Concrete
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float)))
ADFcnnMnist1Parameters Concrete 3 4 Float
forall r.
(Num r, Enum r, PrimElt r) =>
ADFcnnMnist1Parameters Concrete 3 4 r
valsInitVTOPP
env :: AstEnv Concrete
env = AstVarName
(ZonkAny @AstSpanType 5)
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
TKUnit)))
(TKS ((':) @Natural 3 ('[] @Natural)) Float))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))))
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float)))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
TKUnit)))
(TKS ((':) @Natural 3 ('[] @Natural)) Float))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))))
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float)))
-> AstEnv Concrete
-> AstEnv Concrete
forall (target :: Target) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName
(ZonkAny @AstSpanType 5)
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
TKUnit)))
(TKS ((':) @Natural 3 ('[] @Natural)) Float))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))))
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float)))
forall {s :: AstSpanType}.
AstVarName
s
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
TKUnit)))
(TKS ((':) @Natural 3 ('[] @Natural)) Float))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))))
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float)))
varName Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))
TKUnit)))
(TKS ((':) @Natural 3 ('[] @Natural)) Float))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))))
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
(TKProduct
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))
TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float)))
vals AstEnv Concrete
forall (target :: Target). AstEnv target
emptyEnv
blackGlyph :: Concrete (BuildTensorKind 784 (TKR 0 Float))
blackGlyph = SNat 784
-> SingletonTK (TKR 0 Float)
-> Concrete (TKR 0 Float)
-> Concrete (BuildTensorKind 784 (TKR 0 Float))
forall (z :: TK) (k :: Natural).
ConvertTensor Concrete =>
SNat k
-> SingletonTK z -> Concrete z -> Concrete (BuildTensorKind k z)
forall (target :: Target) (z :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
treplicate (forall (n :: Natural). KnownNat n => SNat n
SNat @SizeMnistGlyph) SingletonTK (TKR 0 Float)
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK (Concrete (TKR 0 Float)
-> Concrete (BuildTensorKind 784 (TKR 0 Float)))
-> Concrete (TKR 0 Float)
-> Concrete (BuildTensorKind 784 (TKR 0 Float))
forall a b. (a -> b) -> a -> b
$ Float -> Concrete (TKR 0 Float)
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
r -> target (TKR 0 r)
rscalar Float
7
afcnn2 :: ADReady f
=> MnistFcnnRanked1.ADFcnnMnist1Parameters f 3 4 Float
-> f (TKR 1 Float)
afcnn2 :: forall (f :: Target).
ADReady f =>
ADFcnnMnist1Parameters f 3 4 Float -> f (TKR 1 Float)
afcnn2 = (forall (n :: Natural).
KnownNat n =>
f (TKS ((':) @Natural n ('[] @Natural)) Float)
-> f (TKS ((':) @Natural n ('[] @Natural)) Float))
-> (f (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float)
-> f (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float))
-> SNat 3
-> SNat 4
-> f (TKS ((':) @Natural SizeMnistGlyph ('[] @Natural)) Float)
-> ADFcnnMnist1Parameters f 3 4 Float
-> f (TKR 1 Float)
forall (target :: Target) r (widthHidden :: Natural)
(widthHidden2 :: Natural).
(ADReady target, GoodScalar r, Differentiable r) =>
(forall (n :: Natural).
KnownNat n =>
target (TKS ((':) @Natural n ('[] @Natural)) r)
-> target (TKS ((':) @Natural n ('[] @Natural)) r))
-> (target (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r)
-> target (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r))
-> SNat widthHidden
-> SNat widthHidden2
-> target (TKS ((':) @Natural SizeMnistGlyph ('[] @Natural)) r)
-> ADFcnnMnist1Parameters target widthHidden widthHidden2 r
-> target (TKR 1 r)
MnistFcnnRanked1.afcnnMnist1
f (TKS ((':) @Natural n ('[] @Natural)) Float)
-> f (TKS ((':) @Natural n ('[] @Natural)) Float)
forall (n :: Natural).
KnownNat n =>
f (TKS ((':) @Natural n ('[] @Natural)) Float)
-> f (TKS ((':) @Natural n ('[] @Natural)) Float)
forall a. a -> a
id f (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float)
-> f (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float)
forall a. a -> a
id (forall (n :: Natural). KnownNat n => SNat n
SNat @3) (forall (n :: Natural). KnownNat n => SNat n
SNat @4)
(f (TKR2
(Rank @Natural ((':) @Natural SizeMnistGlyph ('[] @Natural)))
(TKScalar Float))
-> f (TKS ((':) @Natural SizeMnistGlyph ('[] @Natural)) Float)
forall (sh :: [Natural]) (x :: TK).
(KnownShS sh, KnownSTK x) =>
f (TKR2 (Rank @Natural sh) x) -> f (TKS2 sh x)
forall (target :: Target) (sh :: [Natural]) (x :: TK).
(ConvertTensor target, KnownShS sh, KnownSTK x) =>
target (TKR2 (Rank @Natural sh) x) -> target (TKS2 sh x)
sfromR (f (TKR2
(Rank @Natural ((':) @Natural SizeMnistGlyph ('[] @Natural)))
(TKScalar Float))
-> f (TKS ((':) @Natural SizeMnistGlyph ('[] @Natural)) Float))
-> f (TKR2
(Rank @Natural ((':) @Natural SizeMnistGlyph ('[] @Natural)))
(TKScalar Float))
-> f (TKS ((':) @Natural SizeMnistGlyph ('[] @Natural)) Float)
forall a b. (a -> b) -> a -> b
$ Ranked
(Rank @Natural ((':) @Natural SizeMnistGlyph ('[] @Natural))) Float
-> f (TKR2
(Rank @Natural ((':) @Natural SizeMnistGlyph ('[] @Natural)))
(TKScalar Float))
forall r (target :: Target) (n :: Natural).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete (Ranked
(Rank @Natural ((':) @Natural SizeMnistGlyph ('[] @Natural))) Float
-> f (TKR2
(Rank @Natural ((':) @Natural SizeMnistGlyph ('[] @Natural)))
(TKScalar Float)))
-> Ranked
(Rank @Natural ((':) @Natural SizeMnistGlyph ('[] @Natural))) Float
-> f (TKR2
(Rank @Natural ((':) @Natural SizeMnistGlyph ('[] @Natural)))
(TKScalar Float))
forall a b. (a -> b) -> a -> b
$ Concrete (TKR 1 Float) -> RepConcrete (TKR 1 Float)
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete (TKR 1 Float)
blackGlyph)
afcnn1 :: AstTensor AstMethodLet FullSpan (TKR 1 Float)
afcnn1 = ADFcnnMnist1Parameters (AstTensor AstMethodLet FullSpan) 3 4 Float
-> AstTensor AstMethodLet FullSpan (TKR 1 Float)
forall (f :: Target).
ADReady f =>
ADFcnnMnist1Parameters f 3 4 Float -> f (TKR 1 Float)
afcnn2 (ADFcnnMnist1Parameters (AstTensor AstMethodLet FullSpan) 3 4 Float
-> AstTensor AstMethodLet FullSpan (TKR 1 Float))
-> ADFcnnMnist1Parameters
(AstTensor AstMethodLet FullSpan) 3 4 Float
-> AstTensor AstMethodLet FullSpan (TKR 1 Float)
forall a b. (a -> b) -> a -> b
$ AstTensor
AstMethodLet
FullSpan
(X ((ListR
3
(AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 3 ('[] @Natural)) Float)),
(ListR
4
(AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 3 ('[] @Natural)) Float)),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))),
(ListR
SizeMnistLabel
(AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float))))
-> ((ListR
3
(AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 3 ('[] @Natural)) Float)),
(ListR
4
(AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 3 ('[] @Natural)) Float)),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))),
(ListR
SizeMnistLabel
(AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float)))
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget AstTensor
AstMethodLet
FullSpan
(X ((ListR
3
(AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Float))),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 3 ('[] @Natural)) Float)),
(ListR
4
(AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 3 ('[] @Natural)) Float)),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))),
(ListR
SizeMnistLabel
(AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 4 ('[] @Natural)) (TKScalar Float))),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Float))))
AstTensor AstMethodLet FullSpan (XParams 3 4 Float)
var
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target -> AstTensor AstMethodLet FullSpan y -> target y
interpretAstFull @Concrete AstEnv Concrete
env AstTensor AstMethodLet FullSpan (TKR 1 Float)
afcnn1
Concrete (TKR 1 Float) -> Concrete (TKR 1 Float) -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= ADFcnnMnist1Parameters Concrete 3 4 Float -> Concrete (TKR 1 Float)
forall (f :: Target).
ADReady f =>
ADFcnnMnist1Parameters f 3 4 Float -> f (TKR 1 Float)
afcnn2 ADFcnnMnist1Parameters Concrete 3 4 Float
forall r.
(Num r, Enum r, PrimElt r) =>
ADFcnnMnist1Parameters Concrete 3 4 r
valsInitVTOPP
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target -> AstTensor AstMethodLet FullSpan y -> target y
interpretAstFull @Concrete AstEnv Concrete
env
(forall (z :: TK) (s :: AstSpanType).
AstSpan s =>
AstTensor AstMethodLet s z -> AstTensor AstMethodLet s z
simplifyInline @(TKR 1 Float) AstTensor AstMethodLet FullSpan (TKR 1 Float)
afcnn1)
Concrete (TKR 1 Float) -> Concrete (TKR 1 Float) -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= ADFcnnMnist1Parameters Concrete 3 4 Float -> Concrete (TKR 1 Float)
forall (f :: Target).
ADReady f =>
ADFcnnMnist1Parameters f 3 4 Float -> f (TKR 1 Float)
afcnn2 ADFcnnMnist1Parameters Concrete 3 4 Float
forall r.
(Num r, Enum r, PrimElt r) =>
ADFcnnMnist1Parameters Concrete 3 4 r
valsInitVTOPP
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target -> AstTensor AstMethodLet FullSpan y -> target y
interpretAstFull @Concrete AstEnv Concrete
env
(forall (z :: TK) (s :: AstSpanType).
AstSpan s =>
AstTensor AstMethodLet s z -> AstTensor AstMethodLet s z
simplifyInlineContract @(TKR 1 Float) AstTensor AstMethodLet FullSpan (TKR 1 Float)
afcnn1)
Concrete (TKR 1 Float) -> Concrete (TKR 1 Float) -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= ADFcnnMnist1Parameters Concrete 3 4 Float -> Concrete (TKR 1 Float)
forall (f :: Target).
ADReady f =>
ADFcnnMnist1Parameters f 3 4 Float -> f (TKR 1 Float)
afcnn2 ADFcnnMnist1Parameters Concrete 3 4 Float
forall r.
(Num r, Enum r, PrimElt r) =>
ADFcnnMnist1Parameters Concrete 3 4 r
valsInitVTOPP
testVTOPPNonLin :: Assertion
testVTOPPNonLin :: Assertion
testVTOPPNonLin = do
Assertion
resetVarCounter
let blackGlyph :: AstTensor
AstMethodLet FullSpan (BuildTensorKind 784 (TKR 0 Double))
blackGlyph = SNat 784
-> SingletonTK (TKR 0 Double)
-> AstTensor AstMethodLet FullSpan (TKR 0 Double)
-> AstTensor
AstMethodLet FullSpan (BuildTensorKind 784 (TKR 0 Double))
forall (z :: TK) (k :: Natural).
ConvertTensor (AstTensor AstMethodLet FullSpan) =>
SNat k
-> SingletonTK z
-> AstTensor AstMethodLet FullSpan z
-> AstTensor AstMethodLet FullSpan (BuildTensorKind k z)
forall (target :: Target) (z :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
treplicate (forall (n :: Natural). KnownNat n => SNat n
SNat @SizeMnistGlyph) SingletonTK (TKR 0 Double)
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK
(AstTensor AstMethodLet FullSpan (TKR 0 Double)
-> AstTensor
AstMethodLet FullSpan (BuildTensorKind 784 (TKR 0 Double)))
-> AstTensor AstMethodLet FullSpan (TKR 0 Double)
-> AstTensor
AstMethodLet FullSpan (BuildTensorKind 784 (TKR 0 Double))
forall a b. (a -> b) -> a -> b
$ AstTensor AstMethodLet PrimalSpan (TKR 0 Double)
-> AstTensor AstMethodLet FullSpan (TKR 0 Double)
forall (ms :: AstMethodOfSharing) (y :: TK).
AstTensor ms PrimalSpan y -> AstTensor ms FullSpan y
forall (s :: AstSpanType) (ms :: AstMethodOfSharing) (y :: TK).
AstSpan s =>
AstTensor ms PrimalSpan y -> AstTensor ms s y
fromPrimal (AstTensor AstMethodLet PrimalSpan (TKR 0 Double)
-> AstTensor AstMethodLet FullSpan (TKR 0 Double))
-> AstTensor AstMethodLet PrimalSpan (TKR 0 Double)
-> AstTensor AstMethodLet FullSpan (TKR 0 Double)
forall a b. (a -> b) -> a -> b
$ Ranked 0 Double -> AstTensor AstMethodLet PrimalSpan (TKR 0 Double)
forall r (target :: Target) (n :: Natural).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete (Ranked 0 Double
-> AstTensor AstMethodLet PrimalSpan (TKR 0 Double))
-> Ranked 0 Double
-> AstTensor AstMethodLet PrimalSpan (TKR 0 Double)
forall a b. (a -> b) -> a -> b
$ Double -> Ranked 0 Double
forall a. Elt a => a -> Ranked 0 a
Nested.rscalar Double
7
afcnn2TnonLin :: MnistFcnnRanked1.ADFcnnMnist1Parameters
(AstTensor AstMethodLet FullSpan) 3 4 Double
-> AstTensor AstMethodLet FullSpan (TKR 1 Double)
afcnn2TnonLin :: ADFcnnMnist1Parameters (AstTensor AstMethodLet FullSpan) 3 4 Double
-> AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
afcnn2TnonLin =
(forall (n :: Natural).
KnownNat n =>
AstTensor
AstMethodLet FullSpan (TKS ((':) @Natural n ('[] @Natural)) Double)
-> AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural n ('[] @Natural)) Double))
-> (AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)
-> AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double))
-> SNat 3
-> SNat 4
-> AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural SizeMnistGlyph ('[] @Natural)) Double)
-> ADFcnnMnist1Parameters
(AstTensor AstMethodLet FullSpan) 3 4 Double
-> AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
forall (target :: Target) r (widthHidden :: Natural)
(widthHidden2 :: Natural).
(ADReady target, GoodScalar r, Differentiable r) =>
(forall (n :: Natural).
KnownNat n =>
target (TKS ((':) @Natural n ('[] @Natural)) r)
-> target (TKS ((':) @Natural n ('[] @Natural)) r))
-> (target (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r)
-> target (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r))
-> SNat widthHidden
-> SNat widthHidden2
-> target (TKS ((':) @Natural SizeMnistGlyph ('[] @Natural)) r)
-> ADFcnnMnist1Parameters target widthHidden widthHidden2 r
-> target (TKR 1 r)
MnistFcnnRanked1.afcnnMnist1 AstTensor
AstMethodLet FullSpan (TKS ((':) @Natural n ('[] @Natural)) Double)
-> AstTensor
AstMethodLet FullSpan (TKS ((':) @Natural n ('[] @Natural)) Double)
forall (n :: Natural).
KnownNat n =>
AstTensor
AstMethodLet FullSpan (TKS ((':) @Natural n ('[] @Natural)) Double)
-> AstTensor
AstMethodLet FullSpan (TKS ((':) @Natural n ('[] @Natural)) Double)
forall (target :: Target) r (sh :: [Natural]).
(BaseTensor target, LetTensor target, BaseTensor (PrimalOf target),
KnownShS sh, GoodScalar r, Differentiable r) =>
target (TKS sh r) -> target (TKS sh r)
logisticS AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)
-> AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)
forall (target :: Target) (sh :: [Natural]) r.
(KnownShS sh, BaseTensor target, LetTensor target, GoodScalar r,
Differentiable r) =>
target (TKS sh r) -> target (TKS sh r)
softMax1S
(forall (n :: Natural). KnownNat n => SNat n
SNat @3) (forall (n :: Natural). KnownNat n => SNat n
SNat @4) (AstTensor
AstMethodLet
FullSpan
(TKR2
(Rank @Natural ((':) @Natural 784 ('[] @Natural)))
(TKScalar Double))
-> AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
forall (sh :: [Natural]) (x :: TK).
(KnownShS sh, KnownSTK x) =>
AstTensor AstMethodLet FullSpan (TKR2 (Rank @Natural sh) x)
-> AstTensor AstMethodLet FullSpan (TKS2 sh x)
forall (target :: Target) (sh :: [Natural]) (x :: TK).
(ConvertTensor target, KnownShS sh, KnownSTK x) =>
target (TKR2 (Rank @Natural sh) x) -> target (TKS2 sh x)
sfromR AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
AstTensor
AstMethodLet
FullSpan
(TKR2
(Rank @Natural ((':) @Natural 784 ('[] @Natural)))
(TKScalar Double))
blackGlyph)
ftk :: FullShapeTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
TKUnit)))
(TKS2 ((':) @Natural 3 ('[] @Natural)) (TKScalar Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS ((':) @Natural 4 ('[] @Natural)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double) TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
ftk = forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams 3 4 Double))
(forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget @Concrete ((ListR
3
(Concrete
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))),
Concrete
(TKS2 ((':) @Natural 3 ('[] @Natural)) (TKScalar Double))),
(ListR 4 (Concrete (TKS ((':) @Natural 3 ('[] @Natural)) Float)),
Concrete (TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(ListR
SizeMnistLabel
(Concrete (TKS ((':) @Natural 4 ('[] @Natural)) Double)),
Concrete
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
ADFcnnMnist1Parameters Concrete 3 4 Double
forall r.
(Num r, Enum r, PrimElt r) =>
ADFcnnMnist1Parameters Concrete 3 4 r
valsInitVTOPP)
artifactRevnonLin :: AstArtifactRev
(X ((ListR
3
(AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 3 ('[] @Natural)) (TKScalar Double))),
(ListR
4
(AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 3 ('[] @Natural)) Float)),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(ListR
SizeMnistLabel
(AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 4 ('[] @Natural)) Double)),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double))))
(TKR2 1 (TKScalar Double))
artifactRevnonLin =
IncomingCotangentHandling
-> (((ListR
3
(AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 3 ('[] @Natural)) (TKScalar Double))),
(ListR
4
(AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 3 ('[] @Natural)) Float)),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(ListR
SizeMnistLabel
(AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 4 ('[] @Natural)) Double)),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
-> AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double)))
-> FullShapeTK
(X ((ListR
3
(AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 3 ('[] @Natural)) (TKScalar Double))),
(ListR
4
(AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 3 ('[] @Natural)) Float)),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(ListR
SizeMnistLabel
(AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 4 ('[] @Natural)) Double)),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double))))
-> AstArtifactRev
(X ((ListR
3
(AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 3 ('[] @Natural)) (TKScalar Double))),
(ListR
4
(AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 3 ('[] @Natural)) Float)),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(ListR
SizeMnistLabel
(AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 4 ('[] @Natural)) Double)),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double))))
(TKR2 1 (TKScalar Double))
forall src (ztgt :: TK) tgt.
(AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
(tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
IncomingCotangentHandling
-> (src -> tgt)
-> FullShapeTK (X src)
-> AstArtifactRev (X src) ztgt
revArtifactAdapt IncomingCotangentHandling
UseIncomingCotangent ((ListR
3
(AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 3 ('[] @Natural)) (TKScalar Double))),
(ListR
4
(AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 3 ('[] @Natural)) Float)),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(ListR
SizeMnistLabel
(AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 4 ('[] @Natural)) Double)),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
-> AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
ADFcnnMnist1Parameters (AstTensor AstMethodLet FullSpan) 3 4 Double
-> AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
afcnn2TnonLin FullShapeTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
TKUnit)))
(TKS2 ((':) @Natural 3 ('[] @Natural)) (TKScalar Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS ((':) @Natural 4 ('[] @Natural)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double) TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
FullShapeTK
(X ((ListR
3
(AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 3 ('[] @Natural)) (TKScalar Double))),
(ListR
4
(AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 3 ('[] @Natural)) Float)),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(ListR
SizeMnistLabel
(AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 4 ('[] @Natural)) Double)),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double))))
ftk
AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
TKUnit)))
(TKS2 ((':) @Natural 3 ('[] @Natural)) (TKScalar Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS ((':) @Natural 4 ('[] @Natural)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double) TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
(TKR2 1 (TKScalar Double))
-> String
forall (x :: TK) (z :: TK). AstArtifactRev x z -> String
printArtifactPrimalPretty (AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
TKUnit)))
(TKS2 ((':) @Natural 3 ('[] @Natural)) (TKScalar Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS ((':) @Natural 4 ('[] @Natural)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double) TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
(TKR2 1 (TKScalar Double))
-> AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
TKUnit)))
(TKS2 ((':) @Natural 3 ('[] @Natural)) (TKScalar Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS ((':) @Natural 4 ('[] @Natural)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
(TKR2 1 (TKScalar Double))
forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstArtifactRev x z
simplifyArtifact AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
TKUnit)))
(TKS2 ((':) @Natural 3 ('[] @Natural)) (TKScalar Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS ((':) @Natural 4 ('[] @Natural)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double) TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
(TKR2 1 (TKScalar Double))
artifactRevnonLin)
String -> String -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= String
"\\v1 -> rfromS (let v15 = scast (recip (sconcrete (sreplicate [3] 1.0) + exp (negate (sfromVector (fromList [sdot0 (sconcrete (sreplicate [784] 7.0)) (tproject1 (tproject1 (tproject1 (tproject1 v1)))), sdot0 (sconcrete (sreplicate [784] 7.0)) (tproject1 (tproject2 (tproject1 (tproject1 (tproject1 v1))))), sdot0 (sconcrete (sreplicate [784] 7.0)) (tproject1 (tproject2 (tproject2 (tproject1 (tproject1 (tproject1 v1))))))])) + negate (tproject2 (tproject1 (tproject1 v1)))))) ; v19 = recip (sconcrete (sreplicate [4] 1.0) + exp (negate (scast (sfromVector (fromList [sdot0 (tproject1 (tproject1 (tproject2 (tproject1 v1)))) v15, sdot0 (tproject1 (tproject2 (tproject1 (tproject2 (tproject1 v1))))) v15, sdot0 (tproject1 (tproject2 (tproject2 (tproject1 (tproject2 (tproject1 v1)))))) v15, sdot0 (tproject1 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 (tproject1 v1))))))) v15]))) + negate (tproject2 (tproject2 (tproject1 v1))))) ; v22 = exp (sfromVector (fromList [sdot0 (tproject1 (tproject1 (tproject2 v1))) v19, sdot0 (tproject1 (tproject2 (tproject1 (tproject2 v1)))) v19, sdot0 (tproject1 (tproject2 (tproject2 (tproject1 (tproject2 v1))))) v19, sdot0 (tproject1 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1)))))) v19, sdot0 (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1))))))) v19, sdot0 (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1)))))))) v19, sdot0 (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1))))))))) v19, sdot0 (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1)))))))))) v19, sdot0 (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1))))))))))) v19, sdot0 (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1)))))))))))) v19]) + tproject2 (tproject2 v1)) in sreplicate @10 (recip (ssum0 v22)) * v22)"
AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
TKUnit)))
(TKS2 ((':) @Natural 3 ('[] @Natural)) (TKScalar Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS ((':) @Natural 4 ('[] @Natural)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double) TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
(TKR2 1 (TKScalar Double))
-> String
forall (x :: TK) (z :: TK). AstArtifactRev x z -> String
printArtifactPrimalPretty AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
TKUnit)))
(TKS2 ((':) @Natural 3 ('[] @Natural)) (TKScalar Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS ((':) @Natural 4 ('[] @Natural)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double) TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
(TKR2 1 (TKScalar Double))
artifactRevnonLin
String -> String -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= String
"\\v1 -> let v9 = sfromVector (fromList [ssum @784 (sconcrete (sreplicate [784] 7.0) * tproject1 (tproject1 (tproject1 (tproject1 v1)))), ssum @784 (sconcrete (sreplicate [784] 7.0) * tproject1 (tproject2 (tproject1 (tproject1 (tproject1 v1))))), ssum @784 (sconcrete (sreplicate [784] 7.0) * tproject1 (tproject2 (tproject2 (tproject1 (tproject1 (tproject1 v1))))))]) + tproject2 (tproject1 (tproject1 v1)) ; v10 = exp (negate v9) ; v11 = sconcrete (sreplicate [3] 1.0) + v10 ; v12 = recip v11 ; v15 = scast v12 ; v16 = scast (sfromVector (fromList [ssum @3 (tproject1 (tproject1 (tproject2 (tproject1 v1))) * v15), ssum @3 (tproject1 (tproject2 (tproject1 (tproject2 (tproject1 v1)))) * v15), ssum @3 (tproject1 (tproject2 (tproject2 (tproject1 (tproject2 (tproject1 v1))))) * v15), ssum @3 (tproject1 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 (tproject1 v1)))))) * v15)])) + tproject2 (tproject2 (tproject1 v1)) ; v17 = exp (negate v16) ; v18 = sconcrete (sreplicate [4] 1.0) + v17 ; v19 = recip v18 ; v22 = exp (sfromVector (fromList [ssum @4 (tproject1 (tproject1 (tproject2 v1)) * v19), ssum @4 (tproject1 (tproject2 (tproject1 (tproject2 v1))) * v19), ssum @4 (tproject1 (tproject2 (tproject2 (tproject1 (tproject2 v1)))) * v19), ssum @4 (tproject1 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1))))) * v19), ssum @4 (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1)))))) * v19), ssum @4 (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1))))))) * v19), ssum @4 (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1)))))))) * v19), ssum @4 (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1))))))))) * v19), ssum @4 (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1)))))))))) * v19), ssum @4 (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1))))))))))) * v19)]) + tproject2 (tproject2 v1)) ; x23 = ssum @10 v22 ; v24 = sreplicate @10 (recip x23) in rfromS (v24 * v22)"
AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
TKUnit)))
(TKS2 ((':) @Natural 3 ('[] @Natural)) (TKScalar Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS ((':) @Natural 4 ('[] @Natural)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double) TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
(TKR2 1 (TKScalar Double))
-> String
forall (x :: TK) (z :: TK). AstArtifactRev x z -> String
printArtifactPretty AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
TKUnit)))
(TKS2 ((':) @Natural 3 ('[] @Natural)) (TKScalar Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS ((':) @Natural 4 ('[] @Natural)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double) TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
(TKR2 1 (TKScalar Double))
artifactRevnonLin
String -> String -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= String
"\\dret v1 -> let v9 = sfromVector (fromList [ssum @784 (sconcrete (sreplicate [784] 7.0) * tproject1 (tproject1 (tproject1 (tproject1 v1)))), ssum @784 (sconcrete (sreplicate [784] 7.0) * tproject1 (tproject2 (tproject1 (tproject1 (tproject1 v1))))), ssum @784 (sconcrete (sreplicate [784] 7.0) * tproject1 (tproject2 (tproject2 (tproject1 (tproject1 (tproject1 v1))))))]) + tproject2 (tproject1 (tproject1 v1)) ; v10 = exp (negate v9) ; v11 = sconcrete (sreplicate [3] 1.0) + v10 ; v12 = recip v11 ; v13 = sconcrete (sreplicate [3] 1.0) + negate v12 ; v14 = v12 * v13 ; v15 = scast v12 ; v16 = scast (sfromVector (fromList [ssum @3 (tproject1 (tproject1 (tproject2 (tproject1 v1))) * v15), ssum @3 (tproject1 (tproject2 (tproject1 (tproject2 (tproject1 v1)))) * v15), ssum @3 (tproject1 (tproject2 (tproject2 (tproject1 (tproject2 (tproject1 v1))))) * v15), ssum @3 (tproject1 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 (tproject1 v1)))))) * v15)])) + tproject2 (tproject2 (tproject1 v1)) ; v17 = exp (negate v16) ; v18 = sconcrete (sreplicate [4] 1.0) + v17 ; v19 = recip v18 ; v20 = sconcrete (sreplicate [4] 1.0) + negate v19 ; v21 = v19 * v20 ; v22 = exp (sfromVector (fromList [ssum @4 (tproject1 (tproject1 (tproject2 v1)) * v19), ssum @4 (tproject1 (tproject2 (tproject1 (tproject2 v1))) * v19), ssum @4 (tproject1 (tproject2 (tproject2 (tproject1 (tproject2 v1)))) * v19), ssum @4 (tproject1 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1))))) * v19), ssum @4 (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1)))))) * v19), ssum @4 (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1))))))) * v19), ssum @4 (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1)))))))) * v19), ssum @4 (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1))))))))) * v19), ssum @4 (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1)))))))))) * v19), ssum @4 (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1))))))))))) * v19)]) + tproject2 (tproject2 v1)) ; x23 = ssum @10 v22 ; v24 = sreplicate @10 (recip x23) ; v26 = v22 * (sreplicate @10 (negate (recip (x23 * x23)) * ssum @10 (v22 * sfromR dret)) + v24 * sfromR dret) ; v27 = sreplicate @4 (v26 !$ [9]) ; v28 = sreplicate @4 (v26 !$ [8]) ; v29 = sreplicate @4 (v26 !$ [7]) ; v30 = sreplicate @4 (v26 !$ [6]) ; v31 = sreplicate @4 (v26 !$ [5]) ; v32 = sreplicate @4 (v26 !$ [4]) ; v33 = sreplicate @4 (v26 !$ [3]) ; v34 = sreplicate @4 (v26 !$ [2]) ; v35 = sreplicate @4 (v26 !$ [1]) ; v36 = sreplicate @4 (v26 !$ [0]) ; v37 = v21 * (tproject1 (tproject1 (tproject2 v1)) * v36 + (tproject1 (tproject2 (tproject1 (tproject2 v1))) * v35 + (tproject1 (tproject2 (tproject2 (tproject1 (tproject2 v1)))) * v34 + (tproject1 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1))))) * v33 + (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1)))))) * v32 + (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1))))))) * v31 + (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1)))))))) * v30 + (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1))))))))) * v29 + (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1)))))))))) * v28 + tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1))))))))))) * v27))))))))) ; v38 = scast v37 ; v39 = sreplicate @3 (v38 !$ [3]) ; v40 = sreplicate @3 (v38 !$ [2]) ; v41 = sreplicate @3 (v38 !$ [1]) ; v42 = sreplicate @3 (v38 !$ [0]) ; v43 = v14 * scast (tproject1 (tproject1 (tproject2 (tproject1 v1))) * v42 + (tproject1 (tproject2 (tproject1 (tproject2 (tproject1 v1)))) * v41 + (tproject1 (tproject2 (tproject2 (tproject1 (tproject2 (tproject1 v1))))) * v40 + tproject1 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 (tproject1 v1)))))) * v39))) in tpair (tpair (tpair (tpair (sconcrete (sreplicate [784] 7.0) * sreplicate @784 (v43 !$ [0])) (tpair (sconcrete (sreplicate [784] 7.0) * sreplicate @784 (v43 !$ [1])) (tpair (sconcrete (sreplicate [784] 7.0) * sreplicate @784 (v43 !$ [2])) Z1))) v43) (tpair (tpair (v15 * v42) (tpair (v15 * v41) (tpair (v15 * v40) (tpair (v15 * v39) Z1)))) v37)) (tpair (tpair (v19 * v36) (tpair (v19 * v35) (tpair (v19 * v34) (tpair (v19 * v33) (tpair (v19 * v32) (tpair (v19 * v31) (tpair (v19 * v30) (tpair (v19 * v29) (tpair (v19 * v28) (tpair (v19 * v27) Z1)))))))))) v26)"
AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
TKUnit)))
(TKS2 ((':) @Natural 3 ('[] @Natural)) (TKScalar Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS ((':) @Natural 4 ('[] @Natural)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double) TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
(TKR2 1 (TKScalar Double))
-> String
forall (x :: TK) (z :: TK). AstArtifactRev x z -> String
printArtifactPretty (AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
TKUnit)))
(TKS2 ((':) @Natural 3 ('[] @Natural)) (TKScalar Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS ((':) @Natural 4 ('[] @Natural)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double) TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
(TKR2 1 (TKScalar Double))
-> AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
TKUnit)))
(TKS2 ((':) @Natural 3 ('[] @Natural)) (TKScalar Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS ((':) @Natural 4 ('[] @Natural)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
(TKR2 1 (TKScalar Double))
forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstArtifactRev x z
simplifyArtifact AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
TKUnit)))
(TKS2 ((':) @Natural 3 ('[] @Natural)) (TKScalar Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS ((':) @Natural 4 ('[] @Natural)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double) TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
(TKR2 1 (TKScalar Double))
artifactRevnonLin)
String -> String -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= String
"\\dret v1 -> let v12 = recip (sconcrete (sreplicate [3] 1.0) + exp (negate (sfromVector (fromList [sdot0 (sconcrete (sreplicate [784] 7.0)) (tproject1 (tproject1 (tproject1 (tproject1 v1)))), sdot0 (sconcrete (sreplicate [784] 7.0)) (tproject1 (tproject2 (tproject1 (tproject1 (tproject1 v1))))), sdot0 (sconcrete (sreplicate [784] 7.0)) (tproject1 (tproject2 (tproject2 (tproject1 (tproject1 (tproject1 v1))))))])) + negate (tproject2 (tproject1 (tproject1 v1))))) ; v15 = scast v12 ; v19 = recip (sconcrete (sreplicate [4] 1.0) + exp (negate (scast (sfromVector (fromList [sdot0 (tproject1 (tproject1 (tproject2 (tproject1 v1)))) v15, sdot0 (tproject1 (tproject2 (tproject1 (tproject2 (tproject1 v1))))) v15, sdot0 (tproject1 (tproject2 (tproject2 (tproject1 (tproject2 (tproject1 v1)))))) v15, sdot0 (tproject1 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 (tproject1 v1))))))) v15]))) + negate (tproject2 (tproject2 (tproject1 v1))))) ; v22 = exp (sfromVector (fromList [sdot0 (tproject1 (tproject1 (tproject2 v1))) v19, sdot0 (tproject1 (tproject2 (tproject1 (tproject2 v1)))) v19, sdot0 (tproject1 (tproject2 (tproject2 (tproject1 (tproject2 v1))))) v19, sdot0 (tproject1 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1)))))) v19, sdot0 (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1))))))) v19, sdot0 (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1)))))))) v19, sdot0 (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1))))))))) v19, sdot0 (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1)))))))))) v19, sdot0 (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1))))))))))) v19, sdot0 (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1)))))))))))) v19]) + tproject2 (tproject2 v1)) ; x23 = ssum0 v22 ; v26 = v22 * (sreplicate @10 (negate (recip (x23 * x23)) * sdot0 v22 (sfromR dret)) + sreplicate @10 (recip x23) * sfromR dret) ; x27 = v26 !$ [9] ; x28 = v26 !$ [8] ; x29 = v26 !$ [7] ; x30 = v26 !$ [6] ; x31 = v26 !$ [5] ; x32 = v26 !$ [4] ; x33 = v26 !$ [3] ; x34 = v26 !$ [2] ; x35 = v26 !$ [1] ; x36 = v26 !$ [0] ; v37 = (v19 * (sconcrete (sreplicate [4] 1.0) + negate v19)) * (tproject1 (tproject1 (tproject2 v1)) * sreplicate @4 x36 + (tproject1 (tproject2 (tproject1 (tproject2 v1))) * sreplicate @4 x35 + (tproject1 (tproject2 (tproject2 (tproject1 (tproject2 v1)))) * sreplicate @4 x34 + (tproject1 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1))))) * sreplicate @4 x33 + (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1)))))) * sreplicate @4 x32 + (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1))))))) * sreplicate @4 x31 + (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1)))))))) * sreplicate @4 x30 + (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1))))))))) * sreplicate @4 x29 + (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1)))))))))) * sreplicate @4 x28 + tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 v1))))))))))) * sreplicate @4 x27))))))))) ; v38 = scast v37 ; x39 = v38 !$ [3] ; x40 = v38 !$ [2] ; x41 = v38 !$ [1] ; x42 = v38 !$ [0] ; v43 = (v12 * (sconcrete (sreplicate [3] 1.0) + negate v12)) * scast (tproject1 (tproject1 (tproject2 (tproject1 v1))) * sreplicate @3 x42 + (tproject1 (tproject2 (tproject1 (tproject2 (tproject1 v1)))) * sreplicate @3 x41 + (tproject1 (tproject2 (tproject2 (tproject1 (tproject2 (tproject1 v1))))) * sreplicate @3 x40 + tproject1 (tproject2 (tproject2 (tproject2 (tproject1 (tproject2 (tproject1 v1)))))) * sreplicate @3 x39))) in tpair (tpair (tpair (tpair (sconcrete (sreplicate [784] 7.0) * sreplicate @784 (v43 !$ [0])) (tpair (sconcrete (sreplicate [784] 7.0) * sreplicate @784 (v43 !$ [1])) (tpair (sconcrete (sreplicate [784] 7.0) * sreplicate @784 (v43 !$ [2])) Z1))) v43) (tpair (tpair (v15 * sreplicate @3 x42) (tpair (v15 * sreplicate @3 x41) (tpair (v15 * sreplicate @3 x40) (tpair (v15 * sreplicate @3 x39) Z1)))) v37)) (tpair (tpair (v19 * sreplicate @4 x36) (tpair (v19 * sreplicate @4 x35) (tpair (v19 * sreplicate @4 x34) (tpair (v19 * sreplicate @4 x33) (tpair (v19 * sreplicate @4 x32) (tpair (v19 * sreplicate @4 x31) (tpair (v19 * sreplicate @4 x30) (tpair (v19 * sreplicate @4 x29) (tpair (v19 * sreplicate @4 x28) (tpair (v19 * sreplicate @4 x27) Z1)))))))))) v26)"
testVTOAstNonLin :: Assertion
testVTOAstNonLin :: Assertion
testVTOAstNonLin = do
let ftk :: FullShapeTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
TKUnit)))
(TKS2 ((':) @Natural 3 ('[] @Natural)) (TKScalar Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS ((':) @Natural 4 ('[] @Natural)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double) TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
ftk = forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams 3 4 Double))
(forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget @Concrete ((ListR
3
(Concrete
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))),
Concrete
(TKS2 ((':) @Natural 3 ('[] @Natural)) (TKScalar Double))),
(ListR 4 (Concrete (TKS ((':) @Natural 3 ('[] @Natural)) Float)),
Concrete (TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(ListR
SizeMnistLabel
(Concrete (TKS ((':) @Natural 4 ('[] @Natural)) Double)),
Concrete
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
ADFcnnMnist1Parameters Concrete 3 4 Double
forall r.
(Num r, Enum r, PrimElt r) =>
ADFcnnMnist1Parameters Concrete 3 4 r
valsInitVTOPP)
varName :: AstVarName
s
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
TKUnit)))
(TKS2 ((':) @Natural 3 ('[] @Natural)) (TKScalar Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS ((':) @Natural 4 ('[] @Natural)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double) TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
varName = FullShapeTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
TKUnit)))
(TKS2 ((':) @Natural 3 ('[] @Natural)) (TKScalar Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS ((':) @Natural 4 ('[] @Natural)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double) TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
-> Maybe (Int64, Int64)
-> AstVarId
-> AstVarName
s
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
TKUnit)))
(TKS2 ((':) @Natural 3 ('[] @Natural)) (TKScalar Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS ((':) @Natural 4 ('[] @Natural)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
forall (s :: AstSpanType) (y :: TK).
FullShapeTK y -> Maybe (Int64, Int64) -> AstVarId -> AstVarName s y
mkAstVarName FullShapeTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
TKUnit)))
(TKS2 ((':) @Natural 3 ('[] @Natural)) (TKScalar Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS ((':) @Natural 4 ('[] @Natural)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double) TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
ftk Maybe (Int64, Int64)
forall a. Maybe a
Nothing (AstVarId
-> AstVarName
s
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
TKUnit)))
(TKS2 ((':) @Natural 3 ('[] @Natural)) (TKScalar Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS ((':) @Natural 4 ('[] @Natural)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double))))
-> (Int -> AstVarId)
-> Int
-> AstVarName
s
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
TKUnit)))
(TKS2 ((':) @Natural 3 ('[] @Natural)) (TKScalar Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS ((':) @Natural 4 ('[] @Natural)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> AstVarId
intToAstVarId (Int
-> AstVarName
s
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
TKUnit)))
(TKS2 ((':) @Natural 3 ('[] @Natural)) (TKScalar Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS ((':) @Natural 4 ('[] @Natural)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double))))
-> Int
-> AstVarName
s
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
TKUnit)))
(TKS2 ((':) @Natural 3 ('[] @Natural)) (TKScalar Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS ((':) @Natural 4 ('[] @Natural)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
forall a b. (a -> b) -> a -> b
$ Int
100000000
var :: AstTensor AstMethodLet FullSpan (XParams 3 4 Double)
var :: AstTensor AstMethodLet FullSpan (XParams 3 4 Double)
var = AstVarName
FullSpan
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
TKUnit)))
(TKS2 ((':) @Natural 3 ('[] @Natural)) (TKScalar Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS ((':) @Natural 4 ('[] @Natural)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double) TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
-> AstTensor
AstMethodLet
FullSpan
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
TKUnit)))
(TKS2 ((':) @Natural 3 ('[] @Natural)) (TKScalar Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS ((':) @Natural 4 ('[] @Natural)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
forall (b :: AstSpanType) (c :: TK) (a :: AstMethodOfSharing).
AstVarName b c -> AstTensor a b c
AstVar AstVarName
FullSpan
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
TKUnit)))
(TKS2 ((':) @Natural 3 ('[] @Natural)) (TKScalar Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS ((':) @Natural 4 ('[] @Natural)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double) TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
forall {s :: AstSpanType}.
AstVarName
s
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
TKUnit)))
(TKS2 ((':) @Natural 3 ('[] @Natural)) (TKScalar Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS ((':) @Natural 4 ('[] @Natural)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double) TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
varName
vals :: Concrete
(X ((ListR
3
(Concrete
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))),
Concrete
(TKS2 ((':) @Natural 3 ('[] @Natural)) (TKScalar Double))),
(ListR 4 (Concrete (TKS ((':) @Natural 3 ('[] @Natural)) Float)),
Concrete (TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(ListR
SizeMnistLabel
(Concrete (TKS ((':) @Natural 4 ('[] @Natural)) Double)),
Concrete
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double))))
vals = forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget @Concrete ((ListR
3
(Concrete
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))),
Concrete
(TKS2 ((':) @Natural 3 ('[] @Natural)) (TKScalar Double))),
(ListR 4 (Concrete (TKS ((':) @Natural 3 ('[] @Natural)) Float)),
Concrete (TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(ListR
SizeMnistLabel
(Concrete (TKS ((':) @Natural 4 ('[] @Natural)) Double)),
Concrete
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
ADFcnnMnist1Parameters Concrete 3 4 Double
forall r.
(Num r, Enum r, PrimElt r) =>
ADFcnnMnist1Parameters Concrete 3 4 r
valsInitVTOPP
env :: AstEnv Concrete
env = AstVarName
(ZonkAny @AstSpanType 6)
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
TKUnit)))
(TKS2 ((':) @Natural 3 ('[] @Natural)) (TKScalar Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS ((':) @Natural 4 ('[] @Natural)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double) TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
TKUnit)))
(TKS2 ((':) @Natural 3 ('[] @Natural)) (TKScalar Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS ((':) @Natural 4 ('[] @Natural)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
-> AstEnv Concrete
-> AstEnv Concrete
forall (target :: Target) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName
(ZonkAny @AstSpanType 6)
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
TKUnit)))
(TKS2 ((':) @Natural 3 ('[] @Natural)) (TKScalar Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS ((':) @Natural 4 ('[] @Natural)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double) TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
forall {s :: AstSpanType}.
AstVarName
s
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
TKUnit)))
(TKS2 ((':) @Natural 3 ('[] @Natural)) (TKScalar Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS ((':) @Natural 4 ('[] @Natural)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double) TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
varName Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
(TKProduct
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))
TKUnit)))
(TKS2 ((':) @Natural 3 ('[] @Natural)) (TKScalar Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct
(TKS ((':) @Natural 3 ('[] @Natural)) Float)
(TKProduct (TKS ((':) @Natural 3 ('[] @Natural)) Float) TKUnit))))
(TKS ((':) @Natural 4 ('[] @Natural)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double)
(TKProduct
(TKS ((':) @Natural 4 ('[] @Natural)) Double) TKUnit))))))))))
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
vals AstEnv Concrete
forall (target :: Target). AstEnv target
emptyEnv
blackGlyph :: Concrete (BuildTensorKind 784 (TKR 0 Double))
blackGlyph = SNat 784
-> SingletonTK (TKR 0 Double)
-> Concrete (TKR 0 Double)
-> Concrete (BuildTensorKind 784 (TKR 0 Double))
forall (z :: TK) (k :: Natural).
ConvertTensor Concrete =>
SNat k
-> SingletonTK z -> Concrete z -> Concrete (BuildTensorKind k z)
forall (target :: Target) (z :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
treplicate (forall (n :: Natural). KnownNat n => SNat n
SNat @SizeMnistGlyph) SingletonTK (TKR 0 Double)
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK (Concrete (TKR 0 Double)
-> Concrete (BuildTensorKind 784 (TKR 0 Double)))
-> Concrete (TKR 0 Double)
-> Concrete (BuildTensorKind 784 (TKR 0 Double))
forall a b. (a -> b) -> a -> b
$ Double -> Concrete (TKR 0 Double)
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
r -> target (TKR 0 r)
rscalar Double
7
afcnn2 :: ADReady f
=> MnistFcnnRanked1.ADFcnnMnist1Parameters f 3 4 Double
-> f (TKR 1 Double)
afcnn2 :: forall (f :: Target).
ADReady f =>
ADFcnnMnist1Parameters f 3 4 Double -> f (TKR2 1 (TKScalar Double))
afcnn2 = (forall (n :: Natural).
KnownNat n =>
f (TKS ((':) @Natural n ('[] @Natural)) Double)
-> f (TKS ((':) @Natural n ('[] @Natural)) Double))
-> (f (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)
-> f (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double))
-> SNat 3
-> SNat 4
-> f (TKS ((':) @Natural SizeMnistGlyph ('[] @Natural)) Double)
-> ADFcnnMnist1Parameters f 3 4 Double
-> f (TKR2 1 (TKScalar Double))
forall (target :: Target) r (widthHidden :: Natural)
(widthHidden2 :: Natural).
(ADReady target, GoodScalar r, Differentiable r) =>
(forall (n :: Natural).
KnownNat n =>
target (TKS ((':) @Natural n ('[] @Natural)) r)
-> target (TKS ((':) @Natural n ('[] @Natural)) r))
-> (target (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r)
-> target (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r))
-> SNat widthHidden
-> SNat widthHidden2
-> target (TKS ((':) @Natural SizeMnistGlyph ('[] @Natural)) r)
-> ADFcnnMnist1Parameters target widthHidden widthHidden2 r
-> target (TKR 1 r)
MnistFcnnRanked1.afcnnMnist1
f (TKS ((':) @Natural n ('[] @Natural)) Double)
-> f (TKS ((':) @Natural n ('[] @Natural)) Double)
forall (n :: Natural).
KnownNat n =>
f (TKS ((':) @Natural n ('[] @Natural)) Double)
-> f (TKS ((':) @Natural n ('[] @Natural)) Double)
forall (target :: Target) r (sh :: [Natural]).
(BaseTensor target, LetTensor target, BaseTensor (PrimalOf target),
KnownShS sh, GoodScalar r, Differentiable r) =>
target (TKS sh r) -> target (TKS sh r)
logisticS f (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)
-> f (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)
forall (target :: Target) (sh :: [Natural]) r.
(KnownShS sh, BaseTensor target, LetTensor target, GoodScalar r,
Differentiable r) =>
target (TKS sh r) -> target (TKS sh r)
softMax1S (forall (n :: Natural). KnownNat n => SNat n
SNat @3) (forall (n :: Natural). KnownNat n => SNat n
SNat @4)
(f (TKR2
(Rank @Natural ((':) @Natural SizeMnistGlyph ('[] @Natural)))
(TKScalar Double))
-> f (TKS ((':) @Natural SizeMnistGlyph ('[] @Natural)) Double)
forall (sh :: [Natural]) (x :: TK).
(KnownShS sh, KnownSTK x) =>
f (TKR2 (Rank @Natural sh) x) -> f (TKS2 sh x)
forall (target :: Target) (sh :: [Natural]) (x :: TK).
(ConvertTensor target, KnownShS sh, KnownSTK x) =>
target (TKR2 (Rank @Natural sh) x) -> target (TKS2 sh x)
sfromR (f (TKR2
(Rank @Natural ((':) @Natural SizeMnistGlyph ('[] @Natural)))
(TKScalar Double))
-> f (TKS ((':) @Natural SizeMnistGlyph ('[] @Natural)) Double))
-> f (TKR2
(Rank @Natural ((':) @Natural SizeMnistGlyph ('[] @Natural)))
(TKScalar Double))
-> f (TKS ((':) @Natural SizeMnistGlyph ('[] @Natural)) Double)
forall a b. (a -> b) -> a -> b
$ Ranked
(Rank @Natural ((':) @Natural SizeMnistGlyph ('[] @Natural)))
Double
-> f (TKR2
(Rank @Natural ((':) @Natural SizeMnistGlyph ('[] @Natural)))
(TKScalar Double))
forall r (target :: Target) (n :: Natural).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete (Ranked
(Rank @Natural ((':) @Natural SizeMnistGlyph ('[] @Natural)))
Double
-> f (TKR2
(Rank @Natural ((':) @Natural SizeMnistGlyph ('[] @Natural)))
(TKScalar Double)))
-> Ranked
(Rank @Natural ((':) @Natural SizeMnistGlyph ('[] @Natural)))
Double
-> f (TKR2
(Rank @Natural ((':) @Natural SizeMnistGlyph ('[] @Natural)))
(TKScalar Double))
forall a b. (a -> b) -> a -> b
$ Concrete (TKR2 1 (TKScalar Double))
-> RepConcrete (TKR2 1 (TKScalar Double))
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete (TKR2 1 (TKScalar Double))
blackGlyph)
afcnn1 :: AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
afcnn1 = ADFcnnMnist1Parameters (AstTensor AstMethodLet FullSpan) 3 4 Double
-> AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
forall (f :: Target).
ADReady f =>
ADFcnnMnist1Parameters f 3 4 Double -> f (TKR2 1 (TKScalar Double))
afcnn2 (ADFcnnMnist1Parameters
(AstTensor AstMethodLet FullSpan) 3 4 Double
-> AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double)))
-> ADFcnnMnist1Parameters
(AstTensor AstMethodLet FullSpan) 3 4 Double
-> AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
forall a b. (a -> b) -> a -> b
$ AstTensor
AstMethodLet
FullSpan
(X ((ListR
3
(AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 3 ('[] @Natural)) (TKScalar Double))),
(ListR
4
(AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 3 ('[] @Natural)) Float)),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(ListR
SizeMnistLabel
(AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 4 ('[] @Natural)) Double)),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double))))
-> ((ListR
3
(AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 3 ('[] @Natural)) (TKScalar Double))),
(ListR
4
(AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 3 ('[] @Natural)) Float)),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(ListR
SizeMnistLabel
(AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 4 ('[] @Natural)) Double)),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget AstTensor
AstMethodLet
FullSpan
(X ((ListR
3
(AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 784 ('[] @Natural)) (TKScalar Double))),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural 3 ('[] @Natural)) (TKScalar Double))),
(ListR
4
(AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 3 ('[] @Natural)) Float)),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(ListR
SizeMnistLabel
(AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 4 ('[] @Natural)) Double)),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double))))
AstTensor AstMethodLet FullSpan (XParams 3 4 Double)
var
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target -> AstTensor AstMethodLet FullSpan y -> target y
interpretAstFull @Concrete AstEnv Concrete
env AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
afcnn1
Concrete (TKR2 1 (TKScalar Double))
-> Concrete (TKR2 1 (TKScalar Double)) -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= ADFcnnMnist1Parameters Concrete 3 4 Double
-> Concrete (TKR2 1 (TKScalar Double))
forall (f :: Target).
ADReady f =>
ADFcnnMnist1Parameters f 3 4 Double -> f (TKR2 1 (TKScalar Double))
afcnn2 ADFcnnMnist1Parameters Concrete 3 4 Double
forall r.
(Num r, Enum r, PrimElt r) =>
ADFcnnMnist1Parameters Concrete 3 4 r
valsInitVTOPP
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target -> AstTensor AstMethodLet FullSpan y -> target y
interpretAstFull @Concrete AstEnv Concrete
env
(forall (z :: TK) (s :: AstSpanType).
AstSpan s =>
AstTensor AstMethodLet s z -> AstTensor AstMethodLet s z
simplifyInline @(TKR 1 Double) AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
afcnn1)
Concrete (TKR2 1 (TKScalar Double))
-> Concrete (TKR2 1 (TKScalar Double)) -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= ADFcnnMnist1Parameters Concrete 3 4 Double
-> Concrete (TKR2 1 (TKScalar Double))
forall (f :: Target).
ADReady f =>
ADFcnnMnist1Parameters f 3 4 Double -> f (TKR2 1 (TKScalar Double))
afcnn2 ADFcnnMnist1Parameters Concrete 3 4 Double
forall r.
(Num r, Enum r, PrimElt r) =>
ADFcnnMnist1Parameters Concrete 3 4 r
valsInitVTOPP
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target -> AstTensor AstMethodLet FullSpan y -> target y
interpretAstFull @Concrete AstEnv Concrete
env
(forall (z :: TK) (s :: AstSpanType).
AstSpan s =>
AstTensor AstMethodLet s z -> AstTensor AstMethodLet s z
simplifyInlineContract @(TKR 1 Double) AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
afcnn1)
Concrete (TKR2 1 (TKScalar Double))
-> Concrete (TKR2 1 (TKScalar Double)) -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= ADFcnnMnist1Parameters Concrete 3 4 Double
-> Concrete (TKR2 1 (TKScalar Double))
forall (f :: Target).
ADReady f =>
ADFcnnMnist1Parameters f 3 4 Double -> f (TKR2 1 (TKScalar Double))
afcnn2 ADFcnnMnist1Parameters Concrete 3 4 Double
forall r.
(Num r, Enum r, PrimElt r) =>
ADFcnnMnist1Parameters Concrete 3 4 r
valsInitVTOPP
valsInitVT2OPP :: MnistFcnnRanked2.ADFcnnMnist2Parameters Concrete Double Float
valsInitVT2OPP :: ADFcnnMnist2Parameters Concrete Double Float
valsInitVT2OPP =
( ( RepConcrete (TKR 2 Double) -> Concrete (TKR 2 Double)
forall (y :: TK). RepConcrete y -> Concrete y
Concrete (RepConcrete (TKR 2 Double) -> Concrete (TKR 2 Double))
-> RepConcrete (TKR 2 Double) -> Concrete (TKR 2 Double)
forall a b. (a -> b) -> a -> b
$ IShR 2 -> [Double] -> Ranked 2 Double
forall a (n :: Natural). PrimElt a => IShR n -> [a] -> Ranked n a
Nested.rfromListPrimLinear [Int
Item (IShR 2)
4, Int
Item (IShR 2)
3]
([[Double]] -> [Double]
forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat ([[Double]] -> [Double]) -> [[Double]] -> [Double]
forall a b. (a -> b) -> a -> b
$ Int -> [Double] -> [[Double]]
forall a. Int -> a -> [a]
replicate Int
4 [Double
Item [Double]
1, Double
Item [Double]
2, Double
Item [Double]
3])
, RepConcrete (TKR2 1 (TKScalar Double))
-> Concrete (TKR2 1 (TKScalar Double))
forall (y :: TK). RepConcrete y -> Concrete y
Concrete (RepConcrete (TKR2 1 (TKScalar Double))
-> Concrete (TKR2 1 (TKScalar Double)))
-> RepConcrete (TKR2 1 (TKScalar Double))
-> Concrete (TKR2 1 (TKScalar Double))
forall a b. (a -> b) -> a -> b
$ IShR 1 -> [Double] -> Ranked 1 Double
forall a (n :: Natural). PrimElt a => IShR n -> [a] -> Ranked n a
Nested.rfromListPrimLinear [Int
Item (IShR 1)
4] [Double
Item [Double]
1, Double
Item [Double]
2, Double
Item [Double]
3, Double
Item [Double]
4] )
, ( RepConcrete (TKR 2 Float) -> Concrete (TKR 2 Float)
forall (y :: TK). RepConcrete y -> Concrete y
Concrete (RepConcrete (TKR 2 Float) -> Concrete (TKR 2 Float))
-> RepConcrete (TKR 2 Float) -> Concrete (TKR 2 Float)
forall a b. (a -> b) -> a -> b
$ IShR 2 -> [Float] -> Ranked 2 Float
forall a (n :: Natural). PrimElt a => IShR n -> [a] -> Ranked n a
Nested.rfromListPrimLinear [Int
Item (IShR 2)
5, Int
Item (IShR 2)
4]
([[Float]] -> [Float]
forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat ([[Float]] -> [Float]) -> [[Float]] -> [Float]
forall a b. (a -> b) -> a -> b
$ Int -> [Float] -> [[Float]]
forall a. Int -> a -> [a]
replicate Int
5 [Float
Item [Float]
1, Float
Item [Float]
2, Float
Item [Float]
3, Float
Item [Float]
4])
, RepConcrete (TKR2 1 (TKScalar Double))
-> Concrete (TKR2 1 (TKScalar Double))
forall (y :: TK). RepConcrete y -> Concrete y
Concrete (RepConcrete (TKR2 1 (TKScalar Double))
-> Concrete (TKR2 1 (TKScalar Double)))
-> RepConcrete (TKR2 1 (TKScalar Double))
-> Concrete (TKR2 1 (TKScalar Double))
forall a b. (a -> b) -> a -> b
$ IShR 1 -> [Double] -> Ranked 1 Double
forall a (n :: Natural). PrimElt a => IShR n -> [a] -> Ranked n a
Nested.rfromListPrimLinear [Int
Item (IShR 1)
5] [Double
Item [Double]
1, Double
Item [Double]
2, Double
Item [Double]
3, Double
Item [Double]
4, Double
Item [Double]
5] )
, ( RepConcrete (TKR 2 Double) -> Concrete (TKR 2 Double)
forall (y :: TK). RepConcrete y -> Concrete y
Concrete (RepConcrete (TKR 2 Double) -> Concrete (TKR 2 Double))
-> RepConcrete (TKR 2 Double) -> Concrete (TKR 2 Double)
forall a b. (a -> b) -> a -> b
$ IShR 2 -> [Double] -> Ranked 2 Double
forall a (n :: Natural). PrimElt a => IShR n -> [a] -> Ranked n a
Nested.rfromListPrimLinear [Int
Item (IShR 2)
2, Int
Item (IShR 2)
5]
([[Double]] -> [Double]
forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat ([[Double]] -> [Double]) -> [[Double]] -> [Double]
forall a b. (a -> b) -> a -> b
$ Int -> [Double] -> [[Double]]
forall a. Int -> a -> [a]
replicate Int
2 [Double
Item [Double]
1, Double
Item [Double]
2, Double
Item [Double]
3, Double
Item [Double]
4, Double
Item [Double]
5])
, RepConcrete (TKR2 1 (TKScalar Double))
-> Concrete (TKR2 1 (TKScalar Double))
forall (y :: TK). RepConcrete y -> Concrete y
Concrete (RepConcrete (TKR2 1 (TKScalar Double))
-> Concrete (TKR2 1 (TKScalar Double)))
-> RepConcrete (TKR2 1 (TKScalar Double))
-> Concrete (TKR2 1 (TKScalar Double))
forall a b. (a -> b) -> a -> b
$ IShR 1 -> [Double] -> Ranked 1 Double
forall a (n :: Natural). PrimElt a => IShR n -> [a] -> Ranked n a
Nested.rfromListPrimLinear [Int
Item (IShR 1)
2] [Double
Item [Double]
1, Double
Item [Double]
2] ) )
testVT2OPP :: Assertion
testVT2OPP :: Assertion
testVT2OPP = do
Assertion
resetVarCounter
let blackGlyph :: AstTensor AstMethodLet FullSpan (BuildTensorKind 3 (TKR 0 Double))
blackGlyph = SNat 3
-> SingletonTK (TKR 0 Double)
-> AstTensor AstMethodLet FullSpan (TKR 0 Double)
-> AstTensor
AstMethodLet FullSpan (BuildTensorKind 3 (TKR 0 Double))
forall (z :: TK) (k :: Natural).
ConvertTensor (AstTensor AstMethodLet FullSpan) =>
SNat k
-> SingletonTK z
-> AstTensor AstMethodLet FullSpan z
-> AstTensor AstMethodLet FullSpan (BuildTensorKind k z)
forall (target :: Target) (z :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
treplicate (forall (n :: Natural). KnownNat n => SNat n
SNat @3) SingletonTK (TKR 0 Double)
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK
(AstTensor AstMethodLet FullSpan (TKR 0 Double)
-> AstTensor
AstMethodLet FullSpan (BuildTensorKind 3 (TKR 0 Double)))
-> AstTensor AstMethodLet FullSpan (TKR 0 Double)
-> AstTensor
AstMethodLet FullSpan (BuildTensorKind 3 (TKR 0 Double))
forall a b. (a -> b) -> a -> b
$ AstTensor AstMethodLet PrimalSpan (TKR 0 Double)
-> AstTensor AstMethodLet FullSpan (TKR 0 Double)
forall (ms :: AstMethodOfSharing) (y :: TK).
AstTensor ms PrimalSpan y -> AstTensor ms FullSpan y
forall (s :: AstSpanType) (ms :: AstMethodOfSharing) (y :: TK).
AstSpan s =>
AstTensor ms PrimalSpan y -> AstTensor ms s y
fromPrimal (AstTensor AstMethodLet PrimalSpan (TKR 0 Double)
-> AstTensor AstMethodLet FullSpan (TKR 0 Double))
-> AstTensor AstMethodLet PrimalSpan (TKR 0 Double)
-> AstTensor AstMethodLet FullSpan (TKR 0 Double)
forall a b. (a -> b) -> a -> b
$ Ranked 0 Double -> AstTensor AstMethodLet PrimalSpan (TKR 0 Double)
forall r (target :: Target) (n :: Natural).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete (Ranked 0 Double
-> AstTensor AstMethodLet PrimalSpan (TKR 0 Double))
-> Ranked 0 Double
-> AstTensor AstMethodLet PrimalSpan (TKR 0 Double)
forall a b. (a -> b) -> a -> b
$ Double -> Ranked 0 Double
forall a. Elt a => a -> Ranked 0 a
Nested.rscalar Double
7
afcnn2T :: MnistFcnnRanked2.ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Double Float
-> AstTensor AstMethodLet FullSpan (TKR 1 Double)
afcnn2T :: ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Double Float
-> AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
afcnn2T = (AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
-> AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double)))
-> (AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
-> AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double)))
-> AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
-> ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Double Float
-> AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
forall (target :: Target) r q.
(ADReady target, GoodScalar r, Differentiable r, GoodScalar q,
Differentiable q) =>
(target (TKR 1 r) -> target (TKR 1 r))
-> (target (TKR 1 r) -> target (TKR 1 r))
-> target (TKR 1 r)
-> ADFcnnMnist2Parameters target r q
-> target (TKR 1 r)
MnistFcnnRanked2.afcnnMnist2 AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
-> AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
forall a. a -> a
id AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
-> AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
forall a. a -> a
id AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
blackGlyph
ftk :: FullShapeTK
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
ftk = forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams2 Double Float))
(forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget @Concrete ADFcnnMnist2Parameters Concrete Double Float
valsInitVT2OPP)
artifactRev :: AstArtifactRev
(X (ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Double Float))
(TKR2 1 (TKScalar Double))
artifactRev = IncomingCotangentHandling
-> (ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Double Float
-> AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double)))
-> FullShapeTK
(X (ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Double Float))
-> AstArtifactRev
(X (ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Double Float))
(TKR2 1 (TKScalar Double))
forall src (ztgt :: TK) tgt.
(AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
(tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
IncomingCotangentHandling
-> (src -> tgt)
-> FullShapeTK (X src)
-> AstArtifactRev (X src) ztgt
revArtifactAdapt IncomingCotangentHandling
UseIncomingCotangent ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Double Float
-> AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
afcnn2T FullShapeTK
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
FullShapeTK
(X (ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Double Float))
ftk
AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR2 1 (TKScalar Double))
-> String
forall (x :: TK) (z :: TK). AstArtifactRev x z -> String
printArtifactPrimalPretty AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR2 1 (TKScalar Double))
artifactRev
String -> String -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= String
"\\m1 -> let m5 = str (sreplicate @5 (scast (ssum @3 (sconcrete (sreplicate [3,4] 7.0) * str (sfromR (tproject1 (tproject1 (tproject1 m1))))) + sfromR (tproject2 (tproject1 (tproject1 m1)))))) ; m6 = str (sreplicate @2 (scast (ssum @4 (m5 * str (sfromR (tproject1 (tproject2 (tproject1 m1)))))) + sfromR (tproject2 (tproject2 (tproject1 m1))))) in rfromS (ssum @5 (m6 * str (sfromR (tproject1 (tproject2 m1)))) + sfromR (tproject2 (tproject2 m1)))"
AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR2 1 (TKScalar Double))
-> String
forall (x :: TK) (z :: TK). AstArtifactRev x z -> String
printArtifactPretty AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR2 1 (TKScalar Double))
artifactRev
String -> String -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= String
"\\dret m1 -> let m5 = str (sreplicate @5 (scast (ssum @3 (sconcrete (sreplicate [3,4] 7.0) * str (sfromR (tproject1 (tproject1 (tproject1 m1))))) + sfromR (tproject2 (tproject1 (tproject1 m1)))))) ; m6 = str (sreplicate @2 (scast (ssum @4 (m5 * str (sfromR (tproject1 (tproject2 (tproject1 m1)))))) + sfromR (tproject2 (tproject2 (tproject1 m1))))) ; v8 = ssum @2 (str (str (sfromR (tproject1 (tproject2 m1))) * sreplicate @5 (sfromR dret))) ; m9 = sreplicate @4 (scast v8) ; v10 = scast (ssum @5 (str (str (sfromR (tproject1 (tproject2 (tproject1 m1)))) * m9))) in tpair (tpair (tpair (rfromS (str (sconcrete (sreplicate [3,4] 7.0) * sreplicate @3 v10))) (rfromS v10)) (tpair (rfromS (str (m5 * m9))) (rfromS v8))) (tpair (rfromS (str (m6 * sreplicate @5 (sfromR dret)))) dret)"
AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR2 1 (TKScalar Double))
-> String
forall (x :: TK) (z :: TK). AstArtifactRev x z -> String
printArtifactPretty AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR2 1 (TKScalar Double))
artifactRev
String -> String -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= String
"\\dret m1 -> let m5 = str (sreplicate @5 (scast (ssum @3 (sconcrete (sreplicate [3,4] 7.0) * str (sfromR (tproject1 (tproject1 (tproject1 m1))))) + sfromR (tproject2 (tproject1 (tproject1 m1)))))) ; m6 = str (sreplicate @2 (scast (ssum @4 (m5 * str (sfromR (tproject1 (tproject2 (tproject1 m1)))))) + sfromR (tproject2 (tproject2 (tproject1 m1))))) ; v8 = ssum @2 (str (str (sfromR (tproject1 (tproject2 m1))) * sreplicate @5 (sfromR dret))) ; m9 = sreplicate @4 (scast v8) ; v10 = scast (ssum @5 (str (str (sfromR (tproject1 (tproject2 (tproject1 m1)))) * m9))) in tpair (tpair (tpair (rfromS (str (sconcrete (sreplicate [3,4] 7.0) * sreplicate @3 v10))) (rfromS v10)) (tpair (rfromS (str (m5 * m9))) (rfromS v8))) (tpair (rfromS (str (m6 * sreplicate @5 (sfromR dret)))) dret)"
AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR2 1 (TKScalar Double))
-> String
forall (x :: TK) (z :: TK). AstArtifactRev x z -> String
printArtifactPretty (AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR2 1 (TKScalar Double))
-> AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR2 1 (TKScalar Double))
forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstArtifactRev x z
simplifyArtifact AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR2 1 (TKScalar Double))
artifactRev)
String -> String -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= String
"\\dret m1 -> tconvert (ConvT2 (ConvT2 (ConvT2 (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [4,3] FTKScalar)) ConvSX)) (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [4] FTKScalar)) ConvSX))) (ConvT2 (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [5,4] FTKScalar)) ConvSX)) (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [5] FTKScalar)) ConvSX)))) (ConvT2 (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [2,5] FTKScalar)) ConvSX)) ConvId)) (STKProduct (STKProduct (STKProduct (STKS [4,3] STKScalar) (STKS [4] STKScalar)) (STKProduct (STKS [5,4] STKScalar) (STKS [5] STKScalar))) (STKProduct (STKS [2,5] STKScalar) (STKR (SNat @1) STKScalar))) (let v5 = scast (sdot1In (sconcrete (sreplicate [4,3] 7.0)) (sfromR (tproject1 (tproject1 (tproject1 m1)))) + sfromR (tproject2 (tproject1 (tproject1 m1)))) ; v8 = sdot1In (str (sfromR (tproject1 (tproject2 m1)))) (sreplicate @5 (sfromR dret)) ; v9 = scast v8 ; v10 = scast (sdot1In (str (sfromR (tproject1 (tproject2 (tproject1 m1))))) (sreplicate @4 v9)) in tpair (tpair (tpair (sconcrete (sreplicate [4,3] 7.0) * str (sreplicate @3 v10)) v10) (tpair (sreplicate @5 v5 * str (sreplicate @4 v9)) v8)) (tpair (sreplicate @2 (scast (sdot1In (sreplicate @5 v5) (sfromR (tproject1 (tproject2 (tproject1 m1))))) + sfromR (tproject2 (tproject2 (tproject1 m1)))) * str (sreplicate @5 (sfromR dret))) dret))"
testVT2OAst :: Assertion
testVT2OAst :: Assertion
testVT2OAst = do
let ftk :: FullShapeTK
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
ftk = forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams2 Double Float))
(forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget @Concrete ADFcnnMnist2Parameters Concrete Double Float
valsInitVT2OPP)
varName :: AstVarName
s
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
varName = FullShapeTK
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
-> Maybe (Int64, Int64)
-> AstVarId
-> AstVarName
s
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
forall (s :: AstSpanType) (y :: TK).
FullShapeTK y -> Maybe (Int64, Int64) -> AstVarId -> AstVarName s y
mkAstVarName FullShapeTK
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
ftk Maybe (Int64, Int64)
forall a. Maybe a
Nothing (AstVarId
-> AstVarName
s
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
-> (Int -> AstVarId)
-> Int
-> AstVarName
s
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> AstVarId
intToAstVarId (Int
-> AstVarName
s
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
-> Int
-> AstVarName
s
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
forall a b. (a -> b) -> a -> b
$ Int
100000000
var :: AstTensor AstMethodLet FullSpan (XParams2 Double Float)
var :: AstTensor AstMethodLet FullSpan (XParams2 Double Float)
var = AstVarName
FullSpan
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
-> AstTensor
AstMethodLet
FullSpan
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
forall (b :: AstSpanType) (c :: TK) (a :: AstMethodOfSharing).
AstVarName b c -> AstTensor a b c
AstVar AstVarName
FullSpan
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
forall {s :: AstSpanType}.
AstVarName
s
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
varName
vals :: Concrete (XParams2 Double Float)
vals = forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget @Concrete ADFcnnMnist2Parameters Concrete Double Float
valsInitVT2OPP
env :: AstEnv Concrete
env = AstVarName
(ZonkAny @AstSpanType 2)
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
-> AstEnv Concrete
-> AstEnv Concrete
forall (target :: Target) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName
(ZonkAny @AstSpanType 2)
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
forall {s :: AstSpanType}.
AstVarName
s
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
varName Concrete
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
vals AstEnv Concrete
forall (target :: Target). AstEnv target
emptyEnv
blackGlyph :: Concrete (BuildTensorKind 3 (TKR 0 Double))
blackGlyph = SNat 3
-> SingletonTK (TKR 0 Double)
-> Concrete (TKR 0 Double)
-> Concrete (BuildTensorKind 3 (TKR 0 Double))
forall (z :: TK) (k :: Natural).
ConvertTensor Concrete =>
SNat k
-> SingletonTK z -> Concrete z -> Concrete (BuildTensorKind k z)
forall (target :: Target) (z :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
treplicate (forall (n :: Natural). KnownNat n => SNat n
SNat @3) SingletonTK (TKR 0 Double)
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK (Concrete (TKR 0 Double)
-> Concrete (BuildTensorKind 3 (TKR 0 Double)))
-> Concrete (TKR 0 Double)
-> Concrete (BuildTensorKind 3 (TKR 0 Double))
forall a b. (a -> b) -> a -> b
$ Double -> Concrete (TKR 0 Double)
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
r -> target (TKR 0 r)
rscalar Double
7
afcnn2 :: ADReady f
=> MnistFcnnRanked2.ADFcnnMnist2Parameters f Double Float
-> f (TKR 1 Double)
afcnn2 :: forall (f :: Target).
ADReady f =>
ADFcnnMnist2Parameters f Double Float
-> f (TKR2 1 (TKScalar Double))
afcnn2 = (f (TKR2 1 (TKScalar Double)) -> f (TKR2 1 (TKScalar Double)))
-> (f (TKR2 1 (TKScalar Double)) -> f (TKR2 1 (TKScalar Double)))
-> f (TKR2 1 (TKScalar Double))
-> ADFcnnMnist2Parameters f Double Float
-> f (TKR2 1 (TKScalar Double))
forall (target :: Target) r q.
(ADReady target, GoodScalar r, Differentiable r, GoodScalar q,
Differentiable q) =>
(target (TKR 1 r) -> target (TKR 1 r))
-> (target (TKR 1 r) -> target (TKR 1 r))
-> target (TKR 1 r)
-> ADFcnnMnist2Parameters target r q
-> target (TKR 1 r)
MnistFcnnRanked2.afcnnMnist2
f (TKR2 1 (TKScalar Double)) -> f (TKR2 1 (TKScalar Double))
forall a. a -> a
id f (TKR2 1 (TKScalar Double)) -> f (TKR2 1 (TKScalar Double))
forall a. a -> a
id
(Ranked 1 Double -> f (TKR2 1 (TKScalar Double))
forall r (target :: Target) (n :: Natural).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete (Ranked 1 Double -> f (TKR2 1 (TKScalar Double)))
-> Ranked 1 Double -> f (TKR2 1 (TKScalar Double))
forall a b. (a -> b) -> a -> b
$ Concrete (TKR2 1 (TKScalar Double))
-> RepConcrete (TKR2 1 (TKScalar Double))
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete (TKR2 1 (TKScalar Double))
blackGlyph)
afcnn1 :: AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
afcnn1 = ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Double Float
-> AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
forall (f :: Target).
ADReady f =>
ADFcnnMnist2Parameters f Double Float
-> f (TKR2 1 (TKScalar Double))
afcnn2 (ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Double Float
-> AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double)))
-> ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Double Float
-> AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
forall a b. (a -> b) -> a -> b
$ AstTensor
AstMethodLet
FullSpan
(X (ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Double Float))
-> ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Double Float
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget AstTensor
AstMethodLet
FullSpan
(X (ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Double Float))
AstTensor AstMethodLet FullSpan (XParams2 Double Float)
var
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target -> AstTensor AstMethodLet FullSpan y -> target y
interpretAstFull @Concrete AstEnv Concrete
env AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
afcnn1
Concrete (TKR2 1 (TKScalar Double))
-> Concrete (TKR2 1 (TKScalar Double)) -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= ADFcnnMnist2Parameters Concrete Double Float
-> Concrete (TKR2 1 (TKScalar Double))
forall (f :: Target).
ADReady f =>
ADFcnnMnist2Parameters f Double Float
-> f (TKR2 1 (TKScalar Double))
afcnn2 ADFcnnMnist2Parameters Concrete Double Float
valsInitVT2OPP
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target -> AstTensor AstMethodLet FullSpan y -> target y
interpretAstFull @Concrete AstEnv Concrete
env
(forall (z :: TK) (s :: AstSpanType).
AstSpan s =>
AstTensor AstMethodLet s z -> AstTensor AstMethodLet s z
simplifyInline @(TKR 1 Double) AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
afcnn1)
Concrete (TKR2 1 (TKScalar Double))
-> Concrete (TKR2 1 (TKScalar Double)) -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= ADFcnnMnist2Parameters Concrete Double Float
-> Concrete (TKR2 1 (TKScalar Double))
forall (f :: Target).
ADReady f =>
ADFcnnMnist2Parameters f Double Float
-> f (TKR2 1 (TKScalar Double))
afcnn2 ADFcnnMnist2Parameters Concrete Double Float
valsInitVT2OPP
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target -> AstTensor AstMethodLet FullSpan y -> target y
interpretAstFull @Concrete AstEnv Concrete
env
(forall (z :: TK) (s :: AstSpanType).
AstSpan s =>
AstTensor AstMethodLet s z -> AstTensor AstMethodLet s z
simplifyInlineContract @(TKR 1 Double) AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
afcnn1)
Concrete (TKR2 1 (TKScalar Double))
-> Concrete (TKR2 1 (TKScalar Double)) -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= ADFcnnMnist2Parameters Concrete Double Float
-> Concrete (TKR2 1 (TKScalar Double))
forall (f :: Target).
ADReady f =>
ADFcnnMnist2Parameters f Double Float
-> f (TKR2 1 (TKScalar Double))
afcnn2 ADFcnnMnist2Parameters Concrete Double Float
valsInitVT2OPP
testVT2OPPNonLin :: Assertion
testVT2OPPNonLin :: Assertion
testVT2OPPNonLin = do
Assertion
resetVarCounter
let blackGlyph :: AstTensor AstMethodLet FullSpan (BuildTensorKind 3 (TKR 0 Float))
blackGlyph = SNat 3
-> SingletonTK (TKR 0 Float)
-> AstTensor AstMethodLet FullSpan (TKR 0 Float)
-> AstTensor
AstMethodLet FullSpan (BuildTensorKind 3 (TKR 0 Float))
forall (z :: TK) (k :: Natural).
ConvertTensor (AstTensor AstMethodLet FullSpan) =>
SNat k
-> SingletonTK z
-> AstTensor AstMethodLet FullSpan z
-> AstTensor AstMethodLet FullSpan (BuildTensorKind k z)
forall (target :: Target) (z :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
treplicate (forall (n :: Natural). KnownNat n => SNat n
SNat @3) SingletonTK (TKR 0 Float)
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK
(AstTensor AstMethodLet FullSpan (TKR 0 Float)
-> AstTensor
AstMethodLet FullSpan (BuildTensorKind 3 (TKR 0 Float)))
-> AstTensor AstMethodLet FullSpan (TKR 0 Float)
-> AstTensor
AstMethodLet FullSpan (BuildTensorKind 3 (TKR 0 Float))
forall a b. (a -> b) -> a -> b
$ AstTensor AstMethodLet PrimalSpan (TKR 0 Float)
-> AstTensor AstMethodLet FullSpan (TKR 0 Float)
forall (ms :: AstMethodOfSharing) (y :: TK).
AstTensor ms PrimalSpan y -> AstTensor ms FullSpan y
forall (s :: AstSpanType) (ms :: AstMethodOfSharing) (y :: TK).
AstSpan s =>
AstTensor ms PrimalSpan y -> AstTensor ms s y
fromPrimal (AstTensor AstMethodLet PrimalSpan (TKR 0 Float)
-> AstTensor AstMethodLet FullSpan (TKR 0 Float))
-> AstTensor AstMethodLet PrimalSpan (TKR 0 Float)
-> AstTensor AstMethodLet FullSpan (TKR 0 Float)
forall a b. (a -> b) -> a -> b
$ Ranked 0 Float -> AstTensor AstMethodLet PrimalSpan (TKR 0 Float)
forall r (target :: Target) (n :: Natural).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete (Ranked 0 Float -> AstTensor AstMethodLet PrimalSpan (TKR 0 Float))
-> Ranked 0 Float
-> AstTensor AstMethodLet PrimalSpan (TKR 0 Float)
forall a b. (a -> b) -> a -> b
$ Float -> Ranked 0 Float
forall a. Elt a => a -> Ranked 0 a
Nested.rscalar Float
7
afcnn2TnonLin :: MnistFcnnRanked2.ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Float Float
-> AstTensor AstMethodLet FullSpan (TKR 1 Float)
afcnn2TnonLin :: ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Float Float
-> AstTensor AstMethodLet FullSpan (TKR 1 Float)
afcnn2TnonLin = (AstTensor AstMethodLet FullSpan (TKR 1 Float)
-> AstTensor AstMethodLet FullSpan (TKR 1 Float))
-> (AstTensor AstMethodLet FullSpan (TKR 1 Float)
-> AstTensor AstMethodLet FullSpan (TKR 1 Float))
-> AstTensor AstMethodLet FullSpan (TKR 1 Float)
-> ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Float Float
-> AstTensor AstMethodLet FullSpan (TKR 1 Float)
forall (target :: Target) r q.
(ADReady target, GoodScalar r, Differentiable r, GoodScalar q,
Differentiable q) =>
(target (TKR 1 r) -> target (TKR 1 r))
-> (target (TKR 1 r) -> target (TKR 1 r))
-> target (TKR 1 r)
-> ADFcnnMnist2Parameters target r q
-> target (TKR 1 r)
MnistFcnnRanked2.afcnnMnist2 AstTensor AstMethodLet FullSpan (TKR 1 Float)
-> AstTensor AstMethodLet FullSpan (TKR 1 Float)
forall (target :: Target) r (n :: Natural).
(BaseTensor target, LetTensor target, BaseTensor (PrimalOf target),
KnownNat n, GoodScalar r, Differentiable r) =>
target (TKR n r) -> target (TKR n r)
logistic AstTensor AstMethodLet FullSpan (TKR 1 Float)
-> AstTensor AstMethodLet FullSpan (TKR 1 Float)
forall (target :: Target) (n :: Natural) r.
(BaseTensor target, LetTensor target, KnownNat n, GoodScalar r,
Differentiable r) =>
target (TKR n r) -> target (TKR n r)
softMax1 AstTensor AstMethodLet FullSpan (TKR 1 Float)
blackGlyph
constant :: ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Float Float
constant =
let ((Concrete (TKR 2 Double)
a1, Concrete (TKR2 1 (TKScalar Double))
a2), (Concrete (TKR 2 Float)
a3, Concrete (TKR2 1 (TKScalar Double))
a4), (Concrete (TKR 2 Double)
a5, Concrete (TKR2 1 (TKScalar Double))
a6)) = ADFcnnMnist2Parameters Concrete Double Float
valsInitVT2OPP
in ( ( AstTensor AstMethodLet FullSpan (TKR 2 Double)
-> AstTensor AstMethodLet FullSpan (TKR 2 Float)
forall r1 r2 (target :: Target) (n :: Natural).
(RealFrac r1, GoodScalar r1, RealFrac r2, GoodScalar r2,
BaseTensor target) =>
target (TKR n r1) -> target (TKR n r2)
rcast (AstTensor AstMethodLet FullSpan (TKR 2 Double)
-> AstTensor AstMethodLet FullSpan (TKR 2 Float))
-> AstTensor AstMethodLet FullSpan (TKR 2 Double)
-> AstTensor AstMethodLet FullSpan (TKR 2 Float)
forall a b. (a -> b) -> a -> b
$ AstTensor AstMethodLet PrimalSpan (TKR 2 Double)
-> AstTensor AstMethodLet FullSpan (TKR 2 Double)
forall (ms :: AstMethodOfSharing) (y :: TK).
AstTensor ms PrimalSpan y -> AstTensor ms FullSpan y
forall (s :: AstSpanType) (ms :: AstMethodOfSharing) (y :: TK).
AstSpan s =>
AstTensor ms PrimalSpan y -> AstTensor ms s y
fromPrimal (AstTensor AstMethodLet PrimalSpan (TKR 2 Double)
-> AstTensor AstMethodLet FullSpan (TKR 2 Double))
-> AstTensor AstMethodLet PrimalSpan (TKR 2 Double)
-> AstTensor AstMethodLet FullSpan (TKR 2 Double)
forall a b. (a -> b) -> a -> b
$ Ranked 2 Double -> AstTensor AstMethodLet PrimalSpan (TKR 2 Double)
forall r (target :: Target) (n :: Natural).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete (Ranked 2 Double
-> AstTensor AstMethodLet PrimalSpan (TKR 2 Double))
-> Ranked 2 Double
-> AstTensor AstMethodLet PrimalSpan (TKR 2 Double)
forall a b. (a -> b) -> a -> b
$ Concrete (TKR 2 Double) -> RepConcrete (TKR 2 Double)
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete (TKR 2 Double)
a1
, AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
-> AstTensor AstMethodLet FullSpan (TKR 1 Float)
forall r1 r2 (target :: Target) (n :: Natural).
(RealFrac r1, GoodScalar r1, RealFrac r2, GoodScalar r2,
BaseTensor target) =>
target (TKR n r1) -> target (TKR n r2)
rcast (AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
-> AstTensor AstMethodLet FullSpan (TKR 1 Float))
-> AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
-> AstTensor AstMethodLet FullSpan (TKR 1 Float)
forall a b. (a -> b) -> a -> b
$ AstTensor AstMethodLet PrimalSpan (TKR2 1 (TKScalar Double))
-> AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
forall (ms :: AstMethodOfSharing) (y :: TK).
AstTensor ms PrimalSpan y -> AstTensor ms FullSpan y
forall (s :: AstSpanType) (ms :: AstMethodOfSharing) (y :: TK).
AstSpan s =>
AstTensor ms PrimalSpan y -> AstTensor ms s y
fromPrimal (AstTensor AstMethodLet PrimalSpan (TKR2 1 (TKScalar Double))
-> AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double)))
-> AstTensor AstMethodLet PrimalSpan (TKR2 1 (TKScalar Double))
-> AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
forall a b. (a -> b) -> a -> b
$ Ranked 1 Double
-> AstTensor AstMethodLet PrimalSpan (TKR2 1 (TKScalar Double))
forall r (target :: Target) (n :: Natural).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete (Ranked 1 Double
-> AstTensor AstMethodLet PrimalSpan (TKR2 1 (TKScalar Double)))
-> Ranked 1 Double
-> AstTensor AstMethodLet PrimalSpan (TKR2 1 (TKScalar Double))
forall a b. (a -> b) -> a -> b
$ Concrete (TKR2 1 (TKScalar Double))
-> RepConcrete (TKR2 1 (TKScalar Double))
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete (TKR2 1 (TKScalar Double))
a2 )
, ( AstTensor AstMethodLet PrimalSpan (TKR 2 Float)
-> AstTensor AstMethodLet FullSpan (TKR 2 Float)
forall (ms :: AstMethodOfSharing) (y :: TK).
AstTensor ms PrimalSpan y -> AstTensor ms FullSpan y
forall (s :: AstSpanType) (ms :: AstMethodOfSharing) (y :: TK).
AstSpan s =>
AstTensor ms PrimalSpan y -> AstTensor ms s y
fromPrimal (AstTensor AstMethodLet PrimalSpan (TKR 2 Float)
-> AstTensor AstMethodLet FullSpan (TKR 2 Float))
-> AstTensor AstMethodLet PrimalSpan (TKR 2 Float)
-> AstTensor AstMethodLet FullSpan (TKR 2 Float)
forall a b. (a -> b) -> a -> b
$ AstTensor AstMethodLet PrimalSpan (TKR 2 Float)
-> AstTensor AstMethodLet PrimalSpan (TKR 2 Float)
forall r1 r2 (target :: Target) (n :: Natural).
(RealFrac r1, GoodScalar r1, RealFrac r2, GoodScalar r2,
BaseTensor target) =>
target (TKR n r1) -> target (TKR n r2)
rcast (AstTensor AstMethodLet PrimalSpan (TKR 2 Float)
-> AstTensor AstMethodLet PrimalSpan (TKR 2 Float))
-> AstTensor AstMethodLet PrimalSpan (TKR 2 Float)
-> AstTensor AstMethodLet PrimalSpan (TKR 2 Float)
forall a b. (a -> b) -> a -> b
$ Ranked 2 Float -> AstTensor AstMethodLet PrimalSpan (TKR 2 Float)
forall r (target :: Target) (n :: Natural).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete (Ranked 2 Float -> AstTensor AstMethodLet PrimalSpan (TKR 2 Float))
-> Ranked 2 Float
-> AstTensor AstMethodLet PrimalSpan (TKR 2 Float)
forall a b. (a -> b) -> a -> b
$ Concrete (TKR 2 Float) -> RepConcrete (TKR 2 Float)
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete (TKR 2 Float)
a3
, AstTensor AstMethodLet PrimalSpan (TKR 1 Float)
-> AstTensor AstMethodLet FullSpan (TKR 1 Float)
forall (ms :: AstMethodOfSharing) (y :: TK).
AstTensor ms PrimalSpan y -> AstTensor ms FullSpan y
forall (s :: AstSpanType) (ms :: AstMethodOfSharing) (y :: TK).
AstSpan s =>
AstTensor ms PrimalSpan y -> AstTensor ms s y
fromPrimal (AstTensor AstMethodLet PrimalSpan (TKR 1 Float)
-> AstTensor AstMethodLet FullSpan (TKR 1 Float))
-> AstTensor AstMethodLet PrimalSpan (TKR 1 Float)
-> AstTensor AstMethodLet FullSpan (TKR 1 Float)
forall a b. (a -> b) -> a -> b
$ AstTensor AstMethodLet PrimalSpan (TKR2 1 (TKScalar Double))
-> AstTensor AstMethodLet PrimalSpan (TKR 1 Float)
forall r1 r2 (target :: Target) (n :: Natural).
(RealFrac r1, GoodScalar r1, RealFrac r2, GoodScalar r2,
BaseTensor target) =>
target (TKR n r1) -> target (TKR n r2)
rcast (AstTensor AstMethodLet PrimalSpan (TKR2 1 (TKScalar Double))
-> AstTensor AstMethodLet PrimalSpan (TKR 1 Float))
-> AstTensor AstMethodLet PrimalSpan (TKR2 1 (TKScalar Double))
-> AstTensor AstMethodLet PrimalSpan (TKR 1 Float)
forall a b. (a -> b) -> a -> b
$ Ranked 1 Double
-> AstTensor AstMethodLet PrimalSpan (TKR2 1 (TKScalar Double))
forall r (target :: Target) (n :: Natural).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete (Ranked 1 Double
-> AstTensor AstMethodLet PrimalSpan (TKR2 1 (TKScalar Double)))
-> Ranked 1 Double
-> AstTensor AstMethodLet PrimalSpan (TKR2 1 (TKScalar Double))
forall a b. (a -> b) -> a -> b
$ Concrete (TKR2 1 (TKScalar Double))
-> RepConcrete (TKR2 1 (TKScalar Double))
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete (TKR2 1 (TKScalar Double))
a4 )
, ( AstTensor AstMethodLet FullSpan (TKR 2 Double)
-> AstTensor AstMethodLet FullSpan (TKR 2 Float)
forall r1 r2 (target :: Target) (n :: Natural).
(RealFrac r1, GoodScalar r1, RealFrac r2, GoodScalar r2,
BaseTensor target) =>
target (TKR n r1) -> target (TKR n r2)
rcast (AstTensor AstMethodLet FullSpan (TKR 2 Double)
-> AstTensor AstMethodLet FullSpan (TKR 2 Float))
-> AstTensor AstMethodLet FullSpan (TKR 2 Double)
-> AstTensor AstMethodLet FullSpan (TKR 2 Float)
forall a b. (a -> b) -> a -> b
$ AstTensor AstMethodLet PrimalSpan (TKR 2 Double)
-> AstTensor AstMethodLet FullSpan (TKR 2 Double)
forall (ms :: AstMethodOfSharing) (y :: TK).
AstTensor ms PrimalSpan y -> AstTensor ms FullSpan y
forall (s :: AstSpanType) (ms :: AstMethodOfSharing) (y :: TK).
AstSpan s =>
AstTensor ms PrimalSpan y -> AstTensor ms s y
fromPrimal (AstTensor AstMethodLet PrimalSpan (TKR 2 Double)
-> AstTensor AstMethodLet FullSpan (TKR 2 Double))
-> AstTensor AstMethodLet PrimalSpan (TKR 2 Double)
-> AstTensor AstMethodLet FullSpan (TKR 2 Double)
forall a b. (a -> b) -> a -> b
$ Ranked 2 Double -> AstTensor AstMethodLet PrimalSpan (TKR 2 Double)
forall r (target :: Target) (n :: Natural).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete (Ranked 2 Double
-> AstTensor AstMethodLet PrimalSpan (TKR 2 Double))
-> Ranked 2 Double
-> AstTensor AstMethodLet PrimalSpan (TKR 2 Double)
forall a b. (a -> b) -> a -> b
$ Concrete (TKR 2 Double) -> RepConcrete (TKR 2 Double)
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete (TKR 2 Double)
a5
, AstTensor AstMethodLet PrimalSpan (TKR 1 Float)
-> AstTensor AstMethodLet FullSpan (TKR 1 Float)
forall (ms :: AstMethodOfSharing) (y :: TK).
AstTensor ms PrimalSpan y -> AstTensor ms FullSpan y
forall (s :: AstSpanType) (ms :: AstMethodOfSharing) (y :: TK).
AstSpan s =>
AstTensor ms PrimalSpan y -> AstTensor ms s y
fromPrimal (AstTensor AstMethodLet PrimalSpan (TKR 1 Float)
-> AstTensor AstMethodLet FullSpan (TKR 1 Float))
-> AstTensor AstMethodLet PrimalSpan (TKR 1 Float)
-> AstTensor AstMethodLet FullSpan (TKR 1 Float)
forall a b. (a -> b) -> a -> b
$ AstTensor AstMethodLet PrimalSpan (TKR2 1 (TKScalar Double))
-> AstTensor AstMethodLet PrimalSpan (TKR 1 Float)
forall r1 r2 (target :: Target) (n :: Natural).
(RealFrac r1, GoodScalar r1, RealFrac r2, GoodScalar r2,
BaseTensor target) =>
target (TKR n r1) -> target (TKR n r2)
rcast (AstTensor AstMethodLet PrimalSpan (TKR2 1 (TKScalar Double))
-> AstTensor AstMethodLet PrimalSpan (TKR 1 Float))
-> AstTensor AstMethodLet PrimalSpan (TKR2 1 (TKScalar Double))
-> AstTensor AstMethodLet PrimalSpan (TKR 1 Float)
forall a b. (a -> b) -> a -> b
$ Ranked 1 Double
-> AstTensor AstMethodLet PrimalSpan (TKR2 1 (TKScalar Double))
forall r (target :: Target) (n :: Natural).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete (Ranked 1 Double
-> AstTensor AstMethodLet PrimalSpan (TKR2 1 (TKScalar Double)))
-> Ranked 1 Double
-> AstTensor AstMethodLet PrimalSpan (TKR2 1 (TKScalar Double))
forall a b. (a -> b) -> a -> b
$ Concrete (TKR2 1 (TKScalar Double))
-> RepConcrete (TKR2 1 (TKScalar Double))
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete (TKR2 1 (TKScalar Double))
a6 ) )
ast3 :: AstTensor AstMethodLet FullSpan (TKR 1 Float)
ast3 = FullShapeTK (TKR 1 Float)
-> (AstVarName FullSpan (TKR 1 Float)
-> AstTensor AstMethodLet FullSpan (TKR 1 Float))
-> AstTensor AstMethodLet FullSpan (TKR 1 Float)
forall (y :: TK) (s :: AstSpanType) (ms :: AstMethodOfSharing).
FullShapeTK y
-> (AstVarName s y -> AstTensor ms s y) -> AstTensor ms s y
fun1ToAst (IShR 1 -> FullShapeTK (TKScalar Float) -> FullShapeTK (TKR 1 Float)
forall (n :: Natural) (x :: TK).
IShR n -> FullShapeTK x -> FullShapeTK (TKR2 n x)
FTKR (Int
0 Int -> ShR 0 Int -> IShR 1
forall {n1 :: Natural} {i} (n :: Natural).
((n + 1 :: Natural) ~ (n1 :: Natural)) =>
i -> ShR n i -> ShR n1 i
:$: ShR 0 Int
forall (n :: Natural) i.
((n :: Natural) ~ (0 :: Natural)) =>
ShR n i
ZSR) (forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar @Float))
(AstTensor AstMethodLet FullSpan (TKR 1 Float)
-> AstVarName FullSpan (TKR 1 Float)
-> AstTensor AstMethodLet FullSpan (TKR 1 Float)
forall a b. a -> b -> a
const (AstTensor AstMethodLet FullSpan (TKR 1 Float)
-> AstVarName FullSpan (TKR 1 Float)
-> AstTensor AstMethodLet FullSpan (TKR 1 Float))
-> AstTensor AstMethodLet FullSpan (TKR 1 Float)
-> AstVarName FullSpan (TKR 1 Float)
-> AstTensor AstMethodLet FullSpan (TKR 1 Float)
forall a b. (a -> b) -> a -> b
$ ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Float Float
-> AstTensor AstMethodLet FullSpan (TKR 1 Float)
afcnn2TnonLin ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Float Float
constant)
String
"\\dummy" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" -> " String -> String -> String
forall a. [a] -> [a] -> [a]
++ AstTensor AstMethodLet FullSpan (TKR 1 Float) -> String
forall (s :: AstSpanType) (ms :: AstMethodOfSharing) (y :: TK).
AstSpan s =>
AstTensor ms s y -> String
printAstSimple AstTensor AstMethodLet FullSpan (TKR 1 Float)
ast3
String -> String -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= String
"\\dummy -> rfromS (tlet (exp (ssum @5 (str (sreplicate @2 (tlet (ssum @4 (str (sreplicate @5 (tlet (tfromPrimal (STKS [4] STKScalar) (ssum @3 (sconcrete (sreplicate [3,4] 7.0) * str (scast (sconcrete (sfromListLinear [4,3] [1.0,2.0,3.0,1.0,2.0,3.0,1.0,2.0,3.0,1.0,2.0,3.0])))) + scast (sconcrete (sfromListLinear [4] [1.0,2.0,3.0,4.0])))) (\\v5 -> ttletPrimal (recip (sconcrete (sreplicate [4] 1.0) + exp (negate (sfromR (tprimalPart (rfromS v5)))))) (\\v6 -> tfromPrimal (STKS [4] STKScalar) v6 + sfromR (tfromDual (tdualPart (STKR (SNat @1) STKScalar) (rfromS (tfromPrimal (STKS [4] STKScalar) (v6 * (sconcrete (sreplicate [4] 1.0) + negate v6)) * sfromR (tfromDual (tdualPart (STKR (SNat @1) STKScalar) (rfromS v5))))))))))) * tfromPrimal (STKS [4,5] STKScalar) (sconcrete (sfromListLinear [4,5] [1.0,1.0,1.0,1.0,1.0,2.0,2.0,2.0,2.0,2.0,3.0,3.0,3.0,3.0,3.0,4.0,4.0,4.0,4.0,4.0]))) + tfromPrimal (STKS [5] STKScalar) (scast (sconcrete (sfromListLinear [5] [1.0,2.0,3.0,4.0,5.0])))) (\\v7 -> ttletPrimal (recip (sconcrete (sreplicate [5] 1.0) + exp (negate (sfromR (tprimalPart (rfromS v7)))))) (\\v8 -> tfromPrimal (STKS [5] STKScalar) v8 + sfromR (tfromDual (tdualPart (STKR (SNat @1) STKScalar) (rfromS (tfromPrimal (STKS [5] STKScalar) (v8 * (sconcrete (sreplicate [5] 1.0) + negate v8)) * sfromR (tfromDual (tdualPart (STKR (SNat @1) STKScalar) (rfromS v7))))))))))) * tfromPrimal (STKS [5,2] STKScalar) (str (scast (sconcrete (sfromListLinear [2,5] [1.0,2.0,3.0,4.0,5.0,1.0,2.0,3.0,4.0,5.0]))))) + tfromPrimal (STKS [2] STKScalar) (scast (sconcrete (sfromListLinear [2] [1.0,2.0]))))) (\\v9 -> sreplicate @2 (recip (ssum @2 v9)) * v9))"
String
"\\dummy" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" -> " String -> String -> String
forall a. [a] -> [a] -> [a]
++ AstTensor AstMethodLet FullSpan (TKR 1 Float) -> String
forall (s :: AstSpanType) (ms :: AstMethodOfSharing) (y :: TK).
AstSpan s =>
AstTensor ms s y -> String
printAstSimple (AstTensor AstMethodLet FullSpan (TKR 1 Float)
-> AstTensor AstMethodLet FullSpan (TKR 1 Float)
forall (z :: TK) (s :: AstSpanType).
AstSpan s =>
AstTensor AstMethodLet s z -> AstTensor AstMethodLet s z
simplifyInlineContract AstTensor AstMethodLet FullSpan (TKR 1 Float)
ast3)
String -> String -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= String
"\\dummy -> rfromS (tlet (exp (sdot1In (sreplicate @2 (tlet (sdot1In (sreplicate @5 (ttletPrimal (recip (sconcrete (sreplicate [4] 1.0) + exp (sconcrete (sfromListLinear [4] [-43.0,-44.0,-45.0,-46.0])))) (\\v6 -> tfromPrimal (STKS [4] STKScalar) v6 + tfromDual (tdualPart (STKS [4] STKScalar) (tfromPrimal (STKS [4] STKScalar) (v6 * (sconcrete (sreplicate [4] 1.0) + negate v6)) * tfromDual (tdualPart (STKS [4] STKScalar) (tfromPrimal (STKS [4] STKScalar) (sconcrete (sreplicate [4] 0.0))))))))) (tfromPrimal (STKS [5,4] STKScalar) (sconcrete (sfromListLinear [5,4] [1.0,2.0,3.0,4.0,1.0,2.0,3.0,4.0,1.0,2.0,3.0,4.0,1.0,2.0,3.0,4.0,1.0,2.0,3.0,4.0]))) + tfromPrimal (STKS [5] STKScalar) (sconcrete (sfromListLinear [5] [1.0,2.0,3.0,4.0,5.0]))) (\\v7 -> ttletPrimal (recip (sconcrete (sreplicate [5] 1.0) + exp (negate (tprimalPart v7)))) (\\v8 -> tfromPrimal (STKS [5] STKScalar) v8 + tfromDual (tdualPart (STKS [5] STKScalar) (tfromPrimal (STKS [5] STKScalar) (v8 * (sconcrete (sreplicate [5] 1.0) + negate v8)) * tfromDual (tdualPart (STKS [5] STKScalar) v7))))))) (tfromPrimal (STKS [2,5] STKScalar) (sconcrete (sfromListLinear [2,5] [1.0,2.0,3.0,4.0,5.0,1.0,2.0,3.0,4.0,5.0]))) + tfromPrimal (STKS [2] STKScalar) (sconcrete (sfromListLinear [2] [1.0,2.0])))) (\\v9 -> sreplicate @2 (recip (ssum0 v9)) * v9))"
testVT2OPPNonLin2 :: Assertion
testVT2OPPNonLin2 :: Assertion
testVT2OPPNonLin2 = do
Assertion
resetVarCounter
let blackGlyph :: AstTensor AstMethodLet FullSpan (BuildTensorKind 3 (TKR 0 Double))
blackGlyph = SNat 3
-> SingletonTK (TKR 0 Double)
-> AstTensor AstMethodLet FullSpan (TKR 0 Double)
-> AstTensor
AstMethodLet FullSpan (BuildTensorKind 3 (TKR 0 Double))
forall (z :: TK) (k :: Natural).
ConvertTensor (AstTensor AstMethodLet FullSpan) =>
SNat k
-> SingletonTK z
-> AstTensor AstMethodLet FullSpan z
-> AstTensor AstMethodLet FullSpan (BuildTensorKind k z)
forall (target :: Target) (z :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
treplicate (forall (n :: Natural). KnownNat n => SNat n
SNat @3) SingletonTK (TKR 0 Double)
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK
(AstTensor AstMethodLet FullSpan (TKR 0 Double)
-> AstTensor
AstMethodLet FullSpan (BuildTensorKind 3 (TKR 0 Double)))
-> AstTensor AstMethodLet FullSpan (TKR 0 Double)
-> AstTensor
AstMethodLet FullSpan (BuildTensorKind 3 (TKR 0 Double))
forall a b. (a -> b) -> a -> b
$ AstTensor AstMethodLet PrimalSpan (TKR 0 Double)
-> AstTensor AstMethodLet FullSpan (TKR 0 Double)
forall (ms :: AstMethodOfSharing) (y :: TK).
AstTensor ms PrimalSpan y -> AstTensor ms FullSpan y
forall (s :: AstSpanType) (ms :: AstMethodOfSharing) (y :: TK).
AstSpan s =>
AstTensor ms PrimalSpan y -> AstTensor ms s y
fromPrimal (AstTensor AstMethodLet PrimalSpan (TKR 0 Double)
-> AstTensor AstMethodLet FullSpan (TKR 0 Double))
-> AstTensor AstMethodLet PrimalSpan (TKR 0 Double)
-> AstTensor AstMethodLet FullSpan (TKR 0 Double)
forall a b. (a -> b) -> a -> b
$ Ranked 0 Double -> AstTensor AstMethodLet PrimalSpan (TKR 0 Double)
forall r (target :: Target) (n :: Natural).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete (Ranked 0 Double
-> AstTensor AstMethodLet PrimalSpan (TKR 0 Double))
-> Ranked 0 Double
-> AstTensor AstMethodLet PrimalSpan (TKR 0 Double)
forall a b. (a -> b) -> a -> b
$ Double -> Ranked 0 Double
forall a. Elt a => a -> Ranked 0 a
Nested.rscalar Double
7
afcnn2TnonLin :: MnistFcnnRanked2.ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Double Float
-> AstTensor AstMethodLet FullSpan (TKR 1 Double)
afcnn2TnonLin :: ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Double Float
-> AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
afcnn2TnonLin = (AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
-> AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double)))
-> (AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
-> AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double)))
-> AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
-> ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Double Float
-> AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
forall (target :: Target) r q.
(ADReady target, GoodScalar r, Differentiable r, GoodScalar q,
Differentiable q) =>
(target (TKR 1 r) -> target (TKR 1 r))
-> (target (TKR 1 r) -> target (TKR 1 r))
-> target (TKR 1 r)
-> ADFcnnMnist2Parameters target r q
-> target (TKR 1 r)
MnistFcnnRanked2.afcnnMnist2 AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
-> AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
forall (target :: Target) r (n :: Natural).
(BaseTensor target, LetTensor target, BaseTensor (PrimalOf target),
KnownNat n, GoodScalar r, Differentiable r) =>
target (TKR n r) -> target (TKR n r)
logistic AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
-> AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
forall (target :: Target) (n :: Natural) r.
(BaseTensor target, LetTensor target, KnownNat n, GoodScalar r,
Differentiable r) =>
target (TKR n r) -> target (TKR n r)
softMax1 AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
blackGlyph
ftk :: FullShapeTK
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
ftk = forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams2 Double Float))
(forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget @Concrete ADFcnnMnist2Parameters Concrete Double Float
valsInitVT2OPP)
artifactRevnonLin :: AstArtifactRev
(X (ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Double Float))
(TKR2 1 (TKScalar Double))
artifactRevnonLin =
IncomingCotangentHandling
-> (ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Double Float
-> AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double)))
-> FullShapeTK
(X (ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Double Float))
-> AstArtifactRev
(X (ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Double Float))
(TKR2 1 (TKScalar Double))
forall src (ztgt :: TK) tgt.
(AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
(tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
IncomingCotangentHandling
-> (src -> tgt)
-> FullShapeTK (X src)
-> AstArtifactRev (X src) ztgt
revArtifactAdapt IncomingCotangentHandling
UseIncomingCotangent ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Double Float
-> AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
afcnn2TnonLin FullShapeTK
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
FullShapeTK
(X (ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Double Float))
ftk
AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR2 1 (TKScalar Double))
-> String
forall (x :: TK) (z :: TK). AstArtifactRev x z -> String
printArtifactPrimalPretty (AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR2 1 (TKScalar Double))
-> AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR2 1 (TKScalar Double))
forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstArtifactRev x z
simplifyArtifact AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR2 1 (TKScalar Double))
artifactRevnonLin)
String -> String -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= String
"\\m1 -> rfromS (let v23 = exp (sdot1In (sreplicate @2 (recip (sconcrete (sreplicate [5] 1.0) + exp (negate (scast (sdot1In (sreplicate @5 (scast (recip (sconcrete (sreplicate [4] 1.0) + exp (negate (sdot1In (sconcrete (sreplicate [4,3] 7.0)) (sfromR (tproject1 (tproject1 (tproject1 m1))))) + negate (sfromR (tproject2 (tproject1 (tproject1 m1))))))))) (sfromR (tproject1 (tproject2 (tproject1 m1)))))) + negate (sfromR (tproject2 (tproject2 (tproject1 m1)))))))) (sfromR (tproject1 (tproject2 m1))) + sfromR (tproject2 (tproject2 m1))) in sreplicate @2 (recip (ssum0 v23)) * v23)"
AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR2 1 (TKScalar Double))
-> String
forall (x :: TK) (z :: TK). AstArtifactRev x z -> String
printArtifactPrimalPretty AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR2 1 (TKScalar Double))
artifactRevnonLin
String -> String -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= String
"\\m1 -> let v10 = ssum @3 (sconcrete (sreplicate [3,4] 7.0) * str (sfromR (tproject1 (tproject1 (tproject1 m1))))) + sfromR (tproject2 (tproject1 (tproject1 m1))) ; v11 = exp (negate v10) ; v12 = sconcrete (sreplicate [4] 1.0) + v11 ; v13 = recip v12 ; m16 = str (sreplicate @5 (scast v13)) ; v17 = scast (ssum @4 (m16 * str (sfromR (tproject1 (tproject2 (tproject1 m1)))))) + sfromR (tproject2 (tproject2 (tproject1 m1))) ; v18 = exp (negate v17) ; v19 = sconcrete (sreplicate [5] 1.0) + v18 ; v20 = recip v19 ; v23 = exp (ssum @5 (str (sreplicate @2 v20) * str (sfromR (tproject1 (tproject2 m1)))) + sfromR (tproject2 (tproject2 m1))) ; x24 = ssum @2 v23 ; v25 = sreplicate @2 (recip x24) in rfromS (v25 * v23)"
AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR2 1 (TKScalar Double))
-> String
forall (x :: TK) (z :: TK). AstArtifactRev x z -> String
printArtifactPretty AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR2 1 (TKScalar Double))
artifactRevnonLin
String -> String -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= String
"\\dret m1 -> let v10 = ssum @3 (sconcrete (sreplicate [3,4] 7.0) * str (sfromR (tproject1 (tproject1 (tproject1 m1))))) + sfromR (tproject2 (tproject1 (tproject1 m1))) ; v11 = exp (negate v10) ; v12 = sconcrete (sreplicate [4] 1.0) + v11 ; v13 = recip v12 ; v14 = sconcrete (sreplicate [4] 1.0) + negate v13 ; v15 = v13 * v14 ; m16 = str (sreplicate @5 (scast v13)) ; v17 = scast (ssum @4 (m16 * str (sfromR (tproject1 (tproject2 (tproject1 m1)))))) + sfromR (tproject2 (tproject2 (tproject1 m1))) ; v18 = exp (negate v17) ; v19 = sconcrete (sreplicate [5] 1.0) + v18 ; v20 = recip v19 ; v21 = sconcrete (sreplicate [5] 1.0) + negate v20 ; v22 = v20 * v21 ; v23 = exp (ssum @5 (str (sreplicate @2 v20) * str (sfromR (tproject1 (tproject2 m1)))) + sfromR (tproject2 (tproject2 m1))) ; x24 = ssum @2 v23 ; v25 = sreplicate @2 (recip x24) ; v27 = v23 * (sreplicate @2 (negate (recip (x24 * x24)) * ssum @2 (v23 * sfromR dret)) + v25 * sfromR dret) ; v28 = v22 * ssum @2 (str (str (sfromR (tproject1 (tproject2 m1))) * sreplicate @5 v27)) ; m29 = sreplicate @4 (scast v28) ; v30 = v15 * scast (ssum @5 (str (str (sfromR (tproject1 (tproject2 (tproject1 m1)))) * m29))) in tpair (tpair (tpair (rfromS (str (sconcrete (sreplicate [3,4] 7.0) * sreplicate @3 v30))) (rfromS v30)) (tpair (rfromS (str (m16 * m29))) (rfromS v28))) (tpair (rfromS (str (str (sreplicate @2 v20) * sreplicate @5 v27))) (rfromS v27))"
AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR2 1 (TKScalar Double))
-> String
forall (x :: TK) (z :: TK). AstArtifactRev x z -> String
printArtifactPretty (AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR2 1 (TKScalar Double))
-> AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR2 1 (TKScalar Double))
forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstArtifactRev x z
simplifyArtifact AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR2 1 (TKScalar Double))
artifactRevnonLin)
String -> String -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= String
"\\dret m1 -> tconvert (ConvT2 (ConvT2 (ConvT2 (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [4,3] FTKScalar)) ConvSX)) (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [4] FTKScalar)) ConvSX))) (ConvT2 (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [5,4] FTKScalar)) ConvSX)) (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [5] FTKScalar)) ConvSX)))) (ConvT2 (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [2,5] FTKScalar)) ConvSX)) (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [2] FTKScalar)) ConvSX)))) (STKProduct (STKProduct (STKProduct (STKS [4,3] STKScalar) (STKS [4] STKScalar)) (STKProduct (STKS [5,4] STKScalar) (STKS [5] STKScalar))) (STKProduct (STKS [2,5] STKScalar) (STKS [2] STKScalar))) (let v13 = recip (sconcrete (sreplicate [4] 1.0) + exp (negate (sdot1In (sconcrete (sreplicate [4,3] 7.0)) (sfromR (tproject1 (tproject1 (tproject1 m1))))) + negate (sfromR (tproject2 (tproject1 (tproject1 m1)))))) ; v16 = scast v13 ; v20 = recip (sconcrete (sreplicate [5] 1.0) + exp (negate (scast (sdot1In (sreplicate @5 v16) (sfromR (tproject1 (tproject2 (tproject1 m1)))))) + negate (sfromR (tproject2 (tproject2 (tproject1 m1)))))) ; v23 = exp (sdot1In (sreplicate @2 v20) (sfromR (tproject1 (tproject2 m1))) + sfromR (tproject2 (tproject2 m1))) ; x24 = ssum0 v23 ; v27 = v23 * (sreplicate @2 (negate (recip (x24 * x24)) * sdot0 v23 (sfromR dret)) + sreplicate @2 (recip x24) * sfromR dret) ; v28 = (v20 * (sconcrete (sreplicate [5] 1.0) + negate v20)) * sdot1In (str (sfromR (tproject1 (tproject2 m1)))) (sreplicate @5 v27) ; v29 = scast v28 ; v30 = (v13 * (sconcrete (sreplicate [4] 1.0) + negate v13)) * scast (sdot1In (str (sfromR (tproject1 (tproject2 (tproject1 m1))))) (sreplicate @4 v29)) in tpair (tpair (tpair (sconcrete (sreplicate [4,3] 7.0) * str (sreplicate @3 v30)) v30) (tpair (sreplicate @5 v16 * str (sreplicate @4 v29)) v28)) (tpair (sreplicate @2 v20 * str (sreplicate @5 v27)) v27))"
testVT2OAstNonLin2 :: Assertion
testVT2OAstNonLin2 :: Assertion
testVT2OAstNonLin2 = do
let ftk :: FullShapeTK
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
ftk = forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams2 Double Float))
(forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget @Concrete ADFcnnMnist2Parameters Concrete Double Float
valsInitVT2OPP)
varName :: AstVarName
s
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
varName = FullShapeTK
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
-> Maybe (Int64, Int64)
-> AstVarId
-> AstVarName
s
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
forall (s :: AstSpanType) (y :: TK).
FullShapeTK y -> Maybe (Int64, Int64) -> AstVarId -> AstVarName s y
mkAstVarName FullShapeTK
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
ftk Maybe (Int64, Int64)
forall a. Maybe a
Nothing (AstVarId
-> AstVarName
s
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
-> (Int -> AstVarId)
-> Int
-> AstVarName
s
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> AstVarId
intToAstVarId (Int
-> AstVarName
s
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
-> Int
-> AstVarName
s
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
forall a b. (a -> b) -> a -> b
$ Int
100000000
var :: AstTensor AstMethodLet FullSpan (XParams2 Double Float)
var :: AstTensor AstMethodLet FullSpan (XParams2 Double Float)
var = AstVarName
FullSpan
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
-> AstTensor
AstMethodLet
FullSpan
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
forall (b :: AstSpanType) (c :: TK) (a :: AstMethodOfSharing).
AstVarName b c -> AstTensor a b c
AstVar AstVarName
FullSpan
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
forall {s :: AstSpanType}.
AstVarName
s
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
varName
vals :: Concrete (XParams2 Double Float)
vals = forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget @Concrete ADFcnnMnist2Parameters Concrete Double Float
valsInitVT2OPP
env :: AstEnv Concrete
env = AstVarName
(ZonkAny @AstSpanType 3)
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
-> AstEnv Concrete
-> AstEnv Concrete
forall (target :: Target) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName
(ZonkAny @AstSpanType 3)
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
forall {s :: AstSpanType}.
AstVarName
s
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
varName Concrete
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
vals AstEnv Concrete
forall (target :: Target). AstEnv target
emptyEnv
blackGlyph :: Concrete (BuildTensorKind 3 (TKR 0 Double))
blackGlyph = SNat 3
-> SingletonTK (TKR 0 Double)
-> Concrete (TKR 0 Double)
-> Concrete (BuildTensorKind 3 (TKR 0 Double))
forall (z :: TK) (k :: Natural).
ConvertTensor Concrete =>
SNat k
-> SingletonTK z -> Concrete z -> Concrete (BuildTensorKind k z)
forall (target :: Target) (z :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
treplicate (forall (n :: Natural). KnownNat n => SNat n
SNat @3) SingletonTK (TKR 0 Double)
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK (Concrete (TKR 0 Double)
-> Concrete (BuildTensorKind 3 (TKR 0 Double)))
-> Concrete (TKR 0 Double)
-> Concrete (BuildTensorKind 3 (TKR 0 Double))
forall a b. (a -> b) -> a -> b
$ Double -> Concrete (TKR 0 Double)
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
r -> target (TKR 0 r)
rscalar Double
7
afcnn2 :: ADReady f
=> MnistFcnnRanked2.ADFcnnMnist2Parameters f Double Float
-> f (TKR 1 Double)
afcnn2 :: forall (f :: Target).
ADReady f =>
ADFcnnMnist2Parameters f Double Float
-> f (TKR2 1 (TKScalar Double))
afcnn2 = (f (TKR2 1 (TKScalar Double)) -> f (TKR2 1 (TKScalar Double)))
-> (f (TKR2 1 (TKScalar Double)) -> f (TKR2 1 (TKScalar Double)))
-> f (TKR2 1 (TKScalar Double))
-> ADFcnnMnist2Parameters f Double Float
-> f (TKR2 1 (TKScalar Double))
forall (target :: Target) r q.
(ADReady target, GoodScalar r, Differentiable r, GoodScalar q,
Differentiable q) =>
(target (TKR 1 r) -> target (TKR 1 r))
-> (target (TKR 1 r) -> target (TKR 1 r))
-> target (TKR 1 r)
-> ADFcnnMnist2Parameters target r q
-> target (TKR 1 r)
MnistFcnnRanked2.afcnnMnist2
f (TKR2 1 (TKScalar Double)) -> f (TKR2 1 (TKScalar Double))
forall (target :: Target) r (n :: Natural).
(BaseTensor target, LetTensor target, BaseTensor (PrimalOf target),
KnownNat n, GoodScalar r, Differentiable r) =>
target (TKR n r) -> target (TKR n r)
logistic f (TKR2 1 (TKScalar Double)) -> f (TKR2 1 (TKScalar Double))
forall (target :: Target) (n :: Natural) r.
(BaseTensor target, LetTensor target, KnownNat n, GoodScalar r,
Differentiable r) =>
target (TKR n r) -> target (TKR n r)
softMax1
(Ranked 1 Double -> f (TKR2 1 (TKScalar Double))
forall r (target :: Target) (n :: Natural).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete (Ranked 1 Double -> f (TKR2 1 (TKScalar Double)))
-> Ranked 1 Double -> f (TKR2 1 (TKScalar Double))
forall a b. (a -> b) -> a -> b
$ Concrete (TKR2 1 (TKScalar Double))
-> RepConcrete (TKR2 1 (TKScalar Double))
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete (TKR2 1 (TKScalar Double))
blackGlyph)
afcnn1 :: AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
afcnn1 = ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Double Float
-> AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
forall (f :: Target).
ADReady f =>
ADFcnnMnist2Parameters f Double Float
-> f (TKR2 1 (TKScalar Double))
afcnn2 (ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Double Float
-> AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double)))
-> ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Double Float
-> AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
forall a b. (a -> b) -> a -> b
$ AstTensor
AstMethodLet
FullSpan
(X (ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Double Float))
-> ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Double Float
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget AstTensor
AstMethodLet
FullSpan
(X (ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Double Float))
AstTensor AstMethodLet FullSpan (XParams2 Double Float)
var
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target -> AstTensor AstMethodLet FullSpan y -> target y
interpretAstFull @Concrete AstEnv Concrete
env AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
afcnn1
Concrete (TKR2 1 (TKScalar Double))
-> Concrete (TKR2 1 (TKScalar Double)) -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= ADFcnnMnist2Parameters Concrete Double Float
-> Concrete (TKR2 1 (TKScalar Double))
forall (f :: Target).
ADReady f =>
ADFcnnMnist2Parameters f Double Float
-> f (TKR2 1 (TKScalar Double))
afcnn2 ADFcnnMnist2Parameters Concrete Double Float
valsInitVT2OPP
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target -> AstTensor AstMethodLet FullSpan y -> target y
interpretAstFull @Concrete AstEnv Concrete
env
(forall (z :: TK) (s :: AstSpanType).
AstSpan s =>
AstTensor AstMethodLet s z -> AstTensor AstMethodLet s z
simplifyInline @(TKR 1 Double) AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
afcnn1)
Concrete (TKR2 1 (TKScalar Double))
-> Concrete (TKR2 1 (TKScalar Double)) -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= ADFcnnMnist2Parameters Concrete Double Float
-> Concrete (TKR2 1 (TKScalar Double))
forall (f :: Target).
ADReady f =>
ADFcnnMnist2Parameters f Double Float
-> f (TKR2 1 (TKScalar Double))
afcnn2 ADFcnnMnist2Parameters Concrete Double Float
valsInitVT2OPP
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target -> AstTensor AstMethodLet FullSpan y -> target y
interpretAstFull @Concrete AstEnv Concrete
env
(forall (z :: TK) (s :: AstSpanType).
AstSpan s =>
AstTensor AstMethodLet s z -> AstTensor AstMethodLet s z
simplifyInlineContract @(TKR 1 Double) AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
afcnn1)
Concrete (TKR2 1 (TKScalar Double))
-> Concrete (TKR2 1 (TKScalar Double)) -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= ADFcnnMnist2Parameters Concrete Double Float
-> Concrete (TKR2 1 (TKScalar Double))
forall (f :: Target).
ADReady f =>
ADFcnnMnist2Parameters f Double Float
-> f (TKR2 1 (TKScalar Double))
afcnn2 ADFcnnMnist2Parameters Concrete Double Float
valsInitVT2OPP
testVT2OPPNonLin3 :: Assertion
testVT2OPPNonLin3 :: Assertion
testVT2OPPNonLin3 = do
Assertion
resetVarCounter
let blackGlyph :: AstTensor AstMethodLet FullSpan (BuildTensorKind 3 (TKR 0 Double))
blackGlyph = SNat 3
-> SingletonTK (TKR 0 Double)
-> AstTensor AstMethodLet FullSpan (TKR 0 Double)
-> AstTensor
AstMethodLet FullSpan (BuildTensorKind 3 (TKR 0 Double))
forall (z :: TK) (k :: Natural).
ConvertTensor (AstTensor AstMethodLet FullSpan) =>
SNat k
-> SingletonTK z
-> AstTensor AstMethodLet FullSpan z
-> AstTensor AstMethodLet FullSpan (BuildTensorKind k z)
forall (target :: Target) (z :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
treplicate (forall (n :: Natural). KnownNat n => SNat n
SNat @3) SingletonTK (TKR 0 Double)
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK
(AstTensor AstMethodLet FullSpan (TKR 0 Double)
-> AstTensor
AstMethodLet FullSpan (BuildTensorKind 3 (TKR 0 Double)))
-> AstTensor AstMethodLet FullSpan (TKR 0 Double)
-> AstTensor
AstMethodLet FullSpan (BuildTensorKind 3 (TKR 0 Double))
forall a b. (a -> b) -> a -> b
$ AstTensor AstMethodLet PrimalSpan (TKR 0 Double)
-> AstTensor AstMethodLet FullSpan (TKR 0 Double)
forall (ms :: AstMethodOfSharing) (y :: TK).
AstTensor ms PrimalSpan y -> AstTensor ms FullSpan y
forall (s :: AstSpanType) (ms :: AstMethodOfSharing) (y :: TK).
AstSpan s =>
AstTensor ms PrimalSpan y -> AstTensor ms s y
fromPrimal (AstTensor AstMethodLet PrimalSpan (TKR 0 Double)
-> AstTensor AstMethodLet FullSpan (TKR 0 Double))
-> AstTensor AstMethodLet PrimalSpan (TKR 0 Double)
-> AstTensor AstMethodLet FullSpan (TKR 0 Double)
forall a b. (a -> b) -> a -> b
$ Ranked 0 Double -> AstTensor AstMethodLet PrimalSpan (TKR 0 Double)
forall r (target :: Target) (n :: Natural).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete (Ranked 0 Double
-> AstTensor AstMethodLet PrimalSpan (TKR 0 Double))
-> Ranked 0 Double
-> AstTensor AstMethodLet PrimalSpan (TKR 0 Double)
forall a b. (a -> b) -> a -> b
$ Double -> Ranked 0 Double
forall a. Elt a => a -> Ranked 0 a
Nested.rscalar Double
7
blackLabel :: AstTensor AstMethodLet FullSpan (BuildTensorKind 2 (TKR 0 Double))
blackLabel = SNat 2
-> SingletonTK (TKR 0 Double)
-> AstTensor AstMethodLet FullSpan (TKR 0 Double)
-> AstTensor
AstMethodLet FullSpan (BuildTensorKind 2 (TKR 0 Double))
forall (z :: TK) (k :: Natural).
ConvertTensor (AstTensor AstMethodLet FullSpan) =>
SNat k
-> SingletonTK z
-> AstTensor AstMethodLet FullSpan z
-> AstTensor AstMethodLet FullSpan (BuildTensorKind k z)
forall (target :: Target) (z :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
treplicate (forall (n :: Natural). KnownNat n => SNat n
SNat @2) SingletonTK (TKR 0 Double)
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK
(AstTensor AstMethodLet FullSpan (TKR 0 Double)
-> AstTensor
AstMethodLet FullSpan (BuildTensorKind 2 (TKR 0 Double)))
-> AstTensor AstMethodLet FullSpan (TKR 0 Double)
-> AstTensor
AstMethodLet FullSpan (BuildTensorKind 2 (TKR 0 Double))
forall a b. (a -> b) -> a -> b
$ AstTensor AstMethodLet PrimalSpan (TKR 0 Double)
-> AstTensor AstMethodLet FullSpan (TKR 0 Double)
forall (ms :: AstMethodOfSharing) (y :: TK).
AstTensor ms PrimalSpan y -> AstTensor ms FullSpan y
forall (s :: AstSpanType) (ms :: AstMethodOfSharing) (y :: TK).
AstSpan s =>
AstTensor ms PrimalSpan y -> AstTensor ms s y
fromPrimal (AstTensor AstMethodLet PrimalSpan (TKR 0 Double)
-> AstTensor AstMethodLet FullSpan (TKR 0 Double))
-> AstTensor AstMethodLet PrimalSpan (TKR 0 Double)
-> AstTensor AstMethodLet FullSpan (TKR 0 Double)
forall a b. (a -> b) -> a -> b
$ Ranked 0 Double -> AstTensor AstMethodLet PrimalSpan (TKR 0 Double)
forall r (target :: Target) (n :: Natural).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete (Ranked 0 Double
-> AstTensor AstMethodLet PrimalSpan (TKR 0 Double))
-> Ranked 0 Double
-> AstTensor AstMethodLet PrimalSpan (TKR 0 Double)
forall a b. (a -> b) -> a -> b
$ Double -> Ranked 0 Double
forall a. Elt a => a -> Ranked 0 a
Nested.rscalar Double
8
afcnn2TnonLin :: MnistFcnnRanked2.ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Double Float
-> AstTensor AstMethodLet FullSpan (TKScalar Double)
afcnn2TnonLin :: ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Double Float
-> AstTensor AstMethodLet FullSpan (TKScalar Double)
afcnn2TnonLin = (AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double)),
AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double)))
-> ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Double Float
-> AstTensor AstMethodLet FullSpan (TKScalar Double)
forall (target :: Target) r q.
(ADReady target, GoodScalar r, Differentiable r, GoodScalar q,
Differentiable q) =>
(target (TKR 1 r), target (TKR 1 r))
-> ADFcnnMnist2Parameters target r q -> target (TKScalar r)
MnistFcnnRanked2.afcnnMnistLoss2 (AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
blackGlyph, AstTensor AstMethodLet FullSpan (TKR2 1 (TKScalar Double))
blackLabel)
ftk :: FullShapeTK
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
ftk = forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams2 Double Float))
(forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget @Concrete ADFcnnMnist2Parameters Concrete Double Float
valsInitVT2OPP)
artifactRevnonLin :: AstArtifactRev
(X (ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Double Float))
(TKScalar Double)
artifactRevnonLin =
IncomingCotangentHandling
-> (ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Double Float
-> AstTensor AstMethodLet FullSpan (TKScalar Double))
-> FullShapeTK
(X (ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Double Float))
-> AstArtifactRev
(X (ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Double Float))
(TKScalar Double)
forall src (ztgt :: TK) tgt.
(AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
(tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
IncomingCotangentHandling
-> (src -> tgt)
-> FullShapeTK (X src)
-> AstArtifactRev (X src) ztgt
revArtifactAdapt IncomingCotangentHandling
UseIncomingCotangent ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Double Float
-> AstTensor AstMethodLet FullSpan (TKScalar Double)
afcnn2TnonLin FullShapeTK
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
FullShapeTK
(X (ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Double Float))
ftk
AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKScalar Double)
-> String
forall (x :: TK) (z :: TK). AstArtifactRev x z -> String
printArtifactPrimalPretty (AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKScalar Double)
-> AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKScalar Double)
forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstArtifactRev x z
simplifyArtifact AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKScalar Double)
artifactRevnonLin)
String -> String -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= String
"\\m1 -> let v23 = exp (sdot1In (sreplicate @2 (recip (sconcrete (sreplicate [5] 1.0) + exp (negate (scast (sdot1In (sreplicate @5 (scast (recip (sconcrete (sreplicate [4] 1.0) + exp (negate (sdot1In (sconcrete (sreplicate [4,3] 7.0)) (sfromR (tproject1 (tproject1 (tproject1 m1))))) + negate (sfromR (tproject2 (tproject1 (tproject1 m1))))))))) (sfromR (tproject1 (tproject2 (tproject1 m1)))))) + negate (sfromR (tproject2 (tproject2 (tproject1 m1)))))))) (sfromR (tproject1 (tproject2 m1))) + sfromR (tproject2 (tproject2 m1))) in kfromS (negate (sdot0 (sconcrete (sreplicate [2] 8.0)) (log (sreplicate @2 (recip (ssum0 v23)) * v23))))"
AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKScalar Double)
-> String
forall (x :: TK) (z :: TK). AstArtifactRev x z -> String
printArtifactPrimalPretty AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKScalar Double)
artifactRevnonLin
String -> String -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= String
"\\m1 -> let v10 = ssum @3 (sconcrete (sreplicate [3,4] 7.0) * str (sfromR (tproject1 (tproject1 (tproject1 m1))))) + sfromR (tproject2 (tproject1 (tproject1 m1))) ; v11 = exp (negate v10) ; v12 = sconcrete (sreplicate [4] 1.0) + v11 ; v13 = recip v12 ; m16 = str (sreplicate @5 (scast v13)) ; v17 = scast (ssum @4 (m16 * str (sfromR (tproject1 (tproject2 (tproject1 m1)))))) + sfromR (tproject2 (tproject2 (tproject1 m1))) ; v18 = exp (negate v17) ; v19 = sconcrete (sreplicate [5] 1.0) + v18 ; v20 = recip v19 ; v23 = exp (ssum @5 (str (sreplicate @2 v20) * str (sfromR (tproject1 (tproject2 m1)))) + sfromR (tproject2 (tproject2 m1))) ; x24 = ssum @2 v23 ; v25 = sreplicate @2 (recip x24) ; v26 = v25 * v23 ; v27 = log v26 in kfromS (negate (ssum @2 (sconcrete (sreplicate [2] 8.0) * v27)))"
AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKScalar Double)
-> String
forall (x :: TK) (z :: TK). AstArtifactRev x z -> String
printArtifactPretty AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKScalar Double)
artifactRevnonLin
String -> String -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= String
"\\dret m1 -> let v10 = ssum @3 (sconcrete (sreplicate [3,4] 7.0) * str (sfromR (tproject1 (tproject1 (tproject1 m1))))) + sfromR (tproject2 (tproject1 (tproject1 m1))) ; v11 = exp (negate v10) ; v12 = sconcrete (sreplicate [4] 1.0) + v11 ; v13 = recip v12 ; v14 = sconcrete (sreplicate [4] 1.0) + negate v13 ; v15 = v13 * v14 ; m16 = str (sreplicate @5 (scast v13)) ; v17 = scast (ssum @4 (m16 * str (sfromR (tproject1 (tproject2 (tproject1 m1)))))) + sfromR (tproject2 (tproject2 (tproject1 m1))) ; v18 = exp (negate v17) ; v19 = sconcrete (sreplicate [5] 1.0) + v18 ; v20 = recip v19 ; v21 = sconcrete (sreplicate [5] 1.0) + negate v20 ; v22 = v20 * v21 ; v23 = exp (ssum @5 (str (sreplicate @2 v20) * str (sfromR (tproject1 (tproject2 m1)))) + sfromR (tproject2 (tproject2 m1))) ; x24 = ssum @2 v23 ; v25 = sreplicate @2 (recip x24) ; v26 = v25 * v23 ; v29 = sconcrete (sreplicate [2] 8.0) * (recip v26 * sreplicate @2 (sscalar (-1.0) * sfromK dret)) ; v30 = v23 * (sreplicate @2 (negate (recip (x24 * x24)) * ssum @2 (v23 * v29)) + v25 * v29) ; v31 = v22 * ssum @2 (str (str (sfromR (tproject1 (tproject2 m1))) * sreplicate @5 v30)) ; m32 = sreplicate @4 (scast v31) ; v33 = v15 * scast (ssum @5 (str (str (sfromR (tproject1 (tproject2 (tproject1 m1)))) * m32))) in tpair (tpair (tpair (rfromS (str (sconcrete (sreplicate [3,4] 7.0) * sreplicate @3 v33))) (rfromS v33)) (tpair (rfromS (str (m16 * m32))) (rfromS v31))) (tpair (rfromS (str (str (sreplicate @2 v20) * sreplicate @5 v30))) (rfromS v30))"
AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKScalar Double)
-> String
forall (x :: TK) (z :: TK). AstArtifactRev x z -> String
printArtifactPretty (AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKScalar Double)
-> AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKScalar Double)
forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstArtifactRev x z
simplifyArtifact AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKScalar Double)
artifactRevnonLin)
String -> String -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= String
"\\dret m1 -> tconvert (ConvT2 (ConvT2 (ConvT2 (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [4,3] FTKScalar)) ConvSX)) (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [4] FTKScalar)) ConvSX))) (ConvT2 (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [5,4] FTKScalar)) ConvSX)) (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [5] FTKScalar)) ConvSX)))) (ConvT2 (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [2,5] FTKScalar)) ConvSX)) (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [2] FTKScalar)) ConvSX)))) (STKProduct (STKProduct (STKProduct (STKS [4,3] STKScalar) (STKS [4] STKScalar)) (STKProduct (STKS [5,4] STKScalar) (STKS [5] STKScalar))) (STKProduct (STKS [2,5] STKScalar) (STKS [2] STKScalar))) (let v13 = recip (sconcrete (sreplicate [4] 1.0) + exp (negate (sdot1In (sconcrete (sreplicate [4,3] 7.0)) (sfromR (tproject1 (tproject1 (tproject1 m1))))) + negate (sfromR (tproject2 (tproject1 (tproject1 m1)))))) ; v16 = scast v13 ; v20 = recip (sconcrete (sreplicate [5] 1.0) + exp (negate (scast (sdot1In (sreplicate @5 v16) (sfromR (tproject1 (tproject2 (tproject1 m1)))))) + negate (sfromR (tproject2 (tproject2 (tproject1 m1)))))) ; v23 = exp (sdot1In (sreplicate @2 v20) (sfromR (tproject1 (tproject2 m1))) + sfromR (tproject2 (tproject2 m1))) ; x24 = ssum0 v23 ; x25 = recip x24 ; v29 = sconcrete (sreplicate [2] 8.0) * (recip (sreplicate @2 x25 * v23) * sreplicate @2 (sscalar (-1.0) * sfromK dret)) ; v30 = v23 * (sreplicate @2 (negate (recip (x24 * x24)) * sdot0 v23 v29) + sreplicate @2 x25 * v29) ; v31 = (v20 * (sconcrete (sreplicate [5] 1.0) + negate v20)) * sdot1In (str (sfromR (tproject1 (tproject2 m1)))) (sreplicate @5 v30) ; v32 = scast v31 ; v33 = (v13 * (sconcrete (sreplicate [4] 1.0) + negate v13)) * scast (sdot1In (str (sfromR (tproject1 (tproject2 (tproject1 m1))))) (sreplicate @4 v32)) in tpair (tpair (tpair (sconcrete (sreplicate [4,3] 7.0) * str (sreplicate @3 v33)) v33) (tpair (sreplicate @5 v16 * str (sreplicate @4 v32)) v31)) (tpair (sreplicate @2 v20 * str (sreplicate @5 v30)) v30))"
testVT2OAstNonLin3 :: Assertion
testVT2OAstNonLin3 :: Assertion
testVT2OAstNonLin3 = do
let ftk :: FullShapeTK
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
ftk = forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams2 Double Float))
(forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget @Concrete ADFcnnMnist2Parameters Concrete Double Float
valsInitVT2OPP)
varName :: AstVarName
s
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
varName = FullShapeTK
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
-> Maybe (Int64, Int64)
-> AstVarId
-> AstVarName
s
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
forall (s :: AstSpanType) (y :: TK).
FullShapeTK y -> Maybe (Int64, Int64) -> AstVarId -> AstVarName s y
mkAstVarName FullShapeTK
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
ftk Maybe (Int64, Int64)
forall a. Maybe a
Nothing (AstVarId
-> AstVarName
s
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
-> (Int -> AstVarId)
-> Int
-> AstVarName
s
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> AstVarId
intToAstVarId (Int
-> AstVarName
s
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
-> Int
-> AstVarName
s
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
forall a b. (a -> b) -> a -> b
$ Int
100000000
var :: AstTensor AstMethodLet FullSpan (XParams2 Double Float)
var :: AstTensor AstMethodLet FullSpan (XParams2 Double Float)
var = AstVarName
FullSpan
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
-> AstTensor
AstMethodLet
FullSpan
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
forall (b :: AstSpanType) (c :: TK) (a :: AstMethodOfSharing).
AstVarName b c -> AstTensor a b c
AstVar AstVarName
FullSpan
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
forall {s :: AstSpanType}.
AstVarName
s
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
varName
vals :: Concrete (XParams2 Double Float)
vals = forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget @Concrete ADFcnnMnist2Parameters Concrete Double Float
valsInitVT2OPP
env :: AstEnv Concrete
env = AstVarName
(ZonkAny @AstSpanType 4)
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
-> AstEnv Concrete
-> AstEnv Concrete
forall (target :: Target) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName
(ZonkAny @AstSpanType 4)
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
forall {s :: AstSpanType}.
AstVarName
s
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
varName Concrete
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Float) (TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
vals AstEnv Concrete
forall (target :: Target). AstEnv target
emptyEnv
blackGlyph :: Concrete (BuildTensorKind 3 (TKR 0 Double))
blackGlyph = SNat 3
-> SingletonTK (TKR 0 Double)
-> Concrete (TKR 0 Double)
-> Concrete (BuildTensorKind 3 (TKR 0 Double))
forall (z :: TK) (k :: Natural).
ConvertTensor Concrete =>
SNat k
-> SingletonTK z -> Concrete z -> Concrete (BuildTensorKind k z)
forall (target :: Target) (z :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
treplicate (forall (n :: Natural). KnownNat n => SNat n
SNat @3) SingletonTK (TKR 0 Double)
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK (Concrete (TKR 0 Double)
-> Concrete (BuildTensorKind 3 (TKR 0 Double)))
-> Concrete (TKR 0 Double)
-> Concrete (BuildTensorKind 3 (TKR 0 Double))
forall a b. (a -> b) -> a -> b
$ Double -> Concrete (TKR 0 Double)
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
r -> target (TKR 0 r)
rscalar Double
7
blackLabel :: Concrete (BuildTensorKind 2 (TKR 0 Double))
blackLabel = SNat 2
-> SingletonTK (TKR 0 Double)
-> Concrete (TKR 0 Double)
-> Concrete (BuildTensorKind 2 (TKR 0 Double))
forall (z :: TK) (k :: Natural).
ConvertTensor Concrete =>
SNat k
-> SingletonTK z -> Concrete z -> Concrete (BuildTensorKind k z)
forall (target :: Target) (z :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
treplicate (forall (n :: Natural). KnownNat n => SNat n
SNat @2) SingletonTK (TKR 0 Double)
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK (Concrete (TKR 0 Double)
-> Concrete (BuildTensorKind 2 (TKR 0 Double)))
-> Concrete (TKR 0 Double)
-> Concrete (BuildTensorKind 2 (TKR 0 Double))
forall a b. (a -> b) -> a -> b
$ Double -> Concrete (TKR 0 Double)
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
r -> target (TKR 0 r)
rscalar Double
8
afcnn2 :: ADReady f
=> MnistFcnnRanked2.ADFcnnMnist2Parameters f Double Float
-> f (TKScalar Double)
afcnn2 :: forall (f :: Target).
ADReady f =>
ADFcnnMnist2Parameters f Double Float -> f (TKScalar Double)
afcnn2 = (f (TKR2 1 (TKScalar Double)), f (TKR2 1 (TKScalar Double)))
-> ADFcnnMnist2Parameters f Double Float -> f (TKScalar Double)
forall (target :: Target) r q.
(ADReady target, GoodScalar r, Differentiable r, GoodScalar q,
Differentiable q) =>
(target (TKR 1 r), target (TKR 1 r))
-> ADFcnnMnist2Parameters target r q -> target (TKScalar r)
MnistFcnnRanked2.afcnnMnistLoss2
( Ranked 1 Double -> f (TKR2 1 (TKScalar Double))
forall r (target :: Target) (n :: Natural).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete (Ranked 1 Double -> f (TKR2 1 (TKScalar Double)))
-> Ranked 1 Double -> f (TKR2 1 (TKScalar Double))
forall a b. (a -> b) -> a -> b
$ Concrete (TKR2 1 (TKScalar Double))
-> RepConcrete (TKR2 1 (TKScalar Double))
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete (TKR2 1 (TKScalar Double))
blackGlyph
, Ranked 1 Double -> f (TKR2 1 (TKScalar Double))
forall r (target :: Target) (n :: Natural).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete (Ranked 1 Double -> f (TKR2 1 (TKScalar Double)))
-> Ranked 1 Double -> f (TKR2 1 (TKScalar Double))
forall a b. (a -> b) -> a -> b
$ Concrete (TKR2 1 (TKScalar Double))
-> RepConcrete (TKR2 1 (TKScalar Double))
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete (TKR2 1 (TKScalar Double))
blackLabel )
afcnn1 :: AstTensor AstMethodLet FullSpan (TKScalar Double)
afcnn1 = ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Double Float
-> AstTensor AstMethodLet FullSpan (TKScalar Double)
forall (f :: Target).
ADReady f =>
ADFcnnMnist2Parameters f Double Float -> f (TKScalar Double)
afcnn2 (ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Double Float
-> AstTensor AstMethodLet FullSpan (TKScalar Double))
-> ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Double Float
-> AstTensor AstMethodLet FullSpan (TKScalar Double)
forall a b. (a -> b) -> a -> b
$ AstTensor
AstMethodLet
FullSpan
(X (ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Double Float))
-> ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Double Float
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget AstTensor
AstMethodLet
FullSpan
(X (ADFcnnMnist2Parameters
(AstTensor AstMethodLet FullSpan) Double Float))
AstTensor AstMethodLet FullSpan (XParams2 Double Float)
var
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target -> AstTensor AstMethodLet FullSpan y -> target y
interpretAstFull @Concrete AstEnv Concrete
env AstTensor AstMethodLet FullSpan (TKScalar Double)
afcnn1
Concrete (TKScalar Double)
-> Concrete (TKScalar Double) -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= ADFcnnMnist2Parameters Concrete Double Float
-> Concrete (TKScalar Double)
forall (f :: Target).
ADReady f =>
ADFcnnMnist2Parameters f Double Float -> f (TKScalar Double)
afcnn2 ADFcnnMnist2Parameters Concrete Double Float
valsInitVT2OPP
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target -> AstTensor AstMethodLet FullSpan y -> target y
interpretAstFull @Concrete AstEnv Concrete
env
(forall (z :: TK) (s :: AstSpanType).
AstSpan s =>
AstTensor AstMethodLet s z -> AstTensor AstMethodLet s z
simplifyInline @(TKScalar Double) AstTensor AstMethodLet FullSpan (TKScalar Double)
afcnn1)
Concrete (TKScalar Double)
-> Concrete (TKScalar Double) -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= ADFcnnMnist2Parameters Concrete Double Float
-> Concrete (TKScalar Double)
forall (f :: Target).
ADReady f =>
ADFcnnMnist2Parameters f Double Float -> f (TKScalar Double)
afcnn2 ADFcnnMnist2Parameters Concrete Double Float
valsInitVT2OPP
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target -> AstTensor AstMethodLet FullSpan y -> target y
interpretAstFull @Concrete AstEnv Concrete
env
(forall (z :: TK) (s :: AstSpanType).
AstSpan s =>
AstTensor AstMethodLet s z -> AstTensor AstMethodLet s z
simplifyInlineContract @(TKScalar Double) AstTensor AstMethodLet FullSpan (TKScalar Double)
afcnn1)
Concrete (TKScalar Double)
-> Concrete (TKScalar Double) -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= ADFcnnMnist2Parameters Concrete Double Float
-> Concrete (TKScalar Double)
forall (f :: Target).
ADReady f =>
ADFcnnMnist2Parameters f Double Float -> f (TKScalar Double)
afcnn2 ADFcnnMnist2Parameters Concrete Double Float
valsInitVT2OPP
tensorMnistPPRNNR :: TestTree
tensorMnistPPRNNR :: TestTree
tensorMnistPPRNNR = String -> [TestTree] -> TestTree
testGroup String
"PP and Ast tests for RNNR MNIST"
[ String -> Assertion -> TestTree
testCase String
"RNNO PP" Assertion
testRNNOPP
, String -> Assertion -> TestTree
testCase String
"RNNO Ast" Assertion
testRNNOAst
, String -> Assertion -> TestTree
testCase String
"RNNO PP 2" Assertion
testRNNOPP2
, String -> Assertion -> TestTree
testCase String
"RNNO Ast 2" Assertion
testRNNOAst2
]
valsInitRNNOPP
:: Int -> Int -> ADRnnMnistParameters Concrete Double
valsInitRNNOPP :: Int -> Int -> ADRnnMnistParameters Concrete Double
valsInitRNNOPP Int
out_width Int
sizeMnistHeightI =
( ( RepConcrete (TKR 2 Double) -> Concrete (TKR 2 Double)
forall (y :: TK). RepConcrete y -> Concrete y
Concrete
(RepConcrete (TKR 2 Double) -> Concrete (TKR 2 Double))
-> RepConcrete (TKR 2 Double) -> Concrete (TKR 2 Double)
forall a b. (a -> b) -> a -> b
$ IShR 2 -> [Double] -> Ranked 2 Double
forall a (n :: Natural). PrimElt a => IShR n -> [a] -> Ranked n a
Nested.rfromListPrimLinear [Int
Item (IShR 2)
out_width, Int
Item (IShR 2)
sizeMnistHeightI]
((Int -> Double) -> [Int] -> [Double]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int
Item [Int]
0 .. Int
out_width Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
sizeMnistHeightI Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1])
, RepConcrete (TKR 2 Double) -> Concrete (TKR 2 Double)
forall (y :: TK). RepConcrete y -> Concrete y
Concrete
(RepConcrete (TKR 2 Double) -> Concrete (TKR 2 Double))
-> RepConcrete (TKR 2 Double) -> Concrete (TKR 2 Double)
forall a b. (a -> b) -> a -> b
$ IShR 2 -> [Double] -> Ranked 2 Double
forall a (n :: Natural). PrimElt a => IShR n -> [a] -> Ranked n a
Nested.rfromListPrimLinear [Int
Item (IShR 2)
out_width, Int
Item (IShR 2)
out_width]
((Int -> Double) -> [Int] -> [Double]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int
Item [Int]
0 .. Int
out_width Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
out_width Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1])
, RepConcrete (TKR2 1 (TKScalar Double))
-> Concrete (TKR2 1 (TKScalar Double))
forall (y :: TK). RepConcrete y -> Concrete y
Concrete
(RepConcrete (TKR2 1 (TKScalar Double))
-> Concrete (TKR2 1 (TKScalar Double)))
-> RepConcrete (TKR2 1 (TKScalar Double))
-> Concrete (TKR2 1 (TKScalar Double))
forall a b. (a -> b) -> a -> b
$ IShR 1 -> [Double] -> Ranked 1 Double
forall a (n :: Natural). PrimElt a => IShR n -> [a] -> Ranked n a
Nested.rfromListPrimLinear [Int
Item (IShR 1)
out_width]
((Int -> Double) -> [Int] -> [Double]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int
Item [Int]
0 .. Int
out_width Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]) )
, ( RepConcrete (TKR 2 Double) -> Concrete (TKR 2 Double)
forall (y :: TK). RepConcrete y -> Concrete y
Concrete
(RepConcrete (TKR 2 Double) -> Concrete (TKR 2 Double))
-> RepConcrete (TKR 2 Double) -> Concrete (TKR 2 Double)
forall a b. (a -> b) -> a -> b
$ IShR 2 -> [Double] -> Ranked 2 Double
forall a (n :: Natural). PrimElt a => IShR n -> [a] -> Ranked n a
Nested.rfromListPrimLinear [Int
Item (IShR 2)
out_width, Int
Item (IShR 2)
out_width]
((Int -> Double) -> [Int] -> [Double]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int
Item [Int]
0 .. Int
out_width Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
out_width Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1])
, RepConcrete (TKR 2 Double) -> Concrete (TKR 2 Double)
forall (y :: TK). RepConcrete y -> Concrete y
Concrete
(RepConcrete (TKR 2 Double) -> Concrete (TKR 2 Double))
-> RepConcrete (TKR 2 Double) -> Concrete (TKR 2 Double)
forall a b. (a -> b) -> a -> b
$ IShR 2 -> [Double] -> Ranked 2 Double
forall a (n :: Natural). PrimElt a => IShR n -> [a] -> Ranked n a
Nested.rfromListPrimLinear [Int
Item (IShR 2)
out_width, Int
Item (IShR 2)
out_width]
((Int -> Double) -> [Int] -> [Double]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int
Item [Int]
0 .. Int
out_width Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
out_width Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1])
, RepConcrete (TKR2 1 (TKScalar Double))
-> Concrete (TKR2 1 (TKScalar Double))
forall (y :: TK). RepConcrete y -> Concrete y
Concrete
(RepConcrete (TKR2 1 (TKScalar Double))
-> Concrete (TKR2 1 (TKScalar Double)))
-> RepConcrete (TKR2 1 (TKScalar Double))
-> Concrete (TKR2 1 (TKScalar Double))
forall a b. (a -> b) -> a -> b
$ IShR 1 -> [Double] -> Ranked 1 Double
forall a (n :: Natural). PrimElt a => IShR n -> [a] -> Ranked n a
Nested.rfromListPrimLinear [Int
Item (IShR 1)
out_width]
((Int -> Double) -> [Int] -> [Double]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int
Item [Int]
0 .. Int
out_width Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]) )
, ( RepConcrete (TKR 2 Double) -> Concrete (TKR 2 Double)
forall (y :: TK). RepConcrete y -> Concrete y
Concrete
(RepConcrete (TKR 2 Double) -> Concrete (TKR 2 Double))
-> RepConcrete (TKR 2 Double) -> Concrete (TKR 2 Double)
forall a b. (a -> b) -> a -> b
$ IShR 2 -> [Double] -> Ranked 2 Double
forall a (n :: Natural). PrimElt a => IShR n -> [a] -> Ranked n a
Nested.rfromListPrimLinear [Int
Item (IShR 2)
sizeMnistLabelInt, Int
Item (IShR 2)
out_width]
((Int -> Double) -> [Int] -> [Double]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int
Item [Int]
0 .. Int
sizeMnistLabelInt Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
out_width Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1])
, RepConcrete (TKR2 1 (TKScalar Double))
-> Concrete (TKR2 1 (TKScalar Double))
forall (y :: TK). RepConcrete y -> Concrete y
Concrete
(RepConcrete (TKR2 1 (TKScalar Double))
-> Concrete (TKR2 1 (TKScalar Double)))
-> RepConcrete (TKR2 1 (TKScalar Double))
-> Concrete (TKR2 1 (TKScalar Double))
forall a b. (a -> b) -> a -> b
$ IShR 1 -> [Double] -> Ranked 1 Double
forall a (n :: Natural). PrimElt a => IShR n -> [a] -> Ranked n a
Nested.rfromListPrimLinear [Int
Item (IShR 1)
sizeMnistLabelInt]
((Int -> Double) -> [Int] -> [Double]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int
Item [Int]
0 .. Int
sizeMnistLabelInt Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]) ) )
testRNNOPP :: Assertion
testRNNOPP :: Assertion
testRNNOPP = do
Assertion
resetVarCounter
let batch_size :: Int
batch_size = Int
1
sizeMnistHeightI :: Int
sizeMnistHeightI = Int
1
blackGlyph :: AstTensor AstMethodLet PrimalSpan (TKR 3 Double)
blackGlyph :: AstTensor AstMethodLet PrimalSpan (TKR 3 Double)
blackGlyph = SNat 1
-> SingletonTK (BuildTensorKind 1 (TKR2 1 (TKScalar Double)))
-> AstTensor
AstMethodLet
PrimalSpan
(BuildTensorKind 1 (TKR2 1 (TKScalar Double)))
-> AstTensor
AstMethodLet
PrimalSpan
(BuildTensorKind 1 (BuildTensorKind 1 (TKR2 1 (TKScalar Double))))
forall (y :: TK) (k :: Natural) (a :: AstMethodOfSharing)
(b :: AstSpanType).
SNat k
-> SingletonTK y
-> AstTensor a b y
-> AstTensor a b (BuildTensorKind k y)
AstReplicate (forall (n :: Natural). KnownNat n => SNat n
SNat @1) SingletonTK (BuildTensorKind 1 (TKR2 1 (TKScalar Double)))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK
(AstTensor
AstMethodLet
PrimalSpan
(BuildTensorKind 1 (TKR2 1 (TKScalar Double)))
-> AstTensor
AstMethodLet
PrimalSpan
(BuildTensorKind 1 (BuildTensorKind 1 (TKR2 1 (TKScalar Double)))))
-> AstTensor
AstMethodLet
PrimalSpan
(BuildTensorKind 1 (TKR2 1 (TKScalar Double)))
-> AstTensor
AstMethodLet
PrimalSpan
(BuildTensorKind 1 (BuildTensorKind 1 (TKR2 1 (TKScalar Double))))
forall a b. (a -> b) -> a -> b
$ SNat 1
-> SingletonTK (TKR2 1 (TKScalar Double))
-> AstTensor AstMethodLet PrimalSpan (TKR2 1 (TKScalar Double))
-> AstTensor
AstMethodLet
PrimalSpan
(BuildTensorKind 1 (TKR2 1 (TKScalar Double)))
forall (y :: TK) (k :: Natural) (a :: AstMethodOfSharing)
(b :: AstSpanType).
SNat k
-> SingletonTK y
-> AstTensor a b y
-> AstTensor a b (BuildTensorKind k y)
AstReplicate (forall (n :: Natural). KnownNat n => SNat n
SNat @1) SingletonTK (TKR2 1 (TKScalar Double))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK
(AstTensor AstMethodLet PrimalSpan (TKR2 1 (TKScalar Double))
-> AstTensor
AstMethodLet
PrimalSpan
(BuildTensorKind 1 (TKR2 1 (TKScalar Double))))
-> AstTensor AstMethodLet PrimalSpan (TKR2 1 (TKScalar Double))
-> AstTensor
AstMethodLet
PrimalSpan
(BuildTensorKind 1 (TKR2 1 (TKScalar Double)))
forall a b. (a -> b) -> a -> b
$ SNat 1
-> SingletonTK (TKR 0 Double)
-> AstTensor AstMethodLet PrimalSpan (TKR 0 Double)
-> AstTensor
AstMethodLet PrimalSpan (BuildTensorKind 1 (TKR 0 Double))
forall (y :: TK) (k :: Natural) (a :: AstMethodOfSharing)
(b :: AstSpanType).
SNat k
-> SingletonTK y
-> AstTensor a b y
-> AstTensor a b (BuildTensorKind k y)
AstReplicate (forall (n :: Natural). KnownNat n => SNat n
SNat @1) SingletonTK (TKR 0 Double)
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK
(Ranked 0 Double -> AstTensor AstMethodLet PrimalSpan (TKR 0 Double)
forall r (target :: Target) (n :: Natural).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete (Ranked 0 Double
-> AstTensor AstMethodLet PrimalSpan (TKR 0 Double))
-> Ranked 0 Double
-> AstTensor AstMethodLet PrimalSpan (TKR 0 Double)
forall a b. (a -> b) -> a -> b
$ Double -> Ranked 0 Double
forall a. Elt a => a -> Ranked 0 a
Nested.rscalar Double
7
:: AstTensor AstMethodLet PrimalSpan (TKR 0 Double))
afcnn2T :: ADRnnMnistParameters (AstTensor AstMethodLet FullSpan)
Double
-> AstTensor AstMethodLet FullSpan (TKR 2 Double)
afcnn2T :: ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) Double
-> AstTensor AstMethodLet FullSpan (TKR 2 Double)
afcnn2T = Int
-> PrimalOf (AstTensor AstMethodLet FullSpan) (TKR 3 Double)
-> ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) Double
-> AstTensor AstMethodLet FullSpan (TKR 2 Double)
forall (target :: Target) r.
(ADReady target, GoodScalar r, Differentiable r) =>
Int
-> PrimalOf target (TKR 3 r)
-> ADRnnMnistParameters target r
-> target (TKR 2 r)
MnistRnnRanked2.rnnMnistZeroR Int
batch_size AstTensor AstMethodLet PrimalSpan (TKR 3 Double)
PrimalOf (AstTensor AstMethodLet FullSpan) (TKR 3 Double)
blackGlyph
ftk :: FullShapeTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
ftk = forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete
(forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(X (ADRnnMnistParameters Concrete Double)))
(forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget @Concrete (ADRnnMnistParameters Concrete Double
-> Concrete (X (ADRnnMnistParameters Concrete Double)))
-> ADRnnMnistParameters Concrete Double
-> Concrete (X (ADRnnMnistParameters Concrete Double))
forall a b. (a -> b) -> a -> b
$ Int -> Int -> ADRnnMnistParameters Concrete Double
valsInitRNNOPP Int
1 Int
sizeMnistHeightI)
artifactRev :: AstArtifactRev
(X (ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) Double))
(TKR 2 Double)
artifactRev = IncomingCotangentHandling
-> (ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) Double
-> AstTensor AstMethodLet FullSpan (TKR 2 Double))
-> FullShapeTK
(X (ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) Double))
-> AstArtifactRev
(X (ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) Double))
(TKR 2 Double)
forall src (ztgt :: TK) tgt.
(AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
(tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
IncomingCotangentHandling
-> (src -> tgt)
-> FullShapeTK (X src)
-> AstArtifactRev (X src) ztgt
revArtifactAdapt IncomingCotangentHandling
UseIncomingCotangent ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) Double
-> AstTensor AstMethodLet FullSpan (TKR 2 Double)
afcnn2T FullShapeTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
FullShapeTK
(X (ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) Double))
ftk
AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR 2 Double)
-> String
forall (x :: TK) (z :: TK). AstArtifactRev x z -> String
printArtifactPrimalPretty (AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR 2 Double)
-> AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR 2 Double)
forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstArtifactRev x z
simplifyArtifact AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR 2 Double)
AstArtifactRev
(X (ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) Double))
(TKR 2 Double)
artifactRev)
String -> String -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= String
"\\m1 -> rfromS (str (sreplicate @1 (str (sfromR (tproject1 (tproject2 m1))) !$ [0] * sreplicate @10 (tanh (sfromR (tproject1 (tproject1 (tproject2 (tproject1 m1)))) !$ [0, 0] * tanh (sscalar 7.0 * sfromR (tproject1 (tproject1 (tproject1 (tproject1 m1)))) !$ [0, 0] + sfromR (tproject2 (tproject1 (tproject1 m1))) !$ [0]) + sfromR (tproject2 (tproject2 (tproject1 m1))) !$ [0])))) + str (sreplicate @1 (sfromR (tproject2 (tproject2 m1)))))"
AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR 2 Double)
-> String
forall (x :: TK) (z :: TK). AstArtifactRev x z -> String
printArtifactPrimalPretty AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR 2 Double)
AstArtifactRev
(X (ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) Double))
(TKR 2 Double)
artifactRev
String -> String -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= String
"\\m1 -> let x16 = sfromR (tproject1 (tproject1 (tproject1 (tproject1 m1)))) !$ [0, 0] ; x18 = tanh (sscalar 7.0 * x16 + sfromR (tproject2 (tproject1 (tproject1 m1))) !$ [0]) ; x19 = sfromR (tproject1 (tproject1 (tproject2 (tproject1 m1)))) !$ [0, 0] ; x21 = tanh (x19 * x18 + sfromR (tproject2 (tproject2 (tproject1 m1))) !$ [0]) ; v22 = str (sfromR (tproject1 (tproject2 m1))) !$ [0] in rfromS (str (sreplicate @1 (v22 * sreplicate @10 x21)) + str (sreplicate @1 (sfromR (tproject2 (tproject2 m1)))))"
AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR 2 Double)
-> String
forall (x :: TK) (z :: TK). AstArtifactRev x z -> String
printArtifactPretty AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR 2 Double)
AstArtifactRev
(X (ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) Double))
(TKR 2 Double)
artifactRev
String -> String -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= String
"\\dret m1 -> let x16 = sfromR (tproject1 (tproject1 (tproject1 (tproject1 m1)))) !$ [0, 0] ; x18 = tanh (sscalar 7.0 * x16 + sfromR (tproject2 (tproject1 (tproject1 m1))) !$ [0]) ; x19 = sfromR (tproject1 (tproject1 (tproject2 (tproject1 m1)))) !$ [0, 0] ; x21 = tanh (x19 * x18 + sfromR (tproject2 (tproject2 (tproject1 m1))) !$ [0]) ; v22 = str (sfromR (tproject1 (tproject2 m1))) !$ [0] ; x24 = (sscalar 1.0 + negate x21 * x21) * ssum @10 (v22 * ssum @1 (str (sfromR dret))) ; x25 = (sscalar 1.0 + negate x18 * x18) * (x19 * x24) in tpair (tpair (tpair (tpair (rfromS (soneHot (sscalar 7.0 * x25) [0, 0])) (rfromS (soneHot (sscalar 0.0) [0, 0]))) (rfromS (soneHot x25 [0]))) (tpair (tpair (rfromS (soneHot (x18 * x24) [0, 0])) (rfromS (soneHot (sscalar 0.0) [0, 0]))) (rfromS (soneHot x24 [0])))) (tpair (rfromS (str (soneHot (sreplicate @10 x21 * ssum @1 (str (sfromR dret))) [0]))) (rfromS (ssum @1 (str (sfromR dret)))))"
AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR 2 Double)
-> String
forall (x :: TK) (z :: TK). AstArtifactRev x z -> String
printArtifactPretty (AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR 2 Double)
-> AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR 2 Double)
forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstArtifactRev x z
simplifyArtifact AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR 2 Double)
AstArtifactRev
(X (ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) Double))
(TKR 2 Double)
artifactRev)
String -> String -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= String
"\\dret m1 -> tconvert (ConvT2 (ConvT2 (ConvT2 (ConvT2 (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [1,1] FTKScalar)) ConvSX)) (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [1,1] FTKScalar)) ConvSX))) (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [1] FTKScalar)) ConvSX))) (ConvT2 (ConvT2 (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [1,1] FTKScalar)) ConvSX)) (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [1,1] FTKScalar)) ConvSX))) (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [1] FTKScalar)) ConvSX)))) (ConvT2 (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [10,1] FTKScalar)) ConvSX)) (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [10] FTKScalar)) ConvSX)))) (STKProduct (STKProduct (STKProduct (STKProduct (STKS [1,1] STKScalar) (STKS [1,1] STKScalar)) (STKS [1] STKScalar)) (STKProduct (STKProduct (STKS [1,1] STKScalar) (STKS [1,1] STKScalar)) (STKS [1] STKScalar))) (STKProduct (STKS [10,1] STKScalar) (STKS [10] STKScalar))) (let x18 = tanh (sscalar 7.0 * sfromR (tproject1 (tproject1 (tproject1 (tproject1 m1)))) !$ [0, 0] + sfromR (tproject2 (tproject1 (tproject1 m1))) !$ [0]) ; x19 = sfromR (tproject1 (tproject1 (tproject2 (tproject1 m1)))) !$ [0, 0] ; x21 = tanh (x19 * x18 + sfromR (tproject2 (tproject2 (tproject1 m1))) !$ [0]) ; x24 = (sscalar 1.0 + negate x21 * x21) * sdot0 (str (sfromR (tproject1 (tproject2 m1))) !$ [0]) (str (sfromR dret) !$ [0]) ; x25 = (sscalar 1.0 + negate x18 * x18) * (x19 * x24) in tpair (tpair (tpair (tpair (sreplicate @1 (sreplicate @1 (sscalar 7.0 * x25))) (sconcrete (sfromListLinear [1,1] [0.0]))) (sreplicate @1 x25)) (tpair (tpair (sreplicate @1 (sreplicate @1 (x18 * x24))) (sconcrete (sfromListLinear [1,1] [0.0]))) (sreplicate @1 x24))) (tpair (str (sreplicate @1 (sreplicate @10 x21 * str (sfromR dret) !$ [0]))) (str (sfromR dret) !$ [0])))"
testRNNOAst :: Assertion
testRNNOAst :: Assertion
testRNNOAst = do
let batch_size :: Int
batch_size = Int
1
sizeMnistHeightI :: Int
sizeMnistHeightI = Int
1
ftk :: FullShapeTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
ftk = forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete
(forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(X (ADRnnMnistParameters Concrete Double)))
(forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget @Concrete (ADRnnMnistParameters Concrete Double
-> Concrete (X (ADRnnMnistParameters Concrete Double)))
-> ADRnnMnistParameters Concrete Double
-> Concrete (X (ADRnnMnistParameters Concrete Double))
forall a b. (a -> b) -> a -> b
$ Int -> Int -> ADRnnMnistParameters Concrete Double
valsInitRNNOPP Int
1 Int
sizeMnistHeightI)
varName :: AstVarName
FullSpan
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
varName = FullShapeTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
-> Maybe (Int64, Int64)
-> AstVarId
-> AstVarName
FullSpan
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
forall (s :: AstSpanType) (y :: TK).
FullShapeTK y -> Maybe (Int64, Int64) -> AstVarId -> AstVarName s y
mkAstVarName FullShapeTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
ftk Maybe (Int64, Int64)
forall a. Maybe a
Nothing (AstVarId
-> AstVarName
FullSpan
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
-> (Int -> AstVarId)
-> Int
-> AstVarName
FullSpan
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> AstVarId
intToAstVarId (Int
-> AstVarName
FullSpan
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
-> Int
-> AstVarName
FullSpan
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
forall a b. (a -> b) -> a -> b
$ Int
100000000
var :: AstTensor AstMethodLet FullSpan
(X (ADRnnMnistParameters Concrete Double))
var :: AstTensor
AstMethodLet FullSpan (X (ADRnnMnistParameters Concrete Double))
var = AstVarName
FullSpan
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
-> AstTensor
AstMethodLet
FullSpan
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
forall (b :: AstSpanType) (c :: TK) (a :: AstMethodOfSharing).
AstVarName b c -> AstTensor a b c
AstVar AstVarName
FullSpan
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
varName
vals :: Concrete (X (ADRnnMnistParameters Concrete Double))
vals = forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget @Concrete (ADRnnMnistParameters Concrete Double
-> Concrete (X (ADRnnMnistParameters Concrete Double)))
-> ADRnnMnistParameters Concrete Double
-> Concrete (X (ADRnnMnistParameters Concrete Double))
forall a b. (a -> b) -> a -> b
$ Int -> Int -> ADRnnMnistParameters Concrete Double
valsInitRNNOPP Int
1 Int
sizeMnistHeightI
env :: AstEnv Concrete
env = AstVarName
FullSpan
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
-> AstEnv Concrete
-> AstEnv Concrete
forall (target :: Target) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName
FullSpan
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
varName Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
Concrete (X (ADRnnMnistParameters Concrete Double))
vals AstEnv Concrete
forall (target :: Target). AstEnv target
emptyEnv
blackGlyph :: Concrete
(BuildTensorKind 1 (BuildTensorKind 1 (TKR2 1 (TKScalar Double))))
blackGlyph = SNat 1
-> SingletonTK (BuildTensorKind 1 (TKR2 1 (TKScalar Double)))
-> Concrete (BuildTensorKind 1 (TKR2 1 (TKScalar Double)))
-> Concrete
(BuildTensorKind 1 (BuildTensorKind 1 (TKR2 1 (TKScalar Double))))
forall (z :: TK) (k :: Natural).
ConvertTensor Concrete =>
SNat k
-> SingletonTK z -> Concrete z -> Concrete (BuildTensorKind k z)
forall (target :: Target) (z :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
treplicate (forall (n :: Natural). KnownNat n => SNat n
SNat @1) SingletonTK (BuildTensorKind 1 (TKR2 1 (TKScalar Double)))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK
(Concrete (BuildTensorKind 1 (TKR2 1 (TKScalar Double)))
-> Concrete
(BuildTensorKind 1 (BuildTensorKind 1 (TKR2 1 (TKScalar Double)))))
-> Concrete (BuildTensorKind 1 (TKR2 1 (TKScalar Double)))
-> Concrete
(BuildTensorKind 1 (BuildTensorKind 1 (TKR2 1 (TKScalar Double))))
forall a b. (a -> b) -> a -> b
$ SNat 1
-> SingletonTK (TKR2 1 (TKScalar Double))
-> Concrete (TKR2 1 (TKScalar Double))
-> Concrete (BuildTensorKind 1 (TKR2 1 (TKScalar Double)))
forall (z :: TK) (k :: Natural).
ConvertTensor Concrete =>
SNat k
-> SingletonTK z -> Concrete z -> Concrete (BuildTensorKind k z)
forall (target :: Target) (z :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
treplicate (forall (n :: Natural). KnownNat n => SNat n
SNat @1) SingletonTK (TKR2 1 (TKScalar Double))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK
(Concrete (TKR2 1 (TKScalar Double))
-> Concrete (BuildTensorKind 1 (TKR2 1 (TKScalar Double))))
-> Concrete (TKR2 1 (TKScalar Double))
-> Concrete (BuildTensorKind 1 (TKR2 1 (TKScalar Double)))
forall a b. (a -> b) -> a -> b
$ SNat 1
-> SingletonTK (TKR 0 Double)
-> Concrete (TKR 0 Double)
-> Concrete (BuildTensorKind 1 (TKR 0 Double))
forall (z :: TK) (k :: Natural).
ConvertTensor Concrete =>
SNat k
-> SingletonTK z -> Concrete z -> Concrete (BuildTensorKind k z)
forall (target :: Target) (z :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
treplicate (forall (n :: Natural). KnownNat n => SNat n
SNat @1) SingletonTK (TKR 0 Double)
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK (Concrete (TKR 0 Double)
-> Concrete (BuildTensorKind 1 (TKR 0 Double)))
-> Concrete (TKR 0 Double)
-> Concrete (BuildTensorKind 1 (TKR 0 Double))
forall a b. (a -> b) -> a -> b
$ Double -> Concrete (TKR 0 Double)
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
r -> target (TKR 0 r)
rscalar Double
7
afcnn2 :: ADReady f
=> ADRnnMnistParameters f Double
-> f (TKR 2 Double)
afcnn2 :: forall (f :: Target).
ADReady f =>
ADRnnMnistParameters f Double -> f (TKR 2 Double)
afcnn2 = Int
-> PrimalOf f (TKR 3 Double)
-> ADRnnMnistParameters f Double
-> f (TKR 2 Double)
forall (target :: Target) r.
(ADReady target, GoodScalar r, Differentiable r) =>
Int
-> PrimalOf target (TKR 3 r)
-> ADRnnMnistParameters target r
-> target (TKR 2 r)
MnistRnnRanked2.rnnMnistZeroR
Int
batch_size (Ranked 3 Double -> PrimalOf f (TKR 3 Double)
forall r (target :: Target) (n :: Natural).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete (Ranked 3 Double -> PrimalOf f (TKR 3 Double))
-> Ranked 3 Double -> PrimalOf f (TKR 3 Double)
forall a b. (a -> b) -> a -> b
$ Concrete (TKR 3 Double) -> RepConcrete (TKR 3 Double)
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete (TKR 3 Double)
blackGlyph)
afcnn1 :: AstTensor AstMethodLet FullSpan (TKR 2 Double)
afcnn1 = ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) Double
-> AstTensor AstMethodLet FullSpan (TKR 2 Double)
forall (f :: Target).
ADReady f =>
ADRnnMnistParameters f Double -> f (TKR 2 Double)
afcnn2 (ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) Double
-> AstTensor AstMethodLet FullSpan (TKR 2 Double))
-> ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) Double
-> AstTensor AstMethodLet FullSpan (TKR 2 Double)
forall a b. (a -> b) -> a -> b
$ AstTensor
AstMethodLet
FullSpan
(X (ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) Double))
-> ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) Double
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget AstTensor
AstMethodLet
FullSpan
(X (ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) Double))
AstTensor
AstMethodLet FullSpan (X (ADRnnMnistParameters Concrete Double))
var
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target -> AstTensor AstMethodLet FullSpan y -> target y
interpretAstFull @Concrete AstEnv Concrete
env AstTensor AstMethodLet FullSpan (TKR 2 Double)
afcnn1
Concrete (TKR 2 Double) -> Concrete (TKR 2 Double) -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= ADRnnMnistParameters Concrete Double -> Concrete (TKR 2 Double)
forall (f :: Target).
ADReady f =>
ADRnnMnistParameters f Double -> f (TKR 2 Double)
afcnn2 (Int -> Int -> ADRnnMnistParameters Concrete Double
valsInitRNNOPP Int
1 Int
sizeMnistHeightI)
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target -> AstTensor AstMethodLet FullSpan y -> target y
interpretAstFull @Concrete AstEnv Concrete
env
(forall (z :: TK) (s :: AstSpanType).
AstSpan s =>
AstTensor AstMethodLet s z -> AstTensor AstMethodLet s z
simplifyInline @(TKR 2 Double) AstTensor AstMethodLet FullSpan (TKR 2 Double)
afcnn1)
Concrete (TKR 2 Double) -> Concrete (TKR 2 Double) -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= ADRnnMnistParameters Concrete Double -> Concrete (TKR 2 Double)
forall (f :: Target).
ADReady f =>
ADRnnMnistParameters f Double -> f (TKR 2 Double)
afcnn2 (Int -> Int -> ADRnnMnistParameters Concrete Double
valsInitRNNOPP Int
1 Int
sizeMnistHeightI)
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target -> AstTensor AstMethodLet FullSpan y -> target y
interpretAstFull @Concrete AstEnv Concrete
env
(forall (z :: TK) (s :: AstSpanType).
AstSpan s =>
AstTensor AstMethodLet s z -> AstTensor AstMethodLet s z
simplifyInlineContract @(TKR 2 Double) AstTensor AstMethodLet FullSpan (TKR 2 Double)
afcnn1)
Concrete (TKR 2 Double) -> Concrete (TKR 2 Double) -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= ADRnnMnistParameters Concrete Double -> Concrete (TKR 2 Double)
forall (f :: Target).
ADReady f =>
ADRnnMnistParameters f Double -> f (TKR 2 Double)
afcnn2 (Int -> Int -> ADRnnMnistParameters Concrete Double
valsInitRNNOPP Int
1 Int
sizeMnistHeightI)
testRNNOPP2 :: Assertion
testRNNOPP2 :: Assertion
testRNNOPP2 = do
Assertion
resetVarCounter
let batch_size :: Int
batch_size = Int
2
sizeMnistHeightI :: Int
sizeMnistHeightI = Int
2
blackGlyph :: AstTensor AstMethodLet PrimalSpan (TKR 3 Double)
blackGlyph :: AstTensor AstMethodLet PrimalSpan (TKR 3 Double)
blackGlyph = SNat 2
-> SingletonTK (BuildTensorKind 2 (TKR2 1 (TKScalar Double)))
-> AstTensor
AstMethodLet
PrimalSpan
(BuildTensorKind 2 (TKR2 1 (TKScalar Double)))
-> AstTensor
AstMethodLet
PrimalSpan
(BuildTensorKind 2 (BuildTensorKind 2 (TKR2 1 (TKScalar Double))))
forall (y :: TK) (k :: Natural) (a :: AstMethodOfSharing)
(b :: AstSpanType).
SNat k
-> SingletonTK y
-> AstTensor a b y
-> AstTensor a b (BuildTensorKind k y)
AstReplicate (forall (n :: Natural). KnownNat n => SNat n
SNat @2) SingletonTK (BuildTensorKind 2 (TKR2 1 (TKScalar Double)))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK
(AstTensor
AstMethodLet
PrimalSpan
(BuildTensorKind 2 (TKR2 1 (TKScalar Double)))
-> AstTensor
AstMethodLet
PrimalSpan
(BuildTensorKind 2 (BuildTensorKind 2 (TKR2 1 (TKScalar Double)))))
-> AstTensor
AstMethodLet
PrimalSpan
(BuildTensorKind 2 (TKR2 1 (TKScalar Double)))
-> AstTensor
AstMethodLet
PrimalSpan
(BuildTensorKind 2 (BuildTensorKind 2 (TKR2 1 (TKScalar Double))))
forall a b. (a -> b) -> a -> b
$ SNat 2
-> SingletonTK (TKR2 1 (TKScalar Double))
-> AstTensor AstMethodLet PrimalSpan (TKR2 1 (TKScalar Double))
-> AstTensor
AstMethodLet
PrimalSpan
(BuildTensorKind 2 (TKR2 1 (TKScalar Double)))
forall (y :: TK) (k :: Natural) (a :: AstMethodOfSharing)
(b :: AstSpanType).
SNat k
-> SingletonTK y
-> AstTensor a b y
-> AstTensor a b (BuildTensorKind k y)
AstReplicate (forall (n :: Natural). KnownNat n => SNat n
SNat @2) SingletonTK (TKR2 1 (TKScalar Double))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK
(AstTensor AstMethodLet PrimalSpan (TKR2 1 (TKScalar Double))
-> AstTensor
AstMethodLet
PrimalSpan
(BuildTensorKind 2 (TKR2 1 (TKScalar Double))))
-> AstTensor AstMethodLet PrimalSpan (TKR2 1 (TKScalar Double))
-> AstTensor
AstMethodLet
PrimalSpan
(BuildTensorKind 2 (TKR2 1 (TKScalar Double)))
forall a b. (a -> b) -> a -> b
$ SNat 2
-> SingletonTK (TKR 0 Double)
-> AstTensor AstMethodLet PrimalSpan (TKR 0 Double)
-> AstTensor
AstMethodLet PrimalSpan (BuildTensorKind 2 (TKR 0 Double))
forall (y :: TK) (k :: Natural) (a :: AstMethodOfSharing)
(b :: AstSpanType).
SNat k
-> SingletonTK y
-> AstTensor a b y
-> AstTensor a b (BuildTensorKind k y)
AstReplicate (forall (n :: Natural). KnownNat n => SNat n
SNat @2) SingletonTK (TKR 0 Double)
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK
(Ranked 0 Double -> AstTensor AstMethodLet PrimalSpan (TKR 0 Double)
forall r (target :: Target) (n :: Natural).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete (Ranked 0 Double
-> AstTensor AstMethodLet PrimalSpan (TKR 0 Double))
-> Ranked 0 Double
-> AstTensor AstMethodLet PrimalSpan (TKR 0 Double)
forall a b. (a -> b) -> a -> b
$ Double -> Ranked 0 Double
forall a. Elt a => a -> Ranked 0 a
Nested.rscalar Double
7
:: AstTensor AstMethodLet PrimalSpan (TKR 0 Double))
afcnn2T :: ADRnnMnistParameters (AstTensor AstMethodLet FullSpan)
Double
-> AstTensor AstMethodLet FullSpan (TKR 2 Double)
afcnn2T :: ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) Double
-> AstTensor AstMethodLet FullSpan (TKR 2 Double)
afcnn2T = Int
-> PrimalOf (AstTensor AstMethodLet FullSpan) (TKR 3 Double)
-> ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) Double
-> AstTensor AstMethodLet FullSpan (TKR 2 Double)
forall (target :: Target) r.
(ADReady target, GoodScalar r, Differentiable r) =>
Int
-> PrimalOf target (TKR 3 r)
-> ADRnnMnistParameters target r
-> target (TKR 2 r)
MnistRnnRanked2.rnnMnistZeroR Int
batch_size AstTensor AstMethodLet PrimalSpan (TKR 3 Double)
PrimalOf (AstTensor AstMethodLet FullSpan) (TKR 3 Double)
blackGlyph
ftk :: FullShapeTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
ftk = forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete
(forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(X (ADRnnMnistParameters Concrete Double)))
(forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget @Concrete (ADRnnMnistParameters Concrete Double
-> Concrete (X (ADRnnMnistParameters Concrete Double)))
-> ADRnnMnistParameters Concrete Double
-> Concrete (X (ADRnnMnistParameters Concrete Double))
forall a b. (a -> b) -> a -> b
$ Int -> Int -> ADRnnMnistParameters Concrete Double
valsInitRNNOPP Int
2 Int
sizeMnistHeightI)
artifactRev :: AstArtifactRev
(X (ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) Double))
(TKR 2 Double)
artifactRev = IncomingCotangentHandling
-> (ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) Double
-> AstTensor AstMethodLet FullSpan (TKR 2 Double))
-> FullShapeTK
(X (ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) Double))
-> AstArtifactRev
(X (ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) Double))
(TKR 2 Double)
forall src (ztgt :: TK) tgt.
(AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
(tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
IncomingCotangentHandling
-> (src -> tgt)
-> FullShapeTK (X src)
-> AstArtifactRev (X src) ztgt
revArtifactAdapt IncomingCotangentHandling
UseIncomingCotangent ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) Double
-> AstTensor AstMethodLet FullSpan (TKR 2 Double)
afcnn2T FullShapeTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
FullShapeTK
(X (ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) Double))
ftk
AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR 2 Double)
-> String
forall (x :: TK) (z :: TK). AstArtifactRev x z -> String
printArtifactPrimalPretty (AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR 2 Double)
-> AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR 2 Double)
forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstArtifactRev x z
simplifyArtifact AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR 2 Double)
AstArtifactRev
(X (ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) Double))
(TKR 2 Double)
artifactRev)
String -> String -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= String
"\\m1 -> rfromS (let m40 = sappend (tanh (str (sreplicate @2 (sdot1In (sconcrete (sreplicate [2,2] 7.0)) (sfromR (tproject1 (tproject1 (tproject1 (tproject1 m1))))))) + str (sreplicate @2 (sfromR (tproject2 (tproject1 (tproject1 m1))))))) (tanh (str (sreplicate @2 (sdot1In (sfromR (tproject1 (tproject1 (tproject2 (tproject1 m1))))) (sreplicate @2 (tanh (sdot1In (sconcrete (sreplicate [2,2] 7.0)) (sfromR (tproject1 (tproject1 (tproject1 (tproject1 m1))))) + sfromR (tproject2 (tproject1 (tproject1 m1)))))))) + str (sreplicate @2 (sfromR (tproject2 (tproject2 (tproject1 m1))))))) in smatmul2 (sfromR (tproject1 (tproject2 m1))) (tanh ((smatmul2 (sfromR (tproject1 (tproject1 (tproject2 (tproject1 m1))))) (tanh ((str (sreplicate @2 (sdot1In (sconcrete (sreplicate [2,2] 7.0)) (sfromR (tproject1 (tproject1 (tproject1 (tproject1 m1))))))) + smatmul2 (sfromR (tproject2 (tproject1 (tproject1 (tproject1 m1))))) (sslice (SNat @0) (SNat @2) m40)) + str (sreplicate @2 (sfromR (tproject2 (tproject1 (tproject1 m1))))))) + smatmul2 (sfromR (tproject2 (tproject1 (tproject2 (tproject1 m1))))) (sslice (SNat @2) (SNat @2) m40)) + str (sreplicate @2 (sfromR (tproject2 (tproject2 (tproject1 m1))))))) + str (sreplicate @2 (sfromR (tproject2 (tproject2 m1)))))"
AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR 2 Double)
-> String
forall (x :: TK) (z :: TK). AstArtifactRev x z -> String
printArtifactPrimalPretty AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR 2 Double)
AstArtifactRev
(X (ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) Double))
(TKR 2 Double)
artifactRev
String -> String -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= String
"\\m1 -> let m37 = tanh ((str (sreplicate @2 (ssum @2 (sconcrete (sreplicate [2,2] 7.0) * str (sfromR (tproject1 (tproject1 (tproject1 (tproject1 m1)))))))) + str (sreplicate @2 (ssum @2 (sconcrete (sreplicate [2,2] 0.0))))) + str (sreplicate @2 (sfromR (tproject2 (tproject1 (tproject1 m1)))))) ; v38 = tanh ((ssum @2 (sconcrete (sreplicate [2,2] 7.0) * str (sfromR (tproject1 (tproject1 (tproject1 (tproject1 m1)))))) + ssum @2 (sconcrete (sreplicate [2,2] 0.0))) + sfromR (tproject2 (tproject1 (tproject1 m1)))) ; m39 = tanh ((str (sreplicate @2 (ssum @2 (str (sfromR (tproject1 (tproject1 (tproject2 (tproject1 m1))))) * str (sreplicate @2 v38)))) + str (sreplicate @2 (ssum @2 (sconcrete (sreplicate [2,2] 0.0))))) + str (sreplicate @2 (sfromR (tproject2 (tproject2 (tproject1 m1)))))) ; m40 = sappend m37 m39 ; m41 = tanh ((sreplicate @2 (ssum @2 (sconcrete (sreplicate [2,2] 7.0) * str (sfromR (tproject1 (tproject1 (tproject1 (tproject1 m1))))))) + ssum @2 (str (sreplicate @2 (str (sfromR (tproject2 (tproject1 (tproject1 (tproject1 m1))))))) * stranspose @[2,1,0] (sreplicate @2 (str (sslice (SNat @0) (SNat @2) m40))))) + sreplicate @2 (sfromR (tproject2 (tproject1 (tproject1 m1))))) ; m42 = tanh ((ssum @2 (stranspose @[1,2,0] (sreplicate @2 (str (sfromR (tproject1 (tproject1 (tproject2 (tproject1 m1))))))) * stranspose @[2,0,1] (sreplicate @2 m41)) + ssum @2 (stranspose @[1,2,0] (sreplicate @2 (str (sfromR (tproject2 (tproject1 (tproject2 (tproject1 m1))))))) * stranspose @[2,0,1] (sreplicate @2 (str (sslice (SNat @2) (SNat @2) m40))))) + str (sreplicate @2 (sfromR (tproject2 (tproject2 (tproject1 m1)))))) in rfromS (ssum @2 (stranspose @[2,1,0] (sreplicate @2 (sfromR (tproject1 (tproject2 m1)))) * str (sreplicate @10 m42)) + str (sreplicate @2 (sfromR (tproject2 (tproject2 m1)))))"
AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR 2 Double)
-> String
forall (x :: TK) (z :: TK). AstArtifactRev x z -> String
printArtifactPretty AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR 2 Double)
AstArtifactRev
(X (ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) Double))
(TKR 2 Double)
artifactRev
String -> String -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= String
"\\dret m1 -> let m37 = tanh ((str (sreplicate @2 (ssum @2 (sconcrete (sreplicate [2,2] 7.0) * str (sfromR (tproject1 (tproject1 (tproject1 (tproject1 m1)))))))) + str (sreplicate @2 (ssum @2 (sconcrete (sreplicate [2,2] 0.0))))) + str (sreplicate @2 (sfromR (tproject2 (tproject1 (tproject1 m1)))))) ; v38 = tanh ((ssum @2 (sconcrete (sreplicate [2,2] 7.0) * str (sfromR (tproject1 (tproject1 (tproject1 (tproject1 m1)))))) + ssum @2 (sconcrete (sreplicate [2,2] 0.0))) + sfromR (tproject2 (tproject1 (tproject1 m1)))) ; m39 = tanh ((str (sreplicate @2 (ssum @2 (str (sfromR (tproject1 (tproject1 (tproject2 (tproject1 m1))))) * str (sreplicate @2 v38)))) + str (sreplicate @2 (ssum @2 (sconcrete (sreplicate [2,2] 0.0))))) + str (sreplicate @2 (sfromR (tproject2 (tproject2 (tproject1 m1)))))) ; m40 = sappend m37 m39 ; m41 = tanh ((sreplicate @2 (ssum @2 (sconcrete (sreplicate [2,2] 7.0) * str (sfromR (tproject1 (tproject1 (tproject1 (tproject1 m1))))))) + ssum @2 (str (sreplicate @2 (str (sfromR (tproject2 (tproject1 (tproject1 (tproject1 m1))))))) * stranspose @[2,1,0] (sreplicate @2 (str (sslice (SNat @0) (SNat @2) m40))))) + sreplicate @2 (sfromR (tproject2 (tproject1 (tproject1 m1))))) ; m42 = tanh ((ssum @2 (stranspose @[1,2,0] (sreplicate @2 (str (sfromR (tproject1 (tproject1 (tproject2 (tproject1 m1))))))) * stranspose @[2,0,1] (sreplicate @2 m41)) + ssum @2 (stranspose @[1,2,0] (sreplicate @2 (str (sfromR (tproject2 (tproject1 (tproject2 (tproject1 m1))))))) * stranspose @[2,0,1] (sreplicate @2 (str (sslice (SNat @2) (SNat @2) m40))))) + str (sreplicate @2 (sfromR (tproject2 (tproject2 (tproject1 m1)))))) ; m44 = (sconcrete (sreplicate [2,2] 1.0) + negate m42 * m42) * ssum @10 (str (stranspose @[2,1,0] (sreplicate @2 (sfromR (tproject1 (tproject2 m1)))) * sreplicate @2 (sfromR dret))) ; m45 = (sconcrete (sreplicate [2,2] 1.0) + negate m41 * m41) * ssum @2 (stranspose @[1,2,0] (stranspose @[1,2,0] (sreplicate @2 (str (sfromR (tproject1 (tproject1 (tproject2 (tproject1 m1))))))) * sreplicate @2 m44)) ; m46 = sappend (sconcrete (sfromListLinear [0,2] [])) (sappend (str (ssum @2 (stranspose @[2,1,0] (str (sreplicate @2 (str (sfromR (tproject2 (tproject1 (tproject1 (tproject1 m1))))))) * sreplicate @2 m45)))) (sconcrete (sreplicate [2,2] 0.0))) + sappend (sconcrete (sreplicate [2,2] 0.0)) (sappend (str (ssum @2 (stranspose @[1,2,0] (stranspose @[1,2,0] (sreplicate @2 (str (sfromR (tproject2 (tproject1 (tproject2 (tproject1 m1))))))) * sreplicate @2 m44)))) (sconcrete (sfromListLinear [0,2] []))) ; m47 = (sconcrete (sreplicate [2,2] 1.0) + negate m39 * m39) * sslice (SNat @2) (SNat @2) m46 ; m48 = sreplicate @2 (ssum @2 (str m47)) ; v49 = (sconcrete (sreplicate [2] 1.0) + negate v38 * v38) * ssum @2 (str (str (sfromR (tproject1 (tproject1 (tproject2 (tproject1 m1))))) * m48)) ; m50 = (sconcrete (sreplicate [2,2] 1.0) + negate m37 * m37) * sslice (SNat @0) (SNat @2) m46 in tpair (tpair (tpair (tpair (rfromS (str (sconcrete (sreplicate [2,2] 7.0) * sreplicate @2 (ssum @2 (str m50))) + (str (sconcrete (sreplicate [2,2] 7.0) * sreplicate @2 v49) + str (sconcrete (sreplicate [2,2] 7.0) * sreplicate @2 (ssum @2 m45))))) (rfromS (str (sconcrete (sreplicate [2,2] 0.0)) + (str (sconcrete (sreplicate [2,2] 0.0)) + str (ssum @2 (str (stranspose @[2,1,0] (sreplicate @2 (str (sslice (SNat @0) (SNat @2) m40))) * sreplicate @2 m45))))))) (rfromS (ssum @2 (str m50) + (v49 + ssum @2 m45)))) (tpair (tpair (rfromS (str (str (sreplicate @2 v38) * m48) + str (ssum @2 (stranspose @[2,0,1] (stranspose @[2,0,1] (sreplicate @2 m41) * sreplicate @2 m44))))) (rfromS (str (sconcrete (sreplicate [2,2] 0.0)) + str (ssum @2 (stranspose @[2,0,1] (stranspose @[2,0,1] (sreplicate @2 (str (sslice (SNat @2) (SNat @2) m40))) * sreplicate @2 m44)))))) (rfromS (ssum @2 (str m47) + ssum @2 (str m44))))) (tpair (rfromS (ssum @2 (stranspose @[2,1,0] (str (sreplicate @10 m42) * sreplicate @2 (sfromR dret))))) (rfromS (ssum @2 (str (sfromR dret)))))"
AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR 2 Double)
-> String
forall (x :: TK) (z :: TK). AstArtifactRev x z -> String
printArtifactPretty (AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR 2 Double)
-> AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR 2 Double)
forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstArtifactRev x z
simplifyArtifact AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
(TKR 2 Double)
AstArtifactRev
(X (ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) Double))
(TKR 2 Double)
artifactRev)
String -> String -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= String
"\\dret m1 -> tconvert (ConvT2 (ConvT2 (ConvT2 (ConvT2 (ConvCmp (ConvXR STKScalar) ConvSX) (ConvCmp (ConvXR STKScalar) ConvSX)) (ConvCmp (ConvXR STKScalar) ConvSX)) (ConvT2 (ConvT2 (ConvCmp (ConvXR STKScalar) ConvSX) (ConvCmp (ConvXR STKScalar) ConvSX)) (ConvCmp (ConvXR STKScalar) ConvSX))) (ConvT2 (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [10,2] FTKScalar)) ConvSX)) (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [10] FTKScalar)) ConvSX)))) (STKProduct (STKProduct (STKProduct (STKProduct (STKS [2,2] STKScalar) (STKS [2,2] STKScalar)) (STKS [2] STKScalar)) (STKProduct (STKProduct (STKS [2,2] STKScalar) (STKS [2,2] STKScalar)) (STKS [2] STKScalar))) (STKProduct (STKS [10,2] STKScalar) (STKS [10] STKScalar))) (let m37 = tanh (str (sreplicate @2 (sdot1In (sconcrete (sreplicate [2,2] 7.0)) (sfromR (tproject1 (tproject1 (tproject1 (tproject1 m1))))))) + str (sreplicate @2 (sfromR (tproject2 (tproject1 (tproject1 m1)))))) ; v38 = tanh (sdot1In (sconcrete (sreplicate [2,2] 7.0)) (sfromR (tproject1 (tproject1 (tproject1 (tproject1 m1))))) + sfromR (tproject2 (tproject1 (tproject1 m1)))) ; m39 = tanh (str (sreplicate @2 (sdot1In (sfromR (tproject1 (tproject1 (tproject2 (tproject1 m1))))) (sreplicate @2 v38))) + str (sreplicate @2 (sfromR (tproject2 (tproject2 (tproject1 m1)))))) ; m40 = sappend m37 m39 ; m41 = tanh ((sreplicate @2 (sdot1In (sconcrete (sreplicate [2,2] 7.0)) (sfromR (tproject1 (tproject1 (tproject1 (tproject1 m1)))))) + smatmul2 (str (sslice (SNat @0) (SNat @2) m40)) (str (sfromR (tproject2 (tproject1 (tproject1 (tproject1 m1))))))) + sreplicate @2 (sfromR (tproject2 (tproject1 (tproject1 m1))))) ; m42 = tanh ((smatmul2 (sfromR (tproject1 (tproject1 (tproject2 (tproject1 m1))))) (str m41) + smatmul2 (sfromR (tproject2 (tproject1 (tproject2 (tproject1 m1))))) (sslice (SNat @2) (SNat @2) m40)) + str (sreplicate @2 (sfromR (tproject2 (tproject2 (tproject1 m1)))))) ; m44 = (sconcrete (sreplicate [2,2] 1.0) + negate m42 * m42) * smatmul2 (str (sfromR (tproject1 (tproject2 m1)))) (sfromR dret) ; m45 = (sconcrete (sreplicate [2,2] 1.0) + negate m41 * m41) * smatmul2 (str m44) (sfromR (tproject1 (tproject1 (tproject2 (tproject1 m1))))) ; m46 = sappend (smatmul2 (str (sfromR (tproject2 (tproject1 (tproject1 (tproject1 m1)))))) (str m45)) (sconcrete (sreplicate [2,2] 0.0)) + sappend (sconcrete (sreplicate [2,2] 0.0)) (smatmul2 (str (sfromR (tproject2 (tproject1 (tproject2 (tproject1 m1)))))) m44) ; m47 = (sconcrete (sreplicate [2,2] 1.0) + negate m39 * m39) * sslice (SNat @2) (SNat @2) m46 ; v48 = ssum @2 (str m47) ; v49 = (sconcrete (sreplicate [2] 1.0) + negate v38 * v38) * sdot1In (str (sfromR (tproject1 (tproject1 (tproject2 (tproject1 m1)))))) (sreplicate @2 v48) ; m50 = (sconcrete (sreplicate [2,2] 1.0) + negate m37 * m37) * sslice (SNat @0) (SNat @2) m46 in tpair (tpair (tpair (tpair (sconcrete (sreplicate [2,2] 7.0) * str (sreplicate @2 (ssum @2 (str m50))) + (sconcrete (sreplicate [2,2] 7.0) * str (sreplicate @2 v49) + sconcrete (sreplicate [2,2] 7.0) * str (sreplicate @2 (ssum @2 m45)))) (smatmul2 (str m45) (str (sslice (SNat @0) (SNat @2) m40)))) (ssum @2 (str m50) + (v49 + ssum @2 m45))) (tpair (tpair (sreplicate @2 v38 * str (sreplicate @2 v48) + smatmul2 m44 m41) (smatmul2 m44 (str (sslice (SNat @2) (SNat @2) m40)))) (ssum @2 (str m47) + ssum @2 (str m44)))) (tpair (smatmul2 (sfromR dret) (str m42)) (ssum @2 (str (sfromR dret)))))"
testRNNOAst2 :: Assertion
testRNNOAst2 :: Assertion
testRNNOAst2 = do
let batch_size :: Int
batch_size = Int
2
sizeMnistHeightI :: Int
sizeMnistHeightI = Int
2
ftk :: FullShapeTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
ftk = forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete
(forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(X (ADRnnMnistParameters Concrete Double)))
(forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget @Concrete (ADRnnMnistParameters Concrete Double
-> Concrete (X (ADRnnMnistParameters Concrete Double)))
-> ADRnnMnistParameters Concrete Double
-> Concrete (X (ADRnnMnistParameters Concrete Double))
forall a b. (a -> b) -> a -> b
$ Int -> Int -> ADRnnMnistParameters Concrete Double
valsInitRNNOPP Int
2 Int
sizeMnistHeightI)
varName :: AstVarName
FullSpan
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
varName = FullShapeTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
-> Maybe (Int64, Int64)
-> AstVarId
-> AstVarName
FullSpan
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
forall (s :: AstSpanType) (y :: TK).
FullShapeTK y -> Maybe (Int64, Int64) -> AstVarId -> AstVarName s y
mkAstVarName FullShapeTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
ftk Maybe (Int64, Int64)
forall a. Maybe a
Nothing (AstVarId
-> AstVarName
FullSpan
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
-> (Int -> AstVarId)
-> Int
-> AstVarName
FullSpan
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> AstVarId
intToAstVarId (Int
-> AstVarName
FullSpan
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
-> Int
-> AstVarName
FullSpan
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
forall a b. (a -> b) -> a -> b
$ Int
100000000
var :: AstTensor AstMethodLet FullSpan
(X (ADRnnMnistParameters Concrete Double))
var :: AstTensor
AstMethodLet FullSpan (X (ADRnnMnistParameters Concrete Double))
var = AstVarName
FullSpan
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
-> AstTensor
AstMethodLet
FullSpan
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
forall (b :: AstSpanType) (c :: TK) (a :: AstMethodOfSharing).
AstVarName b c -> AstTensor a b c
AstVar AstVarName
FullSpan
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
varName
vals :: Concrete (X (ADRnnMnistParameters Concrete Double))
vals = forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget @Concrete (ADRnnMnistParameters Concrete Double
-> Concrete (X (ADRnnMnistParameters Concrete Double)))
-> ADRnnMnistParameters Concrete Double
-> Concrete (X (ADRnnMnistParameters Concrete Double))
forall a b. (a -> b) -> a -> b
$ Int -> Int -> ADRnnMnistParameters Concrete Double
valsInitRNNOPP Int
2 Int
sizeMnistHeightI
env :: AstEnv Concrete
env = AstVarName
FullSpan
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
-> AstEnv Concrete
-> AstEnv Concrete
forall (target :: Target) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName
FullSpan
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
varName Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double)))
(TKProduct
(TKProduct (TKR 2 Double) (TKR 2 Double))
(TKR2 1 (TKScalar Double))))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))
Concrete (X (ADRnnMnistParameters Concrete Double))
vals AstEnv Concrete
forall (target :: Target). AstEnv target
emptyEnv
blackGlyph :: Concrete
(BuildTensorKind 2 (BuildTensorKind 2 (TKR2 1 (TKScalar Double))))
blackGlyph = SNat 2
-> SingletonTK (BuildTensorKind 2 (TKR2 1 (TKScalar Double)))
-> Concrete (BuildTensorKind 2 (TKR2 1 (TKScalar Double)))
-> Concrete
(BuildTensorKind 2 (BuildTensorKind 2 (TKR2 1 (TKScalar Double))))
forall (z :: TK) (k :: Natural).
ConvertTensor Concrete =>
SNat k
-> SingletonTK z -> Concrete z -> Concrete (BuildTensorKind k z)
forall (target :: Target) (z :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
treplicate (forall (n :: Natural). KnownNat n => SNat n
SNat @2) SingletonTK (BuildTensorKind 2 (TKR2 1 (TKScalar Double)))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK
(Concrete (BuildTensorKind 2 (TKR2 1 (TKScalar Double)))
-> Concrete
(BuildTensorKind 2 (BuildTensorKind 2 (TKR2 1 (TKScalar Double)))))
-> Concrete (BuildTensorKind 2 (TKR2 1 (TKScalar Double)))
-> Concrete
(BuildTensorKind 2 (BuildTensorKind 2 (TKR2 1 (TKScalar Double))))
forall a b. (a -> b) -> a -> b
$ SNat 2
-> SingletonTK (TKR2 1 (TKScalar Double))
-> Concrete (TKR2 1 (TKScalar Double))
-> Concrete (BuildTensorKind 2 (TKR2 1 (TKScalar Double)))
forall (z :: TK) (k :: Natural).
ConvertTensor Concrete =>
SNat k
-> SingletonTK z -> Concrete z -> Concrete (BuildTensorKind k z)
forall (target :: Target) (z :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
treplicate (forall (n :: Natural). KnownNat n => SNat n
SNat @2) SingletonTK (TKR2 1 (TKScalar Double))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK
(Concrete (TKR2 1 (TKScalar Double))
-> Concrete (BuildTensorKind 2 (TKR2 1 (TKScalar Double))))
-> Concrete (TKR2 1 (TKScalar Double))
-> Concrete (BuildTensorKind 2 (TKR2 1 (TKScalar Double)))
forall a b. (a -> b) -> a -> b
$ SNat 2
-> SingletonTK (TKR 0 Double)
-> Concrete (TKR 0 Double)
-> Concrete (BuildTensorKind 2 (TKR 0 Double))
forall (z :: TK) (k :: Natural).
ConvertTensor Concrete =>
SNat k
-> SingletonTK z -> Concrete z -> Concrete (BuildTensorKind k z)
forall (target :: Target) (z :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
treplicate (forall (n :: Natural). KnownNat n => SNat n
SNat @2) SingletonTK (TKR 0 Double)
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK (Concrete (TKR 0 Double)
-> Concrete (BuildTensorKind 2 (TKR 0 Double)))
-> Concrete (TKR 0 Double)
-> Concrete (BuildTensorKind 2 (TKR 0 Double))
forall a b. (a -> b) -> a -> b
$ Double -> Concrete (TKR 0 Double)
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
r -> target (TKR 0 r)
rscalar Double
7
afcnn2 :: ADReady f
=> ADRnnMnistParameters f Double
-> f (TKR 2 Double)
afcnn2 :: forall (f :: Target).
ADReady f =>
ADRnnMnistParameters f Double -> f (TKR 2 Double)
afcnn2 = Int
-> PrimalOf f (TKR 3 Double)
-> ADRnnMnistParameters f Double
-> f (TKR 2 Double)
forall (target :: Target) r.
(ADReady target, GoodScalar r, Differentiable r) =>
Int
-> PrimalOf target (TKR 3 r)
-> ADRnnMnistParameters target r
-> target (TKR 2 r)
MnistRnnRanked2.rnnMnistZeroR
Int
batch_size (Ranked 3 Double -> PrimalOf f (TKR 3 Double)
forall r (target :: Target) (n :: Natural).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete (Ranked 3 Double -> PrimalOf f (TKR 3 Double))
-> Ranked 3 Double -> PrimalOf f (TKR 3 Double)
forall a b. (a -> b) -> a -> b
$ Concrete (TKR 3 Double) -> RepConcrete (TKR 3 Double)
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete (TKR 3 Double)
blackGlyph)
afcnn1 :: AstTensor AstMethodLet FullSpan (TKR 2 Double)
afcnn1 = ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) Double
-> AstTensor AstMethodLet FullSpan (TKR 2 Double)
forall (f :: Target).
ADReady f =>
ADRnnMnistParameters f Double -> f (TKR 2 Double)
afcnn2 (ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) Double
-> AstTensor AstMethodLet FullSpan (TKR 2 Double))
-> ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) Double
-> AstTensor AstMethodLet FullSpan (TKR 2 Double)
forall a b. (a -> b) -> a -> b
$ AstTensor
AstMethodLet
FullSpan
(X (ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) Double))
-> ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) Double
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget AstTensor
AstMethodLet
FullSpan
(X (ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) Double))
AstTensor
AstMethodLet FullSpan (X (ADRnnMnistParameters Concrete Double))
var
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target -> AstTensor AstMethodLet FullSpan y -> target y
interpretAstFull @Concrete AstEnv Concrete
env AstTensor AstMethodLet FullSpan (TKR 2 Double)
afcnn1
Concrete (TKR 2 Double) -> Concrete (TKR 2 Double) -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= ADRnnMnistParameters Concrete Double -> Concrete (TKR 2 Double)
forall (f :: Target).
ADReady f =>
ADRnnMnistParameters f Double -> f (TKR 2 Double)
afcnn2 (Int -> Int -> ADRnnMnistParameters Concrete Double
valsInitRNNOPP Int
2 Int
sizeMnistHeightI)
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target -> AstTensor AstMethodLet FullSpan y -> target y
interpretAstFull @Concrete AstEnv Concrete
env
(forall (z :: TK) (s :: AstSpanType).
AstSpan s =>
AstTensor AstMethodLet s z -> AstTensor AstMethodLet s z
simplifyInline @(TKR 2 Double) AstTensor AstMethodLet FullSpan (TKR 2 Double)
afcnn1)
Concrete (TKR 2 Double) -> Concrete (TKR 2 Double) -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= ADRnnMnistParameters Concrete Double -> Concrete (TKR 2 Double)
forall (f :: Target).
ADReady f =>
ADRnnMnistParameters f Double -> f (TKR 2 Double)
afcnn2 (Int -> Int -> ADRnnMnistParameters Concrete Double
valsInitRNNOPP Int
2 Int
sizeMnistHeightI)
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target -> AstTensor AstMethodLet FullSpan y -> target y
interpretAstFull @Concrete AstEnv Concrete
env
(forall (z :: TK) (s :: AstSpanType).
AstSpan s =>
AstTensor AstMethodLet s z -> AstTensor AstMethodLet s z
simplifyInlineContract @(TKR 2 Double) AstTensor AstMethodLet FullSpan (TKR 2 Double)
afcnn1)
Concrete (TKR 2 Double) -> Concrete (TKR 2 Double) -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= ADRnnMnistParameters Concrete Double -> Concrete (TKR 2 Double)
forall (f :: Target).
ADReady f =>
ADRnnMnistParameters f Double -> f (TKR 2 Double)
afcnn2 (Int -> Int -> ADRnnMnistParameters Concrete Double
valsInitRNNOPP Int
2 Int
sizeMnistHeightI)
tensorMnistCNNRPP :: TestTree
tensorMnistCNNRPP :: TestTree
tensorMnistCNNRPP = String -> [TestTree] -> TestTree
testGroup String
"Ast tests for CNNR MNIST"
[ String -> Assertion -> TestTree
testCase String
"CNNO PP 1" Assertion
testCNNOPP1
, String -> Assertion -> TestTree
testCase String
"CNNO Ast 1" Assertion
testCNNOAst1
, String -> Assertion -> TestTree
testCase String
"CNNO PP 2" Assertion
testCNNOPP2
, String -> Assertion -> TestTree
testCase String
"CNNO Ast 2" Assertion
testCNNOAst2
, String -> Assertion -> TestTree
testCase String
"CNNO PP 2S" Assertion
testCNNOPP2S
]
testCNNOPP1 :: Assertion
testCNNOPP1 :: Assertion
testCNNOPP1 = do
Assertion
resetVarCounter
let batch_size :: Int
batch_size = Int
5
sizeMnistWidthI :: Int
sizeMnistWidthI = Int
7
sizeMnistHeightI :: Int
sizeMnistHeightI = Int
9
ftk :: FullShapeTK
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
ftk = forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete
(forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(X (MnistCnnRanked2.ADCnnMnistParameters
Concrete Double)))
Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
vals
valsInit :: MnistCnnRanked2.ADCnnMnistParameters Concrete Double
valsInit :: ADCnnMnistParameters Concrete Double
valsInit =
((Concrete
(TKS
((':)
@Natural
1
((':)
@Natural 1 ((':) @Natural 2 ((':) @Natural 2 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 1 ('[] @Natural)) Double)),
(Concrete
(TKS
((':)
@Natural
1
((':)
@Natural 1 ((':) @Natural 2 ((':) @Natural 2 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 1 ('[] @Natural)) Double)),
(Concrete
(TKS ((':) @Natural 1 ((':) @Natural 2 ('[] @Natural))) Double),
Concrete (TKS ((':) @Natural 1 ('[] @Natural)) Double)),
(Concrete
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 1 ('[] @Natural)))
Double),
Concrete
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
-> NoShape
((Concrete
(TKS
((':)
@Natural
1
((':)
@Natural 1 ((':) @Natural 2 ((':) @Natural 2 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 1 ('[] @Natural)) Double)),
(Concrete
(TKS
((':)
@Natural
1
((':)
@Natural 1 ((':) @Natural 2 ((':) @Natural 2 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 1 ('[] @Natural)) Double)),
(Concrete
(TKS ((':) @Natural 1 ((':) @Natural 2 ('[] @Natural))) Double),
Concrete (TKS ((':) @Natural 1 ('[] @Natural)) Double)),
(Concrete
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 1 ('[] @Natural)))
Double),
Concrete
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape (((Concrete
(TKS
((':)
@Natural
1
((':)
@Natural 1 ((':) @Natural 2 ((':) @Natural 2 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 1 ('[] @Natural)) Double)),
(Concrete
(TKS
((':)
@Natural
1
((':)
@Natural 1 ((':) @Natural 2 ((':) @Natural 2 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 1 ('[] @Natural)) Double)),
(Concrete
(TKS ((':) @Natural 1 ((':) @Natural 2 ('[] @Natural))) Double),
Concrete (TKS ((':) @Natural 1 ('[] @Natural)) Double)),
(Concrete
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 1 ('[] @Natural)))
Double),
Concrete
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
-> NoShape
((Concrete
(TKS
((':)
@Natural
1
((':)
@Natural 1 ((':) @Natural 2 ((':) @Natural 2 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 1 ('[] @Natural)) Double)),
(Concrete
(TKS
((':)
@Natural
1
((':)
@Natural 1 ((':) @Natural 2 ((':) @Natural 2 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 1 ('[] @Natural)) Double)),
(Concrete
(TKS ((':) @Natural 1 ((':) @Natural 2 ('[] @Natural))) Double),
Concrete (TKS ((':) @Natural 1 ('[] @Natural)) Double)),
(Concrete
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 1 ('[] @Natural)))
Double),
Concrete
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double))))
-> ((Concrete
(TKS
((':)
@Natural
1
((':)
@Natural 1 ((':) @Natural 2 ((':) @Natural 2 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 1 ('[] @Natural)) Double)),
(Concrete
(TKS
((':)
@Natural
1
((':)
@Natural 1 ((':) @Natural 2 ((':) @Natural 2 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 1 ('[] @Natural)) Double)),
(Concrete
(TKS ((':) @Natural 1 ((':) @Natural 2 ('[] @Natural))) Double),
Concrete (TKS ((':) @Natural 1 ('[] @Natural)) Double)),
(Concrete
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 1 ('[] @Natural)))
Double),
Concrete
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
-> NoShape
((Concrete
(TKS
((':)
@Natural
1
((':)
@Natural 1 ((':) @Natural 2 ((':) @Natural 2 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 1 ('[] @Natural)) Double)),
(Concrete
(TKS
((':)
@Natural
1
((':)
@Natural 1 ((':) @Natural 2 ((':) @Natural 2 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 1 ('[] @Natural)) Double)),
(Concrete
(TKS ((':) @Natural 1 ((':) @Natural 2 ('[] @Natural))) Double),
Concrete (TKS ((':) @Natural 1 ('[] @Natural)) Double)),
(Concrete
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 1 ('[] @Natural)))
Double),
Concrete
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
forall a b. (a -> b) -> a -> b
$ (ADCnnMnistParametersShaped Concrete 7 9 1 1 1 1 Double, StdGen)
-> ADCnnMnistParametersShaped Concrete 7 9 1 1 1 1 Double
forall a b. (a, b) -> a
fst
((ADCnnMnistParametersShaped Concrete 7 9 1 1 1 1 Double, StdGen)
-> ADCnnMnistParametersShaped Concrete 7 9 1 1 1 1 Double)
-> (ADCnnMnistParametersShaped Concrete 7 9 1 1 1 1 Double, StdGen)
-> ADCnnMnistParametersShaped Concrete 7 9 1 1 1 1 Double
forall a b. (a -> b) -> a -> b
$ forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue @(MnistCnnRanked2.ADCnnMnistParametersShaped
Concrete 7 9
1 1 1 1 Double)
Double
0.4 (Int -> StdGen
mkStdGen Int
44)
vals :: Concrete (X (ADCnnMnistParameters Concrete Double))
vals = forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget @Concrete ADCnnMnistParameters Concrete Double
valsInit
blackGlyph :: Concrete
(BuildTensorKind
5
(BuildTensorKind 1 (BuildTensorKind 7 (TKR2 1 (TKScalar Double)))))
blackGlyph = SNat 5
-> SingletonTK
(BuildTensorKind 1 (BuildTensorKind 7 (TKR2 1 (TKScalar Double))))
-> Concrete
(BuildTensorKind 1 (BuildTensorKind 7 (TKR2 1 (TKScalar Double))))
-> Concrete
(BuildTensorKind
5
(BuildTensorKind 1 (BuildTensorKind 7 (TKR2 1 (TKScalar Double)))))
forall (z :: TK) (k :: Natural).
ConvertTensor Concrete =>
SNat k
-> SingletonTK z -> Concrete z -> Concrete (BuildTensorKind k z)
forall (target :: Target) (z :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
treplicate (forall (n :: Natural). KnownNat n => SNat n
SNat @5) SingletonTK
(BuildTensorKind 1 (BuildTensorKind 7 (TKR2 1 (TKScalar Double))))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK
(Concrete
(BuildTensorKind 1 (BuildTensorKind 7 (TKR2 1 (TKScalar Double))))
-> Concrete
(BuildTensorKind
5
(BuildTensorKind
1 (BuildTensorKind 7 (TKR2 1 (TKScalar Double))))))
-> Concrete
(BuildTensorKind 1 (BuildTensorKind 7 (TKR2 1 (TKScalar Double))))
-> Concrete
(BuildTensorKind
5
(BuildTensorKind 1 (BuildTensorKind 7 (TKR2 1 (TKScalar Double)))))
forall a b. (a -> b) -> a -> b
$ SNat 1
-> SingletonTK (BuildTensorKind 7 (TKR2 1 (TKScalar Double)))
-> Concrete (BuildTensorKind 7 (TKR2 1 (TKScalar Double)))
-> Concrete
(BuildTensorKind 1 (BuildTensorKind 7 (TKR2 1 (TKScalar Double))))
forall (z :: TK) (k :: Natural).
ConvertTensor Concrete =>
SNat k
-> SingletonTK z -> Concrete z -> Concrete (BuildTensorKind k z)
forall (target :: Target) (z :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
treplicate (forall (n :: Natural). KnownNat n => SNat n
SNat @1) SingletonTK (BuildTensorKind 7 (TKR2 1 (TKScalar Double)))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK
(Concrete (BuildTensorKind 7 (TKR2 1 (TKScalar Double)))
-> Concrete
(BuildTensorKind 1 (BuildTensorKind 7 (TKR2 1 (TKScalar Double)))))
-> Concrete (BuildTensorKind 7 (TKR2 1 (TKScalar Double)))
-> Concrete
(BuildTensorKind 1 (BuildTensorKind 7 (TKR2 1 (TKScalar Double))))
forall a b. (a -> b) -> a -> b
$ SNat 7
-> SingletonTK (TKR2 1 (TKScalar Double))
-> Concrete (TKR2 1 (TKScalar Double))
-> Concrete (BuildTensorKind 7 (TKR2 1 (TKScalar Double)))
forall (z :: TK) (k :: Natural).
ConvertTensor Concrete =>
SNat k
-> SingletonTK z -> Concrete z -> Concrete (BuildTensorKind k z)
forall (target :: Target) (z :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
treplicate (forall (n :: Natural). KnownNat n => SNat n
SNat @7) SingletonTK (TKR2 1 (TKScalar Double))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK
(Concrete (TKR2 1 (TKScalar Double))
-> Concrete (BuildTensorKind 7 (TKR2 1 (TKScalar Double))))
-> Concrete (TKR2 1 (TKScalar Double))
-> Concrete (BuildTensorKind 7 (TKR2 1 (TKScalar Double)))
forall a b. (a -> b) -> a -> b
$ SNat 9
-> SingletonTK (TKR 0 Double)
-> Concrete (TKR 0 Double)
-> Concrete (BuildTensorKind 9 (TKR 0 Double))
forall (z :: TK) (k :: Natural).
ConvertTensor Concrete =>
SNat k
-> SingletonTK z -> Concrete z -> Concrete (BuildTensorKind k z)
forall (target :: Target) (z :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
treplicate (forall (n :: Natural). KnownNat n => SNat n
SNat @9) SingletonTK (TKR 0 Double)
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK (Concrete (TKR 0 Double)
-> Concrete (BuildTensorKind 9 (TKR 0 Double)))
-> Concrete (TKR 0 Double)
-> Concrete (BuildTensorKind 9 (TKR 0 Double))
forall a b. (a -> b) -> a -> b
$ Double -> Concrete (TKR 0 Double)
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
r -> target (TKR 0 r)
rscalar Double
7
afcnn2 :: ADReady f
=> MnistCnnRanked2.ADCnnMnistParameters f Double
-> f (TKR 2 Double)
afcnn2 :: forall (f :: Target).
ADReady f =>
ADCnnMnistParameters f Double -> f (TKR 2 Double)
afcnn2 = Int
-> Int
-> Int
-> PrimalOf f (TKR2 4 (TKScalar Double))
-> ADCnnMnistParameters f Double
-> f (TKR 2 Double)
forall (target :: Target) r.
(ADReady target, GoodScalar r, Differentiable r) =>
Int
-> Int
-> Int
-> PrimalOf target (TKR 4 r)
-> ADCnnMnistParameters target r
-> target (TKR 2 r)
MnistCnnRanked2.convMnistTwoR
Int
sizeMnistHeightI Int
sizeMnistWidthI Int
batch_size
(Ranked 4 Double -> PrimalOf f (TKR2 4 (TKScalar Double))
forall r (target :: Target) (n :: Natural).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete (Ranked 4 Double -> PrimalOf f (TKR2 4 (TKScalar Double)))
-> Ranked 4 Double -> PrimalOf f (TKR2 4 (TKScalar Double))
forall a b. (a -> b) -> a -> b
$ Concrete (TKR2 4 (TKScalar Double))
-> RepConcrete (TKR2 4 (TKScalar Double))
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete (TKR2 4 (TKScalar Double))
blackGlyph)
artifactRev :: AstArtifactRev
(X (ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) Double))
(TKR 2 Double)
artifactRev = IncomingCotangentHandling
-> (ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) Double
-> AstTensor AstMethodLet FullSpan (TKR 2 Double))
-> FullShapeTK
(X (ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) Double))
-> AstArtifactRev
(X (ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) Double))
(TKR 2 Double)
forall src (ztgt :: TK) tgt.
(AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
(tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
IncomingCotangentHandling
-> (src -> tgt)
-> FullShapeTK (X src)
-> AstArtifactRev (X src) ztgt
revArtifactAdapt IncomingCotangentHandling
UseIncomingCotangent ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) Double
-> AstTensor AstMethodLet FullSpan (TKR 2 Double)
forall (f :: Target).
ADReady f =>
ADCnnMnistParameters f Double -> f (TKR 2 Double)
afcnn2 FullShapeTK
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
FullShapeTK
(X (ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) Double))
ftk
AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
(TKR 2 Double)
-> String
forall (x :: TK) (z :: TK). AstArtifactRev x z -> String
printArtifactPrimalPretty (AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
(TKR 2 Double)
-> AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
(TKR 2 Double)
forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstArtifactRev x z
simplifyArtifact AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
(TKR 2 Double)
artifactRev)
String -> String -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= String
"\\u1 -> rfromS (let t189 = sreplicate @1 (ssum @4 (stranspose @[2,0,1] (sreshape @[7,9,4] (stranspose @[1,2,0] (sreplicate @1 (stranspose @[1,2,0] (sreplicate @1 (stranspose @[2,0,3,1] (sgather (stranspose @[2,0,1] (sgather (sconcrete (sreplicate [7,9] 7.0)) (\\[i358, i359] -> [i358 + i359]))) (\\[i185, i186] -> [i185 + i186])))))) * sreplicate @7 (sreplicate @9 (sreplicate @1 (sreplicate @1 (sfromR (tproject1 (tproject1 (tproject1 u1))) !$ [0, 0])))))))) + stranspose @[2,0,1] (sreplicate @7 (sreplicate @9 (sfromR (tproject2 (tproject1 (tproject1 u1)))))) ; t204 = sreshape @[3,4,4] (sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i192, i193, i194, i195] -> [ifH (sscalar -0.0 <=. negate (t189 !$ [0, kfromS (sconcrete (sfromListLinear [3,2] [0,1,2,3,4,5]) !$ [i192, i194]), kfromS (sconcrete (sfromListLinear [4,2] [0,1,2,3,4,5,6,7]) !$ [i193, i195])])) 0 1]) * sgather (t189 !$ [0]) (\\[i198, i199, i200, i201] -> [kfromS (sconcrete (sfromListLinear [3,2] [0,1,2,3,4,5]) !$ [i198, i200]), kfromS (sconcrete (sfromListLinear [4,2] [0,1,2,3,4,5,6,7]) !$ [i199, i201])])) ; t213 = sreplicate @1 (ssum @4 (stranspose @[2,0,1] (sreshape @[3,4,4] (stranspose @[2,0,3,1] (sgather (stranspose @[2,3,4,0,1] (sgather (stranspose @[3,5,0,4,1,2] (sgather (stranspose @[3,2,1,6,5,4,0] (sreplicate @2 (stranspose @[5,4,3,0,1,2] (sreplicate @4 (sreplicate @3 (sreplicate @2 (stranspose @[2,1,0] t204))))))) (\\[i346, i350] -> [kfromS (smaxIndex (t204 !$ [i346, i350])), i350, i346]))) (\\[i355, i356] -> [i355, i356, i355 + i356]))) (\\[i209, i210] -> [i209, i209 + i210, i210])) * sreplicate @3 (sreplicate @4 (sfromR (tproject1 (tproject2 (tproject1 u1))) !$ [0, 0])))))) + stranspose @[2,0,1] (sreplicate @3 (sreplicate @4 (sfromR (tproject2 (tproject2 (tproject1 u1)))))) ; m223 = sreshape @[2,4] (sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i215, i216, i217] -> [ifH (sscalar -0.0 <=. negate (t213 !$ [0, i216, kfromS (sconcrete (sfromListLinear [2,2] [0,1,2,3]) !$ [i215, i217])])) 0 1]) * stranspose @[0,2,1] (sgather (str (sslice (SNat @0) (SNat @2) (t213 !$ [0]))) (\\[i219, i220] -> [kfromS (sconcrete (sfromListLinear [2,2] [0,1,2,3]) !$ [i219, i220])]))) ; m227 = sreplicate @1 (sreplicate @5 (sdot0 (sfromR (tproject1 (tproject1 (tproject2 u1))) !$ [0]) (sgather m223 (\\[i224] -> [i224, kfromS (smaxIndex (m223 !$ [i224]))])))) + str (sreplicate @5 (sfromR (tproject2 (tproject1 (tproject2 u1))))) in str (sreplicate @5 (str (sfromR (tproject1 (tproject2 (tproject2 u1)))) !$ [0])) * sreplicate @10 (sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i228] -> [ifH (sscalar -0.0 <=. negate (m227 !$ [0, i228])) 0 1]) * m227 !$ [0]) + str (sreplicate @5 (sfromR (tproject2 (tproject2 (tproject2 u1))))))"
AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
(TKR 2 Double)
-> String
forall (x :: TK) (z :: TK). AstArtifactRev x z -> String
printArtifactPrimalPretty AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
(TKR 2 Double)
artifactRev
String -> String -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= String
"\\u1 -> let w187 = stranspose @[1,2,0] (sreplicate @1 (stranspose @[1,2,0] (sreplicate @1 (stranspose @[2,0,3,1] (sgather (stranspose @[2,0,1] (sgather (sconcrete (sreplicate [7,9] 7.0)) (\\[i183, i184] -> [i183 + i184]))) (\\[i185, i186] -> [i185 + i186])))))) ; w188 = sreplicate @7 (sreplicate @9 (sreplicate @1 (sreplicate @1 (sfromR (tproject1 (tproject1 (tproject1 u1))) !$ [0, 0])))) ; t189 = sreplicate @1 (ssum @4 (stranspose @[2,0,1] (sreshape @[7,9,4] (w187 * w188)))) + stranspose @[2,0,1] (sreplicate @7 (sreplicate @9 (sfromR (tproject2 (tproject1 (tproject1 u1)))))) ; m190 = str (sreplicate @2 (sconcrete (sreplicate [3] 2) * siota (SNat @3))) + sreplicate @3 (siota (SNat @2)) ; m191 = str (sreplicate @2 (sconcrete (sreplicate [4] 2) * siota (SNat @4))) + sreplicate @4 (siota (SNat @2)) ; u202 = sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i192, i193, i194, i195] -> [let x196 = m190 !$ [i192, i194] ; x197 = m191 !$ [i193, i195] in ifH (sscalar -0.0 <=. negate (t189 !$ [0, kfromS x196, kfromS x197])) 0 1]) ; u203 = sgather (t189 !$ [0]) (\\[i198, i199, i200, i201] -> [kfromS (m190 !$ [i198, i200]), kfromS (m191 !$ [i199, i201])]) ; t204 = sreshape @[3,4,4] (u202 * u203) ; u211 = stranspose @[2,0,3,1] (sgather (stranspose @[2,3,4,0,1] (sgather (stranspose @[3,5,0,4,1,2] (sgather (stranspose @[3,2,1,6,5,4,0] (sreplicate @2 (stranspose @[5,4,3,0,1,2] (sreplicate @4 (sreplicate @3 (sreplicate @2 (stranspose @[2,1,0] t204))))))) (\\[i205, i206] -> [kfromS (smaxIndex (t204 !$ [i205, i206])), i206, i205]))) (\\[i207, i208] -> [i207, i208, i207 + i208]))) (\\[i209, i210] -> [i209, i209 + i210, i210])) ; u212 = sreplicate @3 (sreplicate @4 (sfromR (tproject1 (tproject2 (tproject1 u1))) !$ [0, 0])) ; t213 = sreplicate @1 (ssum @4 (stranspose @[2,0,1] (sreshape @[3,4,4] (u211 * u212)))) + stranspose @[2,0,1] (sreplicate @3 (sreplicate @4 (sfromR (tproject2 (tproject2 (tproject1 u1)))))) ; m214 = str (sreplicate @2 (sconcrete (sreplicate [2] 2) * siota (SNat @2))) + sreplicate @2 (siota (SNat @2)) ; t221 = sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i215, i216, i217] -> [let x218 = m214 !$ [i215, i217] in ifH (sscalar -0.0 <=. negate (t213 !$ [0, i216, kfromS x218])) 0 1]) ; t222 = stranspose @[0,2,1] (sgather (str (sslice (SNat @0) (SNat @2) (t213 !$ [0]))) (\\[i219, i220] -> [kfromS (m214 !$ [i219, i220])])) ; m223 = sreshape @[2,4] (t221 * t222) ; v225 = sfromR (tproject1 (tproject1 (tproject2 u1))) !$ [0] ; v226 = sgather m223 (\\[i224] -> [i224, kfromS (smaxIndex (m223 !$ [i224]))]) ; m227 = sreplicate @1 (sreplicate @5 (ssum @2 (v225 * v226))) + str (sreplicate @5 (sfromR (tproject2 (tproject1 (tproject2 u1))))) ; v229 = sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i228] -> [ifH (sscalar -0.0 <=. negate (m227 !$ [0, i228])) 0 1]) ; v230 = m227 !$ [0] ; m231 = str (sreplicate @5 (str (sfromR (tproject1 (tproject2 (tproject2 u1)))) !$ [0])) ; m232 = sreplicate @10 (v229 * v230) in rfromS (m231 * m232 + str (sreplicate @5 (sfromR (tproject2 (tproject2 (tproject2 u1))))))"
AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
(TKR 2 Double)
-> String
forall (x :: TK) (z :: TK). AstArtifactRev x z -> String
printArtifactPretty AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
(TKR 2 Double)
artifactRev
String -> String -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= String
"\\dret u1 -> let w187 = stranspose @[1,2,0] (sreplicate @1 (stranspose @[1,2,0] (sreplicate @1 (stranspose @[2,0,3,1] (sgather (stranspose @[2,0,1] (sgather (sconcrete (sreplicate [7,9] 7.0)) (\\[i183, i184] -> [i183 + i184]))) (\\[i185, i186] -> [i185 + i186])))))) ; w188 = sreplicate @7 (sreplicate @9 (sreplicate @1 (sreplicate @1 (sfromR (tproject1 (tproject1 (tproject1 u1))) !$ [0, 0])))) ; t189 = sreplicate @1 (ssum @4 (stranspose @[2,0,1] (sreshape @[7,9,4] (w187 * w188)))) + stranspose @[2,0,1] (sreplicate @7 (sreplicate @9 (sfromR (tproject2 (tproject1 (tproject1 u1)))))) ; m190 = str (sreplicate @2 (sconcrete (sreplicate [3] 2) * siota (SNat @3))) + sreplicate @3 (siota (SNat @2)) ; m191 = str (sreplicate @2 (sconcrete (sreplicate [4] 2) * siota (SNat @4))) + sreplicate @4 (siota (SNat @2)) ; u202 = sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i192, i193, i194, i195] -> [let x196 = m190 !$ [i192, i194] ; x197 = m191 !$ [i193, i195] in ifH (sscalar -0.0 <=. negate (t189 !$ [0, kfromS x196, kfromS x197])) 0 1]) ; u203 = sgather (t189 !$ [0]) (\\[i198, i199, i200, i201] -> [kfromS (m190 !$ [i198, i200]), kfromS (m191 !$ [i199, i201])]) ; t204 = sreshape @[3,4,4] (u202 * u203) ; u211 = stranspose @[2,0,3,1] (sgather (stranspose @[2,3,4,0,1] (sgather (stranspose @[3,5,0,4,1,2] (sgather (stranspose @[3,2,1,6,5,4,0] (sreplicate @2 (stranspose @[5,4,3,0,1,2] (sreplicate @4 (sreplicate @3 (sreplicate @2 (stranspose @[2,1,0] t204))))))) (\\[i205, i206] -> [kfromS (smaxIndex (t204 !$ [i205, i206])), i206, i205]))) (\\[i207, i208] -> [i207, i208, i207 + i208]))) (\\[i209, i210] -> [i209, i209 + i210, i210])) ; u212 = sreplicate @3 (sreplicate @4 (sfromR (tproject1 (tproject2 (tproject1 u1))) !$ [0, 0])) ; t213 = sreplicate @1 (ssum @4 (stranspose @[2,0,1] (sreshape @[3,4,4] (u211 * u212)))) + stranspose @[2,0,1] (sreplicate @3 (sreplicate @4 (sfromR (tproject2 (tproject2 (tproject1 u1)))))) ; m214 = str (sreplicate @2 (sconcrete (sreplicate [2] 2) * siota (SNat @2))) + sreplicate @2 (siota (SNat @2)) ; t221 = sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i215, i216, i217] -> [let x218 = m214 !$ [i215, i217] in ifH (sscalar -0.0 <=. negate (t213 !$ [0, i216, kfromS x218])) 0 1]) ; t222 = stranspose @[0,2,1] (sgather (str (sslice (SNat @0) (SNat @2) (t213 !$ [0]))) (\\[i219, i220] -> [kfromS (m214 !$ [i219, i220])])) ; m223 = sreshape @[2,4] (t221 * t222) ; v225 = sfromR (tproject1 (tproject1 (tproject2 u1))) !$ [0] ; v226 = sgather m223 (\\[i224] -> [i224, kfromS (smaxIndex (m223 !$ [i224]))]) ; m227 = sreplicate @1 (sreplicate @5 (ssum @2 (v225 * v226))) + str (sreplicate @5 (sfromR (tproject2 (tproject1 (tproject2 u1))))) ; v229 = sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i228] -> [ifH (sscalar -0.0 <=. negate (m227 !$ [0, i228])) 0 1]) ; v230 = m227 !$ [0] ; m231 = str (sreplicate @5 (str (sfromR (tproject1 (tproject2 (tproject2 u1)))) !$ [0])) ; m232 = sreplicate @10 (v229 * v230) ; m234 = soneHot (v229 * ssum @10 (m231 * sfromR dret)) [0] ; v235 = sreplicate @2 (ssum @5 (ssum @1 m234)) ; t239 = soneHot (sappend (sconcrete (sfromListLinear [0,4] [])) (sappend (str (sscatter (stranspose @[0,2,1] (t221 * sreshape @[2,2,2] (sscatter (v225 * v235) (\\[i236] -> [i236, kfromS (smaxIndex (m223 !$ [i236]))])))) (\\[i237, i238] -> [kfromS (m214 !$ [i237, i238])]))) (sconcrete (sreplicate [1,4] 0.0)))) [0] ; u240 = sreshape @[3,4,2,2] (stranspose @[1,2,0] (sreplicate @4 (ssum @1 t239))) ; t251 = soneHot (sscatter (u202 * sreshape @[3,4,2,2] (stranspose @[2,1,0] (ssum @2 (ssum @3 (ssum @4 (stranspose @[3,4,5,2,1,0] (ssum @2 (stranspose @[6,2,1,0,5,4,3] (sscatter (stranspose @[2,4,5,0,3,1] (sscatter (stranspose @[3,4,0,1,2] (sscatter (stranspose @[1,3,0,2] (u212 * u240)) (\\[i241, i242] -> [i241, i241 + i242, i242]))) (\\[i243, i244] -> [i243, i244, i243 + i244]))) (\\[i245, i246] -> [kfromS (smaxIndex (t204 !$ [i245, i246])), i246, i245])))))))))) (\\[i247, i248, i249, i250] -> [kfromS (m190 !$ [i247, i249]), kfromS (m191 !$ [i248, i250])])) [0] in tpair (tpair (tpair (rfromS (soneHot (ssum @1 (ssum @1 (ssum @9 (ssum @7 (w187 * sreshape @[7,9,1,1,2,2] (stranspose @[1,2,0] (sreplicate @4 (ssum @1 t251)))))))) [0, 0])) (rfromS (ssum @9 (ssum @7 (stranspose @[1,2,0] t251))))) (tpair (rfromS (soneHot (ssum @4 (ssum @3 (u211 * u240))) [0, 0])) (rfromS (ssum @4 (ssum @3 (stranspose @[1,2,0] t239)))))) (tpair (tpair (rfromS (soneHot (v226 * v235) [0])) (rfromS (ssum @5 (str m234)))) (tpair (rfromS (str (soneHot (ssum @5 (str (m232 * sfromR dret))) [0]))) (rfromS (ssum @5 (str (sfromR dret))))))"
testCNNOAst1 :: Assertion
testCNNOAst1 :: Assertion
testCNNOAst1 = do
let batch_size :: Int
batch_size = Int
5
sizeMnistWidthI :: Int
sizeMnistWidthI = Int
7
sizeMnistHeightI :: Int
sizeMnistHeightI = Int
9
ftk :: FullShapeTK
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
ftk = forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete
(forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(X (MnistCnnRanked2.ADCnnMnistParameters
Concrete Double)))
Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
vals
varName :: AstVarName
s
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
varName = FullShapeTK
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
-> Maybe (Int64, Int64)
-> AstVarId
-> AstVarName
s
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
forall (s :: AstSpanType) (y :: TK).
FullShapeTK y -> Maybe (Int64, Int64) -> AstVarId -> AstVarName s y
mkAstVarName FullShapeTK
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
ftk Maybe (Int64, Int64)
forall a. Maybe a
Nothing (AstVarId
-> AstVarName
s
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))))
-> (Int -> AstVarId)
-> Int
-> AstVarName
s
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> AstVarId
intToAstVarId (Int
-> AstVarName
s
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))))
-> Int
-> AstVarName
s
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
forall a b. (a -> b) -> a -> b
$ Int
100000000
var :: AstTensor AstMethodLet FullSpan
(X (MnistCnnRanked2.ADCnnMnistParameters
Concrete Double))
var :: AstTensor
AstMethodLet FullSpan (X (ADCnnMnistParameters Concrete Double))
var = AstVarName
FullSpan
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
-> AstTensor
AstMethodLet
FullSpan
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
forall (b :: AstSpanType) (c :: TK) (a :: AstMethodOfSharing).
AstVarName b c -> AstTensor a b c
AstVar AstVarName
FullSpan
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
forall {s :: AstSpanType}.
AstVarName
s
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
varName
valsInit :: MnistCnnRanked2.ADCnnMnistParameters Concrete Double
valsInit :: ADCnnMnistParameters Concrete Double
valsInit =
((Concrete
(TKS
((':)
@Natural
1
((':)
@Natural 1 ((':) @Natural 2 ((':) @Natural 2 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 1 ('[] @Natural)) Double)),
(Concrete
(TKS
((':)
@Natural
1
((':)
@Natural 1 ((':) @Natural 2 ((':) @Natural 2 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 1 ('[] @Natural)) Double)),
(Concrete
(TKS ((':) @Natural 1 ((':) @Natural 2 ('[] @Natural))) Double),
Concrete (TKS ((':) @Natural 1 ('[] @Natural)) Double)),
(Concrete
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 1 ('[] @Natural)))
Double),
Concrete
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
-> NoShape
((Concrete
(TKS
((':)
@Natural
1
((':)
@Natural 1 ((':) @Natural 2 ((':) @Natural 2 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 1 ('[] @Natural)) Double)),
(Concrete
(TKS
((':)
@Natural
1
((':)
@Natural 1 ((':) @Natural 2 ((':) @Natural 2 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 1 ('[] @Natural)) Double)),
(Concrete
(TKS ((':) @Natural 1 ((':) @Natural 2 ('[] @Natural))) Double),
Concrete (TKS ((':) @Natural 1 ('[] @Natural)) Double)),
(Concrete
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 1 ('[] @Natural)))
Double),
Concrete
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape (((Concrete
(TKS
((':)
@Natural
1
((':)
@Natural 1 ((':) @Natural 2 ((':) @Natural 2 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 1 ('[] @Natural)) Double)),
(Concrete
(TKS
((':)
@Natural
1
((':)
@Natural 1 ((':) @Natural 2 ((':) @Natural 2 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 1 ('[] @Natural)) Double)),
(Concrete
(TKS ((':) @Natural 1 ((':) @Natural 2 ('[] @Natural))) Double),
Concrete (TKS ((':) @Natural 1 ('[] @Natural)) Double)),
(Concrete
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 1 ('[] @Natural)))
Double),
Concrete
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
-> NoShape
((Concrete
(TKS
((':)
@Natural
1
((':)
@Natural 1 ((':) @Natural 2 ((':) @Natural 2 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 1 ('[] @Natural)) Double)),
(Concrete
(TKS
((':)
@Natural
1
((':)
@Natural 1 ((':) @Natural 2 ((':) @Natural 2 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 1 ('[] @Natural)) Double)),
(Concrete
(TKS ((':) @Natural 1 ((':) @Natural 2 ('[] @Natural))) Double),
Concrete (TKS ((':) @Natural 1 ('[] @Natural)) Double)),
(Concrete
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 1 ('[] @Natural)))
Double),
Concrete
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double))))
-> ((Concrete
(TKS
((':)
@Natural
1
((':)
@Natural 1 ((':) @Natural 2 ((':) @Natural 2 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 1 ('[] @Natural)) Double)),
(Concrete
(TKS
((':)
@Natural
1
((':)
@Natural 1 ((':) @Natural 2 ((':) @Natural 2 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 1 ('[] @Natural)) Double)),
(Concrete
(TKS ((':) @Natural 1 ((':) @Natural 2 ('[] @Natural))) Double),
Concrete (TKS ((':) @Natural 1 ('[] @Natural)) Double)),
(Concrete
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 1 ('[] @Natural)))
Double),
Concrete
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
-> NoShape
((Concrete
(TKS
((':)
@Natural
1
((':)
@Natural 1 ((':) @Natural 2 ((':) @Natural 2 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 1 ('[] @Natural)) Double)),
(Concrete
(TKS
((':)
@Natural
1
((':)
@Natural 1 ((':) @Natural 2 ((':) @Natural 2 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 1 ('[] @Natural)) Double)),
(Concrete
(TKS ((':) @Natural 1 ((':) @Natural 2 ('[] @Natural))) Double),
Concrete (TKS ((':) @Natural 1 ('[] @Natural)) Double)),
(Concrete
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 1 ('[] @Natural)))
Double),
Concrete
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
forall a b. (a -> b) -> a -> b
$ (ADCnnMnistParametersShaped Concrete 7 9 1 1 1 1 Double, StdGen)
-> ADCnnMnistParametersShaped Concrete 7 9 1 1 1 1 Double
forall a b. (a, b) -> a
fst
((ADCnnMnistParametersShaped Concrete 7 9 1 1 1 1 Double, StdGen)
-> ADCnnMnistParametersShaped Concrete 7 9 1 1 1 1 Double)
-> (ADCnnMnistParametersShaped Concrete 7 9 1 1 1 1 Double, StdGen)
-> ADCnnMnistParametersShaped Concrete 7 9 1 1 1 1 Double
forall a b. (a -> b) -> a -> b
$ forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue @(MnistCnnRanked2.ADCnnMnistParametersShaped
Concrete 7 9
1 1 1 1 Double)
Double
0.4 (Int -> StdGen
mkStdGen Int
44)
vals :: Concrete (X (ADCnnMnistParameters Concrete Double))
vals = forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget @Concrete ADCnnMnistParameters Concrete Double
valsInit
env :: AstEnv Concrete
env = AstVarName
(ZonkAny @AstSpanType 1)
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
-> AstEnv Concrete
-> AstEnv Concrete
forall (target :: Target) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName
(ZonkAny @AstSpanType 1)
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
forall {s :: AstSpanType}.
AstVarName
s
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
varName Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
vals AstEnv Concrete
forall (target :: Target). AstEnv target
emptyEnv
blackGlyph :: Concrete
(BuildTensorKind
5
(BuildTensorKind 1 (BuildTensorKind 7 (TKR2 1 (TKScalar Double)))))
blackGlyph = SNat 5
-> SingletonTK
(BuildTensorKind 1 (BuildTensorKind 7 (TKR2 1 (TKScalar Double))))
-> Concrete
(BuildTensorKind 1 (BuildTensorKind 7 (TKR2 1 (TKScalar Double))))
-> Concrete
(BuildTensorKind
5
(BuildTensorKind 1 (BuildTensorKind 7 (TKR2 1 (TKScalar Double)))))
forall (z :: TK) (k :: Natural).
ConvertTensor Concrete =>
SNat k
-> SingletonTK z -> Concrete z -> Concrete (BuildTensorKind k z)
forall (target :: Target) (z :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
treplicate (forall (n :: Natural). KnownNat n => SNat n
SNat @5) SingletonTK
(BuildTensorKind 1 (BuildTensorKind 7 (TKR2 1 (TKScalar Double))))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK
(Concrete
(BuildTensorKind 1 (BuildTensorKind 7 (TKR2 1 (TKScalar Double))))
-> Concrete
(BuildTensorKind
5
(BuildTensorKind
1 (BuildTensorKind 7 (TKR2 1 (TKScalar Double))))))
-> Concrete
(BuildTensorKind 1 (BuildTensorKind 7 (TKR2 1 (TKScalar Double))))
-> Concrete
(BuildTensorKind
5
(BuildTensorKind 1 (BuildTensorKind 7 (TKR2 1 (TKScalar Double)))))
forall a b. (a -> b) -> a -> b
$ SNat 1
-> SingletonTK (BuildTensorKind 7 (TKR2 1 (TKScalar Double)))
-> Concrete (BuildTensorKind 7 (TKR2 1 (TKScalar Double)))
-> Concrete
(BuildTensorKind 1 (BuildTensorKind 7 (TKR2 1 (TKScalar Double))))
forall (z :: TK) (k :: Natural).
ConvertTensor Concrete =>
SNat k
-> SingletonTK z -> Concrete z -> Concrete (BuildTensorKind k z)
forall (target :: Target) (z :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
treplicate (forall (n :: Natural). KnownNat n => SNat n
SNat @1) SingletonTK (BuildTensorKind 7 (TKR2 1 (TKScalar Double)))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK
(Concrete (BuildTensorKind 7 (TKR2 1 (TKScalar Double)))
-> Concrete
(BuildTensorKind 1 (BuildTensorKind 7 (TKR2 1 (TKScalar Double)))))
-> Concrete (BuildTensorKind 7 (TKR2 1 (TKScalar Double)))
-> Concrete
(BuildTensorKind 1 (BuildTensorKind 7 (TKR2 1 (TKScalar Double))))
forall a b. (a -> b) -> a -> b
$ SNat 7
-> SingletonTK (TKR2 1 (TKScalar Double))
-> Concrete (TKR2 1 (TKScalar Double))
-> Concrete (BuildTensorKind 7 (TKR2 1 (TKScalar Double)))
forall (z :: TK) (k :: Natural).
ConvertTensor Concrete =>
SNat k
-> SingletonTK z -> Concrete z -> Concrete (BuildTensorKind k z)
forall (target :: Target) (z :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
treplicate (forall (n :: Natural). KnownNat n => SNat n
SNat @7) SingletonTK (TKR2 1 (TKScalar Double))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK
(Concrete (TKR2 1 (TKScalar Double))
-> Concrete (BuildTensorKind 7 (TKR2 1 (TKScalar Double))))
-> Concrete (TKR2 1 (TKScalar Double))
-> Concrete (BuildTensorKind 7 (TKR2 1 (TKScalar Double)))
forall a b. (a -> b) -> a -> b
$ SNat 9
-> SingletonTK (TKR 0 Double)
-> Concrete (TKR 0 Double)
-> Concrete (BuildTensorKind 9 (TKR 0 Double))
forall (z :: TK) (k :: Natural).
ConvertTensor Concrete =>
SNat k
-> SingletonTK z -> Concrete z -> Concrete (BuildTensorKind k z)
forall (target :: Target) (z :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
treplicate (forall (n :: Natural). KnownNat n => SNat n
SNat @9) SingletonTK (TKR 0 Double)
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK (Concrete (TKR 0 Double)
-> Concrete (BuildTensorKind 9 (TKR 0 Double)))
-> Concrete (TKR 0 Double)
-> Concrete (BuildTensorKind 9 (TKR 0 Double))
forall a b. (a -> b) -> a -> b
$ Double -> Concrete (TKR 0 Double)
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
r -> target (TKR 0 r)
rscalar Double
7
afcnn2 :: ADReady f
=> MnistCnnRanked2.ADCnnMnistParameters f Double
-> f (TKR 2 Double)
afcnn2 :: forall (f :: Target).
ADReady f =>
ADCnnMnistParameters f Double -> f (TKR 2 Double)
afcnn2 = Int
-> Int
-> Int
-> PrimalOf f (TKR2 4 (TKScalar Double))
-> ADCnnMnistParameters f Double
-> f (TKR 2 Double)
forall (target :: Target) r.
(ADReady target, GoodScalar r, Differentiable r) =>
Int
-> Int
-> Int
-> PrimalOf target (TKR 4 r)
-> ADCnnMnistParameters target r
-> target (TKR 2 r)
MnistCnnRanked2.convMnistTwoR
Int
sizeMnistHeightI Int
sizeMnistWidthI Int
batch_size
(Ranked 4 Double -> PrimalOf f (TKR2 4 (TKScalar Double))
forall r (target :: Target) (n :: Natural).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete (Ranked 4 Double -> PrimalOf f (TKR2 4 (TKScalar Double)))
-> Ranked 4 Double -> PrimalOf f (TKR2 4 (TKScalar Double))
forall a b. (a -> b) -> a -> b
$ Concrete (TKR2 4 (TKScalar Double))
-> RepConcrete (TKR2 4 (TKScalar Double))
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete (TKR2 4 (TKScalar Double))
blackGlyph)
afcnn1 :: AstTensor AstMethodLet FullSpan (TKR 2 Double)
afcnn1 = ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) Double
-> AstTensor AstMethodLet FullSpan (TKR 2 Double)
forall (f :: Target).
ADReady f =>
ADCnnMnistParameters f Double -> f (TKR 2 Double)
afcnn2 (ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) Double
-> AstTensor AstMethodLet FullSpan (TKR 2 Double))
-> ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) Double
-> AstTensor AstMethodLet FullSpan (TKR 2 Double)
forall a b. (a -> b) -> a -> b
$ AstTensor
AstMethodLet
FullSpan
(X (ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) Double))
-> ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) Double
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget AstTensor
AstMethodLet
FullSpan
(X (ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) Double))
AstTensor
AstMethodLet FullSpan (X (ADCnnMnistParameters Concrete Double))
var
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target -> AstTensor AstMethodLet FullSpan y -> target y
interpretAstFull @Concrete AstEnv Concrete
env AstTensor AstMethodLet FullSpan (TKR 2 Double)
afcnn1
Concrete (TKR 2 Double) -> Concrete (TKR 2 Double) -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= ADCnnMnistParameters Concrete Double -> Concrete (TKR 2 Double)
forall (f :: Target).
ADReady f =>
ADCnnMnistParameters f Double -> f (TKR 2 Double)
afcnn2 ADCnnMnistParameters Concrete Double
valsInit
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target -> AstTensor AstMethodLet FullSpan y -> target y
interpretAstFull @Concrete AstEnv Concrete
env
(forall (z :: TK) (s :: AstSpanType).
AstSpan s =>
AstTensor AstMethodLet s z -> AstTensor AstMethodLet s z
simplifyInline @(TKR 2 Double) AstTensor AstMethodLet FullSpan (TKR 2 Double)
afcnn1)
Concrete (TKR 2 Double) -> Concrete (TKR 2 Double) -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= ADCnnMnistParameters Concrete Double -> Concrete (TKR 2 Double)
forall (f :: Target).
ADReady f =>
ADCnnMnistParameters f Double -> f (TKR 2 Double)
afcnn2 ADCnnMnistParameters Concrete Double
valsInit
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target -> AstTensor AstMethodLet FullSpan y -> target y
interpretAstFull @Concrete AstEnv Concrete
env
(forall (z :: TK) (s :: AstSpanType).
AstSpan s =>
AstTensor AstMethodLet s z -> AstTensor AstMethodLet s z
simplifyInlineContract @(TKR 2 Double) AstTensor AstMethodLet FullSpan (TKR 2 Double)
afcnn1)
Concrete (TKR 2 Double) -> Concrete (TKR 2 Double) -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= ADCnnMnistParameters Concrete Double -> Concrete (TKR 2 Double)
forall (f :: Target).
ADReady f =>
ADCnnMnistParameters f Double -> f (TKR 2 Double)
afcnn2 ADCnnMnistParameters Concrete Double
valsInit
testCNNOPP2 :: Assertion
testCNNOPP2 :: Assertion
testCNNOPP2 = do
Assertion
resetVarCounter
let batch_size :: Int
batch_size = Int
7
sizeMnistWidthI :: Int
sizeMnistWidthI = Int
14
sizeMnistHeightI :: Int
sizeMnistHeightI = Int
23
ftk :: FullShapeTK
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
ftk = forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete
(forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(X (MnistCnnRanked2.ADCnnMnistParameters
Concrete Double)))
Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
vals
valsInit :: MnistCnnRanked2.ADCnnMnistParameters Concrete Double
valsInit :: ADCnnMnistParameters Concrete Double
valsInit =
((Concrete
(TKS
((':)
@Natural
4
((':)
@Natural 1 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(Concrete
(TKS
((':)
@Natural
4
((':)
@Natural 4 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(Concrete
(TKS ((':) @Natural 5 ((':) @Natural 60 ('[] @Natural))) Double),
Concrete (TKS ((':) @Natural 5 ('[] @Natural)) Double)),
(Concrete
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 5 ('[] @Natural)))
Double),
Concrete
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
-> NoShape
((Concrete
(TKS
((':)
@Natural
4
((':)
@Natural 1 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(Concrete
(TKS
((':)
@Natural
4
((':)
@Natural 4 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(Concrete
(TKS ((':) @Natural 5 ((':) @Natural 60 ('[] @Natural))) Double),
Concrete (TKS ((':) @Natural 5 ('[] @Natural)) Double)),
(Concrete
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 5 ('[] @Natural)))
Double),
Concrete
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape (((Concrete
(TKS
((':)
@Natural
4
((':)
@Natural 1 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(Concrete
(TKS
((':)
@Natural
4
((':)
@Natural 4 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(Concrete
(TKS ((':) @Natural 5 ((':) @Natural 60 ('[] @Natural))) Double),
Concrete (TKS ((':) @Natural 5 ('[] @Natural)) Double)),
(Concrete
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 5 ('[] @Natural)))
Double),
Concrete
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
-> NoShape
((Concrete
(TKS
((':)
@Natural
4
((':)
@Natural 1 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(Concrete
(TKS
((':)
@Natural
4
((':)
@Natural 4 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(Concrete
(TKS ((':) @Natural 5 ((':) @Natural 60 ('[] @Natural))) Double),
Concrete (TKS ((':) @Natural 5 ('[] @Natural)) Double)),
(Concrete
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 5 ('[] @Natural)))
Double),
Concrete
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double))))
-> ((Concrete
(TKS
((':)
@Natural
4
((':)
@Natural 1 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(Concrete
(TKS
((':)
@Natural
4
((':)
@Natural 4 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(Concrete
(TKS ((':) @Natural 5 ((':) @Natural 60 ('[] @Natural))) Double),
Concrete (TKS ((':) @Natural 5 ('[] @Natural)) Double)),
(Concrete
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 5 ('[] @Natural)))
Double),
Concrete
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
-> NoShape
((Concrete
(TKS
((':)
@Natural
4
((':)
@Natural 1 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(Concrete
(TKS
((':)
@Natural
4
((':)
@Natural 4 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(Concrete
(TKS ((':) @Natural 5 ((':) @Natural 60 ('[] @Natural))) Double),
Concrete (TKS ((':) @Natural 5 ('[] @Natural)) Double)),
(Concrete
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 5 ('[] @Natural)))
Double),
Concrete
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
forall a b. (a -> b) -> a -> b
$ (ADCnnMnistParametersShaped Concrete 14 23 2 3 4 5 Double, StdGen)
-> ADCnnMnistParametersShaped Concrete 14 23 2 3 4 5 Double
forall a b. (a, b) -> a
fst
((ADCnnMnistParametersShaped Concrete 14 23 2 3 4 5 Double, StdGen)
-> ADCnnMnistParametersShaped Concrete 14 23 2 3 4 5 Double)
-> (ADCnnMnistParametersShaped Concrete 14 23 2 3 4 5 Double,
StdGen)
-> ADCnnMnistParametersShaped Concrete 14 23 2 3 4 5 Double
forall a b. (a -> b) -> a -> b
$ forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue @(MnistCnnRanked2.ADCnnMnistParametersShaped
Concrete 14 23
2 3 4 5 Double)
Double
0.4 (Int -> StdGen
mkStdGen Int
44)
vals :: Concrete (X (ADCnnMnistParameters Concrete Double))
vals = forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget @Concrete ADCnnMnistParameters Concrete Double
valsInit
blackGlyph :: Concrete
(BuildTensorKind
7
(BuildTensorKind
1 (BuildTensorKind 14 (TKR2 1 (TKScalar Double)))))
blackGlyph = SNat 7
-> SingletonTK
(BuildTensorKind 1 (BuildTensorKind 14 (TKR2 1 (TKScalar Double))))
-> Concrete
(BuildTensorKind 1 (BuildTensorKind 14 (TKR2 1 (TKScalar Double))))
-> Concrete
(BuildTensorKind
7
(BuildTensorKind
1 (BuildTensorKind 14 (TKR2 1 (TKScalar Double)))))
forall (z :: TK) (k :: Natural).
ConvertTensor Concrete =>
SNat k
-> SingletonTK z -> Concrete z -> Concrete (BuildTensorKind k z)
forall (target :: Target) (z :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
treplicate (forall (n :: Natural). KnownNat n => SNat n
SNat @7) SingletonTK
(BuildTensorKind 1 (BuildTensorKind 14 (TKR2 1 (TKScalar Double))))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK
(Concrete
(BuildTensorKind 1 (BuildTensorKind 14 (TKR2 1 (TKScalar Double))))
-> Concrete
(BuildTensorKind
7
(BuildTensorKind
1 (BuildTensorKind 14 (TKR2 1 (TKScalar Double))))))
-> Concrete
(BuildTensorKind 1 (BuildTensorKind 14 (TKR2 1 (TKScalar Double))))
-> Concrete
(BuildTensorKind
7
(BuildTensorKind
1 (BuildTensorKind 14 (TKR2 1 (TKScalar Double)))))
forall a b. (a -> b) -> a -> b
$ SNat 1
-> SingletonTK (BuildTensorKind 14 (TKR2 1 (TKScalar Double)))
-> Concrete (BuildTensorKind 14 (TKR2 1 (TKScalar Double)))
-> Concrete
(BuildTensorKind 1 (BuildTensorKind 14 (TKR2 1 (TKScalar Double))))
forall (z :: TK) (k :: Natural).
ConvertTensor Concrete =>
SNat k
-> SingletonTK z -> Concrete z -> Concrete (BuildTensorKind k z)
forall (target :: Target) (z :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
treplicate (forall (n :: Natural). KnownNat n => SNat n
SNat @1) SingletonTK (BuildTensorKind 14 (TKR2 1 (TKScalar Double)))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK
(Concrete (BuildTensorKind 14 (TKR2 1 (TKScalar Double)))
-> Concrete
(BuildTensorKind
1 (BuildTensorKind 14 (TKR2 1 (TKScalar Double)))))
-> Concrete (BuildTensorKind 14 (TKR2 1 (TKScalar Double)))
-> Concrete
(BuildTensorKind 1 (BuildTensorKind 14 (TKR2 1 (TKScalar Double))))
forall a b. (a -> b) -> a -> b
$ SNat 14
-> SingletonTK (TKR2 1 (TKScalar Double))
-> Concrete (TKR2 1 (TKScalar Double))
-> Concrete (BuildTensorKind 14 (TKR2 1 (TKScalar Double)))
forall (z :: TK) (k :: Natural).
ConvertTensor Concrete =>
SNat k
-> SingletonTK z -> Concrete z -> Concrete (BuildTensorKind k z)
forall (target :: Target) (z :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
treplicate (forall (n :: Natural). KnownNat n => SNat n
SNat @14) SingletonTK (TKR2 1 (TKScalar Double))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK
(Concrete (TKR2 1 (TKScalar Double))
-> Concrete (BuildTensorKind 14 (TKR2 1 (TKScalar Double))))
-> Concrete (TKR2 1 (TKScalar Double))
-> Concrete (BuildTensorKind 14 (TKR2 1 (TKScalar Double)))
forall a b. (a -> b) -> a -> b
$ SNat 23
-> SingletonTK (TKR 0 Double)
-> Concrete (TKR 0 Double)
-> Concrete (BuildTensorKind 23 (TKR 0 Double))
forall (z :: TK) (k :: Natural).
ConvertTensor Concrete =>
SNat k
-> SingletonTK z -> Concrete z -> Concrete (BuildTensorKind k z)
forall (target :: Target) (z :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
treplicate (forall (n :: Natural). KnownNat n => SNat n
SNat @23) SingletonTK (TKR 0 Double)
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK (Concrete (TKR 0 Double)
-> Concrete (BuildTensorKind 23 (TKR 0 Double)))
-> Concrete (TKR 0 Double)
-> Concrete (BuildTensorKind 23 (TKR 0 Double))
forall a b. (a -> b) -> a -> b
$ Double -> Concrete (TKR 0 Double)
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
r -> target (TKR 0 r)
rscalar Double
7
afcnn2 :: ADReady f
=> MnistCnnRanked2.ADCnnMnistParameters f Double
-> f (TKR 2 Double)
afcnn2 :: forall (f :: Target).
ADReady f =>
ADCnnMnistParameters f Double -> f (TKR 2 Double)
afcnn2 = Int
-> Int
-> Int
-> PrimalOf f (TKR2 4 (TKScalar Double))
-> ADCnnMnistParameters f Double
-> f (TKR 2 Double)
forall (target :: Target) r.
(ADReady target, GoodScalar r, Differentiable r) =>
Int
-> Int
-> Int
-> PrimalOf target (TKR 4 r)
-> ADCnnMnistParameters target r
-> target (TKR 2 r)
MnistCnnRanked2.convMnistTwoR
Int
sizeMnistHeightI Int
sizeMnistWidthI Int
batch_size
(Ranked 4 Double -> PrimalOf f (TKR2 4 (TKScalar Double))
forall r (target :: Target) (n :: Natural).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete (Ranked 4 Double -> PrimalOf f (TKR2 4 (TKScalar Double)))
-> Ranked 4 Double -> PrimalOf f (TKR2 4 (TKScalar Double))
forall a b. (a -> b) -> a -> b
$ Concrete (TKR2 4 (TKScalar Double))
-> RepConcrete (TKR2 4 (TKScalar Double))
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete (TKR2 4 (TKScalar Double))
blackGlyph)
artifactRev :: AstArtifactRev
(X (ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) Double))
(TKR 2 Double)
artifactRev = IncomingCotangentHandling
-> (ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) Double
-> AstTensor AstMethodLet FullSpan (TKR 2 Double))
-> FullShapeTK
(X (ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) Double))
-> AstArtifactRev
(X (ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) Double))
(TKR 2 Double)
forall src (ztgt :: TK) tgt.
(AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
(tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
IncomingCotangentHandling
-> (src -> tgt)
-> FullShapeTK (X src)
-> AstArtifactRev (X src) ztgt
revArtifactAdapt IncomingCotangentHandling
UseIncomingCotangent ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) Double
-> AstTensor AstMethodLet FullSpan (TKR 2 Double)
forall (f :: Target).
ADReady f =>
ADCnnMnistParameters f Double -> f (TKR 2 Double)
afcnn2 FullShapeTK
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
FullShapeTK
(X (ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) Double))
ftk
AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
(TKR 2 Double)
-> String
forall (x :: TK) (z :: TK). AstArtifactRev x z -> String
printArtifactPrimalPretty (AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
(TKR 2 Double)
-> AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
(TKR 2 Double)
forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstArtifactRev x z
simplifyArtifact AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
(TKR 2 Double)
artifactRev)
String -> String -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= String
"\\u1 -> rfromS (let t255 = ssum @12 (stranspose @[3,0,1,2] (sreshape @[4,14,23,12] (sreplicate @4 (stranspose @[1,2,0] (sreplicate @1 (stranspose @[1,2,0] (sreplicate @1 (stranspose @[2,0,3,1] (sgather (stranspose @[2,0,1] (sgather (sconcrete (sreplicate [14,23] 7.0)) (\\[i473, i474] -> [i473 + i474]))) (\\[i251, i252] -> [i251 + i252]))))))) * str (sreplicate @14 (str (sreplicate @23 (str (sreplicate @1 (str (sreplicate @1 (str (sfromR (tproject1 (tproject1 (tproject1 u1)))) !$ [0]))))))))))) + stranspose @[2,0,1] (sreplicate @14 (sreplicate @23 (sfromR (tproject2 (tproject1 (tproject1 u1)))))) ; u271 = sreshape @[4,7,11,4] (sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i258, i259, i260, i261, i262] -> [ifH (sscalar -0.0 <=. negate (t255 !$ [i258, kfromS (sconcrete (sfromListLinear [7,2] [0,1,2,3,4,5,6,7,8,9,10,11,12,13]) !$ [i259, i261]), kfromS (sconcrete (sfromListLinear [11,2] [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21]) !$ [i260, i262])])) 0 1]) * stranspose @[4,0,1,2,3] (sgather (stranspose @[1,2,0] t255) (\\[i265, i266, i267, i268] -> [kfromS (sconcrete (sfromListLinear [7,2] [0,1,2,3,4,5,6,7,8,9,10,11,12,13]) !$ [i265, i267]), kfromS (sconcrete (sfromListLinear [11,2] [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21]) !$ [i266, i268])]))) ; t280 = ssum @48 (stranspose @[3,0,1,2] (sreshape @[4,7,11,48] (sreplicate @4 (stranspose @[2,0,3,4,1] (sgather (stranspose @[2,4,5,0,3,1] (sgather (stranspose @[4,6,0,5,2,1,3] (sgather (stranspose @[4,3,2,1,7,6,5,0] (sreplicate @3 (stranspose @[6,5,4,3,0,1,2] (sreplicate @11 (sreplicate @7 (sreplicate @4 (stranspose @[3,2,1,0] u271))))))) (\\[i458, i460, i461] -> [kfromS (smaxIndex (u271 !$ [i461, i458, i460])), i460, i458, i461]))) (\\[i466, i468] -> [i466, i468, i466 + i468]))) (\\[i277, i278] -> [i277, i277 + i278, i278]))) * str (sreplicate @7 (str (sreplicate @11 (sfromR (tproject1 (tproject2 (tproject1 u1)))))))))) + stranspose @[2,0,1] (sreplicate @7 (sreplicate @11 (sfromR (tproject2 (tproject2 (tproject1 u1)))))) ; u296 = sreshape @[4,3,5,4] (sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i283, i284, i285, i286, i287] -> [ifH (sscalar -0.0 <=. negate (t280 !$ [i283, kfromS (sconcrete (sfromListLinear [3,2] [0,1,2,3,4,5]) !$ [i284, i286]), kfromS (sconcrete (sfromListLinear [5,2] [0,1,2,3,4,5,6,7,8,9]) !$ [i285, i287])])) 0 1]) * stranspose @[4,0,1,2,3] (sgather (stranspose @[1,2,0] t280) (\\[i290, i291, i292, i293] -> [kfromS (sconcrete (sfromListLinear [3,2] [0,1,2,3,4,5]) !$ [i290, i292]), kfromS (sconcrete (sfromListLinear [5,2] [0,1,2,3,4,5,6,7,8,9]) !$ [i291, i293])]))) ; m301 = str (sreplicate @7 (sdot1In (sfromR (tproject1 (tproject1 (tproject2 u1)))) (sreplicate @5 (sreshape @[60] (sgather u296 (\\[i297, i298, i299] -> [i297, i298, i299, kfromS (smaxIndex (u296 !$ [i297, i298, i299]))])))))) + str (sreplicate @7 (sfromR (tproject2 (tproject1 (tproject2 u1))))) in smatmul2 (sfromR (tproject1 (tproject2 (tproject2 u1)))) (sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i302, i303] -> [ifH (sscalar -0.0 <=. negate (m301 !$ [i302, i303])) 0 1]) * m301) + str (sreplicate @7 (sfromR (tproject2 (tproject2 (tproject2 u1))))))"
AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
(TKR 2 Double)
-> String
forall (x :: TK) (z :: TK). AstArtifactRev x z -> String
printArtifactPrimalPretty AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
(TKR 2 Double)
artifactRev
String -> String -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= String
"\\u1 -> let w253 = sreplicate @4 (stranspose @[1,2,0] (sreplicate @1 (stranspose @[1,2,0] (sreplicate @1 (stranspose @[2,0,3,1] (sgather (stranspose @[2,0,1] (sgather (sconcrete (sreplicate [14,23] 7.0)) (\\[i249, i250] -> [i249 + i250]))) (\\[i251, i252] -> [i251 + i252]))))))) ; w254 = str (sreplicate @14 (str (sreplicate @23 (str (sreplicate @1 (str (sreplicate @1 (str (sfromR (tproject1 (tproject1 (tproject1 u1)))) !$ [0])))))))) ; t255 = ssum @12 (stranspose @[3,0,1,2] (sreshape @[4,14,23,12] (w253 * w254))) + stranspose @[2,0,1] (sreplicate @14 (sreplicate @23 (sfromR (tproject2 (tproject1 (tproject1 u1)))))) ; m256 = str (sreplicate @2 (sconcrete (sreplicate [7] 2) * siota (SNat @7))) + sreplicate @7 (siota (SNat @2)) ; m257 = str (sreplicate @2 (sconcrete (sreplicate [11] 2) * siota (SNat @11))) + sreplicate @11 (siota (SNat @2)) ; w269 = sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i258, i259, i260, i261, i262] -> [let x263 = m256 !$ [i259, i261] ; x264 = m257 !$ [i260, i262] in ifH (sscalar -0.0 <=. negate (t255 !$ [i258, kfromS x263, kfromS x264])) 0 1]) ; w270 = stranspose @[4,0,1,2,3] (sgather (stranspose @[1,2,0] t255) (\\[i265, i266, i267, i268] -> [kfromS (m256 !$ [i265, i267]), kfromS (m257 !$ [i266, i268])])) ; u271 = sreshape @[4,7,11,4] (w269 * w270) ; w279 = sreplicate @4 (stranspose @[2,0,3,4,1] (sgather (stranspose @[2,4,5,0,3,1] (sgather (stranspose @[4,6,0,5,2,1,3] (sgather (stranspose @[4,3,2,1,7,6,5,0] (sreplicate @3 (stranspose @[6,5,4,3,0,1,2] (sreplicate @11 (sreplicate @7 (sreplicate @4 (stranspose @[3,2,1,0] u271))))))) (\\[i272, i273, i274] -> [kfromS (smaxIndex (u271 !$ [i274, i272, i273])), i273, i272, i274]))) (\\[i275, i276] -> [i275, i276, i275 + i276]))) (\\[i277, i278] -> [i277, i277 + i278, i278]))) ; t280 = ssum @48 (stranspose @[3,0,1,2] (sreshape @[4,7,11,48] (w279 * str (sreplicate @7 (str (sreplicate @11 (sfromR (tproject1 (tproject2 (tproject1 u1)))))))))) + stranspose @[2,0,1] (sreplicate @7 (sreplicate @11 (sfromR (tproject2 (tproject2 (tproject1 u1)))))) ; m281 = str (sreplicate @2 (sconcrete (sreplicate [3] 2) * siota (SNat @3))) + sreplicate @3 (siota (SNat @2)) ; m282 = str (sreplicate @2 (sconcrete (sreplicate [5] 2) * siota (SNat @5))) + sreplicate @5 (siota (SNat @2)) ; w294 = sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i283, i284, i285, i286, i287] -> [let x288 = m281 !$ [i284, i286] ; x289 = m282 !$ [i285, i287] in ifH (sscalar -0.0 <=. negate (t280 !$ [i283, kfromS x288, kfromS x289])) 0 1]) ; w295 = stranspose @[4,0,1,2,3] (sgather (stranspose @[1,2,0] t280) (\\[i290, i291, i292, i293] -> [kfromS (m281 !$ [i290, i292]), kfromS (m282 !$ [i291, i293])])) ; u296 = sreshape @[4,3,5,4] (w294 * w295) ; m300 = str (sreplicate @5 (sreshape @[60] (sgather u296 (\\[i297, i298, i299] -> [i297, i298, i299, kfromS (smaxIndex (u296 !$ [i297, i298, i299]))])))) ; m301 = str (sreplicate @7 (ssum @60 (str (sfromR (tproject1 (tproject1 (tproject2 u1)))) * m300))) + str (sreplicate @7 (sfromR (tproject2 (tproject1 (tproject2 u1))))) ; m304 = sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i302, i303] -> [ifH (sscalar -0.0 <=. negate (m301 !$ [i302, i303])) 0 1]) ; t305 = str (sreplicate @10 (m304 * m301)) in rfromS (ssum @5 (stranspose @[2,1,0] (sreplicate @7 (sfromR (tproject1 (tproject2 (tproject2 u1))))) * t305) + str (sreplicate @7 (sfromR (tproject2 (tproject2 (tproject2 u1))))))"
AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
(TKR 2 Double)
-> String
forall (x :: TK) (z :: TK). AstArtifactRev x z -> String
printArtifactPretty AstArtifactRev
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
(TKR 2 Double)
artifactRev
String -> String -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= String
"\\dret u1 -> let w253 = sreplicate @4 (stranspose @[1,2,0] (sreplicate @1 (stranspose @[1,2,0] (sreplicate @1 (stranspose @[2,0,3,1] (sgather (stranspose @[2,0,1] (sgather (sconcrete (sreplicate [14,23] 7.0)) (\\[i249, i250] -> [i249 + i250]))) (\\[i251, i252] -> [i251 + i252]))))))) ; w254 = str (sreplicate @14 (str (sreplicate @23 (str (sreplicate @1 (str (sreplicate @1 (str (sfromR (tproject1 (tproject1 (tproject1 u1)))) !$ [0])))))))) ; t255 = ssum @12 (stranspose @[3,0,1,2] (sreshape @[4,14,23,12] (w253 * w254))) + stranspose @[2,0,1] (sreplicate @14 (sreplicate @23 (sfromR (tproject2 (tproject1 (tproject1 u1)))))) ; m256 = str (sreplicate @2 (sconcrete (sreplicate [7] 2) * siota (SNat @7))) + sreplicate @7 (siota (SNat @2)) ; m257 = str (sreplicate @2 (sconcrete (sreplicate [11] 2) * siota (SNat @11))) + sreplicate @11 (siota (SNat @2)) ; w269 = sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i258, i259, i260, i261, i262] -> [let x263 = m256 !$ [i259, i261] ; x264 = m257 !$ [i260, i262] in ifH (sscalar -0.0 <=. negate (t255 !$ [i258, kfromS x263, kfromS x264])) 0 1]) ; w270 = stranspose @[4,0,1,2,3] (sgather (stranspose @[1,2,0] t255) (\\[i265, i266, i267, i268] -> [kfromS (m256 !$ [i265, i267]), kfromS (m257 !$ [i266, i268])])) ; u271 = sreshape @[4,7,11,4] (w269 * w270) ; w279 = sreplicate @4 (stranspose @[2,0,3,4,1] (sgather (stranspose @[2,4,5,0,3,1] (sgather (stranspose @[4,6,0,5,2,1,3] (sgather (stranspose @[4,3,2,1,7,6,5,0] (sreplicate @3 (stranspose @[6,5,4,3,0,1,2] (sreplicate @11 (sreplicate @7 (sreplicate @4 (stranspose @[3,2,1,0] u271))))))) (\\[i272, i273, i274] -> [kfromS (smaxIndex (u271 !$ [i274, i272, i273])), i273, i272, i274]))) (\\[i275, i276] -> [i275, i276, i275 + i276]))) (\\[i277, i278] -> [i277, i277 + i278, i278]))) ; t280 = ssum @48 (stranspose @[3,0,1,2] (sreshape @[4,7,11,48] (w279 * str (sreplicate @7 (str (sreplicate @11 (sfromR (tproject1 (tproject2 (tproject1 u1)))))))))) + stranspose @[2,0,1] (sreplicate @7 (sreplicate @11 (sfromR (tproject2 (tproject2 (tproject1 u1)))))) ; m281 = str (sreplicate @2 (sconcrete (sreplicate [3] 2) * siota (SNat @3))) + sreplicate @3 (siota (SNat @2)) ; m282 = str (sreplicate @2 (sconcrete (sreplicate [5] 2) * siota (SNat @5))) + sreplicate @5 (siota (SNat @2)) ; w294 = sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i283, i284, i285, i286, i287] -> [let x288 = m281 !$ [i284, i286] ; x289 = m282 !$ [i285, i287] in ifH (sscalar -0.0 <=. negate (t280 !$ [i283, kfromS x288, kfromS x289])) 0 1]) ; w295 = stranspose @[4,0,1,2,3] (sgather (stranspose @[1,2,0] t280) (\\[i290, i291, i292, i293] -> [kfromS (m281 !$ [i290, i292]), kfromS (m282 !$ [i291, i293])])) ; u296 = sreshape @[4,3,5,4] (w294 * w295) ; m300 = str (sreplicate @5 (sreshape @[60] (sgather u296 (\\[i297, i298, i299] -> [i297, i298, i299, kfromS (smaxIndex (u296 !$ [i297, i298, i299]))])))) ; m301 = str (sreplicate @7 (ssum @60 (str (sfromR (tproject1 (tproject1 (tproject2 u1)))) * m300))) + str (sreplicate @7 (sfromR (tproject2 (tproject1 (tproject2 u1))))) ; m304 = sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i302, i303] -> [ifH (sscalar -0.0 <=. negate (m301 !$ [i302, i303])) 0 1]) ; t305 = str (sreplicate @10 (m304 * m301)) ; m307 = m304 * ssum @10 (str (stranspose @[2,1,0] (sreplicate @7 (sfromR (tproject1 (tproject2 (tproject2 u1))))) * sreplicate @5 (sfromR dret))) ; m308 = sreplicate @60 (ssum @7 (str m307)) ; t316 = stranspose @[2,0,1] (sscatter (stranspose @[1,2,3,4,0] (w294 * sreshape @[4,3,5,2,2] (sscatter (sreshape @[4,3,5] (ssum @5 (str (str (sfromR (tproject1 (tproject1 (tproject2 u1)))) * m308)))) (\\[i309, i310, i311] -> [i309, i310, i311, kfromS (smaxIndex (u296 !$ [i309, i310, i311]))])))) (\\[i312, i313, i314, i315] -> [kfromS (m281 !$ [i312, i314]), kfromS (m282 !$ [i313, i315])])) ; w317 = sreshape @[4,7,11,4,3,4] (stranspose @[1,2,3,0] (sreplicate @48 t316)) ; t329 = stranspose @[2,0,1] (sscatter (stranspose @[1,2,3,4,0] (w269 * sreshape @[4,7,11,2,2] (stranspose @[3,2,1,0] (ssum @4 (ssum @7 (ssum @11 (stranspose @[4,5,6,3,2,1,0] (ssum @3 (stranspose @[7,3,2,1,0,6,5,4] (sscatter (stranspose @[2,5,4,6,0,3,1] (sscatter (stranspose @[3,5,0,4,1,2] (sscatter (stranspose @[1,4,0,2,3] (ssum @4 (str (sreplicate @7 (str (sreplicate @11 (sfromR (tproject1 (tproject2 (tproject1 u1))))))) * w317))) (\\[i318, i319] -> [i318, i318 + i319, i319]))) (\\[i320, i321] -> [i320, i321, i320 + i321]))) (\\[i322, i323, i324] -> [kfromS (smaxIndex (u271 !$ [i324, i322, i323])), i323, i322, i324]))))))))))) (\\[i325, i326, i327, i328] -> [kfromS (m256 !$ [i325, i327]), kfromS (m257 !$ [i326, i328])])) in tpair (tpair (tpair (rfromS (str (soneHot (ssum @1 (str (ssum @1 (str (ssum @23 (str (ssum @14 (str (w253 * sreshape @[4,14,23,1,1,3,4] (stranspose @[1,2,3,0] (sreplicate @12 t329))))))))))) [0]))) (rfromS (ssum @23 (ssum @14 (stranspose @[1,2,0] t329))))) (tpair (rfromS (ssum @11 (str (ssum @7 (str (w279 * w317)))))) (rfromS (ssum @11 (ssum @7 (stranspose @[1,2,0] t316)))))) (tpair (tpair (rfromS (str (m300 * m308))) (rfromS (ssum @7 (str m307)))) (tpair (rfromS (ssum @7 (stranspose @[2,1,0] (t305 * sreplicate @5 (sfromR dret))))) (rfromS (ssum @7 (str (sfromR dret))))))"
testCNNOAst2 :: Assertion
testCNNOAst2 :: Assertion
testCNNOAst2 = do
let batch_size :: Int
batch_size = Int
7
sizeMnistWidthI :: Int
sizeMnistWidthI = Int
14
sizeMnistHeightI :: Int
sizeMnistHeightI = Int
23
ftk :: FullShapeTK
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
ftk = forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete
(forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(X (MnistCnnRanked2.ADCnnMnistParameters
Concrete Double)))
Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
vals
varName :: AstVarName
s
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
varName = FullShapeTK
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
-> Maybe (Int64, Int64)
-> AstVarId
-> AstVarName
s
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
forall (s :: AstSpanType) (y :: TK).
FullShapeTK y -> Maybe (Int64, Int64) -> AstVarId -> AstVarName s y
mkAstVarName FullShapeTK
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
ftk Maybe (Int64, Int64)
forall a. Maybe a
Nothing (AstVarId
-> AstVarName
s
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))))
-> (Int -> AstVarId)
-> Int
-> AstVarName
s
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> AstVarId
intToAstVarId (Int
-> AstVarName
s
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double))))))
-> Int
-> AstVarName
s
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
forall a b. (a -> b) -> a -> b
$ Int
100000000
var :: AstTensor AstMethodLet FullSpan
(X (MnistCnnRanked2.ADCnnMnistParameters
Concrete Double))
var :: AstTensor
AstMethodLet FullSpan (X (ADCnnMnistParameters Concrete Double))
var = AstVarName
FullSpan
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
-> AstTensor
AstMethodLet
FullSpan
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
forall (b :: AstSpanType) (c :: TK) (a :: AstMethodOfSharing).
AstVarName b c -> AstTensor a b c
AstVar AstVarName
FullSpan
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
forall {s :: AstSpanType}.
AstVarName
s
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
varName
valsInit :: MnistCnnRanked2.ADCnnMnistParameters Concrete Double
valsInit :: ADCnnMnistParameters Concrete Double
valsInit =
((Concrete
(TKS
((':)
@Natural
4
((':)
@Natural 1 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(Concrete
(TKS
((':)
@Natural
4
((':)
@Natural 4 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(Concrete
(TKS ((':) @Natural 5 ((':) @Natural 60 ('[] @Natural))) Double),
Concrete (TKS ((':) @Natural 5 ('[] @Natural)) Double)),
(Concrete
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 5 ('[] @Natural)))
Double),
Concrete
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
-> NoShape
((Concrete
(TKS
((':)
@Natural
4
((':)
@Natural 1 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(Concrete
(TKS
((':)
@Natural
4
((':)
@Natural 4 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(Concrete
(TKS ((':) @Natural 5 ((':) @Natural 60 ('[] @Natural))) Double),
Concrete (TKS ((':) @Natural 5 ('[] @Natural)) Double)),
(Concrete
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 5 ('[] @Natural)))
Double),
Concrete
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape (((Concrete
(TKS
((':)
@Natural
4
((':)
@Natural 1 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(Concrete
(TKS
((':)
@Natural
4
((':)
@Natural 4 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(Concrete
(TKS ((':) @Natural 5 ((':) @Natural 60 ('[] @Natural))) Double),
Concrete (TKS ((':) @Natural 5 ('[] @Natural)) Double)),
(Concrete
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 5 ('[] @Natural)))
Double),
Concrete
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
-> NoShape
((Concrete
(TKS
((':)
@Natural
4
((':)
@Natural 1 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(Concrete
(TKS
((':)
@Natural
4
((':)
@Natural 4 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(Concrete
(TKS ((':) @Natural 5 ((':) @Natural 60 ('[] @Natural))) Double),
Concrete (TKS ((':) @Natural 5 ('[] @Natural)) Double)),
(Concrete
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 5 ('[] @Natural)))
Double),
Concrete
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double))))
-> ((Concrete
(TKS
((':)
@Natural
4
((':)
@Natural 1 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(Concrete
(TKS
((':)
@Natural
4
((':)
@Natural 4 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(Concrete
(TKS ((':) @Natural 5 ((':) @Natural 60 ('[] @Natural))) Double),
Concrete (TKS ((':) @Natural 5 ('[] @Natural)) Double)),
(Concrete
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 5 ('[] @Natural)))
Double),
Concrete
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
-> NoShape
((Concrete
(TKS
((':)
@Natural
4
((':)
@Natural 1 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(Concrete
(TKS
((':)
@Natural
4
((':)
@Natural 4 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(Concrete
(TKS ((':) @Natural 5 ((':) @Natural 60 ('[] @Natural))) Double),
Concrete (TKS ((':) @Natural 5 ('[] @Natural)) Double)),
(Concrete
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 5 ('[] @Natural)))
Double),
Concrete
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
forall a b. (a -> b) -> a -> b
$ (ADCnnMnistParametersShaped Concrete 14 23 2 3 4 5 Double, StdGen)
-> ADCnnMnistParametersShaped Concrete 14 23 2 3 4 5 Double
forall a b. (a, b) -> a
fst
((ADCnnMnistParametersShaped Concrete 14 23 2 3 4 5 Double, StdGen)
-> ADCnnMnistParametersShaped Concrete 14 23 2 3 4 5 Double)
-> (ADCnnMnistParametersShaped Concrete 14 23 2 3 4 5 Double,
StdGen)
-> ADCnnMnistParametersShaped Concrete 14 23 2 3 4 5 Double
forall a b. (a -> b) -> a -> b
$ forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue @(MnistCnnRanked2.ADCnnMnistParametersShaped
Concrete 14 23
2 3 4 5 Double)
Double
0.4 (Int -> StdGen
mkStdGen Int
44)
vals :: Concrete (X (ADCnnMnistParameters Concrete Double))
vals = forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget @Concrete ADCnnMnistParameters Concrete Double
valsInit
env :: AstEnv Concrete
env = AstVarName
(ZonkAny @AstSpanType 0)
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
-> AstEnv Concrete
-> AstEnv Concrete
forall (target :: Target) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName
(ZonkAny @AstSpanType 0)
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
forall {s :: AstSpanType}.
AstVarName
s
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
varName Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR2 4 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
(TKProduct
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))
(TKProduct (TKR 2 Double) (TKR2 1 (TKScalar Double)))))
vals AstEnv Concrete
forall (target :: Target). AstEnv target
emptyEnv
blackGlyph :: Concrete
(BuildTensorKind
7
(BuildTensorKind
1 (BuildTensorKind 14 (TKR2 1 (TKScalar Double)))))
blackGlyph = SNat 7
-> SingletonTK
(BuildTensorKind 1 (BuildTensorKind 14 (TKR2 1 (TKScalar Double))))
-> Concrete
(BuildTensorKind 1 (BuildTensorKind 14 (TKR2 1 (TKScalar Double))))
-> Concrete
(BuildTensorKind
7
(BuildTensorKind
1 (BuildTensorKind 14 (TKR2 1 (TKScalar Double)))))
forall (z :: TK) (k :: Natural).
ConvertTensor Concrete =>
SNat k
-> SingletonTK z -> Concrete z -> Concrete (BuildTensorKind k z)
forall (target :: Target) (z :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
treplicate (forall (n :: Natural). KnownNat n => SNat n
SNat @7) SingletonTK
(BuildTensorKind 1 (BuildTensorKind 14 (TKR2 1 (TKScalar Double))))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK
(Concrete
(BuildTensorKind 1 (BuildTensorKind 14 (TKR2 1 (TKScalar Double))))
-> Concrete
(BuildTensorKind
7
(BuildTensorKind
1 (BuildTensorKind 14 (TKR2 1 (TKScalar Double))))))
-> Concrete
(BuildTensorKind 1 (BuildTensorKind 14 (TKR2 1 (TKScalar Double))))
-> Concrete
(BuildTensorKind
7
(BuildTensorKind
1 (BuildTensorKind 14 (TKR2 1 (TKScalar Double)))))
forall a b. (a -> b) -> a -> b
$ SNat 1
-> SingletonTK (BuildTensorKind 14 (TKR2 1 (TKScalar Double)))
-> Concrete (BuildTensorKind 14 (TKR2 1 (TKScalar Double)))
-> Concrete
(BuildTensorKind 1 (BuildTensorKind 14 (TKR2 1 (TKScalar Double))))
forall (z :: TK) (k :: Natural).
ConvertTensor Concrete =>
SNat k
-> SingletonTK z -> Concrete z -> Concrete (BuildTensorKind k z)
forall (target :: Target) (z :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
treplicate (forall (n :: Natural). KnownNat n => SNat n
SNat @1) SingletonTK (BuildTensorKind 14 (TKR2 1 (TKScalar Double)))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK
(Concrete (BuildTensorKind 14 (TKR2 1 (TKScalar Double)))
-> Concrete
(BuildTensorKind
1 (BuildTensorKind 14 (TKR2 1 (TKScalar Double)))))
-> Concrete (BuildTensorKind 14 (TKR2 1 (TKScalar Double)))
-> Concrete
(BuildTensorKind 1 (BuildTensorKind 14 (TKR2 1 (TKScalar Double))))
forall a b. (a -> b) -> a -> b
$ SNat 14
-> SingletonTK (TKR2 1 (TKScalar Double))
-> Concrete (TKR2 1 (TKScalar Double))
-> Concrete (BuildTensorKind 14 (TKR2 1 (TKScalar Double)))
forall (z :: TK) (k :: Natural).
ConvertTensor Concrete =>
SNat k
-> SingletonTK z -> Concrete z -> Concrete (BuildTensorKind k z)
forall (target :: Target) (z :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
treplicate (forall (n :: Natural). KnownNat n => SNat n
SNat @14) SingletonTK (TKR2 1 (TKScalar Double))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK
(Concrete (TKR2 1 (TKScalar Double))
-> Concrete (BuildTensorKind 14 (TKR2 1 (TKScalar Double))))
-> Concrete (TKR2 1 (TKScalar Double))
-> Concrete (BuildTensorKind 14 (TKR2 1 (TKScalar Double)))
forall a b. (a -> b) -> a -> b
$ SNat 23
-> SingletonTK (TKR 0 Double)
-> Concrete (TKR 0 Double)
-> Concrete (BuildTensorKind 23 (TKR 0 Double))
forall (z :: TK) (k :: Natural).
ConvertTensor Concrete =>
SNat k
-> SingletonTK z -> Concrete z -> Concrete (BuildTensorKind k z)
forall (target :: Target) (z :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
treplicate (forall (n :: Natural). KnownNat n => SNat n
SNat @23) SingletonTK (TKR 0 Double)
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK (Concrete (TKR 0 Double)
-> Concrete (BuildTensorKind 23 (TKR 0 Double)))
-> Concrete (TKR 0 Double)
-> Concrete (BuildTensorKind 23 (TKR 0 Double))
forall a b. (a -> b) -> a -> b
$ Double -> Concrete (TKR 0 Double)
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
r -> target (TKR 0 r)
rscalar Double
7
afcnn2 :: ADReady f
=> MnistCnnRanked2.ADCnnMnistParameters f Double
-> f (TKR 2 Double)
afcnn2 :: forall (f :: Target).
ADReady f =>
ADCnnMnistParameters f Double -> f (TKR 2 Double)
afcnn2 = Int
-> Int
-> Int
-> PrimalOf f (TKR2 4 (TKScalar Double))
-> ADCnnMnistParameters f Double
-> f (TKR 2 Double)
forall (target :: Target) r.
(ADReady target, GoodScalar r, Differentiable r) =>
Int
-> Int
-> Int
-> PrimalOf target (TKR 4 r)
-> ADCnnMnistParameters target r
-> target (TKR 2 r)
MnistCnnRanked2.convMnistTwoR
Int
sizeMnistHeightI Int
sizeMnistWidthI Int
batch_size
(Ranked 4 Double -> PrimalOf f (TKR2 4 (TKScalar Double))
forall r (target :: Target) (n :: Natural).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete (Ranked 4 Double -> PrimalOf f (TKR2 4 (TKScalar Double)))
-> Ranked 4 Double -> PrimalOf f (TKR2 4 (TKScalar Double))
forall a b. (a -> b) -> a -> b
$ Concrete (TKR2 4 (TKScalar Double))
-> RepConcrete (TKR2 4 (TKScalar Double))
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete (TKR2 4 (TKScalar Double))
blackGlyph)
afcnn1 :: AstTensor AstMethodLet FullSpan (TKR 2 Double)
afcnn1 = ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) Double
-> AstTensor AstMethodLet FullSpan (TKR 2 Double)
forall (f :: Target).
ADReady f =>
ADCnnMnistParameters f Double -> f (TKR 2 Double)
afcnn2 (ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) Double
-> AstTensor AstMethodLet FullSpan (TKR 2 Double))
-> ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) Double
-> AstTensor AstMethodLet FullSpan (TKR 2 Double)
forall a b. (a -> b) -> a -> b
$ AstTensor
AstMethodLet
FullSpan
(X (ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) Double))
-> ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) Double
forall (target :: Target) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget AstTensor
AstMethodLet
FullSpan
(X (ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) Double))
AstTensor
AstMethodLet FullSpan (X (ADCnnMnistParameters Concrete Double))
var
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target -> AstTensor AstMethodLet FullSpan y -> target y
interpretAstFull @Concrete AstEnv Concrete
env AstTensor AstMethodLet FullSpan (TKR 2 Double)
afcnn1
Concrete (TKR 2 Double) -> Concrete (TKR 2 Double) -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= ADCnnMnistParameters Concrete Double -> Concrete (TKR 2 Double)
forall (f :: Target).
ADReady f =>
ADCnnMnistParameters f Double -> f (TKR 2 Double)
afcnn2 ADCnnMnistParameters Concrete Double
valsInit
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target -> AstTensor AstMethodLet FullSpan y -> target y
interpretAstFull @Concrete AstEnv Concrete
env
(forall (z :: TK) (s :: AstSpanType).
AstSpan s =>
AstTensor AstMethodLet s z -> AstTensor AstMethodLet s z
simplifyInline @(TKR 2 Double) AstTensor AstMethodLet FullSpan (TKR 2 Double)
afcnn1)
Concrete (TKR 2 Double) -> Concrete (TKR 2 Double) -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= ADCnnMnistParameters Concrete Double -> Concrete (TKR 2 Double)
forall (f :: Target).
ADReady f =>
ADCnnMnistParameters f Double -> f (TKR 2 Double)
afcnn2 ADCnnMnistParameters Concrete Double
valsInit
forall (target :: Target) (y :: TK).
ADReady target =>
AstEnv target -> AstTensor AstMethodLet FullSpan y -> target y
interpretAstFull @Concrete AstEnv Concrete
env
(forall (z :: TK) (s :: AstSpanType).
AstSpan s =>
AstTensor AstMethodLet s z -> AstTensor AstMethodLet s z
simplifyInlineContract @(TKR 2 Double) AstTensor AstMethodLet FullSpan (TKR 2 Double)
afcnn1)
Concrete (TKR 2 Double) -> Concrete (TKR 2 Double) -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= ADCnnMnistParameters Concrete Double -> Concrete (TKR 2 Double)
forall (f :: Target).
ADReady f =>
ADCnnMnistParameters f Double -> f (TKR 2 Double)
afcnn2 ADCnnMnistParameters Concrete Double
valsInit
testCNNOPP2S :: Assertion
testCNNOPP2S :: Assertion
testCNNOPP2S = do
Assertion
resetVarCounter
let batch_size :: SNat 7
batch_size = forall (n :: Natural). KnownNat n => SNat n
SNat @7
sizeMnistWidthI :: SNat 14
sizeMnistWidthI = forall (n :: Natural). KnownNat n => SNat n
SNat @14
sizeMnistHeightI :: SNat 23
sizeMnistHeightI = forall (n :: Natural). KnownNat n => SNat n
SNat @23
ftk :: FullShapeTK
(TKProduct
(TKProduct
(TKProduct
(TKS
((':)
@Natural
4
((':)
@Natural 1 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double)
(TKS ((':) @Natural 4 ('[] @Natural)) Double))
(TKProduct
(TKS
((':)
@Natural
4
((':)
@Natural 4 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double)
(TKS ((':) @Natural 4 ('[] @Natural)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 5 ((':) @Natural 60 ('[] @Natural))) Double)
(TKS ((':) @Natural 5 ('[] @Natural)) Double))
(TKProduct
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 5 ('[] @Natural)))
Double)
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double))))
ftk = forall (target :: Target) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete
(forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(X (MnistCnnShaped2.ADCnnMnistParametersShaped
Concrete 14 23 2 3 4 5 Double)))
Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS
((':)
@Natural
4
((':)
@Natural 1 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double)
(TKS ((':) @Natural 4 ('[] @Natural)) Double))
(TKProduct
(TKS
((':)
@Natural
4
((':)
@Natural 4 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double)
(TKS ((':) @Natural 4 ('[] @Natural)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 5 ((':) @Natural 60 ('[] @Natural))) Double)
(TKS ((':) @Natural 5 ('[] @Natural)) Double))
(TKProduct
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 5 ('[] @Natural)))
Double)
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double))))
vals
valsInit :: MnistCnnShaped2.ADCnnMnistParametersShaped Concrete 14 23 2 3 4 5 Double
valsInit :: ADCnnMnistParametersShaped Concrete 14 23 2 3 4 5 Double
valsInit =
(ADCnnMnistParametersShaped Concrete 14 23 2 3 4 5 Double, StdGen)
-> ADCnnMnistParametersShaped Concrete 14 23 2 3 4 5 Double
forall a b. (a, b) -> a
fst
((ADCnnMnistParametersShaped Concrete 14 23 2 3 4 5 Double, StdGen)
-> ADCnnMnistParametersShaped Concrete 14 23 2 3 4 5 Double)
-> (ADCnnMnistParametersShaped Concrete 14 23 2 3 4 5 Double,
StdGen)
-> ADCnnMnistParametersShaped Concrete 14 23 2 3 4 5 Double
forall a b. (a -> b) -> a -> b
$ forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue @(MnistCnnShaped2.ADCnnMnistParametersShaped
Concrete 14 23
2 3 4 5 Double)
Double
0.4 (Int -> StdGen
mkStdGen Int
44)
vals :: Concrete
(X ((Concrete
(TKS
((':)
@Natural
4
((':)
@Natural 1 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(Concrete
(TKS
((':)
@Natural
4
((':)
@Natural 4 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(Concrete
(TKS ((':) @Natural 5 ((':) @Natural 60 ('[] @Natural))) Double),
Concrete (TKS ((':) @Natural 5 ('[] @Natural)) Double)),
(Concrete
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 5 ('[] @Natural)))
Double),
Concrete
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double))))
vals = forall (target :: Target) vals.
AdaptableTarget target vals =>
vals -> target (X vals)
toTarget @Concrete ((Concrete
(TKS
((':)
@Natural
4
((':)
@Natural 1 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(Concrete
(TKS
((':)
@Natural
4
((':)
@Natural 4 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double),
Concrete (TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(Concrete
(TKS ((':) @Natural 5 ((':) @Natural 60 ('[] @Natural))) Double),
Concrete (TKS ((':) @Natural 5 ('[] @Natural)) Double)),
(Concrete
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 5 ('[] @Natural)))
Double),
Concrete
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
ADCnnMnistParametersShaped Concrete 14 23 2 3 4 5 Double
valsInit
blackGlyph :: Concrete
(BuildTensorKind
7
(BuildTensorKind
1
(BuildTensorKind
14 (TKS2 ((':) @Natural 23 ('[] @Natural)) (TKScalar Double)))))
blackGlyph = SNat 7
-> SingletonTK
(BuildTensorKind
1
(BuildTensorKind
14 (TKS2 ((':) @Natural 23 ('[] @Natural)) (TKScalar Double))))
-> Concrete
(BuildTensorKind
1
(BuildTensorKind
14 (TKS2 ((':) @Natural 23 ('[] @Natural)) (TKScalar Double))))
-> Concrete
(BuildTensorKind
7
(BuildTensorKind
1
(BuildTensorKind
14 (TKS2 ((':) @Natural 23 ('[] @Natural)) (TKScalar Double)))))
forall (z :: TK) (k :: Natural).
ConvertTensor Concrete =>
SNat k
-> SingletonTK z -> Concrete z -> Concrete (BuildTensorKind k z)
forall (target :: Target) (z :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
treplicate SNat 7
batch_size SingletonTK
(BuildTensorKind
1
(BuildTensorKind
14 (TKS2 ((':) @Natural 23 ('[] @Natural)) (TKScalar Double))))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK
(Concrete
(BuildTensorKind
1
(BuildTensorKind
14 (TKS2 ((':) @Natural 23 ('[] @Natural)) (TKScalar Double))))
-> Concrete
(BuildTensorKind
7
(BuildTensorKind
1
(BuildTensorKind
14 (TKS2 ((':) @Natural 23 ('[] @Natural)) (TKScalar Double))))))
-> Concrete
(BuildTensorKind
1
(BuildTensorKind
14 (TKS2 ((':) @Natural 23 ('[] @Natural)) (TKScalar Double))))
-> Concrete
(BuildTensorKind
7
(BuildTensorKind
1
(BuildTensorKind
14 (TKS2 ((':) @Natural 23 ('[] @Natural)) (TKScalar Double)))))
forall a b. (a -> b) -> a -> b
$ SNat 1
-> SingletonTK
(BuildTensorKind
14 (TKS2 ((':) @Natural 23 ('[] @Natural)) (TKScalar Double)))
-> Concrete
(BuildTensorKind
14 (TKS2 ((':) @Natural 23 ('[] @Natural)) (TKScalar Double)))
-> Concrete
(BuildTensorKind
1
(BuildTensorKind
14 (TKS2 ((':) @Natural 23 ('[] @Natural)) (TKScalar Double))))
forall (z :: TK) (k :: Natural).
ConvertTensor Concrete =>
SNat k
-> SingletonTK z -> Concrete z -> Concrete (BuildTensorKind k z)
forall (target :: Target) (z :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
treplicate (forall (n :: Natural). KnownNat n => SNat n
SNat @1) SingletonTK
(BuildTensorKind
14 (TKS2 ((':) @Natural 23 ('[] @Natural)) (TKScalar Double)))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK
(Concrete
(BuildTensorKind
14 (TKS2 ((':) @Natural 23 ('[] @Natural)) (TKScalar Double)))
-> Concrete
(BuildTensorKind
1
(BuildTensorKind
14 (TKS2 ((':) @Natural 23 ('[] @Natural)) (TKScalar Double)))))
-> Concrete
(BuildTensorKind
14 (TKS2 ((':) @Natural 23 ('[] @Natural)) (TKScalar Double)))
-> Concrete
(BuildTensorKind
1
(BuildTensorKind
14 (TKS2 ((':) @Natural 23 ('[] @Natural)) (TKScalar Double))))
forall a b. (a -> b) -> a -> b
$ SNat 14
-> SingletonTK
(TKS2 ((':) @Natural 23 ('[] @Natural)) (TKScalar Double))
-> Concrete
(TKS2 ((':) @Natural 23 ('[] @Natural)) (TKScalar Double))
-> Concrete
(BuildTensorKind
14 (TKS2 ((':) @Natural 23 ('[] @Natural)) (TKScalar Double)))
forall (z :: TK) (k :: Natural).
ConvertTensor Concrete =>
SNat k
-> SingletonTK z -> Concrete z -> Concrete (BuildTensorKind k z)
forall (target :: Target) (z :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
treplicate SNat 14
sizeMnistWidthI SingletonTK
(TKS2 ((':) @Natural 23 ('[] @Natural)) (TKScalar Double))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK
(Concrete
(TKS2 ((':) @Natural 23 ('[] @Natural)) (TKScalar Double))
-> Concrete
(BuildTensorKind
14 (TKS2 ((':) @Natural 23 ('[] @Natural)) (TKScalar Double))))
-> Concrete
(TKS2 ((':) @Natural 23 ('[] @Natural)) (TKScalar Double))
-> Concrete
(BuildTensorKind
14 (TKS2 ((':) @Natural 23 ('[] @Natural)) (TKScalar Double)))
forall a b. (a -> b) -> a -> b
$ SNat 23
-> SingletonTK (TKS ('[] @Natural) Double)
-> Concrete (TKS ('[] @Natural) Double)
-> Concrete (BuildTensorKind 23 (TKS ('[] @Natural) Double))
forall (z :: TK) (k :: Natural).
ConvertTensor Concrete =>
SNat k
-> SingletonTK z -> Concrete z -> Concrete (BuildTensorKind k z)
forall (target :: Target) (z :: TK) (k :: Natural).
(BaseTensor target, ConvertTensor target) =>
SNat k -> SingletonTK z -> target z -> target (BuildTensorKind k z)
treplicate SNat 23
sizeMnistHeightI SingletonTK (TKS ('[] @Natural) Double)
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK (Concrete (TKS ('[] @Natural) Double)
-> Concrete (BuildTensorKind 23 (TKS ('[] @Natural) Double)))
-> Concrete (TKS ('[] @Natural) Double)
-> Concrete (BuildTensorKind 23 (TKS ('[] @Natural) Double))
forall a b. (a -> b) -> a -> b
$ Double -> Concrete (TKS ('[] @Natural) Double)
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
r -> target (TKS ('[] @Natural) r)
sscalar Double
7
afcnn2 :: ADReady f
=> MnistCnnShaped2.ADCnnMnistParametersShaped f 14 23 2 3 4 5 Double
-> f (TKS '[SizeMnistLabel, 7] Double)
afcnn2 :: forall (f :: Target).
ADReady f =>
ADCnnMnistParametersShaped f 14 23 2 3 4 5 Double
-> f (TKS
((':) @Natural SizeMnistLabel ((':) @Natural 7 ('[] @Natural)))
Double)
afcnn2 = SNat 2
-> SNat 3
-> SNat 14
-> SNat 23
-> SNat 4
-> SNat 5
-> SNat 7
-> PrimalOf
f
(TKS
((':)
@Natural
7
((':)
@Natural 1 ((':) @Natural 14 ((':) @Natural 23 ('[] @Natural)))))
Double)
-> ((f (TKS2
((':)
@Natural
4
((':)
@Natural
1
((':) @Natural (2 + 1) ((':) @Natural (3 + 1) ('[] @Natural)))))
(TKScalar Double)),
f (TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(f (TKS2
((':)
@Natural
4
((':)
@Natural
4
((':) @Natural (2 + 1) ((':) @Natural (3 + 1) ('[] @Natural)))))
(TKScalar Double)),
f (TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(f (TKS2
((':)
@Natural
5
((':) @Natural ((4 * Div 14 4) * Div 23 4) ('[] @Natural)))
(TKScalar Double)),
f (TKS ((':) @Natural 5 ('[] @Natural)) Double)),
(f (TKS
((':) @Natural SizeMnistLabel ((':) @Natural 5 ('[] @Natural)))
Double),
f (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
-> f (TKS
((':) @Natural SizeMnistLabel ((':) @Natural 7 ('[] @Natural)))
Double)
forall (kh :: Natural) (kw :: Natural) (h :: Natural)
(w :: Natural) (c_out :: Natural) (n_hidden :: Natural)
(batch_size :: Natural) (target :: Target) r.
((<=) @Natural 1 kh, (<=) @Natural 1 kw, ADReady target,
GoodScalar r, Differentiable r) =>
SNat kh
-> SNat kw
-> SNat h
-> SNat w
-> SNat c_out
-> SNat n_hidden
-> SNat batch_size
-> PrimalOf
target
(TKS
((':)
@Natural
batch_size
((':)
@Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
r)
-> ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
-> target
(TKS
((':)
@Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
r)
MnistCnnShaped2.convMnistTwoS
(forall (n :: Natural). KnownNat n => SNat n
SNat @2) (forall (n :: Natural). KnownNat n => SNat n
SNat @3) SNat 14
sizeMnistWidthI SNat 23
sizeMnistHeightI
(forall (n :: Natural). KnownNat n => SNat n
SNat @4) (forall (n :: Natural). KnownNat n => SNat n
SNat @5) SNat 7
batch_size
(Shaped
((':)
@Natural
7
((':)
@Natural 1 ((':) @Natural 14 ((':) @Natural 23 ('[] @Natural)))))
Double
-> PrimalOf
f
(TKS
((':)
@Natural
7
((':)
@Natural 1 ((':) @Natural 14 ((':) @Natural 23 ('[] @Natural)))))
Double)
forall r (target :: Target) (sh :: [Natural]).
(GoodScalar r, BaseTensor target) =>
Shaped sh r -> target (TKS sh r)
sconcrete (Shaped
((':)
@Natural
7
((':)
@Natural 1 ((':) @Natural 14 ((':) @Natural 23 ('[] @Natural)))))
Double
-> PrimalOf
f
(TKS
((':)
@Natural
7
((':)
@Natural 1 ((':) @Natural 14 ((':) @Natural 23 ('[] @Natural)))))
Double))
-> Shaped
((':)
@Natural
7
((':)
@Natural 1 ((':) @Natural 14 ((':) @Natural 23 ('[] @Natural)))))
Double
-> PrimalOf
f
(TKS
((':)
@Natural
7
((':)
@Natural 1 ((':) @Natural 14 ((':) @Natural 23 ('[] @Natural)))))
Double)
forall a b. (a -> b) -> a -> b
$ Concrete
(TKS
((':)
@Natural
7
((':)
@Natural 1 ((':) @Natural 14 ((':) @Natural 23 ('[] @Natural)))))
Double)
-> RepConcrete
(TKS
((':)
@Natural
7
((':)
@Natural 1 ((':) @Natural 14 ((':) @Natural 23 ('[] @Natural)))))
Double)
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete Concrete
(TKS
((':)
@Natural
7
((':)
@Natural 1 ((':) @Natural 14 ((':) @Natural 23 ('[] @Natural)))))
Double)
blackGlyph)
artifactRev :: AstArtifactRev
(X ((AstTensor
AstMethodLet
FullSpan
(TKS
((':)
@Natural
4
((':)
@Natural 1 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(AstTensor
AstMethodLet
FullSpan
(TKS
((':)
@Natural
4
((':)
@Natural 4 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 5 ((':) @Natural 60 ('[] @Natural))) Double),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 5 ('[] @Natural)) Double)),
(AstTensor
AstMethodLet
FullSpan
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 5 ('[] @Natural)))
Double),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double))))
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 7 ('[] @Natural)))
Double)
artifactRev = IncomingCotangentHandling
-> (((AstTensor
AstMethodLet
FullSpan
(TKS
((':)
@Natural
4
((':)
@Natural 1 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(AstTensor
AstMethodLet
FullSpan
(TKS
((':)
@Natural
4
((':)
@Natural 4 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 5 ((':) @Natural 60 ('[] @Natural))) Double),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 5 ('[] @Natural)) Double)),
(AstTensor
AstMethodLet
FullSpan
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 5 ('[] @Natural)))
Double),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
-> AstTensor
AstMethodLet
FullSpan
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 7 ('[] @Natural)))
Double))
-> FullShapeTK
(X ((AstTensor
AstMethodLet
FullSpan
(TKS
((':)
@Natural
4
((':)
@Natural 1 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(AstTensor
AstMethodLet
FullSpan
(TKS
((':)
@Natural
4
((':)
@Natural 4 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 5 ((':) @Natural 60 ('[] @Natural))) Double),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 5 ('[] @Natural)) Double)),
(AstTensor
AstMethodLet
FullSpan
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 5 ('[] @Natural)))
Double),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double))))
-> AstArtifactRev
(X ((AstTensor
AstMethodLet
FullSpan
(TKS
((':)
@Natural
4
((':)
@Natural 1 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(AstTensor
AstMethodLet
FullSpan
(TKS
((':)
@Natural
4
((':)
@Natural 4 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 5 ((':) @Natural 60 ('[] @Natural))) Double),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 5 ('[] @Natural)) Double)),
(AstTensor
AstMethodLet
FullSpan
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 5 ('[] @Natural)))
Double),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double))))
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 7 ('[] @Natural)))
Double)
forall src (ztgt :: TK) tgt.
(AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
(tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
IncomingCotangentHandling
-> (src -> tgt)
-> FullShapeTK (X src)
-> AstArtifactRev (X src) ztgt
revArtifactAdapt IncomingCotangentHandling
UseIncomingCotangent ((AstTensor
AstMethodLet
FullSpan
(TKS
((':)
@Natural
4
((':)
@Natural 1 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(AstTensor
AstMethodLet
FullSpan
(TKS
((':)
@Natural
4
((':)
@Natural 4 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 5 ((':) @Natural 60 ('[] @Natural))) Double),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 5 ('[] @Natural)) Double)),
(AstTensor
AstMethodLet
FullSpan
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 5 ('[] @Natural)))
Double),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double)))
-> AstTensor
AstMethodLet
FullSpan
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 7 ('[] @Natural)))
Double)
ADCnnMnistParametersShaped
(AstTensor AstMethodLet FullSpan) 14 23 2 3 4 5 Double
-> AstTensor
AstMethodLet
FullSpan
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 7 ('[] @Natural)))
Double)
forall (f :: Target).
ADReady f =>
ADCnnMnistParametersShaped f 14 23 2 3 4 5 Double
-> f (TKS
((':) @Natural SizeMnistLabel ((':) @Natural 7 ('[] @Natural)))
Double)
afcnn2 FullShapeTK
(TKProduct
(TKProduct
(TKProduct
(TKS
((':)
@Natural
4
((':)
@Natural 1 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double)
(TKS ((':) @Natural 4 ('[] @Natural)) Double))
(TKProduct
(TKS
((':)
@Natural
4
((':)
@Natural 4 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double)
(TKS ((':) @Natural 4 ('[] @Natural)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 5 ((':) @Natural 60 ('[] @Natural))) Double)
(TKS ((':) @Natural 5 ('[] @Natural)) Double))
(TKProduct
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 5 ('[] @Natural)))
Double)
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double))))
FullShapeTK
(X ((AstTensor
AstMethodLet
FullSpan
(TKS
((':)
@Natural
4
((':)
@Natural 1 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(AstTensor
AstMethodLet
FullSpan
(TKS
((':)
@Natural
4
((':)
@Natural 4 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 4 ('[] @Natural)) Double)),
(AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 5 ((':) @Natural 60 ('[] @Natural))) Double),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural 5 ('[] @Natural)) Double)),
(AstTensor
AstMethodLet
FullSpan
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 5 ('[] @Natural)))
Double),
AstTensor
AstMethodLet
FullSpan
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double))))
ftk
AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKS
((':)
@Natural
4
((':)
@Natural 1 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double)
(TKS ((':) @Natural 4 ('[] @Natural)) Double))
(TKProduct
(TKS
((':)
@Natural
4
((':)
@Natural 4 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double)
(TKS ((':) @Natural 4 ('[] @Natural)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 5 ((':) @Natural 60 ('[] @Natural))) Double)
(TKS ((':) @Natural 5 ('[] @Natural)) Double))
(TKProduct
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 5 ('[] @Natural)))
Double)
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double))))
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 7 ('[] @Natural)))
Double)
-> String
forall (x :: TK) (z :: TK). AstArtifactRev x z -> String
printArtifactPrimalPretty (AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKS
((':)
@Natural
4
((':)
@Natural 1 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double)
(TKS ((':) @Natural 4 ('[] @Natural)) Double))
(TKProduct
(TKS
((':)
@Natural
4
((':)
@Natural 4 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double)
(TKS ((':) @Natural 4 ('[] @Natural)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 5 ((':) @Natural 60 ('[] @Natural))) Double)
(TKS ((':) @Natural 5 ('[] @Natural)) Double))
(TKProduct
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 5 ('[] @Natural)))
Double)
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double))))
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 7 ('[] @Natural)))
Double)
-> AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKS
((':)
@Natural
4
((':)
@Natural 1 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double)
(TKS ((':) @Natural 4 ('[] @Natural)) Double))
(TKProduct
(TKS
((':)
@Natural
4
((':)
@Natural 4 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double)
(TKS ((':) @Natural 4 ('[] @Natural)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 5 ((':) @Natural 60 ('[] @Natural))) Double)
(TKS ((':) @Natural 5 ('[] @Natural)) Double))
(TKProduct
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 5 ('[] @Natural)))
Double)
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double))))
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 7 ('[] @Natural)))
Double)
forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstArtifactRev x z
simplifyArtifact AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKS
((':)
@Natural
4
((':)
@Natural 1 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double)
(TKS ((':) @Natural 4 ('[] @Natural)) Double))
(TKProduct
(TKS
((':)
@Natural
4
((':)
@Natural 4 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double)
(TKS ((':) @Natural 4 ('[] @Natural)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 5 ((':) @Natural 60 ('[] @Natural))) Double)
(TKS ((':) @Natural 5 ('[] @Natural)) Double))
(TKProduct
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 5 ('[] @Natural)))
Double)
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double))))
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 7 ('[] @Natural)))
Double)
artifactRev)
String -> String -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= String
"\\u1 -> let t255 = ssum @12 (stranspose @[3,0,1,2] (sreshape @[4,14,23,12] (sreplicate @4 (stranspose @[1,2,0] (sreplicate @1 (stranspose @[1,2,0] (sreplicate @1 (stranspose @[2,0,3,1] (sgather (stranspose @[2,0,1] (sgather (sconcrete (sreplicate [14,23] 7.0)) (\\[i473, i474] -> [i473 + i474]))) (\\[i251, i252] -> [i251 + i252]))))))) * str (sreplicate @14 (str (sreplicate @23 (str (sreplicate @1 (str (sreplicate @1 (str (tproject1 (tproject1 (tproject1 u1))) !$ [0]))))))))))) + stranspose @[2,0,1] (sreplicate @14 (sreplicate @23 (tproject2 (tproject1 (tproject1 u1))))) ; u271 = sreshape @[4,7,11,4] (sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i258, i259, i260, i261, i262] -> [ifH (sscalar -0.0 <=. negate (t255 !$ [i258, kfromS (sconcrete (sfromListLinear [7,2] [0,1,2,3,4,5,6,7,8,9,10,11,12,13]) !$ [i259, i261]), kfromS (sconcrete (sfromListLinear [11,2] [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21]) !$ [i260, i262])])) 0 1]) * stranspose @[4,0,1,2,3] (sgather (stranspose @[1,2,0] t255) (\\[i265, i266, i267, i268] -> [kfromS (sconcrete (sfromListLinear [7,2] [0,1,2,3,4,5,6,7,8,9,10,11,12,13]) !$ [i265, i267]), kfromS (sconcrete (sfromListLinear [11,2] [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21]) !$ [i266, i268])]))) ; t280 = ssum @48 (stranspose @[3,0,1,2] (sreshape @[4,7,11,48] (sreplicate @4 (stranspose @[2,0,3,4,1] (sgather (stranspose @[2,4,5,0,3,1] (sgather (stranspose @[4,6,0,5,2,1,3] (sgather (stranspose @[4,3,2,1,7,6,5,0] (sreplicate @3 (stranspose @[6,5,4,3,0,1,2] (sreplicate @11 (sreplicate @7 (sreplicate @4 (stranspose @[3,2,1,0] u271))))))) (\\[i458, i460, i461] -> [kfromS (smaxIndex (u271 !$ [i461, i458, i460])), i460, i458, i461]))) (\\[i466, i468] -> [i466, i468, i466 + i468]))) (\\[i277, i278] -> [i277, i277 + i278, i278]))) * str (sreplicate @7 (str (sreplicate @11 (tproject1 (tproject2 (tproject1 u1))))))))) + stranspose @[2,0,1] (sreplicate @7 (sreplicate @11 (tproject2 (tproject2 (tproject1 u1))))) ; u296 = sreshape @[4,3,5,4] (sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i283, i284, i285, i286, i287] -> [ifH (sscalar -0.0 <=. negate (t280 !$ [i283, kfromS (sconcrete (sfromListLinear [3,2] [0,1,2,3,4,5]) !$ [i284, i286]), kfromS (sconcrete (sfromListLinear [5,2] [0,1,2,3,4,5,6,7,8,9]) !$ [i285, i287])])) 0 1]) * stranspose @[4,0,1,2,3] (sgather (stranspose @[1,2,0] t280) (\\[i290, i291, i292, i293] -> [kfromS (sconcrete (sfromListLinear [3,2] [0,1,2,3,4,5]) !$ [i290, i292]), kfromS (sconcrete (sfromListLinear [5,2] [0,1,2,3,4,5,6,7,8,9]) !$ [i291, i293])]))) ; m301 = str (sreplicate @7 (sdot1In (tproject1 (tproject1 (tproject2 u1))) (sreplicate @5 (sreshape @[60] (sgather u296 (\\[i297, i298, i299] -> [i297, i298, i299, kfromS (smaxIndex (u296 !$ [i297, i298, i299]))])))))) + str (sreplicate @7 (tproject2 (tproject1 (tproject2 u1)))) in smatmul2 (tproject1 (tproject2 (tproject2 u1))) (sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i302, i303] -> [ifH (sscalar -0.0 <=. negate (m301 !$ [i302, i303])) 0 1]) * m301) + str (sreplicate @7 (tproject2 (tproject2 (tproject2 u1))))"
AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKS
((':)
@Natural
4
((':)
@Natural 1 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double)
(TKS ((':) @Natural 4 ('[] @Natural)) Double))
(TKProduct
(TKS
((':)
@Natural
4
((':)
@Natural 4 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double)
(TKS ((':) @Natural 4 ('[] @Natural)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 5 ((':) @Natural 60 ('[] @Natural))) Double)
(TKS ((':) @Natural 5 ('[] @Natural)) Double))
(TKProduct
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 5 ('[] @Natural)))
Double)
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double))))
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 7 ('[] @Natural)))
Double)
-> String
forall (x :: TK) (z :: TK). AstArtifactRev x z -> String
printArtifactPrimalPretty AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKS
((':)
@Natural
4
((':)
@Natural 1 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double)
(TKS ((':) @Natural 4 ('[] @Natural)) Double))
(TKProduct
(TKS
((':)
@Natural
4
((':)
@Natural 4 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double)
(TKS ((':) @Natural 4 ('[] @Natural)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 5 ((':) @Natural 60 ('[] @Natural))) Double)
(TKS ((':) @Natural 5 ('[] @Natural)) Double))
(TKProduct
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 5 ('[] @Natural)))
Double)
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double))))
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 7 ('[] @Natural)))
Double)
artifactRev
String -> String -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= String
"\\u1 -> let w253 = sreplicate @4 (stranspose @[1,2,0] (sreplicate @1 (stranspose @[1,2,0] (sreplicate @1 (stranspose @[2,0,3,1] (sgather (stranspose @[2,0,1] (sgather (sconcrete (sreplicate [14,23] 7.0)) (\\[i249, i250] -> [i249 + i250]))) (\\[i251, i252] -> [i251 + i252]))))))) ; w254 = str (sreplicate @14 (str (sreplicate @23 (str (sreplicate @1 (str (sreplicate @1 (str (tproject1 (tproject1 (tproject1 u1))) !$ [0])))))))) ; t255 = ssum @12 (stranspose @[3,0,1,2] (sreshape @[4,14,23,12] (w253 * w254))) + stranspose @[2,0,1] (sreplicate @14 (sreplicate @23 (tproject2 (tproject1 (tproject1 u1))))) ; m256 = str (sreplicate @2 (sconcrete (sreplicate [7] 2) * siota (SNat @7))) + sreplicate @7 (siota (SNat @2)) ; m257 = str (sreplicate @2 (sconcrete (sreplicate [11] 2) * siota (SNat @11))) + sreplicate @11 (siota (SNat @2)) ; w269 = sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i258, i259, i260, i261, i262] -> [let x263 = m256 !$ [i259, i261] ; x264 = m257 !$ [i260, i262] in ifH (sscalar -0.0 <=. negate (t255 !$ [i258, kfromS x263, kfromS x264])) 0 1]) ; w270 = stranspose @[4,0,1,2,3] (sgather (stranspose @[1,2,0] t255) (\\[i265, i266, i267, i268] -> [kfromS (m256 !$ [i265, i267]), kfromS (m257 !$ [i266, i268])])) ; u271 = sreshape @[4,7,11,4] (w269 * w270) ; w279 = sreplicate @4 (stranspose @[2,0,3,4,1] (sgather (stranspose @[2,4,5,0,3,1] (sgather (stranspose @[4,6,0,5,2,1,3] (sgather (stranspose @[4,3,2,1,7,6,5,0] (sreplicate @3 (stranspose @[6,5,4,3,0,1,2] (sreplicate @11 (sreplicate @7 (sreplicate @4 (stranspose @[3,2,1,0] u271))))))) (\\[i272, i273, i274] -> [kfromS (smaxIndex (u271 !$ [i274, i272, i273])), i273, i272, i274]))) (\\[i275, i276] -> [i275, i276, i275 + i276]))) (\\[i277, i278] -> [i277, i277 + i278, i278]))) ; t280 = ssum @48 (stranspose @[3,0,1,2] (sreshape @[4,7,11,48] (w279 * str (sreplicate @7 (str (sreplicate @11 (tproject1 (tproject2 (tproject1 u1))))))))) + stranspose @[2,0,1] (sreplicate @7 (sreplicate @11 (tproject2 (tproject2 (tproject1 u1))))) ; m281 = str (sreplicate @2 (sconcrete (sreplicate [3] 2) * siota (SNat @3))) + sreplicate @3 (siota (SNat @2)) ; m282 = str (sreplicate @2 (sconcrete (sreplicate [5] 2) * siota (SNat @5))) + sreplicate @5 (siota (SNat @2)) ; w294 = sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i283, i284, i285, i286, i287] -> [let x288 = m281 !$ [i284, i286] ; x289 = m282 !$ [i285, i287] in ifH (sscalar -0.0 <=. negate (t280 !$ [i283, kfromS x288, kfromS x289])) 0 1]) ; w295 = stranspose @[4,0,1,2,3] (sgather (stranspose @[1,2,0] t280) (\\[i290, i291, i292, i293] -> [kfromS (m281 !$ [i290, i292]), kfromS (m282 !$ [i291, i293])])) ; u296 = sreshape @[4,3,5,4] (w294 * w295) ; m300 = str (sreplicate @5 (sreshape @[60] (sgather u296 (\\[i297, i298, i299] -> [i297, i298, i299, kfromS (smaxIndex (u296 !$ [i297, i298, i299]))])))) ; m301 = str (sreplicate @7 (ssum @60 (str (tproject1 (tproject1 (tproject2 u1))) * m300))) + str (sreplicate @7 (tproject2 (tproject1 (tproject2 u1)))) ; m304 = sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i302, i303] -> [ifH (sscalar -0.0 <=. negate (m301 !$ [i302, i303])) 0 1]) ; t305 = str (sreplicate @10 (m304 * m301)) in ssum @5 (stranspose @[2,1,0] (sreplicate @7 (tproject1 (tproject2 (tproject2 u1)))) * t305) + str (sreplicate @7 (tproject2 (tproject2 (tproject2 u1))))"
AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKS
((':)
@Natural
4
((':)
@Natural 1 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double)
(TKS ((':) @Natural 4 ('[] @Natural)) Double))
(TKProduct
(TKS
((':)
@Natural
4
((':)
@Natural 4 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double)
(TKS ((':) @Natural 4 ('[] @Natural)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 5 ((':) @Natural 60 ('[] @Natural))) Double)
(TKS ((':) @Natural 5 ('[] @Natural)) Double))
(TKProduct
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 5 ('[] @Natural)))
Double)
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double))))
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 7 ('[] @Natural)))
Double)
-> String
forall (x :: TK) (z :: TK). AstArtifactRev x z -> String
printArtifactPretty AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKS
((':)
@Natural
4
((':)
@Natural 1 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double)
(TKS ((':) @Natural 4 ('[] @Natural)) Double))
(TKProduct
(TKS
((':)
@Natural
4
((':)
@Natural 4 ((':) @Natural 3 ((':) @Natural 4 ('[] @Natural)))))
Double)
(TKS ((':) @Natural 4 ('[] @Natural)) Double)))
(TKProduct
(TKProduct
(TKS ((':) @Natural 5 ((':) @Natural 60 ('[] @Natural))) Double)
(TKS ((':) @Natural 5 ('[] @Natural)) Double))
(TKProduct
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 5 ('[] @Natural)))
Double)
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) Double))))
(TKS
((':) @Natural SizeMnistLabel ((':) @Natural 7 ('[] @Natural)))
Double)
artifactRev
String -> String -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= String
"\\dret u1 -> let w253 = sreplicate @4 (stranspose @[1,2,0] (sreplicate @1 (stranspose @[1,2,0] (sreplicate @1 (stranspose @[2,0,3,1] (sgather (stranspose @[2,0,1] (sgather (sconcrete (sreplicate [14,23] 7.0)) (\\[i249, i250] -> [i249 + i250]))) (\\[i251, i252] -> [i251 + i252]))))))) ; w254 = str (sreplicate @14 (str (sreplicate @23 (str (sreplicate @1 (str (sreplicate @1 (str (tproject1 (tproject1 (tproject1 u1))) !$ [0])))))))) ; t255 = ssum @12 (stranspose @[3,0,1,2] (sreshape @[4,14,23,12] (w253 * w254))) + stranspose @[2,0,1] (sreplicate @14 (sreplicate @23 (tproject2 (tproject1 (tproject1 u1))))) ; m256 = str (sreplicate @2 (sconcrete (sreplicate [7] 2) * siota (SNat @7))) + sreplicate @7 (siota (SNat @2)) ; m257 = str (sreplicate @2 (sconcrete (sreplicate [11] 2) * siota (SNat @11))) + sreplicate @11 (siota (SNat @2)) ; w269 = sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i258, i259, i260, i261, i262] -> [let x263 = m256 !$ [i259, i261] ; x264 = m257 !$ [i260, i262] in ifH (sscalar -0.0 <=. negate (t255 !$ [i258, kfromS x263, kfromS x264])) 0 1]) ; w270 = stranspose @[4,0,1,2,3] (sgather (stranspose @[1,2,0] t255) (\\[i265, i266, i267, i268] -> [kfromS (m256 !$ [i265, i267]), kfromS (m257 !$ [i266, i268])])) ; u271 = sreshape @[4,7,11,4] (w269 * w270) ; w279 = sreplicate @4 (stranspose @[2,0,3,4,1] (sgather (stranspose @[2,4,5,0,3,1] (sgather (stranspose @[4,6,0,5,2,1,3] (sgather (stranspose @[4,3,2,1,7,6,5,0] (sreplicate @3 (stranspose @[6,5,4,3,0,1,2] (sreplicate @11 (sreplicate @7 (sreplicate @4 (stranspose @[3,2,1,0] u271))))))) (\\[i272, i273, i274] -> [kfromS (smaxIndex (u271 !$ [i274, i272, i273])), i273, i272, i274]))) (\\[i275, i276] -> [i275, i276, i275 + i276]))) (\\[i277, i278] -> [i277, i277 + i278, i278]))) ; t280 = ssum @48 (stranspose @[3,0,1,2] (sreshape @[4,7,11,48] (w279 * str (sreplicate @7 (str (sreplicate @11 (tproject1 (tproject2 (tproject1 u1))))))))) + stranspose @[2,0,1] (sreplicate @7 (sreplicate @11 (tproject2 (tproject2 (tproject1 u1))))) ; m281 = str (sreplicate @2 (sconcrete (sreplicate [3] 2) * siota (SNat @3))) + sreplicate @3 (siota (SNat @2)) ; m282 = str (sreplicate @2 (sconcrete (sreplicate [5] 2) * siota (SNat @5))) + sreplicate @5 (siota (SNat @2)) ; w294 = sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i283, i284, i285, i286, i287] -> [let x288 = m281 !$ [i284, i286] ; x289 = m282 !$ [i285, i287] in ifH (sscalar -0.0 <=. negate (t280 !$ [i283, kfromS x288, kfromS x289])) 0 1]) ; w295 = stranspose @[4,0,1,2,3] (sgather (stranspose @[1,2,0] t280) (\\[i290, i291, i292, i293] -> [kfromS (m281 !$ [i290, i292]), kfromS (m282 !$ [i291, i293])])) ; u296 = sreshape @[4,3,5,4] (w294 * w295) ; m300 = str (sreplicate @5 (sreshape @[60] (sgather u296 (\\[i297, i298, i299] -> [i297, i298, i299, kfromS (smaxIndex (u296 !$ [i297, i298, i299]))])))) ; m301 = str (sreplicate @7 (ssum @60 (str (tproject1 (tproject1 (tproject2 u1))) * m300))) + str (sreplicate @7 (tproject2 (tproject1 (tproject2 u1)))) ; m304 = sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i302, i303] -> [ifH (sscalar -0.0 <=. negate (m301 !$ [i302, i303])) 0 1]) ; t305 = str (sreplicate @10 (m304 * m301)) ; m307 = m304 * ssum @10 (str (stranspose @[2,1,0] (sreplicate @7 (tproject1 (tproject2 (tproject2 u1)))) * sreplicate @5 dret)) ; m308 = sreplicate @60 (ssum @7 (str m307)) ; t316 = stranspose @[2,0,1] (sscatter (stranspose @[1,2,3,4,0] (w294 * sreshape @[4,3,5,2,2] (sscatter (sreshape @[4,3,5] (ssum @5 (str (str (tproject1 (tproject1 (tproject2 u1))) * m308)))) (\\[i309, i310, i311] -> [i309, i310, i311, kfromS (smaxIndex (u296 !$ [i309, i310, i311]))])))) (\\[i312, i313, i314, i315] -> [kfromS (m281 !$ [i312, i314]), kfromS (m282 !$ [i313, i315])])) ; w317 = sreshape @[4,7,11,4,3,4] (stranspose @[1,2,3,0] (sreplicate @48 t316)) ; t329 = stranspose @[2,0,1] (sscatter (stranspose @[1,2,3,4,0] (w269 * sreshape @[4,7,11,2,2] (stranspose @[3,2,1,0] (ssum @4 (ssum @7 (ssum @11 (stranspose @[4,5,6,3,2,1,0] (ssum @3 (stranspose @[7,3,2,1,0,6,5,4] (sscatter (stranspose @[2,5,4,6,0,3,1] (sscatter (stranspose @[3,5,0,4,1,2] (sscatter (stranspose @[1,4,0,2,3] (ssum @4 (str (sreplicate @7 (str (sreplicate @11 (tproject1 (tproject2 (tproject1 u1)))))) * w317))) (\\[i318, i319] -> [i318, i318 + i319, i319]))) (\\[i320, i321] -> [i320, i321, i320 + i321]))) (\\[i322, i323, i324] -> [kfromS (smaxIndex (u271 !$ [i324, i322, i323])), i323, i322, i324]))))))))))) (\\[i325, i326, i327, i328] -> [kfromS (m256 !$ [i325, i327]), kfromS (m257 !$ [i326, i328])])) in tpair (tpair (tpair (str (soneHot (ssum @1 (str (ssum @1 (str (ssum @23 (str (ssum @14 (str (w253 * sreshape @[4,14,23,1,1,3,4] (stranspose @[1,2,3,0] (sreplicate @12 t329))))))))))) [0])) (ssum @23 (ssum @14 (stranspose @[1,2,0] t329)))) (tpair (ssum @11 (str (ssum @7 (str (w279 * w317))))) (ssum @11 (ssum @7 (stranspose @[1,2,0] t316))))) (tpair (tpair (str (m300 * m308)) (ssum @7 (str m307))) (tpair (ssum @7 (stranspose @[2,1,0] (t305 * sreplicate @5 dret))) (ssum @7 (str dret))))"