{-# OPTIONS_GHC -Wno-missing-export-lists #-}
-- | Parsing and pre-processing of MNIST data.
module MnistData where

import Prelude

import Codec.Compression.GZip (decompress)
import Data.ByteString.Lazy qualified as LBS
import Data.IDX
import Data.List (sortBy)
import Data.List.NonEmpty qualified as NonEmpty
import Data.Maybe (fromMaybe)
import Data.Ord (comparing)
import Data.Vector.Generic qualified as V
import Data.Vector.Storable (Vector)
import Data.Vector.Storable qualified as VS
import Data.Vector.Unboxed qualified
import GHC.TypeLits (KnownNat, Nat, type (*))
import System.IO (IOMode (ReadMode), withBinaryFile)
import System.Random

import Data.Array.Nested qualified as Nested
import Data.Array.Nested.Ranked.Shape
import Data.Array.Nested.Shaped.Shape

import HordeAd

type SizeMnistWidth = 28 :: Nat

sizeMnistWidth :: SNat SizeMnistWidth
sizeMnistWidth :: SNat SizeMnistHeight
sizeMnistWidth = forall (n :: Nat). KnownNat n => SNat n
SNat @SizeMnistWidth

sizeMnistWidthInt :: Int
sizeMnistWidthInt :: Int
sizeMnistWidthInt = SNat SizeMnistHeight -> Int
forall (n :: Nat). SNat n -> Int
sNatValue SNat SizeMnistHeight
sizeMnistWidth

type SizeMnistHeight = SizeMnistWidth

sizeMnistHeight :: SNat SizeMnistHeight
sizeMnistHeight :: SNat SizeMnistHeight
sizeMnistHeight = forall (n :: Nat). KnownNat n => SNat n
SNat @SizeMnistHeight

sizeMnistHeightInt :: Int
sizeMnistHeightInt :: Int
sizeMnistHeightInt = forall (n :: Nat) r. (KnownNat n, Num r) => r
valueOf @SizeMnistHeight

type SizeMnistGlyph = SizeMnistWidth * SizeMnistHeight

sizeMnistGlyphInt :: Int
sizeMnistGlyphInt :: Int
sizeMnistGlyphInt = forall (n :: Nat) r. (KnownNat n, Num r) => r
valueOf @SizeMnistGlyph

type SizeMnistLabel = 10 :: Nat

sizeMnistLabel :: SNat SizeMnistLabel
sizeMnistLabel :: SNat SizeMnistLabel
sizeMnistLabel = forall (n :: Nat). KnownNat n => SNat n
SNat @SizeMnistLabel

sizeMnistLabelInt :: Int
sizeMnistLabelInt :: Int
sizeMnistLabelInt = SNat SizeMnistLabel -> Int
forall (n :: Nat). SNat n -> Int
sNatValue SNat SizeMnistLabel
sizeMnistLabel

type LengthTestData = 10000 :: Nat

-- Actually, a better representation, supported by @Data.IDX@,
-- is an integer label and a picture (the same vector as below).
-- Then we'd use @lossCrossEntropy@ that picks a component according
-- to the label instead of performing a dot product with scaling.
-- This results in much smaller Delta expressions.
-- Our library makes this easy to express and gradients compute fine.
-- OTOH, methods with only matrix operations and graphs can't handle that.
-- However, the goal of the exercise it to implement the same
-- neural net that backprop uses for comparative benchmarks.
-- Also, loss computation is not the bottleneck and the more general
-- mechanism that admits non-discrete target labels fuses nicely
-- with softMax. This also seems to be the standard or at least
-- a simple default in tutorial.
type MnistData r = (Vector r, Vector r)

type MnistDataLinearR r =
  ( Nested.Ranked 1 r
  , Nested.Ranked 1 r )

type MnistDataR r =
  ( Nested.Ranked 2 r
  , Nested.Ranked 1 r )

type MnistDataBatchR r =
  ( Nested.Ranked 3 r  -- [batch_size, SizeMnistHeight, SizeMnistWidth]
  , Nested.Ranked 2 r )  -- [batch_size, SizeMnistLabel]

type MnistDataS r =
  ( Nested.Shaped '[SizeMnistHeight, SizeMnistWidth] r
  , Nested.Shaped '[SizeMnistLabel] r )

type MnistDataBatchS batch_size r =
  ( Nested.Shaped '[batch_size, SizeMnistHeight, SizeMnistWidth] r
  , Nested.Shaped '[batch_size, SizeMnistLabel] r )

mkMnistDataLinearR :: Nested.PrimElt r
                   => MnistData r -> MnistDataLinearR r
mkMnistDataLinearR :: forall r. PrimElt r => MnistData r -> MnistDataLinearR r
mkMnistDataLinearR (Vector r
input, Vector r
target) =
  ( IShR 1 -> Vector r -> Ranked 1 r
forall (n :: Nat) a. PrimElt a => IShR n -> Vector a -> Ranked n a
Nested.rfromVector
      (Int
sizeMnistGlyphInt Int -> ShR 0 Int -> IShR 1
forall {n1 :: Nat} {i} (n :: Nat).
((n + 1 :: Nat) ~ (n1 :: Nat)) =>
i -> ShR n i -> ShR n1 i
:$: ShR 0 Int
forall (n :: Nat) i. ((n :: Nat) ~ (0 :: Nat)) => ShR n i
ZSR) Vector r
input
  , IShR 1 -> Vector r -> Ranked 1 r
forall (n :: Nat) a. PrimElt a => IShR n -> Vector a -> Ranked n a
Nested.rfromVector
      (Int
sizeMnistLabelInt Int -> ShR 0 Int -> IShR 1
forall {n1 :: Nat} {i} (n :: Nat).
((n + 1 :: Nat) ~ (n1 :: Nat)) =>
i -> ShR n i -> ShR n1 i
:$: ShR 0 Int
forall (n :: Nat) i. ((n :: Nat) ~ (0 :: Nat)) => ShR n i
ZSR) Vector r
target )

mkMnistDataR :: Nested.PrimElt r
             => MnistData r -> MnistDataR r
mkMnistDataR :: forall r. PrimElt r => MnistData r -> MnistDataR r
mkMnistDataR (Vector r
input, Vector r
target) =
  ( IShR 2 -> Vector r -> Ranked 2 r
forall (n :: Nat) a. PrimElt a => IShR n -> Vector a -> Ranked n a
Nested.rfromVector
      (Int
sizeMnistHeightInt Int -> IShR 1 -> IShR 2
forall {n1 :: Nat} {i} (n :: Nat).
((n + 1 :: Nat) ~ (n1 :: Nat)) =>
i -> ShR n i -> ShR n1 i
:$: Int
sizeMnistWidthInt Int -> ShR 0 Int -> IShR 1
forall {n1 :: Nat} {i} (n :: Nat).
((n + 1 :: Nat) ~ (n1 :: Nat)) =>
i -> ShR n i -> ShR n1 i
:$: ShR 0 Int
forall (n :: Nat) i. ((n :: Nat) ~ (0 :: Nat)) => ShR n i
ZSR) Vector r
input
  , IShR 1 -> Vector r -> Ranked 1 r
forall (n :: Nat) a. PrimElt a => IShR n -> Vector a -> Ranked n a
Nested.rfromVector
      (Int
sizeMnistLabelInt Int -> ShR 0 Int -> IShR 1
forall {n1 :: Nat} {i} (n :: Nat).
((n + 1 :: Nat) ~ (n1 :: Nat)) =>
i -> ShR n i -> ShR n1 i
:$: ShR 0 Int
forall (n :: Nat) i. ((n :: Nat) ~ (0 :: Nat)) => ShR n i
ZSR) Vector r
target )

mkMnistDataBatchR :: Nested.Elt r
                  => [MnistDataR r] -> MnistDataBatchR r
mkMnistDataBatchR :: forall r. Elt r => [MnistDataR r] -> MnistDataBatchR r
mkMnistDataBatchR [MnistDataR r]
l =
  let ([Ranked 2 r]
inputs, [Ranked 1 r]
targets) = [MnistDataR r] -> ([Ranked 2 r], [Ranked 1 r])
forall a b. [(a, b)] -> ([a], [b])
unzip [MnistDataR r]
l
  in ( NonEmpty (Ranked 2 r) -> Ranked (2 + 1) r
forall (n :: Nat) a.
Elt a =>
NonEmpty (Ranked n a) -> Ranked (n + 1) a
Nested.rfromListOuter (NonEmpty (Ranked 2 r) -> Ranked (2 + 1) r)
-> NonEmpty (Ranked 2 r) -> Ranked (2 + 1) r
forall a b. (a -> b) -> a -> b
$ [Ranked 2 r] -> NonEmpty (Ranked 2 r)
forall a. HasCallStack => [a] -> NonEmpty a
NonEmpty.fromList [Ranked 2 r]
inputs
     , NonEmpty (Ranked 1 r) -> Ranked (1 + 1) r
forall (n :: Nat) a.
Elt a =>
NonEmpty (Ranked n a) -> Ranked (n + 1) a
Nested.rfromListOuter (NonEmpty (Ranked 1 r) -> Ranked (1 + 1) r)
-> NonEmpty (Ranked 1 r) -> Ranked (1 + 1) r
forall a b. (a -> b) -> a -> b
$ [Ranked 1 r] -> NonEmpty (Ranked 1 r)
forall a. HasCallStack => [a] -> NonEmpty a
NonEmpty.fromList [Ranked 1 r]
targets )

mkMnistDataS :: Nested.PrimElt r
             => MnistData r -> MnistDataS r
mkMnistDataS :: forall r. PrimElt r => MnistData r -> MnistDataS r
mkMnistDataS (Vector r
input, Vector r
target) =
  (ShS
  ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat)))
-> Vector r
-> Shaped
     ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat)))
     r
