{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -Wno-missing-export-lists #-}
module MnistRnnShaped2 where
import Prelude hiding (foldl')
import Data.Kind (Type)
import Data.List (foldl')
import Data.Vector.Generic qualified as V
import Data.Vector.Storable (Vector)
import GHC.TypeLits (KnownNat, Nat, fromSNat, type (*))
import Data.Array.Nested.Permutation qualified as Permutation
import Data.Array.Nested qualified as Nested
import Data.Array.Nested.Shaped.Shape
import HordeAd
import MnistData
type
(target :: Target) sizeMnistHeight width r =
( LayerWeigthsRNNShaped target sizeMnistHeight width r
, LayerWeigthsRNNShaped target width width r
, ( target (TKS '[SizeMnistLabel, width] r)
, target (TKS '[SizeMnistLabel] r) ) )
type LayerWeigthsRNNShaped :: Target -> Nat -> Nat -> Type -> Type
type LayerWeigthsRNNShaped target in_width out_width r =
( target (TKS '[out_width, in_width] r)
, target (TKS '[out_width, out_width] r)
, target (TKS '[out_width] r) )
zeroStateS
:: (BaseTensor target, KnownShS sh, GoodScalar r)
=> (target (TKS sh r)
-> a)
-> a
zeroStateS :: forall (target :: Target) (sh :: [Nat]) r a.
(BaseTensor target, KnownShS sh, GoodScalar r) =>
(target (TKS sh r) -> a) -> a
zeroStateS target (TKS sh r) -> a
f = target (TKS sh r) -> a
f (r -> target (TKS sh r)
forall (sh :: [Nat]) r (target :: Target).
(KnownShS sh, GoodScalar r, BaseTensor target) =>
r -> target (TKS sh r)
srepl r
0)
unrollLastS :: forall target state c w r n sh.
(BaseTensor target, KnownNat n, KnownShS sh, GoodScalar r)
=> (state -> target (TKS sh r) -> w -> (c, state))
-> (state -> target (TKS (n ': sh) r) -> w -> (c, state))
unrollLastS :: forall (target :: Target) state c w r (n :: Nat) (sh :: [Nat]).
(BaseTensor target, KnownNat n, KnownShS sh, GoodScalar r) =>
(state -> target (TKS sh r) -> w -> (c, state))
-> state -> target (TKS ((':) @Nat n sh) r) -> w -> (c, state)
unrollLastS state -> target (TKS sh r) -> w -> (c, state)
f state
s0 target (TKS ((':) @Nat n sh) r)
xs w
w =
let g :: (c, state) -> target (TKS sh r) -> (c, state)
g :: (c, state) -> target (TKS sh r) -> (c, state)
g (c
_, !state
s) target (TKS sh r)
x = state -> target (TKS sh r) -> w -> (c, state)
f state
s target (TKS sh r)
x w
w
in ((c, state) -> target (TKS sh r) -> (c, state))
-> (c, state) -> [target (TKS sh r)] -> (c, state)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (c, state) -> target (TKS sh r) -> (c, state)
g (c
forall a. HasCallStack => a
undefined, state
s0) (target (TKS ((':) @Nat n sh) r) -> [target (TKS sh r)]
forall (n :: Nat) (sh :: [Nat]) (x :: TK) (target :: Target).
(KnownNat n, KnownShS sh, KnownSTK x, BaseTensor target) =>
target (TKS2 ((':) @Nat n sh) x) -> [target (TKS2 sh x)]
sunravelToList target (TKS ((':) @Nat n sh) r)
xs)
rnnMnistLayerS
:: (ADReady target, GoodScalar r, Differentiable r)
=> SNat in_width -> SNat out_width -> SNat batch_size
-> target (TKS '[out_width, batch_size] r)
-> target (TKS '[in_width, batch_size] r)
-> LayerWeigthsRNNShaped target in_width out_width r
-> target (TKS '[out_width, batch_size] r)
rnnMnistLayerS :: forall (target :: Target) r (in_width :: Nat) (out_width :: Nat)
(batch_size :: Nat).
(ADReady target, GoodScalar r, Differentiable r) =>
SNat in_width
-> SNat out_width
-> SNat batch_size
-> target
(TKS ((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat))) r)
-> target
(TKS ((':) @Nat in_width ((':) @Nat batch_size ('[] @Nat))) r)
-> LayerWeigthsRNNShaped target in_width out_width r
-> target
(TKS ((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat))) r)
rnnMnistLayerS SNat in_width
SNat SNat out_width
SNat SNat batch_size
SNat
target
(TKS ((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat))) r)
s target
(TKS ((':) @Nat in_width ((':) @Nat batch_size ('[] @Nat))) r)
x (target
(TKS ((':) @Nat out_width ((':) @Nat in_width ('[] @Nat))) r)
wX, target
(TKS ((':) @Nat out_width ((':) @Nat out_width ('[] @Nat))) r)
wS, target (TKS ((':) @Nat out_width ('[] @Nat)) r)
b) =
let y :: target
(TKS ((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat))) r)
y = target
(TKS ((':) @Nat out_width ((':) @Nat in_width ('[] @Nat))) r)
wX target
(TKS ((':) @Nat out_width ((':) @Nat in_width ('[] @Nat))) r)
-> target
(TKS ((':) @Nat in_width ((':) @Nat batch_size ('[] @Nat))) r)
-> target
(TKS ((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat))) r)
forall (m :: Nat) (n :: Nat) (p :: Nat) r (target :: Target).
(KnownNat m, KnownNat n, KnownNat p, GoodScalar r,
BaseTensor target) =>
target (TKS ((':) @Nat m ((':) @Nat n ('[] @Nat))) r)
-> target (TKS ((':) @Nat n ((':) @Nat p ('[] @Nat))) r)
-> target (TKS ((':) @Nat m ((':) @Nat p ('[] @Nat))) r)
`smatmul2` target
(TKS ((':) @Nat in_width ((':) @Nat batch_size ('[] @Nat))) r)
x target
(TKS ((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat))) r)
-> target
(TKS ((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat))) r)
-> target
(TKS ((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat))) r)
forall a. Num a => a -> a -> a
+ target
(TKS ((':) @Nat out_width ((':) @Nat out_width ('[] @Nat))) r)
wS target
(TKS ((':) @Nat out_width ((':) @Nat out_width ('[] @Nat))) r)
-> target
(TKS ((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat))) r)
-> target
(TKS ((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat))) r)
forall (m :: Nat) (n :: Nat) (p :: Nat) r (target :: Target).
(KnownNat m, KnownNat n, KnownNat p, GoodScalar r,
BaseTensor target) =>
target (TKS ((':) @Nat m ((':) @Nat n ('[] @Nat))) r)
-> target (TKS ((':) @Nat n ((':) @Nat p ('[] @Nat))) r)
-> target (TKS ((':) @Nat m ((':) @Nat p ('[] @Nat))) r)
`smatmul2` target
(TKS ((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat))) r)
s
target
(TKS ((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat))) r)
-> target
(TKS ((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat))) r)
-> target
(TKS ((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat))) r)
forall a. Num a => a -> a -> a
+ target
(TKS2
((':) @Nat batch_size ((':) @Nat out_width ('[] @Nat)))
(TKScalar r))
-> target
(TKS ((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat))) r)
forall (n :: Nat) (m :: Nat) (sh :: [Nat]) (x :: TK)
(target :: Target).
(KnownSTK x, BaseTensor target) =>
target (TKS2 ((':) @Nat n ((':) @Nat m sh)) x)
-> target (TKS2 ((':) @Nat m ((':) @Nat n sh)) x)
str (target (TKS ((':) @Nat out_width ('[] @Nat)) r)
-> target
(TKS2
((':) @Nat batch_size ((':) @Nat out_width ('[] @Nat)))
(TKScalar r))
forall (k :: Nat) (sh :: [Nat]) (x :: TK) (target :: Target).
(KnownNat k, KnownShS sh, KnownSTK x, BaseTensor target) =>
target (TKS2 sh x) -> target (TKS2 ((':) @Nat k sh) x)
sreplicate target (TKS ((':) @Nat out_width ('[] @Nat)) r)
b)
in target
(TKS ((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat))) r)
-> target
(TKS ((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat))) r)
forall a. Floating a => a -> a
tanh target
(TKS ((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat))) r)
y
rnnMnistTwoS
:: (ADReady target, GoodScalar r, Differentiable r)
=> SNat out_width -> SNat batch_size -> SNat sizeMnistH
-> target (TKS '[2 * out_width, batch_size] r)
-> PrimalOf target (TKS '[sizeMnistH, batch_size] r)
-> ( LayerWeigthsRNNShaped target sizeMnistH out_width r
, LayerWeigthsRNNShaped target out_width out_width r )
-> ( target (TKS '[out_width, batch_size] r)
, target (TKS '[2 * out_width, batch_size] r) )
rnnMnistTwoS :: forall (target :: Target) r (out_width :: Nat) (batch_size :: Nat)
(sizeMnistH :: Nat).
(ADReady target, GoodScalar r, Differentiable r) =>
SNat out_width
-> SNat batch_size
-> SNat sizeMnistH
-> target
(TKS
((':) @Nat (2 * out_width) ((':) @Nat batch_size ('[] @Nat))) r)
-> PrimalOf
target
(TKS ((':) @Nat sizeMnistH ((':) @Nat batch_size ('[] @Nat))) r)
-> (LayerWeigthsRNNShaped target sizeMnistH out_width r,
LayerWeigthsRNNShaped target out_width out_width r)
-> (target
(TKS ((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat))) r),
target
(TKS
((':) @Nat (2 * out_width) ((':) @Nat batch_size ('[] @Nat))) r))
rnnMnistTwoS out_width :: SNat out_width
out_width@SNat out_width
SNat
batch_size :: SNat batch_size
batch_size@SNat batch_size
SNat
sizeMnistHeightHere :: SNat sizeMnistH
sizeMnistHeightHere@SNat sizeMnistH
SNat
target
(TKS
((':) @Nat (2 * out_width) ((':) @Nat batch_size ('[] @Nat))) r)
s' PrimalOf
target
(TKS ((':) @Nat sizeMnistH ((':) @Nat batch_size ('[] @Nat))) r)
x ((target
(TKS ((':) @Nat out_width ((':) @Nat sizeMnistH ('[] @Nat))) r)
wX, target
(TKS ((':) @Nat out_width ((':) @Nat out_width ('[] @Nat))) r)
wS, target (TKS ((':) @Nat out_width ('[] @Nat)) r)
b), (target
(TKS ((':) @Nat out_width ((':) @Nat out_width ('[] @Nat))) r)
wX2, target
(TKS ((':) @Nat out_width ((':) @Nat out_width ('[] @Nat))) r)
wS2, target (TKS ((':) @Nat out_width ('[] @Nat)) r)
b2)) =
let s3 :: target
(TKS2
((':)
@Nat (out_width + out_width) ((':) @Nat batch_size ('[] @Nat)))
(TKScalar r))
s3 = target
(TKS
((':) @Nat (2 * out_width) ((':) @Nat batch_size ('[] @Nat))) r)
-> (target
(TKS
((':) @Nat (2 * out_width) ((':) @Nat batch_size ('[] @Nat))) r)
-> target
(TKS2
((':)
@Nat (out_width + out_width) ((':) @Nat batch_size ('[] @Nat)))
(TKScalar r)))
-> target
(TKS2
((':)
@Nat (out_width + out_width) ((':) @Nat batch_size ('[] @Nat)))
(TKScalar r))
forall (x :: TK) (z :: TK) (target :: Target).
LetTensor target =>
target x -> (target x -> target z) -> target z
tlet target
(TKS
((':) @Nat (2 * out_width) ((':) @Nat batch_size ('[] @Nat))) r)
s' ((target
(TKS
((':) @Nat (2 * out_width) ((':) @Nat batch_size ('[] @Nat))) r)
-> target
(TKS2
((':)
@Nat (out_width + out_width) ((':) @Nat batch_size ('[] @Nat)))
(TKScalar r)))
-> target
(TKS2
((':)
@Nat (out_width + out_width) ((':) @Nat batch_size ('[] @Nat)))
(TKScalar r)))
-> (target
(TKS
((':) @Nat (2 * out_width) ((':) @Nat batch_size ('[] @Nat))) r)
-> target
(TKS2
((':)
@Nat (out_width + out_width) ((':) @Nat batch_size ('[] @Nat)))
(TKScalar r)))
-> target
(TKS2
((':)
@Nat (out_width + out_width) ((':) @Nat batch_size ('[] @Nat)))
(TKScalar r))
forall a b. (a -> b) -> a -> b
$ \target
(TKS
((':) @Nat (2 * out_width) ((':) @Nat batch_size ('[] @Nat))) r)
s ->
let s1 :: target
(TKS2
((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat)))
(TKScalar r))
s1 = SNat 0
-> SNat out_width
-> SNat out_width
-> target
(TKS2
((':)
@Nat
((0 + out_width) + out_width)
((':) @Nat batch_size ('[] @Nat)))
(TKScalar r))
-> target
(TKS2
((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat)))
(TKScalar r))
forall (i :: Nat) (n :: Nat) (k :: Nat) (sh :: [Nat]) (x :: TK)
(target :: Target).
(KnownSTK x, BaseTensor target) =>
SNat i
-> SNat n
-> SNat k
-> target (TKS2 ((':) @Nat ((i + n) + k) sh) x)
-> target (TKS2 ((':) @Nat n sh) x)
sslice (forall (n :: Nat). KnownNat n => SNat n
SNat @0) SNat out_width
out_width SNat out_width
forall (n :: Nat). KnownNat n => SNat n
SNat target
(TKS2
((':)
@Nat
((0 + out_width) + out_width)
((':) @Nat batch_size ('[] @Nat)))
(TKScalar r))
target
(TKS
((':) @Nat (2 * out_width) ((':) @Nat batch_size ('[] @Nat))) r)
s
s2 :: target
(TKS2
((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat)))
(TKScalar r))
s2 = SNat out_width
-> SNat out_width
-> SNat 0
-> target
(TKS2
((':)
@Nat
((out_width + out_width) + 0)
((':) @Nat batch_size ('[] @Nat)))
(TKScalar r))
-> target
(TKS2
((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat)))
(TKScalar r))
forall (i :: Nat) (n :: Nat) (k :: Nat) (sh :: [Nat]) (x :: TK)
(target :: Target).
(KnownSTK x, BaseTensor target) =>
SNat i
-> SNat n
-> SNat k
-> target (TKS2 ((':) @Nat ((i + n) + k) sh) x)
-> target (TKS2 ((':) @Nat n sh) x)
sslice SNat out_width
out_width SNat out_width
out_width SNat 0
forall (n :: Nat). KnownNat n => SNat n
SNat target
(TKS2
((':)
@Nat
((out_width + out_width) + 0)
((':) @Nat batch_size ('[] @Nat)))
(TKScalar r))
target
(TKS
((':) @Nat (2 * out_width) ((':) @Nat batch_size ('[] @Nat))) r)
s
vec1 :: target
(TKS2
((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat)))
(TKScalar r))
vec1 = SNat sizeMnistH
-> SNat out_width
-> SNat batch_size
-> target
(TKS2
((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat)))
(TKScalar r))
-> target
(TKS ((':) @Nat sizeMnistH ((':) @Nat batch_size ('[] @Nat))) r)
-> LayerWeigthsRNNShaped target sizeMnistH out_width r
-> target
(TKS2
((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat)))
(TKScalar r))
forall (target :: Target) r (in_width :: Nat) (out_width :: Nat)
(batch_size :: Nat).
(ADReady target, GoodScalar r, Differentiable r) =>
SNat in_width
-> SNat out_width
-> SNat batch_size
-> target
(TKS ((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat))) r)
-> target
(TKS ((':) @Nat in_width ((':) @Nat batch_size ('[] @Nat))) r)
-> LayerWeigthsRNNShaped target in_width out_width r
-> target
(TKS ((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat))) r)
rnnMnistLayerS SNat sizeMnistH
sizeMnistHeightHere
SNat out_width
out_width
SNat batch_size
batch_size
target
(TKS2
((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat)))
(TKScalar r))
s1 (PrimalOf
target
(TKS ((':) @Nat sizeMnistH ((':) @Nat batch_size ('[] @Nat))) r)
-> target
(TKS ((':) @Nat sizeMnistH ((':) @Nat batch_size ('[] @Nat))) r)
forall (target :: Target) (sh :: [Nat]) (x :: TK).
(BaseTensor target, KnownShS sh, KnownSTK x) =>
PrimalOf target (TKS2 sh x) -> target (TKS2 sh x)
sfromPrimal PrimalOf
target
(TKS ((':) @Nat sizeMnistH ((':) @Nat batch_size ('[] @Nat))) r)
x) (target
(TKS ((':) @Nat out_width ((':) @Nat sizeMnistH ('[] @Nat))) r)
wX, target
(TKS ((':) @Nat out_width ((':) @Nat out_width ('[] @Nat))) r)
wS, target (TKS ((':) @Nat out_width ('[] @Nat)) r)
b)
vec2 :: target
(TKS2
((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat)))
(TKScalar r))
vec2 = SNat out_width
-> SNat out_width
-> SNat batch_size
-> target
(TKS2
((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat)))
(TKScalar r))
-> target
(TKS2
((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat)))
(TKScalar r))
-> LayerWeigthsRNNShaped target out_width out_width r
-> target
(TKS2
((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat)))
(TKScalar r))
forall (target :: Target) r (in_width :: Nat) (out_width :: Nat)
(batch_size :: Nat).
(ADReady target, GoodScalar r, Differentiable r) =>
SNat in_width
-> SNat out_width
-> SNat batch_size
-> target
(TKS ((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat))) r)
-> target
(TKS ((':) @Nat in_width ((':) @Nat batch_size ('[] @Nat))) r)
-> LayerWeigthsRNNShaped target in_width out_width r
-> target
(TKS ((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat))) r)
rnnMnistLayerS SNat out_width
out_width
SNat out_width
out_width
SNat batch_size
batch_size
target
(TKS2
((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat)))
(TKScalar r))
s2 target
(TKS2
((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat)))
(TKScalar r))
vec1 (target
(TKS ((':) @Nat out_width ((':) @Nat out_width ('[] @Nat))) r)
wX2, target
(TKS ((':) @Nat out_width ((':) @Nat out_width ('[] @Nat))) r)
wS2, target (TKS ((':) @Nat out_width ('[] @Nat)) r)
b2)
in target
(TKS2
((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat)))
(TKScalar r))
-> target
(TKS2
((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat)))
(TKScalar r))
-> target
(TKS2
((':)
@Nat (out_width + out_width) ((':) @Nat batch_size ('[] @Nat)))
(TKScalar r))
forall (m :: Nat) (n :: Nat) (sh :: [Nat]) (x :: TK)
(target :: Target).
(KnownSTK x, BaseTensor target) =>
target (TKS2 ((':) @Nat m sh) x)
-> target (TKS2 ((':) @Nat n sh) x)
-> target (TKS2 ((':) @Nat (m + n) sh) x)
sappend target
(TKS2
((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat)))
(TKScalar r))
vec1 target
(TKS2
((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat)))
(TKScalar r))
vec2
in (SNat out_width
-> SNat out_width
-> SNat 0
-> target
(TKS2
((':)
@Nat
((out_width + out_width) + 0)
((':) @Nat batch_size ('[] @Nat)))
(TKScalar r))
-> target
(TKS2
((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat)))
(TKScalar r))
forall (i :: Nat) (n :: Nat) (k :: Nat) (sh :: [Nat]) (x :: TK)
(target :: Target).
(KnownSTK x, BaseTensor target) =>
SNat i
-> SNat n
-> SNat k
-> target (TKS2 ((':) @Nat ((i + n) + k) sh) x)
-> target (TKS2 ((':) @Nat n sh) x)
sslice SNat out_width
out_width SNat out_width
out_width SNat 0
forall (n :: Nat). KnownNat n => SNat n
SNat target
(TKS2
((':)
@Nat (out_width + out_width) ((':) @Nat batch_size ('[] @Nat)))
(TKScalar r))
target
(TKS2
((':)
@Nat
((out_width + out_width) + 0)
((':) @Nat batch_size ('[] @Nat)))
(TKScalar r))
s3, target
(TKS2
((':)
@Nat (out_width + out_width) ((':) @Nat batch_size ('[] @Nat)))
(TKScalar r))
target
(TKS
((':) @Nat (2 * out_width) ((':) @Nat batch_size ('[] @Nat))) r)
s3)
rnnMnistZeroS
:: (ADReady target, GoodScalar r, Differentiable r)
=> SNat out_width
-> SNat batch_size
-> SNat sizeMnistH -> SNat sizeMnistW
-> PrimalOf target (TKS '[sizeMnistW, sizeMnistH, batch_size] r)
-> ADRnnMnistParametersShaped target sizeMnistH out_width r
-> target (TKS '[SizeMnistLabel, batch_size] r)
rnnMnistZeroS :: forall (target :: Target) r (out_width :: Nat) (batch_size :: Nat)
(sizeMnistH :: Nat) (sizeMnistW :: Nat).
(ADReady target, GoodScalar r, Differentiable r) =>
SNat out_width
-> SNat batch_size
-> SNat sizeMnistH
-> SNat sizeMnistW
-> PrimalOf
target
(TKS
((':)
@Nat
sizeMnistW
((':) @Nat sizeMnistH ((':) @Nat batch_size ('[] @Nat))))
r)
-> ADRnnMnistParametersShaped target sizeMnistH out_width r
-> target
(TKS
((':) @Nat SizeMnistLabel ((':) @Nat batch_size ('[] @Nat))) r)
rnnMnistZeroS out_width :: SNat out_width
out_width@SNat out_width
SNat
batch_size :: SNat batch_size
batch_size@SNat batch_size
SNat
sizeMnistHeightHere :: SNat sizeMnistH
sizeMnistHeightHere@SNat sizeMnistH
SNat _sizeMnistWidthHere :: SNat sizeMnistW
_sizeMnistWidthHere@SNat sizeMnistW
SNat
PrimalOf
target
(TKS
((':)
@Nat
sizeMnistW
((':) @Nat sizeMnistH ((':) @Nat batch_size ('[] @Nat))))
r)
xs ((target
(TKS ((':) @Nat out_width ((':) @Nat sizeMnistH ('[] @Nat))) r)
wX, target
(TKS ((':) @Nat out_width ((':) @Nat out_width ('[] @Nat))) r)
wS, target (TKS ((':) @Nat out_width ('[] @Nat)) r)
b), (target
(TKS ((':) @Nat out_width ((':) @Nat out_width ('[] @Nat))) r)
wX2, target
(TKS ((':) @Nat out_width ((':) @Nat out_width ('[] @Nat))) r)
wS2, target (TKS ((':) @Nat out_width ('[] @Nat)) r)
b2), (target
(TKS ((':) @Nat SizeMnistLabel ((':) @Nat out_width ('[] @Nat))) r)
w3, target (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)
b3)) =
let rnnMnistTwo :: target
(TKS
((':) @Nat (2 * out_width) ((':) @Nat batch_size ('[] @Nat))) r)
-> PrimalOf
target
(TKS ((':) @Nat sizeMnistH ((':) @Nat batch_size ('[] @Nat))) r)
-> (LayerWeigthsRNNShaped target sizeMnistH out_width r,
LayerWeigthsRNNShaped target out_width out_width r)
-> (target
(TKS ((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat))) r),
target
(TKS
((':) @Nat (2 * out_width) ((':) @Nat batch_size ('[] @Nat))) r))
rnnMnistTwo = SNat out_width
-> SNat batch_size
-> SNat sizeMnistH
-> target
(TKS
((':) @Nat (2 * out_width) ((':) @Nat batch_size ('[] @Nat))) r)
-> PrimalOf
target
(TKS ((':) @Nat sizeMnistH ((':) @Nat batch_size ('[] @Nat))) r)
-> (LayerWeigthsRNNShaped target sizeMnistH out_width r,
LayerWeigthsRNNShaped target out_width out_width r)
-> (target
(TKS ((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat))) r),
target
(TKS
((':) @Nat (2 * out_width) ((':) @Nat batch_size ('[] @Nat))) r))
forall (target :: Target) r (out_width :: Nat) (batch_size :: Nat)
(sizeMnistH :: Nat).
(ADReady target, GoodScalar r, Differentiable r) =>
SNat out_width
-> SNat batch_size
-> SNat sizeMnistH
-> target
(TKS
((':) @Nat (2 * out_width) ((':) @Nat batch_size ('[] @Nat))) r)
-> PrimalOf
target
(TKS ((':) @Nat sizeMnistH ((':) @Nat batch_size ('[] @Nat))) r)
-> (LayerWeigthsRNNShaped target sizeMnistH out_width r,
LayerWeigthsRNNShaped target out_width out_width r)
-> (target
(TKS ((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat))) r),
target
(TKS
((':) @Nat (2 * out_width) ((':) @Nat batch_size ('[] @Nat))) r))
rnnMnistTwoS SNat out_width
out_width SNat batch_size
batch_size SNat sizeMnistH
sizeMnistHeightHere
(target
(TKS ((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat))) r)
out, target
(TKS
((':) @Nat (2 * out_width) ((':) @Nat batch_size ('[] @Nat))) r)
_s) = (target
(TKS
((':) @Nat (2 * out_width) ((':) @Nat batch_size ('[] @Nat))) r)
-> PrimalOf
target
(TKS
((':)
@Nat
sizeMnistW
((':) @Nat sizeMnistH ((':) @Nat batch_size ('[] @Nat))))
r)
-> (LayerWeigthsRNNShaped target sizeMnistH out_width r,
LayerWeigthsRNNShaped target out_width out_width r)
-> (target
(TKS ((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat))) r),
target
(TKS
((':) @Nat (2 * out_width) ((':) @Nat batch_size ('[] @Nat))) r)))
-> PrimalOf
target
(TKS
((':)
@Nat
sizeMnistW
((':) @Nat sizeMnistH ((':) @Nat batch_size ('[] @Nat))))
r)
-> (LayerWeigthsRNNShaped target sizeMnistH out_width r,
LayerWeigthsRNNShaped target out_width out_width r)
-> (target
(TKS ((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat))) r),
target
(TKS
((':) @Nat (2 * out_width) ((':) @Nat batch_size ('[] @Nat))) r))
forall (target :: Target) (sh :: [Nat]) r a.
(BaseTensor target, KnownShS sh, GoodScalar r) =>
(target (TKS sh r) -> a) -> a
zeroStateS ((target
(TKS
((':) @Nat (2 * out_width) ((':) @Nat batch_size ('[] @Nat))) r)
-> PrimalOf
target
(TKS ((':) @Nat sizeMnistH ((':) @Nat batch_size ('[] @Nat))) r)
-> (LayerWeigthsRNNShaped target sizeMnistH out_width r,
LayerWeigthsRNNShaped target out_width out_width r)
-> (target
(TKS ((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat))) r),
target
(TKS
((':) @Nat (2 * out_width) ((':) @Nat batch_size ('[] @Nat))) r)))
-> target
(TKS
((':) @Nat (2 * out_width) ((':) @Nat batch_size ('[] @Nat))) r)
-> PrimalOf
target
(TKS
((':)
@Nat
sizeMnistW
((':) @Nat sizeMnistH ((':) @Nat batch_size ('[] @Nat))))
r)
-> (LayerWeigthsRNNShaped target sizeMnistH out_width r,
LayerWeigthsRNNShaped target out_width out_width r)
-> (target
(TKS ((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat))) r),
target
(TKS
((':) @Nat (2 * out_width) ((':) @Nat batch_size ('[] @Nat))) r))
forall (target :: Target) state c w r (n :: Nat) (sh :: [Nat]).
(BaseTensor target, KnownNat n, KnownShS sh, GoodScalar r) =>
(state -> target (TKS sh r) -> w -> (c, state))
-> state -> target (TKS ((':) @Nat n sh) r) -> w -> (c, state)
unrollLastS target
(TKS
((':) @Nat (2 * out_width) ((':) @Nat batch_size ('[] @Nat))) r)
-> PrimalOf
target
(TKS ((':) @Nat sizeMnistH ((':) @Nat batch_size ('[] @Nat))) r)
-> (LayerWeigthsRNNShaped target sizeMnistH out_width r,
LayerWeigthsRNNShaped target out_width out_width r)
-> (target
(TKS ((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat))) r),
target
(TKS
((':) @Nat (2 * out_width) ((':) @Nat batch_size ('[] @Nat))) r))
rnnMnistTwo) PrimalOf
target
(TKS
((':)
@Nat
sizeMnistW
((':) @Nat sizeMnistH ((':) @Nat batch_size ('[] @Nat))))
r)
xs
((target
(TKS ((':) @Nat out_width ((':) @Nat sizeMnistH ('[] @Nat))) r)
wX, target
(TKS ((':) @Nat out_width ((':) @Nat out_width ('[] @Nat))) r)
wS, target (TKS ((':) @Nat out_width ('[] @Nat)) r)
b), (target
(TKS ((':) @Nat out_width ((':) @Nat out_width ('[] @Nat))) r)
wX2, target
(TKS ((':) @Nat out_width ((':) @Nat out_width ('[] @Nat))) r)
wS2, target (TKS ((':) @Nat out_width ('[] @Nat)) r)
b2))
in target
(TKS ((':) @Nat SizeMnistLabel ((':) @Nat out_width ('[] @Nat))) r)
w3 target
(TKS ((':) @Nat SizeMnistLabel ((':) @Nat out_width ('[] @Nat))) r)
-> target
(TKS ((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat))) r)
-> target
(TKS
((':) @Nat SizeMnistLabel ((':) @Nat batch_size ('[] @Nat))) r)
forall (m :: Nat) (n :: Nat) (p :: Nat) r (target :: Target).
(KnownNat m, KnownNat n, KnownNat p, GoodScalar r,
BaseTensor target) =>
target (TKS ((':) @Nat m ((':) @Nat n ('[] @Nat))) r)
-> target (TKS ((':) @Nat n ((':) @Nat p ('[] @Nat))) r)
-> target (TKS ((':) @Nat m ((':) @Nat p ('[] @Nat))) r)
`smatmul2` target
(TKS ((':) @Nat out_width ((':) @Nat batch_size ('[] @Nat))) r)
out target
(TKS
((':) @Nat SizeMnistLabel ((':) @Nat batch_size ('[] @Nat))) r)
-> target
(TKS
((':) @Nat SizeMnistLabel ((':) @Nat batch_size ('[] @Nat))) r)
-> target
(TKS
((':) @Nat SizeMnistLabel ((':) @Nat batch_size ('[] @Nat))) r)
forall a. Num a => a -> a -> a
+ target
(TKS2
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat)))
(TKScalar r))
-> target
(TKS
((':) @Nat SizeMnistLabel ((':) @Nat batch_size ('[] @Nat))) r)
forall (n :: Nat) (m :: Nat) (sh :: [Nat]) (x :: TK)
(target :: Target).
(KnownSTK x, BaseTensor target) =>
target (TKS2 ((':) @Nat n ((':) @Nat m sh)) x)
-> target (TKS2 ((':) @Nat m ((':) @Nat n sh)) x)
str (target (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)
-> target
(TKS2
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat)))
(TKScalar r))
forall (k :: Nat) (sh :: [Nat]) (x :: TK) (target :: Target).
(KnownNat k, KnownShS sh, KnownSTK x, BaseTensor target) =>
target (TKS2 sh x) -> target (TKS2 ((':) @Nat k sh) x)
sreplicate target (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)
b3)
rnnMnistLossFusedS
:: forall target h w out_width batch_size r.
( h ~ SizeMnistHeight, w ~ SizeMnistWidth, Differentiable r
, ADReady target, ADReady (PrimalOf target), GoodScalar r)
=> SNat out_width
-> SNat batch_size
-> ( PrimalOf target (TKS '[batch_size, h, w] r)
, PrimalOf target (TKS '[batch_size, SizeMnistLabel] r) )
-> ADRnnMnistParametersShaped target h out_width r
-> target (TKScalar r)
rnnMnistLossFusedS :: forall (target :: Target) (h :: Nat) (w :: Nat) (out_width :: Nat)
(batch_size :: Nat) r.
((h :: Nat) ~ (SizeMnistHeight :: Nat),
(w :: Nat) ~ (SizeMnistHeight :: Nat), Differentiable r,
ADReady target, ADReady (PrimalOf target), GoodScalar r) =>
SNat out_width
-> SNat batch_size
-> (PrimalOf
target
(TKS
((':) @Nat batch_size ((':) @Nat h ((':) @Nat w ('[] @Nat)))) r),
PrimalOf
target
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r))
-> ADRnnMnistParametersShaped target h out_width r
-> target (TKScalar r)
rnnMnistLossFusedS out_width :: SNat out_width
out_width@SNat out_width
SNat
batch_size :: SNat batch_size
batch_size@SNat batch_size
SNat
(PrimalOf
target
(TKS
((':) @Nat batch_size ((':) @Nat h ((':) @Nat w ('[] @Nat)))) r)
glyphS, PrimalOf
target
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)
labelS) ADRnnMnistParametersShaped target h out_width r
adparameters =
let xs :: PrimalOf
target
(TKS2
(PermutePrefix
@Nat
((':) @Nat 2 ((':) @Nat 1 ((':) @Nat 0 ('[] @Nat))))
((':)
@Nat
batch_size
((':)
@Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat)))))
(TKScalar r))
xs = forall (perm :: [Nat]) (sh :: [Nat]) (x :: TK) (target :: Target).
(KnownPerm perm, IsPermutation perm,
(<=) @Nat (Rank @Nat perm) (Rank @Nat sh), KnownSTK x,
BaseTensor target) =>
target (TKS2 sh x) -> target (TKS2 (PermutePrefix @Nat perm sh) x)
stranspose @'[2, 1, 0] PrimalOf
target
(TKS
((':) @Nat batch_size ((':) @Nat h ((':) @Nat w ('[] @Nat)))) r)
PrimalOf
target
(TKS2
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
(TKScalar r))
glyphS
result :: target
(TKS
((':) @Nat SizeMnistLabel ((':) @Nat batch_size ('[] @Nat))) r)
result = SNat out_width
-> SNat batch_size
-> SNat h
-> SNat w
-> PrimalOf
target
(TKS
((':) @Nat w ((':) @Nat h ((':) @Nat batch_size ('[] @Nat)))) r)
-> ADRnnMnistParametersShaped target h out_width r
-> target
(TKS
((':) @Nat SizeMnistLabel ((':) @Nat batch_size ('[] @Nat))) r)
forall (target :: Target) r (out_width :: Nat) (batch_size :: Nat)
(sizeMnistH :: Nat) (sizeMnistW :: Nat).
(ADReady target, GoodScalar r, Differentiable r) =>
SNat out_width
-> SNat batch_size
-> SNat sizeMnistH
-> SNat sizeMnistW
-> PrimalOf
target
(TKS
((':)
@Nat
sizeMnistW
((':) @Nat sizeMnistH ((':) @Nat batch_size ('[] @Nat))))
r)
-> ADRnnMnistParametersShaped target sizeMnistH out_width r
-> target
(TKS
((':) @Nat SizeMnistLabel ((':) @Nat batch_size ('[] @Nat))) r)
rnnMnistZeroS SNat out_width
out_width
SNat batch_size
batch_size
(forall (n :: Nat). KnownNat n => SNat n
SNat @h) (forall (n :: Nat). KnownNat n => SNat n
SNat @w)
PrimalOf
target
(TKS
((':) @Nat w ((':) @Nat h ((':) @Nat batch_size ('[] @Nat)))) r)
PrimalOf
target
(TKS2
(PermutePrefix
@Nat
((':) @Nat 2 ((':) @Nat 1 ((':) @Nat 0 ('[] @Nat))))
((':)
@Nat
batch_size
((':)
@Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat)))))
(TKScalar r))
xs ADRnnMnistParametersShaped target h out_width r
adparameters
targets :: PrimalOf
target
(TKS
((':) @Nat SizeMnistLabel ((':) @Nat batch_size ('[] @Nat))) r)
targets = PrimalOf
target
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)
-> PrimalOf
target
(TKS
((':) @Nat SizeMnistLabel ((':) @Nat batch_size ('[] @Nat))) r)
forall (n :: Nat) (m :: Nat) (sh :: [Nat]) (x :: TK)
(target :: Target).
(KnownSTK x, BaseTensor target) =>
target (TKS2 ((':) @Nat n ((':) @Nat m sh)) x)
-> target (TKS2 ((':) @Nat m ((':) @Nat n sh)) x)
str PrimalOf
target
(TKS
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)
labelS
loss :: target (TKScalar r)
loss = PrimalOf
target
(TKS
((':) @Nat SizeMnistLabel ((':) @Nat batch_size ('[] @Nat))) r)
-> target
(TKS
((':) @Nat SizeMnistLabel ((':) @Nat batch_size ('[] @Nat))) r)
-> target (TKScalar r)
forall (target :: Target) (sh :: [Nat]) r.
(ADReady target, ADReady (PrimalOf target), GoodScalar r,
KnownShS sh, Differentiable r) =>
PrimalOf target (TKS sh r)
-> target (TKS sh r) -> target (TKScalar r)
lossSoftMaxCrossEntropyS PrimalOf
target
(TKS
((':) @Nat SizeMnistLabel ((':) @Nat batch_size ('[] @Nat))) r)
targets target
(TKS
((':) @Nat SizeMnistLabel ((':) @Nat batch_size ('[] @Nat))) r)
result
in PrimalOf target (TKScalar r) -> target (TKScalar r)
forall (target :: Target) r.
(BaseTensor target, GoodScalar r) =>
PrimalOf target (TKScalar r) -> target (TKScalar r)
kfromPrimal (PrimalOf target (TKScalar r) -> PrimalOf target (TKScalar r)
forall a. Fractional a => a -> a
recip (PrimalOf target (TKScalar r) -> PrimalOf target (TKScalar r))
-> PrimalOf target (TKScalar r) -> PrimalOf target (TKScalar r)
forall a b. (a -> b) -> a -> b
$ r -> PrimalOf target (TKScalar r)
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
r -> target (TKScalar r)
kconcrete (r -> PrimalOf target (TKScalar r))
-> r -> PrimalOf target (TKScalar r)
forall a b. (a -> b) -> a -> b
$ Integer -> r
forall a. Num a => Integer -> a
fromInteger (Integer -> r) -> Integer -> r
forall a b. (a -> b) -> a -> b
$ SNat batch_size -> Integer
forall (n :: Nat). SNat n -> Integer
fromSNat SNat batch_size
batch_size) target (TKScalar r) -> target (TKScalar r) -> target (TKScalar r)
forall a. Num a => a -> a -> a
* target (TKScalar r)
loss
rnnMnistTestS
:: forall target h w out_width batch_size r.
( h ~ SizeMnistHeight, w ~ SizeMnistWidth
, target ~ Concrete, Differentiable r, GoodScalar r )
=> SNat out_width
-> SNat batch_size
-> MnistDataBatchS batch_size r
-> ADRnnMnistParametersShaped target h out_width r
-> r
rnnMnistTestS :: forall (target :: Target) (h :: Nat) (w :: Nat) (out_width :: Nat)
(batch_size :: Nat) r.
((h :: Nat) ~ (SizeMnistHeight :: Nat),
(w :: Nat) ~ (SizeMnistHeight :: Nat),
(target :: Target) ~ (Concrete :: Target), Differentiable r,
GoodScalar r) =>
SNat out_width
-> SNat batch_size
-> MnistDataBatchS batch_size r
-> ADRnnMnistParametersShaped target h out_width r
-> r
rnnMnistTestS out_width :: SNat out_width
out_width@SNat out_width
SNat batch_size :: SNat batch_size
batch_size@SNat batch_size
SNat
(Shaped
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r
glyphS, Shaped
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r
labelS) ADRnnMnistParametersShaped target h out_width r
testParams =
let
input :: Concrete
(TKS
((':)
@Nat
SizeMnistHeight
((':) @Nat SizeMnistHeight ((':) @Nat batch_size ('[] @Nat))))
r)
input = Shaped
((':)
@Nat
SizeMnistHeight
((':) @Nat SizeMnistHeight ((':) @Nat batch_size ('[] @Nat))))
r
-> Concrete
(TKS
((':)
@Nat
SizeMnistHeight
((':) @Nat SizeMnistHeight ((':) @Nat batch_size ('[] @Nat))))
r)
forall r (target :: Target) (sh :: [Nat]).
(GoodScalar r, BaseTensor target) =>
Shaped sh r -> target (TKS sh r)
sconcrete
(Shaped
((':)
@Nat
SizeMnistHeight
((':) @Nat SizeMnistHeight ((':) @Nat batch_size ('[] @Nat))))
r
-> Concrete
(TKS
((':)
@Nat
SizeMnistHeight
((':) @Nat SizeMnistHeight ((':) @Nat batch_size ('[] @Nat))))
r))
-> Shaped
((':)
@Nat
SizeMnistHeight
((':) @Nat SizeMnistHeight ((':) @Nat batch_size ('[] @Nat))))
r
-> Concrete
(TKS
((':)
@Nat
SizeMnistHeight
((':) @Nat SizeMnistHeight ((':) @Nat batch_size ('[] @Nat))))
r)
forall a b. (a -> b) -> a -> b
$ Perm ((':) @Nat 2 ((':) @Nat 1 ((':) @Nat 0 ('[] @Nat))))
-> Shaped
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r
-> Shaped
(PermutePrefix
@Nat
((':) @Nat 2 ((':) @Nat 1 ((':) @Nat 0 ('[] @Nat))))
((':)
@Nat
batch_size
((':)
@Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat)))))
r
forall (is :: [Nat]) (sh :: [Nat]) a.
(IsPermutation is, (<=) @Nat (Rank @Nat is) (Rank @Nat sh),
Elt a) =>
Perm is -> Shaped sh a -> Shaped (PermutePrefix @Nat is sh) a
Nested.stranspose (forall (l :: [Nat]). KnownPerm l => Perm l
Permutation.makePerm @'[2, 1, 0]) Shaped
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r
glyphS
outputS :: Concrete (TKS '[SizeMnistLabel, batch_size] r)
outputS :: Concrete
(TKS
((':) @Nat SizeMnistLabel ((':) @Nat batch_size ('[] @Nat))) r)
outputS =
let nn :: ADRnnMnistParametersShaped target h out_width r
-> target (TKS '[SizeMnistLabel, batch_size] r)
nn :: ADRnnMnistParametersShaped target h out_width r
-> target
(TKS
((':) @Nat SizeMnistLabel ((':) @Nat batch_size ('[] @Nat))) r)
nn = SNat out_width
-> SNat batch_size
-> SNat h
-> SNat w
-> PrimalOf
target
(TKS
((':) @Nat w ((':) @Nat h ((':) @Nat batch_size ('[] @Nat)))) r)
-> ADRnnMnistParametersShaped target h out_width r
-> target
(TKS
((':) @Nat SizeMnistLabel ((':) @Nat batch_size ('[] @Nat))) r)
forall (target :: Target) r (out_width :: Nat) (batch_size :: Nat)
(sizeMnistH :: Nat) (sizeMnistW :: Nat).
(ADReady target, GoodScalar r, Differentiable r) =>
SNat out_width
-> SNat batch_size
-> SNat sizeMnistH
-> SNat sizeMnistW
-> PrimalOf
target
(TKS
((':)
@Nat
sizeMnistW
((':) @Nat sizeMnistH ((':) @Nat batch_size ('[] @Nat))))
r)
-> ADRnnMnistParametersShaped target sizeMnistH out_width r
-> target
(TKS
((':) @Nat SizeMnistLabel ((':) @Nat batch_size ('[] @Nat))) r)
rnnMnistZeroS SNat out_width
out_width
SNat batch_size
batch_size
(forall (n :: Nat). KnownNat n => SNat n
SNat @h) (forall (n :: Nat). KnownNat n => SNat n
SNat @w)
PrimalOf
target
(TKS
((':) @Nat w ((':) @Nat h ((':) @Nat batch_size ('[] @Nat)))) r)
Concrete
(TKS
((':)
@Nat
SizeMnistHeight
((':) @Nat SizeMnistHeight ((':) @Nat batch_size ('[] @Nat))))
r)
input
in ADRnnMnistParametersShaped target h out_width r
-> target
(TKS
((':) @Nat SizeMnistLabel ((':) @Nat batch_size ('[] @Nat))) r)
nn ADRnnMnistParametersShaped target h out_width r
testParams
outputs :: [Vector r]
outputs = (Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)
-> Vector r)
-> [Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)]
-> [Vector r]
forall a b. (a -> b) -> [a] -> [b]
map Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r) -> Vector r
forall r (sh :: [Nat]).
GoodScalar r =>
Concrete (TKS sh r) -> Vector r
stoVector ([Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)]
-> [Vector r])
-> [Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)]
-> [Vector r]
forall a b. (a -> b) -> a -> b
$ Concrete
(TKS2
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat)))
(TKScalar r))
-> [Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)]
forall (n :: Nat) (sh :: [Nat]) (x :: TK) (target :: Target).
(KnownNat n, KnownShS sh, KnownSTK x, BaseTensor target) =>
target (TKS2 ((':) @Nat n sh) x) -> [target (TKS2 sh x)]
sunravelToList
(Concrete
(TKS2
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat)))
(TKScalar r))
-> [Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)])
-> Concrete
(TKS2
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat)))
(TKScalar r))
-> [Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)]
forall a b. (a -> b) -> a -> b
$ forall (perm :: [Nat]) (sh :: [Nat]) (x :: TK) (target :: Target).
(KnownPerm perm, IsPermutation perm,
(<=) @Nat (Rank @Nat perm) (Rank @Nat sh), KnownSTK x,
BaseTensor target) =>
target (TKS2 sh x) -> target (TKS2 (PermutePrefix @Nat perm sh) x)
stranspose @'[1, 0] Concrete
(TKS
((':) @Nat SizeMnistLabel ((':) @Nat batch_size ('[] @Nat))) r)
outputS
labels :: [Vector r]
labels = (Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)
-> Vector r)
-> [Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)]
-> [Vector r]
forall a b. (a -> b) -> [a] -> [b]
map Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r) -> Vector r
forall r (sh :: [Nat]).
GoodScalar r =>
Concrete (TKS sh r) -> Vector r
stoVector
([Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)]
-> [Vector r])
-> [Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)]
-> [Vector r]
forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) (sh :: [Nat]) (x :: TK) (target :: Target).
(KnownNat n, KnownShS sh, KnownSTK x, BaseTensor target) =>
target (TKS2 ((':) @Nat n sh) x) -> [target (TKS2 sh x)]
sunravelToList @_ @_ @(TKScalar r)
(Concrete
(TKS2
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat)))
(TKScalar r))
-> [Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)])
-> Concrete
(TKS2
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat)))
(TKScalar r))
-> [Concrete (TKS ((':) @Nat SizeMnistLabel ('[] @Nat)) r)]
forall a b. (a -> b) -> a -> b
$ Shaped
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r
-> Concrete
(TKS2
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat)))
(TKScalar r))
forall r (target :: Target) (sh :: [Nat]).
(GoodScalar r, BaseTensor target) =>
Shaped sh r -> target (TKS sh r)
sconcrete Shaped
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r
labelS
matchesLabels :: Vector r -> Vector r -> Int
matchesLabels :: Vector r -> Vector r -> Int
matchesLabels Vector r
output Vector r
label | Vector r -> Int
forall (v :: Type -> Type) a. (Vector v a, Ord a) => v a -> Int
V.maxIndex Vector r
output Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Vector r -> Int
forall (v :: Type -> Type) a. (Vector v a, Ord a) => v a -> Int
V.maxIndex Vector r
label = Int
1
| Bool
otherwise = Int
0
in Int -> r
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([Int] -> Int
forall a. Num a => [a] -> a
forall (t :: Type -> Type) a. (Foldable t, Num a) => t a -> a
sum ((Vector r -> Vector r -> Int) -> [Vector r] -> [Vector r] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Vector r -> Vector r -> Int
matchesLabels [Vector r]
outputs [Vector r]
labels))
r -> r -> r
forall a. Fractional a => a -> a -> a
/ Integer -> r
forall a. Num a => Integer -> a
fromInteger (SNat batch_size -> Integer
forall (n :: Nat). SNat n -> Integer
fromSNat SNat batch_size
batch_size)