{-# LANGUAGE OverloadedLists #-}
module TestMnistRNNR
( testTrees
) where
import Prelude
import Control.Monad (foldM, unless)
import System.IO (hPutStrLn, stderr)
import System.Random
import Test.Tasty
import Test.Tasty.HUnit hiding (assert)
import Text.Printf
import Data.Array.Nested.Ranked.Shape
import HordeAd
import HordeAd.Core.Adaptor
import HordeAd.Core.AstEnv
import HordeAd.Core.AstFreshId
import HordeAd.Core.AstInterpret
import EqEpsilon
import MnistData
import MnistRnnRanked2 (ADRnnMnistParameters, ADRnnMnistParametersShaped)
import MnistRnnRanked2 qualified
testTrees :: [TestTree]
testTrees :: [TestTree]
testTrees = [ Item [TestTree]
TestTree
tensorADValMnistTestsRNNRA
, Item [TestTree]
TestTree
tensorADValMnistTestsRNNRI
, Item [TestTree]
TestTree
tensorADValMnistTestsRNNRO
]
mnistTestCaseRNNRA
:: forall r.
(Differentiable r, GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r)
=> String
-> Int -> Int -> Int -> Int -> Int -> r
-> TestTree
mnistTestCaseRNNRA :: forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
AssertEqualUpToEpsilon r) =>
String -> Int -> Int -> Int -> Int -> Int -> r -> TestTree
mnistTestCaseRNNRA String
prefix Int
epochs Int
maxBatches Int
width Int
miniBatchSize Int
totalBatchSize
r
expected =
Int
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
width ((forall (n :: Nat). KnownNat n => SNat n -> TestTree) -> TestTree)
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat @width) ->
let targetInit :: NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
targetInit =
Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
forall a b. (a -> b) -> a -> b
$ (Concrete (X (ADRnnMnistParametersShaped Concrete n r)), StdGen)
-> Concrete (X (ADRnnMnistParametersShaped Concrete n r))
forall a b. (a, b) -> a
fst
((Concrete (X (ADRnnMnistParametersShaped Concrete n r)), StdGen)
-> Concrete (X (ADRnnMnistParametersShaped Concrete n r)))
-> (Concrete (X (ADRnnMnistParametersShaped Concrete n r)), StdGen)
-> Concrete (X (ADRnnMnistParametersShaped Concrete n r))
forall a b. (a -> b) -> a -> b
$ forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue @(Concrete (X (ADRnnMnistParametersShaped
Concrete width r)))
Double
0.23 (Int -> StdGen
mkStdGen Int
44)
name :: String
name = String
prefix String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": "
String -> String -> String
forall a. [a] -> [a] -> [a]
++ [String] -> String
unwords [ Int -> String
forall a. Show a => a -> String
show Int
epochs, Int -> String
forall a. Show a => a -> String
show Int
maxBatches
, Int -> String
forall a. Show a => a -> String
show Int
width, Int -> String
forall a. Show a => a -> String
show Int
miniBatchSize
, Int -> String
forall a. Show a => a -> String
show (Int -> String) -> Int -> String
forall a b. (a -> b) -> a -> b
$ SingletonTK (X (ADRnnMnistParameters Concrete r)) -> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK
(SingletonTK (X (ADRnnMnistParameters Concrete r)) -> Int)
-> SingletonTK (X (ADRnnMnistParameters Concrete r)) -> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(X (ADRnnMnistParameters Concrete r))
, Int -> String
forall a. Show a => a -> String
show (SingletonTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
targetInit) ]
ftest :: Int -> MnistDataBatchR r
-> Concrete (X (ADRnnMnistParameters Concrete r))
-> r
ftest :: Int
-> MnistDataBatchR r
-> Concrete (X (ADRnnMnistParameters Concrete r))
-> r
ftest Int
batch_size MnistDataBatchR r
mnistData Concrete (X (ADRnnMnistParameters Concrete r))
pars =
Int -> MnistDataBatchR r -> ADRnnMnistParameters Concrete r -> r
forall (target :: TK -> Type) r.
((target :: (TK -> Type)) ~ (Concrete :: (TK -> Type)),
GoodScalar r, Differentiable r) =>
Int -> MnistDataBatchR r -> ADRnnMnistParameters target r -> r
MnistRnnRanked2.rnnMnistTestR
Int
batch_size MnistDataBatchR r
mnistData (forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget @Concrete Concrete (X (ADRnnMnistParameters Concrete r))
pars)
in String -> Assertion -> TestTree
testCase String
name (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ do
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: Epochs to run/max batches per epoch: %d/%d"
String
prefix Int
epochs Int
maxBatches
trainData <- (MnistData r -> MnistDataR r) -> [MnistData r] -> [MnistDataR r]
forall a b. (a -> b) -> [a] -> [b]
map MnistData r -> MnistDataR r
forall r. PrimElt r => MnistData r -> MnistDataR r
mkMnistDataR
([MnistData r] -> [MnistDataR r])
-> IO [MnistData r] -> IO [MnistDataR r]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> String -> IO [MnistData r]
forall r.
(Storable r, Fractional r) =>
String -> String -> IO [MnistData r]
loadMnistData String
trainGlyphsPath String
trainLabelsPath
testData <- map mkMnistDataR . take (totalBatchSize * maxBatches)
<$> loadMnistData testGlyphsPath testLabelsPath
let testDataR = [MnistDataR r] -> MnistDataBatchR r
forall r. Elt r => [MnistDataR r] -> MnistDataBatchR r
mkMnistDataBatchR [MnistDataR r]
testData
f :: MnistDataBatchR r
-> ADVal Concrete (X (ADRnnMnistParameters Concrete r))
-> ADVal Concrete (TKScalar r)
f (Ranked 3 r
glyphR, Ranked 2 r
labelR) ADVal Concrete (X (ADRnnMnistParameters Concrete r))
adinputs =
Int
-> (PrimalOf (ADVal Concrete) (TKR 3 r),
PrimalOf (ADVal Concrete) (TKR2 2 (TKScalar r)))
-> ADRnnMnistParameters (ADVal Concrete) r
-> ADVal Concrete (TKScalar r)
forall (target :: TK -> Type) r.
(ADReady target, ADReady (PrimalOf target), GoodScalar r,
Differentiable r) =>
Int
-> (PrimalOf target (TKR 3 r), PrimalOf target (TKR 2 r))
-> ADRnnMnistParameters target r
-> target (TKScalar r)
MnistRnnRanked2.rnnMnistLossFusedR
Int
miniBatchSize (Ranked 3 r -> Concrete (TKR 3 r)
forall r (target :: TK -> Type) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 3 r
glyphR, Ranked 2 r -> Concrete (TKR2 2 (TKScalar r))
forall r (target :: TK -> Type) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 2 r
labelR)
(forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget @(ADVal Concrete) ADVal Concrete (X (ADRnnMnistParameters (ADVal Concrete) r))
ADVal Concrete (X (ADRnnMnistParameters Concrete r))
adinputs)
runBatch :: ( Concrete (X (ADRnnMnistParameters Concrete r))
, StateAdam (X (ADRnnMnistParameters Concrete r)) )
-> (Int, [MnistDataR r])
-> IO ( Concrete (X (ADRnnMnistParameters Concrete r))
, StateAdam (X (ADRnnMnistParameters Concrete r)) )
runBatch (!Concrete (X (ADRnnMnistParameters Concrete r))
parameters, !StateAdam (X (ADRnnMnistParameters Concrete r))
stateAdam) (Int
k, [MnistDataR r]
chunk) = do
let chunkR :: [MnistDataBatchR r]
chunkR = ([MnistDataR r] -> MnistDataBatchR r)
-> [[MnistDataR r]] -> [MnistDataBatchR r]
forall a b. (a -> b) -> [a] -> [b]
map [MnistDataR r] -> MnistDataBatchR r
forall r. Elt r => [MnistDataR r] -> MnistDataBatchR r
mkMnistDataBatchR
([[MnistDataR r]] -> [MnistDataBatchR r])
-> [[MnistDataR r]] -> [MnistDataBatchR r]
forall a b. (a -> b) -> a -> b
$ ([MnistDataR r] -> Bool) -> [[MnistDataR r]] -> [[MnistDataR r]]
forall a. (a -> Bool) -> [a] -> [a]
filter (\[MnistDataR r]
ch -> [MnistDataR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataR r]
ch Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
miniBatchSize)
([[MnistDataR r]] -> [[MnistDataR r]])
-> [[MnistDataR r]] -> [[MnistDataR r]]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataR r] -> [[MnistDataR r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
miniBatchSize [MnistDataR r]
chunk
res :: (Concrete (X (ADRnnMnistParameters Concrete r)),
StateAdam (X (ADRnnMnistParameters Concrete r)))
res@(Concrete (X (ADRnnMnistParameters Concrete r))
parameters2, StateAdam (X (ADRnnMnistParameters Concrete r))
_) =
forall a (x :: TK) (z :: TK).
KnownSTK x =>
(a -> ADVal Concrete x -> ADVal Concrete z)
-> [a] -> Concrete x -> StateAdam x -> (Concrete x, StateAdam x)
sgdAdam @(MnistDataBatchR r)
@(X (ADRnnMnistParameters Concrete r))
MnistDataBatchR r
-> ADVal Concrete (X (ADRnnMnistParameters Concrete r))
-> ADVal Concrete (TKScalar r)
f [MnistDataBatchR r]
chunkR Concrete (X (ADRnnMnistParameters Concrete r))
parameters StateAdam (X (ADRnnMnistParameters Concrete r))
stateAdam
trainScore :: r
trainScore =
Int
-> MnistDataBatchR r
-> Concrete (X (ADRnnMnistParameters Concrete r))
-> r
ftest ([MnistDataR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataR r]
chunk) ([MnistDataR r] -> MnistDataBatchR r
forall r. Elt r => [MnistDataR r] -> MnistDataBatchR r
mkMnistDataBatchR [MnistDataR r]
chunk) Concrete (X (ADRnnMnistParameters Concrete r))
parameters2
testScore :: r
testScore =
Int
-> MnistDataBatchR r
-> Concrete (X (ADRnnMnistParameters Concrete r))
-> r
ftest ((Int
totalBatchSize Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
maxBatches) Int -> Int -> Int
forall a. Ord a => a -> a -> a
`min` Int
10000)
MnistDataBatchR r
testDataR Concrete (X (ADRnnMnistParameters Concrete r))
parameters2
lenChunk :: Int
lenChunk = [MnistDataR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataR r]
chunk
Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
width Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$ do
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: (Batch %d with %d points)"
String
prefix Int
k Int
lenChunk
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Training error: %.2f%%"
String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
trainScore) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Validation error: %.2f%%"
String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
testScore ) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
(Concrete (X (ADRnnMnistParameters Concrete r)),
StateAdam (X (ADRnnMnistParameters Concrete r)))
res
let runEpoch :: Int
-> ( Concrete (X (ADRnnMnistParameters Concrete r))
, StateAdam (X (ADRnnMnistParameters Concrete r)) )
-> IO (Concrete (X (ADRnnMnistParameters Concrete r)))
runEpoch Int
n (Concrete (X (ADRnnMnistParameters Concrete r))
params2, StateAdam (X (ADRnnMnistParameters Concrete r))
_) | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
epochs = Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
Concrete (X (ADRnnMnistParameters Concrete r))
params2
runEpoch Int
n paramsStateAdam :: (Concrete (X (ADRnnMnistParameters Concrete r)),
StateAdam (X (ADRnnMnistParameters Concrete r)))
paramsStateAdam@(!Concrete (X (ADRnnMnistParameters Concrete r))
_, !StateAdam (X (ADRnnMnistParameters Concrete r))
_) = do
Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
width Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$ String -> String -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: [Epoch %d]" String
prefix Int
n
let trainDataShuffled :: [MnistDataR r]
trainDataShuffled = StdGen -> [MnistDataR r] -> [MnistDataR r]
forall a. StdGen -> [a] -> [a]
shuffle (Int -> StdGen
mkStdGen (Int -> StdGen) -> Int -> StdGen
forall a b. (a -> b) -> a -> b
$ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
5) [MnistDataR r]
trainData
chunks :: [(Int, [MnistDataR r])]
chunks = Int -> [(Int, [MnistDataR r])] -> [(Int, [MnistDataR r])]
forall a. Int -> [a] -> [a]
take Int
maxBatches
([(Int, [MnistDataR r])] -> [(Int, [MnistDataR r])])
-> [(Int, [MnistDataR r])] -> [(Int, [MnistDataR r])]
forall a b. (a -> b) -> a -> b
$ [Int] -> [[MnistDataR r]] -> [(Int, [MnistDataR r])]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
Item [Int]
1 ..]
([[MnistDataR r]] -> [(Int, [MnistDataR r])])
-> [[MnistDataR r]] -> [(Int, [MnistDataR r])]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataR r] -> [[MnistDataR r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
totalBatchSize [MnistDataR r]
trainDataShuffled
res <- ((Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> (Int, [MnistDataR r])
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
-> (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> [(Int, [MnistDataR r])]
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
forall (t :: Type -> Type) (m :: Type -> Type) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> (Int, [MnistDataR r])
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
(Concrete (X (ADRnnMnistParameters Concrete r)),
StateAdam (X (ADRnnMnistParameters Concrete r)))
-> (Int, [MnistDataR r])
-> IO
(Concrete (X (ADRnnMnistParameters Concrete r)),
StateAdam (X (ADRnnMnistParameters Concrete r)))
runBatch (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
(Concrete (X (ADRnnMnistParameters Concrete r)),
StateAdam (X (ADRnnMnistParameters Concrete r)))
paramsStateAdam [(Int, [MnistDataR r])]
chunks
runEpoch (succ n) res
ftk = forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete
(forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(X (ADRnnMnistParameters Concrete r)))
Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
targetInit
res <- runEpoch 1 (targetInit, initialStateAdam ftk)
let testErrorFinal =
r
1 r -> r -> r
forall a. Num a => a -> a -> a
- Int
-> MnistDataBatchR r
-> Concrete (X (ADRnnMnistParameters Concrete r))
-> r
ftest ((Int
totalBatchSize Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
maxBatches) Int -> Int -> Int
forall a. Ord a => a -> a -> a
`min` Int
10000)
MnistDataBatchR r
testDataR Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
Concrete (X (ADRnnMnistParameters Concrete r))
res
testErrorFinal @?~ expected
{-# SPECIALIZE mnistTestCaseRNNRA
:: String
-> Int -> Int -> Int -> Int -> Int -> Double
-> TestTree #-}
tensorADValMnistTestsRNNRA :: TestTree
tensorADValMnistTestsRNNRA :: TestTree
tensorADValMnistTestsRNNRA = String -> [TestTree] -> TestTree
testGroup String
"RNNR ADVal MNIST tests"
[ String -> Int -> Int -> Int -> Int -> Int -> Double -> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
AssertEqualUpToEpsilon r) =>
String -> Int -> Int -> Int -> Int -> Int -> r -> TestTree
mnistTestCaseRNNRA String
"RNNRA 1 epoch, 1 batch" Int
1 Int
1 Int
128 Int
150 Int
5000
(Double
0.6026 :: Double)
, String -> Int -> Int -> Int -> Int -> Int -> Float -> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
AssertEqualUpToEpsilon r) =>
String -> Int -> Int -> Int -> Int -> Int -> r -> TestTree
mnistTestCaseRNNRA String
"RNNRA artificial 1 2 3 4 5" Int
2 Int
3 Int
4 Int
5 Int
50
(Float
0.8933333 :: Float)
, String -> Int -> Int -> Int -> Int -> Int -> Double -> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
AssertEqualUpToEpsilon r) =>
String -> Int -> Int -> Int -> Int -> Int -> r -> TestTree
mnistTestCaseRNNRA String
"RNNRA artificial 5 4 3 2 1" Int
5 Int
4 Int
3 Int
2 Int
49
(Double
0.8622448979591837 :: Double)
, String -> Int -> Int -> Int -> Int -> Int -> Float -> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
AssertEqualUpToEpsilon r) =>
String -> Int -> Int -> Int -> Int -> Int -> r -> TestTree
mnistTestCaseRNNRA String
"RNNRA 1 epoch, 0 batch" Int
1 Int
0 Int
128 Int
150 Int
50
(Float
1.0 :: Float)
]
mnistTestCaseRNNRI
:: forall r.
(Differentiable r, GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r)
=> String
-> Int -> Int -> Int -> Int -> Int -> r
-> TestTree
mnistTestCaseRNNRI :: forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
AssertEqualUpToEpsilon r) =>
String -> Int -> Int -> Int -> Int -> Int -> r -> TestTree
mnistTestCaseRNNRI String
prefix Int
epochs Int
maxBatches Int
width Int
miniBatchSize Int
totalBatchSize
r
expected =
Int
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
width ((forall (n :: Nat). KnownNat n => SNat n -> TestTree) -> TestTree)
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat @width) ->
let targetInit :: NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
targetInit =
Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
forall a b. (a -> b) -> a -> b
$ (Concrete (X (ADRnnMnistParametersShaped Concrete n r)), StdGen)
-> Concrete (X (ADRnnMnistParametersShaped Concrete n r))
forall a b. (a, b) -> a
fst
((Concrete (X (ADRnnMnistParametersShaped Concrete n r)), StdGen)
-> Concrete (X (ADRnnMnistParametersShaped Concrete n r)))
-> (Concrete (X (ADRnnMnistParametersShaped Concrete n r)), StdGen)
-> Concrete (X (ADRnnMnistParametersShaped Concrete n r))
forall a b. (a -> b) -> a -> b
$ forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue @(Concrete (X (ADRnnMnistParametersShaped
Concrete width r)))
Double
0.23 (Int -> StdGen
mkStdGen Int
44)
name :: String
name = String
prefix String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": "
String -> String -> String
forall a. [a] -> [a] -> [a]
++ [String] -> String
unwords [ Int -> String
forall a. Show a => a -> String
show Int
epochs, Int -> String
forall a. Show a => a -> String
show Int
maxBatches
, Int -> String
forall a. Show a => a -> String
show Int
width, Int -> String
forall a. Show a => a -> String
show Int
miniBatchSize
, Int -> String
forall a. Show a => a -> String
show (Int -> String) -> Int -> String
forall a b. (a -> b) -> a -> b
$ SingletonTK (X (ADRnnMnistParameters Concrete r)) -> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK
(SingletonTK (X (ADRnnMnistParameters Concrete r)) -> Int)
-> SingletonTK (X (ADRnnMnistParameters Concrete r)) -> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(X (ADRnnMnistParameters Concrete r))
, Int -> String
forall a. Show a => a -> String
show (SingletonTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
targetInit) ]
ftest :: Int -> MnistDataBatchR r
-> Concrete (X (ADRnnMnistParameters Concrete r))
-> r
ftest :: Int
-> MnistDataBatchR r
-> Concrete (X (ADRnnMnistParameters Concrete r))
-> r
ftest Int
batch_size MnistDataBatchR r
mnistData Concrete (X (ADRnnMnistParameters Concrete r))
pars =
Int -> MnistDataBatchR r -> ADRnnMnistParameters Concrete r -> r
forall (target :: TK -> Type) r.
((target :: (TK -> Type)) ~ (Concrete :: (TK -> Type)),
GoodScalar r, Differentiable r) =>
Int -> MnistDataBatchR r -> ADRnnMnistParameters target r -> r
MnistRnnRanked2.rnnMnistTestR
Int
batch_size MnistDataBatchR r
mnistData (forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget @Concrete Concrete (X (ADRnnMnistParameters Concrete r))
pars)
in String -> Assertion -> TestTree
testCase String
name (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ do
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: Epochs to run/max batches per epoch: %d/%d"
String
prefix Int
epochs Int
maxBatches
trainData <- (MnistData r -> MnistDataR r) -> [MnistData r] -> [MnistDataR r]
forall a b. (a -> b) -> [a] -> [b]
map MnistData r -> MnistDataR r
forall r. PrimElt r => MnistData r -> MnistDataR r
mkMnistDataR
([MnistData r] -> [MnistDataR r])
-> IO [MnistData r] -> IO [MnistDataR r]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> String -> IO [MnistData r]
forall r.
(Storable r, Fractional r) =>
String -> String -> IO [MnistData r]
loadMnistData String
trainGlyphsPath String
trainLabelsPath
testData <- map mkMnistDataR . take (totalBatchSize * maxBatches)
<$> loadMnistData testGlyphsPath testLabelsPath
let testDataR = [MnistDataR r] -> MnistDataBatchR r
forall r. Elt r => [MnistDataR r] -> MnistDataBatchR r
mkMnistDataBatchR [MnistDataR r]
testData
ftk = forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete
(forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(X (ADRnnMnistParameters Concrete r)))
Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
targetInit
(_, _, var, varAst) <- funToAstRevIO ftk
(varGlyph, astGlyph) <-
funToAstIO (FTKR (miniBatchSize
:$: sizeMnistHeightInt
:$: sizeMnistWidthInt
:$: ZSR) FTKScalar) id
(varLabel, astLabel) <-
funToAstIO (FTKR (miniBatchSize
:$: sizeMnistLabelInt
:$: ZSR) FTKScalar) id
let ast :: AstTensor AstMethodLet FullSpan (TKScalar r)
ast = AstTensor AstMethodLet FullSpan (TKScalar r)
-> AstTensor AstMethodLet FullSpan (TKScalar r)
forall (z :: TK) (s :: AstSpanType).
AstSpan s =>
AstTensor AstMethodLet s z -> AstTensor AstMethodLet s z
simplifyInline
(AstTensor AstMethodLet FullSpan (TKScalar r)
-> AstTensor AstMethodLet FullSpan (TKScalar r))
-> AstTensor AstMethodLet FullSpan (TKScalar r)
-> AstTensor AstMethodLet FullSpan (TKScalar r)
forall a b. (a -> b) -> a -> b
$ Int
-> (PrimalOf
(AstTensor AstMethodLet FullSpan) (TKR2 3 (TKScalar r)),
PrimalOf (AstTensor AstMethodLet FullSpan) (TKR2 2 (TKScalar r)))
-> ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) r
-> AstTensor AstMethodLet FullSpan (TKScalar r)
forall (target :: TK -> Type) r.
(ADReady target, ADReady (PrimalOf target), GoodScalar r,
Differentiable r) =>
Int
-> (PrimalOf target (TKR 3 r), PrimalOf target (TKR 2 r))
-> ADRnnMnistParameters target r
-> target (TKScalar r)
MnistRnnRanked2.rnnMnistLossFusedR
Int
miniBatchSize (AstTensor AstMethodLet PrimalSpan (TKR2 3 (TKScalar r))
PrimalOf (AstTensor AstMethodLet FullSpan) (TKR2 3 (TKScalar r))
astGlyph, AstTensor AstMethodLet PrimalSpan (TKR2 2 (TKScalar r))
PrimalOf (AstTensor AstMethodLet FullSpan) (TKR2 2 (TKScalar r))
astLabel)
(AstTensor
AstMethodLet
FullSpan
(X (ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) r))
-> ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) r
forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget AstTensor
AstMethodLet
FullSpan
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
AstTensor
AstMethodLet
FullSpan
(X (ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) r))
varAst)
f :: MnistDataBatchR r
-> ADVal Concrete (X (ADRnnMnistParameters Concrete r))
-> ADVal Concrete (TKScalar r)
f (Ranked 3 r
glyph, Ranked 2 r
label) ADVal Concrete (X (ADRnnMnistParameters Concrete r))
varInputs =
let env :: AstEnv (ADVal Concrete)
env = AstVarName
FullSpan
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> ADVal
Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> AstEnv (ADVal Concrete)
-> AstEnv (ADVal Concrete)
forall (target :: TK -> Type) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName
FullSpan
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
var ADVal
Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
ADVal Concrete (X (ADRnnMnistParameters Concrete r))
varInputs AstEnv (ADVal Concrete)
forall (target :: TK -> Type). AstEnv target
emptyEnv
envMnist :: AstEnv (ADVal Concrete)
envMnist = AstVarName PrimalSpan (TKR2 3 (TKScalar r))
-> ADVal Concrete (TKR2 3 (TKScalar r))
-> AstEnv (ADVal Concrete)
-> AstEnv (ADVal Concrete)
forall (target :: TK -> Type) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName PrimalSpan (TKR2 3 (TKScalar r))
varGlyph (Ranked 3 r -> ADVal Concrete (TKR2 3 (TKScalar r))
forall r (target :: TK -> Type) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 3 r
glyph)
(AstEnv (ADVal Concrete) -> AstEnv (ADVal Concrete))
-> AstEnv (ADVal Concrete) -> AstEnv (ADVal Concrete)
forall a b. (a -> b) -> a -> b
$ AstVarName PrimalSpan (TKR2 2 (TKScalar r))
-> ADVal Concrete (TKR2 2 (TKScalar r))
-> AstEnv (ADVal Concrete)
-> AstEnv (ADVal Concrete)
forall (target :: TK -> Type) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName PrimalSpan (TKR2 2 (TKScalar r))
varLabel (Ranked 2 r -> ADVal Concrete (TKR2 2 (TKScalar r))
forall r (target :: TK -> Type) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 2 r
label) AstEnv (ADVal Concrete)
env
in AstEnv (ADVal Concrete)
-> AstTensor AstMethodLet FullSpan (TKScalar r)
-> ADVal Concrete (TKScalar r)
forall (target :: TK -> Type) (y :: TK).
ADReady target =>
AstEnv target -> AstTensor AstMethodLet FullSpan y -> target y
interpretAstFull AstEnv (ADVal Concrete)
envMnist AstTensor AstMethodLet FullSpan (TKScalar r)
ast
runBatch :: ( Concrete (X (ADRnnMnistParameters Concrete r))
, StateAdam (X (ADRnnMnistParameters Concrete r)) )
-> (Int, [MnistDataR r])
-> IO ( Concrete (X (ADRnnMnistParameters Concrete r))
, StateAdam (X (ADRnnMnistParameters Concrete r)) )
runBatch (!Concrete (X (ADRnnMnistParameters Concrete r))
parameters, !StateAdam (X (ADRnnMnistParameters Concrete r))
stateAdam) (Int
k, [MnistDataR r]
chunk) = do
let chunkR :: [MnistDataBatchR r]
chunkR = ([MnistDataR r] -> MnistDataBatchR r)
-> [[MnistDataR r]] -> [MnistDataBatchR r]
forall a b. (a -> b) -> [a] -> [b]
map [MnistDataR r] -> MnistDataBatchR r
forall r. Elt r => [MnistDataR r] -> MnistDataBatchR r
mkMnistDataBatchR
([[MnistDataR r]] -> [MnistDataBatchR r])
-> [[MnistDataR r]] -> [MnistDataBatchR r]
forall a b. (a -> b) -> a -> b
$ ([MnistDataR r] -> Bool) -> [[MnistDataR r]] -> [[MnistDataR r]]
forall a. (a -> Bool) -> [a] -> [a]
filter (\[MnistDataR r]
ch -> [MnistDataR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataR r]
ch Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
miniBatchSize)
([[MnistDataR r]] -> [[MnistDataR r]])
-> [[MnistDataR r]] -> [[MnistDataR r]]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataR r] -> [[MnistDataR r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
miniBatchSize [MnistDataR r]
chunk
res :: (Concrete (X (ADRnnMnistParameters Concrete r)),
StateAdam (X (ADRnnMnistParameters Concrete r)))
res@(Concrete (X (ADRnnMnistParameters Concrete r))
parameters2, StateAdam (X (ADRnnMnistParameters Concrete r))
_) =
forall a (x :: TK) (z :: TK).
KnownSTK x =>
(a -> ADVal Concrete x -> ADVal Concrete z)
-> [a] -> Concrete x -> StateAdam x -> (Concrete x, StateAdam x)
sgdAdam @(MnistDataBatchR r)
@(X (ADRnnMnistParameters Concrete r))
MnistDataBatchR r
-> ADVal Concrete (X (ADRnnMnistParameters Concrete r))
-> ADVal Concrete (TKScalar r)
f [MnistDataBatchR r]
chunkR Concrete (X (ADRnnMnistParameters Concrete r))
parameters StateAdam (X (ADRnnMnistParameters Concrete r))
stateAdam
trainScore :: r
trainScore =
Int
-> MnistDataBatchR r
-> Concrete (X (ADRnnMnistParameters Concrete r))
-> r
ftest ([MnistDataR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataR r]
chunk) ([MnistDataR r] -> MnistDataBatchR r
forall r. Elt r => [MnistDataR r] -> MnistDataBatchR r
mkMnistDataBatchR [MnistDataR r]
chunk) Concrete (X (ADRnnMnistParameters Concrete r))
parameters2
testScore :: r
testScore =
Int
-> MnistDataBatchR r
-> Concrete (X (ADRnnMnistParameters Concrete r))
-> r
ftest ((Int
totalBatchSize Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
maxBatches) Int -> Int -> Int
forall a. Ord a => a -> a -> a
`min` Int
10000)
MnistDataBatchR r
testDataR Concrete (X (ADRnnMnistParameters Concrete r))
parameters2
lenChunk :: Int
lenChunk = [MnistDataR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataR r]
chunk
Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
width Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$ do
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: (Batch %d with %d points)"
String
prefix Int
k Int
lenChunk
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Training error: %.2f%%"
String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
trainScore) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Validation error: %.2f%%"
String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
testScore ) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
(Concrete (X (ADRnnMnistParameters Concrete r)),
StateAdam (X (ADRnnMnistParameters Concrete r)))
res
let runEpoch :: Int
-> ( Concrete (X (ADRnnMnistParameters Concrete r))
, StateAdam (X (ADRnnMnistParameters Concrete r)) )
-> IO (Concrete (X (ADRnnMnistParameters Concrete r)))
runEpoch Int
n (Concrete (X (ADRnnMnistParameters Concrete r))
params2, StateAdam (X (ADRnnMnistParameters Concrete r))
_) | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
epochs = Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
Concrete (X (ADRnnMnistParameters Concrete r))
params2
runEpoch Int
n paramsStateAdam :: (Concrete (X (ADRnnMnistParameters Concrete r)),
StateAdam (X (ADRnnMnistParameters Concrete r)))
paramsStateAdam@(!Concrete (X (ADRnnMnistParameters Concrete r))
_, !StateAdam (X (ADRnnMnistParameters Concrete r))
_) = do
Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
width Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$ String -> String -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: [Epoch %d]" String
prefix Int
n
let trainDataShuffled :: [MnistDataR r]
trainDataShuffled = StdGen -> [MnistDataR r] -> [MnistDataR r]
forall a. StdGen -> [a] -> [a]
shuffle (Int -> StdGen
mkStdGen (Int -> StdGen) -> Int -> StdGen
forall a b. (a -> b) -> a -> b
$ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
5) [MnistDataR r]
trainData
chunks :: [(Int, [MnistDataR r])]
chunks = Int -> [(Int, [MnistDataR r])] -> [(Int, [MnistDataR r])]
forall a. Int -> [a] -> [a]
take Int
maxBatches
([(Int, [MnistDataR r])] -> [(Int, [MnistDataR r])])
-> [(Int, [MnistDataR r])] -> [(Int, [MnistDataR r])]
forall a b. (a -> b) -> a -> b
$ [Int] -> [[MnistDataR r]] -> [(Int, [MnistDataR r])]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
Item [Int]
1 ..]
([[MnistDataR r]] -> [(Int, [MnistDataR r])])
-> [[MnistDataR r]] -> [(Int, [MnistDataR r])]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataR r] -> [[MnistDataR r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
totalBatchSize [MnistDataR r]
trainDataShuffled
res <- ((Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> (Int, [MnistDataR r])
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
-> (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> [(Int, [MnistDataR r])]
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
forall (t :: Type -> Type) (m :: Type -> Type) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> (Int, [MnistDataR r])
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
(Concrete (X (ADRnnMnistParameters Concrete r)),
StateAdam (X (ADRnnMnistParameters Concrete r)))
-> (Int, [MnistDataR r])
-> IO
(Concrete (X (ADRnnMnistParameters Concrete r)),
StateAdam (X (ADRnnMnistParameters Concrete r)))
runBatch (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
(Concrete (X (ADRnnMnistParameters Concrete r)),
StateAdam (X (ADRnnMnistParameters Concrete r)))
paramsStateAdam [(Int, [MnistDataR r])]
chunks
runEpoch (succ n) res
res <- runEpoch 1 (targetInit, initialStateAdam ftk)
let testErrorFinal =
r
1 r -> r -> r
forall a. Num a => a -> a -> a
- Int
-> MnistDataBatchR r
-> Concrete (X (ADRnnMnistParameters Concrete r))
-> r
ftest ((Int
totalBatchSize Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
maxBatches) Int -> Int -> Int
forall a. Ord a => a -> a -> a
`min` Int
10000)
MnistDataBatchR r
testDataR Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
Concrete (X (ADRnnMnistParameters Concrete r))
res
testErrorFinal @?~ expected
{-# SPECIALIZE mnistTestCaseRNNRI
:: String
-> Int -> Int -> Int -> Int -> Int -> Double
-> TestTree #-}
tensorADValMnistTestsRNNRI :: TestTree
tensorADValMnistTestsRNNRI :: TestTree
tensorADValMnistTestsRNNRI = String -> [TestTree] -> TestTree
testGroup String
"RNNR Intermediate MNIST tests"
[ String -> Int -> Int -> Int -> Int -> Int -> Double -> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
AssertEqualUpToEpsilon r) =>
String -> Int -> Int -> Int -> Int -> Int -> r -> TestTree
mnistTestCaseRNNRI String
"RNNRI 1 epoch, 1 batch" Int
1 Int
1 Int
128 Int
150 Int
5000
(Double
0.6026 :: Double)
, String -> Int -> Int -> Int -> Int -> Int -> Float -> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
AssertEqualUpToEpsilon r) =>
String -> Int -> Int -> Int -> Int -> Int -> r -> TestTree
mnistTestCaseRNNRI String
"RNNRI artificial 1 2 3 4 5" Int
2 Int
3 Int
4 Int
5 Int
50
(Float
0.8933333 :: Float)
, String -> Int -> Int -> Int -> Int -> Int -> Double -> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
AssertEqualUpToEpsilon r) =>
String -> Int -> Int -> Int -> Int -> Int -> r -> TestTree
mnistTestCaseRNNRI String
"RNNRI artificial 5 4 3 2 1" Int
5 Int
4 Int
3 Int
2 Int
49
(Double
0.8622448979591837 :: Double)
, String -> Int -> Int -> Int -> Int -> Int -> Float -> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
AssertEqualUpToEpsilon r) =>
String -> Int -> Int -> Int -> Int -> Int -> r -> TestTree
mnistTestCaseRNNRI String
"RNNRI 1 epoch, 0 batch" Int
1 Int
0 Int
128 Int
150 Int
50
(Float
1.0 :: Float)
]
mnistTestCaseRNNRO
:: forall r.
( Differentiable r, GoodScalar r
, PrintfArg r, AssertEqualUpToEpsilon r, ADTensorScalar r ~ r )
=> String
-> Int -> Int -> Int -> Int -> Int -> r
-> TestTree
mnistTestCaseRNNRO :: forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
AssertEqualUpToEpsilon r,
(ADTensorScalar r :: Type) ~ (r :: Type)) =>
String -> Int -> Int -> Int -> Int -> Int -> r -> TestTree
mnistTestCaseRNNRO String
prefix Int
epochs Int
maxBatches Int
width Int
miniBatchSize Int
totalBatchSize
r
expected =
Int
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
width ((forall (n :: Nat). KnownNat n => SNat n -> TestTree) -> TestTree)
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat @width) ->
let targetInit :: NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
targetInit =
Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))
-> NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
forall a b. (a -> b) -> a -> b
$ (Concrete (X (ADRnnMnistParametersShaped Concrete n r)), StdGen)
-> Concrete (X (ADRnnMnistParametersShaped Concrete n r))
forall a b. (a, b) -> a
fst
((Concrete (X (ADRnnMnistParametersShaped Concrete n r)), StdGen)
-> Concrete (X (ADRnnMnistParametersShaped Concrete n r)))
-> (Concrete (X (ADRnnMnistParametersShaped Concrete n r)), StdGen)
-> Concrete (X (ADRnnMnistParametersShaped Concrete n r))
forall a b. (a -> b) -> a -> b
$ forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue @(Concrete (X (ADRnnMnistParametersShaped
Concrete width r)))
Double
0.23 (Int -> StdGen
mkStdGen Int
44)
name :: String
name = String
prefix String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": "
String -> String -> String
forall a. [a] -> [a] -> [a]
++ [String] -> String
unwords [ Int -> String
forall a. Show a => a -> String
show Int
epochs, Int -> String
forall a. Show a => a -> String
show Int
maxBatches
, Int -> String
forall a. Show a => a -> String
show Int
width, Int -> String
forall a. Show a => a -> String
show Int
miniBatchSize
, Int -> String
forall a. Show a => a -> String
show (Int -> String) -> Int -> String
forall a b. (a -> b) -> a -> b
$ SingletonTK (X (ADRnnMnistParameters Concrete r)) -> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK
(SingletonTK (X (ADRnnMnistParameters Concrete r)) -> Int)
-> SingletonTK (X (ADRnnMnistParameters Concrete r)) -> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(X (ADRnnMnistParameters Concrete r))
, Int -> String
forall a. Show a => a -> String
show (SingletonTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
targetInit) ]
ftest :: Int -> MnistDataBatchR r
-> Concrete (X (ADRnnMnistParameters Concrete r))
-> r
ftest :: Int
-> MnistDataBatchR r
-> Concrete (X (ADRnnMnistParameters Concrete r))
-> r
ftest Int
batch_size MnistDataBatchR r
mnistData Concrete (X (ADRnnMnistParameters Concrete r))
pars =
Int -> MnistDataBatchR r -> ADRnnMnistParameters Concrete r -> r
forall (target :: TK -> Type) r.
((target :: (TK -> Type)) ~ (Concrete :: (TK -> Type)),
GoodScalar r, Differentiable r) =>
Int -> MnistDataBatchR r -> ADRnnMnistParameters target r -> r
MnistRnnRanked2.rnnMnistTestR
Int
batch_size MnistDataBatchR r
mnistData (forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget @Concrete Concrete (X (ADRnnMnistParameters Concrete r))
pars)
in String -> Assertion -> TestTree
testCase String
name (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ do
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: Epochs to run/max batches per epoch: %d/%d"
String
prefix Int
epochs Int
maxBatches
trainData <- (MnistData r -> MnistDataR r) -> [MnistData r] -> [MnistDataR r]
forall a b. (a -> b) -> [a] -> [b]
map MnistData r -> MnistDataR r
forall r. PrimElt r => MnistData r -> MnistDataR r
mkMnistDataR
([MnistData r] -> [MnistDataR r])
-> IO [MnistData r] -> IO [MnistDataR r]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> String -> IO [MnistData r]
forall r.
(Storable r, Fractional r) =>
String -> String -> IO [MnistData r]
loadMnistData String
trainGlyphsPath String
trainLabelsPath
testData <- map mkMnistDataR . take (totalBatchSize * maxBatches)
<$> loadMnistData testGlyphsPath testLabelsPath
let testDataR = [MnistDataR r] -> MnistDataBatchR r
forall r. Elt r => [MnistDataR r] -> MnistDataBatchR r
mkMnistDataBatchR [MnistDataR r]
testData
dataInit = case Int -> [MnistDataR r] -> [[MnistDataR r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
miniBatchSize [MnistDataR r]
testData of
[MnistDataR r]
d : [[MnistDataR r]]
_ -> let (Ranked 3 r
dglyph, Ranked 2 r
dlabel) = [MnistDataR r] -> MnistDataBatchR r
forall r. Elt r => [MnistDataR r] -> MnistDataBatchR r
mkMnistDataBatchR [MnistDataR r]
d
in (Ranked 3 r -> Concrete (TKR 3 r)
forall r (target :: TK -> Type) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 3 r
dglyph, Ranked 2 r -> Concrete (TKR2 2 (TKScalar r))
forall r (target :: TK -> Type) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 2 r
dlabel)
[] -> String -> (Concrete (TKR 3 r), Concrete (TKR2 2 (TKScalar r)))
forall a. HasCallStack => String -> a
error String
"empty test data"
f :: ( ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) r
, ( AstTensor AstMethodLet FullSpan (TKR 3 r)
, AstTensor AstMethodLet FullSpan (TKR 2 r) ) )
-> AstTensor AstMethodLet FullSpan (TKScalar r)
f = \ (ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) r
pars, (AstTensor AstMethodLet FullSpan (TKR 3 r)
glyphR, AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar r))
labelR)) ->
Int
-> (PrimalOf (AstTensor AstMethodLet FullSpan) (TKR 3 r),
PrimalOf (AstTensor AstMethodLet FullSpan) (TKR2 2 (TKScalar r)))
-> ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) r
-> AstTensor AstMethodLet FullSpan (TKScalar r)
forall (target :: TK -> Type) r.
(ADReady target, ADReady (PrimalOf target), GoodScalar r,
Differentiable r) =>
Int
-> (PrimalOf target (TKR 3 r), PrimalOf target (TKR 2 r))
-> ADRnnMnistParameters target r
-> target (TKScalar r)
MnistRnnRanked2.rnnMnistLossFusedR
Int
miniBatchSize (AstTensor AstMethodLet FullSpan (TKR 3 r)
-> PrimalOf (AstTensor AstMethodLet FullSpan) (TKR 3 r)
forall (target :: TK -> Type) (n :: Nat) (x :: TK).
BaseTensor target =>
target (TKR2 n x) -> PrimalOf target (TKR2 n x)
rprimalPart AstTensor AstMethodLet FullSpan (TKR 3 r)
glyphR, AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar r))
-> PrimalOf (AstTensor AstMethodLet FullSpan) (TKR2 2 (TKScalar r))
forall (target :: TK -> Type) (n :: Nat) (x :: TK).
BaseTensor target =>
target (TKR2 n x) -> PrimalOf target (TKR2 n x)
rprimalPart AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar r))
labelR) ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) r
pars
artRaw = ((ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) r,
(AstTensor AstMethodLet FullSpan (TKR 3 r),
AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar r))))
-> AstTensor AstMethodLet FullSpan (TKScalar r))
-> Value
(ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) r,
(AstTensor AstMethodLet FullSpan (TKR 3 r),
AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar r))))
-> AstArtifactRev
(X (ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) r,
(AstTensor AstMethodLet FullSpan (TKR 3 r),
AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar r)))))
(TKScalar r)
forall src r tgt.
((X src :: TK) ~ (X (Value src) :: TK), KnownSTK (X src),
AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
AdaptableTarget Concrete (Value src),
(tgt :: Type)
~ (AstTensor AstMethodLet FullSpan (TKScalar r) :: Type)) =>
(src -> tgt) -> Value src -> AstArtifactRev (X src) (TKScalar r)
gradArtifact (ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) r,
(AstTensor AstMethodLet FullSpan (TKR 3 r),
AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar r))))
-> AstTensor AstMethodLet FullSpan (TKScalar r)
f (Concrete (X (ADRnnMnistParameters Concrete r))
-> ADRnnMnistParameters Concrete r
forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget Concrete (X (ADRnnMnistParameters Concrete r))
NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
targetInit, (Concrete (TKR 3 r), Concrete (TKR2 2 (TKScalar r)))
dataInit)
art = AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
(TKScalar r)
-> AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
(TKScalar r)
forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstArtifactRev x z
simplifyArtifactGradient AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
(TKScalar r)
AstArtifactRev
(X (ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) r,
(AstTensor AstMethodLet FullSpan (TKR 3 r),
AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar r)))))
(TKScalar r)
artRaw
go :: [MnistDataBatchR r]
-> ( Concrete (X (ADRnnMnistParameters Concrete r))
, StateAdam (X (ADRnnMnistParameters Concrete r)) )
-> ( Concrete (X (ADRnnMnistParameters Concrete r))
, StateAdam (X (ADRnnMnistParameters Concrete r)) )
go [] (Concrete (X (ADRnnMnistParameters Concrete r))
parameters, StateAdam (X (ADRnnMnistParameters Concrete r))
stateAdam) = (Concrete (X (ADRnnMnistParameters Concrete r))
parameters, StateAdam (X (ADRnnMnistParameters Concrete r))
stateAdam)
go ((Ranked 3 r
glyph, Ranked 2 r
label) : [MnistDataBatchR r]
rest) (!Concrete (X (ADRnnMnistParameters Concrete r))
parameters, !StateAdam (X (ADRnnMnistParameters Concrete r))
stateAdam) =
let parametersAndInput :: Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
parametersAndInput =
Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> Concrete (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r)))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
forall (x :: TK) (z :: TK).
Concrete x -> Concrete z -> Concrete (TKProduct x z)
forall (target :: TK -> Type) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
Concrete (X (ADRnnMnistParameters Concrete r))
parameters (Concrete (TKR 3 r)
-> Concrete (TKR2 2 (TKScalar r))
-> Concrete (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r)))
forall (x :: TK) (z :: TK).
Concrete x -> Concrete z -> Concrete (TKProduct x z)
forall (target :: TK -> Type) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair (Ranked 3 r -> Concrete (TKR 3 r)
forall r (target :: TK -> Type) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 3 r
glyph) (Ranked 2 r -> Concrete (TKR2 2 (TKScalar r))
forall r (target :: TK -> Type) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 2 r
label))
gradient :: Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
gradient = Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
forall (x :: TK) (z :: TK). Concrete (TKProduct x z) -> Concrete x
forall (target :: TK -> Type) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
forall a b. (a -> b) -> a -> b
$ (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct (TKR 3 r) (TKR2 2 (TKScalar r)))),
Concrete (TKScalar r))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
forall a b. (a, b) -> a
fst
((Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct (TKR 3 r) (TKR2 2 (TKScalar r)))),
Concrete (TKScalar r))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct (TKR 3 r) (TKR2 2 (TKScalar r)))))
-> (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct (TKR 3 r) (TKR2 2 (TKScalar r)))),
Concrete (TKScalar r))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
forall a b. (a -> b) -> a -> b
$ AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
(TKScalar r)
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
-> Maybe (Concrete (ADTensorKind (TKScalar r)))
-> (Concrete
(ADTensorKind
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))),
Concrete (TKScalar r))
forall (x :: TK) (z :: TK).
AstArtifactRev x z
-> Concrete x
-> Maybe (Concrete (ADTensorKind z))
-> (Concrete (ADTensorKind x), Concrete z)
revInterpretArtifact
AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
(TKScalar r)
art Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
parametersAndInput Maybe (Concrete (ADTensorKind (TKScalar r)))
Maybe (Concrete (TKScalar r))
forall a. Maybe a
Nothing
in [MnistDataBatchR r]
-> (Concrete (X (ADRnnMnistParameters Concrete r)),
StateAdam (X (ADRnnMnistParameters Concrete r)))
-> (Concrete (X (ADRnnMnistParameters Concrete r)),
StateAdam (X (ADRnnMnistParameters Concrete r)))
go [MnistDataBatchR r]
rest (forall (y :: TK).
ArgsAdam
-> StateAdam y
-> SingletonTK y
-> Concrete y
-> Concrete (ADTensorKind y)
-> (Concrete y, StateAdam y)
updateWithGradientAdam
@(X (ADRnnMnistParameters Concrete r))
ArgsAdam
defaultArgsAdam StateAdam (X (ADRnnMnistParameters Concrete r))
stateAdam SingletonTK
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
SingletonTK (X (ADRnnMnistParameters Concrete r))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete (X (ADRnnMnistParameters Concrete r))
parameters
Concrete (ADTensorKind (X (ADRnnMnistParameters Concrete r)))
Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
gradient)
runBatch :: ( Concrete (X (ADRnnMnistParameters Concrete r))
, StateAdam (X (ADRnnMnistParameters Concrete r)) )
-> (Int, [MnistDataR r])
-> IO ( Concrete (X (ADRnnMnistParameters Concrete r))
, StateAdam (X (ADRnnMnistParameters Concrete r)) )
runBatch (!Concrete (X (ADRnnMnistParameters Concrete r))
parameters, !StateAdam (X (ADRnnMnistParameters Concrete r))
stateAdam) (Int
k, [MnistDataR r]
chunk) = do
let chunkR :: [MnistDataBatchR r]
chunkR = ([MnistDataR r] -> MnistDataBatchR r)
-> [[MnistDataR r]] -> [MnistDataBatchR r]
forall a b. (a -> b) -> [a] -> [b]
map [MnistDataR r] -> MnistDataBatchR r
forall r. Elt r => [MnistDataR r] -> MnistDataBatchR r
mkMnistDataBatchR
([[MnistDataR r]] -> [MnistDataBatchR r])
-> [[MnistDataR r]] -> [MnistDataBatchR r]
forall a b. (a -> b) -> a -> b
$ ([MnistDataR r] -> Bool) -> [[MnistDataR r]] -> [[MnistDataR r]]
forall a. (a -> Bool) -> [a] -> [a]
filter (\[MnistDataR r]
ch -> [MnistDataR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataR r]
ch Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
miniBatchSize)
([[MnistDataR r]] -> [[MnistDataR r]])
-> [[MnistDataR r]] -> [[MnistDataR r]]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataR r] -> [[MnistDataR r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
miniBatchSize [MnistDataR r]
chunk
res :: (Concrete (X (ADRnnMnistParameters Concrete r)),
StateAdam (X (ADRnnMnistParameters Concrete r)))
res@(Concrete (X (ADRnnMnistParameters Concrete r))
parameters2, StateAdam (X (ADRnnMnistParameters Concrete r))
_) = [MnistDataBatchR r]
-> (Concrete (X (ADRnnMnistParameters Concrete r)),
StateAdam (X (ADRnnMnistParameters Concrete r)))
-> (Concrete (X (ADRnnMnistParameters Concrete r)),
StateAdam (X (ADRnnMnistParameters Concrete r)))
go [MnistDataBatchR r]
chunkR (Concrete (X (ADRnnMnistParameters Concrete r))
parameters, StateAdam (X (ADRnnMnistParameters Concrete r))
stateAdam)
trainScore :: r
trainScore =
Int
-> MnistDataBatchR r
-> Concrete (X (ADRnnMnistParameters Concrete r))
-> r
ftest ([MnistDataR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataR r]
chunk) ([MnistDataR r] -> MnistDataBatchR r
forall r. Elt r => [MnistDataR r] -> MnistDataBatchR r
mkMnistDataBatchR [MnistDataR r]
chunk) Concrete (X (ADRnnMnistParameters Concrete r))
parameters2
testScore :: r
testScore =
Int
-> MnistDataBatchR r
-> Concrete (X (ADRnnMnistParameters Concrete r))
-> r
ftest ((Int
totalBatchSize Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
maxBatches) Int -> Int -> Int
forall a. Ord a => a -> a -> a
`min` Int
10000)
MnistDataBatchR r
testDataR Concrete (X (ADRnnMnistParameters Concrete r))
parameters2
lenChunk :: Int
lenChunk = [MnistDataR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataR r]
chunk
Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
width Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$ do
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: (Batch %d with %d points)"
String
prefix Int
k Int
lenChunk
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Training error: %.2f%%"
String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
trainScore) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Validation error: %.2f%%"
String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
testScore ) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
(Concrete (X (ADRnnMnistParameters Concrete r)),
StateAdam (X (ADRnnMnistParameters Concrete r)))
res
let runEpoch :: Int
-> ( Concrete (X (ADRnnMnistParameters Concrete r))
, StateAdam (X (ADRnnMnistParameters Concrete r)) )
-> IO (Concrete (X (ADRnnMnistParameters Concrete r)))
runEpoch Int
n (Concrete (X (ADRnnMnistParameters Concrete r))
params2, StateAdam (X (ADRnnMnistParameters Concrete r))
_) | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
epochs = Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
Concrete (X (ADRnnMnistParameters Concrete r))
params2
runEpoch Int
n paramsStateAdam :: (Concrete (X (ADRnnMnistParameters Concrete r)),
StateAdam (X (ADRnnMnistParameters Concrete r)))
paramsStateAdam@(!Concrete (X (ADRnnMnistParameters Concrete r))
_, !StateAdam (X (ADRnnMnistParameters Concrete r))
_) = do
Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
width Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$ String -> String -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: [Epoch %d]" String
prefix Int
n
let trainDataShuffled :: [MnistDataR r]
trainDataShuffled = StdGen -> [MnistDataR r] -> [MnistDataR r]
forall a. StdGen -> [a] -> [a]
shuffle (Int -> StdGen
mkStdGen (Int -> StdGen) -> Int -> StdGen
forall a b. (a -> b) -> a -> b
$ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
5) [MnistDataR r]
trainData
chunks :: [(Int, [MnistDataR r])]
chunks = Int -> [(Int, [MnistDataR r])] -> [(Int, [MnistDataR r])]
forall a. Int -> [a] -> [a]
take Int
maxBatches
([(Int, [MnistDataR r])] -> [(Int, [MnistDataR r])])
-> [(Int, [MnistDataR r])] -> [(Int, [MnistDataR r])]
forall a b. (a -> b) -> a -> b
$ [Int] -> [[MnistDataR r]] -> [(Int, [MnistDataR r])]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
Item [Int]
1 ..]
([[MnistDataR r]] -> [(Int, [MnistDataR r])])
-> [[MnistDataR r]] -> [(Int, [MnistDataR r])]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataR r] -> [[MnistDataR r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
totalBatchSize [MnistDataR r]
trainDataShuffled
res <- ((Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> (Int, [MnistDataR r])
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
-> (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> [(Int, [MnistDataR r])]
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
forall (t :: Type -> Type) (m :: Type -> Type) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> (Int, [MnistDataR r])
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
(Concrete (X (ADRnnMnistParameters Concrete r)),
StateAdam (X (ADRnnMnistParameters Concrete r)))
-> (Int, [MnistDataR r])
-> IO
(Concrete (X (ADRnnMnistParameters Concrete r)),
StateAdam (X (ADRnnMnistParameters Concrete r)))
runBatch (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
(Concrete (X (ADRnnMnistParameters Concrete r)),
StateAdam (X (ADRnnMnistParameters Concrete r)))
paramsStateAdam [(Int, [MnistDataR r])]
chunks
runEpoch (succ n) res
ftk = forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(X (ADRnnMnistParameters
Concrete r)))
Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat SizeMnistHeight ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKProduct
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ((':) @Nat n ('[] @Nat))) (TKScalar r)))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
targetInit
res <- runEpoch 1 (targetInit, initialStateAdam ftk)
let testErrorFinal =
r
1 r -> r -> r
forall a. Num a => a -> a -> a
- Int
-> MnistDataBatchR r
-> Concrete (X (ADRnnMnistParameters Concrete r))
-> r
ftest ((Int
totalBatchSize Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
maxBatches) Int -> Int -> Int
forall a. Ord a => a -> a -> a
`min` Int
10000)
MnistDataBatchR r
testDataR Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r)))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 2 (TKScalar r)))
(TKR2 1 (TKScalar r))))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))
Concrete (X (ADRnnMnistParameters Concrete r))
res
assertEqualUpToEpsilon 1e-1 expected testErrorFinal
{-# SPECIALIZE mnistTestCaseRNNRO
:: String
-> Int -> Int -> Int -> Int -> Int -> Double
-> TestTree #-}
tensorADValMnistTestsRNNRO :: TestTree
tensorADValMnistTestsRNNRO :: TestTree
tensorADValMnistTestsRNNRO = String -> [TestTree] -> TestTree
testGroup String
"RNNR Once MNIST tests"
[ String -> Int -> Int -> Int -> Int -> Int -> Double -> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
AssertEqualUpToEpsilon r,
(ADTensorScalar r :: Type) ~ (r :: Type)) =>
String -> Int -> Int -> Int -> Int -> Int -> r -> TestTree
mnistTestCaseRNNRO String
"RNNRO 1 epoch, 1 batch" Int
1 Int
1 Int
128 Int
150 Int
5000
(Double
0.6026 :: Double)
, String -> Int -> Int -> Int -> Int -> Int -> Float -> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
AssertEqualUpToEpsilon r,
(ADTensorScalar r :: Type) ~ (r :: Type)) =>
String -> Int -> Int -> Int -> Int -> Int -> r -> TestTree
mnistTestCaseRNNRO String
"RNNRO artificial 1 2 3 4 5" Int
2 Int
3 Int
4 Int
5 Int
50
(Float
0.8933333 :: Float)
, String -> Int -> Int -> Int -> Int -> Int -> Double -> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
AssertEqualUpToEpsilon r,
(ADTensorScalar r :: Type) ~ (r :: Type)) =>
String -> Int -> Int -> Int -> Int -> Int -> r -> TestTree
mnistTestCaseRNNRO String
"RNNRO artificial 5 4 3 2 1" Int
5 Int
4 Int
3 Int
2 Int
49
(Double
0.8928571428571429 :: Double)
, String -> Int -> Int -> Int -> Int -> Int -> Float -> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
AssertEqualUpToEpsilon r,
(ADTensorScalar r :: Type) ~ (r :: Type)) =>
String -> Int -> Int -> Int -> Int -> Int -> r -> TestTree
mnistTestCaseRNNRO String
"RNNRO 1 epoch, 0 batch" Int
1 Int
0 Int
128 Int
150 Int
50
(Float
1.0 :: Float)
]