forall a (sh :: [Nat]).
PrimElt a =>
ShS sh -> Vector a -> Shaped sh a
Nested.sfromVector ShS
  ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat)))
forall (sh :: [Nat]). KnownShS sh => ShS sh
knownShS Vector r
input, ShS ((':) @Nat SizeMnistLabel ('[] @Nat))
-> Vector r -> Shaped ((':) @Nat SizeMnistLabel ('[] @Nat)) r
forall a (sh :: [Nat]).
PrimElt a =>
ShS sh -> Vector a -> Shaped sh a
Nested.sfromVector ShS ((':) @Nat SizeMnistLabel ('[] @Nat))
forall (sh :: [Nat]). KnownShS sh => ShS sh
knownShS Vector r
target)
{-# SPECIALIZE mkMnistDataS :: MnistData Double -> MnistDataS Double #-}
{-# SPECIALIZE mkMnistDataS :: MnistData Float -> MnistDataS Float #-}

mkMnistDataBatchS :: forall batch_size r. (Nested.Elt r, KnownNat batch_size)
                  => [MnistDataS r] -> MnistDataBatchS batch_size r
mkMnistDataBatchS :: forall (batch_size :: Nat) r.
(Elt r, KnownNat batch_size) =>
[MnistDataS r] -> MnistDataBatchS batch_size r
mkMnistDataBatchS [MnistDataS r]
l =
  let ([Shaped
   ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat)))
   r]
