{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module TestMnistCNNR
( testTrees
) where
import Prelude
import Control.Monad (foldM, unless)
import System.IO (hPutStrLn, stderr)
import System.Random
import Test.Tasty
import Test.Tasty.HUnit hiding (assert)
import Text.Printf
import Data.Array.Nested.Ranked.Shape
import HordeAd
import HordeAd.Core.Adaptor
import HordeAd.Core.AstEnv
import HordeAd.Core.AstFreshId
import HordeAd.Core.AstInterpret
import EqEpsilon
import MnistCnnRanked2 qualified
import MnistData
testTrees :: [TestTree]
testTrees :: [TestTree]
testTrees = [ TestTree
tensorADValMnistTestsCNNRA
, TestTree
tensorADValMnistTestsCNNRI
, TestTree
tensorADValMnistTestsCNNRO
]
type XParams r = X (MnistCnnRanked2.ADCnnMnistParameters Concrete r)
mnistTestCaseCNNRA
:: forall r.
(Differentiable r, GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r)
=> String
-> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> r
-> TestTree
mnistTestCaseCNNRA :: forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNRA String
prefix Int
epochs Int
maxBatches Int
khInt Int
kwInt Int
c_outInt Int
n_hiddenInt
Int
miniBatchSize Int
totalBatchSize r
expected =
Int
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
khInt ((forall (n :: Nat). KnownNat n => SNat n -> TestTree) -> TestTree)
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
_khSNat :: SNat kh) ->
Int
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
kwInt ((forall (n :: Nat). KnownNat n => SNat n -> TestTree) -> TestTree)
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
_kwSNat :: SNat kw) ->
Int
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
c_outInt ((forall (n :: Nat). KnownNat n => SNat n -> TestTree) -> TestTree)
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
_c_outSNat :: SNat c_out) ->
Int
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
n_hiddenInt ((forall (n :: Nat). KnownNat n => SNat n -> TestTree) -> TestTree)
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
_n_hiddenSNat :: SNat n_hidden) ->
let targetInit :: NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
targetInit =
Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
-> NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
-> NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
-> NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
forall a b. (a -> b) -> a -> b
$ (Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight n n n n r)),
StdGen)
-> Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight n n n n r))
forall a b. (a, b) -> a
fst
((Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight n n n n r)),
StdGen)
-> Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight n n n n r)))
-> (Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight n n n n r)),
StdGen)
-> Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight n n n n r))
forall a b. (a -> b) -> a -> b
$ forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue
@(Concrete (X (MnistCnnRanked2.ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistWidth
kh kw c_out n_hidden r)))
Double
0.4 (Int -> StdGen
mkStdGen Int
44)
name :: String
name = String
prefix String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": "
String -> String -> String
forall a. [a] -> [a] -> [a]
++ [String] -> String
unwords [ Int -> String
forall a. Show a => a -> String
show Int
epochs, Int -> String
forall a. Show a => a -> String
show Int
maxBatches
, Int -> String
forall a. Show a => a -> String
show Int
khInt, Int -> String
forall a. Show a => a -> String
show Int
kwInt
, Int -> String
forall a. Show a => a -> String
show Int
c_outInt, Int -> String
forall a. Show a => a -> String
show Int
n_hiddenInt
, Int -> String
forall a. Show a => a -> String
show Int
miniBatchSize
, Int -> String
forall a. Show a => a -> String
show (Int -> String) -> Int -> String
forall a b. (a -> b) -> a -> b
$ SingletonTK (XParams r) -> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK (SingletonTK (XParams r) -> Int) -> SingletonTK (XParams r) -> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams r)
, Int -> String
forall a. Show a => a -> String
show (SingletonTK
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
targetInit) ]
ftest :: Int -> MnistDataBatchR r -> Concrete (XParams r) -> r
ftest :: Int -> MnistDataBatchR r -> Concrete (XParams r) -> r
ftest Int
batch_size MnistDataBatchR r
mnistData Concrete (XParams r)
pars =
Int -> MnistDataBatchR r -> ADCnnMnistParameters Concrete r -> r
forall (target :: TK -> Type) r.
((target :: (TK -> Type)) ~ (Concrete :: (TK -> Type)),
GoodScalar r, Differentiable r) =>
Int -> MnistDataBatchR r -> ADCnnMnistParameters Concrete r -> r
MnistCnnRanked2.convMnistTestR
Int
batch_size MnistDataBatchR r
mnistData (forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget @Concrete Concrete (XParams r)
pars)
in String -> Assertion -> TestTree
testCase String
name (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ do
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: Epochs to run/max batches per epoch: %d/%d"
String
prefix Int
epochs Int
maxBatches
trainData <- (MnistData r -> MnistDataR r) -> [MnistData r] -> [MnistDataR r]
forall a b. (a -> b) -> [a] -> [b]
map MnistData r -> MnistDataR r
forall r. PrimElt r => MnistData r -> MnistDataR r
mkMnistDataR
([MnistData r] -> [MnistDataR r])
-> IO [MnistData r] -> IO [MnistDataR r]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> String -> IO [MnistData r]
forall r.
(Storable r, Fractional r) =>
String -> String -> IO [MnistData r]
loadMnistData String
trainGlyphsPath String
trainLabelsPath
testData <- map mkMnistDataR . take (totalBatchSize * maxBatches)
<$> loadMnistData testGlyphsPath testLabelsPath
let testDataR = [MnistDataR r] -> MnistDataBatchR r
forall r. Elt r => [MnistDataR r] -> MnistDataBatchR r
mkMnistDataBatchR [MnistDataR r]
testData
f :: MnistDataBatchR r -> ADVal Concrete (XParams r)
-> ADVal Concrete (TKScalar r)
f (Ranked 3 r
glyphR, Ranked 2 r
labelR) ADVal Concrete (XParams r)
adinputs =
Int
-> (PrimalOf (ADVal Concrete) (TKR 3 r),
PrimalOf (ADVal Concrete) (TKR2 2 (TKScalar r)))
-> ADCnnMnistParameters (ADVal Concrete) r
-> ADVal Concrete (TKScalar r)
forall (target :: TK -> Type) r.
(ADReady target, ADReady (PrimalOf target), GoodScalar r,
Differentiable r) =>
Int
-> (PrimalOf target (TKR 3 r), PrimalOf target (TKR 2 r))
-> ADCnnMnistParameters target r
-> target (TKScalar r)
MnistCnnRanked2.convMnistLossFusedR
Int
miniBatchSize (Ranked 3 r -> Concrete (TKR 3 r)
forall r (target :: TK -> Type) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 3 r
glyphR, Ranked 2 r -> Concrete (TKR2 2 (TKScalar r))
forall r (target :: TK -> Type) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 2 r
labelR)
(ADVal Concrete (X (ADCnnMnistParameters (ADVal Concrete) r))
-> ADCnnMnistParameters (ADVal Concrete) r
forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget ADVal Concrete (X (ADCnnMnistParameters (ADVal Concrete) r))
ADVal Concrete (XParams r)
adinputs)
runBatch :: (Concrete (XParams r), StateAdam (XParams r))
-> (Int, [MnistDataR r])
-> IO (Concrete (XParams r), StateAdam (XParams r))
runBatch (!Concrete (XParams r)
parameters, !StateAdam (XParams r)
stateAdam) (Int
k, [MnistDataR r]
chunk) = do
let chunkR :: [MnistDataBatchR r]
chunkR = ([MnistDataR r] -> MnistDataBatchR r)
-> [[MnistDataR r]] -> [MnistDataBatchR r]
forall a b. (a -> b) -> [a] -> [b]
map [MnistDataR r] -> MnistDataBatchR r
forall r. Elt r => [MnistDataR r] -> MnistDataBatchR r
mkMnistDataBatchR
([[MnistDataR r]] -> [MnistDataBatchR r])
-> [[MnistDataR r]] -> [MnistDataBatchR r]
forall a b. (a -> b) -> a -> b
$ ([MnistDataR r] -> Bool) -> [[MnistDataR r]] -> [[MnistDataR r]]
forall a. (a -> Bool) -> [a] -> [a]
filter (\[MnistDataR r]
ch -> [MnistDataR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataR r]
ch Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
miniBatchSize)
([[MnistDataR r]] -> [[MnistDataR r]])
-> [[MnistDataR r]] -> [[MnistDataR r]]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataR r] -> [[MnistDataR r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
miniBatchSize [MnistDataR r]
chunk
res :: (Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
res@(Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
parameters2, StateAdam
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
_) =
(MnistDataBatchR r
-> ADVal
Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> ADVal Concrete (TKScalar r))
-> [MnistDataBatchR r]
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> StateAdam
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> (Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
forall a (x :: TK) (z :: TK).
KnownSTK x =>
(a -> ADVal Concrete x -> ADVal Concrete z)
-> [a] -> Concrete x -> StateAdam x -> (Concrete x, StateAdam x)
sgdAdam MnistDataBatchR r
-> ADVal
Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> ADVal Concrete (TKScalar r)
MnistDataBatchR r
-> ADVal Concrete (XParams r) -> ADVal Concrete (TKScalar r)
f [MnistDataBatchR r]
chunkR Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
Concrete (XParams r)
parameters StateAdam
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
StateAdam (XParams r)
stateAdam
trainScore :: r
trainScore =
Int -> MnistDataBatchR r -> Concrete (XParams r) -> r
ftest ([MnistDataR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataR r]
chunk) ([MnistDataR r] -> MnistDataBatchR r
forall r. Elt r => [MnistDataR r] -> MnistDataBatchR r
mkMnistDataBatchR [MnistDataR r]
chunk) Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
Concrete (XParams r)
parameters2
testScore :: r
testScore =
Int -> MnistDataBatchR r -> Concrete (XParams r) -> r
ftest (Int
totalBatchSize Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
maxBatches) MnistDataBatchR r
testDataR Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
Concrete (XParams r)
parameters2
lenChunk :: Int
lenChunk = [MnistDataR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataR r]
chunk
Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
n_hiddenInt Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$ do
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: (Batch %d with %d points)"
String
prefix Int
k Int
lenChunk
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Training error: %.2f%%"
String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
trainScore) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Validation error: %.2f%%"
String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
testScore ) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
(Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
res
let runEpoch :: Int
-> (Concrete (XParams r), StateAdam (XParams r))
-> IO (Concrete (XParams r))
runEpoch Int
n (Concrete (XParams r)
params2, StateAdam (XParams r)
_) | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
epochs = Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
Concrete (XParams r)
params2
runEpoch Int
n paramsStateAdam :: (Concrete (XParams r), StateAdam (XParams r))
paramsStateAdam@(!Concrete (XParams r)
_, !StateAdam (XParams r)
_) = do
Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
n_hiddenInt Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$ String -> String -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: [Epoch %d]" String
prefix Int
n
let trainDataShuffled :: [MnistDataR r]
trainDataShuffled = StdGen -> [MnistDataR r] -> [MnistDataR r]
forall a. StdGen -> [a] -> [a]
shuffle (Int -> StdGen
mkStdGen (Int -> StdGen) -> Int -> StdGen
forall a b. (a -> b) -> a -> b
$ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
5) [MnistDataR r]
trainData
chunks :: [(Int, [MnistDataR r])]
chunks = Int -> [(Int, [MnistDataR r])] -> [(Int, [MnistDataR r])]
forall a. Int -> [a] -> [a]
take Int
maxBatches
([(Int, [MnistDataR r])] -> [(Int, [MnistDataR r])])
-> [(Int, [MnistDataR r])] -> [(Int, [MnistDataR r])]
forall a b. (a -> b) -> a -> b
$ [Int] -> [[MnistDataR r]] -> [(Int, [MnistDataR r])]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
1 ..]
([[MnistDataR r]] -> [(Int, [MnistDataR r])])
-> [[MnistDataR r]] -> [(Int, [MnistDataR r])]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataR r] -> [[MnistDataR r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
totalBatchSize [MnistDataR r]
trainDataShuffled
res <- ((Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
-> (Int, [MnistDataR r])
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))))
-> (Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
-> [(Int, [MnistDataR r])]
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
forall (t :: Type -> Type) (m :: Type -> Type) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
-> (Int, [MnistDataR r])
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
(Concrete (XParams r), StateAdam (XParams r))
-> (Int, [MnistDataR r])
-> IO (Concrete (XParams r), StateAdam (XParams r))
runBatch (Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
(Concrete (XParams r), StateAdam (XParams r))
paramsStateAdam [(Int, [MnistDataR r])]
chunks
runEpoch (succ n) res
ftk = forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams r)) Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
targetInit
res <- runEpoch 1 (targetInit, initialStateAdam ftk)
let testErrorFinal =
r
1 r -> r -> r
forall a. Num a => a -> a -> a
- Int -> MnistDataBatchR r -> Concrete (XParams r) -> r
ftest (Int
totalBatchSize Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
maxBatches) MnistDataBatchR r
testDataR Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
Concrete (XParams r)
res
testErrorFinal @?~ expected
{-# SPECIALIZE mnistTestCaseCNNRA
:: String
-> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Double
-> TestTree #-}
tensorADValMnistTestsCNNRA :: TestTree
tensorADValMnistTestsCNNRA :: TestTree
tensorADValMnistTestsCNNRA = String -> [TestTree] -> TestTree
testGroup String
"CNNR ADVal MNIST tests"
[ String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Double
-> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNRA String
"CNNRA 1 epoch, 1 batch" Int
1 Int
1 Int
4 Int
4 Int
8 Int
16 Int
1 Int
1
(Double
1 :: Double)
, String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Float
-> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNRA String
"CNNRA artificial 1 2 3 4 5" Int
1 Int
1 Int
2 Int
3 Int
4 Int
5 Int
1 Int
10
(Float
1 :: Float)
, String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Double
-> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNRA String
"CNNRA artificial 5 4 3 2 1" Int
5 Int
4 Int
3 Int
2 Int
1 Int
1 Int
1 Int
1
(Double
1 :: Double)
, String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Float
-> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNRA String
"CNNRA 1 epoch, 0 batch" Int
1 Int
0 Int
4 Int
4 Int
16 Int
64 Int
16 Int
50
(Float
1.0 :: Float)
]
mnistTestCaseCNNRI
:: forall r.
(Differentiable r, GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r)
=> String
-> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> r
-> TestTree
mnistTestCaseCNNRI :: forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNRI String
prefix Int
epochs Int
maxBatches Int
khInt Int
kwInt Int
c_outInt Int
n_hiddenInt
Int
miniBatchSize Int
totalBatchSize r
expected =
Int
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
khInt ((forall (n :: Nat). KnownNat n => SNat n -> TestTree) -> TestTree)
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
_khSNat :: SNat kh) ->
Int
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
kwInt ((forall (n :: Nat). KnownNat n => SNat n -> TestTree) -> TestTree)
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
_kwSNat :: SNat kw) ->
Int
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
c_outInt ((forall (n :: Nat). KnownNat n => SNat n -> TestTree) -> TestTree)
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
_c_outSNat :: SNat c_out) ->
Int
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
n_hiddenInt ((forall (n :: Nat). KnownNat n => SNat n -> TestTree) -> TestTree)
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
_n_hiddenSNat :: SNat n_hidden) ->
let targetInit :: NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
targetInit =
Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
-> NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
-> NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
-> NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
forall a b. (a -> b) -> a -> b
$ (Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight n n n n r)),
StdGen)
-> Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight n n n n r))
forall a b. (a, b) -> a
fst
((Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight n n n n r)),
StdGen)
-> Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight n n n n r)))
-> (Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight n n n n r)),
StdGen)
-> Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight n n n n r))
forall a b. (a -> b) -> a -> b
$ forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue
@(Concrete (X (MnistCnnRanked2.ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistWidth
kh kw c_out n_hidden r)))
Double
0.4 (Int -> StdGen
mkStdGen Int
44)
name :: String
name = String
prefix String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": "
String -> String -> String
forall a. [a] -> [a] -> [a]
++ [String] -> String
unwords [ Int -> String
forall a. Show a => a -> String
show Int
epochs, Int -> String
forall a. Show a => a -> String
show Int
maxBatches
, Int -> String
forall a. Show a => a -> String
show Int
khInt, Int -> String
forall a. Show a => a -> String
show Int
kwInt
, Int -> String
forall a. Show a => a -> String
show Int
c_outInt, Int -> String
forall a. Show a => a -> String
show Int
n_hiddenInt
, Int -> String
forall a. Show a => a -> String
show Int
miniBatchSize
, Int -> String
forall a. Show a => a -> String
show (Int -> String) -> Int -> String
forall a b. (a -> b) -> a -> b
$ SingletonTK (XParams r) -> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK (SingletonTK (XParams r) -> Int) -> SingletonTK (XParams r) -> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams r)
, Int -> String
forall a. Show a => a -> String
show (SingletonTK
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
targetInit) ]
ftest :: Int -> MnistDataBatchR r -> Concrete (XParams r) -> r
ftest :: Int -> MnistDataBatchR r -> Concrete (XParams r) -> r
ftest Int
batch_size MnistDataBatchR r
mnistData Concrete (XParams r)
pars =
Int -> MnistDataBatchR r -> ADCnnMnistParameters Concrete r -> r
forall (target :: TK -> Type) r.
((target :: (TK -> Type)) ~ (Concrete :: (TK -> Type)),
GoodScalar r, Differentiable r) =>
Int -> MnistDataBatchR r -> ADCnnMnistParameters Concrete r -> r
MnistCnnRanked2.convMnistTestR
Int
batch_size MnistDataBatchR r
mnistData (forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget @Concrete Concrete (XParams r)
pars)
in String -> Assertion -> TestTree
testCase String
name (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ do
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: Epochs to run/max batches per epoch: %d/%d"
String
prefix Int
epochs Int
maxBatches
trainData <- (MnistData r -> MnistDataR r) -> [MnistData r] -> [MnistDataR r]
forall a b. (a -> b) -> [a] -> [b]
map MnistData r -> MnistDataR r
forall r. PrimElt r => MnistData r -> MnistDataR r
mkMnistDataR
([MnistData r] -> [MnistDataR r])
-> IO [MnistData r] -> IO [MnistDataR r]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> String -> IO [MnistData r]
forall r.
(Storable r, Fractional r) =>
String -> String -> IO [MnistData r]
loadMnistData String
trainGlyphsPath String
trainLabelsPath
testData <- map mkMnistDataR . take (totalBatchSize * maxBatches)
<$> loadMnistData testGlyphsPath testLabelsPath
let testDataR = [MnistDataR r] -> MnistDataBatchR r
forall r. Elt r => [MnistDataR r] -> MnistDataBatchR r
mkMnistDataBatchR [MnistDataR r]
testData
ftk = forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams r)) Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
targetInit
(_, _, var, varAst2) <- funToAstRevIO ftk
(varGlyph, astGlyph) <-
funToAstIO (FTKR (miniBatchSize
:$: sizeMnistHeightInt
:$: sizeMnistWidthInt
:$: ZSR) FTKScalar) id
(varLabel, astLabel) <-
funToAstIO (FTKR (miniBatchSize
:$: sizeMnistLabelInt
:$: ZSR) FTKScalar) id
let ast :: AstTensor AstMethodLet FullSpan (TKScalar r)
ast = AstTensor AstMethodLet FullSpan (TKScalar r)
-> AstTensor AstMethodLet FullSpan (TKScalar r)
forall (z :: TK) (s :: AstSpanType).
AstSpan s =>
AstTensor AstMethodLet s z -> AstTensor AstMethodLet s z
simplifyInline
(AstTensor AstMethodLet FullSpan (TKScalar r)
-> AstTensor AstMethodLet FullSpan (TKScalar r))
-> AstTensor AstMethodLet FullSpan (TKScalar r)
-> AstTensor AstMethodLet FullSpan (TKScalar r)
forall a b. (a -> b) -> a -> b
$ Int
-> (PrimalOf
(AstTensor AstMethodLet FullSpan) (TKR2 3 (TKScalar r)),
PrimalOf (AstTensor AstMethodLet FullSpan) (TKR2 2 (TKScalar r)))
-> ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) r
-> AstTensor AstMethodLet FullSpan (TKScalar r)
forall (target :: TK -> Type) r.
(ADReady target, ADReady (PrimalOf target), GoodScalar r,
Differentiable r) =>
Int
-> (PrimalOf target (TKR 3 r), PrimalOf target (TKR 2 r))
-> ADCnnMnistParameters target r
-> target (TKScalar r)
MnistCnnRanked2.convMnistLossFusedR
Int
miniBatchSize (AstTensor AstMethodLet PrimalSpan (TKR2 3 (TKScalar r))
PrimalOf (AstTensor AstMethodLet FullSpan) (TKR2 3 (TKScalar r))
astGlyph, AstTensor AstMethodLet PrimalSpan (TKR2 2 (TKScalar r))
PrimalOf (AstTensor AstMethodLet FullSpan) (TKR2 2 (TKScalar r))
astLabel)
(AstTensor
AstMethodLet
FullSpan
(X (ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) r))
-> ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) r
forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget AstTensor
AstMethodLet
FullSpan
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
AstTensor
AstMethodLet
FullSpan
(X (ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) r))
varAst2)
f :: MnistDataBatchR r -> ADVal Concrete (XParams r)
-> ADVal Concrete (TKScalar r)
f (Ranked 3 r
glyph, Ranked 2 r
label) ADVal Concrete (XParams r)
varInputs =
let env :: AstEnv (ADVal Concrete)
env = AstVarName
FullSpan
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> ADVal
Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> AstEnv (ADVal Concrete)
-> AstEnv (ADVal Concrete)
forall (target :: TK -> Type) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName
FullSpan
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
var ADVal
Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
ADVal Concrete (XParams r)
varInputs AstEnv (ADVal Concrete)
forall (target :: TK -> Type). AstEnv target
emptyEnv
envMnist :: AstEnv (ADVal Concrete)
envMnist = AstVarName PrimalSpan (TKR2 3 (TKScalar r))
-> ADVal Concrete (TKR2 3 (TKScalar r))
-> AstEnv (ADVal Concrete)
-> AstEnv (ADVal Concrete)
forall (target :: TK -> Type) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName PrimalSpan (TKR2 3 (TKScalar r))
varGlyph (Ranked 3 r -> ADVal Concrete (TKR2 3 (TKScalar r))
forall r (target :: TK -> Type) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 3 r
glyph)
(AstEnv (ADVal Concrete) -> AstEnv (ADVal Concrete))
-> AstEnv (ADVal Concrete) -> AstEnv (ADVal Concrete)
forall a b. (a -> b) -> a -> b
$ AstVarName PrimalSpan (TKR2 2 (TKScalar r))
-> ADVal Concrete (TKR2 2 (TKScalar r))
-> AstEnv (ADVal Concrete)
-> AstEnv (ADVal Concrete)
forall (target :: TK -> Type) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName PrimalSpan (TKR2 2 (TKScalar r))
varLabel (Ranked 2 r -> ADVal Concrete (TKR2 2 (TKScalar r))
forall r (target :: TK -> Type) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 2 r
label) AstEnv (ADVal Concrete)
env
in AstEnv (ADVal Concrete)
-> AstTensor AstMethodLet FullSpan (TKScalar r)
-> ADVal Concrete (TKScalar r)
forall (target :: TK -> Type) (y :: TK).
ADReady target =>
AstEnv target -> AstTensor AstMethodLet FullSpan y -> target y
interpretAstFull AstEnv (ADVal Concrete)
envMnist AstTensor AstMethodLet FullSpan (TKScalar r)
ast
runBatch :: (Concrete (XParams r), StateAdam (XParams r))
-> (Int, [MnistDataR r])
-> IO (Concrete (XParams r), StateAdam (XParams r))
runBatch (!Concrete (XParams r)
parameters, !StateAdam (XParams r)
stateAdam) (Int
k, [MnistDataR r]
chunk) = do
let chunkR :: [MnistDataBatchR r]
chunkR = ([MnistDataR r] -> MnistDataBatchR r)
-> [[MnistDataR r]] -> [MnistDataBatchR r]
forall a b. (a -> b) -> [a] -> [b]
map [MnistDataR r] -> MnistDataBatchR r
forall r. Elt r => [MnistDataR r] -> MnistDataBatchR r
mkMnistDataBatchR
([[MnistDataR r]] -> [MnistDataBatchR r])
-> [[MnistDataR r]] -> [MnistDataBatchR r]
forall a b. (a -> b) -> a -> b
$ ([MnistDataR r] -> Bool) -> [[MnistDataR r]] -> [[MnistDataR r]]
forall a. (a -> Bool) -> [a] -> [a]
filter (\[MnistDataR r]
ch -> [MnistDataR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataR r]
ch Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
miniBatchSize)
([[MnistDataR r]] -> [[MnistDataR r]])
-> [[MnistDataR r]] -> [[MnistDataR r]]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataR r] -> [[MnistDataR r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
miniBatchSize [MnistDataR r]
chunk
res :: (Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
res@(Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
parameters2, StateAdam
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
_) =
(MnistDataBatchR r
-> ADVal
Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> ADVal Concrete (TKScalar r))
-> [MnistDataBatchR r]
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> StateAdam
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> (Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
forall a (x :: TK) (z :: TK).
KnownSTK x =>
(a -> ADVal Concrete x -> ADVal Concrete z)
-> [a] -> Concrete x -> StateAdam x -> (Concrete x, StateAdam x)
sgdAdam MnistDataBatchR r
-> ADVal
Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> ADVal Concrete (TKScalar r)
MnistDataBatchR r
-> ADVal Concrete (XParams r) -> ADVal Concrete (TKScalar r)
f [MnistDataBatchR r]
chunkR Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
Concrete (XParams r)
parameters StateAdam
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
StateAdam (XParams r)
stateAdam
!trainScore :: r
trainScore =
Int -> MnistDataBatchR r -> Concrete (XParams r) -> r
ftest ([MnistDataR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataR r]
chunk) ([MnistDataR r] -> MnistDataBatchR r
forall r. Elt r => [MnistDataR r] -> MnistDataBatchR r
mkMnistDataBatchR [MnistDataR r]
chunk) Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
Concrete (XParams r)
parameters2
!testScore :: r
testScore =
Int -> MnistDataBatchR r -> Concrete (XParams r) -> r
ftest (Int
totalBatchSize Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
maxBatches) MnistDataBatchR r
testDataR Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
Concrete (XParams r)
parameters2
!lenChunk :: Int
lenChunk = [MnistDataR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataR r]
chunk
Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
n_hiddenInt Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$ do
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: (Batch %d with %d points)"
String
prefix Int
k Int
lenChunk
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Training error: %.2f%%"
String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
trainScore) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Validation error: %.2f%%"
String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
testScore ) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
(Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
res
let runEpoch :: Int
-> (Concrete (XParams r), StateAdam (XParams r))
-> IO (Concrete (XParams r))
runEpoch Int
n (Concrete (XParams r)
params2, StateAdam (XParams r)
_) | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
epochs = Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
Concrete (XParams r)
params2
runEpoch Int
n paramsStateAdam :: (Concrete (XParams r), StateAdam (XParams r))
paramsStateAdam@(!Concrete (XParams r)
_, !StateAdam (XParams r)
_) = do
Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
n_hiddenInt Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$ String -> String -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: [Epoch %d]" String
prefix Int
n
let trainDataShuffled :: [MnistDataR r]
trainDataShuffled = StdGen -> [MnistDataR r] -> [MnistDataR r]
forall a. StdGen -> [a] -> [a]
shuffle (Int -> StdGen
mkStdGen (Int -> StdGen) -> Int -> StdGen
forall a b. (a -> b) -> a -> b
$ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
5) [MnistDataR r]
trainData
chunks :: [(Int, [MnistDataR r])]
chunks = Int -> [(Int, [MnistDataR r])] -> [(Int, [MnistDataR r])]
forall a. Int -> [a] -> [a]
take Int
maxBatches
([(Int, [MnistDataR r])] -> [(Int, [MnistDataR r])])
-> [(Int, [MnistDataR r])] -> [(Int, [MnistDataR r])]
forall a b. (a -> b) -> a -> b
$ [Int] -> [[MnistDataR r]] -> [(Int, [MnistDataR r])]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
1 ..]
([[MnistDataR r]] -> [(Int, [MnistDataR r])])
-> [[MnistDataR r]] -> [(Int, [MnistDataR r])]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataR r] -> [[MnistDataR r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
totalBatchSize [MnistDataR r]
trainDataShuffled
res <- ((Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
-> (Int, [MnistDataR r])
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))))
-> (Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
-> [(Int, [MnistDataR r])]
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
forall (t :: Type -> Type) (m :: Type -> Type) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
-> (Int, [MnistDataR r])
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
(Concrete (XParams r), StateAdam (XParams r))
-> (Int, [MnistDataR r])
-> IO (Concrete (XParams r), StateAdam (XParams r))
runBatch (Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
(Concrete (XParams r), StateAdam (XParams r))
paramsStateAdam [(Int, [MnistDataR r])]
chunks
runEpoch (succ n) res
res <- runEpoch 1 (targetInit, initialStateAdam ftk)
let testErrorFinal =
r
1 r -> r -> r
forall a. Num a => a -> a -> a
- Int -> MnistDataBatchR r -> Concrete (XParams r) -> r
ftest (Int
totalBatchSize Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
maxBatches) MnistDataBatchR r
testDataR Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
Concrete (XParams r)
res
testErrorFinal @?~ expected
{-# SPECIALIZE mnistTestCaseCNNRI
:: String
-> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Double
-> TestTree #-}
tensorADValMnistTestsCNNRI :: TestTree
tensorADValMnistTestsCNNRI :: TestTree
tensorADValMnistTestsCNNRI = String -> [TestTree] -> TestTree
testGroup String
"CNNR Intermediate MNIST tests"
[ String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Double
-> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNRI String
"CNNRI 1 epoch, 1 batch" Int
1 Int
1 Int
4 Int
4 Int
8 Int
16 Int
1 Int
1
(Double
1 :: Double)
, String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Float
-> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNRI String
"CNNRI artificial 1 2 3 4 5" Int
1 Int
1 Int
2 Int
3 Int
4 Int
5 Int
1 Int
10
(Float
1 :: Float)
, String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Double
-> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNRI String
"CNNRI artificial 5 4 3 2 1" Int
5 Int
4 Int
3 Int
2 Int
1 Int
1 Int
1 Int
1
(Double
1 :: Double)
, String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Float
-> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNRI String
"CNNRI 1 epoch, 0 batch" Int
1 Int
0 Int
4 Int
4 Int
16 Int
64 Int
16 Int
50
(Float
1.0 :: Float)
]
mnistTestCaseCNNRO
:: forall r.
( Differentiable r, GoodScalar r
, PrintfArg r, AssertEqualUpToEpsilon r, ADTensorScalar r ~ r )
=> String
-> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> r
-> TestTree
mnistTestCaseCNNRO :: forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
AssertEqualUpToEpsilon r,
(ADTensorScalar r :: Type) ~ (r :: Type)) =>
String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNRO String
prefix Int
epochs Int
maxBatches Int
khInt Int
kwInt Int
c_outInt Int
n_hiddenInt
Int
miniBatchSize Int
totalBatchSize r
expected =
Int
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
khInt ((forall (n :: Nat). KnownNat n => SNat n -> TestTree) -> TestTree)
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
_khSNat :: SNat kh) ->
Int
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
kwInt ((forall (n :: Nat). KnownNat n => SNat n -> TestTree) -> TestTree)
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
_kwSNat :: SNat kw) ->
Int
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
c_outInt ((forall (n :: Nat). KnownNat n => SNat n -> TestTree) -> TestTree)
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
_c_outSNat :: SNat c_out) ->
Int
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Nat). KnownNat n => SNat n -> r) -> r
withSNat Int
n_hiddenInt ((forall (n :: Nat). KnownNat n => SNat n -> TestTree) -> TestTree)
-> (forall (n :: Nat). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
_n_hiddenSNat :: SNat n_hidden) ->
let targetInit :: NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
targetInit =
Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
-> NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
forall vals. ForgetShape vals => vals -> NoShape vals
forgetShape (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
-> NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r)))))
-> NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
forall a b. (a -> b) -> a -> b
$ (Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight n n n n r)),
StdGen)
-> Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight n n n n r))
forall a b. (a, b) -> a
fst
((Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight n n n n r)),
StdGen)
-> Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight n n n n r)))
-> (Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight n n n n r)),
StdGen)
-> Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight n n n n r))
forall a b. (a -> b) -> a -> b
$ forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue
@(Concrete (X (MnistCnnRanked2.ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistWidth
kh kw c_out n_hidden r)))
Double
0.4 (Int -> StdGen
mkStdGen Int
44)
name :: String
name = String
prefix String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": "
String -> String -> String
forall a. [a] -> [a] -> [a]
++ [String] -> String
unwords [ Int -> String
forall a. Show a => a -> String
show Int
epochs, Int -> String
forall a. Show a => a -> String
show Int
maxBatches
, Int -> String
forall a. Show a => a -> String
show Int
khInt, Int -> String
forall a. Show a => a -> String
show Int
kwInt
, Int -> String
forall a. Show a => a -> String
show Int
c_outInt, Int -> String
forall a. Show a => a -> String
show Int
n_hiddenInt
, Int -> String
forall a. Show a => a -> String
show Int
miniBatchSize
, Int -> String
forall a. Show a => a -> String
show (Int -> String) -> Int -> String
forall a b. (a -> b) -> a -> b
$ SingletonTK (XParams r) -> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK (SingletonTK (XParams r) -> Int) -> SingletonTK (XParams r) -> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams r)
, Int -> String
forall a. Show a => a -> String
show (SingletonTK
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
targetInit) ]
ftest :: Int -> MnistDataBatchR r -> Concrete (XParams r) -> r
ftest :: Int -> MnistDataBatchR r -> Concrete (XParams r) -> r
ftest Int
batch_size MnistDataBatchR r
mnistData Concrete (XParams r)
pars =
Int -> MnistDataBatchR r -> ADCnnMnistParameters Concrete r -> r
forall (target :: TK -> Type) r.
((target :: (TK -> Type)) ~ (Concrete :: (TK -> Type)),
GoodScalar r, Differentiable r) =>
Int -> MnistDataBatchR r -> ADCnnMnistParameters Concrete r -> r
MnistCnnRanked2.convMnistTestR
Int
batch_size MnistDataBatchR r
mnistData (forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget @Concrete Concrete (XParams r)
pars)
in String -> Assertion -> TestTree
testCase String
name (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ do
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: Epochs to run/max batches per epoch: %d/%d"
String
prefix Int
epochs Int
maxBatches
trainData <- (MnistData r -> MnistDataR r) -> [MnistData r] -> [MnistDataR r]
forall a b. (a -> b) -> [a] -> [b]
map MnistData r -> MnistDataR r
forall r. PrimElt r => MnistData r -> MnistDataR r
mkMnistDataR
([MnistData r] -> [MnistDataR r])
-> IO [MnistData r] -> IO [MnistDataR r]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> String -> IO [MnistData r]
forall r.
(Storable r, Fractional r) =>
String -> String -> IO [MnistData r]
loadMnistData String
trainGlyphsPath String
trainLabelsPath
testData <- map mkMnistDataR . take (totalBatchSize * maxBatches)
<$> loadMnistData testGlyphsPath testLabelsPath
let testDataR = [MnistDataR r] -> MnistDataBatchR r
forall r. Elt r => [MnistDataR r] -> MnistDataBatchR r
mkMnistDataBatchR [MnistDataR r]
testData
dataInit = case Int -> [MnistDataR r] -> [[MnistDataR r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
miniBatchSize [MnistDataR r]
testData of
[MnistDataR r]
d : [[MnistDataR r]]
_ -> let (Ranked 3 r
dglyph, Ranked 2 r
dlabel) = [MnistDataR r] -> MnistDataBatchR r
forall r. Elt r => [MnistDataR r] -> MnistDataBatchR r
mkMnistDataBatchR [MnistDataR r]
d
in (Ranked 3 r -> Concrete (TKR 3 r)
forall r (target :: TK -> Type) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 3 r
dglyph, Ranked 2 r -> Concrete (TKR2 2 (TKScalar r))
forall r (target :: TK -> Type) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 2 r
dlabel)
[] -> String -> (Concrete (TKR 3 r), Concrete (TKR2 2 (TKScalar r)))
forall a. HasCallStack => String -> a
error String
"empty test data"
f :: ( MnistCnnRanked2.ADCnnMnistParameters
(AstTensor AstMethodLet FullSpan) r
, ( AstTensor AstMethodLet FullSpan (TKR 3 r)
, AstTensor AstMethodLet FullSpan (TKR 2 r) ) )
-> AstTensor AstMethodLet FullSpan (TKScalar r)
f = \ (ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) r
pars, (AstTensor AstMethodLet FullSpan (TKR 3 r)
glyphR, AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar r))
labelR)) ->
Int
-> (PrimalOf (AstTensor AstMethodLet FullSpan) (TKR 3 r),
PrimalOf (AstTensor AstMethodLet FullSpan) (TKR2 2 (TKScalar r)))
-> ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) r
-> AstTensor AstMethodLet FullSpan (TKScalar r)
forall (target :: TK -> Type) r.
(ADReady target, ADReady (PrimalOf target), GoodScalar r,
Differentiable r) =>
Int
-> (PrimalOf target (TKR 3 r), PrimalOf target (TKR 2 r))
-> ADCnnMnistParameters target r
-> target (TKScalar r)
MnistCnnRanked2.convMnistLossFusedR
Int
miniBatchSize (AstTensor AstMethodLet FullSpan (TKR 3 r)
-> PrimalOf (AstTensor AstMethodLet FullSpan) (TKR 3 r)
forall (target :: TK -> Type) (n :: Nat) (x :: TK).
BaseTensor target =>
target (TKR2 n x) -> PrimalOf target (TKR2 n x)
rprimalPart AstTensor AstMethodLet FullSpan (TKR 3 r)
glyphR, AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar r))
-> PrimalOf (AstTensor AstMethodLet FullSpan) (TKR2 2 (TKScalar r))
forall (target :: TK -> Type) (n :: Nat) (x :: TK).
BaseTensor target =>
target (TKR2 n x) -> PrimalOf target (TKR2 n x)
rprimalPart AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar r))
labelR) ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) r
pars
artRaw = ((ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) r,
(AstTensor AstMethodLet FullSpan (TKR 3 r),
AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar r))))
-> AstTensor AstMethodLet FullSpan (TKScalar r))
-> Value
(ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) r,
(AstTensor AstMethodLet FullSpan (TKR 3 r),
AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar r))))
-> AstArtifactRev
(X (ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) r,
(AstTensor AstMethodLet FullSpan (TKR 3 r),
AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar r)))))
(TKScalar r)
forall src r tgt.
((X src :: TK) ~ (X (Value src) :: TK), KnownSTK (X src),
AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
AdaptableTarget Concrete (Value src),
(tgt :: Type)
~ (AstTensor AstMethodLet FullSpan (TKScalar r) :: Type)) =>
(src -> tgt) -> Value src -> AstArtifactRev (X src) (TKScalar r)
gradArtifact (ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) r,
(AstTensor AstMethodLet FullSpan (TKR 3 r),
AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar r))))
-> AstTensor AstMethodLet FullSpan (TKScalar r)
f (Concrete (XParams r) -> ADCnnMnistParameters Concrete r
forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget Concrete (XParams r)
NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
targetInit, (Concrete (TKR 3 r), Concrete (TKR2 2 (TKScalar r)))
dataInit)
art = AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
(TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
(TKScalar r)
-> AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
(TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
(TKScalar r)
forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstArtifactRev x z
simplifyArtifactGradient AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
(TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
(TKScalar r)
AstArtifactRev
(X (ADCnnMnistParameters (AstTensor AstMethodLet FullSpan) r,
(AstTensor AstMethodLet FullSpan (TKR 3 r),
AstTensor AstMethodLet FullSpan (TKR2 2 (TKScalar r)))))
(TKScalar r)
artRaw
go :: [MnistDataBatchR r]
-> (Concrete (XParams r), StateAdam (XParams r))
-> (Concrete (XParams r), StateAdam (XParams r))
go [] (Concrete (XParams r)
parameters, StateAdam (XParams r)
stateAdam) = (Concrete (XParams r)
parameters, StateAdam (XParams r)
stateAdam)
go ((Ranked 3 r
glyph, Ranked 2 r
label) : [MnistDataBatchR r]
rest) (!Concrete (XParams r)
parameters, !StateAdam (XParams r)
stateAdam) =
let parametersAndInput :: Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
(TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
parametersAndInput =
Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> Concrete (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r)))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
(TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
forall (x :: TK) (z :: TK).
Concrete x -> Concrete z -> Concrete (TKProduct x z)
forall (target :: TK -> Type) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
Concrete (XParams r)
parameters (Concrete (TKR 3 r)
-> Concrete (TKR2 2 (TKScalar r))
-> Concrete (TKProduct (TKR 3 r) (TKR2 2 (TKScalar r)))
forall (x :: TK) (z :: TK).
Concrete x -> Concrete z -> Concrete (TKProduct x z)
forall (target :: TK -> Type) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair (Ranked 3 r -> Concrete (TKR 3 r)
forall r (target :: TK -> Type) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 3 r
glyph) (Ranked 2 r -> Concrete (TKR2 2 (TKScalar r))
forall r (target :: TK -> Type) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 2 r
label))
gradient :: Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
gradient =
Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
(TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
forall (x :: TK) (z :: TK). Concrete (TKProduct x z) -> Concrete x
forall (target :: TK -> Type) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
(TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
(TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
forall a b. (a -> b) -> a -> b
$ (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
(TKProduct (TKR 3 r) (TKR2 2 (TKScalar r)))),
Concrete (TKScalar r))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
(TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
forall a b. (a, b) -> a
fst
((Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
(TKProduct (TKR 3 r) (TKR2 2 (TKScalar r)))),
Concrete (TKScalar r))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
(TKProduct (TKR 3 r) (TKR2 2 (TKScalar r)))))
-> (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
(TKProduct (TKR 3 r) (TKR2 2 (TKScalar r)))),
Concrete (TKScalar r))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
(TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
forall a b. (a -> b) -> a -> b
$ AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
(TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
(TKScalar r)
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
(TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
-> Maybe (Concrete (ADTensorKind (TKScalar r)))
-> (Concrete
(ADTensorKind
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
(TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))),
Concrete (TKScalar r))
forall (x :: TK) (z :: TK).
AstArtifactRev x z
-> Concrete x
-> Maybe (Concrete (ADTensorKind z))
-> (Concrete (ADTensorKind x), Concrete z)
revInterpretArtifact AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
(TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
(TKScalar r)
art Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
(TKProduct (TKR 3 r) (TKR2 2 (TKScalar r))))
parametersAndInput Maybe (Concrete (ADTensorKind (TKScalar r)))
Maybe (Concrete (TKScalar r))
forall a. Maybe a
Nothing
in [MnistDataBatchR r]
-> (Concrete (XParams r), StateAdam (XParams r))
-> (Concrete (XParams r), StateAdam (XParams r))
go [MnistDataBatchR r]
rest (forall (y :: TK).
ArgsAdam
-> StateAdam y
-> SingletonTK y
-> Concrete y
-> Concrete (ADTensorKind y)
-> (Concrete y, StateAdam y)
updateWithGradientAdam
@(XParams r)
ArgsAdam
defaultArgsAdam StateAdam (XParams r)
stateAdam SingletonTK
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
SingletonTK (XParams r)
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete (XParams r)
parameters
Concrete (ADTensorKind (XParams r))
Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
gradient)
runBatch :: (Concrete (XParams r), StateAdam (XParams r))
-> (Int, [MnistDataR r])
-> IO (Concrete (XParams r), StateAdam (XParams r))
runBatch (!Concrete (XParams r)
parameters, !StateAdam (XParams r)
stateAdam) (Int
k, [MnistDataR r]
chunk) = do
let chunkR :: [MnistDataBatchR r]
chunkR = ([MnistDataR r] -> MnistDataBatchR r)
-> [[MnistDataR r]] -> [MnistDataBatchR r]
forall a b. (a -> b) -> [a] -> [b]
map [MnistDataR r] -> MnistDataBatchR r
forall r. Elt r => [MnistDataR r] -> MnistDataBatchR r
mkMnistDataBatchR
([[MnistDataR r]] -> [MnistDataBatchR r])
-> [[MnistDataR r]] -> [MnistDataBatchR r]
forall a b. (a -> b) -> a -> b
$ ([MnistDataR r] -> Bool) -> [[MnistDataR r]] -> [[MnistDataR r]]
forall a. (a -> Bool) -> [a] -> [a]
filter (\[MnistDataR r]
ch -> [MnistDataR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataR r]
ch Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
miniBatchSize)
([[MnistDataR r]] -> [[MnistDataR r]])
-> [[MnistDataR r]] -> [[MnistDataR r]]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataR r] -> [[MnistDataR r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
miniBatchSize [MnistDataR r]
chunk
res :: (Concrete (XParams r), StateAdam (XParams r))
res@(Concrete (XParams r)
parameters2, StateAdam (XParams r)
_) = [MnistDataBatchR r]
-> (Concrete (XParams r), StateAdam (XParams r))
-> (Concrete (XParams r), StateAdam (XParams r))
go [MnistDataBatchR r]
chunkR (Concrete (XParams r)
parameters, StateAdam (XParams r)
stateAdam)
trainScore :: r
trainScore =
Int -> MnistDataBatchR r -> Concrete (XParams r) -> r
ftest ([MnistDataR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataR r]
chunk) ([MnistDataR r] -> MnistDataBatchR r
forall r. Elt r => [MnistDataR r] -> MnistDataBatchR r
mkMnistDataBatchR [MnistDataR r]
chunk) Concrete (XParams r)
parameters2
testScore :: r
testScore =
Int -> MnistDataBatchR r -> Concrete (XParams r) -> r
ftest (Int
totalBatchSize Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
maxBatches) MnistDataBatchR r
testDataR Concrete (XParams r)
parameters2
lenChunk :: Int
lenChunk = [MnistDataR r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataR r]
chunk
Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
n_hiddenInt Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$ do
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: (Batch %d with %d points)"
String
prefix Int
k Int
lenChunk
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Training error: %.2f%%"
String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
trainScore) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Validation error: %.2f%%"
String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
testScore ) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
(Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
(Concrete (XParams r), StateAdam (XParams r))
res
let runEpoch :: Int
-> (Concrete (XParams r), StateAdam (XParams r))
-> IO (Concrete (XParams r))
runEpoch Int
n (Concrete (XParams r)
params2, StateAdam (XParams r)
_) | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
epochs = Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
Concrete (XParams r)
params2
runEpoch Int
n paramsStateAdam :: (Concrete (XParams r), StateAdam (XParams r))
paramsStateAdam@(!Concrete (XParams r)
_, !StateAdam (XParams r)
_) = do
Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
n_hiddenInt Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$ String -> String -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: [Epoch %d]" String
prefix Int
n
let trainDataShuffled :: [MnistDataR r]
trainDataShuffled = StdGen -> [MnistDataR r] -> [MnistDataR r]
forall a. StdGen -> [a] -> [a]
shuffle (Int -> StdGen
mkStdGen (Int -> StdGen) -> Int -> StdGen
forall a b. (a -> b) -> a -> b
$ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
5) [MnistDataR r]
trainData
chunks :: [(Int, [MnistDataR r])]
chunks = Int -> [(Int, [MnistDataR r])] -> [(Int, [MnistDataR r])]
forall a. Int -> [a] -> [a]
take Int
maxBatches
([(Int, [MnistDataR r])] -> [(Int, [MnistDataR r])])
-> [(Int, [MnistDataR r])] -> [(Int, [MnistDataR r])]
forall a b. (a -> b) -> a -> b
$ [Int] -> [[MnistDataR r]] -> [(Int, [MnistDataR r])]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
1 ..]
([[MnistDataR r]] -> [(Int, [MnistDataR r])])
-> [[MnistDataR r]] -> [(Int, [MnistDataR r])]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataR r] -> [[MnistDataR r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
totalBatchSize [MnistDataR r]
trainDataShuffled
res <- ((Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
-> (Int, [MnistDataR r])
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))))
-> (Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
-> [(Int, [MnistDataR r])]
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
forall (t :: Type -> Type) (m :: Type -> Type) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
-> (Int, [MnistDataR r])
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
(Concrete (XParams r), StateAdam (XParams r))
-> (Int, [MnistDataR r])
-> IO (Concrete (XParams r), StateAdam (XParams r))
runBatch (Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r))))))
(Concrete (XParams r), StateAdam (XParams r))
paramsStateAdam [(Int, [MnistDataR r])]
chunks
runEpoch (succ n) res
ftk = forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams r)) Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
NoShape
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat 1 ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Nat
n
((':) @Nat n ((':) @Nat (n + 1) ((':) @Nat (n + 1) ('[] @Nat)))))
(TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Nat n ((':) @Nat ((n * 7) * 7) ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat n ('[] @Nat)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Nat SizeMnistLabel ((':) @Nat n ('[] @Nat))) (TKScalar r))
(TKS2 ((':) @Nat SizeMnistLabel ('[] @Nat)) (TKScalar r))))))
targetInit
res <- runEpoch 1 (targetInit, initialStateAdam ftk)
let testErrorFinal =
r
1 r -> r -> r
forall a. Num a => a -> a -> a
- Int -> MnistDataBatchR r -> Concrete (XParams r) -> r
ftest (Int
totalBatchSize Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
maxBatches) MnistDataBatchR r
testDataR Concrete
(TKProduct
(TKProduct
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 4 (TKScalar r)) (TKR2 1 (TKScalar r))))
(TKProduct
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))
(TKProduct (TKR2 2 (TKScalar r)) (TKR2 1 (TKScalar r)))))
Concrete (XParams r)
res
assertEqualUpToEpsilon 1e-1 expected testErrorFinal
{-# SPECIALIZE mnistTestCaseCNNRO
:: String
-> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Double
-> TestTree #-}
tensorADValMnistTestsCNNRO :: TestTree
tensorADValMnistTestsCNNRO :: TestTree
tensorADValMnistTestsCNNRO = String -> [TestTree] -> TestTree
testGroup String
"CNNR Once MNIST tests"
[ String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Double
-> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
AssertEqualUpToEpsilon r,
(ADTensorScalar r :: Type) ~ (r :: Type)) =>
String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNRO String
"CNNRO 1 epoch, 1 batch" Int
1 Int
1 Int
4 Int
4 Int
8 Int
16 Int
1 Int
1
(Double
1 :: Double)
, String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Float
-> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
AssertEqualUpToEpsilon r,
(ADTensorScalar r :: Type) ~ (r :: Type)) =>
String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNRO String
"CNNRO artificial 1 2 3 4 5" Int
1 Int
1 Int
2 Int
3 Int
4 Int
5 Int
1 Int
10
(Float
1 :: Float)
, String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Double
-> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
AssertEqualUpToEpsilon r,
(ADTensorScalar r :: Type) ~ (r :: Type)) =>
String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNRO String
"CNNRO artificial 5 4 3 2 1" Int
5 Int
4 Int
3 Int
2 Int
1 Int
1 Int
1 Int
1
(Double
1 :: Double)
, String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Float
-> TestTree
forall r.
(Differentiable r, GoodScalar r, PrintfArg r,
AssertEqualUpToEpsilon r,
(ADTensorScalar r :: Type) ~ (r :: Type)) =>
String
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNRO String
"CNNRO 1 epoch, 0 batch" Int
1 Int
0 Int
4 Int
4 Int
16 Int
64 Int
16 Int
50
(Float
1.0 :: Float)
]