{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
module TestMnistCNNS
( testTrees
) where
import Prelude
import Control.Monad (foldM, unless)
import GHC.TypeLits (KnownNat, type (<=))
import System.IO (hPutStrLn, stderr)
import System.Random
import Test.Tasty
import Test.Tasty.HUnit hiding (assert)
import Text.Printf
import Data.Array.Nested.Shaped.Shape
import HordeAd
import HordeAd.Core.Adaptor
import HordeAd.Core.AstEnv
import HordeAd.Core.AstFreshId
import HordeAd.Core.AstInterpret
import EqEpsilon
import MnistCnnShaped2 qualified
import MnistData
testTrees :: [TestTree]
testTrees :: [TestTree]
testTrees = [ TestTree
tensorADValMnistTestsCNNSA
, TestTree
tensorADValMnistTestsCNNSI
, TestTree
tensorADValMnistTestsCNNSO
]
type XParams kh kw c_out n_hidden r =
X (MnistCnnShaped2.ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistWidth kh kw c_out n_hidden r)
mnistTestCaseCNNSA
:: forall kh kw r.
( 1 <= kh, 1 <= kw
, Differentiable r, GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r )
=> String
-> Int -> Int -> SNat kh -> SNat kw -> Int -> Int -> Int -> Int -> r
-> TestTree
mnistTestCaseCNNSA :: forall (kh :: Natural) (kw :: Natural) r.
((<=) @Natural 1 kh, (<=) @Natural 1 kw, Differentiable r,
GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> SNat kh
-> SNat kw
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNSA String
prefix Int
epochs Int
maxBatches kh :: SNat kh
kh@SNat kh
SNat kw :: SNat kw
kw@SNat kw
SNat Int
c_outInt Int
n_hiddenInt
Int
miniBatchSizeInt Int
totalBatchSize r
expected =
Int
-> (forall (n :: Natural). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Natural). KnownNat n => SNat n -> r) -> r
withSNat Int
c_outInt ((forall (n :: Natural). KnownNat n => SNat n -> TestTree)
-> TestTree)
-> (forall (n :: Natural). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
_c_outSNat :: SNat c_out) ->
Int
-> (forall (n :: Natural). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Natural). KnownNat n => SNat n -> r) -> r
withSNat Int
n_hiddenInt ((forall (n :: Natural). KnownNat n => SNat n -> TestTree)
-> TestTree)
-> (forall (n :: Natural). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
_n_hiddenSNat :: SNat n_hidden) ->
Int
-> (forall (n :: Natural). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Natural). KnownNat n => SNat n -> r) -> r
withSNat Int
miniBatchSizeInt ((forall (n :: Natural). KnownNat n => SNat n -> TestTree)
-> TestTree)
-> (forall (n :: Natural). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
miniBatchSize :: SNat miniBatchSize) ->
let targetInit :: Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
targetInit =
(Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
StdGen)
-> Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
forall a b. (a, b) -> a
fst ((Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
StdGen)
-> Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
-> (Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
StdGen)
-> Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
forall a b. (a -> b) -> a -> b
$ forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue
@(Concrete (X (MnistCnnShaped2.ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistWidth
kh kw c_out n_hidden r)))
Double
0.4 (Int -> StdGen
mkStdGen Int
44)
name :: String
name = String
prefix String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": "
String -> String -> String
forall a. [a] -> [a] -> [a]
++ [String] -> String
unwords [ Int -> String
forall a. Show a => a -> String
show Int
epochs, Int -> String
forall a. Show a => a -> String
show Int
maxBatches
, Int -> String
forall a. Show a => a -> String
show (SNat kh -> Int
forall (n :: Natural). SNat n -> Int
sNatValue SNat kh
kh), Int -> String
forall a. Show a => a -> String
show (SNat kw -> Int
forall (n :: Natural). SNat n -> Int
sNatValue SNat kw
kw)
, Int -> String
forall a. Show a => a -> String
show Int
c_outInt, Int -> String
forall a. Show a => a -> String
show Int
n_hiddenInt
, Int -> String
forall a. Show a => a -> String
show Int
miniBatchSizeInt
, Int -> String
forall a. Show a => a -> String
show (Int -> String) -> Int -> String
forall a b. (a -> b) -> a -> b
$ SingletonTK
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK (SingletonTK
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> Int)
-> SingletonTK
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams kh kw c_out n_hidden r)
, Int -> String
forall a. Show a => a -> String
show (SingletonTK
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
targetInit) ]
ftest :: KnownNat batch_size
=> MnistDataBatchS batch_size r
-> Concrete (XParams kh kw c_out n_hidden r) -> r
ftest :: forall (batch_size :: Natural).
KnownNat batch_size =>
MnistDataBatchS batch_size r
-> Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> r
ftest @batch_size MnistDataBatchS batch_size r
mnistData Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
pars =
SNat kh
-> SNat kw
-> SNat n
-> SNat n
-> SNat batch_size
-> MnistDataBatchS batch_size r
-> ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r
-> r
forall (kh :: Natural) (kw :: Natural) (h :: Natural)
(w :: Natural) (c_out :: Natural) (n_hidden :: Natural)
(batch_size :: Natural) (target :: TK -> Type) r.
((h :: Natural) ~ (SizeMnistHeight :: Natural),
(w :: Natural) ~ (SizeMnistHeight :: Natural), (<=) @Natural 1 kh,
(<=) @Natural 1 kw,
(target :: (TK -> Type)) ~ (Concrete :: (TK -> Type)),
GoodScalar r, Differentiable r) =>
SNat kh
-> SNat kw
-> SNat c_out
-> SNat n_hidden
-> SNat batch_size
-> MnistDataBatchS batch_size r
-> ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
-> r
MnistCnnShaped2.convMnistTestS SNat kh
kh SNat kw
kw (forall (n :: Natural). KnownNat n => SNat n
SNat @c_out) (forall (n :: Natural). KnownNat n => SNat n
SNat @n_hidden)
(forall (n :: Natural). KnownNat n => SNat n
SNat @batch_size) MnistDataBatchS batch_size r
mnistData (forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget @Concrete Concrete
(X ((Concrete
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r)),
Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(Concrete
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r)),
Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(Concrete
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r)),
Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(Concrete
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r)),
Concrete
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
pars)
in String -> Assertion -> TestTree
testCase String
name (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ do
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: Epochs to run/max batches per epoch: %d/%d"
String
prefix Int
epochs Int
maxBatches
trainData <- (MnistData r -> MnistDataS r) -> [MnistData r] -> [MnistDataS r]
forall a b. (a -> b) -> [a] -> [b]
map MnistData r -> MnistDataS r
forall r. PrimElt r => MnistData r -> MnistDataS r
mkMnistDataS
([MnistData r] -> [MnistDataS r])
-> IO [MnistData r] -> IO [MnistDataS r]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> String -> IO [MnistData r]
forall r.
(Storable r, Fractional r) =>
String -> String -> IO [MnistData r]
loadMnistData String
trainGlyphsPath String
trainLabelsPath
testData <- map mkMnistDataS . take (totalBatchSize * maxBatches)
<$> loadMnistData testGlyphsPath testLabelsPath
withSNat (totalBatchSize * maxBatches) $ \(SNat @lenTestData) -> do
let testDataS :: MnistDataBatchS n r
testDataS = [MnistDataS r] -> MnistDataBatchS n r
forall (batch_size :: Natural) r.
(Elt r, KnownNat batch_size) =>
[MnistDataS r] -> MnistDataBatchS batch_size r
mkMnistDataBatchS [MnistDataS r]
testData
f :: MnistDataBatchS miniBatchSize r
-> ADVal Concrete (XParams kh kw c_out n_hidden r)
-> ADVal Concrete (TKScalar r)
f :: MnistDataBatchS n r
-> ADVal
Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> ADVal Concrete (TKScalar r)
f (Shaped
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r
glyphR, Shaped
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r
labelR) ADVal
Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
adinputs =
SNat kh
-> SNat kw
-> SNat n
-> SNat n
-> SNat n
-> (PrimalOf
(ADVal Concrete)
(TKS
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r),
PrimalOf
(ADVal Concrete)
(TKS
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r))
-> ADCnnMnistParametersShaped
(ADVal Concrete) SizeMnistHeight SizeMnistHeight kh kw n n r
-> ADVal Concrete (TKScalar r)
forall (kh :: Natural) (kw :: Natural) (h :: Natural)
(w :: Natural) (c_out :: Natural) (n_hidden :: Natural)
(batch_size :: Natural) (target :: TK -> Type) r.
((h :: Natural) ~ (SizeMnistHeight :: Natural),
(w :: Natural) ~ (SizeMnistHeight :: Natural), (<=) @Natural 1 kh,
(<=) @Natural 1 kw, ADReady target, ADReady (PrimalOf target),
GoodScalar r, Differentiable r) =>
SNat kh
-> SNat kw
-> SNat c_out
-> SNat n_hidden
-> SNat batch_size
-> (PrimalOf
target
(TKS
((':)
@Natural
batch_size
((':) @Natural h ((':) @Natural w ('[] @Natural))))
r),
PrimalOf
target
(TKS
((':)
@Natural batch_size ((':) @Natural SizeMnistLabel ('[] @Natural)))
r))
-> ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
-> target (TKScalar r)
MnistCnnShaped2.convMnistLossFusedS
SNat kh
kh SNat kw
kw (forall (n :: Natural). KnownNat n => SNat n
SNat @c_out) (forall (n :: Natural). KnownNat n => SNat n
SNat @n_hidden)
SNat n
miniBatchSize (Shaped
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r
-> Concrete
(TKS
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r)
forall r (target :: TK -> Type) (sh :: [Natural]).
(GoodScalar r, BaseTensor target) =>
Shaped sh r -> target (TKS sh r)
sconcrete Shaped
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r
glyphR, Shaped
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r
-> Concrete
(TKS
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r)
forall r (target :: TK -> Type) (sh :: [Natural]).
(GoodScalar r, BaseTensor target) =>
Shaped sh r -> target (TKS sh r)
sconcrete Shaped
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r
labelR)
(ADVal
Concrete
(X ((ADVal
Concrete
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r)),
ADVal
Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(ADVal
Concrete
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r)),
ADVal
Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(ADVal
Concrete
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r)),
ADVal
Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(ADVal
Concrete
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r)),
ADVal
Concrete
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> ((ADVal
Concrete
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r)),
ADVal
Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(ADVal
Concrete
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r)),
ADVal
Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(ADVal
Concrete
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r)),
ADVal
Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(ADVal
Concrete
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r)),
ADVal
Concrete
(TKS2 ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))
forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget ADVal
Concrete
(X ((ADVal
Concrete
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r)),
ADVal
Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(ADVal
Concrete
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r)),
ADVal
Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(ADVal
Concrete
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r)),
ADVal
Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(ADVal
Concrete
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r)),
ADVal
Concrete
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
ADVal
Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
adinputs)
runBatch :: (Concrete (XParams kh kw c_out n_hidden r), StateAdam (XParams kh kw c_out n_hidden r))
-> (Int, [MnistDataS r])
-> IO (Concrete (XParams kh kw c_out n_hidden r), StateAdam (XParams kh kw c_out n_hidden r))
runBatch :: (Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
StateAdam
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
-> (Int, [MnistDataS r])
-> IO
(Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
StateAdam
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
runBatch (!Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
parameters, !StateAdam
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
stateAdam) (Int
k, [MnistDataS r]
chunk) = do
let chunkS :: [MnistDataBatchS n r]
chunkS = ([MnistDataS r] -> MnistDataBatchS n r)
-> [[MnistDataS r]] -> [MnistDataBatchS n r]
forall a b. (a -> b) -> [a] -> [b]
map [MnistDataS r] -> MnistDataBatchS n r
forall (batch_size :: Natural) r.
(Elt r, KnownNat batch_size) =>
[MnistDataS r] -> MnistDataBatchS batch_size r
mkMnistDataBatchS
([[MnistDataS r]] -> [MnistDataBatchS n r])
-> [[MnistDataS r]] -> [MnistDataBatchS n r]
forall a b. (a -> b) -> a -> b
$ ([MnistDataS r] -> Bool) -> [[MnistDataS r]] -> [[MnistDataS r]]
forall a. (a -> Bool) -> [a] -> [a]
filter (\[MnistDataS r]
ch -> [MnistDataS r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataS r]
ch Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
miniBatchSizeInt)
([[MnistDataS r]] -> [[MnistDataS r]])
-> [[MnistDataS r]] -> [[MnistDataS r]]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataS r] -> [[MnistDataS r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
miniBatchSizeInt [MnistDataS r]
chunk
res :: (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
res@(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
parameters2, StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
_) =
(MnistDataBatchS n r
-> ADVal
Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> ADVal Concrete (TKScalar r))
-> [MnistDataBatchS n r]
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
forall a (x :: TK) (z :: TK).
KnownSTK x =>
(a -> ADVal Concrete x -> ADVal Concrete z)
-> [a] -> Concrete x -> StateAdam x -> (Concrete x, StateAdam x)
sgdAdam MnistDataBatchS n r
-> ADVal
Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> ADVal Concrete (TKScalar r)
MnistDataBatchS n r
-> ADVal
Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> ADVal Concrete (TKScalar r)
f [MnistDataBatchS n r]
chunkS Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
parameters StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
StateAdam
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
stateAdam
trainScore :: r
trainScore = Int -> (forall (n :: Natural). KnownNat n => SNat n -> r) -> r
forall r.
Int -> (forall (n :: Natural). KnownNat n => SNat n -> r) -> r
withSNat ([MnistDataS r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataS r]
chunk) ((forall (n :: Natural). KnownNat n => SNat n -> r) -> r)
-> (forall (n :: Natural). KnownNat n => SNat n -> r) -> r
forall a b. (a -> b) -> a -> b
$ \(SNat @len) ->
forall (batch_size :: Natural).
KnownNat batch_size =>
MnistDataBatchS batch_size r
-> Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> r
ftest @len ([MnistDataS r] -> MnistDataBatchS n r
forall (batch_size :: Natural) r.
(Elt r, KnownNat batch_size) =>
[MnistDataS r] -> MnistDataBatchS batch_size r
mkMnistDataBatchS [MnistDataS r]
chunk) Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
parameters2
testScore :: r
testScore = forall (batch_size :: Natural).
KnownNat batch_size =>
MnistDataBatchS batch_size r
-> Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> r
ftest @lenTestData MnistDataBatchS n r
testDataS Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
parameters2
lenChunk :: Int
lenChunk = [MnistDataS r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataS r]
chunk
Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
n_hiddenInt Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$ do
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: (Batch %d with %d points)"
String
prefix Int
k Int
lenChunk
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Training error: %.2f%%"
String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
trainScore) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Validation error: %.2f%%"
String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
testScore ) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
res
let runEpoch :: Int
-> (Concrete (XParams kh kw c_out n_hidden r), StateAdam (XParams kh kw c_out n_hidden r))
-> IO (Concrete (XParams kh kw c_out n_hidden r))
runEpoch :: Int
-> (Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
StateAdam
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
-> IO
(Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
runEpoch Int
n (Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
params2, StateAdam
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
_) | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
epochs = Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
params2
runEpoch Int
n paramsStateAdam :: (Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
StateAdam
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
paramsStateAdam@(!Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
_, !StateAdam
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
_) = do
Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
n_hiddenInt Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$ String -> String -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: [Epoch %d]" String
prefix Int
n
let trainDataShuffled :: [MnistDataS r]
trainDataShuffled = StdGen -> [MnistDataS r] -> [MnistDataS r]
forall a. StdGen -> [a] -> [a]
shuffle (Int -> StdGen
mkStdGen (Int -> StdGen) -> Int -> StdGen
forall a b. (a -> b) -> a -> b
$ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
5) [MnistDataS r]
trainData
chunks :: [(Int, [MnistDataS r])]
chunks = Int -> [(Int, [MnistDataS r])] -> [(Int, [MnistDataS r])]
forall a. Int -> [a] -> [a]
take Int
maxBatches
([(Int, [MnistDataS r])] -> [(Int, [MnistDataS r])])
-> [(Int, [MnistDataS r])] -> [(Int, [MnistDataS r])]
forall a b. (a -> b) -> a -> b
$ [Int] -> [[MnistDataS r]] -> [(Int, [MnistDataS r])]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
1 ..]
([[MnistDataS r]] -> [(Int, [MnistDataS r])])
-> [[MnistDataS r]] -> [(Int, [MnistDataS r])]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataS r] -> [[MnistDataS r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
totalBatchSize [MnistDataS r]
trainDataShuffled
res <- ((Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
-> (Int, [MnistDataS r])
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))))
-> (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
-> [(Int, [MnistDataS r])]
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
forall (t :: Type -> Type) (m :: Type -> Type) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
-> (Int, [MnistDataS r])
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
(Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
StateAdam
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
-> (Int, [MnistDataS r])
-> IO
(Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
StateAdam
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
runBatch (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
(Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
StateAdam
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
paramsStateAdam [(Int, [MnistDataS r])]
chunks
runEpoch (succ n) res
ftk :: FullShapeTK
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
ftk = forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams kh kw c_out n_hidden r)) Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
targetInit
res <- Int
-> (Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
StateAdam
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
-> IO
(Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
runEpoch Int
1 (Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
targetInit, FullShapeTK
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
forall (y :: TK). FullShapeTK y -> StateAdam y
initialStateAdam FullShapeTK
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
ftk)
let testErrorFinal =
r
1 r -> r -> r
forall a. Num a => a -> a -> a
- MnistDataBatchS n r
-> Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> r
forall (batch_size :: Natural).
KnownNat batch_size =>
MnistDataBatchS batch_size r
-> Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> r
ftest MnistDataBatchS n r
testDataS Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
res
testErrorFinal @?~ expected
tensorADValMnistTestsCNNSA :: TestTree
tensorADValMnistTestsCNNSA :: TestTree
tensorADValMnistTestsCNNSA = String -> [TestTree] -> TestTree
testGroup String
"CNNS ADVal MNIST tests"
[ String
-> Int
-> Int
-> SNat 4
-> SNat 4
-> Int
-> Int
-> Int
-> Int
-> Double
-> TestTree
forall (kh :: Natural) (kw :: Natural) r.
((<=) @Natural 1 kh, (<=) @Natural 1 kw, Differentiable r,
GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> SNat kh
-> SNat kw
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNSA String
"CNNSA 1 epoch, 1 batch"
Int
1 Int
1 (forall (n :: Natural). KnownNat n => SNat n
SNat @4) (forall (n :: Natural). KnownNat n => SNat n
SNat @4) Int
8 Int
16 Int
1 Int
1
(Double
1 :: Double)
, String
-> Int
-> Int
-> SNat 2
-> SNat 3
-> Int
-> Int
-> Int
-> Int
-> Float
-> TestTree
forall (kh :: Natural) (kw :: Natural) r.
((<=) @Natural 1 kh, (<=) @Natural 1 kw, Differentiable r,
GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> SNat kh
-> SNat kw
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNSA String
"CNNSA artificial 1 2 3 4 5"
Int
1 Int
1 (forall (n :: Natural). KnownNat n => SNat n
SNat @2) (forall (n :: Natural). KnownNat n => SNat n
SNat @3) Int
4 Int
5 Int
1 Int
10
(Float
1 :: Float)
, String
-> Int
-> Int
-> SNat 3
-> SNat 2
-> Int
-> Int
-> Int
-> Int
-> Double
-> TestTree
forall (kh :: Natural) (kw :: Natural) r.
((<=) @Natural 1 kh, (<=) @Natural 1 kw, Differentiable r,
GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> SNat kh
-> SNat kw
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNSA String
"CNNSA artificial 5 4 3 2 1"
Int
5 Int
4 (forall (n :: Natural). KnownNat n => SNat n
SNat @3) (forall (n :: Natural). KnownNat n => SNat n
SNat @2) Int
1 Int
1 Int
1 Int
1
(Double
1 :: Double)
, String
-> Int
-> Int
-> SNat 4
-> SNat 4
-> Int
-> Int
-> Int
-> Int
-> Float
-> TestTree
forall (kh :: Natural) (kw :: Natural) r.
((<=) @Natural 1 kh, (<=) @Natural 1 kw, Differentiable r,
GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> SNat kh
-> SNat kw
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNSA String
"CNNSA 1 epoch, 0 batch"
Int
1 Int
0 (forall (n :: Natural). KnownNat n => SNat n
SNat @4) (forall (n :: Natural). KnownNat n => SNat n
SNat @4) Int
16 Int
64 Int
16 Int
50
(Float
1.0 :: Float)
]
mnistTestCaseCNNSI
:: forall kh kw r.
( 1 <= kh, 1 <= kw
, Differentiable r, GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r )
=> String
-> Int -> Int -> SNat kh -> SNat kw -> Int -> Int -> Int -> Int -> r
-> TestTree
mnistTestCaseCNNSI :: forall (kh :: Natural) (kw :: Natural) r.
((<=) @Natural 1 kh, (<=) @Natural 1 kw, Differentiable r,
GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> SNat kh
-> SNat kw
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNSI String
prefix Int
epochs Int
maxBatches kh :: SNat kh
kh@SNat kh
SNat kw :: SNat kw
kw@SNat kw
SNat Int
c_outInt Int
n_hiddenInt
Int
miniBatchSizeInt Int
totalBatchSize r
expected =
Int
-> (forall (n :: Natural). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Natural). KnownNat n => SNat n -> r) -> r
withSNat Int
c_outInt ((forall (n :: Natural). KnownNat n => SNat n -> TestTree)
-> TestTree)
-> (forall (n :: Natural). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
_c_outSNat :: SNat c_out) ->
Int
-> (forall (n :: Natural). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Natural). KnownNat n => SNat n -> r) -> r
withSNat Int
n_hiddenInt ((forall (n :: Natural). KnownNat n => SNat n -> TestTree)
-> TestTree)
-> (forall (n :: Natural). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
_n_hiddenSNat :: SNat n_hidden) ->
Int
-> (forall (n :: Natural). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Natural). KnownNat n => SNat n -> r) -> r
withSNat Int
miniBatchSizeInt ((forall (n :: Natural). KnownNat n => SNat n -> TestTree)
-> TestTree)
-> (forall (n :: Natural). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
miniBatchSize :: SNat miniBatchSize) ->
let targetInit :: Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
targetInit =
(Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
StdGen)
-> Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
forall a b. (a, b) -> a
fst ((Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
StdGen)
-> Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
-> (Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
StdGen)
-> Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
forall a b. (a -> b) -> a -> b
$ forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue
@(Concrete (X (MnistCnnShaped2.ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistWidth
kh kw c_out n_hidden r)))
Double
0.4 (Int -> StdGen
mkStdGen Int
44)
name :: String
name = String
prefix String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": "
String -> String -> String
forall a. [a] -> [a] -> [a]
++ [String] -> String
unwords [ Int -> String
forall a. Show a => a -> String
show Int
epochs, Int -> String
forall a. Show a => a -> String
show Int
maxBatches
, Int -> String
forall a. Show a => a -> String
show (SNat kh -> Int
forall (n :: Natural). SNat n -> Int
sNatValue SNat kh
kh), Int -> String
forall a. Show a => a -> String
show (SNat kw -> Int
forall (n :: Natural). SNat n -> Int
sNatValue SNat kw
kw)
, Int -> String
forall a. Show a => a -> String
show Int
c_outInt, Int -> String
forall a. Show a => a -> String
show Int
n_hiddenInt
, Int -> String
forall a. Show a => a -> String
show Int
miniBatchSizeInt
, Int -> String
forall a. Show a => a -> String
show (Int -> String) -> Int -> String
forall a b. (a -> b) -> a -> b
$ SingletonTK
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK (SingletonTK
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> Int)
-> SingletonTK
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams kh kw c_out n_hidden r)
, Int -> String
forall a. Show a => a -> String
show (SingletonTK
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
targetInit) ]
ftest :: KnownNat batch_size
=> MnistDataBatchS batch_size r
-> Concrete (XParams kh kw c_out n_hidden r) -> r
ftest :: forall (batch_size :: Natural).
KnownNat batch_size =>
MnistDataBatchS batch_size r
-> Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> r
ftest @batch_size MnistDataBatchS batch_size r
mnistData Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
pars =
SNat kh
-> SNat kw
-> SNat n
-> SNat n
-> SNat batch_size
-> MnistDataBatchS batch_size r
-> ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r
-> r
forall (kh :: Natural) (kw :: Natural) (h :: Natural)
(w :: Natural) (c_out :: Natural) (n_hidden :: Natural)
(batch_size :: Natural) (target :: TK -> Type) r.
((h :: Natural) ~ (SizeMnistHeight :: Natural),
(w :: Natural) ~ (SizeMnistHeight :: Natural), (<=) @Natural 1 kh,
(<=) @Natural 1 kw,
(target :: (TK -> Type)) ~ (Concrete :: (TK -> Type)),
GoodScalar r, Differentiable r) =>
SNat kh
-> SNat kw
-> SNat c_out
-> SNat n_hidden
-> SNat batch_size
-> MnistDataBatchS batch_size r
-> ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
-> r
MnistCnnShaped2.convMnistTestS SNat kh
kh SNat kw
kw (forall (n :: Natural). KnownNat n => SNat n
SNat @c_out) (forall (n :: Natural). KnownNat n => SNat n
SNat @n_hidden)
(forall (n :: Natural). KnownNat n => SNat n
SNat @batch_size) MnistDataBatchS batch_size r
mnistData (forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget @Concrete Concrete
(X ((Concrete
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r)),
Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(Concrete
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r)),
Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(Concrete
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r)),
Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(Concrete
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r)),
Concrete
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
pars)
in String -> Assertion -> TestTree
testCase String
name (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ do
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: Epochs to run/max batches per epoch: %d/%d"
String
prefix Int
epochs Int
maxBatches
trainData <- (MnistData r -> MnistDataS r) -> [MnistData r] -> [MnistDataS r]
forall a b. (a -> b) -> [a] -> [b]
map MnistData r -> MnistDataS r
forall r. PrimElt r => MnistData r -> MnistDataS r
mkMnistDataS
([MnistData r] -> [MnistDataS r])
-> IO [MnistData r] -> IO [MnistDataS r]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> String -> IO [MnistData r]
forall r.
(Storable r, Fractional r) =>
String -> String -> IO [MnistData r]
loadMnistData String
trainGlyphsPath String
trainLabelsPath
testData <- map mkMnistDataS . take (totalBatchSize * maxBatches)
<$> loadMnistData testGlyphsPath testLabelsPath
withSNat (totalBatchSize * maxBatches) $ \(SNat @lenTestData) -> do
let testDataS :: MnistDataBatchS n r
testDataS = [MnistDataS r] -> MnistDataBatchS n r
forall (batch_size :: Natural) r.
(Elt r, KnownNat batch_size) =>
[MnistDataS r] -> MnistDataBatchS batch_size r
mkMnistDataBatchS [MnistDataS r]
testData
ftk :: FullShapeTK
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
ftk = forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams kh kw c_out n_hidden r)) Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
targetInit
(_, _, var, varAst2) <- FullShapeTK
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> IO
(AstVarName
PrimalSpan
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
AstTensor
AstMethodShare
PrimalSpan
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
AstVarName
FullSpan
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
AstTensor
AstMethodLet
FullSpan
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
forall (x :: TK).
FullShapeTK x
-> IO
(AstVarName PrimalSpan x, AstTensor AstMethodShare PrimalSpan x,
AstVarName FullSpan x, AstTensor AstMethodLet FullSpan x)
funToAstRevIO FullShapeTK
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
ftk
(varGlyph, astGlyph) <-
funToAstIO (FTKS (miniBatchSize
:$$ sizeMnistHeight
:$$ sizeMnistWidth
:$$ ZSS) FTKScalar) id
(varLabel, astLabel) <-
funToAstIO (FTKS (miniBatchSize
:$$ sizeMnistLabel
:$$ ZSS) FTKScalar) id
let ast :: AstTensor AstMethodLet FullSpan (TKScalar r)
ast = AstTensor AstMethodLet FullSpan (TKScalar r)
-> AstTensor AstMethodLet FullSpan (TKScalar r)
forall (z :: TK) (s :: AstSpanType).
AstSpan s =>
AstTensor AstMethodLet s z -> AstTensor AstMethodLet s z
simplifyInline
(AstTensor AstMethodLet FullSpan (TKScalar r)
-> AstTensor AstMethodLet FullSpan (TKScalar r))
-> AstTensor AstMethodLet FullSpan (TKScalar r)
-> AstTensor AstMethodLet FullSpan (TKScalar r)
forall a b. (a -> b) -> a -> b
$ SNat kh
-> SNat kw
-> SNat n
-> SNat n
-> SNat n
-> (PrimalOf
(AstTensor AstMethodLet FullSpan)
(TKS2
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
(TKScalar r)),
PrimalOf
(AstTensor AstMethodLet FullSpan)
(TKS2
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
(TKScalar r)))
-> ADCnnMnistParametersShaped
(AstTensor AstMethodLet FullSpan)
SizeMnistHeight
SizeMnistHeight
kh
kw
n
n
r
-> AstTensor AstMethodLet FullSpan (TKScalar r)
forall (kh :: Natural) (kw :: Natural) (h :: Natural)
(w :: Natural) (c_out :: Natural) (n_hidden :: Natural)
(batch_size :: Natural) (target :: TK -> Type) r.
((h :: Natural) ~ (SizeMnistHeight :: Natural),
(w :: Natural) ~ (SizeMnistHeight :: Natural), (<=) @Natural 1 kh,
(<=) @Natural 1 kw, ADReady target, ADReady (PrimalOf target),
GoodScalar r, Differentiable r) =>
SNat kh
-> SNat kw
-> SNat c_out
-> SNat n_hidden
-> SNat batch_size
-> (PrimalOf
target
(TKS
((':)
@Natural
batch_size
((':) @Natural h ((':) @Natural w ('[] @Natural))))
r),
PrimalOf
target
(TKS
((':)
@Natural batch_size ((':) @Natural SizeMnistLabel ('[] @Natural)))
r))
-> ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
-> target (TKScalar r)
MnistCnnShaped2.convMnistLossFusedS
SNat kh
kh SNat kw
kw (forall (n :: Natural). KnownNat n => SNat n
SNat @c_out) (forall (n :: Natural). KnownNat n => SNat n
SNat @n_hidden)
SNat n
miniBatchSize (AstTensor
AstMethodLet
PrimalSpan
(TKS2
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
(TKScalar r))
PrimalOf
(AstTensor AstMethodLet FullSpan)
(TKS2
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
(TKScalar r))
astGlyph, AstTensor
AstMethodLet
PrimalSpan
(TKS2
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
(TKScalar r))
PrimalOf
(AstTensor AstMethodLet FullSpan)
(TKS2
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
(TKScalar r))
astLabel)
(AstTensor
AstMethodLet
FullSpan
(X ((AstTensor
AstMethodLet
FullSpan
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r)),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(AstTensor
AstMethodLet
FullSpan
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r)),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(AstTensor
AstMethodLet
FullSpan
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r)),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(AstTensor
AstMethodLet
FullSpan
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r)),
AstTensor
AstMethodLet
FullSpan
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> ((AstTensor
AstMethodLet
FullSpan
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r)),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(AstTensor
AstMethodLet
FullSpan
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r)),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(AstTensor
AstMethodLet
FullSpan
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r)),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(AstTensor
AstMethodLet
FullSpan
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r)),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))
forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget AstTensor
AstMethodLet
FullSpan
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
AstTensor
AstMethodLet
FullSpan
(X ((AstTensor
AstMethodLet
FullSpan
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r)),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(AstTensor
AstMethodLet
FullSpan
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r)),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(AstTensor
AstMethodLet
FullSpan
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r)),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(AstTensor
AstMethodLet
FullSpan
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r)),
AstTensor
AstMethodLet
FullSpan
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
varAst2)
f :: MnistDataBatchS miniBatchSize r
-> ADVal Concrete (XParams kh kw c_out n_hidden r)
-> ADVal Concrete (TKScalar r)
f (Shaped
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r
glyph, Shaped
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r
label) ADVal
Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
varInputs =
let env :: AstEnv (ADVal Concrete)
env = AstVarName
FullSpan
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> ADVal
Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> AstEnv (ADVal Concrete)
-> AstEnv (ADVal Concrete)
forall (target :: TK -> Type) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName
FullSpan
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
var ADVal
Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
ADVal
Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
varInputs AstEnv (ADVal Concrete)
forall (target :: TK -> Type). AstEnv target
emptyEnv
envMnist :: AstEnv (ADVal Concrete)
envMnist = AstVarName
PrimalSpan
(TKS2
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
(TKScalar r))
-> ADVal
Concrete
(TKS2
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
(TKScalar r))
-> AstEnv (ADVal Concrete)
-> AstEnv (ADVal Concrete)
forall (target :: TK -> Type) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName
PrimalSpan
(TKS2
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
(TKScalar r))
varGlyph (Shaped
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r
-> ADVal
Concrete
(TKS2
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
(TKScalar r))
forall r (target :: TK -> Type) (sh :: [Natural]).
(GoodScalar r, BaseTensor target) =>
Shaped sh r -> target (TKS sh r)
sconcrete Shaped
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r
glyph)
(AstEnv (ADVal Concrete) -> AstEnv (ADVal Concrete))
-> AstEnv (ADVal Concrete) -> AstEnv (ADVal Concrete)
forall a b. (a -> b) -> a -> b
$ AstVarName
PrimalSpan
(TKS2
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
(TKScalar r))
-> ADVal
Concrete
(TKS2
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
(TKScalar r))
-> AstEnv (ADVal Concrete)
-> AstEnv (ADVal Concrete)
forall (target :: TK -> Type) (s :: AstSpanType) (y :: TK).
AstVarName s y -> target y -> AstEnv target -> AstEnv target
extendEnv AstVarName
PrimalSpan
(TKS2
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
(TKScalar r))
varLabel (Shaped
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r
-> ADVal
Concrete
(TKS2
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
(TKScalar r))
forall r (target :: TK -> Type) (sh :: [Natural]).
(GoodScalar r, BaseTensor target) =>
Shaped sh r -> target (TKS sh r)
sconcrete Shaped
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r
label) AstEnv (ADVal Concrete)
env
in AstEnv (ADVal Concrete)
-> AstTensor AstMethodLet FullSpan (TKScalar r)
-> ADVal Concrete (TKScalar r)
forall (target :: TK -> Type) (y :: TK).
ADReady target =>
AstEnv target -> AstTensor AstMethodLet FullSpan y -> target y
interpretAstFull AstEnv (ADVal Concrete)
envMnist AstTensor AstMethodLet FullSpan (TKScalar r)
ast
runBatch :: (Concrete (XParams kh kw c_out n_hidden r), StateAdam (XParams kh kw c_out n_hidden r))
-> (Int, [MnistDataS r])
-> IO (Concrete (XParams kh kw c_out n_hidden r), StateAdam (XParams kh kw c_out n_hidden r))
runBatch (!Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
parameters, !StateAdam
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
stateAdam) (Int
k, [MnistDataS r]
chunk) = do
let chunkS :: [MnistDataBatchS n r]
chunkS = ([MnistDataS r] -> MnistDataBatchS n r)
-> [[MnistDataS r]] -> [MnistDataBatchS n r]
forall a b. (a -> b) -> [a] -> [b]
map [MnistDataS r] -> MnistDataBatchS n r
forall (batch_size :: Natural) r.
(Elt r, KnownNat batch_size) =>
[MnistDataS r] -> MnistDataBatchS batch_size r
mkMnistDataBatchS
([[MnistDataS r]] -> [MnistDataBatchS n r])
-> [[MnistDataS r]] -> [MnistDataBatchS n r]
forall a b. (a -> b) -> a -> b
$ ([MnistDataS r] -> Bool) -> [[MnistDataS r]] -> [[MnistDataS r]]
forall a. (a -> Bool) -> [a] -> [a]
filter (\[MnistDataS r]
ch -> [MnistDataS r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataS r]
ch Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
miniBatchSizeInt)
([[MnistDataS r]] -> [[MnistDataS r]])
-> [[MnistDataS r]] -> [[MnistDataS r]]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataS r] -> [[MnistDataS r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
miniBatchSizeInt [MnistDataS r]
chunk
res :: (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
res@(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
parameters2, StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
_) =
(MnistDataBatchS n r
-> ADVal
Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> ADVal Concrete (TKScalar r))
-> [MnistDataBatchS n r]
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
forall a (x :: TK) (z :: TK).
KnownSTK x =>
(a -> ADVal Concrete x -> ADVal Concrete z)
-> [a] -> Concrete x -> StateAdam x -> (Concrete x, StateAdam x)
sgdAdam MnistDataBatchS n r
-> ADVal
Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> ADVal Concrete (TKScalar r)
MnistDataBatchS n r
-> ADVal
Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> ADVal Concrete (TKScalar r)
f [MnistDataBatchS n r]
chunkS Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
parameters StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
StateAdam
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
stateAdam
!trainScore :: r
trainScore = Int -> (forall (n :: Natural). KnownNat n => SNat n -> r) -> r
forall r.
Int -> (forall (n :: Natural). KnownNat n => SNat n -> r) -> r
withSNat ([MnistDataS r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataS r]
chunk) ((forall (n :: Natural). KnownNat n => SNat n -> r) -> r)
-> (forall (n :: Natural). KnownNat n => SNat n -> r) -> r
forall a b. (a -> b) -> a -> b
$ \(SNat @len) ->
forall (batch_size :: Natural).
KnownNat batch_size =>
MnistDataBatchS batch_size r
-> Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> r
ftest @len ([MnistDataS r] -> MnistDataBatchS n r
forall (batch_size :: Natural) r.
(Elt r, KnownNat batch_size) =>
[MnistDataS r] -> MnistDataBatchS batch_size r
mkMnistDataBatchS [MnistDataS r]
chunk) Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
parameters2
!testScore :: r
testScore = forall (batch_size :: Natural).
KnownNat batch_size =>
MnistDataBatchS batch_size r
-> Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> r
ftest @lenTestData MnistDataBatchS n r
testDataS Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
parameters2
!lenChunk :: Int
lenChunk = [MnistDataS r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataS r]
chunk
Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
n_hiddenInt Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$ do
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: (Batch %d with %d points)"
String
prefix Int
k Int
lenChunk
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Training error: %.2f%%"
String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
trainScore) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Validation error: %.2f%%"
String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
testScore ) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
res
let runEpoch :: Int
-> (Concrete (XParams kh kw c_out n_hidden r), StateAdam (XParams kh kw c_out n_hidden r))
-> IO (Concrete (XParams kh kw c_out n_hidden r))
runEpoch Int
n (Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
params2, StateAdam
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
_) | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
epochs = Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
params2
runEpoch Int
n paramsStateAdam :: (Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
StateAdam
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
paramsStateAdam@(!Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
_, !StateAdam
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
_) = do
Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
n_hiddenInt Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$ String -> String -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: [Epoch %d]" String
prefix Int
n
let trainDataShuffled :: [MnistDataS r]
trainDataShuffled = StdGen -> [MnistDataS r] -> [MnistDataS r]
forall a. StdGen -> [a] -> [a]
shuffle (Int -> StdGen
mkStdGen (Int -> StdGen) -> Int -> StdGen
forall a b. (a -> b) -> a -> b
$ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
5) [MnistDataS r]
trainData
chunks :: [(Int, [MnistDataS r])]
chunks = Int -> [(Int, [MnistDataS r])] -> [(Int, [MnistDataS r])]
forall a. Int -> [a] -> [a]
take Int
maxBatches
([(Int, [MnistDataS r])] -> [(Int, [MnistDataS r])])
-> [(Int, [MnistDataS r])] -> [(Int, [MnistDataS r])]
forall a b. (a -> b) -> a -> b
$ [Int] -> [[MnistDataS r]] -> [(Int, [MnistDataS r])]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
1 ..]
([[MnistDataS r]] -> [(Int, [MnistDataS r])])
-> [[MnistDataS r]] -> [(Int, [MnistDataS r])]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataS r] -> [[MnistDataS r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
totalBatchSize [MnistDataS r]
trainDataShuffled
res <- ((Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
-> (Int, [MnistDataS r])
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))))
-> (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
-> [(Int, [MnistDataS r])]
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
forall (t :: Type -> Type) (m :: Type -> Type) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
-> (Int, [MnistDataS r])
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
(Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
StateAdam
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
-> (Int, [MnistDataS r])
-> IO
(Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
StateAdam
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
runBatch (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
(Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
StateAdam
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
paramsStateAdam [(Int, [MnistDataS r])]
chunks
runEpoch (succ n) res
res <- runEpoch 1 (targetInit, initialStateAdam ftk)
let testErrorFinal =
r
1 r -> r -> r
forall a. Num a => a -> a -> a
- MnistDataBatchS n r
-> Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> r
forall (batch_size :: Natural).
KnownNat batch_size =>
MnistDataBatchS batch_size r
-> Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> r
ftest MnistDataBatchS n r
testDataS Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
res
testErrorFinal @?~ expected
tensorADValMnistTestsCNNSI :: TestTree
tensorADValMnistTestsCNNSI :: TestTree
tensorADValMnistTestsCNNSI = String -> [TestTree] -> TestTree
testGroup String
"CNNS Intermediate MNIST tests"
[ String
-> Int
-> Int
-> SNat 4
-> SNat 4
-> Int
-> Int
-> Int
-> Int
-> Double
-> TestTree
forall (kh :: Natural) (kw :: Natural) r.
((<=) @Natural 1 kh, (<=) @Natural 1 kw, Differentiable r,
GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> SNat kh
-> SNat kw
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNSI String
"CNNSI 1 epoch, 1 batch"
Int
1 Int
1 (forall (n :: Natural). KnownNat n => SNat n
SNat @4) (forall (n :: Natural). KnownNat n => SNat n
SNat @4) Int
8 Int
16 Int
1 Int
1
(Double
1 :: Double)
, String
-> Int
-> Int
-> SNat 2
-> SNat 3
-> Int
-> Int
-> Int
-> Int
-> Float
-> TestTree
forall (kh :: Natural) (kw :: Natural) r.
((<=) @Natural 1 kh, (<=) @Natural 1 kw, Differentiable r,
GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> SNat kh
-> SNat kw
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNSI String
"CNNSI artificial 1 2 3 4 5"
Int
1 Int
1 (forall (n :: Natural). KnownNat n => SNat n
SNat @2) (forall (n :: Natural). KnownNat n => SNat n
SNat @3) Int
4 Int
5 Int
1 Int
10
(Float
1 :: Float)
, String
-> Int
-> Int
-> SNat 3
-> SNat 2
-> Int
-> Int
-> Int
-> Int
-> Double
-> TestTree
forall (kh :: Natural) (kw :: Natural) r.
((<=) @Natural 1 kh, (<=) @Natural 1 kw, Differentiable r,
GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> SNat kh
-> SNat kw
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNSI String
"CNNSI artificial 5 4 3 2 1"
Int
5 Int
4 (forall (n :: Natural). KnownNat n => SNat n
SNat @3) (forall (n :: Natural). KnownNat n => SNat n
SNat @2) Int
1 Int
1 Int
1 Int
1
(Double
1 :: Double)
, String
-> Int
-> Int
-> SNat 4
-> SNat 4
-> Int
-> Int
-> Int
-> Int
-> Float
-> TestTree
forall (kh :: Natural) (kw :: Natural) r.
((<=) @Natural 1 kh, (<=) @Natural 1 kw, Differentiable r,
GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r) =>
String
-> Int
-> Int
-> SNat kh
-> SNat kw
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNSI String
"CNNSI 1 epoch, 0 batch"
Int
1 Int
0 (forall (n :: Natural). KnownNat n => SNat n
SNat @4) (forall (n :: Natural). KnownNat n => SNat n
SNat @4) Int
16 Int
64 Int
16 Int
50
(Float
1.0 :: Float)
]
mnistTestCaseCNNSO
:: forall kh kw r.
( 1 <= kh, 1 <= kw
, Differentiable r, GoodScalar r
, PrintfArg r, AssertEqualUpToEpsilon r, ADTensorScalar r ~ r )
=> String
-> Int -> Int -> SNat kh -> SNat kw -> Int -> Int -> Int -> Int -> r
-> TestTree
mnistTestCaseCNNSO :: forall (kh :: Natural) (kw :: Natural) r.
((<=) @Natural 1 kh, (<=) @Natural 1 kw, Differentiable r,
GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r,
(ADTensorScalar r :: Type) ~ (r :: Type)) =>
String
-> Int
-> Int
-> SNat kh
-> SNat kw
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNSO String
prefix Int
epochs Int
maxBatches kh :: SNat kh
kh@SNat kh
SNat kw :: SNat kw
kw@SNat kw
SNat Int
c_outInt Int
n_hiddenInt
Int
miniBatchSizeInt Int
totalBatchSize r
expected =
Int
-> (forall (n :: Natural). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Natural). KnownNat n => SNat n -> r) -> r
withSNat Int
c_outInt ((forall (n :: Natural). KnownNat n => SNat n -> TestTree)
-> TestTree)
-> (forall (n :: Natural). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
_c_outSNat :: SNat c_out) ->
Int
-> (forall (n :: Natural). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Natural). KnownNat n => SNat n -> r) -> r
withSNat Int
n_hiddenInt ((forall (n :: Natural). KnownNat n => SNat n -> TestTree)
-> TestTree)
-> (forall (n :: Natural). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
_n_hiddenSNat :: SNat n_hidden) ->
Int
-> (forall (n :: Natural). KnownNat n => SNat n -> TestTree)
-> TestTree
forall r.
Int -> (forall (n :: Natural). KnownNat n => SNat n -> r) -> r
withSNat Int
miniBatchSizeInt ((forall (n :: Natural). KnownNat n => SNat n -> TestTree)
-> TestTree)
-> (forall (n :: Natural). KnownNat n => SNat n -> TestTree)
-> TestTree
forall a b. (a -> b) -> a -> b
$ \(SNat n
miniBatchSize :: SNat miniBatchSize) ->
let targetInit :: Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
targetInit =
(Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
StdGen)
-> Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
forall a b. (a, b) -> a
fst ((Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
StdGen)
-> Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
-> (Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
StdGen)
-> Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
forall a b. (a -> b) -> a -> b
$ forall vals. RandomValue vals => Double -> StdGen -> (vals, StdGen)
randomValue
@(Concrete (X (MnistCnnShaped2.ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistWidth
kh kw c_out n_hidden r)))
Double
0.4 (Int -> StdGen
mkStdGen Int
44)
name :: String
name = String
prefix String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": "
String -> String -> String
forall a. [a] -> [a] -> [a]
++ [String] -> String
unwords [ Int -> String
forall a. Show a => a -> String
show Int
epochs, Int -> String
forall a. Show a => a -> String
show Int
maxBatches
, Int -> String
forall a. Show a => a -> String
show (SNat kh -> Int
forall (n :: Natural). SNat n -> Int
sNatValue SNat kh
kh), Int -> String
forall a. Show a => a -> String
show (SNat kw -> Int
forall (n :: Natural). SNat n -> Int
sNatValue SNat kw
kw)
, Int -> String
forall a. Show a => a -> String
show Int
c_outInt, Int -> String
forall a. Show a => a -> String
show Int
n_hiddenInt
, Int -> String
forall a. Show a => a -> String
show Int
miniBatchSizeInt
, Int -> String
forall a. Show a => a -> String
show (Int -> String) -> Int -> String
forall a b. (a -> b) -> a -> b
$ SingletonTK
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> Int
forall (y :: TK). SingletonTK y -> Int
widthSTK
(SingletonTK
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> Int)
-> SingletonTK
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> Int
forall a b. (a -> b) -> a -> b
$ forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams kh kw c_out n_hidden r)
, Int -> String
forall a. Show a => a -> String
show (SingletonTK
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> Int
forall (y :: TK). SingletonTK y -> Concrete y -> Int
forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> Int
tsize SingletonTK
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
targetInit) ]
ftest :: KnownNat batch_size
=> MnistDataBatchS batch_size r
-> Concrete (XParams kh kw c_out n_hidden r) -> r
ftest :: forall (batch_size :: Natural).
KnownNat batch_size =>
MnistDataBatchS batch_size r
-> Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> r
ftest @batch_size MnistDataBatchS batch_size r
mnistData Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
pars =
SNat kh
-> SNat kw
-> SNat n
-> SNat n
-> SNat batch_size
-> MnistDataBatchS batch_size r
-> ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r
-> r
forall (kh :: Natural) (kw :: Natural) (h :: Natural)
(w :: Natural) (c_out :: Natural) (n_hidden :: Natural)
(batch_size :: Natural) (target :: TK -> Type) r.
((h :: Natural) ~ (SizeMnistHeight :: Natural),
(w :: Natural) ~ (SizeMnistHeight :: Natural), (<=) @Natural 1 kh,
(<=) @Natural 1 kw,
(target :: (TK -> Type)) ~ (Concrete :: (TK -> Type)),
GoodScalar r, Differentiable r) =>
SNat kh
-> SNat kw
-> SNat c_out
-> SNat n_hidden
-> SNat batch_size
-> MnistDataBatchS batch_size r
-> ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
-> r
MnistCnnShaped2.convMnistTestS SNat kh
kh SNat kw
kw (forall (n :: Natural). KnownNat n => SNat n
SNat @c_out) (forall (n :: Natural). KnownNat n => SNat n
SNat @n_hidden)
(forall (n :: Natural). KnownNat n => SNat n
SNat @batch_size) MnistDataBatchS batch_size r
mnistData (forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget @Concrete Concrete
(X ((Concrete
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r)),
Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(Concrete
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r)),
Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(Concrete
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r)),
Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(Concrete
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r)),
Concrete
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
pars)
in String -> Assertion -> TestTree
testCase String
name (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ do
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: Epochs to run/max batches per epoch: %d/%d"
String
prefix Int
epochs Int
maxBatches
trainData <- (MnistData r -> MnistDataS r) -> [MnistData r] -> [MnistDataS r]
forall a b. (a -> b) -> [a] -> [b]
map MnistData r -> MnistDataS r
forall r. PrimElt r => MnistData r -> MnistDataS r
mkMnistDataS
([MnistData r] -> [MnistDataS r])
-> IO [MnistData r] -> IO [MnistDataS r]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> String -> IO [MnistData r]
forall r.
(Storable r, Fractional r) =>
String -> String -> IO [MnistData r]
loadMnistData String
trainGlyphsPath String
trainLabelsPath
testData <- map mkMnistDataS . take (totalBatchSize * maxBatches)
<$> loadMnistData testGlyphsPath testLabelsPath
withSNat (totalBatchSize * maxBatches) $ \(SNat @lenTestData) -> do
let testDataS :: MnistDataBatchS n r
testDataS = [MnistDataS r] -> MnistDataBatchS n r
forall (batch_size :: Natural) r.
(Elt r, KnownNat batch_size) =>
[MnistDataS r] -> MnistDataBatchS batch_size r
mkMnistDataBatchS [MnistDataS r]
testData
dataInit :: (Concrete
(TKS
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r),
Concrete
(TKS
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r))
dataInit = case Int -> [MnistDataS r] -> [[MnistDataS r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
miniBatchSizeInt [MnistDataS r]
testData of
[MnistDataS r]
d : [[MnistDataS r]]
_ -> let (Shaped
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r
dglyph, Shaped
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r
dlabel) = [MnistDataS r]
-> (Shaped
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r,
Shaped
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r)
forall (batch_size :: Natural) r.
(Elt r, KnownNat batch_size) =>
[MnistDataS r] -> MnistDataBatchS batch_size r
mkMnistDataBatchS [MnistDataS r]
d
in (Shaped
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r
-> Concrete
(TKS
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r)
forall r (target :: TK -> Type) (sh :: [Natural]).
(GoodScalar r, BaseTensor target) =>
Shaped sh r -> target (TKS sh r)
sconcrete Shaped
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r
dglyph, Shaped
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r
-> Concrete
(TKS
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r)
forall r (target :: TK -> Type) (sh :: [Natural]).
(GoodScalar r, BaseTensor target) =>
Shaped sh r -> target (TKS sh r)
sconcrete Shaped
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r
dlabel)
[] -> String
-> (Concrete
(TKS
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r),
Concrete
(TKS
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r))
forall a. HasCallStack => String -> a
error String
"empty test data"
f :: ( MnistCnnShaped2.ADCnnMnistParametersShaped
(AstTensor AstMethodLet FullSpan)
SizeMnistHeight SizeMnistWidth
kh kw c_out n_hidden r
, ( AstTensor AstMethodLet FullSpan (TKS '[miniBatchSize, SizeMnistHeight, SizeMnistWidth] r)
, AstTensor AstMethodLet FullSpan (TKS '[miniBatchSize, SizeMnistLabel] r) ) )
-> AstTensor AstMethodLet FullSpan (TKScalar r)
f :: (ADCnnMnistParametersShaped
(AstTensor AstMethodLet FullSpan)
SizeMnistHeight
SizeMnistHeight
kh
kw
n
n
r,
(AstTensor
AstMethodLet
FullSpan
(TKS
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r),
AstTensor
AstMethodLet
FullSpan
(TKS
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
r)))
-> AstTensor AstMethodLet FullSpan (TKScalar r)
f = \ (ADCnnMnistParametersShaped
(AstTensor AstMethodLet FullSpan)
SizeMnistHeight
SizeMnistHeight
kh
kw
n
n
r
pars, (AstTensor
AstMethodLet
FullSpan
(TKS
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r)
glyphR, AstTensor
AstMethodLet
FullSpan
(TKS
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r)
labelR)) ->
SNat kh
-> SNat kw
-> SNat n
-> SNat n
-> SNat n
-> (PrimalOf
(AstTensor AstMethodLet FullSpan)
(TKS
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r),
PrimalOf
(AstTensor AstMethodLet FullSpan)
(TKS
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r))
-> ADCnnMnistParametersShaped
(AstTensor AstMethodLet FullSpan)
SizeMnistHeight
SizeMnistHeight
kh
kw
n
n
r
-> AstTensor AstMethodLet FullSpan (TKScalar r)
forall (kh :: Natural) (kw :: Natural) (h :: Natural)
(w :: Natural) (c_out :: Natural) (n_hidden :: Natural)
(batch_size :: Natural) (target :: TK -> Type) r.
((h :: Natural) ~ (SizeMnistHeight :: Natural),
(w :: Natural) ~ (SizeMnistHeight :: Natural), (<=) @Natural 1 kh,
(<=) @Natural 1 kw, ADReady target, ADReady (PrimalOf target),
GoodScalar r, Differentiable r) =>
SNat kh
-> SNat kw
-> SNat c_out
-> SNat n_hidden
-> SNat batch_size
-> (PrimalOf
target
(TKS
((':)
@Natural
batch_size
((':) @Natural h ((':) @Natural w ('[] @Natural))))
r),
PrimalOf
target
(TKS
((':)
@Natural batch_size ((':) @Natural SizeMnistLabel ('[] @Natural)))
r))
-> ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
-> target (TKScalar r)
MnistCnnShaped2.convMnistLossFusedS
SNat kh
kh SNat kw
kw (forall (n :: Natural). KnownNat n => SNat n
SNat @c_out) (forall (n :: Natural). KnownNat n => SNat n
SNat @n_hidden)
SNat n
miniBatchSize (AstTensor
AstMethodLet
FullSpan
(TKS
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r)
-> PrimalOf
(AstTensor AstMethodLet FullSpan)
(TKS
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r)
forall (target :: TK -> Type) (sh :: [Natural]) (x :: TK).
BaseTensor target =>
target (TKS2 sh x) -> PrimalOf target (TKS2 sh x)
sprimalPart AstTensor
AstMethodLet
FullSpan
(TKS
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r)
glyphR, AstTensor
AstMethodLet
FullSpan
(TKS
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r)
-> PrimalOf
(AstTensor AstMethodLet FullSpan)
(TKS
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r)
forall (target :: TK -> Type) (sh :: [Natural]) (x :: TK).
BaseTensor target =>
target (TKS2 sh x) -> PrimalOf target (TKS2 sh x)
sprimalPart AstTensor
AstMethodLet
FullSpan
(TKS
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r)
labelR) ADCnnMnistParametersShaped
(AstTensor AstMethodLet FullSpan)
SizeMnistHeight
SizeMnistHeight
kh
kw
n
n
r
pars
artRaw :: AstArtifactRev
(X (((AstTensor
AstMethodLet
FullSpan
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r)),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(AstTensor
AstMethodLet
FullSpan
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r)),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(AstTensor
AstMethodLet
FullSpan
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r)),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(AstTensor
AstMethodLet
FullSpan
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r)),
AstTensor
AstMethodLet
FullSpan
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))),
(AstTensor
AstMethodLet
FullSpan
(TKS
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r),
AstTensor
AstMethodLet
FullSpan
(TKS
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
r))))
(TKScalar r)
artRaw = ((((AstTensor
AstMethodLet
FullSpan
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r)),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(AstTensor
AstMethodLet
FullSpan
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r)),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(AstTensor
AstMethodLet
FullSpan
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r)),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(AstTensor
AstMethodLet
FullSpan
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r)),
AstTensor
AstMethodLet
FullSpan
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))),
(AstTensor
AstMethodLet
FullSpan
(TKS
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r),
AstTensor
AstMethodLet
FullSpan
(TKS
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
r)))
-> AstTensor AstMethodLet FullSpan (TKScalar r))
-> Value
(((AstTensor
AstMethodLet
FullSpan
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r)),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(AstTensor
AstMethodLet
FullSpan
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r)),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(AstTensor
AstMethodLet
FullSpan
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r)),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(AstTensor
AstMethodLet
FullSpan
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r)),
AstTensor
AstMethodLet
FullSpan
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))),
(AstTensor
AstMethodLet
FullSpan
(TKS
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r),
AstTensor
AstMethodLet
FullSpan
(TKS
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
r)))
-> AstArtifactRev
(X (((AstTensor
AstMethodLet
FullSpan
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r)),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(AstTensor
AstMethodLet
FullSpan
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r)),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(AstTensor
AstMethodLet
FullSpan
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r)),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(AstTensor
AstMethodLet
FullSpan
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r)),
AstTensor
AstMethodLet
FullSpan
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))),
(AstTensor
AstMethodLet
FullSpan
(TKS
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r),
AstTensor
AstMethodLet
FullSpan
(TKS
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
r))))
(TKScalar r)
forall src r tgt.
((X src :: TK) ~ (X (Value src) :: TK), KnownSTK (X src),
AdaptableTarget (AstTensor AstMethodLet FullSpan) src,
AdaptableTarget Concrete (Value src),
(tgt :: Type)
~ (AstTensor AstMethodLet FullSpan (TKScalar r) :: Type)) =>
(src -> tgt) -> Value src -> AstArtifactRev (X src) (TKScalar r)
gradArtifact (((AstTensor
AstMethodLet
FullSpan
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r)),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(AstTensor
AstMethodLet
FullSpan
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r)),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(AstTensor
AstMethodLet
FullSpan
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r)),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(AstTensor
AstMethodLet
FullSpan
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r)),
AstTensor
AstMethodLet
FullSpan
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))),
(AstTensor
AstMethodLet
FullSpan
(TKS
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r),
AstTensor
AstMethodLet
FullSpan
(TKS
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
r)))
-> AstTensor AstMethodLet FullSpan (TKScalar r)
(ADCnnMnistParametersShaped
(AstTensor AstMethodLet FullSpan)
SizeMnistHeight
SizeMnistHeight
kh
kw
n
n
r,
(AstTensor
AstMethodLet
FullSpan
(TKS
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r),
AstTensor
AstMethodLet
FullSpan
(TKS
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
r)))
-> AstTensor AstMethodLet FullSpan (TKScalar r)
f (Concrete
(X ((Concrete
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r)),
Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(Concrete
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r)),
Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(Concrete
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r)),
Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(Concrete
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r)),
Concrete
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> ((Concrete
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r)),
Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(Concrete
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r)),
Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(Concrete
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r)),
Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(Concrete
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r)),
Concrete
(TKS2 ((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))
forall (target :: TK -> Type) vals.
AdaptableTarget target vals =>
target (X vals) -> vals
fromTarget Concrete
(X ((Concrete
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r)),
Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(Concrete
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r)),
Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(Concrete
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r)),
Concrete (TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(Concrete
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r)),
Concrete
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
targetInit, (Concrete
(TKS
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r),
Concrete
(TKS
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r))
dataInit)
art :: AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
(TKProduct
(TKS
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r)
(TKS
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
r)))
(TKScalar r)
art = AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
(TKProduct
(TKS
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r)
(TKS
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
r)))
(TKScalar r)
-> AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
(TKProduct
(TKS
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r)
(TKS
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
r)))
(TKScalar r)
forall (x :: TK) (z :: TK).
AstArtifactRev x z -> AstArtifactRev x z
simplifyArtifactGradient AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
(TKProduct
(TKS
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r)
(TKS
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
r)))
(TKScalar r)
AstArtifactRev
(X (((AstTensor
AstMethodLet
FullSpan
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r)),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(AstTensor
AstMethodLet
FullSpan
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r)),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(AstTensor
AstMethodLet
FullSpan
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r)),
AstTensor
AstMethodLet
FullSpan
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))),
(AstTensor
AstMethodLet
FullSpan
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r)),
AstTensor
AstMethodLet
FullSpan
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))),
(AstTensor
AstMethodLet
FullSpan
(TKS
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r),
AstTensor
AstMethodLet
FullSpan
(TKS
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
r))))
(TKScalar r)
artRaw
go :: [MnistDataBatchS miniBatchSize r]
-> ( Concrete (XParams kh kw c_out n_hidden r)
, StateAdam (XParams kh kw c_out n_hidden r) )
-> ( Concrete (XParams kh kw c_out n_hidden r)
, StateAdam (XParams kh kw c_out n_hidden r) )
go :: [(Shaped
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r,
Shaped
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r)]
-> (Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
StateAdam
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
-> (Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
StateAdam
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
go [] (Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
parameters, StateAdam
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
stateAdam) = (Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
parameters, StateAdam
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
stateAdam)
go ((Shaped
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r
glyph, Shaped
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r
label) : [(Shaped
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r,
Shaped
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r)]
rest) (!Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
parameters, !StateAdam
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
stateAdam) =
let parametersAndInput :: Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
(TKProduct
(TKS
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r)
(TKS
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
r)))
parametersAndInput =
Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> Concrete
(TKProduct
(TKS
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r)
(TKS
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
(TKProduct
(TKS
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r)
(TKS
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
r)))
forall (x :: TK) (z :: TK).
Concrete x -> Concrete z -> Concrete (TKProduct x z)
forall (target :: TK -> Type) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
parameters (Concrete
(TKS
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r)
-> Concrete
(TKS
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r)
-> Concrete
(TKProduct
(TKS
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r)
(TKS
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r))
forall (x :: TK) (z :: TK).
Concrete x -> Concrete z -> Concrete (TKProduct x z)
forall (target :: TK -> Type) (x :: TK) (z :: TK).
BaseTensor target =>
target x -> target z -> target (TKProduct x z)
tpair (Shaped
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r
-> Concrete
(TKS
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r)
forall r (target :: TK -> Type) (sh :: [Natural]).
(GoodScalar r, BaseTensor target) =>
Shaped sh r -> target (TKS sh r)
sconcrete Shaped
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r
glyph) (Shaped
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r
-> Concrete
(TKS
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r)
forall r (target :: TK -> Type) (sh :: [Natural]).
(GoodScalar r, BaseTensor target) =>
Shaped sh r -> target (TKS sh r)
sconcrete Shaped
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r
label))
gradient :: Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
gradient =
Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
(TKProduct
(TKS
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r)
(TKS
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
r)))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
forall (x :: TK) (z :: TK). Concrete (TKProduct x z) -> Concrete x
forall (target :: TK -> Type) (x :: TK) (z :: TK).
BaseTensor target =>
target (TKProduct x z) -> target x
tproject1 (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
(TKProduct
(TKS
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r)
(TKS
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
r)))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
(TKProduct
(TKS
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r)
(TKS
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
r)))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
forall a b. (a -> b) -> a -> b
$ (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
(TKProduct
(TKS
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r)
(TKS
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
r))),
Concrete (TKScalar r))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
(TKProduct
(TKS
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r)
(TKS
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
r)))
forall a b. (a, b) -> a
fst
((Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
(TKProduct
(TKS
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r)
(TKS
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
r))),
Concrete (TKScalar r))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
(TKProduct
(TKS
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r)
(TKS
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
r))))
-> (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
(TKProduct
(TKS
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r)
(TKS
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
r))),
Concrete (TKScalar r))
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
(TKProduct
(TKS
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r)
(TKS
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
r)))
forall a b. (a -> b) -> a -> b
$ AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
(TKProduct
(TKS
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r)
(TKS
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
r)))
(TKScalar r)
-> Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
(TKProduct
(TKS
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r)
(TKS
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
r)))
-> Maybe (Concrete (ADTensorKind (TKScalar r)))
-> (Concrete
(ADTensorKind
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
(TKProduct
(TKS
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r)
(TKS
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
r)))),
Concrete (TKScalar r))
forall (x :: TK) (z :: TK).
AstArtifactRev x z
-> Concrete x
-> Maybe (Concrete (ADTensorKind z))
-> (Concrete (ADTensorKind x), Concrete z)
revInterpretArtifact AstArtifactRev
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
(TKProduct
(TKS
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r)
(TKS
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
r)))
(TKScalar r)
art Concrete
(TKProduct
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
(TKProduct
(TKS
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r)
(TKS
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
r)))
parametersAndInput Maybe (Concrete (ADTensorKind (TKScalar r)))
Maybe (Concrete (TKScalar r))
forall a. Maybe a
Nothing
in [(Shaped
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r,
Shaped
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r)]
-> (Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
StateAdam
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
-> (Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
StateAdam
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
go [(Shaped
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r,
Shaped
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r)]
rest (forall (y :: TK).
ArgsAdam
-> StateAdam y
-> SingletonTK y
-> Concrete y
-> Concrete (ADTensorKind y)
-> (Concrete y, StateAdam y)
updateWithGradientAdam
@(XParams kh kw c_out n_hidden r)
ArgsAdam
defaultArgsAdam StateAdam
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
stateAdam SingletonTK
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
SingletonTK
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
parameters
Concrete
(ADTensorKind
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
gradient)
runBatch :: ( Concrete (XParams kh kw c_out n_hidden r)
, StateAdam (XParams kh kw c_out n_hidden r) )
-> (Int, [MnistDataS r])
-> IO ( Concrete (XParams kh kw c_out n_hidden r)
, StateAdam (XParams kh kw c_out n_hidden r) )
runBatch :: (Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
StateAdam
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
-> (Int, [MnistDataS r])
-> IO
(Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
StateAdam
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
runBatch (!Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
parameters, !StateAdam
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
stateAdam) (Int
k, [MnistDataS r]
chunk) = do
let chunkS :: [(Shaped
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r,
Shaped
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r)]
chunkS = ([MnistDataS r]
-> (Shaped
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r,
Shaped
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r))
-> [[MnistDataS r]]
-> [(Shaped
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r,
Shaped
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r)]
forall a b. (a -> b) -> [a] -> [b]
map [MnistDataS r]
-> (Shaped
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r,
Shaped
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r)
forall (batch_size :: Natural) r.
(Elt r, KnownNat batch_size) =>
[MnistDataS r] -> MnistDataBatchS batch_size r
mkMnistDataBatchS
([[MnistDataS r]]
-> [(Shaped
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r,
Shaped
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural)))
r)])
-> [[MnistDataS r]]
-> [(Shaped
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r,
Shaped
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r)]
forall a b. (a -> b) -> a -> b
$ ([MnistDataS r] -> Bool) -> [[MnistDataS r]] -> [[MnistDataS r]]
forall a. (a -> Bool) -> [a] -> [a]
filter (\[MnistDataS r]
ch -> [MnistDataS r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataS r]
ch Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
miniBatchSizeInt)
([[MnistDataS r]] -> [[MnistDataS r]])
-> [[MnistDataS r]] -> [[MnistDataS r]]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataS r] -> [[MnistDataS r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
miniBatchSizeInt [MnistDataS r]
chunk
res :: (Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
StateAdam
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
res@(Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
parameters2, StateAdam
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
_) = [(Shaped
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r,
Shaped
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r)]
-> (Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
StateAdam
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
-> (Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
StateAdam
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
go [(Shaped
((':)
@Natural
n
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r,
Shaped
((':) @Natural n ((':) @Natural SizeMnistLabel ('[] @Natural))) r)]
chunkS (Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
parameters, StateAdam
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
stateAdam)
trainScore :: r
trainScore = Int -> (forall (n :: Natural). KnownNat n => SNat n -> r) -> r
forall r.
Int -> (forall (n :: Natural). KnownNat n => SNat n -> r) -> r
withSNat ([MnistDataS r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataS r]
chunk) ((forall (n :: Natural). KnownNat n => SNat n -> r) -> r)
-> (forall (n :: Natural). KnownNat n => SNat n -> r) -> r
forall a b. (a -> b) -> a -> b
$ \(SNat @len) ->
forall (batch_size :: Natural).
KnownNat batch_size =>
MnistDataBatchS batch_size r
-> Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> r
ftest @len ([MnistDataS r] -> MnistDataBatchS n r
forall (batch_size :: Natural) r.
(Elt r, KnownNat batch_size) =>
[MnistDataS r] -> MnistDataBatchS batch_size r
mkMnistDataBatchS [MnistDataS r]
chunk) Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
parameters2
testScore :: r
testScore = forall (batch_size :: Natural).
KnownNat batch_size =>
MnistDataBatchS batch_size r
-> Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> r
ftest @lenTestData MnistDataBatchS n r
testDataS Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
parameters2
lenChunk :: Int
lenChunk = [MnistDataS r] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [MnistDataS r]
chunk
Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
n_hiddenInt Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$ do
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: (Batch %d with %d points)"
String
prefix Int
k Int
lenChunk
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Training error: %.2f%%"
String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
trainScore) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
String -> String -> r -> String
forall r. PrintfType r => String -> r
printf String
"%s: Validation error: %.2f%%"
String
prefix ((r
1 r -> r -> r
forall a. Num a => a -> a -> a
- r
testScore ) r -> r -> r
forall a. Num a => a -> a -> a
* r
100)
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
(Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
StateAdam
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
res
let runEpoch :: Int
-> (Concrete (XParams kh kw c_out n_hidden r), StateAdam (XParams kh kw c_out n_hidden r))
-> IO (Concrete (XParams kh kw c_out n_hidden r))
runEpoch :: Int
-> (Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
StateAdam
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
-> IO
(Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
runEpoch Int
n (Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
params2, StateAdam
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
_) | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
epochs = Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
params2
runEpoch Int
n paramsStateAdam :: (Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
StateAdam
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
paramsStateAdam@(!Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
_, !StateAdam
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
_) = do
Bool -> Assertion -> Assertion
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
n_hiddenInt Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10) (Assertion -> Assertion) -> Assertion -> Assertion
forall a b. (a -> b) -> a -> b
$
Handle -> String -> Assertion
hPutStrLn Handle
stderr (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$ String -> String -> Int -> String
forall r. PrintfType r => String -> r
printf String
"\n%s: [Epoch %d]" String
prefix Int
n
let trainDataShuffled :: [MnistDataS r]
trainDataShuffled = StdGen -> [MnistDataS r] -> [MnistDataS r]
forall a. StdGen -> [a] -> [a]
shuffle (Int -> StdGen
mkStdGen (Int -> StdGen) -> Int -> StdGen
forall a b. (a -> b) -> a -> b
$ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
5) [MnistDataS r]
trainData
chunks :: [(Int, [MnistDataS r])]
chunks = Int -> [(Int, [MnistDataS r])] -> [(Int, [MnistDataS r])]
forall a. Int -> [a] -> [a]
take Int
maxBatches
([(Int, [MnistDataS r])] -> [(Int, [MnistDataS r])])
-> [(Int, [MnistDataS r])] -> [(Int, [MnistDataS r])]
forall a b. (a -> b) -> a -> b
$ [Int] -> [[MnistDataS r]] -> [(Int, [MnistDataS r])]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
1 ..]
([[MnistDataS r]] -> [(Int, [MnistDataS r])])
-> [[MnistDataS r]] -> [(Int, [MnistDataS r])]
forall a b. (a -> b) -> a -> b
$ Int -> [MnistDataS r] -> [[MnistDataS r]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
totalBatchSize [MnistDataS r]
trainDataShuffled
res <- ((Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
-> (Int, [MnistDataS r])
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))))
-> (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
-> [(Int, [MnistDataS r])]
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
forall (t :: Type -> Type) (m :: Type -> Type) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
-> (Int, [MnistDataS r])
-> IO
(Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
(Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
StateAdam
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
-> (Int, [MnistDataS r])
-> IO
(Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
StateAdam
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
runBatch (Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))),
StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r))))))
(Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
StateAdam
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
paramsStateAdam [(Int, [MnistDataS r])]
chunks
runEpoch (succ n) res
ftk :: FullShapeTK
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
ftk = forall (target :: TK -> Type) (y :: TK).
BaseTensor target =>
SingletonTK y -> target y -> FullShapeTK y
tftk @Concrete (forall (y :: TK). KnownSTK y => SingletonTK y
knownSTK @(XParams kh kw c_out n_hidden r)) Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
targetInit
res <- Int
-> (Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)),
StateAdam
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
-> IO
(Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r)))
runEpoch Int
1 (Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
targetInit, FullShapeTK
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
-> StateAdam
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
forall (y :: TK). FullShapeTK y -> StateAdam y
initialStateAdam FullShapeTK
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
ftk)
let testErrorFinal =
r
1 r -> r -> r
forall a. Num a => a -> a -> a
- MnistDataBatchS n r
-> Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> r
forall (batch_size :: Natural).
KnownNat batch_size =>
MnistDataBatchS batch_size r
-> Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
-> r
ftest MnistDataBatchS n r
testDataS Concrete
(TKProduct
(TKProduct
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':)
@Natural
n
((':)
@Natural
n
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r))))
(TKProduct
(TKProduct
(TKS2
((':) @Natural n ((':) @Natural ((n * 7) * 7) ('[] @Natural)))
(TKScalar r))
(TKS2 ((':) @Natural n ('[] @Natural)) (TKScalar r)))
(TKProduct
(TKS2
((':) @Natural SizeMnistLabel ((':) @Natural n ('[] @Natural)))
(TKScalar r))
(TKS2
((':) @Natural SizeMnistLabel ('[] @Natural)) (TKScalar r)))))
Concrete
(X (ADCnnMnistParametersShaped
Concrete SizeMnistHeight SizeMnistHeight kh kw n n r))
res
assertEqualUpToEpsilon 1e-1 expected testErrorFinal
tensorADValMnistTestsCNNSO :: TestTree
tensorADValMnistTestsCNNSO :: TestTree
tensorADValMnistTestsCNNSO = String -> [TestTree] -> TestTree
testGroup String
"CNNS Once MNIST tests"
[ String
-> Int
-> Int
-> SNat 4
-> SNat 4
-> Int
-> Int
-> Int
-> Int
-> Double
-> TestTree
forall (kh :: Natural) (kw :: Natural) r.
((<=) @Natural 1 kh, (<=) @Natural 1 kw, Differentiable r,
GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r,
(ADTensorScalar r :: Type) ~ (r :: Type)) =>
String
-> Int
-> Int
-> SNat kh
-> SNat kw
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNSO String
"CNNSO 1 epoch, 1 batch"
Int
1 Int
1 (forall (n :: Natural). KnownNat n => SNat n
SNat @4) (forall (n :: Natural). KnownNat n => SNat n
SNat @4) Int
8 Int
16 Int
1 Int
1
(Double
1 :: Double)
, String
-> Int
-> Int
-> SNat 2
-> SNat 3
-> Int
-> Int
-> Int
-> Int
-> Float
-> TestTree
forall (kh :: Natural) (kw :: Natural) r.
((<=) @Natural 1 kh, (<=) @Natural 1 kw, Differentiable r,
GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r,
(ADTensorScalar r :: Type) ~ (r :: Type)) =>
String
-> Int
-> Int
-> SNat kh
-> SNat kw
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNSO String
"CNNSO artificial 1 2 3 4 5"
Int
1 Int
1 (forall (n :: Natural). KnownNat n => SNat n
SNat @2) (forall (n :: Natural). KnownNat n => SNat n
SNat @3) Int
4 Int
5 Int
1 Int
10
(Float
1 :: Float)
, String
-> Int
-> Int
-> SNat 3
-> SNat 2
-> Int
-> Int
-> Int
-> Int
-> Double
-> TestTree
forall (kh :: Natural) (kw :: Natural) r.
((<=) @Natural 1 kh, (<=) @Natural 1 kw, Differentiable r,
GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r,
(ADTensorScalar r :: Type) ~ (r :: Type)) =>
String
-> Int
-> Int
-> SNat kh
-> SNat kw
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNSO String
"CNNSO artificial 5 4 3 2 1"
Int
5 Int
4 (forall (n :: Natural). KnownNat n => SNat n
SNat @3) (forall (n :: Natural). KnownNat n => SNat n
SNat @2) Int
1 Int
1 Int
1 Int
1
(Double
1 :: Double)
, String
-> Int
-> Int
-> SNat 4
-> SNat 4
-> Int
-> Int
-> Int
-> Int
-> Float
-> TestTree
forall (kh :: Natural) (kw :: Natural) r.
((<=) @Natural 1 kh, (<=) @Natural 1 kw, Differentiable r,
GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r,
(ADTensorScalar r :: Type) ~ (r :: Type)) =>
String
-> Int
-> Int
-> SNat kh
-> SNat kw
-> Int
-> Int
-> Int
-> Int
-> r
-> TestTree
mnistTestCaseCNNSO String
"CNNSO 1 epoch, 0 batch"
Int
1 Int
0 (forall (n :: Natural). KnownNat n => SNat n
SNat @4) (forall (n :: Natural). KnownNat n => SNat n
SNat @4) Int
16 Int
64 Int
16 Int
50
(Float
1.0 :: Float)
]