{-# OPTIONS_GHC -Wno-missing-export-lists #-}
module MnistData where
import Prelude
import Codec.Compression.GZip (decompress)
import Data.ByteString.Lazy qualified as LBS
import Data.IDX
import Data.List (sortBy)
import Data.List.NonEmpty qualified as NonEmpty
import Data.Maybe (fromMaybe)
import Data.Ord (comparing)
import Data.Vector.Generic qualified as V
import Data.Vector.Storable (Vector)
import Data.Vector.Storable qualified as VS
import Data.Vector.Unboxed qualified
import GHC.TypeLits (KnownNat, Nat, type (*))
import System.IO (IOMode (ReadMode), withBinaryFile)
import System.Random
import Data.Array.Nested qualified as Nested
import Data.Array.Nested.Ranked.Shape
import Data.Array.Nested.Shaped.Shape
import HordeAd
type SizeMnistWidth = 28 :: Nat
sizeMnistWidth :: SNat SizeMnistWidth
sizeMnistWidth :: SNat SizeMnistHeight
sizeMnistWidth = forall (n :: Nat). KnownNat n => SNat n
SNat @SizeMnistWidth
sizeMnistWidthInt :: Int
sizeMnistWidthInt :: Int
sizeMnistWidthInt = SNat SizeMnistHeight -> Int
forall (n :: Nat). SNat n -> Int
sNatValue SNat SizeMnistHeight
sizeMnistWidth
type SizeMnistHeight = SizeMnistWidth
sizeMnistHeight :: SNat SizeMnistHeight
sizeMnistHeight :: SNat SizeMnistHeight
sizeMnistHeight = forall (n :: Nat). KnownNat n => SNat n
SNat @SizeMnistHeight
sizeMnistHeightInt :: Int
sizeMnistHeightInt :: Int
sizeMnistHeightInt = forall (n :: Nat) r. (KnownNat n, Num r) => r
valueOf @SizeMnistHeight
type SizeMnistGlyph = SizeMnistWidth * SizeMnistHeight
sizeMnistGlyphInt :: Int
sizeMnistGlyphInt :: Int
sizeMnistGlyphInt = forall (n :: Nat) r. (KnownNat n, Num r) => r
valueOf @SizeMnistGlyph
type SizeMnistLabel = 10 :: Nat
sizeMnistLabel :: SNat SizeMnistLabel
sizeMnistLabel :: SNat SizeMnistLabel
sizeMnistLabel = forall (n :: Nat). KnownNat n => SNat n
SNat @SizeMnistLabel
sizeMnistLabelInt :: Int
sizeMnistLabelInt :: Int
sizeMnistLabelInt = SNat SizeMnistLabel -> Int
forall (n :: Nat). SNat n -> Int
sNatValue SNat SizeMnistLabel
sizeMnistLabel
type LengthTestData = 10000 :: Nat
type MnistData r = (Vector r, Vector r)
type MnistDataLinearR r =
( Nested.Ranked 1 r
, Nested.Ranked 1 r )
type MnistDataR r =
( Nested.Ranked 2 r
, Nested.Ranked 1 r )
type MnistDataBatchR r =
( Nested.Ranked 3 r
, Nested.Ranked 2 r )
type MnistDataS r =
( Nested.Shaped '[SizeMnistHeight, SizeMnistWidth] r
, Nested.Shaped '[SizeMnistLabel] r )
type MnistDataBatchS batch_size r =
( Nested.Shaped '[batch_size, SizeMnistHeight, SizeMnistWidth] r
, Nested.Shaped '[batch_size, SizeMnistLabel] r )
mkMnistDataLinearR :: Nested.PrimElt r
=> MnistData r -> MnistDataLinearR r
mkMnistDataLinearR :: forall r. PrimElt r => MnistData r -> MnistDataLinearR r
mkMnistDataLinearR (Vector r
input, Vector r
target) =
( IShR 1 -> Vector r -> Ranked 1 r
forall (n :: Nat) a. PrimElt a => IShR n -> Vector a -> Ranked n a
Nested.rfromVector
(Int
sizeMnistGlyphInt Int -> ShR 0 Int -> IShR 1
forall {n1 :: Nat} {i} (n :: Nat).
((n + 1 :: Nat) ~ (n1 :: Nat)) =>
i -> ShR n i -> ShR n1 i
:$: ShR 0 Int
forall (n :: Nat) i. ((n :: Nat) ~ (0 :: Nat)) => ShR n i
ZSR) Vector r
input
, IShR 1 -> Vector r -> Ranked 1 r
forall (n :: Nat) a. PrimElt a => IShR n -> Vector a -> Ranked n a
Nested.rfromVector
(Int
sizeMnistLabelInt Int -> ShR 0 Int -> IShR 1
forall {n1 :: Nat} {i} (n :: Nat).
((n + 1 :: Nat) ~ (n1 :: Nat)) =>
i -> ShR n i -> ShR n1 i
:$: ShR 0 Int
forall (n :: Nat) i. ((n :: Nat) ~ (0 :: Nat)) => ShR n i
ZSR) Vector r
target )
mkMnistDataR :: Nested.PrimElt r
=> MnistData r -> MnistDataR r
mkMnistDataR :: forall r. PrimElt r => MnistData r -> MnistDataR r
mkMnistDataR (Vector r
input, Vector r
target) =
( IShR 2 -> Vector r -> Ranked 2 r
forall (n :: Nat) a. PrimElt a => IShR n -> Vector a -> Ranked n a
Nested.rfromVector
(Int
sizeMnistHeightInt Int -> IShR 1 -> IShR 2
forall {n1 :: Nat} {i} (n :: Nat).
((n + 1 :: Nat) ~ (n1 :: Nat)) =>
i -> ShR n i -> ShR n1 i
:$: Int
sizeMnistWidthInt Int -> ShR 0 Int -> IShR 1
forall {n1 :: Nat} {i} (n :: Nat).
((n + 1 :: Nat) ~ (n1 :: Nat)) =>
i -> ShR n i -> ShR n1 i
:$: ShR 0 Int
forall (n :: Nat) i. ((n :: Nat) ~ (0 :: Nat)) => ShR n i
ZSR) Vector r
input
, IShR 1 -> Vector r -> Ranked 1 r
forall (n :: Nat) a. PrimElt a => IShR n -> Vector a -> Ranked n a
Nested.rfromVector
(Int
sizeMnistLabelInt Int -> ShR 0 Int -> IShR 1
forall {n1 :: Nat} {i} (n :: Nat).
((n + 1 :: Nat) ~ (n1 :: Nat)) =>
i -> ShR n i -> ShR n1 i
:$: ShR 0 Int
forall (n :: Nat) i. ((n :: Nat) ~ (0 :: Nat)) => ShR n i
ZSR) Vector r
target )
mkMnistDataBatchR :: Nested.Elt r
=> [MnistDataR r] -> MnistDataBatchR r
mkMnistDataBatchR :: forall r. Elt r => [MnistDataR r] -> MnistDataBatchR r
mkMnistDataBatchR [MnistDataR r]
l =
let ([Ranked 2 r]
inputs, [Ranked 1 r]
targets) = [MnistDataR r] -> ([Ranked 2 r], [Ranked 1 r])
forall a b. [(a, b)] -> ([a], [b])
unzip [MnistDataR r]
l
in ( NonEmpty (Ranked 2 r) -> Ranked (2 + 1) r
forall (n :: Nat) a.
Elt a =>
NonEmpty (Ranked n a) -> Ranked (n + 1) a
Nested.rfromListOuter (NonEmpty (Ranked 2 r) -> Ranked (2 + 1) r)
-> NonEmpty (Ranked 2 r) -> Ranked (2 + 1) r
forall a b. (a -> b) -> a -> b
$ [Ranked 2 r] -> NonEmpty (Ranked 2 r)
forall a. HasCallStack => [a] -> NonEmpty a
NonEmpty.fromList [Ranked 2 r]
inputs
, NonEmpty (Ranked 1 r) -> Ranked (1 + 1) r
forall (n :: Nat) a.
Elt a =>
NonEmpty (Ranked n a) -> Ranked (n + 1) a
Nested.rfromListOuter (NonEmpty (Ranked 1 r) -> Ranked (1 + 1) r)
-> NonEmpty (Ranked 1 r) -> Ranked (1 + 1) r
forall a b. (a -> b) -> a -> b
$ [Ranked 1 r] -> NonEmpty (Ranked 1 r)
forall a. HasCallStack => [a] -> NonEmpty a
NonEmpty.fromList [Ranked 1 r]
targets )
mkMnistDataS :: Nested.PrimElt r
=> MnistData r -> MnistDataS r
mkMnistDataS :: forall r. PrimElt r => MnistData r -> MnistDataS r
mkMnistDataS (Vector r
input, Vector r
target) =
(ShS
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat)))
-> Vector r
-> Shaped
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat)))
r
forall a (sh :: [Nat]).
PrimElt a =>
ShS sh -> Vector a -> Shaped sh a
Nested.sfromVector ShS
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat)))
forall (sh :: [Nat]). KnownShS sh => ShS sh
knownShS Vector r
input, ShS ((':) @Nat SizeMnistLabel ('[] @Nat))
-> Vector r -> Shaped ((':) @Nat SizeMnistLabel ('[] @Nat)) r
forall a (sh :: [Nat]).
PrimElt a =>
ShS sh -> Vector a -> Shaped sh a
Nested.sfromVector ShS ((':) @Nat SizeMnistLabel ('[] @Nat))
forall (sh :: [Nat]). KnownShS sh => ShS sh
knownShS Vector r
target)
{-# SPECIALIZE mkMnistDataS :: MnistData Double -> MnistDataS Double #-}
{-# SPECIALIZE mkMnistDataS :: MnistData Float -> MnistDataS Float #-}
mkMnistDataBatchS :: forall batch_size r. (Nested.Elt r, KnownNat batch_size)
=> [MnistDataS r] -> MnistDataBatchS batch_size r
mkMnistDataBatchS :: forall (batch_size :: Nat) r.
(Elt r, KnownNat batch_size) =>
[MnistDataS r] -> MnistDataBatchS batch_size r
mkMnistDataBatchS [MnistDataS r]
l =
let ([Shaped
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat)))
r]
inputs, [Shaped ((':) @Nat SizeMnistLabel ('[] @Nat)) r]
targets) = [MnistDataS r]
-> ([Shaped
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat)))
r],
[Shaped ((':) @Nat SizeMnistLabel ('[] @Nat)) r])
forall a b. [(a, b)] -> ([a], [b])
unzip [MnistDataS r]
l
in ( SNat batch_size
-> NonEmpty
(Shaped
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat)))
r)
-> Shaped
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r
forall a (n :: Nat) (sh :: [Nat]).
Elt a =>
SNat n -> NonEmpty (Shaped sh a) -> Shaped ((':) @Nat n sh) a
Nested.sfromListOuter (forall (n :: Nat). KnownNat n => SNat n
SNat @batch_size)
(NonEmpty
(Shaped
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat)))
r)
-> Shaped
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r)
-> NonEmpty
(Shaped
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat)))
r)
-> Shaped
((':)
@Nat
batch_size
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat))))
r
forall a b. (a -> b) -> a -> b
$ [Shaped
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat)))
r]
-> NonEmpty
(Shaped
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat)))
r)
forall a. HasCallStack => [a] -> NonEmpty a
NonEmpty.fromList [Shaped
((':) @Nat SizeMnistHeight ((':) @Nat SizeMnistHeight ('[] @Nat)))
r]
inputs
, SNat batch_size
-> NonEmpty (Shaped ((':) @Nat SizeMnistLabel ('[] @Nat)) r)
-> Shaped
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r
forall a (n :: Nat) (sh :: [Nat]).
Elt a =>
SNat n -> NonEmpty (Shaped sh a) -> Shaped ((':) @Nat n sh) a
Nested.sfromListOuter (forall (n :: Nat). KnownNat n => SNat n
SNat @batch_size)
(NonEmpty (Shaped ((':) @Nat SizeMnistLabel ('[] @Nat)) r)
-> Shaped
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r)
-> NonEmpty (Shaped ((':) @Nat SizeMnistLabel ('[] @Nat)) r)
-> Shaped
((':) @Nat batch_size ((':) @Nat SizeMnistLabel ('[] @Nat))) r
forall a b. (a -> b) -> a -> b
$ [Shaped ((':) @Nat SizeMnistLabel ('[] @Nat)) r]
-> NonEmpty (Shaped ((':) @Nat SizeMnistLabel ('[] @Nat)) r)
forall a. HasCallStack => [a] -> NonEmpty a
NonEmpty.fromList [Shaped ((':) @Nat SizeMnistLabel ('[] @Nat)) r]
targets )
{-# SPECIALIZE mkMnistDataBatchS :: forall batch_size. KnownNat batch_size => [MnistDataS Double] -> MnistDataBatchS batch_size Double #-}
{-# SPECIALIZE mkMnistDataBatchS :: forall batch_size. KnownNat batch_size => [MnistDataS Float] -> MnistDataBatchS batch_size Float #-}
readMnistData :: forall r. (VS.Storable r, Fractional r)
=> LBS.ByteString -> LBS.ByteString -> [MnistData r]
readMnistData :: forall r.
(Storable r, Fractional r) =>
ByteString -> ByteString -> [MnistData r]
readMnistData ByteString
glyphsBS ByteString
labelsBS =
let glyphs :: IDXData
glyphs = IDXData -> Maybe IDXData -> IDXData
forall a. a -> Maybe a -> a
fromMaybe ([Char] -> IDXData
forall a. HasCallStack => [Char] -> a
error [Char]
"wrong MNIST glyphs file")
(Maybe IDXData -> IDXData) -> Maybe IDXData -> IDXData
forall a b. (a -> b) -> a -> b
$ ByteString -> Maybe IDXData
decodeIDX ByteString
glyphsBS
labels :: IDXLabels
labels = IDXLabels -> Maybe IDXLabels -> IDXLabels
forall a. a -> Maybe a -> a
fromMaybe ([Char] -> IDXLabels
forall a. HasCallStack => [Char] -> a
error [Char]
"wrong MNIST labels file")
(Maybe IDXLabels -> IDXLabels) -> Maybe IDXLabels -> IDXLabels
forall a b. (a -> b) -> a -> b
$ ByteString -> Maybe IDXLabels
decodeIDXLabels ByteString
labelsBS
intData :: [(Int, Vector Int)]
intData = [(Int, Vector Int)]
-> Maybe [(Int, Vector Int)] -> [(Int, Vector Int)]
forall a. a -> Maybe a -> a
fromMaybe ([Char] -> [(Int, Vector Int)]
forall a. HasCallStack => [Char] -> a
error [Char]
"can't decode MNIST file into integers")
(Maybe [(Int, Vector Int)] -> [(Int, Vector Int)])
-> Maybe [(Int, Vector Int)] -> [(Int, Vector Int)]
forall a b. (a -> b) -> a -> b
$ IDXLabels -> IDXData -> Maybe [(Int, Vector Int)]
labeledIntData IDXLabels
labels IDXData
glyphs
f :: (Int, Data.Vector.Unboxed.Vector Int) -> MnistData r
f :: (Int, Vector Int) -> MnistData r
f (Int
labN, Vector Int
v) =
let !vGlyph :: Vector r
vGlyph = (Int -> r) -> Vector Int -> Vector r
forall (v :: Type -> Type) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
V.map (\Int
r -> Int -> r
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
r r -> r -> r
forall a. Fractional a => a -> a -> a
/ r
255) (Vector Int -> Vector r) -> Vector Int -> Vector r
forall a b. (a -> b) -> a -> b
$ Vector Int -> Vector Int
forall (v :: Type -> Type) a (w :: Type -> Type).
(Vector v a, Vector w a) =>
v a -> w a
V.convert Vector Int
v
!vLabel :: Vector r
vLabel = Int -> (Int -> r) -> Vector r
forall (v :: Type -> Type) a.
Vector v a =>
Int -> (Int -> a) -> v a
V.generate Int
sizeMnistLabelInt
(\Int
i -> if Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
labN then r
1 else r
0)
in (Vector r
vGlyph, Vector r
vLabel)
in ((Int, Vector Int) -> MnistData r)
-> [(Int, Vector Int)] -> [MnistData r]
forall a b. (a -> b) -> [a] -> [b]
map (Int, Vector Int) -> MnistData r
f [(Int, Vector Int)]
intData
{-# SPECIALIZE readMnistData :: LBS.ByteString -> LBS.ByteString -> [MnistData Double] #-}
{-# SPECIALIZE readMnistData :: LBS.ByteString -> LBS.ByteString -> [MnistData Float] #-}
trainGlyphsPath, trainLabelsPath, testGlyphsPath, testLabelsPath :: FilePath
trainGlyphsPath :: [Char]
trainGlyphsPath = [Char]
"samplesData/train-images-idx3-ubyte.gz"
trainLabelsPath :: [Char]
trainLabelsPath = [Char]
"samplesData/train-labels-idx1-ubyte.gz"
testGlyphsPath :: [Char]
testGlyphsPath = [Char]
"samplesData/t10k-images-idx3-ubyte.gz"
testLabelsPath :: [Char]
testLabelsPath = [Char]
"samplesData/t10k-labels-idx1-ubyte.gz"
loadMnistData :: (VS.Storable r, Fractional r)
=> FilePath -> FilePath -> IO [MnistData r]
loadMnistData :: forall r.
(Storable r, Fractional r) =>
[Char] -> [Char] -> IO [MnistData r]
loadMnistData [Char]
glyphsPath [Char]
labelsPath =
[Char]
-> IOMode -> (Handle -> IO [MnistData r]) -> IO [MnistData r]
forall r. [Char] -> IOMode -> (Handle -> IO r) -> IO r
withBinaryFile [Char]
glyphsPath IOMode
ReadMode ((Handle -> IO [MnistData r]) -> IO [MnistData r])
-> (Handle -> IO [MnistData r]) -> IO [MnistData r]
forall a b. (a -> b) -> a -> b
$ \Handle
glyphsHandle ->
[Char]
-> IOMode -> (Handle -> IO [MnistData r]) -> IO [MnistData r]
forall r. [Char] -> IOMode -> (Handle -> IO r) -> IO r
withBinaryFile [Char]
labelsPath IOMode
ReadMode ((Handle -> IO [MnistData r]) -> IO [MnistData r])
-> (Handle -> IO [MnistData r]) -> IO [MnistData r]
forall a b. (a -> b) -> a -> b
$ \Handle
labelsHandle -> do
glyphsContents <- Handle -> IO ByteString
LBS.hGetContents Handle
glyphsHandle
labelsContents <- LBS.hGetContents labelsHandle
return $! readMnistData (decompress glyphsContents)
(decompress labelsContents)
{-# SPECIALIZE loadMnistData :: FilePath -> FilePath -> IO [MnistData Double] #-}
{-# SPECIALIZE loadMnistData :: FilePath -> FilePath -> IO [MnistData Float] #-}
shuffle :: StdGen -> [a] -> [a]
shuffle :: forall a. StdGen -> [a] -> [a]
shuffle StdGen
g [a]
l =
let rnds :: [Int]
rnds = StdGen -> [Int]
forall g. RandomGen g => g -> [Int]
forall a g. (Random a, RandomGen g) => g -> [a]
randoms StdGen
g :: [Int]
res :: [a]
res = ((a, Int) -> a) -> [(a, Int)] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (a, Int) -> a
forall a b. (a, b) -> a
fst ([(a, Int)] -> [a]) -> [(a, Int)] -> [a]
forall a b. (a -> b) -> a -> b
$ ((a, Int) -> (a, Int) -> Ordering) -> [(a, Int)] -> [(a, Int)]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (((a, Int) -> Int) -> (a, Int) -> (a, Int) -> Ordering
forall a b. Ord a => (b -> a) -> b -> b -> Ordering
comparing (a, Int) -> Int
forall a b. (a, b) -> b
snd) ([(a, Int)] -> [(a, Int)]) -> [(a, Int)] -> [(a, Int)]
forall a b. (a -> b) -> a -> b
$ [a] -> [Int] -> [(a, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [a]
l [Int]
rnds
in (a -> () -> ()) -> () -> [a] -> ()
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: Type -> Type) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr a -> () -> ()
forall a b. a -> b -> b
seq () [a]
res () -> [a] -> [a]
forall a b. a -> b -> b
`seq` [a]
res
chunksOf :: Int -> [e] -> [[e]]
chunksOf :: forall e. Int -> [e] -> [[e]]
chunksOf Int
n = [e] -> [[e]]
go where
go :: [e] -> [[e]]
go [] = []
go [e]
l = let ([e]
chunk, [e]
rest) = Int -> [e] -> ([e], [e])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
n [e]
l
in [e]
chunk [e] -> [[e]] -> [[e]]
forall a. a -> [a] -> [a]
: [e] -> [[e]]
go [e]
rest