inputs, [Shaped ((':) @Nat SizeMnistLabel ('[] @Nat)) r]
targets) = [MnistDataS r]
-> ([Shaped
       ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat)))
       r],
    [Shaped ((':) @Nat SizeMnistLabel ('[] @Nat)) r])
forall a b. [(a, b)] -> ([a], [b])
unzip [MnistDataS r]
l
  in ( SNat batch_size
-> NonEmpty
     (Shaped
        ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat)))
        r)
-> Shaped
     ((':)
        @Nat
        batch_size
        ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
     r
forall a (n :: Nat) (sh :: [Nat]).
Elt a =>
SNat n -> NonEmpty (Shaped sh a) -> Shaped ((':) @Nat n sh) a
Nested.sfromListOuter (forall (n :: Nat). KnownNat n => SNat n
SNat @batch_size)
       (NonEmpty
   (Shaped
      ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat)))
      r)
 -> Shaped
      ((':)
         @Nat
         batch_size
         ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
      r)
-> NonEmpty
     (Shaped
        ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat)))
        r)
-> Shaped
     ((':)
        @Nat
        batch_size
        ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
     r
forall a b. (a -> b) -> a -> b
$ [Shaped
   ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat)))
   r]
-> NonEmpty
     (Shaped
        ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat)))
        r)
forall a. HasCallStack => [a] -> NonEmpty a
NonEmpty.fromList [Shaped
   ((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat)))
   r]
inputs
     , SNat batch_size
-> NonEmpty (Shaped ((':) @Nat SizeMnistLabel ('[] @Nat)) r)
-> Shaped
     ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r
