{-# OPTIONS_GHC -Wno-missing-export-lists #-}
module MnistRnnRanked2 where
import Prelude hiding (foldl')
import Data.Kind (Type)
import Data.List (foldl')
import Data.Vector.Generic qualified as V
import Data.Vector.Storable (Vector)
import GHC.TypeLits (KnownNat, Nat, type (+))
import Data.Array.Nested qualified as Nested
import Data.Array.Nested.Ranked.Shape
import HordeAd
import MnistData
type (target :: Target) width r =
( LayerWeigthsRNNShaped target SizeMnistHeight width r
, LayerWeigthsRNNShaped target width width r
, ( target (TKS '[SizeMnistLabel, width] r)
, target (TKS '[SizeMnistLabel] r) ) )
type LayerWeigthsRNNShaped :: Target -> Nat -> Nat -> Type -> Type
type LayerWeigthsRNNShaped target in_width out_width r =
( target (TKS '[out_width, in_width] r)
, target (TKS '[out_width, out_width] r)
, target (TKS '[out_width] r) )
type ADRnnMnistParameters target r =
( LayerWeigthsRNN target r
, LayerWeigthsRNN target r
, ( target (TKR 2 r)
, target (TKR 1 r) ) )
type LayerWeigthsRNN (target :: Target) r =
( target (TKR 2 r)
, target (TKR 2 r)
, target (TKR 1 r) )
zeroStateR
:: (BaseTensor target, GoodScalar r)
=> IShR n -> (target (TKR n r)
-> a)
-> a
zeroStateR :: forall (target :: Target) r (n :: Nat) a.
(BaseTensor target, GoodScalar r) =>
IShR n -> (target (TKR n r) -> a) -> a
zeroStateR IShR n
sh target (TKR n r) -> a
f = target (TKR n r) -> a
f (IShR n -> r -> target (TKR n r)
forall (n :: Nat) r (target :: Target).
(GoodScalar r, BaseTensor target) =>
IShR n -> r -> target (TKR n r)
rrepl IShR n
sh r
0)
unrollLastR :: forall target state c w r n.
(BaseTensor target, GoodScalar r, KnownNat n)
=> (state -> target (TKR n r) -> w -> (c, state))
-> (state -> target (TKR (1 + n) r) -> w -> (c, state))
unrollLastR :: forall (target :: Target) state c w r (n :: Nat).
(BaseTensor target, GoodScalar r, KnownNat n) =>
(state -> target (TKR n r) -> w -> (c, state))
-> state -> target (TKR (1 + n) r) -> w -> (c, state)
unrollLastR state -> target (TKR n r) -> w -> (c, state)
f state
s0 target (TKR (1 + n) r)
xs w
w =
let g :: (c, state) -> target (TKR n r) -> (c, state)
g :: (c, state) -> target (TKR n r) -> (c, state)
g (c
_, !state
s) target (TKR n r)
x = state -> target (TKR n r) -> w -> (c, state)
f state
s target (TKR n r)
x w
w
in ((c, state) -> target (TKR n r) -> (c, state))
-> (c, state) -> [target (TKR n r)] -> (c, state)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (c, state) -> target (TKR n r) -> (c, state)
g (c
forall a. HasCallStack => a
undefined, state
s0) (target (TKR (1 + n) r) -> [target (TKR n r)]
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownSTK x, KnownNat n, BaseTensor target) =>
target (TKR2 (1 + n) x) -> [target (TKR2 n x)]
runravelToList target (TKR (1 + n) r)
xs)
rnnMnistLayerR
:: (ADReady target, GoodScalar r, Differentiable r)
=> target (TKR 2 r)
-> target (TKR 2 r)
-> LayerWeigthsRNN target r
-> target (TKR 2 r)
rnnMnistLayerR :: forall (target :: Target) r.
(ADReady target, GoodScalar r, Differentiable r) =>
target (TKR 2 r)
-> target (TKR 2 r) -> LayerWeigthsRNN target r -> target (TKR 2 r)
rnnMnistLayerR target (TKR 2 r)
s target (TKR 2 r)
x (target (TKR 2 r)
wX, target (TKR 2 r)
wS, target (TKR 1 r)
b) = case target (TKR 2 r) -> IShR 2
forall (n :: Nat) (x :: TK).
KnownSTK x =>
target (TKR2 n x) -> IShR n
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKR2 n x) -> IShR n
rshape target (TKR 2 r)
s of
Int
_out_width :$: Int
batch_size :$: ShR n Int
ZSR ->
let y :: target (TKR 2 r)
y = target (TKR 2 r)
wX target (TKR 2 r) -> target (TKR 2 r) -> target (TKR 2 r)
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
target (TKR 2 r) -> target (TKR 2 r) -> target (TKR 2 r)
`rmatmul2` target (TKR 2 r)
x target (TKR 2 r) -> target (TKR 2 r) -> target (TKR 2 r)
forall a. Num a => a -> a -> a
+ target (TKR 2 r)
wS target (TKR 2 r) -> target (TKR 2 r) -> target (TKR 2 r)
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
target (TKR 2 r) -> target (TKR 2 r) -> target (TKR 2 r)
`rmatmul2` target (TKR 2 r)
s
target (TKR 2 r) -> target (TKR 2 r) -> target (TKR 2 r)
forall a. Num a => a -> a -> a
+ target (TKR2 (2 + 0) (TKScalar r))
-> target (TKR2 (2 + 0) (TKScalar r))
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownSTK x, BaseTensor target) =>
target (TKR2 (2 + n) x) -> target (TKR2 (2 + n) x)
rtr (Int -> target (TKR 1 r) -> target (TKR2 (1 + 1) (TKScalar r))
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownNat n, KnownSTK x, BaseTensor target) =>
Int -> target (TKR2 n x) -> target (TKR2 (1 + n) x)
rreplicate Int
batch_size target (TKR 1 r)
b)
in target (TKR 2 r) -> target (TKR 2 r)
forall a. Floating a => a -> a
tanh target (TKR 2 r)
y
rnnMnistTwoR
:: (ADReady target, GoodScalar r, Differentiable r)
=> target (TKR 2 r)
-> PrimalOf target (TKR 2 r)
-> ( LayerWeigthsRNN target r
, LayerWeigthsRNN target r )
-> ( target (TKR 2 r)
, target (TKR 2 r) )
rnnMnistTwoR :: forall (target :: Target) r.
(ADReady target, GoodScalar r, Differentiable r) =>
target (TKR 2 r)
-> PrimalOf target (TKR 2 r)
-> (LayerWeigthsRNN target r, LayerWeigthsRNN target r)
-> (target (TKR 2 r), target (TKR 2 r))
rnnMnistTwoR target (TKR 2 r)
s' PrimalOf target (TKR 2 r)
x ((target (TKR 2 r)
wX, target (TKR 2 r)
wS, target (TKR 1 r)
b), (target (TKR 2 r)
wX2, target (TKR 2 r)
wS2, target (TKR 1 r)
b2)) = case target (TKR 2 r) -> IShR 2
forall (n :: Nat) (x :: TK).
KnownSTK x =>
target (TKR2 n x) -> IShR n
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKR2 n x) -> IShR n
rshape target (TKR 2 r)
s' of
Int
out_width_x_2 :$: Int
_batch_size :$: ShR n Int
ZSR ->
let out_width :: Int
out_width = Int
out_width_x_2 Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2
s3 :: target (TKR 2 r)
s3 = target (TKR 2 r)
-> (target (TKR 2 r) -> target (TKR 2 r)) -> target (TKR 2 r)
forall (x :: TK) (z :: TK) (target :: Target).
LetTensor target =>
target x -> (target x -> target z) -> target z
tlet target (TKR 2 r)
s' ((target (TKR 2 r) -> target (TKR 2 r)) -> target (TKR 2 r))
-> (target (TKR 2 r) -> target (TKR 2 r)) -> target (TKR 2 r)
forall a b. (a -> b) -> a -> b
$ \target (TKR 2 r)
s ->
let s1 :: target (TKR2 (1 + 1) (TKScalar r))
s1 = Int
-> Int
-> target (TKR2 (1 + 1) (TKScalar r))
-> target (TKR2 (1 + 1) (TKScalar r))
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownSTK x, BaseTensor target) =>
Int -> Int -> target (TKR2 (1 + n) x) -> target (TKR2 (1 + n) x)
rslice Int
0 Int
out_width target (TKR 2 r)
target (TKR2 (1 + 1) (TKScalar r))
s
s2 :: target (TKR2 (1 + 1) (TKScalar r))
s2 = Int
-> Int
-> target (TKR2 (1 + 1) (TKScalar r))
-> target (TKR2 (1 + 1) (TKScalar r))
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownSTK x, BaseTensor target) =>
Int -> Int -> target (TKR2 (1 + n) x) -> target (TKR2 (1 + n) x)
rslice Int
out_width Int
out_width target (TKR 2 r)
target (TKR2 (1 + 1) (TKScalar r))
s
vec1 :: target (TKR 2 r)
vec1 = target (TKR 2 r)
-> target (TKR 2 r) -> LayerWeigthsRNN target r -> target (TKR 2 r)
forall (target :: Target) r.
(ADReady target, GoodScalar r, Differentiable r) =>
target (TKR 2 r)
-> target (TKR 2 r) -> LayerWeigthsRNN target r -> target (TKR 2 r)
rnnMnistLayerR target (TKR 2 r)
target (TKR2 (1 + 1) (TKScalar r))
s1 (PrimalOf target (TKR 2 r) -> target (TKR 2 r)
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownNat n, KnownSTK x) =>
PrimalOf target (TKR2 n x) -> target (TKR2 n x)
rfromPrimal PrimalOf target (TKR 2 r)
x) (target (TKR 2 r)
wX, target (TKR 2 r)
wS, target (TKR 1 r)
b)
vec2 :: target (TKR 2 r)
vec2 = target (TKR 2 r)
-> target (TKR 2 r) -> LayerWeigthsRNN target r -> target (TKR 2 r)
forall (target :: Target) r.
(ADReady target, GoodScalar r, Differentiable r) =>
target (TKR 2 r)
-> target (TKR 2 r) -> LayerWeigthsRNN target r -> target (TKR 2 r)
rnnMnistLayerR target (TKR 2 r)
target (TKR2 (1 + 1) (TKScalar r))
s2 target (TKR 2 r)
vec1 (target (TKR 2 r)
wX2, target (TKR 2 r)
wS2, target (TKR 1 r)
b2)
in target (TKR2 (1 + 1) (TKScalar r))
-> target (TKR2 (1 + 1) (TKScalar r))
-> target (TKR2 (1 + 1) (TKScalar r))
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownSTK x, BaseTensor target) =>
target (TKR2 (1 + n) x)
-> target (TKR2 (1 + n) x) -> target (TKR2 (1 + n) x)
rappend target (TKR 2 r)
target (TKR2 (1 + 1) (TKScalar r))
vec1 target (TKR 2 r)
target (TKR2 (1 + 1) (TKScalar r))
vec2
in (Int
-> Int
-> target (TKR2 (1 + 1) (TKScalar r))
-> target (TKR2 (1 + 1) (TKScalar r))
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownSTK x, BaseTensor target) =>
Int -> Int -> target (TKR2 (1 + n) x) -> target (TKR2 (1 + n) x)
rslice Int
out_width Int
out_width target (TKR 2 r)
target (TKR2 (1 + 1) (TKScalar r))
s3, target (TKR 2 r)
s3)
rnnMnistZeroR
:: (ADReady target, GoodScalar r, Differentiable r)
=> Int
-> PrimalOf target (TKR 3 r)
-> ADRnnMnistParameters target r
-> target (TKR 2 r)
rnnMnistZeroR :: forall (target :: Target) r.
(ADReady target, GoodScalar r, Differentiable r) =>
Int
-> PrimalOf target (TKR 3 r)
-> ADRnnMnistParameters target r
-> target (TKR 2 r)
rnnMnistZeroR Int
batch_size PrimalOf target (TKR 3 r)
xs
((target (TKR 2 r)
wX, target (TKR 2 r)
wS, target (TKR 1 r)
b), (target (TKR 2 r)
wX2, target (TKR 2 r)
wS2, target (TKR 1 r)
b2), (target (TKR 2 r)
w3, target (TKR 1 r)
b3)) = case target (TKR 1 r) -> IShR 1
forall (n :: Nat) (x :: TK).
KnownSTK x =>
target (TKR2 n x) -> IShR n
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKR2 n x) -> IShR n
rshape target (TKR 1 r)
b of
Int
out_width :$: ShR n Int
ZSR ->
let sh :: IShR 2
sh = Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
out_width Int -> IShR 1 -> IShR 2
forall {n1 :: Nat} {i} (n :: Nat).
((n + 1 :: Nat) ~ (n1 :: Nat)) =>
i -> ShR n i -> ShR n1 i
:$: Int
batch_size Int -> ShR 0 Int -> IShR 1
forall {n1 :: Nat} {i} (n :: Nat).
((n + 1 :: Nat) ~ (n1 :: Nat)) =>
i -> ShR n i -> ShR n1 i
:$: ShR 0 Int
forall (n :: Nat) i. ((n :: Nat) ~ (0 :: Nat)) => ShR n i
ZSR
(target (TKR 2 r)
out, target (TKR 2 r)
_s) = IShR 2
-> (target (TKR 2 r)
-> PrimalOf target (TKR (1 + 2) r)
-> (LayerWeigthsRNN target r, LayerWeigthsRNN target r)
-> (target (TKR 2 r), target (TKR 2 r)))
-> PrimalOf target (TKR (1 + 2) r)
-> (LayerWeigthsRNN target r, LayerWeigthsRNN target r)
-> (target (TKR 2 r), target (TKR 2 r))
forall (target :: Target) r (n :: Nat) a.
(BaseTensor target, GoodScalar r) =>
IShR n -> (target (TKR n r) -> a) -> a
zeroStateR IShR 2
sh ((target (TKR 2 r)
-> PrimalOf target (TKR 2 r)
-> (LayerWeigthsRNN target r, LayerWeigthsRNN target r)
-> (target (TKR 2 r), target (TKR 2 r)))
-> target (TKR 2 r)
-> PrimalOf target (TKR (1 + 2) r)
-> (LayerWeigthsRNN target r, LayerWeigthsRNN target r)
-> (target (TKR 2 r), target (TKR 2 r))
forall (target :: Target) state c w r (n :: Nat).
(BaseTensor target, GoodScalar r, KnownNat n) =>
(state -> target (TKR n r) -> w -> (c, state))
-> state -> target (TKR (1 + n) r) -> w -> (c, state)
unrollLastR target (TKR 2 r)
-> PrimalOf target (TKR 2 r)
-> (LayerWeigthsRNN target r, LayerWeigthsRNN target r)
-> (target (TKR 2 r), target (TKR 2 r))
forall (target :: Target) r.
(ADReady target, GoodScalar r, Differentiable r) =>
target (TKR 2 r)
-> PrimalOf target (TKR 2 r)
-> (LayerWeigthsRNN target r, LayerWeigthsRNN target r)
-> (target (TKR 2 r), target (TKR 2 r))
rnnMnistTwoR) PrimalOf target (TKR 3 r)
PrimalOf target (TKR (1 + 2) r)
xs
((target (TKR 2 r)
wX, target (TKR 2 r)
wS, target (TKR 1 r)
b), (target (TKR 2 r)
wX2, target (TKR 2 r)
wS2, target (TKR 1 r)
b2))
in target (TKR 2 r)
w3 target (TKR 2 r) -> target (TKR 2 r) -> target (TKR 2 r)
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
target (TKR 2 r) -> target (TKR 2 r) -> target (TKR 2 r)
`rmatmul2` target (TKR 2 r)
out target (TKR 2 r) -> target (TKR 2 r) -> target (TKR 2 r)
forall a. Num a => a -> a -> a
+ target (TKR2 (2 + 0) (TKScalar r))
-> target (TKR2 (2 + 0) (TKScalar r))
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownSTK x, BaseTensor target) =>
target (TKR2 (2 + n) x) -> target (TKR2 (2 + n) x)
rtr (Int -> target (TKR 1 r) -> target (TKR2 (1 + 1) (TKScalar r))
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownNat n, KnownSTK x, BaseTensor target) =>
Int -> target (TKR2 n x) -> target (TKR2 (1 + n) x)
rreplicate Int
batch_size target (TKR 1 r)
b3)
rnnMnistLossFusedR
:: (ADReady target, ADReady (PrimalOf target), GoodScalar r, Differentiable r)
=> Int
-> (PrimalOf target (TKR 3 r), PrimalOf target (TKR 2 r))
-> ADRnnMnistParameters target r
-> target (TKScalar r)
rnnMnistLossFusedR :: forall (target :: Target) r.
(ADReady target, ADReady (PrimalOf target), GoodScalar r,
Differentiable r) =>
Int
-> (PrimalOf target (TKR 3 r), PrimalOf target (TKR 2 r))
-> ADRnnMnistParameters target r
-> target (TKScalar r)
rnnMnistLossFusedR Int
batch_size (PrimalOf target (TKR 3 r)
glyphR, PrimalOf target (TKR 2 r)
labelR) ADRnnMnistParameters target r
adparameters =
let xs :: PrimalOf target (TKR 3 r)
xs = PermR -> PrimalOf target (TKR 3 r) -> PrimalOf target (TKR 3 r)
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownSTK x, BaseTensor target) =>
PermR -> target (TKR2 n x) -> target (TKR2 n x)
rtranspose [Int
2, Int
1, Int
0] PrimalOf target (TKR 3 r)
glyphR
result :: target (TKR 2 r)
result = Int
-> PrimalOf target (TKR 3 r)
-> ADRnnMnistParameters target r
-> target (TKR 2 r)
forall (target :: Target) r.
(ADReady target, GoodScalar r, Differentiable r) =>
Int
-> PrimalOf target (TKR 3 r)
-> ADRnnMnistParameters target r
-> target (TKR 2 r)
rnnMnistZeroR Int
batch_size PrimalOf target (TKR 3 r)
xs ADRnnMnistParameters target r
adparameters
targets :: PrimalOf target (TKR2 (2 + 0) (TKScalar r))
targets = PrimalOf target (TKR2 (2 + 0) (TKScalar r))
-> PrimalOf target (TKR2 (2 + 0) (TKScalar r))
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownSTK x, BaseTensor target) =>
target (TKR2 (2 + n) x) -> target (TKR2 (2 + n) x)
rtr PrimalOf target (TKR 2 r)
PrimalOf target (TKR2 (2 + 0) (TKScalar r))
labelR
loss :: target (TKScalar r)
loss = PrimalOf target (TKR 2 r)
-> target (TKR 2 r) -> target (TKScalar r)
forall (target :: Target) (n :: Nat) r.
(BaseTensor target, ConvertTensor target, LetTensor target,
BaseTensor (PrimalOf target), ConvertTensor (PrimalOf target),
LetTensor (PrimalOf target), KnownNat n, GoodScalar r,
Differentiable r) =>
PrimalOf target (TKR n r)
-> target (TKR n r) -> target (TKScalar r)
lossSoftMaxCrossEntropyR PrimalOf target (TKR 2 r)
PrimalOf target (TKR2 (2 + 0) (TKScalar r))
targets target (TKR 2 r)
result
in PrimalOf target (TKScalar r) -> target (TKScalar r)
forall (target :: Target) r.
(BaseTensor target, GoodScalar r) =>
PrimalOf target (TKScalar r) -> target (TKScalar r)
kfromPrimal (PrimalOf target (TKScalar r) -> PrimalOf target (TKScalar r)
forall a. Fractional a => a -> a
recip (PrimalOf target (TKScalar r) -> PrimalOf target (TKScalar r))
-> PrimalOf target (TKScalar r) -> PrimalOf target (TKScalar r)
forall a b. (a -> b) -> a -> b
$ r -> PrimalOf target (TKScalar r)
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
r -> target (TKScalar r)
kconcrete (r -> PrimalOf target (TKScalar r))
-> r -> PrimalOf target (TKScalar r)
forall a b. (a -> b) -> a -> b
$ Int -> r
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
batch_size) target (TKScalar r) -> target (TKScalar r) -> target (TKScalar r)
forall a. Num a => a -> a -> a
* target (TKScalar r)
loss
rnnMnistTestR
:: forall target r.
(target ~ Concrete, GoodScalar r, Differentiable r)
=> Int
-> MnistDataBatchR r
-> ADRnnMnistParameters target r
-> r
rnnMnistTestR :: forall (target :: Target) r.
((target :: Target) ~ (Concrete :: Target), GoodScalar r,
Differentiable r) =>
Int -> MnistDataBatchR r -> ADRnnMnistParameters target r -> r
rnnMnistTestR Int
0 MnistDataBatchR r
_ ADRnnMnistParameters target r
_ = r
0
rnnMnistTestR Int
batch_size (Ranked 3 r
glyphR, Ranked 2 r
labelR) ADRnnMnistParameters target r
testParams =
let input :: target (TKR 3 r)
input :: target (TKR 3 r)
input = Ranked 3 r -> target (TKR 3 r)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete (Ranked 3 r -> target (TKR 3 r)) -> Ranked 3 r -> target (TKR 3 r)
forall a b. (a -> b) -> a -> b
$ PermR -> Ranked 3 r -> Ranked 3 r
forall (n :: Nat) a. Elt a => PermR -> Ranked n a -> Ranked n a
Nested.rtranspose [Int
2, Int
1, Int
0] Ranked 3 r
glyphR
outputR :: Concrete (TKR 2 r)
outputR :: Concrete (TKR 2 r)
outputR =
let nn :: ADRnnMnistParameters target r
-> target (TKR 2 r)
nn :: ADRnnMnistParameters target r -> target (TKR 2 r)
nn = Int
-> PrimalOf target (TKR 3 r)
-> ADRnnMnistParameters target r
-> target (TKR 2 r)
forall (target :: Target) r.
(ADReady target, GoodScalar r, Differentiable r) =>
Int
-> PrimalOf target (TKR 3 r)
-> ADRnnMnistParameters target r
-> target (TKR 2 r)
rnnMnistZeroR Int
batch_size target (TKR 3 r)
PrimalOf target (TKR 3 r)
input
in ADRnnMnistParameters target r -> target (TKR 2 r)
nn ADRnnMnistParameters target r
testParams
outputs :: [Vector r]
outputs = (Concrete (TKR 1 r) -> Vector r)
-> [Concrete (TKR 1 r)] -> [Vector r]
forall a b. (a -> b) -> [a] -> [b]
map Concrete (TKR 1 r) -> Vector r
forall r (n :: Nat). GoodScalar r => Concrete (TKR n r) -> Vector r
rtoVector ([Concrete (TKR 1 r)] -> [Vector r])
-> [Concrete (TKR 1 r)] -> [Vector r]
forall a b. (a -> b) -> a -> b
$ Concrete (TKR2 (1 + 1) (TKScalar r)) -> [Concrete (TKR 1 r)]
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownSTK x, KnownNat n, BaseTensor target) =>
target (TKR2 (1 + n) x) -> [target (TKR2 n x)]
runravelToList
(Concrete (TKR2 (1 + 1) (TKScalar r)) -> [Concrete (TKR 1 r)])
-> Concrete (TKR2 (1 + 1) (TKScalar r)) -> [Concrete (TKR 1 r)]
forall a b. (a -> b) -> a -> b
$ PermR -> Concrete (TKR 2 r) -> Concrete (TKR 2 r)
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownSTK x, BaseTensor target) =>
PermR -> target (TKR2 n x) -> target (TKR2 n x)
rtranspose [Int
1, Int
0] Concrete (TKR 2 r)
outputR
labels :: [Vector r]
labels = (Concrete (TKR 1 r) -> Vector r)
-> [Concrete (TKR 1 r)] -> [Vector r]
forall a b. (a -> b) -> [a] -> [b]
map Concrete (TKR 1 r) -> Vector r
forall r (n :: Nat). GoodScalar r => Concrete (TKR n r) -> Vector r
rtoVector
([Concrete (TKR 1 r)] -> [Vector r])
-> [Concrete (TKR 1 r)] -> [Vector r]
forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) (x :: TK) (target :: Target).
(KnownSTK x, KnownNat n, BaseTensor target) =>
target (TKR2 (1 + n) x) -> [target (TKR2 n x)]
runravelToList @_ @(TKScalar r)
(Concrete (TKR2 (1 + 1) (TKScalar r)) -> [Concrete (TKR 1 r)])
-> Concrete (TKR2 (1 + 1) (TKScalar r)) -> [Concrete (TKR 1 r)]
forall a b. (a -> b) -> a -> b
$ Ranked 2 r -> Concrete (TKR 2 r)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 2 r
labelR
matchesLabels :: Vector r -> Vector r -> Int
matchesLabels :: Vector r -> Vector r -> Int
matchesLabels Vector r
output Vector r
label | Vector r -> Int
forall (v :: Type -> Type) a. (Vector v a, Ord a) => v a -> Int
V.maxIndex Vector r
output Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Vector r -> Int
forall (v :: Type -> Type) a. (Vector v a, Ord a) => v a -> Int
V.maxIndex Vector r
label = Int
1
| Bool
otherwise = Int
0
in Int -> r
forall a b. (Integral a, Num b) => a -> b
fromIntegral (PermR -> Int
forall a. Num a => [a] -> a
forall (t :: Type -> Type) a. (Foldable t, Num a) => t a -> a
sum ((Vector r -> Vector r -> Int) -> [Vector r] -> [Vector r] -> PermR
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Vector r -> Vector r -> Int
matchesLabels [Vector r]
outputs [Vector r]
labels))
r -> r -> r
forall a. Fractional a => a -> a -> a
/ Int -> r
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
batch_size