{-# LANGUAGE OverloadedLists #-}
{-# OPTIONS_GHC -Wno-missing-export-lists #-}
module MnistCnnRanked2 where
import Prelude
import Data.Vector.Generic qualified as V
import Data.Vector.Storable (Vector)
import GHC.TypeLits (type (*), type (+), type Div)
import Data.Array.Nested qualified as Nested
import Data.Array.Nested.Ranked.Shape
import HordeAd
import MnistData
type
(target :: Target) h w kh kw c_out n_hidden r =
( ( target (TKS '[c_out, 1, kh + 1, kw + 1] r)
, target (TKS '[c_out] r) )
, ( target (TKS '[c_out, c_out, kh + 1, kw + 1] r)
, target (TKS '[c_out] r) )
, ( target (TKS '[n_hidden, c_out * (h `Div` 4) * (w `Div` 4)] r)
, target (TKS '[n_hidden] r) )
, ( target (TKS '[SizeMnistLabel, n_hidden] r)
, target (TKS '[SizeMnistLabel] r) )
)
type ADCnnMnistParameters (target :: Target) r =
( ( target (TKR 4 r)
, target (TKR 1 r) )
, ( target (TKR 4 r)
, target (TKR 1 r) )
, ( target (TKR 2 r)
, target (TKR 1 r) )
, ( target (TKR 2 r)
, target (TKR 1 r) ) )
convMnistLayerR
:: (ADReady target, GoodScalar r, Differentiable r)
=> target (TKR 4 r)
-> target (TKR 4 r)
-> target (TKR 1 r)
-> target (TKR 4 r)
convMnistLayerR :: forall (target :: Target) r.
(ADReady target, GoodScalar r, Differentiable r) =>
target (TKR 4 r)
-> target (TKR 4 r) -> target (TKR 1 r) -> target (TKR 4 r)
convMnistLayerR target (TKR 4 r)
ker target (TKR 4 r)
input target (TKR 1 r)
bias =
let (Int
batch_size :$: Int
_ :$: Int
h :$: Int
w :$: ShR n Int
ZSR) = target (TKR 4 r) -> ShR 4 Int
forall (n :: Nat) (x :: TK).
KnownSTK x =>
target (TKR2 n x) -> IShR n
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKR2 n x) -> IShR n
rshape target (TKR 4 r)
input
yConv :: target (TKR 4 r)
yConv = target (TKR 4 r) -> target (TKR 4 r) -> target (TKR 4 r)
forall (target :: Target) r.
(ADReady target, GoodScalar r) =>
target (TKR 4 r) -> target (TKR 4 r) -> target (TKR 4 r)
conv2dUnpadded target (TKR 4 r)
ker target (TKR 4 r)
input
biasStretched :: target (TKR 4 r)
biasStretched = PermR -> target (TKR 4 r) -> target (TKR 4 r)
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownSTK x, BaseTensor target) =>
PermR -> target (TKR2 n x) -> target (TKR2 n x)
rtranspose [Int
Item PermR
0, Int
Item PermR
3, Int
Item PermR
1, Int
Item PermR
2]
(target (TKR 4 r) -> target (TKR 4 r))
-> target (TKR 4 r) -> target (TKR 4 r)
forall a b. (a -> b) -> a -> b
$ Int
-> target (TKR2 3 (TKScalar r))
-> target (TKR2 (1 + 3) (TKScalar r))
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownNat n, KnownSTK x, BaseTensor target) =>
Int -> target (TKR2 n x) -> target (TKR2 (1 + n) x)
rreplicate Int
batch_size (target (TKR2 3 (TKScalar r))
-> target (TKR2 (1 + 3) (TKScalar r)))
-> target (TKR2 3 (TKScalar r))
-> target (TKR2 (1 + 3) (TKScalar r))
forall a b. (a -> b) -> a -> b
$ Int
-> target (TKR2 2 (TKScalar r))
-> target (TKR2 (1 + 2) (TKScalar r))
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownNat n, KnownSTK x, BaseTensor target) =>
Int -> target (TKR2 n x) -> target (TKR2 (1 + n) x)
rreplicate Int
h (target (TKR2 2 (TKScalar r))
-> target (TKR2 (1 + 2) (TKScalar r)))
-> target (TKR2 2 (TKScalar r))
-> target (TKR2 (1 + 2) (TKScalar r))
forall a b. (a -> b) -> a -> b
$ Int -> target (TKR 1 r) -> target (TKR2 (1 + 1) (TKScalar r))
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownNat n, KnownSTK x, BaseTensor target) =>
Int -> target (TKR2 n x) -> target (TKR2 (1 + n) x)
rreplicate Int
w target (TKR 1 r)
bias
yRelu :: target (TKR 4 r)
yRelu = target (TKR 4 r) -> target (TKR 4 r)
forall (target :: Target) (n :: Nat) r.
(ADReady target, GoodScalar r, KnownNat n, Differentiable r) =>
target (TKR n r) -> target (TKR n r)
relu (target (TKR 4 r) -> target (TKR 4 r))
-> target (TKR 4 r) -> target (TKR 4 r)
forall a b. (a -> b) -> a -> b
$ target (TKR 4 r)
yConv target (TKR 4 r) -> target (TKR 4 r) -> target (TKR 4 r)
forall a. Num a => a -> a -> a
+ target (TKR 4 r)
biasStretched
in Int -> Int -> target (TKR 4 r) -> target (TKR 4 r)
forall (target :: Target) r.
(ADReady target, GoodScalar r) =>
Int -> Int -> target (TKR 4 r) -> target (TKR 4 r)
maxPool2dUnpadded Int
2 Int
2 target (TKR 4 r)
yRelu
convMnistTwoR
:: (ADReady target, GoodScalar r, Differentiable r)
=> Int -> Int -> Int
-> PrimalOf target (TKR 4 r)
-> ADCnnMnistParameters target r
-> target (TKR 2 r)
convMnistTwoR :: forall (target :: Target) r.
(ADReady target, GoodScalar r, Differentiable r) =>
Int
-> Int
-> Int
-> PrimalOf target (TKR 4 r)
-> ADCnnMnistParameters target r
-> target (TKR 2 r)
convMnistTwoR Int
sizeMnistHeightI Int
sizeMnistWidthI Int
batch_size PrimalOf target (TKR 4 r)
input
( (target (TKR 4 r)
ker1, target (TKR 1 r)
bias1), (target (TKR 4 r)
ker2, target (TKR 1 r)
bias2)
, (target (TKR 2 r)
weightsDense, target (TKR 1 r)
biasesDense), (target (TKR 2 r)
weightsReadout, target (TKR 1 r)
biasesReadout) ) =
let t1 :: target (TKR 4 r)
t1 = target (TKR 4 r)
-> target (TKR 4 r) -> target (TKR 1 r) -> target (TKR 4 r)
forall (target :: Target) r.
(ADReady target, GoodScalar r, Differentiable r) =>
target (TKR 4 r)
-> target (TKR 4 r) -> target (TKR 1 r) -> target (TKR 4 r)
convMnistLayerR target (TKR 4 r)
ker1 (PrimalOf target (TKR 4 r) -> target (TKR 4 r)
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownNat n, KnownSTK x) =>
PrimalOf target (TKR2 n x) -> target (TKR2 n x)
rfromPrimal PrimalOf target (TKR 4 r)
input) target (TKR 1 r)
bias1
t2 :: target (TKR 4 r)
t2 = target (TKR 4 r)
-> target (TKR 4 r) -> target (TKR 1 r) -> target (TKR 4 r)
forall (target :: Target) r.
(ADReady target, GoodScalar r, Differentiable r) =>
target (TKR 4 r)
-> target (TKR 4 r) -> target (TKR 1 r) -> target (TKR 4 r)
convMnistLayerR target (TKR 4 r)
ker2 target (TKR 4 r)
t1 target (TKR 1 r)
bias2
c_out :: Int
c_out = target (TKR2 (1 + 0) (TKScalar r)) -> Int
forall (n :: Nat) (x :: TK).
KnownSTK x =>
target (TKR2 (1 + n) x) -> Int
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKR2 (1 + n) x) -> Int
rwidth target (TKR 1 r)
target (TKR2 (1 + 0) (TKScalar r))
bias1
m1 :: target (TKR2 ((0 + 1) + 1) (TKScalar r))
m1 = IShR ((0 + 1) + 1)
-> target (TKR 4 r) -> target (TKR2 ((0 + 1) + 1) (TKScalar r))
forall (n :: Nat) (m :: Nat) (x :: TK) (target :: Target).
(KnownSTK x, BaseTensor target) =>
IShR m -> target (TKR2 n x) -> target (TKR2 m x)
rreshape (Int
batch_size
Int -> ShR (0 + 1) Int -> IShR ((0 + 1) + 1)
forall {n1 :: Nat} {i} (n :: Nat).
((n + 1 :: Nat) ~ (n1 :: Nat)) =>
i -> ShR n i -> ShR n1 i
:$: Int
c_out Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Int
sizeMnistHeightI Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
4)
Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Int
sizeMnistWidthI Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
4)
Int -> ShR 0 Int -> ShR (0 + 1) Int
forall {n1 :: Nat} {i} (n :: Nat).
((n + 1 :: Nat) ~ (n1 :: Nat)) =>
i -> ShR n i -> ShR n1 i
:$: ShR 0 Int
forall (n :: Nat) i. ((n :: Nat) ~ (0 :: Nat)) => ShR n i
ZSR)
target (TKR 4 r)
t2
m2 :: target (TKR2 (2 + 0) (TKScalar r))
m2 = target (TKR2 (2 + 0) (TKScalar r))
-> target (TKR2 (2 + 0) (TKScalar r))
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownSTK x, BaseTensor target) =>
target (TKR2 (2 + n) x) -> target (TKR2 (2 + n) x)
rtr target (TKR2 (2 + 0) (TKScalar r))
target (TKR2 ((0 + 1) + 1) (TKScalar r))
m1
denseLayer :: target (TKR 2 r)
denseLayer = target (TKR 2 r)
weightsDense target (TKR 2 r) -> target (TKR 2 r) -> target (TKR 2 r)
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
target (TKR 2 r) -> target (TKR 2 r) -> target (TKR 2 r)
`rmatmul2` target (TKR 2 r)
target (TKR2 (2 + 0) (TKScalar r))
m2
target (TKR 2 r) -> target (TKR 2 r) -> target (TKR 2 r)
forall a. Num a => a -> a -> a
+ target (TKR2 (2 + 0) (TKScalar r))
-> target (TKR2 (2 + 0) (TKScalar r))
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownSTK x, BaseTensor target) =>
target (TKR2 (2 + n) x) -> target (TKR2 (2 + n) x)
rtr (Int -> target (TKR 1 r) -> target (TKR2 (1 + 1) (TKScalar r))
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownNat n, KnownSTK x, BaseTensor target) =>
Int -> target (TKR2 n x) -> target (TKR2 (1 + n) x)
rreplicate Int
batch_size target (TKR 1 r)
biasesDense)
denseRelu :: target (TKR 2 r)
denseRelu = target (TKR 2 r) -> target (TKR 2 r)
forall (target :: Target) (n :: Nat) r.
(ADReady target, GoodScalar r, KnownNat n, Differentiable r) =>
target (TKR n r) -> target (TKR n r)
relu target (TKR 2 r)
denseLayer
in target (TKR 2 r)
weightsReadout target (TKR 2 r) -> target (TKR 2 r) -> target (TKR 2 r)
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
target (TKR 2 r) -> target (TKR 2 r) -> target (TKR 2 r)
`rmatmul2` target (TKR 2 r)
denseRelu
target (TKR 2 r) -> target (TKR 2 r) -> target (TKR 2 r)
forall a. Num a => a -> a -> a
+ target (TKR2 (2 + 0) (TKScalar r))
-> target (TKR2 (2 + 0) (TKScalar r))
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownSTK x, BaseTensor target) =>
target (TKR2 (2 + n) x) -> target (TKR2 (2 + n) x)
rtr (Int -> target (TKR 1 r) -> target (TKR2 (1 + 1) (TKScalar r))
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownNat n, KnownSTK x, BaseTensor target) =>
Int -> target (TKR2 n x) -> target (TKR2 (1 + n) x)
rreplicate Int
batch_size target (TKR 1 r)
biasesReadout)
convMnistLossFusedR
:: (ADReady target, ADReady (PrimalOf target), GoodScalar r, Differentiable r)
=> Int
-> ( PrimalOf target (TKR 3 r)
, PrimalOf target (TKR 2 r) )
-> ADCnnMnistParameters target r
-> target (TKScalar r)
convMnistLossFusedR :: forall (target :: Target) r.
(ADReady target, ADReady (PrimalOf target), GoodScalar r,
Differentiable r) =>
Int
-> (PrimalOf target (TKR 3 r), PrimalOf target (TKR 2 r))
-> ADCnnMnistParameters target r
-> target (TKScalar r)
convMnistLossFusedR Int
batch_size (PrimalOf target (TKR 3 r)
glyphR, PrimalOf target (TKR 2 r)
labelR) ADCnnMnistParameters target r
adparameters =
let input :: PrimalOf target (TKR2 ((((0 + 1) + 1) + 1) + 1) (TKScalar r))
input = IShR ((((0 + 1) + 1) + 1) + 1)
-> PrimalOf target (TKR 3 r)
-> PrimalOf target (TKR2 ((((0 + 1) + 1) + 1) + 1) (TKScalar r))
forall (n :: Nat) (m :: Nat) (x :: TK) (target :: Target).
(KnownSTK x, BaseTensor target) =>
IShR m -> target (TKR2 n x) -> target (TKR2 m x)
rreshape (Int
batch_size
Int
-> ShR (((0 + 1) + 1) + 1) Int -> IShR ((((0 + 1) + 1) + 1) + 1)
forall {n1 :: Nat} {i} (n :: Nat).
((n + 1 :: Nat) ~ (n1 :: Nat)) =>
i -> ShR n i -> ShR n1 i
:$: Int
1
Int -> IShR ((0 + 1) + 1) -> ShR (((0 + 1) + 1) + 1) Int
forall {n1 :: Nat} {i} (n :: Nat).
((n + 1 :: Nat) ~ (n1 :: Nat)) =>
i -> ShR n i -> ShR n1 i
:$: Int
sizeMnistHeightInt
Int -> ShR (0 + 1) Int -> IShR ((0 + 1) + 1)
forall {n1 :: Nat} {i} (n :: Nat).
((n + 1 :: Nat) ~ (n1 :: Nat)) =>
i -> ShR n i -> ShR n1 i
:$: Int
sizeMnistWidthInt
Int -> ShR 0 Int -> ShR (0 + 1) Int
forall {n1 :: Nat} {i} (n :: Nat).
((n + 1 :: Nat) ~ (n1 :: Nat)) =>
i -> ShR n i -> ShR n1 i
:$: ShR 0 Int
forall (n :: Nat) i. ((n :: Nat) ~ (0 :: Nat)) => ShR n i
ZSR)
PrimalOf target (TKR 3 r)
glyphR
result :: target (TKR 2 r)
result = Int
-> Int
-> Int
-> PrimalOf target (TKR 4 r)
-> ADCnnMnistParameters target r
-> target (TKR 2 r)
forall (target :: Target) r.
(ADReady target, GoodScalar r, Differentiable r) =>
Int
-> Int
-> Int
-> PrimalOf target (TKR 4 r)
-> ADCnnMnistParameters target r
-> target (TKR 2 r)
convMnistTwoR Int
sizeMnistHeightInt Int
sizeMnistWidthInt
Int
batch_size PrimalOf target (TKR 4 r)
PrimalOf target (TKR2 ((((0 + 1) + 1) + 1) + 1) (TKScalar r))
input ADCnnMnistParameters target r
adparameters
targets :: PrimalOf target (TKR2 (2 + 0) (TKScalar r))
targets = PrimalOf target (TKR2 (2 + 0) (TKScalar r))
-> PrimalOf target (TKR2 (2 + 0) (TKScalar r))
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownSTK x, BaseTensor target) =>
target (TKR2 (2 + n) x) -> target (TKR2 (2 + n) x)
rtr PrimalOf target (TKR 2 r)
PrimalOf target (TKR2 (2 + 0) (TKScalar r))
labelR
loss :: target (TKScalar r)
loss = PrimalOf target (TKR 2 r)
-> target (TKR 2 r) -> target (TKScalar r)
forall (target :: Target) (n :: Nat) r.
(BaseTensor target, ConvertTensor target, LetTensor target,
BaseTensor (PrimalOf target), ConvertTensor (PrimalOf target),
LetTensor (PrimalOf target), KnownNat n, GoodScalar r,
Differentiable r) =>
PrimalOf target (TKR n r)
-> target (TKR n r) -> target (TKScalar r)
lossSoftMaxCrossEntropyR PrimalOf target (TKR 2 r)
PrimalOf target (TKR2 (2 + 0) (TKScalar r))
targets target (TKR 2 r)
result
in PrimalOf target (TKScalar r) -> target (TKScalar r)
forall (target :: Target) r.
(BaseTensor target, GoodScalar r) =>
PrimalOf target (TKScalar r) -> target (TKScalar r)
kfromPrimal (PrimalOf target (TKScalar r) -> PrimalOf target (TKScalar r)
forall a. Fractional a => a -> a
recip (PrimalOf target (TKScalar r) -> PrimalOf target (TKScalar r))
-> PrimalOf target (TKScalar r) -> PrimalOf target (TKScalar r)
forall a b. (a -> b) -> a -> b
$ r -> PrimalOf target (TKScalar r)
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
r -> target (TKScalar r)
kconcrete (r -> PrimalOf target (TKScalar r))
-> r -> PrimalOf target (TKScalar r)
forall a b. (a -> b) -> a -> b
$ Int -> r
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
batch_size) target (TKScalar r) -> target (TKScalar r) -> target (TKScalar r)
forall a. Num a => a -> a -> a
* target (TKScalar r)
loss
convMnistTestR
:: forall target r.
(target ~ Concrete, GoodScalar r, Differentiable r)
=> Int
-> MnistDataBatchR r
-> ADCnnMnistParameters Concrete r
-> r
convMnistTestR :: forall (target :: Target) r.
((target :: Target) ~ (Concrete :: Target), GoodScalar r,
Differentiable r) =>
Int -> MnistDataBatchR r -> ADCnnMnistParameters Concrete r -> r
convMnistTestR Int
0 MnistDataBatchR r
_ ADCnnMnistParameters Concrete r
_ = r
0
convMnistTestR Int
batch_size (Ranked 3 r
glyphR, Ranked 2 r
labelR) ADCnnMnistParameters Concrete r
testParams =
let input :: target (TKR 4 r)
input :: target (TKR 4 r)
input =
Ranked 4 r -> target (TKR 4 r)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete (Ranked 4 r -> target (TKR 4 r)) -> Ranked 4 r -> target (TKR 4 r)
forall a b. (a -> b) -> a -> b
$ ShR 4 Int -> Ranked 3 r -> Ranked 4 r
forall (n :: Nat) (n' :: Nat) a.
Elt a =>
IShR n' -> Ranked n a -> Ranked n' a
Nested.rreshape [ Int
Item (ShR 4 Int)
batch_size, Int
Item (ShR 4 Int)
1
, Int
Item (ShR 4 Int)
sizeMnistHeightInt, Int
Item (ShR 4 Int)
sizeMnistWidthInt ]
Ranked 3 r
glyphR
outputR :: Concrete (TKR 2 r)
outputR :: Concrete (TKR 2 r)
outputR =
let nn :: ADCnnMnistParameters target r
-> target (TKR 2 r)
nn :: ADCnnMnistParameters target r -> target (TKR 2 r)
nn = Int
-> Int
-> Int
-> PrimalOf target (TKR 4 r)
-> ADCnnMnistParameters target r
-> target (TKR 2 r)
forall (target :: Target) r.
(ADReady target, GoodScalar r, Differentiable r) =>
Int
-> Int
-> Int
-> PrimalOf target (TKR 4 r)
-> ADCnnMnistParameters target r
-> target (TKR 2 r)
convMnistTwoR Int
sizeMnistHeightInt Int
sizeMnistWidthInt
Int
batch_size target (TKR 4 r)
PrimalOf target (TKR 4 r)
input
in ADCnnMnistParameters target r -> target (TKR 2 r)
nn ADCnnMnistParameters target r
ADCnnMnistParameters Concrete r
testParams
outputs :: [Vector r]
outputs = (Concrete (TKR 1 r) -> Vector r)
-> [Concrete (TKR 1 r)] -> [Vector r]
forall a b. (a -> b) -> [a] -> [b]
map Concrete (TKR 1 r) -> Vector r
forall r (n :: Nat). GoodScalar r => Concrete (TKR n r) -> Vector r
rtoVector ([Concrete (TKR 1 r)] -> [Vector r])
-> [Concrete (TKR 1 r)] -> [Vector r]
forall a b. (a -> b) -> a -> b
$ Concrete (TKR2 (1 + 1) (TKScalar r)) -> [Concrete (TKR 1 r)]
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownSTK x, KnownNat n, BaseTensor target) =>
target (TKR2 (1 + n) x) -> [target (TKR2 n x)]
runravelToList
(Concrete (TKR2 (1 + 1) (TKScalar r)) -> [Concrete (TKR 1 r)])
-> Concrete (TKR2 (1 + 1) (TKScalar r)) -> [Concrete (TKR 1 r)]
forall a b. (a -> b) -> a -> b
$ PermR -> Concrete (TKR 2 r) -> Concrete (TKR 2 r)
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownSTK x, BaseTensor target) =>
PermR -> target (TKR2 n x) -> target (TKR2 n x)
rtranspose [Int
Item PermR
1, Int
Item PermR
0] Concrete (TKR 2 r)
outputR
labels :: [Vector r]
labels = (Concrete (TKR 1 r) -> Vector r)
-> [Concrete (TKR 1 r)] -> [Vector r]
forall a b. (a -> b) -> [a] -> [b]
map Concrete (TKR 1 r) -> Vector r
forall r (n :: Nat). GoodScalar r => Concrete (TKR n r) -> Vector r
rtoVector
([Concrete (TKR 1 r)] -> [Vector r])
-> [Concrete (TKR 1 r)] -> [Vector r]
forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) (x :: TK) (target :: Target).
(KnownSTK x, KnownNat n, BaseTensor target) =>
target (TKR2 (1 + n) x) -> [target (TKR2 n x)]
runravelToList @_ @(TKScalar r)
(Concrete (TKR2 (1 + 1) (TKScalar r)) -> [Concrete (TKR 1 r)])
-> Concrete (TKR2 (1 + 1) (TKScalar r)) -> [Concrete (TKR 1 r)]
forall a b. (a -> b) -> a -> b
$ Ranked 2 r -> Concrete (TKR 2 r)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 2 r
labelR
matchesLabels :: Vector r -> Vector r -> Int
matchesLabels :: Vector r -> Vector r -> Int
matchesLabels Vector r
output Vector r
label | Vector r -> Int
forall (v :: Type -> Type) a. (Vector v a, Ord a) => v a -> Int
V.maxIndex Vector r
output Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Vector r -> Int
forall (v :: Type -> Type) a. (Vector v a, Ord a) => v a -> Int
V.maxIndex Vector r
label = Int
1
| Bool
otherwise = Int
0
in Int -> r
forall a b. (Integral a, Num b) => a -> b
fromIntegral (PermR -> Int
forall a. Num a => [a] -> a
forall (t :: Type -> Type) a. (Foldable t, Num a) => t a -> a
sum ((Vector r -> Vector r -> Int) -> [Vector r] -> [Vector r] -> PermR
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Vector r -> Vector r -> Int
matchesLabels [Vector r]
outputs [Vector r]
labels))
r -> r -> r
forall a. Fractional a => a -> a -> a
/ Int -> r
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
batch_size