{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -Wno-missing-export-lists #-}
module MnistCnnShaped2 where
import Prelude
import Data.Vector.Generic qualified as V
import Data.Vector.Storable (Vector)
import GHC.TypeLits (fromSNat, type (*), type (+), type (<=), type Div)
import Data.Array.Nested qualified as Nested
import Data.Array.Nested.Shaped.Shape
import HordeAd
import MnistData
type
(target :: Target) h w kh kw c_out n_hidden r =
( ( target (TKS '[c_out, 1, kh + 1, kw + 1] r)
, target (TKS '[c_out] r) )
, ( target (TKS '[c_out, c_out, kh + 1, kw + 1] r)
, target (TKS '[c_out] r) )
, ( target (TKS '[n_hidden, c_out * (h `Div` 4) * (w `Div` 4) ] r)
, target (TKS '[n_hidden] r) )
, ( target (TKS '[SizeMnistLabel, n_hidden] r)
, target (TKS '[SizeMnistLabel] r) )
)
convMnistLayerS
:: forall kh kw h w c_in c_out batch_size target r.
( 1 <= kh
, 1 <= kw
, ADReady target, GoodScalar r, Differentiable r )
=> SNat kh -> SNat kw -> SNat h -> SNat w
-> SNat c_in -> SNat c_out -> SNat batch_size
-> target (TKS '[c_out, c_in, kh + 1, kw + 1] r)
-> target (TKS '[batch_size, c_in, h, w] r)
-> target (TKS '[c_out] r)
-> target (TKS '[batch_size, c_out, h `Div` 2, w `Div` 2] r)
convMnistLayerS :: forall (kh :: Natural) (kw :: Natural) (h :: Natural)
(w :: Natural) (c_in :: Natural) (c_out :: Natural)
(batch_size :: Natural) (target :: Target) r.
((<=) @Natural 1 kh, (<=) @Natural 1 kw, ADReady target,
GoodScalar r, Differentiable r) =>
SNat kh
-> SNat kw
-> SNat h
-> SNat w
-> SNat c_in
-> SNat c_out
-> SNat batch_size
-> target
(TKS
((':)
@Natural
c_out
((':)
@Natural
c_in
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
r)
-> target
(TKS
((':)
@Natural
batch_size
((':)
@Natural c_in ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
r)
-> target (TKS ((':) @Natural c_out ('[] @Natural)) r)
-> target
(TKS
((':)
@Natural
batch_size
((':)
@Natural
c_out
((':)
@Natural (Div h 2) ((':) @Natural (Div w 2) ('[] @Natural)))))
r)
convMnistLayerS SNat kh
SNat SNat kw
SNat SNat h
SNat SNat w
SNat SNat c_in
SNat SNat c_out
SNat SNat batch_size
SNat
target
(TKS
((':)
@Natural
c_out
((':)
@Natural
c_in
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
r)
ker target
(TKS
((':)
@Natural
batch_size
((':)
@Natural c_in ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
r)
input target (TKS ((':) @Natural c_out ('[] @Natural)) r)
bias =
let yConv :: target
(TKS
((':)
@Natural
batch_size
((':)
@Natural c_out ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
r)
yConv = target
(TKS
((':)
@Natural
c_out
((':)
@Natural
c_in
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
r)
-> target
(TKS
((':)
@Natural
batch_size
((':)
@Natural c_in ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
r)
-> target
(TKS
((':)
@Natural
batch_size
((':)
@Natural c_out ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
r)
forall (nCoutK :: Natural) (nCinpK :: Natural) (nKh :: Natural)
(nKw :: Natural) (nImgs :: Natural) (nCinpA :: Natural)
(nAh :: Natural) (nAw :: Natural) (target :: Target) r
(shB :: [Natural]) (shK1 :: [Natural]).
(KnownNat nCoutK, KnownNat nCinpK, KnownNat nKh, KnownNat nKw,
KnownNat nImgs, KnownNat nAh, KnownNat nAw, ADReady target,
GoodScalar r, (nCinpA :: Natural) ~ (nCinpK :: Natural),
(shB :: [Natural])
~ ((':)
@Natural
nImgs
((':)
@Natural
nCoutK
((':)
@Natural nAh ((':) @Natural nAw ('[] @Natural)))) :: [Natural]),
(shK1 :: [Natural])
~ ((':)
@Natural
1
((':)
@Natural
nCinpA
((':)
@Natural nKh ((':) @Natural nKw ('[] @Natural)))) :: [Natural])) =>
target
(TKS
((':)
@Natural
nCoutK
((':)
@Natural
nCinpK
((':) @Natural nKh ((':) @Natural nKw ('[] @Natural)))))
r)
-> target
(TKS
((':)
@Natural
nImgs
((':)
@Natural
nCinpA
((':) @Natural nAh ((':) @Natural nAw ('[] @Natural)))))
r)
-> target (TKS shB r)
conv2dUnpaddedS target
(TKS
((':)
@Natural
c_out
((':)
@Natural
c_in
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
r)
ker target
(TKS
((':)
@Natural
batch_size
((':)
@Natural c_in ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
r)
input
biasStretched :: target
(TKS2
(PermutePrefix
@Natural
((':)
@Natural
0
((':)
@Natural 3 ((':) @Natural 1 ((':) @Natural 2 ('[] @Natural)))))
((':)
@Natural
batch_size
((':)
@Natural
h
((':) @Natural w ((':) @Natural c_out ('[] @Natural))))))
(TKScalar r))
biasStretched = forall (perm :: [Natural]) (sh :: [Natural]) (x :: TK)
(target :: Target).
(KnownPerm perm, IsPermutation perm,
(<=) @Natural (Rank @Natural perm) (Rank @Natural sh), KnownSTK x,
BaseTensor target) =>
target (TKS2 sh x)
-> target (TKS2 (PermutePrefix @Natural perm sh) x)
stranspose @'[0, 3, 1, 2]
(target
(TKS2
((':)
@Natural
batch_size
((':)
@Natural h ((':) @Natural w ((':) @Natural c_out ('[] @Natural)))))
(TKScalar r))
-> target
(TKS2
(PermutePrefix
@Natural
((':)
@Natural
0
((':)
@Natural 3 ((':) @Natural 1 ((':) @Natural 2 ('[] @Natural)))))
((':)
@Natural
batch_size
((':)
@Natural
h
((':) @Natural w ((':) @Natural c_out ('[] @Natural))))))
(TKScalar r)))
-> target
(TKS2
((':)
@Natural
batch_size
((':)
@Natural h ((':) @Natural w ((':) @Natural c_out ('[] @Natural)))))
(TKScalar r))
-> target
(TKS2
(PermutePrefix
@Natural
((':)
@Natural
0
((':)
@Natural 3 ((':) @Natural 1 ((':) @Natural 2 ('[] @Natural)))))
((':)
@Natural
batch_size
((':)
@Natural
h
((':) @Natural w ((':) @Natural c_out ('[] @Natural))))))
(TKScalar r))
forall a b. (a -> b) -> a -> b
$ target
(TKS2
((':)
@Natural h ((':) @Natural w ((':) @Natural c_out ('[] @Natural))))
(TKScalar r))
-> target
(TKS2
((':)
@Natural
batch_size
((':)
@Natural h ((':) @Natural w ((':) @Natural c_out ('[] @Natural)))))
(TKScalar r))
forall (k :: Natural) (sh :: [Natural]) (x :: TK)
(target :: Target).
(KnownNat k, KnownShS sh, KnownSTK x, BaseTensor target) =>
target (TKS2 sh x) -> target (TKS2 ((':) @Natural k sh) x)
sreplicate
(target
(TKS2
((':)
@Natural h ((':) @Natural w ((':) @Natural c_out ('[] @Natural))))
(TKScalar r))
-> target
(TKS2
((':)
@Natural
batch_size
((':)
@Natural h ((':) @Natural w ((':) @Natural c_out ('[] @Natural)))))
(TKScalar r)))
-> target
(TKS2
((':)
@Natural h ((':) @Natural w ((':) @Natural c_out ('[] @Natural))))
(TKScalar r))
-> target
(TKS2
((':)
@Natural
batch_size
((':)
@Natural h ((':) @Natural w ((':) @Natural c_out ('[] @Natural)))))
(TKScalar r))
forall a b. (a -> b) -> a -> b
$ target
(TKS2
((':) @Natural w ((':) @Natural c_out ('[] @Natural)))
(TKScalar r))
-> target
(TKS2
((':)
@Natural h ((':) @Natural w ((':) @Natural c_out ('[] @Natural))))
(TKScalar r))
forall (k :: Natural) (sh :: [Natural]) (x :: TK)
(target :: Target).
(KnownNat k, KnownShS sh, KnownSTK x, BaseTensor target) =>
target (TKS2 sh x) -> target (TKS2 ((':) @Natural k sh) x)
sreplicate
(target
(TKS2
((':) @Natural w ((':) @Natural c_out ('[] @Natural)))
(TKScalar r))
-> target
(TKS2
((':)
@Natural h ((':) @Natural w ((':) @Natural c_out ('[] @Natural))))
(TKScalar r)))
-> target
(TKS2
((':) @Natural w ((':) @Natural c_out ('[] @Natural)))
(TKScalar r))
-> target
(TKS2
((':)
@Natural h ((':) @Natural w ((':) @Natural c_out ('[] @Natural))))
(TKScalar r))
forall a b. (a -> b) -> a -> b
$ target (TKS ((':) @Natural c_out ('[] @Natural)) r)
-> target
(TKS2
((':) @Natural w ((':) @Natural c_out ('[] @Natural)))
(TKScalar r))
forall (k :: Natural) (sh :: [Natural]) (x :: TK)
(target :: Target).
(KnownNat k, KnownShS sh, KnownSTK x, BaseTensor target) =>
target (TKS2 sh x) -> target (TKS2 ((':) @Natural k sh) x)
sreplicate target (TKS ((':) @Natural c_out ('[] @Natural)) r)
bias
yRelu :: target
(TKS
((':)
@Natural
batch_size
((':)
@Natural c_out ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
r)
yRelu = target
(TKS
((':)
@Natural
batch_size
((':)
@Natural c_out ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
r)
-> target
(TKS
((':)
@Natural
batch_size
((':)
@Natural c_out ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
r)
forall (target :: Target) (sh :: [Natural]) r.
(KnownShS sh, ADReady target, GoodScalar r, Differentiable r) =>
target (TKS sh r) -> target (TKS sh r)
reluS (target
(TKS
((':)
@Natural
batch_size
((':)
@Natural c_out ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
r)
-> target
(TKS
((':)
@Natural
batch_size
((':)
@Natural c_out ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
r))
-> target
(TKS
((':)
@Natural
batch_size
((':)
@Natural c_out ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
r)
-> target
(TKS
((':)
@Natural
batch_size
((':)
@Natural c_out ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
r)
forall a b. (a -> b) -> a -> b
$ target
(TKS
((':)
@Natural
batch_size
((':)
@Natural c_out ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
r)
yConv target
(TKS
((':)
@Natural
batch_size
((':)
@Natural c_out ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
r)
-> target
(TKS
((':)
@Natural
batch_size
((':)
@Natural c_out ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
r)
-> target
(TKS
((':)
@Natural
batch_size
((':)
@Natural c_out ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
r)
forall a. Num a => a -> a -> a
+ target
(TKS
((':)
@Natural
batch_size
((':)
@Natural c_out ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
r)
target
(TKS2
(PermutePrefix
@Natural
((':)
@Natural
0
((':)
@Natural 3 ((':) @Natural 1 ((':) @Natural 2 ('[] @Natural)))))
((':)
@Natural
batch_size
((':)
@Natural
h
((':) @Natural w ((':) @Natural c_out ('[] @Natural))))))
(TKScalar r))
biasStretched
in forall (ksize :: Natural) (stride :: Natural)
(batch_size :: Natural) (channels :: Natural) (h :: Natural)
(w :: Natural) (target :: Target) r (shOut :: [Natural])
(shK1 :: [Natural]).
(KnownNat ksize, KnownNat stride, KnownNat batch_size,
KnownNat channels, KnownNat h, KnownNat w, (<=) @Natural 1 stride,
ADReady target, GoodScalar r,
(shOut :: [Natural])
~ ((':)
@Natural
batch_size
((':)
@Natural
channels
((':)
@Natural
(Div h stride)
((':) @Natural (Div w stride) ('[] @Natural)))) :: [Natural]),
(shK1 :: [Natural])
~ ((':)
@Natural
1
((':)
@Natural
1
((':)
@Natural
ksize
((':) @Natural ksize ('[] @Natural)))) :: [Natural])) =>
target
(TKS
((':)
@Natural
batch_size
((':)
@Natural
channels
((':) @Natural h ((':) @Natural w ('[] @Natural)))))
r)
-> target (TKS shOut r)
maxPool2dUnpaddedS @2 @2 target
(TKS
((':)
@Natural
batch_size
((':)
@Natural c_out ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
r)
yRelu
convMnistTwoS
:: forall kh kw h w c_out n_hidden batch_size target r.
( 1 <= kh
, 1 <= kw
, ADReady target, GoodScalar r, Differentiable r )
=> SNat kh -> SNat kw -> SNat h -> SNat w
-> SNat c_out -> SNat n_hidden -> SNat batch_size
-> PrimalOf target (TKS '[batch_size, 1, h, w] r)
-> ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
-> target (TKS '[SizeMnistLabel, batch_size] r)
convMnistTwoS :: forall (kh :: Natural) (kw :: Natural) (h :: Natural)
(w :: Natural) (c_out :: Natural) (n_hidden :: Natural)
(batch_size :: Natural) (target :: Target) r.
((<=) @Natural 1 kh, (<=) @Natural 1 kw, ADReady target,
GoodScalar r, Differentiable r) =>
SNat kh
-> SNat kw
-> SNat h
-> SNat w
-> SNat c_out
-> SNat n_hidden
-> SNat batch_size
-> PrimalOf
target
(TKS
((':)
@Natural
batch_size
((':)
@Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
r)
-> ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
-> target
(TKS
((':)
@Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
r)
convMnistTwoS kh :: SNat kh
kh@SNat kh
SNat kw :: SNat kw
kw@SNat kw
SNat h :: SNat h
h@SNat h
SNat w :: SNat w
w@SNat w
SNat
c_out :: SNat c_out
c_out@SNat c_out
SNat _n_hidden :: SNat n_hidden
_n_hidden@SNat n_hidden
SNat batch_size :: SNat batch_size
batch_size@SNat batch_size
SNat
PrimalOf
target
(TKS
((':)
@Natural
batch_size
((':)
@Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
r)
input
( (target
(TKS
((':)
@Natural
c_out
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
r)
ker1, target (TKS ((':) @Natural c_out ('[] @Natural)) r)
bias1), (target
(TKS
((':)
@Natural
c_out
((':)
@Natural
c_out
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
r)
ker2, target (TKS ((':) @Natural c_out ('[] @Natural)) r)
bias2)
, (target
(TKS
((':)
@Natural
n_hidden
((':) @Natural ((c_out * Div h 4) * Div w 4) ('[] @Natural)))
r)
weightsDense, target (TKS ((':) @Natural n_hidden ('[] @Natural)) r)
biasesDense), (target
(TKS
((':)
@Natural SizeMnistLabel ((':) @Natural n_hidden ('[] @Natural)))
r)
weightsReadout, target (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r)
biasesReadout) ) =
forall (a :: Natural) (b :: Natural) r.
(((a :: Natural) ~ (b :: Natural)) => r) -> r
forall {k} (a :: k) (b :: k) r. (((a :: k) ~ (b :: k)) => r) -> r
assumeEquality @(Div (Div w 2) 2) @(Div w 4) ((((Div (Div w 2) 2 :: Natural) ~ (Div w 4 :: Natural)) =>
target
(TKS
((':)
@Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
r))
-> target
(TKS
((':)
@Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
r))
-> (((Div (Div w 2) 2 :: Natural) ~ (Div w 4 :: Natural)) =>
target
(TKS
((':)
@Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
r))
-> target
(TKS
((':)
@Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
r)
forall a b. (a -> b) -> a -> b
$
forall (a :: Natural) (b :: Natural) r.
(((a :: Natural) ~ (b :: Natural)) => r) -> r
forall {k} (a :: k) (b :: k) r. (((a :: k) ~ (b :: k)) => r) -> r
assumeEquality @(Div (Div h 2) 2) @(Div h 4) ((((Div (Div h 2) 2 :: Natural) ~ (Div h 4 :: Natural)) =>
target
(TKS
((':)
@Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
r))
-> target
(TKS
((':)
@Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
r))
-> (((Div (Div h 2) 2 :: Natural) ~ (Div h 4 :: Natural)) =>
target
(TKS
((':)
@Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
r))
-> target
(TKS
((':)
@Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
r)
forall a b. (a -> b) -> a -> b
$
let t1 :: target
(TKS
((':)
@Natural
batch_size
((':)
@Natural
c_out
((':)
@Natural (Div h 2) ((':) @Natural (Div w 2) ('[] @Natural)))))
r)
t1 = SNat kh
-> SNat kw
-> SNat h
-> SNat w
-> SNat 1
-> SNat c_out
-> SNat batch_size
-> target
(TKS
((':)
@Natural
c_out
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
r)
-> target
(TKS
((':)
@Natural
batch_size
((':)
@Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
r)
-> target (TKS ((':) @Natural c_out ('[] @Natural)) r)
-> target
(TKS
((':)
@Natural
batch_size
((':)
@Natural
c_out
((':)
@Natural (Div h 2) ((':) @Natural (Div w 2) ('[] @Natural)))))
r)
forall (kh :: Natural) (kw :: Natural) (h :: Natural)
(w :: Natural) (c_in :: Natural) (c_out :: Natural)
(batch_size :: Natural) (target :: Target) r.
((<=) @Natural 1 kh, (<=) @Natural 1 kw, ADReady target,
GoodScalar r, Differentiable r) =>
SNat kh
-> SNat kw
-> SNat h
-> SNat w
-> SNat c_in
-> SNat c_out
-> SNat batch_size
-> target
(TKS
((':)
@Natural
c_out
((':)
@Natural
c_in
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
r)
-> target
(TKS
((':)
@Natural
batch_size
((':)
@Natural c_in ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
r)
-> target (TKS ((':) @Natural c_out ('[] @Natural)) r)
-> target
(TKS
((':)
@Natural
batch_size
((':)
@Natural
c_out
((':)
@Natural (Div h 2) ((':) @Natural (Div w 2) ('[] @Natural)))))
r)
convMnistLayerS SNat kh
kh SNat kw
kw SNat h
h SNat w
w
(forall (n :: Natural). KnownNat n => SNat n
SNat @1) SNat c_out
c_out SNat batch_size
batch_size
target
(TKS
((':)
@Natural
c_out
((':)
@Natural
1
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
r)
ker1 (PrimalOf
target
(TKS
((':)
@Natural
batch_size
((':)
@Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
r)
-> target
(TKS
((':)
@Natural
batch_size
((':)
@Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
r)
forall (target :: Target) (sh :: [Natural]) (x :: TK).
(BaseTensor target, KnownShS sh, KnownSTK x) =>
PrimalOf target (TKS2 sh x) -> target (TKS2 sh x)
sfromPrimal PrimalOf
target
(TKS
((':)
@Natural
batch_size
((':)
@Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
r)
input) target (TKS ((':) @Natural c_out ('[] @Natural)) r)
bias1
t2 :: target
(TKS
((':)
@Natural
batch_size
((':)
@Natural
c_out
((':)
@Natural
(Div (Div h 2) 2)
((':) @Natural (Div (Div w 2) 2) ('[] @Natural)))))
r)
t2 = SNat kh
-> SNat kw
-> SNat (Div h 2)
-> SNat (Div w 2)
-> SNat c_out
-> SNat c_out
-> SNat batch_size
-> target
(TKS
((':)
@Natural
c_out
((':)
@Natural
c_out
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
r)
-> target
(TKS
((':)
@Natural
batch_size
((':)
@Natural
c_out
((':)
@Natural (Div h 2) ((':) @Natural (Div w 2) ('[] @Natural)))))
r)
-> target (TKS ((':) @Natural c_out ('[] @Natural)) r)
-> target
(TKS
((':)
@Natural
batch_size
((':)
@Natural
c_out
((':)
@Natural
(Div (Div h 2) 2)
((':) @Natural (Div (Div w 2) 2) ('[] @Natural)))))
r)
forall (kh :: Natural) (kw :: Natural) (h :: Natural)
(w :: Natural) (c_in :: Natural) (c_out :: Natural)
(batch_size :: Natural) (target :: Target) r.
((<=) @Natural 1 kh, (<=) @Natural 1 kw, ADReady target,
GoodScalar r, Differentiable r) =>
SNat kh
-> SNat kw
-> SNat h
-> SNat w
-> SNat c_in
-> SNat c_out
-> SNat batch_size
-> target
(TKS
((':)
@Natural
c_out
((':)
@Natural
c_in
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
r)
-> target
(TKS
((':)
@Natural
batch_size
((':)
@Natural c_in ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
r)
-> target (TKS ((':) @Natural c_out ('[] @Natural)) r)
-> target
(TKS
((':)
@Natural
batch_size
((':)
@Natural
c_out
((':)
@Natural (Div h 2) ((':) @Natural (Div w 2) ('[] @Natural)))))
r)
convMnistLayerS SNat kh
kh SNat kw
kw (forall (n :: Natural). KnownNat n => SNat n
SNat @(h `Div` 2)) (forall (n :: Natural). KnownNat n => SNat n
SNat @(w `Div` 2))
SNat c_out
c_out SNat c_out
c_out SNat batch_size
batch_size
target
(TKS
((':)
@Natural
c_out
((':)
@Natural
c_out
((':) @Natural (kh + 1) ((':) @Natural (kw + 1) ('[] @Natural)))))
r)
ker2 target
(TKS
((':)
@Natural
batch_size
((':)
@Natural
c_out
((':)
@Natural (Div h 2) ((':) @Natural (Div w 2) ('[] @Natural)))))
r)
t1 target (TKS ((':) @Natural c_out ('[] @Natural)) r)
bias2
m1 :: target
(TKS2
((':)
@Natural
batch_size
((':) @Natural ((c_out * Div h 4) * Div w 4) ('[] @Natural)))
(TKScalar r))
m1 = target
(TKS2
((':)
@Natural
batch_size
((':)
@Natural
c_out
((':)
@Natural (Div h 4) ((':) @Natural (Div w 4) ('[] @Natural)))))
(TKScalar r))
-> target
(TKS2
((':)
@Natural
batch_size
((':) @Natural ((c_out * Div h 4) * Div w 4) ('[] @Natural)))
(TKScalar r))
forall (sh :: [Natural]) (sh2 :: [Natural]) (x :: TK)
(target :: Target).
((Product sh :: Natural) ~ (Product sh2 :: Natural), KnownShS sh2,
KnownSTK x, BaseTensor target) =>
target (TKS2 sh x) -> target (TKS2 sh2 x)
sreshape target
(TKS2
((':)
@Natural
batch_size
((':)
@Natural
c_out
((':)
@Natural (Div h 4) ((':) @Natural (Div w 4) ('[] @Natural)))))
(TKScalar r))
target
(TKS
((':)
@Natural
batch_size
((':)
@Natural
c_out
((':)
@Natural
(Div (Div h 2) 2)
((':) @Natural (Div (Div w 2) 2) ('[] @Natural)))))
r)
t2
denseLayer :: target
(TKS
((':) @Natural n_hidden ((':) @Natural batch_size ('[] @Natural)))
r)
denseLayer = target
(TKS
((':)
@Natural
n_hidden
((':) @Natural ((c_out * Div h 4) * Div w 4) ('[] @Natural)))
r)
weightsDense target
(TKS
((':)
@Natural
n_hidden
((':) @Natural ((c_out * Div h 4) * Div w 4) ('[] @Natural)))
r)
-> target
(TKS
((':)
@Natural
((c_out * Div h 4) * Div w 4)
((':) @Natural batch_size ('[] @Natural)))
r)
-> target
(TKS
((':) @Natural n_hidden ((':) @Natural batch_size ('[] @Natural)))
r)
forall (m :: Natural) (n :: Natural) (p :: Natural) r
(target :: Target).
(KnownNat m, KnownNat n, KnownNat p, GoodScalar r,
BaseTensor target) =>
target (TKS ((':) @Natural m ((':) @Natural n ('[] @Natural))) r)
-> target
(TKS ((':) @Natural n ((':) @Natural p ('[] @Natural))) r)
-> target
(TKS ((':) @Natural m ((':) @Natural p ('[] @Natural))) r)
`smatmul2` target
(TKS2
((':)
@Natural
batch_size
((':) @Natural ((c_out * Div h 4) * Div w 4) ('[] @Natural)))
(TKScalar r))
-> target
(TKS
((':)
@Natural
((c_out * Div h 4) * Div w 4)
((':) @Natural batch_size ('[] @Natural)))
r)
forall (n :: Natural) (m :: Natural) (sh :: [Natural]) (x :: TK)
(target :: Target).
(KnownSTK x, BaseTensor target) =>
target (TKS2 ((':) @Natural n ((':) @Natural m sh)) x)
-> target (TKS2 ((':) @Natural m ((':) @Natural n sh)) x)
str target
(TKS2
((':)
@Natural
batch_size
((':) @Natural ((c_out * Div h 4) * Div w 4) ('[] @Natural)))
(TKScalar r))
m1
target
(TKS
((':) @Natural n_hidden ((':) @Natural batch_size ('[] @Natural)))
r)
-> target
(TKS
((':) @Natural n_hidden ((':) @Natural batch_size ('[] @Natural)))
r)
-> target
(TKS
((':) @Natural n_hidden ((':) @Natural batch_size ('[] @Natural)))
r)
forall a. Num a => a -> a -> a
+ target
(TKS2
((':) @Natural batch_size ((':) @Natural n_hidden ('[] @Natural)))
(TKScalar r))
-> target
(TKS
((':) @Natural n_hidden ((':) @Natural batch_size ('[] @Natural)))
r)
forall (n :: Natural) (m :: Natural) (sh :: [Natural]) (x :: TK)
(target :: Target).
(KnownSTK x, BaseTensor target) =>
target (TKS2 ((':) @Natural n ((':) @Natural m sh)) x)
-> target (TKS2 ((':) @Natural m ((':) @Natural n sh)) x)
str (target (TKS ((':) @Natural n_hidden ('[] @Natural)) r)
-> target
(TKS2
((':) @Natural batch_size ((':) @Natural n_hidden ('[] @Natural)))
(TKScalar r))
forall (k :: Natural) (sh :: [Natural]) (x :: TK)
(target :: Target).
(KnownNat k, KnownShS sh, KnownSTK x, BaseTensor target) =>
target (TKS2 sh x) -> target (TKS2 ((':) @Natural k sh) x)
sreplicate target (TKS ((':) @Natural n_hidden ('[] @Natural)) r)
biasesDense)
in target
(TKS
((':)
@Natural SizeMnistLabel ((':) @Natural n_hidden ('[] @Natural)))
r)
weightsReadout target
(TKS
((':)
@Natural SizeMnistLabel ((':) @Natural n_hidden ('[] @Natural)))
r)
-> target
(TKS
((':) @Natural n_hidden ((':) @Natural batch_size ('[] @Natural)))
r)
-> target
(TKS
((':)
@Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
r)
forall (m :: Natural) (n :: Natural) (p :: Natural) r
(target :: Target).
(KnownNat m, KnownNat n, KnownNat p, GoodScalar r,
BaseTensor target) =>
target (TKS ((':) @Natural m ((':) @Natural n ('[] @Natural))) r)
-> target
(TKS ((':) @Natural n ((':) @Natural p ('[] @Natural))) r)
-> target
(TKS ((':) @Natural m ((':) @Natural p ('[] @Natural))) r)
`smatmul2` target
(TKS
((':) @Natural n_hidden ((':) @Natural batch_size ('[] @Natural)))
r)
-> target
(TKS
((':) @Natural n_hidden ((':) @Natural batch_size ('[] @Natural)))
r)
forall (target :: Target) (sh :: [Natural]) r.
(KnownShS sh, ADReady target, GoodScalar r, Differentiable r) =>
target (TKS sh r) -> target (TKS sh r)
reluS target
(TKS
((':) @Natural n_hidden ((':) @Natural batch_size ('[] @Natural)))
r)
denseLayer
target
(TKS
((':)
@Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
r)
-> target
(TKS
((':)
@Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
r)
-> target
(TKS
((':)
@Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
r)
forall a. Num a => a -> a -> a
+ target
(TKS2
((':)
@Natural batch_size ((':) @Natural SizeMnistLabel ('[] @Natural)))
(TKScalar r))
-> target
(TKS
((':)
@Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
r)
forall (n :: Natural) (m :: Natural) (sh :: [Natural]) (x :: TK)
(target :: Target).
(KnownSTK x, BaseTensor target) =>
target (TKS2 ((':) @Natural n ((':) @Natural m sh)) x)
-> target (TKS2 ((':) @Natural m ((':) @Natural n sh)) x)
str (target (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r)
-> target
(TKS2
((':)
@Natural batch_size ((':) @Natural SizeMnistLabel ('[] @Natural)))
(TKScalar r))
forall (k :: Natural) (sh :: [Natural]) (x :: TK)
(target :: Target).
(KnownNat k, KnownShS sh, KnownSTK x, BaseTensor target) =>
target (TKS2 sh x) -> target (TKS2 ((':) @Natural k sh) x)
sreplicate target (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r)
biasesReadout)
convMnistLossFusedS
:: forall kh kw h w c_out n_hidden batch_size target r.
( h ~ SizeMnistHeight, w ~ SizeMnistWidth
, 1 <= kh
, 1 <= kw
, ADReady target, ADReady (PrimalOf target)
, GoodScalar r, Differentiable r )
=> SNat kh -> SNat kw
-> SNat c_out
-> SNat n_hidden -> SNat batch_size
-> ( PrimalOf target (TKS '[batch_size, h, w] r)
, PrimalOf target (TKS '[batch_size, SizeMnistLabel] r) )
-> ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
-> target (TKScalar r)
convMnistLossFusedS :: forall (kh :: Natural) (kw :: Natural) (h :: Natural)
(w :: Natural) (c_out :: Natural) (n_hidden :: Natural)
(batch_size :: Natural) (target :: Target) r.
((h :: Natural) ~ (SizeMnistHeight :: Natural),
(w :: Natural) ~ (SizeMnistHeight :: Natural), (<=) @Natural 1 kh,
(<=) @Natural 1 kw, ADReady target, ADReady (PrimalOf target),
GoodScalar r, Differentiable r) =>
SNat kh
-> SNat kw
-> SNat c_out
-> SNat n_hidden
-> SNat batch_size
-> (PrimalOf
target
(TKS
((':)
@Natural
batch_size
((':) @Natural h ((':) @Natural w ('[] @Natural))))
r),
PrimalOf
target
(TKS
((':)
@Natural batch_size ((':) @Natural SizeMnistLabel ('[] @Natural)))
r))
-> ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
-> target (TKScalar r)
convMnistLossFusedS kh :: SNat kh
kh@SNat kh
SNat kw :: SNat kw
kw@SNat kw
SNat
c_out :: SNat c_out
c_out@SNat c_out
SNat n_hidden :: SNat n_hidden
n_hidden@SNat n_hidden
SNat batch_size :: SNat batch_size
batch_size@SNat batch_size
SNat
(PrimalOf
target
(TKS
((':)
@Natural
batch_size
((':) @Natural h ((':) @Natural w ('[] @Natural))))
r)
glyphS, PrimalOf
target
(TKS
((':)
@Natural batch_size ((':) @Natural SizeMnistLabel ('[] @Natural)))
r)
labelS) ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
adparameters =
let input :: PrimalOf target (TKS '[batch_size, 1, h, w] r)
input :: PrimalOf
target
(TKS
((':)
@Natural
batch_size
((':)
@Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
r)
input = PrimalOf
target
(TKS2
((':)
@Natural
batch_size
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
(TKScalar r))
-> PrimalOf
target
(TKS2
((':)
@Natural
batch_size
((':)
@Natural
1
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural)))))
(TKScalar r))
forall (sh :: [Natural]) (sh2 :: [Natural]) (x :: TK)
(target :: Target).
((Product sh :: Natural) ~ (Product sh2 :: Natural), KnownShS sh2,
KnownSTK x, BaseTensor target) =>
target (TKS2 sh x) -> target (TKS2 sh2 x)
sreshape PrimalOf
target
(TKS
((':)
@Natural
batch_size
((':) @Natural h ((':) @Natural w ('[] @Natural))))
r)
PrimalOf
target
(TKS2
((':)
@Natural
batch_size
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
(TKScalar r))
glyphS
result :: target
(TKS
((':)
@Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
r)
result = SNat kh
-> SNat kw
-> SNat h
-> SNat w
-> SNat c_out
-> SNat n_hidden
-> SNat batch_size
-> PrimalOf
target
(TKS
((':)
@Natural
batch_size
((':)
@Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
r)
-> ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
-> target
(TKS
((':)
@Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
r)
forall (kh :: Natural) (kw :: Natural) (h :: Natural)
(w :: Natural) (c_out :: Natural) (n_hidden :: Natural)
(batch_size :: Natural) (target :: Target) r.
((<=) @Natural 1 kh, (<=) @Natural 1 kw, ADReady target,
GoodScalar r, Differentiable r) =>
SNat kh
-> SNat kw
-> SNat h
-> SNat w
-> SNat c_out
-> SNat n_hidden
-> SNat batch_size
-> PrimalOf
target
(TKS
((':)
@Natural
batch_size
((':)
@Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
r)
-> ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
-> target
(TKS
((':)
@Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
r)
convMnistTwoS SNat kh
kh SNat kw
kw (forall (n :: Natural). KnownNat n => SNat n
SNat @h) (forall (n :: Natural). KnownNat n => SNat n
SNat @w)
SNat c_out
c_out SNat n_hidden
n_hidden SNat batch_size
batch_size
PrimalOf
target
(TKS
((':)
@Natural
batch_size
((':)
@Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
r)
input ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
adparameters
targets :: PrimalOf
target
(TKS
((':)
@Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
r)
targets = PrimalOf
target
(TKS
((':)
@Natural batch_size ((':) @Natural SizeMnistLabel ('[] @Natural)))
r)
-> PrimalOf
target
(TKS
((':)
@Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
r)
forall (n :: Natural) (m :: Natural) (sh :: [Natural]) (x :: TK)
(target :: Target).
(KnownSTK x, BaseTensor target) =>
target (TKS2 ((':) @Natural n ((':) @Natural m sh)) x)
-> target (TKS2 ((':) @Natural m ((':) @Natural n sh)) x)
str PrimalOf
target
(TKS
((':)
@Natural batch_size ((':) @Natural SizeMnistLabel ('[] @Natural)))
r)
labelS
loss :: target (TKScalar r)
loss = PrimalOf
target
(TKS
((':)
@Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
r)
-> target
(TKS
((':)
@Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
r)
-> target (TKScalar r)
forall (target :: Target) (sh :: [Natural]) r.
(ADReady target, ADReady (PrimalOf target), GoodScalar r,
KnownShS sh, Differentiable r) =>
PrimalOf target (TKS sh r)
-> target (TKS sh r) -> target (TKScalar r)
lossSoftMaxCrossEntropyS PrimalOf
target
(TKS
((':)
@Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
r)
targets target
(TKS
((':)
@Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
r)
result
in PrimalOf target (TKScalar r) -> target (TKScalar r)
forall (target :: Target) r.
(BaseTensor target, GoodScalar r) =>
PrimalOf target (TKScalar r) -> target (TKScalar r)
kfromPrimal (PrimalOf target (TKScalar r) -> PrimalOf target (TKScalar r)
forall a. Fractional a => a -> a
recip (PrimalOf target (TKScalar r) -> PrimalOf target (TKScalar r))
-> PrimalOf target (TKScalar r) -> PrimalOf target (TKScalar r)
forall a b. (a -> b) -> a -> b
$ r -> PrimalOf target (TKScalar r)
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
r -> target (TKScalar r)
kconcrete (r -> PrimalOf target (TKScalar r))
-> r -> PrimalOf target (TKScalar r)
forall a b. (a -> b) -> a -> b
$ Integer -> r
forall a. Num a => Integer -> a
fromInteger (Integer -> r) -> Integer -> r
forall a b. (a -> b) -> a -> b
$ SNat batch_size -> Integer
forall (n :: Natural). SNat n -> Integer
fromSNat SNat batch_size
batch_size) target (TKScalar r) -> target (TKScalar r) -> target (TKScalar r)
forall a. Num a => a -> a -> a
* target (TKScalar r)
loss
convMnistTestS
:: forall kh kw h w c_out n_hidden batch_size target r.
( h ~ SizeMnistHeight, w ~ SizeMnistWidth
, 1 <= kh
, 1 <= kw
, target ~ Concrete
, GoodScalar r, Differentiable r )
=> SNat kh -> SNat kw
-> SNat c_out
-> SNat n_hidden -> SNat batch_size
-> MnistDataBatchS batch_size r
-> ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
-> r
convMnistTestS :: forall (kh :: Natural) (kw :: Natural) (h :: Natural)
(w :: Natural) (c_out :: Natural) (n_hidden :: Natural)
(batch_size :: Natural) (target :: Target) r.
((h :: Natural) ~ (SizeMnistHeight :: Natural),
(w :: Natural) ~ (SizeMnistHeight :: Natural), (<=) @Natural 1 kh,
(<=) @Natural 1 kw, (target :: Target) ~ (Concrete :: Target),
GoodScalar r, Differentiable r) =>
SNat kh
-> SNat kw
-> SNat c_out
-> SNat n_hidden
-> SNat batch_size
-> MnistDataBatchS batch_size r
-> ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
-> r
convMnistTestS SNat kh
_ SNat kw
_ SNat c_out
_ SNat n_hidden
_ batch_size :: SNat batch_size
batch_size@SNat batch_size
SNat MnistDataBatchS batch_size r
_ ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
_
| SNat batch_size -> Int
forall (n :: Natural). SNat n -> Int
sNatValue SNat batch_size
batch_size Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = r
0
convMnistTestS kh :: SNat kh
kh@SNat kh
SNat kw :: SNat kw
kw@SNat kw
SNat
c_out :: SNat c_out
c_out@SNat c_out
SNat n_hidden :: SNat n_hidden
n_hidden@SNat n_hidden
SNat batch_size :: SNat batch_size
batch_size@SNat batch_size
SNat
(Shaped
((':)
@Natural
batch_size
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r
glyphS, Shaped
((':)
@Natural batch_size ((':) @Natural SizeMnistLabel ('[] @Natural)))
r
labelS) ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
testParams =
let input :: target (TKS '[batch_size, 1, h, w] r)
input :: target
(TKS
((':)
@Natural
batch_size
((':)
@Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
r)
input = Shaped
((':)
@Natural
batch_size
((':)
@Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
r
-> target
(TKS
((':)
@Natural
batch_size
((':)
@Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
r)
forall r (target :: Target) (sh :: [Natural]).
(GoodScalar r, BaseTensor target) =>
Shaped sh r -> target (TKS sh r)
sconcrete (Shaped
((':)
@Natural
batch_size
((':)
@Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
r
-> target
(TKS
((':)
@Natural
batch_size
((':)
@Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
r))
-> Shaped
((':)
@Natural
batch_size
((':)
@Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
r
-> target
(TKS
((':)
@Natural
batch_size
((':)
@Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
r)
forall a b. (a -> b) -> a -> b
$ ShS
((':)
@Natural
batch_size
((':)
@Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
-> Shaped
((':)
@Natural
batch_size
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r
-> Shaped
((':)
@Natural
batch_size
((':)
@Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
r
forall a (sh :: [Natural]) (sh' :: [Natural]).
(Elt a, (Product sh :: Natural) ~ (Product sh' :: Natural)) =>
ShS sh' -> Shaped sh a -> Shaped sh' a
Nested.sreshape ShS
((':)
@Natural
batch_size
((':)
@Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
forall (sh :: [Natural]). KnownShS sh => ShS sh
knownShS Shaped
((':)
@Natural
batch_size
((':)
@Natural
SizeMnistHeight
((':) @Natural SizeMnistHeight ('[] @Natural))))
r
glyphS
outputS :: Concrete (TKS '[SizeMnistLabel, batch_size] r)
outputS :: Concrete
(TKS
((':)
@Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
r)
outputS =
let nn :: ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
-> target (TKS '[SizeMnistLabel, batch_size] r)
nn :: ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
-> target
(TKS
((':)
@Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
r)
nn = SNat kh
-> SNat kw
-> SNat h
-> SNat w
-> SNat c_out
-> SNat n_hidden
-> SNat batch_size
-> PrimalOf
target
(TKS
((':)
@Natural
batch_size
((':)
@Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
r)
-> ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
-> target
(TKS
((':)
@Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
r)
forall (kh :: Natural) (kw :: Natural) (h :: Natural)
(w :: Natural) (c_out :: Natural) (n_hidden :: Natural)
(batch_size :: Natural) (target :: Target) r.
((<=) @Natural 1 kh, (<=) @Natural 1 kw, ADReady target,
GoodScalar r, Differentiable r) =>
SNat kh
-> SNat kw
-> SNat h
-> SNat w
-> SNat c_out
-> SNat n_hidden
-> SNat batch_size
-> PrimalOf
target
(TKS
((':)
@Natural
batch_size
((':)
@Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
r)
-> ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
-> target
(TKS
((':)
@Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
r)
convMnistTwoS SNat kh
kh SNat kw
kw (forall (n :: Natural). KnownNat n => SNat n
SNat @h) (forall (n :: Natural). KnownNat n => SNat n
SNat @w)
SNat c_out
c_out SNat n_hidden
n_hidden SNat batch_size
batch_size
target
(TKS
((':)
@Natural
batch_size
((':)
@Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
r)
PrimalOf
target
(TKS
((':)
@Natural
batch_size
((':)
@Natural 1 ((':) @Natural h ((':) @Natural w ('[] @Natural)))))
r)
input
in ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
-> target
(TKS
((':)
@Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
r)
nn ADCnnMnistParametersShaped target h w kh kw c_out n_hidden r
testParams
outputs :: [Vector r]
outputs = (Concrete (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r)
-> Vector r)
-> [Concrete (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r)]
-> [Vector r]
forall a b. (a -> b) -> [a] -> [b]
map Concrete (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r)
-> Vector r
forall r (sh :: [Natural]).
GoodScalar r =>
Concrete (TKS sh r) -> Vector r
stoVector ([Concrete (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r)]
-> [Vector r])
-> [Concrete (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r)]
-> [Vector r]
forall a b. (a -> b) -> a -> b
$ Concrete
(TKS2
((':)
@Natural batch_size ((':) @Natural SizeMnistLabel ('[] @Natural)))
(TKScalar r))
-> [Concrete (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r)]
forall (n :: Natural) (sh :: [Natural]) (x :: TK)
(target :: Target).
(KnownNat n, KnownShS sh, KnownSTK x, BaseTensor target) =>
target (TKS2 ((':) @Natural n sh) x) -> [target (TKS2 sh x)]
sunravelToList
(Concrete
(TKS2
((':)
@Natural batch_size ((':) @Natural SizeMnistLabel ('[] @Natural)))
(TKScalar r))
-> [Concrete
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r)])
-> Concrete
(TKS2
((':)
@Natural batch_size ((':) @Natural SizeMnistLabel ('[] @Natural)))
(TKScalar r))
-> [Concrete (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r)]
forall a b. (a -> b) -> a -> b
$ forall (perm :: [Natural]) (sh :: [Natural]) (x :: TK)
(target :: Target).
(KnownPerm perm, IsPermutation perm,
(<=) @Natural (Rank @Natural perm) (Rank @Natural sh), KnownSTK x,
BaseTensor target) =>
target (TKS2 sh x)
-> target (TKS2 (PermutePrefix @Natural perm sh) x)
stranspose @'[1, 0] Concrete
(TKS
((':)
@Natural SizeMnistLabel ((':) @Natural batch_size ('[] @Natural)))
r)
outputS
labels :: [Vector r]
labels = (Concrete (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r)
-> Vector r)
-> [Concrete (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r)]
-> [Vector r]
forall a b. (a -> b) -> [a] -> [b]
map Concrete (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r)
-> Vector r
forall r (sh :: [Natural]).
GoodScalar r =>
Concrete (TKS sh r) -> Vector r
stoVector
([Concrete (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r)]
-> [Vector r])
-> [Concrete (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r)]
-> [Vector r]
forall a b. (a -> b) -> a -> b
$ forall (n :: Natural) (sh :: [Natural]) (x :: TK)
(target :: Target).
(KnownNat n, KnownShS sh, KnownSTK x, BaseTensor target) =>
target (TKS2 ((':) @Natural n sh) x) -> [target (TKS2 sh x)]
sunravelToList @_ @_ @(TKScalar r)
(Concrete
(TKS2
((':)
@Natural batch_size ((':) @Natural SizeMnistLabel ('[] @Natural)))
(TKScalar r))
-> [Concrete
(TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r)])
-> Concrete
(TKS2
((':)
@Natural batch_size ((':) @Natural SizeMnistLabel ('[] @Natural)))
(TKScalar r))
-> [Concrete (TKS ((':) @Natural SizeMnistLabel ('[] @Natural)) r)]
forall a b. (a -> b) -> a -> b
$ Shaped
((':)
@Natural batch_size ((':) @Natural SizeMnistLabel ('[] @Natural)))
r
-> Concrete
(TKS2
((':)
@Natural batch_size ((':) @Natural SizeMnistLabel ('[] @Natural)))
(TKScalar r))
forall r (target :: Target) (sh :: [Natural]).
(GoodScalar r, BaseTensor target) =>
Shaped sh r -> target (TKS sh r)
sconcrete Shaped
((':)
@Natural batch_size ((':) @Natural SizeMnistLabel ('[] @Natural)))
r
labelS
matchesLabels :: Vector r -> Vector r -> Int
matchesLabels :: Vector r -> Vector r -> Int
matchesLabels Vector r
output Vector r
label | Vector r -> Int
forall (v :: Type -> Type) a. (Vector v a, Ord a) => v a -> Int
V.maxIndex Vector r
output Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Vector r -> Int
forall (v :: Type -> Type) a. (Vector v a, Ord a) => v a -> Int
V.maxIndex Vector r
label = Int
1
| Bool
otherwise = Int
0
in Int -> r
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([Int] -> Int
forall a. Num a => [a] -> a
forall (t :: Type -> Type) a. (Foldable t, Num a) => t a -> a
sum ((Vector r -> Vector r -> Int) -> [Vector r] -> [Vector r] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Vector r -> Vector r -> Int
matchesLabels [Vector r]
outputs [Vector r]
labels))
r -> r -> r
forall a. Fractional a => a -> a -> a
/ Integer -> r
forall a. Num a => Integer -> a
fromInteger (SNat batch_size -> Integer
forall (n :: Natural). SNat n -> Integer
fromSNat SNat batch_size
batch_size)