module TestMnistRNNS
( testTrees
) where
import Prelude
import Control.Monad (foldM, unless)
import Data.Proxy (Proxy (Proxy))
import Data.Type.Equality ((:~:) (Refl))
import GHC.TypeLits (KnownNat, sameNat)
import System.IO (hPutStrLn, stderr)
import System.Random
import Test.Tasty
import Test.Tasty.HUnit hiding (assert)
import Text.Printf
import Data.Array.Nested.Shaped.Shape
import HordeAd
import HordeAd.Core.Adaptor
import HordeAd.Core.AstEnv
import HordeAd.Core.AstFreshId
import HordeAd.Core.AstInterpret
import EqEpsilon
import MnistData
import MnistRnnShaped2 (ADRnnMnistParametersShaped)
import MnistRnnShaped2 qualified
type XParams out_width r =
X (ADRnnMnistParametersShaped Concrete SizeMnistHeight out_width r)
testTrees :: [TestTree]
testTrees :: [TestTree]
testTrees = [ TestTree
tensorADValMnistTestsRNNSA
, TestTree
tensorADValMnistTestsRNNSI
, TestTree
tensorADValMnistTestsRNNSO
]
mnistTestCaseRNNSA
:: forall width batch_size r.
(Differentiable r, GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r)
=> String
-> Int -> Int -> SNat width -> SNat batch_size -> Int -> r
-> TestTree
mnistTestCaseRNNSA :: forall (width :: Nat) (batch_size :: Nat) r.
(Differentiable r, GoodScalar r, PrintfArg r,
AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> SNat width
-> SNat batch_size
-> Int
-> r
-> TestTree
mnistTestCaseRNNSA String
prefix Int
epochs Int
maxBatches width :: SNat width
width@SNat width
SNat batch_size :: SNat batch_size
batch_size@SNat batch_size
SNat
Int
totalBatchSize r
expected =
let targetInit :: Concrete (XParams width r)
targetInit =
(Concrete (XParams width r), StdGen) -> Concrete (XParams width r)
forall a b. (a, b) -> a
fst ((Concrete (XParams width r), StdGen)
-> Concrete (XParams width r))
-> (Concrete (XParams width r), StdGen)
-> Concrete (XParams width r)
forall a b. (a -> b) -> a -> b
$ forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue @(Concrete (XParams width r)) Double
0.23 (Int -> StdGen
mkStdGen Int
44)
miniBatchSize :: Int
miniBatchSize = SNat batch_size -> Int
forall (n :: Nat). SNat n -> Int
sNatValue SNat batch_size
batch_size
name :: String
name = String
prefix String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": "
String -> String -> String
forall a. [a] -> [a] -> [a]
++ [String] -> String
unwords [ Int -> String
forall a. Show a => a -> String
show Int
epochs, Int -> String
forall a. Show a => a -> String
show Int
maxBatches
, Int -> String
forall a. Show a => a -> String
show (SNat width -> Int
forall (n :: Nat). SNat n -> Int
sNatValue SNat width
width), Int -> String
forall a. Show a => a -> String
show Int
miniBatchSize
, Int -> String
forall a. Show a => a -> String
show (Int -> String) -> Int -> String
forall a b. (a -> b) -> a -> b
$ SingletonTK (XParams width r) -> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK
(SingletonTK (XParams width r) -> Int)
-> SingletonTK (XParams width r) -> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams width r)
, Int -> String
forall a. Show a => a -> String
show (SingletonTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
Concrete (XParams width r)
targetInit) ]
ftest :: forall batch_size2. KnownNat batch_size2
=> MnistDataBatchS batch_size2 r -> Concrete (XParams width r)
-> r
ftest :: forall (batch_size2 :: Nat).
KnownNat batch_size2 =>
MnistDataBatchS batch_size2 r -> Concrete (XParams width r) -> r
ftest MnistDataBatchS batch_size2 r
_ Concrete (XParams width r)
_ | Just (:~:) @Nat 0 batch_size2
Refl <- Proxy @Nat 0
-> Proxy @Nat batch_size2 -> Maybe ((:~:) @Nat 0 batch_size2)
forall (a :: Nat) (b :: Nat) (proxy1 :: Nat -> Type)
(proxy2 :: Nat -> Type).
(KnownNat a, KnownNat b) =>
proxy1 a -> proxy2 b -> Maybe ((:~:) @Nat a b)
sameNat (forall (t :: Nat). Proxy @Nat t
forall {k} (t :: k). Proxy @k t
Proxy @0) (forall (t :: Nat). Proxy @Nat t
forall {k} (t :: k). Proxy @k t
Proxy @batch_size2) = r
0
ftest MnistDataBatchS batch_size2 r
mnistData Concrete (XParams width r)
testParams =
SNat width
-> SNat batch_size2
-> MnistDataBatchS batch_size2 r
-> ADRnnMnistParametersShaped Concrete SizeMnistHeight width r
-> r
forall (target :: TK -> Type) (h :: Nat) (w :: Nat)
(out_width :: Nat) (batch_size :: Nat) r.
((h :: Nat) ~ (SizeMnistHeight :: Nat),
(w :: Nat) ~ (SizeMnistHeight :: Nat),
(target :: (TK -> Type)) ~ (Concrete :: (TK -> Type)),
Differentiable r, GoodScalar r) =>
SNat out_width
-> SNat batch_size
-> MnistDataBatchS batch_size r
-> ADRnnMnistParametersShaped target h out_width r
-> r
MnistRnnShaped2.rnnMnistTestS
SNat width
width (forall (n :: Nat). KnownNat n => SNat n
SNat @batch_size2) MnistDataBatchS batch_size2 r
mnistData (forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget @Concrete Concrete (XParams width r)
testParams)
in String -> Assertion -> TestTree
testCase String
name (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ do
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: Epochs to run/max batches per epoch: %d/%d"
String
prefix Int
epochs Int
maxBatches
trainData <- (MnistData r -> MnistDataS r) -> [MnistData r] -> [MnistDataS r]
forall a b. (a -> b) -> [a] -> [b]
map MnistData r -> MnistDataS r
forall r. PrimElt r => MnistData r -> MnistDataS r
mkMnistDataS
([MnistData r] -> [MnistDataS r])
-> IO [MnistData r] -> IO [MnistDataS r]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> String -> IO [MnistData r]
forall r.
(Storable r, Fractional r) =>
String -> String -> IO [MnistData r]
loadMnistData String
trainGlyphsPath String
trainLabelsPath
testData <- map mkMnistDataS . take (totalBatchSize * maxBatches)
<$> loadMnistData testGlyphsPath testLabelsPath
withSNat ((totalBatchSize * maxBatches) `min` 10000)
$ \(SNat @lenTestData) -> do
let testDataS :: MnistDataBatchS n r
testDataS = forall (batch_size :: Nat) r.
(Elt r, KnownNat batch_size) =>
[MnistDataS r] -> MnistDataBatchS batch_size r
mkMnistDataBatchS @lenTestData [MnistDataS r]
testData
f :: MnistDataBatchS batch_size r
-> ADVal Concrete (XParams width r)
-> ADVal Concrete (TKScalar r)
f :: MnistDataBatchS batch_size r
-> ADVal Concrete (XParams width r) -> ADVal Concrete (TKScalar r)
f (Shaped
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r
glyphS, Shaped
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r
labelS) ADVal Concrete (XParams width r)
adinputs =
SNat width
-> SNat batch_size
-> (PrimalOf
(ADVal Concrete)
(TKS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r),
PrimalOf
(ADVal Concrete)
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r))
-> ADRnnMnistParametersShaped
(ADVal Concrete) SizeMnistHeight width r
-> ADVal Concrete (TKScalar r)
forall (target :: TK -> Type) (h :: Nat) (w :: Nat)
(out_width :: Nat) (batch_size :: Nat) r.
((h :: Nat) ~ (SizeMnistHeight :: Nat),
(w :: Nat) ~ (SizeMnistHeight :: Nat), Differentiable r,
ADReady target, ADReady (PrimalOf target), GoodScalar r) =>
SNat out_width
-> SNat batch_size
-> (PrimalOf
target
(TKS
((':) @Nat batch_size ((':) @Nat h ((':) @Nat w ('[] @Nat)))) r),
PrimalOf
target
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r))
-> ADRnnMnistParametersShaped target h out_width r
-> target (TKScalar r)
MnistRnnShaped2.rnnMnistLossFusedS
SNat width
width SNat batch_size
batch_size (Shaped
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r
-> Concrete
(TKS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r)
forall r (target :: TK -> Type) (sh :: [Nat]).
(GoodScalar r, BaseTensor target) =>
Shaped sh r -> target (TKS sh r)
sconcrete Shaped
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r
glyphS, Shaped
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r
-> Concrete
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)
forall r (target :: TK -> Type) (sh :: [Nat]).
(GoodScalar r, BaseTensor target) =>
Shaped sh r -> target (TKS sh r)
sconcrete Shaped
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r
labelS)
(forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget @(ADVal Concrete) ADVal
Concrete
(X (ADRnnMnistParametersShaped
(ADVal Concrete) SizeMnistHeight width r))
ADVal Concrete (XParams width r)
adinputs)
runBatch :: ( Concrete (XParams width r)
, StateAdam (XParams width r) )
-> (Int, [MnistDataS r])
-> IO ( Concrete (XParams width r)
, StateAdam (XParams width r) )
runBatch :: (Concrete (XParams width r), StateAdam (XParams width r))
-> (Int, [MnistDataS r])
-> IO (Concrete (XParams width r), StateAdam (XParams width r))
runBatch (!Concrete (XParams width r)
parameters, !StateAdam (XParams width r)
stateAdam) (Int
k, [MnistDataS r]
chunk) = do
let chunkS :: [MnistDataBatchS batch_size r]
chunkS = ([MnistDataS r] -> MnistDataBatchS batch_size r)
-> [[MnistDataS r]] -> [MnistDataBatchS batch_size r]
forall a b. (a -> b) -> [a] -> [b]
map [MnistDataS r] -> MnistDataBatchS batch_size r
forall (batch_size :: Nat) r.
(Elt r, KnownNat batch_size) =>
[MnistDataS r] -> MnistDataBatchS batch_size r
mkMnistDataBatchS
([[MnistDataS r]] -> [MnistDataBatchS batch_size r])
-> [[MnistDataS r]] -> [MnistDataBatchS batch_size r]
forall a b. (a -> b) -> a -> b
$ ([MnistDataS r] -> Bool) -> [[MnistDataS r]] -> [[MnistDataS r]]
forall a. (a -> Bool) -> [a] -> [a]
filter (\[MnistDataS r]
ch -> [MnistDataS r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataS r]
ch Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
miniBatchSize)
([[MnistDataS r]] -> [[MnistDataS r]])
-> [[MnistDataS r]] -> [[MnistDataS r]]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataS r] -> [[MnistDataS r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
miniBatchSize [MnistDataS r]
chunk
res :: (Concrete (XParams width r), StateAdam (XParams width r))
res@(Concrete (XParams width r)
parameters2, StateAdam (XParams width r)
_) =
forall a (x :: TK) (z :: TK).
KnownSTK x =>
(a -> ADVal Concrete x -> ADVal Concrete z)
-> [a] -> Concrete x -> StateAdam x -> (Concrete x, StateAdam x)
sgdAdam @(MnistDataBatchS batch_size r)
@(XParams width r)
MnistDataBatchS batch_size r
-> ADVal Concrete (XParams width r) -> ADVal Concrete (TKScalar r)
f [MnistDataBatchS batch_size r]
chunkS Concrete (XParams width r)
parameters StateAdam (XParams width r)
stateAdam
trainScore :: r
trainScore = Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat ([MnistDataS r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataS r]
chunk) ((forall (n :: Nat). KnownNat n => SNat n -> r) -> r)
-> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
forall a b. (a -> b) -> a -> b
$ \(SNat @len) ->
forall (batch_size2 :: Nat).
KnownNat batch_size2 =>
MnistDataBatchS batch_size2 r -> Concrete (XParams width r) -> r
ftest @len ([MnistDataS r] -> MnistDataBatchS n r
forall (batch_size :: Nat) r.
(Elt r, KnownNat batch_size) =>
[MnistDataS r] -> MnistDataBatchS batch_size r
mkMnistDataBatchS [MnistDataS r]
chunk) Concrete (XParams width r)
parameters2
testScore :: r
testScore = forall (batch_size2 :: Nat).
KnownNat batch_size2 =>
MnistDataBatchS batch_size2 r -> Concrete (XParams width r) -> r
ftest @lenTestData MnistDataBatchS n r
testDataS Concrete (XParams width r)
parameters2
lenChunk :: Int
lenChunk = [MnistDataS r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataS r]
chunk
Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (SNat width -> Int
forall (n :: Nat). SNat n -> Int
sNatValue SNat width
width Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$ do
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: (Batch %d with %d points)"
String
prefix Int
k Int
lenChunk
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Training error: %.2f%%"
String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
trainScore) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Validation error: %.2f%%"
String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
testScore ) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
(Concrete (XParams width r), StateAdam (XParams width r))
res
let runEpoch :: Int
-> ( Concrete (XParams width r)
, StateAdam (XParams width r) )
-> IO (Concrete (XParams width r))
runEpoch :: Int
-> (Concrete (XParams width r), StateAdam (XParams width r))
-> IO (Concrete (XParams width r))
runEpoch Int
n (Concrete (XParams width r)
params2, StateAdam (XParams width r)
_) | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
epochs = Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
Concrete (XParams width r)
params2
runEpoch Int
n paramsStateAdam :: (Concrete (XParams width r), StateAdam (XParams width r))
paramsStateAdam@(!Concrete (XParams width r)
_, !StateAdam (XParams width r)
_) = do
Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (SNat width -> Int
forall (n :: Nat). SNat n -> Int
sNatValue SNat width
width Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$ String -> String -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: [Epoch %d]" String
prefix Int
n
let trainDataShuffled :: [MnistDataS r]
trainDataShuffled = StdGen -> [MnistDataS r] -> [MnistDataS r]
forall a. StdGen -> [a] -> [a]
shuffle (Int -> StdGen
mkStdGen (Int -> StdGen) -> Int -> StdGen
forall a b. (a -> b) -> a -> b
$ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
5) [MnistDataS r]
trainData
chunks :: [(Int, [MnistDataS r])]
chunks = Int -> [(Int, [MnistDataS r])] -> [(Int, [MnistDataS r])]
forall a. Int -> [a] -> [a]
take Int
maxBatches
([(Int, [MnistDataS r])] -> [(Int, [MnistDataS r])])
-> [(Int, [MnistDataS r])] -> [(Int, [MnistDataS r])]
forall a b. (a -> b) -> a -> b
$ [Int] -> [[MnistDataS r]] -> [(Int, [MnistDataS r])]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
1 ..]
([[MnistDataS r]] -> [(Int, [MnistDataS r])])
-> [[MnistDataS r]] -> [(Int, [MnistDataS r])]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataS r] -> [[MnistDataS r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
totalBatchSize [MnistDataS r]
trainDataShuffled
res <- ((Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
-> (Int, [MnistDataS r])
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
-> (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
-> [(Int, [MnistDataS r])]
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
forall (t :: Type -> Type) (m :: Type -> Type) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
-> (Int, [MnistDataS r])
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
(Concrete (XParams width r), StateAdam (XParams width r))
-> (Int, [MnistDataS r])
-> IO (Concrete (XParams width r), StateAdam (XParams width r))
runBatch (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
(Concrete (XParams width r), StateAdam (XParams width r))
paramsStateAdam [(Int, [MnistDataS r])]
chunks
runEpoch (succ n) res
ftk :: FullShapeTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
ftk = forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams width r))
Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
Concrete (XParams width r)
targetInit
res <- Int
-> (Concrete (XParams width r), StateAdam (XParams width r))
-> IO (Concrete (XParams width r))
runEpoch Int
1 (Concrete (XParams width r)
targetInit, FullShapeTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
forall (y :: TK). FullShapeTK y -> StateAdam y
initialStateAdam FullShapeTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
ftk)
let testErrorFinal = r
1 r -> r -> r
forall a. Num a => a -> a -> a
- forall (batch_size2 :: Nat).
KnownNat batch_size2 =>
MnistDataBatchS batch_size2 r -> Concrete (XParams width r) -> r
ftest @lenTestData MnistDataBatchS n r
testDataS Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
Concrete (XParams width r)
res
testErrorFinal @?~ expected
{-# SPECIALIZE mnistTestCaseRNNSA
:: String
-> Int -> Int -> SNat width -> SNat batch_size -> Int -> Double
-> TestTree #-}
tensorADValMnistTestsRNNSA :: TestTree
tensorADValMnistTestsRNNSA :: TestTree
tensorADValMnistTestsRNNSA = String -> [TestTree] -> TestTree
testGroup String
"RNNS ADVal MNIST tests"
[ String
-> Int -> Int -> SNat 128 -> SNat 150 -> Int -> Double -> TestTree
forall (width :: Nat) (batch_size :: Nat) r.
(Differentiable r, GoodScalar r, PrintfArg r,
AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> SNat width
-> SNat batch_size
-> Int
-> r
-> TestTree
mnistTestCaseRNNSA String
"RNNSA 1 epoch, 1 batch" Int
1 Int
1 (forall (n :: Nat). KnownNat n => SNat n
SNat @128) (forall (n :: Nat). KnownNat n => SNat n
SNat @150) Int
5000
(Double
0.6026 :: Double)
, String
-> Int -> Int -> SNat 4 -> SNat 5 -> Int -> Float -> TestTree
forall (width :: Nat) (batch_size :: Nat) r.
(Differentiable r, GoodScalar r, PrintfArg r,
AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> SNat width
-> SNat batch_size
-> Int
-> r
-> TestTree
mnistTestCaseRNNSA String
"RNNSA artificial 1 2 3 4 5" Int
2 Int
3 (forall (n :: Nat). KnownNat n => SNat n
SNat @4) (forall (n :: Nat). KnownNat n => SNat n
SNat @5) Int
50
(Float
0.8933333 :: Float)
, String
-> Int -> Int -> SNat 3 -> SNat 2 -> Int -> Double -> TestTree
forall (width :: Nat) (batch_size :: Nat) r.
(Differentiable r, GoodScalar r, PrintfArg r,
AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> SNat width
-> SNat batch_size
-> Int
-> r
-> TestTree
mnistTestCaseRNNSA String
"RNNSA artificial 5 4 3 2 1" Int
5 Int
4 (forall (n :: Nat). KnownNat n => SNat n
SNat @3) (forall (n :: Nat). KnownNat n => SNat n
SNat @2) Int
49
(Double
0.8622448979591837 :: Double)
, String
-> Int -> Int -> SNat 128 -> SNat 150 -> Int -> Float -> TestTree
forall (width :: Nat) (batch_size :: Nat) r.
(Differentiable r, GoodScalar r, PrintfArg r,
AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> SNat width
-> SNat batch_size
-> Int
-> r
-> TestTree
mnistTestCaseRNNSA String
"RNNSA 1 epoch, 0 batch" Int
1 Int
0 (forall (n :: Nat). KnownNat n => SNat n
SNat @128) (forall (n :: Nat). KnownNat n => SNat n
SNat @150) Int
50
(Float
1.0 :: Float)
]
mnistTestCaseRNNSI
:: forall width batch_size r.
(Differentiable r, GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r)
=> String
-> Int -> Int -> SNat width -> SNat batch_size -> Int -> r
-> TestTree
mnistTestCaseRNNSI :: forall (width :: Nat) (batch_size :: Nat) r.
(Differentiable r, GoodScalar r, PrintfArg r,
AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> SNat width
-> SNat batch_size
-> Int
-> r
-> TestTree
mnistTestCaseRNNSI String
prefix Int
epochs Int
maxBatches width :: SNat width
width@SNat width
SNat batch_size :: SNat batch_size
batch_size@SNat batch_size
SNat
Int
totalBatchSize r
expected =
let targetInit :: Concrete (XParams width r)
targetInit =
(Concrete (XParams width r), StdGen) -> Concrete (XParams width r)
forall a b. (a, b) -> a
fst ((Concrete (XParams width r), StdGen)
-> Concrete (XParams width r))
-> (Concrete (XParams width r), StdGen)
-> Concrete (XParams width r)
forall a b. (a -> b) -> a -> b
$ forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue @(Concrete (XParams width r)) Double
0.23 (Int -> StdGen
mkStdGen Int
44)
miniBatchSize :: Int
miniBatchSize = SNat batch_size -> Int
forall (n :: Nat). SNat n -> Int
sNatValue SNat batch_size
batch_size
name :: String
name = String
prefix String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": "
String -> String -> String
forall a. [a] -> [a] -> [a]
++ [String] -> String
unwords [ Int -> String
forall a. Show a => a -> String
show Int
epochs, Int -> String
forall a. Show a => a -> String
show Int
maxBatches
, Int -> String
forall a. Show a => a -> String
show (SNat width -> Int
forall (n :: Nat). SNat n -> Int
sNatValue SNat width
width), Int -> String
forall a. Show a => a -> String
show Int
miniBatchSize
, Int -> String
forall a. Show a => a -> String
show (Int -> String) -> Int -> String
forall a b. (a -> b) -> a -> b
$ SingletonTK (XParams width r) -> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK
(SingletonTK (XParams width r) -> Int)
-> SingletonTK (XParams width r) -> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams width r)
, Int -> String
forall a. Show a => a -> String
show (SingletonTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
Concrete (XParams width r)
targetInit) ]
ftest :: forall batch_size2. KnownNat batch_size2
=> MnistDataBatchS batch_size2 r -> Concrete (XParams width r)
-> r
ftest :: forall (batch_size2 :: Nat).
KnownNat batch_size2 =>
MnistDataBatchS batch_size2 r -> Concrete (XParams width r) -> r
ftest MnistDataBatchS batch_size2 r
_ Concrete (XParams width r)
_ | Just (:~:) @Nat 0 batch_size2
Refl <- Proxy @Nat 0
-> Proxy @Nat batch_size2 -> Maybe ((:~:) @Nat 0 batch_size2)
forall (a :: Nat) (b :: Nat) (proxy1 :: Nat -> Type)
(proxy2 :: Nat -> Type).
(KnownNat a, KnownNat b) =>
proxy1 a -> proxy2 b -> Maybe ((:~:) @Nat a b)
sameNat (forall (t :: Nat). Proxy @Nat t
forall {k} (t :: k). Proxy @k t
Proxy @0) (forall (t :: Nat). Proxy @Nat t
forall {k} (t :: k). Proxy @k t
Proxy @batch_size2) = r
0
ftest MnistDataBatchS batch_size2 r
mnistData Concrete (XParams width r)
testParams =
SNat width
-> SNat batch_size2
-> MnistDataBatchS batch_size2 r
-> ADRnnMnistParametersShaped Concrete SizeMnistHeight width r
-> r
forall (target :: TK -> Type) (h :: Nat) (w :: Nat)
(out_width :: Nat) (batch_size :: Nat) r.
((h :: Nat) ~ (SizeMnistHeight :: Nat),
(w :: Nat) ~ (SizeMnistHeight :: Nat),
(target :: (TK -> Type)) ~ (Concrete :: (TK -> Type)),
Differentiable r, GoodScalar r) =>
SNat out_width
-> SNat batch_size
-> MnistDataBatchS batch_size r
-> ADRnnMnistParametersShaped target h out_width r
-> r
MnistRnnShaped2.rnnMnistTestS
SNat width
width (forall (n :: Nat). KnownNat n => SNat n
SNat @batch_size2) MnistDataBatchS batch_size2 r
mnistData (forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget @Concrete Concrete (XParams width r)
testParams)
in String -> Assertion -> TestTree
testCase String
name (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ do
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: Epochs to run/max batches per epoch: %d/%d"
String
prefix Int
epochs Int
maxBatches
trainData <- (MnistData r -> MnistDataS r) -> [MnistData r] -> [MnistDataS r]
forall a b. (a -> b) -> [a] -> [b]
map MnistData r -> MnistDataS r
forall r. PrimElt r => MnistData r -> MnistDataS r
mkMnistDataS
([MnistData r] -> [MnistDataS r])
-> IO [MnistData r] -> IO [MnistDataS r]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> String -> IO [MnistData r]
forall r.
(Storable r, Fractional r) =>
String -> String -> IO [MnistData r]
loadMnistData String
trainGlyphsPath String
trainLabelsPath
testData <- map mkMnistDataS . take (totalBatchSize * maxBatches)
<$> loadMnistData testGlyphsPath testLabelsPath
withSNat ((totalBatchSize * maxBatches) `min` 10000)
$ \(SNat @lenTestData) -> do
let testDataS :: MnistDataBatchS n r
testDataS = forall (batch_size :: Nat) r.
(Elt r, KnownNat batch_size) =>
[MnistDataS r] -> MnistDataBatchS batch_size r
mkMnistDataBatchS @lenTestData [MnistDataS r]
testData
ftk :: FullShapeTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
ftk = forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams width r)) Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
Concrete (XParams width r)
targetInit
(_, _, var, varAst) <- FullShapeTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> IO
(AstVarName
PrimalSpan
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
AstTensor
AstMethodShare
PrimalSpan
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
AstVarName
FullSpan
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
AstTensor
AstMethodLet
FullSpan
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
forall (x :: TK).
FullShapeTK x
-> IO
(AstVarName PrimalSpan x, AstTensor AstMethodShare PrimalSpan x,
AstVarName FullSpan x, AstTensor AstMethodLet FullSpan x)
funToAstRevIO FullShapeTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
ftk
(varGlyph, astGlyph) <- funToAstIO (FTKS knownShS FTKScalar) id
(varLabel, astLabel) <- funToAstIO (FTKS knownShS FTKScalar) id
let ast :: AstTensor AstMethodLet FullSpan (TKScalar r)
ast = AstTensor AstMethodLet FullSpan (TKScalar r)
-> AstTensor AstMethodLet FullSpan (TKScalar r)
forall (z :: TK) (s :: AstSpanType).
AstSpan s =>
AstTensor AstMethodLet s z -> AstTensor AstMethodLet s z
simplifyInline
(AstTensor AstMethodLet FullSpan (TKScalar r)
-> AstTensor AstMethodLet FullSpan (TKScalar r))
-> AstTensor AstMethodLet FullSpan (TKScalar r)
-> AstTensor AstMethodLet FullSpan (TKScalar r)
forall a b. (a -> b) -> a -> b
$ SNat width
-> SNat batch_size
-> (PrimalOf
(AstTensor AstMethodLet FullSpan)
(TKS2
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
(TKScalar r)),
PrimalOf
(AstTensor AstMethodLet FullSpan)
(TKS2
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat)))
(TKScalar r)))
-> ADRnnMnistParametersShaped
(AstTensor AstMethodLet FullSpan) SizeMnistHeight width r
-> AstTensor AstMethodLet FullSpan (TKScalar r)
forall (target :: TK -> Type) (h :: Nat) (w :: Nat)
(out_width :: Nat) (batch_size :: Nat) r.
((h :: Nat) ~ (SizeMnistHeight :: Nat),
(w :: Nat) ~ (SizeMnistHeight :: Nat), Differentiable r,
ADReady target, ADReady (PrimalOf target), GoodScalar r) =>
SNat out_width
-> SNat batch_size
-> (PrimalOf
target
(TKS
((':) @Nat batch_size ((':) @Nat h ((':) @Nat w ('[] @Nat)))) r),
PrimalOf
target
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r))
-> ADRnnMnistParametersShaped target h out_width r
-> target (TKScalar r)
MnistRnnShaped2.rnnMnistLossFusedS
SNat width
width SNat batch_size
batch_size (AstTensor
AstMethodLet
PrimalSpan
(TKS2
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
(TKScalar r))
PrimalOf
(AstTensor AstMethodLet FullSpan)
(TKS2
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
(TKScalar r))
astGlyph, AstTensor
AstMethodLet
PrimalSpan
(TKS2
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat)))
(TKScalar r))
PrimalOf
(AstTensor AstMethodLet FullSpan)
(TKS2
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat)))
(TKScalar r))
astLabel)
(AstTensor
AstMethodLet
FullSpan
(X (ADRnnMnistParametersShaped
(AstTensor AstMethodLet FullSpan) SizeMnistHeight width r))
-> ADRnnMnistParametersShaped
(AstTensor AstMethodLet FullSpan) SizeMnistHeight width r
forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget AstTensor
AstMethodLet
FullSpan
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
AstTensor
AstMethodLet
FullSpan
(X (ADRnnMnistParametersShaped
(AstTensor AstMethodLet FullSpan) SizeMnistHeight width r))
varAst)
f :: MnistDataBatchS batch_size r
-> ADVal Concrete (XParams width r)
-> ADVal Concrete (TKScalar r)
f (Shaped
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r
glyph, Shaped
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r
label) ADVal Concrete (XParams width r)
varInputs =
let env :: AstEnv (ADVal Concrete)
env = AstVarName
FullSpan
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> ADVal
Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> AstEnv (ADVal Concrete)
-> AstEnv (ADVal Concrete)
forall (target :: TK -> Type) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName
FullSpan
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
var ADVal
Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
ADVal Concrete (XParams width r)
varInputs AstEnv (ADVal Concrete)
forall (target :: TK -> Type). AstEnv target
emptyEnv
envMnist :: AstEnv (ADVal Concrete)
envMnist = AstVarName
PrimalSpan
(TKS2
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
(TKScalar r))
-> ADVal
Concrete
(TKS2
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
(TKScalar r))
-> AstEnv (ADVal Concrete)
-> AstEnv (ADVal Concrete)
forall (target :: TK -> Type) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName
PrimalSpan
(TKS2
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
(TKScalar r))
varGlyph (Shaped
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r
-> ADVal
Concrete
(TKS2
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
(TKScalar r))
forall r (target :: TK -> Type) (sh :: [Nat]).
(GoodScalar r, BaseTensor target) =>
Shaped sh r -> target (TKS sh r)
sconcrete Shaped
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r
glyph)
(AstEnv (ADVal Concrete) -> AstEnv (ADVal Concrete))
-> AstEnv (ADVal Concrete) -> AstEnv (ADVal Concrete)
forall a b. (a -> b) -> a -> b
$ AstVarName
PrimalSpan
(TKS2
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat)))
(TKScalar r))
-> ADVal
Concrete
(TKS2
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat)))
(TKScalar r))
-> AstEnv (ADVal Concrete)
-> AstEnv (ADVal Concrete)
forall (target :: TK -> Type) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName
PrimalSpan
(TKS2
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat)))
(TKScalar r))
varLabel (Shaped
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r
-> ADVal
Concrete
(TKS2
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat)))
(TKScalar r))
forall r (target :: TK -> Type) (sh :: [Nat]).
(GoodScalar r, BaseTensor target) =>
Shaped sh r -> target (TKS sh r)
sconcrete Shaped
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r
label) AstEnv (ADVal Concrete)
env
in AstEnv (ADVal Concrete)
-> AstTensor AstMethodLet FullSpan (TKScalar r)
-> ADVal Concrete (TKScalar r)
forall (target :: TK -> Type) (y :: TK).
ADReady target =>
AstEnv target -> AstTensor AstMethodLet FullSpan y -> target y
interpretAstFull AstEnv (ADVal Concrete)
envMnist AstTensor AstMethodLet FullSpan (TKScalar r)
ast
runBatch :: ( Concrete (XParams width r)
, StateAdam (XParams width r) )
-> (Int, [MnistDataS r])
-> IO ( Concrete (XParams width r)
, StateAdam (XParams width r) )
runBatch (!Concrete (XParams width r)
parameters, !StateAdam (XParams width r)
stateAdam) (Int
k, [MnistDataS r]
chunk) = do
let chunkS :: [MnistDataBatchS batch_size r]
chunkS = ([MnistDataS r] -> MnistDataBatchS batch_size r)
-> [[MnistDataS r]] -> [MnistDataBatchS batch_size r]
forall a b. (a -> b) -> [a] -> [b]
map [MnistDataS r] -> MnistDataBatchS batch_size r
forall (batch_size :: Nat) r.
(Elt r, KnownNat batch_size) =>
[MnistDataS r] -> MnistDataBatchS batch_size r
mkMnistDataBatchS
([[MnistDataS r]] -> [MnistDataBatchS batch_size r])
-> [[MnistDataS r]] -> [MnistDataBatchS batch_size r]
forall a b. (a -> b) -> a -> b
$ ([MnistDataS r] -> Bool) -> [[MnistDataS r]] -> [[MnistDataS r]]
forall a. (a -> Bool) -> [a] -> [a]
filter (\[MnistDataS r]
ch -> [MnistDataS r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataS r]
ch Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
miniBatchSize)
([[MnistDataS r]] -> [[MnistDataS r]])
-> [[MnistDataS r]] -> [[MnistDataS r]]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataS r] -> [[MnistDataS r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
miniBatchSize [MnistDataS r]
chunk
res :: (Concrete (XParams width r), StateAdam (XParams width r))
res@(Concrete (XParams width r)
parameters2, StateAdam (XParams width r)
_) =
forall a (x :: TK) (z :: TK).
KnownSTK x =>
(a -> ADVal Concrete x -> ADVal Concrete z)
-> [a] -> Concrete x -> StateAdam x -> (Concrete x, StateAdam x)
sgdAdam @(MnistDataBatchS batch_size r)
@(XParams width r)
MnistDataBatchS batch_size r
-> ADVal Concrete (XParams width r) -> ADVal Concrete (TKScalar r)
f [MnistDataBatchS batch_size r]
chunkS Concrete (XParams width r)
parameters StateAdam (XParams width r)
stateAdam
trainScore :: r
trainScore = Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat ([MnistDataS r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataS r]
chunk) ((forall (n :: Nat). KnownNat n => SNat n -> r) -> r)
-> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
forall a b. (a -> b) -> a -> b
$ \(SNat @len) ->
forall (batch_size2 :: Nat).
KnownNat batch_size2 =>
MnistDataBatchS batch_size2 r -> Concrete (XParams width r) -> r
ftest @len ([MnistDataS r] -> MnistDataBatchS n r
forall (batch_size :: Nat) r.
(Elt r, KnownNat batch_size) =>
[MnistDataS r] -> MnistDataBatchS batch_size r
mkMnistDataBatchS [MnistDataS r]
chunk) Concrete (XParams width r)
parameters2
testScore :: r
testScore = forall (batch_size2 :: Nat).
KnownNat batch_size2 =>
MnistDataBatchS batch_size2 r -> Concrete (XParams width r) -> r
ftest @lenTestData MnistDataBatchS n r
testDataS Concrete (XParams width r)
parameters2
lenChunk :: Int
lenChunk = [MnistDataS r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataS r]
chunk
Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (SNat width -> Int
forall (n :: Nat). SNat n -> Int
sNatValue SNat width
width Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$ do
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: (Batch %d with %d points)"
String
prefix Int
k Int
lenChunk
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Training error: %.2f%%"
String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
trainScore) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Validation error: %.2f%%"
String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
testScore ) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
(Concrete (XParams width r), StateAdam (XParams width r))
res
let runEpoch :: Int
-> ( Concrete (XParams width r)
, StateAdam (XParams width r) )
-> IO (Concrete (XParams width r))
runEpoch Int
n (Concrete (XParams width r)
params2, StateAdam (XParams width r)
_) | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
epochs = Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
Concrete (XParams width r)
params2
runEpoch Int
n paramsStateAdam :: (Concrete (XParams width r), StateAdam (XParams width r))
paramsStateAdam@(!Concrete (XParams width r)
_, !StateAdam (XParams width r)
_) = do
Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (SNat width -> Int
forall (n :: Nat). SNat n -> Int
sNatValue SNat width
width Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$ String -> String -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: [Epoch %d]" String
prefix Int
n
let trainDataShuffled :: [MnistDataS r]
trainDataShuffled = StdGen -> [MnistDataS r] -> [MnistDataS r]
forall a. StdGen -> [a] -> [a]
shuffle (Int -> StdGen
mkStdGen (Int -> StdGen) -> Int -> StdGen
forall a b. (a -> b) -> a -> b
$ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
5) [MnistDataS r]
trainData
chunks :: [(Int, [MnistDataS r])]
chunks = Int -> [(Int, [MnistDataS r])] -> [(Int, [MnistDataS r])]
forall a. Int -> [a] -> [a]
take Int
maxBatches
([(Int, [MnistDataS r])] -> [(Int, [MnistDataS r])])
-> [(Int, [MnistDataS r])] -> [(Int, [MnistDataS r])]
forall a b. (a -> b) -> a -> b
$ [Int] -> [[MnistDataS r]] -> [(Int, [MnistDataS r])]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
1 ..]
([[MnistDataS r]] -> [(Int, [MnistDataS r])])
-> [[MnistDataS r]] -> [(Int, [MnistDataS r])]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataS r] -> [[MnistDataS r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
totalBatchSize [MnistDataS r]
trainDataShuffled
res <- ((Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
-> (Int, [MnistDataS r])
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
-> (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
-> [(Int, [MnistDataS r])]
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
forall (t :: Type -> Type) (m :: Type -> Type) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
-> (Int, [MnistDataS r])
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
(Concrete (XParams width r), StateAdam (XParams width r))
-> (Int, [MnistDataS r])
-> IO (Concrete (XParams width r), StateAdam (XParams width r))
runBatch (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
(Concrete (XParams width r), StateAdam (XParams width r))
paramsStateAdam [(Int, [MnistDataS r])]
chunks
runEpoch (succ n) res
res <- runEpoch 1 (targetInit, initialStateAdam ftk)
let testErrorFinal = r
1 r -> r -> r
forall a. Num a => a -> a -> a
- forall (batch_size2 :: Nat).
KnownNat batch_size2 =>
MnistDataBatchS batch_size2 r -> Concrete (XParams width r) -> r
ftest @lenTestData MnistDataBatchS n r
testDataS Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
Concrete (XParams width r)
res
testErrorFinal @?~ expected
{-# SPECIALIZE mnistTestCaseRNNSI
:: String
-> Int -> Int -> SNat width -> SNat batch_size -> Int -> Double
-> TestTree #-}
tensorADValMnistTestsRNNSI :: TestTree
tensorADValMnistTestsRNNSI :: TestTree
tensorADValMnistTestsRNNSI = String -> [TestTree] -> TestTree
testGroup String
"RNNS Intermediate MNIST tests"
[ String
-> Int -> Int -> SNat 128 -> SNat 150 -> Int -> Double -> TestTree
forall (width :: Nat) (batch_size :: Nat) r.
(Differentiable r, GoodScalar r, PrintfArg r,
AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> SNat width
-> SNat batch_size
-> Int
-> r
-> TestTree
mnistTestCaseRNNSI String
"RNNSI 1 epoch, 1 batch" Int
1 Int
1 (forall (n :: Nat). KnownNat n => SNat n
SNat @128) (forall (n :: Nat). KnownNat n => SNat n
SNat @150) Int
5000
(Double
0.6026 :: Double)
, String
-> Int -> Int -> SNat 4 -> SNat 5 -> Int -> Float -> TestTree
forall (width :: Nat) (batch_size :: Nat) r.
(Differentiable r, GoodScalar r, PrintfArg r,
AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> SNat width
-> SNat batch_size
-> Int
-> r
-> TestTree
mnistTestCaseRNNSI String
"RNNSI artificial 1 2 3 4 5" Int
2 Int
3 (forall (n :: Nat). KnownNat n => SNat n
SNat @4) (forall (n :: Nat). KnownNat n => SNat n
SNat @5) Int
50
(Float
0.8933333 :: Float)
, String
-> Int -> Int -> SNat 3 -> SNat 2 -> Int -> Double -> TestTree
forall (width :: Nat) (batch_size :: Nat) r.
(Differentiable r, GoodScalar r, PrintfArg r,
AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> SNat width
-> SNat batch_size
-> Int
-> r
-> TestTree
mnistTestCaseRNNSI String
"RNNSI artificial 5 4 3 2 1" Int
5 Int
4 (forall (n :: Nat). KnownNat n => SNat n
SNat @3) (forall (n :: Nat). KnownNat n => SNat n
SNat @2) Int
49
(Double
0.8622448979591837 :: Double)
, String
-> Int -> Int -> SNat 128 -> SNat 150 -> Int -> Float -> TestTree
forall (width :: Nat) (batch_size :: Nat) r.
(Differentiable r, GoodScalar r, PrintfArg r,
AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> SNat width
-> SNat batch_size
-> Int
-> r
-> TestTree
mnistTestCaseRNNSI String
"RNNSI 1 epoch, 0 batch" Int
1 Int
0 (forall (n :: Nat). KnownNat n => SNat n
SNat @128) (forall (n :: Nat). KnownNat n => SNat n
SNat @150) Int
50
(Float
1.0 :: Float)
]
mnistTestCaseRNNSO
:: forall width batch_size r.
( Differentiable r, GoodScalar r
, PrintfArg r, AssertEqualUpToEpsilon r, ADTensorScalar r ~ r )
=> String
-> Int -> Int -> SNat width -> SNat batch_size -> Int -> r
-> TestTree
mnistTestCaseRNNSO :: forall (width :: Nat) (batch_size :: Nat) r.
(Differentiable r, GoodScalar r, PrintfArg r,
AssertEqualUpToEpsilon r,
(ADTensorScalar r :: Type) ~ (r :: Type)) =>
String
-> Int
-> Int
-> SNat width
-> SNat batch_size
-> Int
-> r
-> TestTree
mnistTestCaseRNNSO String
prefix Int
epochs Int
maxBatches width :: SNat width
width@SNat width
SNat batch_size :: SNat batch_size
batch_size@SNat batch_size
SNat
Int
totalBatchSize r
expected =
let targetInit :: Concrete (XParams width r)
targetInit =
(Concrete (XParams width r), StdGen) -> Concrete (XParams width r)
forall a b. (a, b) -> a
fst ((Concrete (XParams width r), StdGen)
-> Concrete (XParams width r))
-> (Concrete (XParams width r), StdGen)
-> Concrete (XParams width r)
forall a b. (a -> b) -> a -> b
$ forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue @(Concrete (XParams width r)) Double
0.23 (Int -> StdGen
mkStdGen Int
44)
miniBatchSize :: Int
miniBatchSize = SNat batch_size -> Int
forall (n :: Nat). SNat n -> Int
sNatValue SNat batch_size
batch_size
name :: String
name = String
prefix String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": "
String -> String -> String
forall a. [a] -> [a] -> [a]
++ [String] -> String
unwords [ Int -> String
forall a. Show a => a -> String
show Int
epochs, Int -> String
forall a. Show a => a -> String
show Int
maxBatches
, Int -> String
forall a. Show a => a -> String
show (SNat width -> Int
forall (n :: Nat). SNat n -> Int
sNatValue SNat width
width), Int -> String
forall a. Show a => a -> String
show Int
miniBatchSize
, Int -> String
forall a. Show a => a -> String
show (Int -> String) -> Int -> String
forall a b. (a -> b) -> a -> b
$ SingletonTK (XParams width r) -> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK
(SingletonTK (XParams width r) -> Int)
-> SingletonTK (XParams width r) -> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams width r)
, Int -> String
forall a. Show a => a -> String
show (SingletonTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
Concrete (XParams width r)
targetInit) ]
ftest :: forall batch_size2. KnownNat batch_size2
=> MnistDataBatchS batch_size2 r -> Concrete (XParams width r)
-> r
ftest :: forall (batch_size2 :: Nat).
KnownNat batch_size2 =>
MnistDataBatchS batch_size2 r -> Concrete (XParams width r) -> r
ftest MnistDataBatchS batch_size2 r
_ Concrete (XParams width r)
_ | Just (:~:) @Nat 0 batch_size2
Refl <- Proxy @Nat 0
-> Proxy @Nat batch_size2 -> Maybe ((:~:) @Nat 0 batch_size2)
forall (a :: Nat) (b :: Nat) (proxy1 :: Nat -> Type)
(proxy2 :: Nat -> Type).
(KnownNat a, KnownNat b) =>
proxy1 a -> proxy2 b -> Maybe ((:~:) @Nat a b)
sameNat (forall (t :: Nat). Proxy @Nat t
forall {k} (t :: k). Proxy @k t
Proxy @0) (forall (t :: Nat). Proxy @Nat t
forall {k} (t :: k). Proxy @k t
Proxy @batch_size2) = r
0
ftest MnistDataBatchS batch_size2 r
mnistData Concrete (XParams width r)
testParams =
SNat width
-> SNat batch_size2
-> MnistDataBatchS batch_size2 r
-> ADRnnMnistParametersShaped Concrete SizeMnistHeight width r
-> r
forall (target :: TK -> Type) (h :: Nat) (w :: Nat)
(out_width :: Nat) (batch_size :: Nat) r.
((h :: Nat) ~ (SizeMnistHeight :: Nat),
(w :: Nat) ~ (SizeMnistHeight :: Nat),
(target :: (TK -> Type)) ~ (Concrete :: (TK -> Type)),
Differentiable r, GoodScalar r) =>
SNat out_width
-> SNat batch_size
-> MnistDataBatchS batch_size r
-> ADRnnMnistParametersShaped target h out_width r
-> r
MnistRnnShaped2.rnnMnistTestS
SNat width
width (forall (n :: Nat). KnownNat n => SNat n
SNat @batch_size2) MnistDataBatchS batch_size2 r
mnistData (forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget @Concrete Concrete (XParams width r)
testParams)
in String -> Assertion -> TestTree
testCase String
name (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ do
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: Epochs to run/max batches per epoch: %d/%d"
String
prefix Int
epochs Int
maxBatches
trainData <- (MnistData r -> MnistDataS r) -> [MnistData r] -> [MnistDataS r]
forall a b. (a -> b) -> [a] -> [b]
map MnistData r -> MnistDataS r
forall r. PrimElt r => MnistData r -> MnistDataS r
mkMnistDataS
([MnistData r] -> [MnistDataS r])
-> IO [MnistData r] -> IO [MnistDataS r]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> String -> IO [MnistData r]
forall r.
(Storable r, Fractional r) =>
String -> String -> IO [MnistData r]
loadMnistData String
trainGlyphsPath String
trainLabelsPath
testData <- map mkMnistDataS . take (totalBatchSize * maxBatches)
<$> loadMnistData testGlyphsPath testLabelsPath
withSNat ((totalBatchSize * maxBatches) `min` 10000)
$ \(SNat @lenTestData) -> do
let testDataS :: MnistDataBatchS n r
testDataS = forall (batch_size :: Nat) r.
(Elt r, KnownNat batch_size) =>
[MnistDataS r] -> MnistDataBatchS batch_size r
mkMnistDataBatchS @lenTestData [MnistDataS r]
testData
ftk :: FullShapeTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
ftk = forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams width r)) Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
Concrete (XParams width r)
targetInit
ftkData :: FullShapeTK
(TKProduct
(TKS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r)
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r))
ftkData = FullShapeTK
(TKS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r)
-> FullShapeTK
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)
-> FullShapeTK
(TKProduct
(TKS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r)
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r))
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct (ShS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
-> FullShapeTK (TKScalar r)
-> FullShapeTK
(TKS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS (SNat batch_size
batch_size
SNat batch_size
-> ShS
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat)))
-> ShS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
forall {sh1 :: [Nat]} (n :: Nat) (sh :: [Nat]).
(KnownNat n, ((':) @Nat n sh :: [Nat]) ~ (sh1 :: [Nat])) =>
SNat n -> ShS sh -> ShS sh1
:$$ SNat SizeMnistHeight
sizeMnistHeight
SNat SizeMnistHeight
-> ShS ((':) @Nat SizeMnistHeight ('[] @Nat))
-> ShS
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat)))
forall {sh1 :: [Nat]} (n :: Nat) (sh :: [Nat]).
(KnownNat n, ((':) @Nat n sh :: [Nat]) ~ (sh1 :: [Nat])) =>
SNat n -> ShS sh -> ShS sh1
:$$ SNat SizeMnistHeight
sizeMnistWidth
SNat SizeMnistHeight
-> ShS ('[] @Nat) -> ShS ((':) @Nat SizeMnistHeight ('[] @Nat))
forall {sh1 :: [Nat]} (n :: Nat) (sh :: [Nat]).
(KnownNat n, ((':) @Nat n sh :: [Nat]) ~ (sh1 :: [Nat])) =>
SNat n -> ShS sh -> ShS sh1
:$$ ShS ('[] @Nat)
forall (sh :: [Nat]).
((sh :: [Nat]) ~ ('[] @Nat :: [Nat])) =>
ShS sh
ZSS) FullShapeTK (TKScalar r)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar)
(ShS ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat)))
-> FullShapeTK (TKScalar r)
-> FullShapeTK
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)
forall (sh :: [Nat]) (x :: TK).
ShS sh -> FullShapeTK x -> FullShapeTK (TKS2 sh x)
FTKS (SNat batch_size
batch_size
SNat batch_size
-> ShS ((':) @Nat SizeMnistLabel ('[] @Nat))
-> ShS ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat)))
forall {sh1 :: [Nat]} (n :: Nat) (sh :: [Nat]).
(KnownNat n, ((':) @Nat n sh :: [Nat]) ~ (sh1 :: [Nat])) =>
SNat n -> ShS sh -> ShS sh1
:$$ SNat SizeMnistLabel
sizeMnistLabel
SNat SizeMnistLabel
-> ShS ('[] @Nat) -> ShS ((':) @Nat SizeMnistLabel ('[] @Nat))
forall {sh1 :: [Nat]} (n :: Nat) (sh :: [Nat]).
(KnownNat n, ((':) @Nat n sh :: [Nat]) ~ (sh1 :: [Nat])) =>
SNat n -> ShS sh -> ShS sh1
:$$ ShS ('[] @Nat)
forall (sh :: [Nat]).
((sh :: [Nat]) ~ ('[] @Nat :: [Nat])) =>
ShS sh
ZSS) FullShapeTK (TKScalar r)
forall r. GoodScalar r => FullShapeTK (TKScalar r)
FTKScalar)
f :: ( ADRnnMnistParametersShaped (AstTensor AstMethodLet FullSpan)
SizeMnistHeight width r
, ( AstTensor AstMethodLet FullSpan
(TKS '[batch_size, SizeMnistHeight, SizeMnistWidth] r)
, AstTensor AstMethodLet FullSpan
(TKS '[batch_size, SizeMnistLabel] r) ) )
-> AstTensor AstMethodLet FullSpan (TKScalar r)
f :: (ADRnnMnistParametersShaped
(AstTensor AstMethodLet FullSpan) SizeMnistHeight width r,
(AstTensor
AstMethodLet
FullSpan
(TKS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r),
AstTensor
AstMethodLet
FullSpan
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)))
-> AstTensor AstMethodLet FullSpan (TKScalar r)
f = \ (ADRnnMnistParametersShaped
(AstTensor AstMethodLet FullSpan) SizeMnistHeight width r
pars, (AstTensor
AstMethodLet
FullSpan
(TKS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r)
glyphS, AstTensor
AstMethodLet
FullSpan
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)
labelS)) ->
SNat width
-> SNat batch_size
-> (PrimalOf
(AstTensor AstMethodLet FullSpan)
(TKS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r),
PrimalOf
(AstTensor AstMethodLet FullSpan)
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r))
-> ADRnnMnistParametersShaped
(AstTensor AstMethodLet FullSpan) SizeMnistHeight width r
-> AstTensor AstMethodLet FullSpan (TKScalar r)
forall (target :: TK -> Type) (h :: Nat) (w :: Nat)
(out_width :: Nat) (batch_size :: Nat) r.
((h :: Nat) ~ (SizeMnistHeight :: Nat),
(w :: Nat) ~ (SizeMnistHeight :: Nat), Differentiable r,
ADReady target, ADReady (PrimalOf target), GoodScalar r) =>
SNat out_width
-> SNat batch_size
-> (PrimalOf
target
(TKS
((':) @Nat batch_size ((':) @Nat h ((':) @Nat w ('[] @Nat)))) r),
PrimalOf
target
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r))
-> ADRnnMnistParametersShaped target h out_width r
-> target (TKScalar r)
MnistRnnShaped2.rnnMnistLossFusedS
SNat width
width SNat batch_size
batch_size (AstTensor
AstMethodLet
FullSpan
(TKS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r)
-> PrimalOf
(AstTensor AstMethodLet FullSpan)
(TKS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r)
forall (target :: TK -> Type) (sh :: [Nat]) (x :: TK).
BaseTensor target =>
target (TKS2 sh x) -> PrimalOf target (TKS2 sh x)
sprimalPart AstTensor
AstMethodLet
FullSpan
(TKS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r)
glyphS, AstTensor
AstMethodLet
FullSpan
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)
-> PrimalOf
(AstTensor AstMethodLet FullSpan)
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)
forall (target :: TK -> Type) (sh :: [Nat]) (x :: TK).
BaseTensor target =>
target (TKS2 sh x) -> PrimalOf target (TKS2 sh x)
sprimalPart AstTensor
AstMethodLet
FullSpan
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)
labelS) ADRnnMnistParametersShaped
(AstTensor AstMethodLet FullSpan) SizeMnistHeight width r
pars
artRaw :: AstArtifactRev
(X (ADRnnMnistParametersShaped
(AstTensor AstMethodLet FullSpan) SizeMnistHeight width r,
(AstTensor
AstMethodLet
FullSpan
(TKS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r),
AstTensor
AstMethodLet
FullSpan
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r))))
(TKScalar r)
artRaw = IncomingCotangentHandling
-> ((ADRnnMnistParametersShaped
(AstTensor AstMethodLet FullSpan) SizeMnistHeight width r,
(AstTensor
AstMethodLet
FullSpan
(TKS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r),
AstTensor
AstMethodLet
FullSpan
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)))
-> AstTensor AstMethodLet FullSpan (TKScalar r))
-> FullShapeTK
(X (ADRnnMnistParametersShaped
(AstTensor AstMethodLet FullSpan) SizeMnistHeight width r,
(AstTensor
AstMethodLet
FullSpan
(TKS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r),
AstTensor
AstMethodLet
FullSpan
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r))))
-> AstArtifactRev
(X (ADRnnMnistParametersShaped
(AstTensor AstMethodLet FullSpan) SizeMnistHeight width r,
(AstTensor
AstMethodLet
FullSpan
(TKS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r),
AstTensor
AstMethodLet
FullSpan
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r))))
(TKScalar r)
forall src (ztgt :: TK) tgt.
(AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
(tgt :: Type) ~ (AstTensor AstMethodLet FullSpan ztgt :: Type)) =>
IncomingCotangentHandling
-> (src -> tgt)
-> FullShapeTK (X src)
-> AstArtifactRev (X src) ztgt
revArtifactAdapt IncomingCotangentHandling
IgnoreIncomingCotangent
(ADRnnMnistParametersShaped
(AstTensor AstMethodLet FullSpan) SizeMnistHeight width r,
(AstTensor
AstMethodLet
FullSpan
(TKS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r),
AstTensor
AstMethodLet
FullSpan
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)))
-> AstTensor AstMethodLet FullSpan (TKScalar r)
f (FullShapeTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> FullShapeTK
(TKProduct
(TKS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r)
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r))
-> FullShapeTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r)
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)))
forall (y1 :: TK) (z :: TK).
FullShapeTK y1 -> FullShapeTK z -> FullShapeTK (TKProduct y1 z)
FTKProduct FullShapeTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
ftk FullShapeTK
(TKProduct
(TKS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r)
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r))
ftkData)
art :: AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r)
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)))
(TKScalar r)
art = AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r)
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)))
(TKScalar r)
-> AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r)
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)))
(TKScalar r)
forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstArtifactRev x z
simplifyArtifactGradient AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r)
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)))
(TKScalar r)
AstArtifactRev
(X (ADRnnMnistParametersShaped
(AstTensor AstMethodLet FullSpan) SizeMnistHeight width r,
(AstTensor
AstMethodLet
FullSpan
(TKS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r),
AstTensor
AstMethodLet
FullSpan
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r))))
(TKScalar r)
artRaw
go :: [MnistDataBatchS batch_size r]
-> ( Concrete (XParams width r)
, StateAdam (XParams width r) )
-> ( Concrete (XParams width r)
, StateAdam (XParams width r) )
go :: [MnistDataBatchS batch_size r]
-> (Concrete (XParams width r), StateAdam (XParams width r))
-> (Concrete (XParams width r), StateAdam (XParams width r))
go [] (Concrete (XParams width r)
parameters, StateAdam (XParams width r)
stateAdam) = (Concrete (XParams width r)
parameters, StateAdam (XParams width r)
stateAdam)
go ((Shaped
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r
glyph, Shaped
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r
label) : [MnistDataBatchS batch_size r]
rest) (!Concrete (XParams width r)
parameters, !StateAdam (XParams width r)
stateAdam) =
let parametersAndInput :: Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r)
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)))
parametersAndInput =
Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> Concrete
(TKProduct
(TKS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r)
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r)
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)))
forall (x :: TK) (z :: TK).
Concrete x -> Concrete z -> Concrete (TKProduct x z)
forall (target :: TK -> Type) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
Concrete (XParams width r)
parameters (Concrete
(TKS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r)
-> Concrete
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)
-> Concrete
(TKProduct
(TKS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r)
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r))
forall (x :: TK) (z :: TK).
Concrete x -> Concrete z -> Concrete (TKProduct x z)
forall (target :: TK -> Type) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair (Shaped
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r
-> Concrete
(TKS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r)
forall r (target :: TK -> Type) (sh :: [Nat]).
(GoodScalar r, BaseTensor target) =>
Shaped sh r -> target (TKS sh r)
sconcrete Shaped
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r
glyph) (Shaped
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r
-> Concrete
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)
forall r (target :: TK -> Type) (sh :: [Nat]).
(GoodScalar r, BaseTensor target) =>
Shaped sh r -> target (TKS sh r)
sconcrete Shaped
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r
label))
gradient :: Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
gradient = Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r)
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
forall (x :: TK) (z :: TK). Concrete (TKProduct x z) -> Concrete x
forall (target :: TK -> Type) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r)
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r)
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
forall a b. (a -> b) -> a -> b
$ (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r)
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r))),
Concrete (TKScalar r))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r)
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)))
forall a b. (a, b) -> a
fst
((Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r)
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r))),
Concrete (TKScalar r))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r)
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r))))
-> (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r)
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r))),
Concrete (TKScalar r))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r)
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)))
forall a b. (a -> b) -> a -> b
$ AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r)
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)))
(TKScalar r)
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r)
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)))
-> Maybe (Concrete (ADTensorKind (TKScalar r)))
-> (Concrete
(ADTensorKind
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r)
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)))),
Concrete (TKScalar r))
forall (x :: TK) (z :: TK).
AstArtifactRev x z
-> Concrete x
-> Maybe (Concrete (ADTensorKind z))
-> (Concrete (ADTensorKind x), Concrete z)
revInterpretArtifact
AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r)
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)))
(TKScalar r)
art Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r)
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)))
parametersAndInput Maybe (Concrete (ADTensorKind (TKScalar r)))
Maybe (Concrete (TKScalar r))
forall a. Maybe a
Nothing
in [MnistDataBatchS batch_size r]
-> (Concrete (XParams width r), StateAdam (XParams width r))
-> (Concrete (XParams width r), StateAdam (XParams width r))
go [MnistDataBatchS batch_size r]
rest (forall (y :: TK).
ArgsAdam
-> StateAdam y
-> SingletonTK y
-> Concrete y
-> Concrete (ADTensorKind y)
-> (Concrete y, StateAdam y)
updateWithGradientAdam
@(XParams width r)
ArgsAdam
defaultArgsAdam StateAdam (XParams width r)
stateAdam SingletonTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
SingletonTK (XParams width r)
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete (XParams width r)
parameters
Concrete (ADTensorKind (XParams width r))
Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
gradient)
runBatch :: ( Concrete (XParams width r)
, StateAdam (XParams width r) )
-> (Int, [MnistDataS r])
-> IO ( Concrete (XParams width r)
, StateAdam (XParams width r) )
runBatch :: (Concrete (XParams width r), StateAdam (XParams width r))
-> (Int, [MnistDataS r])
-> IO (Concrete (XParams width r), StateAdam (XParams width r))
runBatch (!Concrete (XParams width r)
parameters, !StateAdam (XParams width r)
stateAdam) (Int
k, [MnistDataS r]
chunk) = do
let chunkS :: [MnistDataBatchS batch_size r]
chunkS = ([MnistDataS r] -> MnistDataBatchS batch_size r)
-> [[MnistDataS r]] -> [MnistDataBatchS batch_size r]
forall a b. (a -> b) -> [a] -> [b]
map [MnistDataS r] -> MnistDataBatchS batch_size r
forall (batch_size :: Nat) r.
(Elt r, KnownNat batch_size) =>
[MnistDataS r] -> MnistDataBatchS batch_size r
mkMnistDataBatchS
([[MnistDataS r]] -> [MnistDataBatchS batch_size r])
-> [[MnistDataS r]] -> [MnistDataBatchS batch_size r]
forall a b. (a -> b) -> a -> b
$ ([MnistDataS r] -> Bool) -> [[MnistDataS r]] -> [[MnistDataS r]]
forall a. (a -> Bool) -> [a] -> [a]
filter (\[MnistDataS r]
ch -> [MnistDataS r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataS r]
ch Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
miniBatchSize)
([[MnistDataS r]] -> [[MnistDataS r]])
-> [[MnistDataS r]] -> [[MnistDataS r]]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataS r] -> [[MnistDataS r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
miniBatchSize [MnistDataS r]
chunk
res :: (Concrete (XParams width r), StateAdam (XParams width r))
res@(Concrete (XParams width r)
parameters2, StateAdam (XParams width r)
_) = [MnistDataBatchS batch_size r]
-> (Concrete (XParams width r), StateAdam (XParams width r))
-> (Concrete (XParams width r), StateAdam (XParams width r))
go [MnistDataBatchS batch_size r]
chunkS (Concrete (XParams width r)
parameters, StateAdam (XParams width r)
stateAdam)
trainScore :: r
trainScore = Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat ([MnistDataS r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataS r]
chunk) ((forall (n :: Nat). KnownNat n => SNat n -> r) -> r)
-> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
forall a b. (a -> b) -> a -> b
$ \(SNat @len) ->
forall (batch_size2 :: Nat).
KnownNat batch_size2 =>
MnistDataBatchS batch_size2 r -> Concrete (XParams width r) -> r
ftest @len ([MnistDataS r] -> MnistDataBatchS n r
forall (batch_size :: Nat) r.
(Elt r, KnownNat batch_size) =>
[MnistDataS r] -> MnistDataBatchS batch_size r
mkMnistDataBatchS [MnistDataS r]
chunk) Concrete (XParams width r)
parameters2
testScore :: r
testScore = forall (batch_size2 :: Nat).
KnownNat batch_size2 =>
MnistDataBatchS batch_size2 r -> Concrete (XParams width r) -> r
ftest @lenTestData MnistDataBatchS n r
testDataS Concrete (XParams width r)
parameters2
lenChunk :: Int
lenChunk = [MnistDataS r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataS r]
chunk
Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (SNat width -> Int
forall (n :: Nat). SNat n -> Int
sNatValue SNat width
width Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$ do
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: (Batch %d with %d points)"
String
prefix Int
k Int
lenChunk
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Training error: %.2f%%"
String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
trainScore) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Validation error: %.2f%%"
String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
testScore ) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
(Concrete (XParams width r), StateAdam (XParams width r))
res
let runEpoch :: Int
-> ( Concrete (XParams width r)
, StateAdam (XParams width r) )
-> IO (Concrete (XParams width r))
runEpoch :: Int
-> (Concrete (XParams width r), StateAdam (XParams width r))
-> IO (Concrete (XParams width r))
runEpoch Int
n (Concrete (XParams width r)
params2, StateAdam (XParams width r)
_) | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
epochs = Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
Concrete (XParams width r)
params2
runEpoch Int
n paramsStateAdam :: (Concrete (XParams width r), StateAdam (XParams width r))
paramsStateAdam@(!Concrete (XParams width r)
_, !StateAdam (XParams width r)
_) = do
Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (SNat width -> Int
forall (n :: Nat). SNat n -> Int
sNatValue SNat width
width Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$ String -> String -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: [Epoch %d]" String
prefix Int
n
let trainDataShuffled :: [MnistDataS r]
trainDataShuffled = StdGen -> [MnistDataS r] -> [MnistDataS r]
forall a. StdGen -> [a] -> [a]
shuffle (Int -> StdGen
mkStdGen (Int -> StdGen) -> Int -> StdGen
forall a b. (a -> b) -> a -> b
$ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
5) [MnistDataS r]
trainData
chunks :: [(Int, [MnistDataS r])]
chunks = Int -> [(Int, [MnistDataS r])] -> [(Int, [MnistDataS r])]
forall a. Int -> [a] -> [a]
take Int
maxBatches
([(Int, [MnistDataS r])] -> [(Int, [MnistDataS r])])
-> [(Int, [MnistDataS r])] -> [(Int, [MnistDataS r])]
forall a b. (a -> b) -> a -> b
$ [Int] -> [[MnistDataS r]] -> [(Int, [MnistDataS r])]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
1 ..]
([[MnistDataS r]] -> [(Int, [MnistDataS r])])
-> [[MnistDataS r]] -> [(Int, [MnistDataS r])]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataS r] -> [[MnistDataS r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
totalBatchSize [MnistDataS r]
trainDataShuffled
res <- ((Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
-> (Int, [MnistDataS r])
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
-> (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
-> [(Int, [MnistDataS r])]
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
forall (t :: Type -> Type) (m :: Type -> Type) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
-> (Int, [MnistDataS r])
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
(Concrete (XParams width r), StateAdam (XParams width r))
-> (Int, [MnistDataS r])
-> IO (Concrete (XParams width r), StateAdam (XParams width r))
runBatch (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
(Concrete (XParams width r), StateAdam (XParams width r))
paramsStateAdam [(Int, [MnistDataS r])]
chunks
runEpoch (succ n) res
res <- Int
-> (Concrete (XParams width r), StateAdam (XParams width r))
-> IO (Concrete (XParams width r))
runEpoch Int
1 (Concrete (XParams width r)
targetInit, FullShapeTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
forall (y :: TK). FullShapeTK y -> StateAdam y
initialStateAdam FullShapeTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
ftk)
let testErrorFinal = r
1 r -> r -> r
forall a. Num a => a -> a -> a
- forall (batch_size2 :: Nat).
KnownNat batch_size2 =>
MnistDataBatchS batch_size2 r -> Concrete (XParams width r) -> r
ftest @lenTestData MnistDataBatchS n r
testDataS Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat width ((':) @Nat SizeMnistHeight ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat width ((':) @Nat width ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat width ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat width ('[] @Nat)))
(TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
Concrete (XParams width r)
res
assertEqualUpToEpsilon 1e-1 expected testErrorFinal
{-# SPECIALIZE mnistTestCaseRNNSO
:: String
-> Int -> Int -> SNat width -> SNat batch_size -> Int -> Double
-> TestTree #-}
tensorADValMnistTestsRNNSO :: TestTree
tensorADValMnistTestsRNNSO :: TestTree
tensorADValMnistTestsRNNSO = String -> [TestTree] -> TestTree
testGroup String
"RNNS Once MNIST tests"
[ String
-> Int -> Int -> SNat 128 -> SNat 150 -> Int -> Double -> TestTree
forall (width :: Nat) (batch_size :: Nat) r.
(Differentiable r, GoodScalar r, PrintfArg r,
AssertEqualUpToEpsilon r,
(ADTensorScalar r :: Type) ~ (r :: Type)) =>
String
-> Int
-> Int
-> SNat width
-> SNat batch_size
-> Int
-> r
-> TestTree
mnistTestCaseRNNSO String
"RNNSO 1 epoch, 1 batch" Int
1 Int
1 (forall (n :: Nat). KnownNat n => SNat n
SNat @128) (forall (n :: Nat). KnownNat n => SNat n
SNat @150) Int
5000
(Double
0.6026 :: Double)
, String
-> Int -> Int -> SNat 4 -> SNat 5 -> Int -> Float -> TestTree
forall (width :: Nat) (batch_size :: Nat) r.
(Differentiable r, GoodScalar r, PrintfArg r,
AssertEqualUpToEpsilon r,
(ADTensorScalar r :: Type) ~ (r :: Type)) =>
String
-> Int
-> Int
-> SNat width
-> SNat batch_size
-> Int
-> r
-> TestTree
mnistTestCaseRNNSO String
"RNNSO artificial 1 2 3 4 5" Int
2 Int
3 (forall (n :: Nat). KnownNat n => SNat n
SNat @4) (forall (n :: Nat). KnownNat n => SNat n
SNat @5) Int
50
(Float
0.8933333 :: Float)
, String
-> Int -> Int -> SNat 3 -> SNat 2 -> Int -> Double -> TestTree
forall (width :: Nat) (batch_size :: Nat) r.
(Differentiable r, GoodScalar r, PrintfArg r,
AssertEqualUpToEpsilon r,
(ADTensorScalar r :: Type) ~ (r :: Type)) =>
String
-> Int
-> Int
-> SNat width
-> SNat batch_size
-> Int
-> r
-> TestTree
mnistTestCaseRNNSO String
"RNNSO artificial 5 4 3 2 1" Int
5 Int
4 (forall (n :: Nat). KnownNat n => SNat n
SNat @3) (forall (n :: Nat). KnownNat n => SNat n
SNat @2) Int
49
(Double
0.9336734693877551 :: Double)
, String
-> Int -> Int -> SNat 128 -> SNat 150 -> Int -> Float -> TestTree
forall (width :: Nat) (batch_size :: Nat) r.
(Differentiable r, GoodScalar r, PrintfArg r,
AssertEqualUpToEpsilon r,
(ADTensorScalar r :: Type) ~ (r :: Type)) =>
String
-> Int
-> Int
-> SNat width
-> SNat batch_size
-> Int
-> r
-> TestTree
mnistTestCaseRNNSO String
"RNNSO 1 epoch, 0 batch" Int
1 Int
0 (forall (n :: Nat). KnownNat n => SNat n
SNat @128) (forall (n :: Nat). KnownNat n => SNat n
SNat @150) Int
50
(Float
1.0 :: Float)
]