forall a (n :: Nat) (sh :: [Nat]).
Elt a =>
SNat n -> NonEmpty (Shaped sh a) -> Shaped ((':) @Nat n sh) a
Nested.sfromListOuter (forall (n :: Nat). KnownNat n => SNat n
SNat @batch_size)
       (NonEmpty (Shaped ((':) @Nat SizeMnistLabel ('[] @Nat)) r)
 -> Shaped
      ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)
-> NonEmpty (Shaped ((':) @Nat SizeMnistLabel ('[] @Nat)) r)
-> Shaped
     ((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r
forall a b. (a -> b) -> a -> b
$ [Shaped ((':) @Nat SizeMnistLabel ('[] @Nat)) r]
-> NonEmpty (Shaped ((':) @Nat SizeMnistLabel ('[] @Nat)) r)
forall a. HasCallStack => [a] -> NonEmpty a
NonEmpty.fromList [Shaped ((':) @Nat SizeMnistLabel ('[] @Nat)) r]
targets )
{-# SPECIALIZE mkMnistDataBatchS :: forall batch_size. KnownNat batch_size => [MnistDataS Double] -> MnistDataBatchS batch_size Double #-}
{-# SPECIALIZE mkMnistDataBatchS :: forall batch_size. KnownNat batch_size => [MnistDataS Float] -> MnistDataBatchS batch_size Float #-}

readMnistData :: forall r. (VS.Storable r, Fractional r)
              => LBS.ByteString -> LBS.ByteString -> [MnistData r]
readMnistData :: forall r.
(Storable r, Fractional r) =>
ByteString -> ByteString -> [MnistData r]
readMnistData ByteString
glyphsBS ByteString
labelsBS =
  let glyphs :: IDXData
glyphs = IDXData -> Maybe IDXData -> IDXData
forall a. a -> Maybe a -> a
fromMaybe ([Char] -> IDXData
forall a. HasCallStack => [Char] -> a
error [Char]
"wrong MNIST glyphs file")
               (Maybe IDXData -> IDXData) -> Maybe IDXData -> IDXData
forall a b. (a -> b) -> a -> b
$ ByteString -> Maybe IDXData
decodeIDX ByteString
glyphsBS
      labels :: IDXLabels
labels = IDXLabels -> Maybe IDXLabels -> IDXLabels
forall a. a -> Maybe a -> a
fromMaybe ([Char] -> IDXLabels
forall a. HasCallStack => [Char] -> a
error [Char]
"wrong MNIST labels file")
               (Maybe IDXLabels -> IDXLabels) -> Maybe IDXLabels -> IDXLabels
forall a b. (a -> b) -> a -> b
$ ByteString -> Maybe IDXLabels
decodeIDXLabels ByteString
labelsBS
      intData :: [(Int, Vector Int)]
intData = [(Int, Vector Int)]
-> Maybe [(Int, Vector Int)] -> [(Int, Vector Int)]
forall a. a -> Maybe a -> a
fromMaybe ([Char] -> [(Int, Vector Int)]
forall a. HasCallStack => [Char] -> a
error [Char]
"can't decode MNIST file into integers")
                (Maybe [(Int, Vector Int)] -> [(Int, Vector Int)])
-> Maybe [(Int, Vector Int)] -> [(Int, Vector Int)]
forall a b. (a -> b) -> a -> b
$ IDXLabels -> IDXData -> Maybe [(Int, Vector Int)]
labeledIntData IDXLabels
labels IDXData
glyphs
      f :: (Int, Data.Vector.Unboxed.Vector Int) -> MnistData r
      -- Copied from library backprop to enable comparison of results.
      -- I have no idea how this is different from @labeledDoubleData@, etc.
      f :: (Int, Vector Int) -> MnistData r
f (Int
labN, Vector Int
v) =
        let !vGlyph :: Vector r
vGlyph = (Int -> r) -> Vector Int -> Vector r
forall (v :: Type -> Type) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
V.map (\Int
r -> Int -> r
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
r r -> r -> r
forall a. Fractional a => a -> a -> a
/ r
255) (Vector Int -> Vector r) -> Vector Int -> Vector r
forall a b. (a -> b) -> a -> b
$ Vector Int -> Vector Int
forall (v :: Type -> Type) a (w :: Type -> Type).
(Vector v a, Vector w a) =>
v a -> w a
V.convert Vector Int
v
            !vLabel :: Vector r
vLabel = Int -> (Int -> r) -> Vector r
forall (v :: Type -> Type) a.
Vector v a =>
Int -> (Int -> a) -> v a
V.generate Int
sizeMnistLabelInt
                                 (\Int
i -> if Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
labN then r
1 else r
0)
        in (Vector r
vGlyph, Vector r
vLabel)
  in ((Int, Vector Int) -> MnistData r)
-> [(Int, Vector Int)] -> [MnistData r]
forall a b. (a -> b) -> [a] -> [b]
map (Int, Vector Int) -> MnistData r
f [(Int, Vector Int)]
intData
{-# SPECIALIZE readMnistData :: LBS.ByteString -> LBS.ByteString -> [MnistData Double] #-}
{-# SPECIALIZE readMnistData :: LBS.ByteString -> LBS.ByteString -> [MnistData Float] #-}

trainGlyphsPath, trainLabelsPath, testGlyphsPath, testLabelsPath :: FilePath
trainGlyphsPath :: [Char]
trainGlyphsPath = [Char]
"samplesData/train-images-idx3-ubyte.gz"
trainLabelsPath :: [Char]
trainLabelsPath = [Char]
"samplesData/train-labels-idx1-ubyte.gz"
testGlyphsPath :: [Char]
testGlyphsPath  = [Char]
"samplesData/t10k-images-idx3-ubyte.gz"
testLabelsPath :: [Char]
testLabelsPath  = [Char]
"samplesData/t10k-labels-idx1-ubyte.gz"

loadMnistData :: (VS.Storable r, Fractional r)
              => FilePath -> FilePath -> IO [MnistData r]
loadMnistData :: forall r.
(Storable r, Fractional r) =>
[Char] -> [Char] -> IO [MnistData r]
loadMnistData [Char]
glyphsPath [Char]
labelsPath =
  [Char]
-> IOMode -> (Handle -> IO [MnistData r]) -> IO [MnistData r]
forall r. [Char] -> IOMode -> (Handle -> IO r) -> IO r
withBinaryFile [Char]
glyphsPath IOMode
ReadMode ((Handle -> IO [MnistData r]) -> IO [MnistData r])
-> (Handle -> IO [MnistData r]) -> IO [MnistData r]
forall a b. (a -> b) -> a -> b
$ \Handle
glyphsHandle ->
    [Char]
-> IOMode -> (Handle -> IO [MnistData r]) -> IO [MnistData r]
forall r. [Char] -> IOMode -> (Handle -> IO r) -> IO r
withBinaryFile [Char]
labelsPath IOMode
ReadMode ((Handle -> IO [MnistData r]) -> IO [MnistData r])
-> (Handle -> IO [MnistData r]) -> IO [MnistData r]
forall a b. (a -> b) -> a -> b
$ \Handle
labelsHandle -> do
      glyphsContents <- Handle -> IO ByteString
LBS.hGetContents Handle
glyphsHandle
      labelsContents <- LBS.hGetContents labelsHandle
      return $! readMnistData (decompress glyphsContents)
                              (decompress labelsContents)
{-# SPECIALIZE loadMnistData :: FilePath -> FilePath -> IO [MnistData Double] #-}
{-# SPECIALIZE loadMnistData :: FilePath -> FilePath -> IO [MnistData Float] #-}

-- Good enough for QuickCheck, so good enough for me.
shuffle :: StdGen -> [a] -> [a]
shuffle :: forall a. StdGen -> [a] -> [a]
shuffle StdGen
g [a]
l =
  let rnds :: [Int]
rnds = StdGen -> [Int]
forall g. RandomGen g => g -> [Int]
forall a g. (Random a, RandomGen g) => g -> [a]
randoms StdGen
g :: [Int]
      res :: [a]
res = ((a, Int) -> a) -> [(a, Int)] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (a, Int) -> a
forall a b. (a, b) -> a
fst ([(a, Int)] -> [a]) -> [(a, Int)] -> [a]
forall a b. (a -> b) -> a -> b
$ ((a, Int) -> (a, Int) -> Ordering) -> [(a, Int)] -> [(a, Int)]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (((a, Int) -> Int) -> (a, Int) -> (a, Int) -> Ordering
forall a b. Ord a => (b -> a) -> b -> b -> Ordering
comparing (a, Int) -> Int
forall a b. (a, b) -> b
snd) ([(a, Int)] -> [(a, Int)]) -> [(a, Int)] -> [(a, Int)]
forall a b. (a -> b) -> a -> b
$ [a] -> [Int] -> [(a, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [a]
l [Int]
rnds
  in (a -> () -> ()) -> () -> [a] -> ()
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: Type -> Type) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr a -> () -> ()
forall a b. a -> b -> b
seq () [a]
res () -> [a] -> [a]
forall a b. a -> b -> b
`seq` [a]
res

chunksOf :: Int -> [e] -> [[e]]
chunksOf :: forall e. Int -> [e] -> [[e]]
chunksOf Int
n = [e] -> [[e]]
go where
  go :: [e] -> [[e]]
go [] = []
  go [e]
l = let ([e]
chunk, [e]
rest) = Int -> [e] -> ([e], [e])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
n [e]
l
         in [e]
chunk [e] -> [[e]] -> [[e]]
forall a. a -> [a] -> [a]
: [e] -> [[e]]
go [e]
rest