{-# LANGUAGE OverloadedLists #-}
{-# OPTIONS_GHC -Wno-missing-export-lists #-}
-- | Ranked tensor-based implementation of Convolutional Neural Network
-- for classification of MNIST digits. Sports 2 hidden layers.
--
-- With the current CPU backend it's slow enough that it's hard to see
-- if it trains.
module MnistCnnRanked2 where

import Prelude

import Data.Vector.Generic qualified as V
import Data.Vector.Storable (Vector)
import GHC.TypeLits (type (*), type (+), type Div)

import Data.Array.Nested qualified as Nested
import Data.Array.Nested.Ranked.Shape

import HordeAd
import MnistData

-- | The differentiable type of all trainable parameters of this nn.
-- Shaped version, statically checking all dimension widths.
--
-- Due to subtraction complicating posititive number type inference,
-- @kh@ denotes kernel height minus one and analogously @kw@ is kernel
-- width minus one.
type ADCnnMnistParametersShaped
       (target :: Target) h w kh kw c_out n_hidden r =
  ( ( target (TKS '[c_out, 1, kh + 1, kw + 1] r)
    , target (TKS '[c_out] r) )
  , ( target (TKS '[c_out, c_out, kh + 1, kw + 1] r)
    , target (TKS '[c_out] r) )
  , ( target (TKS '[n_hidden, c_out * (h `Div` 4) * (w `Div` 4)] r)
    , target (TKS '[n_hidden] r) )
  , ( target (TKS '[SizeMnistLabel, n_hidden] r)
    , target (TKS '[SizeMnistLabel] r) )
  )

-- | The differentiable type of all trainable parameters of this nn.
-- Ranked version.
type ADCnnMnistParameters (target :: Target) r =
  ( ( target (TKR 4 r)
    , target (TKR 1 r) )
  , ( target (TKR 4 r)
    , target (TKR 1 r) )
  , ( target (TKR 2 r)
    , target (TKR 1 r) )
  , ( target (TKR 2 r)
    , target (TKR 1 r) ) )

-- | A single convolutional layer with @relu@ and @maxPool@.
convMnistLayerR
  :: (ADReady target, GoodScalar r, Differentiable r)
  => target (TKR 4 r)  -- ^ @[c_out, c_in, kh + 1, kw + 1]@
  -> target (TKR 4 r)  -- ^ @[batch_size, c_in, h, w]@
  -> target (TKR 1 r)  -- ^ @[c_out]@
  -> target (TKR 4 r)  -- ^ @[batch_size, c_out, h \`Div\` 2, w \`Div\` 2]@
convMnistLayerR :: forall (target :: Target) r.
(ADReady target, GoodScalar r, Differentiable r) =>
target (TKR 4 r)
-> target (TKR 4 r) -> target (TKR 1 r) -> target (TKR 4 r)
convMnistLayerR target (TKR 4 r)
ker target (TKR 4 r)
input target (TKR 1 r)
bias =
  let (Int
batch_size :$: Int
_ :$: Int
h :$: Int
w :$: ShR n Int
ZSR) = target (TKR 4 r) -> ShR 4 Int
forall (n :: Nat) (x :: TK).
KnownSTK x =>
target (TKR2 n x) -> IShR n
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKR2 n x) -> IShR n
rshape target (TKR 4 r)
input
      yConv :: target (TKR 4 r)
yConv = target (TKR 4 r) -> target (TKR 4 r) -> target (TKR 4 r)
forall (target :: Target) r.
(ADReady target, GoodScalar r) =>
target (TKR 4 r) -> target (TKR 4 r) -> target (TKR 4 r)
conv2dUnpadded target (TKR 4 r)
ker target (TKR 4 r)
input
      biasStretched :: target (TKR 4 r)
biasStretched = PermR -> target (TKR 4 r) -> target (TKR 4 r)
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownSTK x, BaseTensor target) =>
PermR -> target (TKR2 n x) -> target (TKR2 n x)
rtranspose [Int
Item PermR
0, Int
Item PermR
3, Int
Item PermR
1, Int
Item PermR
2]
                      (target (TKR 4 r) -> target (TKR 4 r))
-> target (TKR 4 r) -> target (TKR 4 r)
forall a b. (a -> b) -> a -> b
$ Int
-> target (TKR2 3 (TKScalar r))
-> target (TKR2 (1 + 3) (TKScalar r))
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownNat n, KnownSTK x, BaseTensor target) =>
Int -> target (TKR2 n x) -> target (TKR2 (1 + n) x)
rreplicate Int
batch_size (target (TKR2 3 (TKScalar r))
 -> target (TKR2 (1 + 3) (TKScalar r)))
-> target (TKR2 3 (TKScalar r))
-> target (TKR2 (1 + 3) (TKScalar r))
forall a b. (a -> b) -> a -> b
$ Int
-> target (TKR2 2 (TKScalar r))
-> target (TKR2 (1 + 2) (TKScalar r))
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownNat n, KnownSTK x, BaseTensor target) =>
Int -> target (TKR2 n x) -> target (TKR2 (1 + n) x)
rreplicate Int
h (target (TKR2 2 (TKScalar r))
 -> target (TKR2 (1 + 2) (TKScalar r)))
-> target (TKR2 2 (TKScalar r))
-> target (TKR2 (1 + 2) (TKScalar r))
forall a b. (a -> b) -> a -> b
$ Int -> target (TKR 1 r) -> target (TKR2 (1 + 1) (TKScalar r))
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownNat n, KnownSTK x, BaseTensor target) =>
Int -> target (TKR2 n x) -> target (TKR2 (1 + n) x)
rreplicate Int
w target (TKR 1 r)
bias
      yRelu :: target (TKR 4 r)
yRelu = target (TKR 4 r) -> target (TKR 4 r)
forall (target :: Target) (n :: Nat) r.
(ADReady target, GoodScalar r, KnownNat n, Differentiable r) =>
target (TKR n r) -> target (TKR n r)
relu (target (TKR 4 r) -> target (TKR 4 r))
-> target (TKR 4 r) -> target (TKR 4 r)
forall a b. (a -> b) -> a -> b
$ target (TKR 4 r)
yConv target (TKR 4 r) -> target (TKR 4 r) -> target (TKR 4 r)
forall a. Num a => a -> a -> a
+ target (TKR 4 r)
biasStretched
  in Int -> Int -> target (TKR 4 r) -> target (TKR 4 r)
forall (target :: Target) r.
(ADReady target, GoodScalar r) =>
Int -> Int -> target (TKR 4 r) -> target (TKR 4 r)
maxPool2dUnpadded Int
2 Int
2 target (TKR 4 r)
yRelu

-- | Composition of two convolutional layers.
convMnistTwoR
  :: (ADReady target, GoodScalar r, Differentiable r)
  => Int -> Int -> Int
  -> PrimalOf target (TKR 4 r)
       -- ^ input images @[batch_size, 1, SizeMnistHeight, SizeMnistWidth]@
  -> ADCnnMnistParameters target r  -- ^ parameters
  -> target (TKR 2 r)  -- ^ output classification @[SizeMnistLabel, batch_size]@
convMnistTwoR :: forall (target :: Target) r.
(ADReady target, GoodScalar r, Differentiable r) =>
Int
-> Int
-> Int
-> PrimalOf target (TKR 4 r)
-> ADCnnMnistParameters target r
-> target (TKR 2 r)
convMnistTwoR Int
sizeMnistHeightI Int
sizeMnistWidthI Int
batch_size PrimalOf target (TKR 4 r)
input
              ( (target (TKR 4 r)
ker1, target (TKR 1 r)
bias1), (target (TKR 4 r)
ker2, target (TKR 1 r)
bias2)
              , (target (TKR 2 r)
weightsDense, target (TKR 1 r)
biasesDense), (target (TKR 2 r)
weightsReadout, target (TKR 1 r)
biasesReadout) ) =
  let t1 :: target (TKR 4 r)
t1 = target (TKR 4 r)
-> target (TKR 4 r) -> target (TKR 1 r) -> target (TKR 4 r)
forall (target :: Target) r.
(ADReady target, GoodScalar r, Differentiable r) =>
target (TKR 4 r)
-> target (TKR 4 r) -> target (TKR 1 r) -> target (TKR 4 r)
convMnistLayerR target (TKR 4 r)
ker1 (PrimalOf target (TKR 4 r) -> target (TKR 4 r)
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownNat n, KnownSTK x) =>
PrimalOf target (TKR2 n x) -> target (TKR2 n x)
rfromPrimal PrimalOf target (TKR 4 r)
input) target (TKR 1 r)
bias1
      t2 :: target (TKR 4 r)
t2 = target (TKR 4 r)
-> target (TKR 4 r) -> target (TKR 1 r) -> target (TKR 4 r)
forall (target :: Target) r.
(ADReady target, GoodScalar r, Differentiable r) =>
target (TKR 4 r)
-> target (TKR 4 r) -> target (TKR 1 r) -> target (TKR 4 r)
convMnistLayerR target (TKR 4 r)
ker2 target (TKR 4 r)
t1 target (TKR 1 r)
bias2
             -- [ batch_size, c_out
             -- , SizeMnistHeight `Div` 4, SizeMnistWidth `Div` 2 ]
      c_out :: Int
c_out = target (TKR2 (1 + 0) (TKScalar r)) -> Int
forall (n :: Nat) (x :: TK).
KnownSTK x =>
target (TKR2 (1 + n) x) -> Int
forall (target :: Target) (n :: Nat) (x :: TK).
(BaseTensor target, KnownSTK x) =>
target (TKR2 (1 + n) x) -> Int
rwidth target (TKR 1 r)
target (TKR2 (1 + 0) (TKScalar r))
bias1
      m1 :: target (TKR2 ((0 + 1) + 1) (TKScalar r))
m1 = IShR ((0 + 1) + 1)
-> target (TKR 4 r) -> target (TKR2 ((0 + 1) + 1) (TKScalar r))
forall (n :: Nat) (m :: Nat) (x :: TK) (target :: Target).
(KnownSTK x, BaseTensor target) =>
IShR m -> target (TKR2 n x) -> target (TKR2 m x)
rreshape (Int
batch_size
                     Int -> ShR (0 + 1) Int -> IShR ((0 + 1) + 1)
forall {n1 :: Nat} {i} (n :: Nat).
((n + 1 :: Nat) ~ (n1 :: Nat)) =>
i -> ShR n i -> ShR n1 i
:$: Int
c_out Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Int
sizeMnistHeightI Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
4)
                               Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Int
sizeMnistWidthI Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
4)
                     Int -> ShR 0 Int -> ShR (0 + 1) Int
forall {n1 :: Nat} {i} (n :: Nat).
((n + 1 :: Nat) ~ (n1 :: Nat)) =>
i -> ShR n i -> ShR n1 i
:$: ShR 0 Int
forall (n :: Nat) i. ((n :: Nat) ~ (0 :: Nat)) => ShR n i
ZSR)
                    target (TKR 4 r)
t2
      m2 :: target (TKR2 (2 + 0) (TKScalar r))
m2 = target (TKR2 (2 + 0) (TKScalar r))
-> target (TKR2 (2 + 0) (TKScalar r))
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownSTK x, BaseTensor target) =>
target (TKR2 (2 + n) x) -> target (TKR2 (2 + n) x)
rtr target (TKR2 (2 + 0) (TKScalar r))
target (TKR2 ((0 + 1) + 1) (TKScalar r))
m1
      denseLayer :: target (TKR 2 r)
denseLayer = target (TKR 2 r)
weightsDense target (TKR 2 r) -> target (TKR 2 r) -> target (TKR 2 r)
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
target (TKR 2 r) -> target (TKR 2 r) -> target (TKR 2 r)
`rmatmul2` target (TKR 2 r)
target (TKR2 (2 + 0) (TKScalar r))
m2
                   target (TKR 2 r) -> target (TKR 2 r) -> target (TKR 2 r)
forall a. Num a => a -> a -> a
+ target (TKR2 (2 + 0) (TKScalar r))
-> target (TKR2 (2 + 0) (TKScalar r))
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownSTK x, BaseTensor target) =>
target (TKR2 (2 + n) x) -> target (TKR2 (2 + n) x)
rtr (Int -> target (TKR 1 r) -> target (TKR2 (1 + 1) (TKScalar r))
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownNat n, KnownSTK x, BaseTensor target) =>
Int -> target (TKR2 n x) -> target (TKR2 (1 + n) x)
rreplicate Int
batch_size target (TKR 1 r)
biasesDense)
      denseRelu :: target (TKR 2 r)
denseRelu = target (TKR 2 r) -> target (TKR 2 r)
forall (target :: Target) (n :: Nat) r.
(ADReady target, GoodScalar r, KnownNat n, Differentiable r) =>
target (TKR n r) -> target (TKR n r)
relu target (TKR 2 r)
denseLayer
  in target (TKR 2 r)
weightsReadout target (TKR 2 r) -> target (TKR 2 r) -> target (TKR 2 r)
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
target (TKR 2 r) -> target (TKR 2 r) -> target (TKR 2 r)
`rmatmul2` target (TKR 2 r)
denseRelu
     target (TKR 2 r) -> target (TKR 2 r) -> target (TKR 2 r)
forall a. Num a => a -> a -> a
+ target (TKR2 (2 + 0) (TKScalar r))
-> target (TKR2 (2 + 0) (TKScalar r))
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownSTK x, BaseTensor target) =>
target (TKR2 (2 + n) x) -> target (TKR2 (2 + n) x)
rtr (Int -> target (TKR 1 r) -> target (TKR2 (1 + 1) (TKScalar r))
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownNat n, KnownSTK x, BaseTensor target) =>
Int -> target (TKR2 n x) -> target (TKR2 (1 + n) x)
rreplicate Int
batch_size target (TKR 1 r)
biasesReadout)

-- | The neural network composed with the SoftMax-CrossEntropy loss function.
convMnistLossFusedR
  :: (ADReady target, ADReady (PrimalOf target), GoodScalar r, Differentiable r)
  => Int  -- ^ batch_size
  -> ( PrimalOf target (TKR 3 r)
         -- ^ @[batch_size, SizeMnistHeight, SizeMnistWidth]@
     , PrimalOf target (TKR 2 r) )  -- ^ @[batch_size, SizeMnistLabel]@
  -> ADCnnMnistParameters target r  -- kh kw c_out n_hidden
  -> target (TKScalar r)
convMnistLossFusedR :: forall (target :: Target) r.
(ADReady target, ADReady (PrimalOf target), GoodScalar r,
 Differentiable r) =>
Int
-> (PrimalOf target (TKR 3 r), PrimalOf target (TKR 2 r))
-> ADCnnMnistParameters target r
-> target (TKScalar r)
convMnistLossFusedR Int
batch_size (PrimalOf target (TKR 3 r)
glyphR, PrimalOf target (TKR 2 r)
labelR) ADCnnMnistParameters target r
adparameters =
  let input :: PrimalOf target (TKR2 ((((0 + 1) + 1) + 1) + 1) (TKScalar r))
input = IShR ((((0 + 1) + 1) + 1) + 1)
-> PrimalOf target (TKR 3 r)
-> PrimalOf target (TKR2 ((((0 + 1) + 1) + 1) + 1) (TKScalar r))
forall (n :: Nat) (m :: Nat) (x :: TK) (target :: Target).
(KnownSTK x, BaseTensor target) =>
IShR m -> target (TKR2 n x) -> target (TKR2 m x)
rreshape (Int
batch_size
                        Int
-> ShR (((0 + 1) + 1) + 1) Int -> IShR ((((0 + 1) + 1) + 1) + 1)
forall {n1 :: Nat} {i} (n :: Nat).
((n + 1 :: Nat) ~ (n1 :: Nat)) =>
i -> ShR n i -> ShR n1 i
:$: Int
1
                        Int -> IShR ((0 + 1) + 1) -> ShR (((0 + 1) + 1) + 1) Int
forall {n1 :: Nat} {i} (n :: Nat).
((n + 1 :: Nat) ~ (n1 :: Nat)) =>
i -> ShR n i -> ShR n1 i
:$: Int
sizeMnistHeightInt
                        Int -> ShR (0 + 1) Int -> IShR ((0 + 1) + 1)
forall {n1 :: Nat} {i} (n :: Nat).
((n + 1 :: Nat) ~ (n1 :: Nat)) =>
i -> ShR n i -> ShR n1 i
:$: Int
sizeMnistWidthInt
                        Int -> ShR 0 Int -> ShR (0 + 1) Int
forall {n1 :: Nat} {i} (n :: Nat).
((n + 1 :: Nat) ~ (n1 :: Nat)) =>
i -> ShR n i -> ShR n1 i
:$: ShR 0 Int
forall (n :: Nat) i. ((n :: Nat) ~ (0 :: Nat)) => ShR n i
ZSR)
                       PrimalOf target (TKR 3 r)
glyphR
      result :: target (TKR 2 r)
result = Int
-> Int
-> Int
-> PrimalOf target (TKR 4 r)
-> ADCnnMnistParameters target r
-> target (TKR 2 r)
forall (target :: Target) r.
(ADReady target, GoodScalar r, Differentiable r) =>
Int
-> Int
-> Int
-> PrimalOf target (TKR 4 r)
-> ADCnnMnistParameters target r
-> target (TKR 2 r)
convMnistTwoR Int
sizeMnistHeightInt Int
sizeMnistWidthInt
                             Int
batch_size PrimalOf target (TKR 4 r)
PrimalOf target (TKR2 ((((0 + 1) + 1) + 1) + 1) (TKScalar r))
input ADCnnMnistParameters target r
adparameters
      targets :: PrimalOf target (TKR2 (2 + 0) (TKScalar r))
targets = PrimalOf target (TKR2 (2 + 0) (TKScalar r))
-> PrimalOf target (TKR2 (2 + 0) (TKScalar r))
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownSTK x, BaseTensor target) =>
target (TKR2 (2 + n) x) -> target (TKR2 (2 + n) x)
rtr PrimalOf target (TKR 2 r)
PrimalOf target (TKR2 (2 + 0) (TKScalar r))
labelR
      loss :: target (TKScalar r)
loss = PrimalOf target (TKR 2 r)
-> target (TKR 2 r) -> target (TKScalar r)
forall (target :: Target) (n :: Nat) r.
(BaseTensor target, ConvertTensor target, LetTensor target,
 BaseTensor (PrimalOf target), ConvertTensor (PrimalOf target),
 LetTensor (PrimalOf target), KnownNat n, GoodScalar r,
 Differentiable r) =>
PrimalOf target (TKR n r)
-> target (TKR n r) -> target (TKScalar r)
lossSoftMaxCrossEntropyR PrimalOf target (TKR 2 r)
PrimalOf target (TKR2 (2 + 0) (TKScalar r))
targets target (TKR 2 r)
result
  in PrimalOf target (TKScalar r) -> target (TKScalar r)
forall (target :: Target) r.
(BaseTensor target, GoodScalar r) =>
PrimalOf target (TKScalar r) -> target (TKScalar r)
kfromPrimal (PrimalOf target (TKScalar r) -> PrimalOf target (TKScalar r)
forall a. Fractional a => a -> a
recip (PrimalOf target (TKScalar r) -> PrimalOf target (TKScalar r))
-> PrimalOf target (TKScalar r) -> PrimalOf target (TKScalar r)
forall a b. (a -> b) -> a -> b
$ r -> PrimalOf target (TKScalar r)
forall r (target :: Target).
(GoodScalar r, BaseTensor target) =>
r -> target (TKScalar r)
kconcrete (r -> PrimalOf target (TKScalar r))
-> r -> PrimalOf target (TKScalar r)
forall a b. (a -> b) -> a -> b
$ Int -> r
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
batch_size) target (TKScalar r) -> target (TKScalar r) -> target (TKScalar r)
forall a. Num a => a -> a -> a
* target (TKScalar r)
loss

-- | A function testing the neural network given testing set of inputs
-- and the trained parameters.
convMnistTestR
  :: forall target r.
     (target ~ Concrete, GoodScalar r, Differentiable r)
  => Int
  -> MnistDataBatchR r
  -> ADCnnMnistParameters Concrete r
  -> r
convMnistTestR :: forall (target :: Target) r.
((target :: Target) ~ (Concrete :: Target), GoodScalar r,
 Differentiable r) =>
Int -> MnistDataBatchR r -> ADCnnMnistParameters Concrete r -> r
convMnistTestR Int
0 MnistDataBatchR r
_ ADCnnMnistParameters Concrete r
_ = r
0
convMnistTestR Int
batch_size (Ranked 3 r
glyphR, Ranked 2 r
labelR) ADCnnMnistParameters Concrete r
testParams =
  let input :: target (TKR 4 r)
      input :: target (TKR 4 r)
input =
        Ranked 4 r -> target (TKR 4 r)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete (Ranked 4 r -> target (TKR 4 r)) -> Ranked 4 r -> target (TKR 4 r)
forall a b. (a -> b) -> a -> b
$ ShR 4 Int -> Ranked 3 r -> Ranked 4 r
forall (n :: Nat) (n' :: Nat) a.
Elt a =>
IShR n' -> Ranked n a -> Ranked n' a
Nested.rreshape [ Int
Item (ShR 4 Int)
batch_size, Int
Item (ShR 4 Int)
1
                                    , Int
Item (ShR 4 Int)
sizeMnistHeightInt, Int
Item (ShR 4 Int)
sizeMnistWidthInt ]
                                    Ranked 3 r
glyphR
      outputR :: Concrete (TKR 2 r)
      outputR :: Concrete (TKR 2 r)
outputR =
        let nn :: ADCnnMnistParameters target r
               -> target (TKR 2 r)  -- [SizeMnistLabel, batch_size]
            nn :: ADCnnMnistParameters target r -> target (TKR 2 r)
nn = Int
-> Int
-> Int
-> PrimalOf target (TKR 4 r)
-> ADCnnMnistParameters target r
-> target (TKR 2 r)
forall (target :: Target) r.
(ADReady target, GoodScalar r, Differentiable r) =>
Int
-> Int
-> Int
-> PrimalOf target (TKR 4 r)
-> ADCnnMnistParameters target r
-> target (TKR 2 r)
convMnistTwoR Int
sizeMnistHeightInt Int
sizeMnistWidthInt
                               Int
batch_size target (TKR 4 r)
PrimalOf target (TKR 4 r)
input
        in ADCnnMnistParameters target r -> target (TKR 2 r)
nn ADCnnMnistParameters target r
ADCnnMnistParameters Concrete r
testParams
      outputs :: [Vector r]
outputs = (Concrete (TKR 1 r) -> Vector r)
-> [Concrete (TKR 1 r)] -> [Vector r]
forall a b. (a -> b) -> [a] -> [b]
map Concrete (TKR 1 r) -> Vector r
forall r (n :: Nat). GoodScalar r => Concrete (TKR n r) -> Vector r
rtoVector ([Concrete (TKR 1 r)] -> [Vector r])
-> [Concrete (TKR 1 r)] -> [Vector r]
forall a b. (a -> b) -> a -> b
$ Concrete (TKR2 (1 + 1) (TKScalar r)) -> [Concrete (TKR 1 r)]
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownSTK x, KnownNat n, BaseTensor target) =>
target (TKR2 (1 + n) x) -> [target (TKR2 n x)]
runravelToList
                (Concrete (TKR2 (1 + 1) (TKScalar r)) -> [Concrete (TKR 1 r)])
-> Concrete (TKR2 (1 + 1) (TKScalar r)) -> [Concrete (TKR 1 r)]
forall a b. (a -> b) -> a -> b
$ PermR -> Concrete (TKR 2 r) -> Concrete (TKR 2 r)
forall (n :: Nat) (x :: TK) (target :: Target).
(KnownSTK x, BaseTensor target) =>
PermR -> target (TKR2 n x) -> target (TKR2 n x)
rtranspose [Int
Item PermR
1, Int
Item PermR
0] Concrete (TKR 2 r)
outputR
      labels :: [Vector r]
labels = (Concrete (TKR 1 r) -> Vector r)
-> [Concrete (TKR 1 r)] -> [Vector r]
forall a b. (a -> b) -> [a] -> [b]
map Concrete (TKR 1 r) -> Vector r
forall r (n :: Nat). GoodScalar r => Concrete (TKR n r) -> Vector r
rtoVector
               ([Concrete (TKR 1 r)] -> [Vector r])
-> [Concrete (TKR 1 r)] -> [Vector r]
forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) (x :: TK) (target :: Target).
(KnownSTK x, KnownNat n, BaseTensor target) =>
target (TKR2 (1 + n) x) -> [target (TKR2 n x)]
runravelToList @_ @(TKScalar r)
               (Concrete (TKR2 (1 + 1) (TKScalar r)) -> [Concrete (TKR 1 r)])
-> Concrete (TKR2 (1 + 1) (TKScalar r)) -> [Concrete (TKR 1 r)]
forall a b. (a -> b) -> a -> b
$ Ranked 2 r -> Concrete (TKR 2 r)
forall r (target :: Target) (n :: Nat).
(GoodScalar r, BaseTensor target) =>
Ranked n r -> target (TKR n r)
rconcrete Ranked 2 r
labelR
      matchesLabels :: Vector r -> Vector r -> Int
      matchesLabels :: Vector r -> Vector r -> Int
matchesLabels Vector r
output Vector r
label | Vector r -> Int
forall (v :: Type -> Type) a. (Vector v a, Ord a) => v a -> Int
V.maxIndex Vector r
output Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Vector r -> Int
forall (v :: Type -> Type) a. (Vector v a, Ord a) => v a -> Int
V.maxIndex Vector r
label = Int
1
                                 | Bool
otherwise = Int
0
  in Int -> r
forall a b. (Integral a, Num b) => a -> b
fromIntegral (PermR -> Int
forall a. Num a => [a] -> a
forall (t :: Type -> Type) a. (Foldable t, Num a) => t a -> a
sum ((Vector r -> Vector r -> Int) -> [Vector r] -> [Vector r] -> PermR
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Vector r -> Vector r -> Int
matchesLabels [Vector r]
outputs [Vector r]
labels))
     r -> r -> r
forall a. Fractional a => a -> a -> a
/ Int -> r
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
batch_size