{-# LANGUAGE TypeFamilies #-}
module Synapse.NN.Batching
(
Sample (Sample, sampleInput, sampleOutput)
, Dataset (Dataset, unDataset)
, datasetSize
, shuffleDataset
, splitDataset
, VecDataset
, BatchedDataset
, batchVectors
) where
import Synapse.Tensors (DType, Indexable(unsafeIndex))
import Synapse.Tensors.Vec (Vec(Vec))
import qualified Synapse.Tensors.Vec as V
import Synapse.Tensors.Mat (Mat)
import qualified Synapse.Tensors.Mat as M
import Control.Monad.ST (runST)
import System.Random (RandomGen, uniformR)
import Data.Vector (thaw, unsafeFreeze)
import Data.Vector.Mutable (swap)
data Sample a = Sample
{ forall a. Sample a -> a
sampleInput :: a
, forall a. Sample a -> a
sampleOutput :: a
} deriving (Sample a -> Sample a -> Bool
(Sample a -> Sample a -> Bool)
-> (Sample a -> Sample a -> Bool) -> Eq (Sample a)
forall a. Eq a => Sample a -> Sample a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall a. Eq a => Sample a -> Sample a -> Bool
== :: Sample a -> Sample a -> Bool
$c/= :: forall a. Eq a => Sample a -> Sample a -> Bool
/= :: Sample a -> Sample a -> Bool
Eq, Int -> Sample a -> ShowS
[Sample a] -> ShowS
Sample a -> String
(Int -> Sample a -> ShowS)
-> (Sample a -> String) -> ([Sample a] -> ShowS) -> Show (Sample a)
forall a. Show a => Int -> Sample a -> ShowS
forall a. Show a => [Sample a] -> ShowS
forall a. Show a => Sample a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall a. Show a => Int -> Sample a -> ShowS
showsPrec :: Int -> Sample a -> ShowS
$cshow :: forall a. Show a => Sample a -> String
show :: Sample a -> String
$cshowList :: forall a. Show a => [Sample a] -> ShowS
showList :: [Sample a] -> ShowS
Show)
type instance DType (Sample a) = DType a
newtype Dataset a = Dataset
{ forall a. Dataset a -> Vec (Sample a)
unDataset :: Vec (Sample a)
} deriving (Dataset a -> Dataset a -> Bool
(Dataset a -> Dataset a -> Bool)
-> (Dataset a -> Dataset a -> Bool) -> Eq (Dataset a)
forall a. Eq a => Dataset a -> Dataset a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall a. Eq a => Dataset a -> Dataset a -> Bool
== :: Dataset a -> Dataset a -> Bool
$c/= :: forall a. Eq a => Dataset a -> Dataset a -> Bool
/= :: Dataset a -> Dataset a -> Bool
Eq, Int -> Dataset a -> ShowS
[Dataset a] -> ShowS
Dataset a -> String
(Int -> Dataset a -> ShowS)
-> (Dataset a -> String)
-> ([Dataset a] -> ShowS)
-> Show (Dataset a)
forall a. Show a => Int -> Dataset a -> ShowS
forall a. Show a => [Dataset a] -> ShowS
forall a. Show a => Dataset a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall a. Show a => Int -> Dataset a -> ShowS
showsPrec :: Int -> Dataset a -> ShowS
$cshow :: forall a. Show a => Dataset a -> String
show :: Dataset a -> String
$cshowList :: forall a. Show a => [Dataset a] -> ShowS
showList :: [Dataset a] -> ShowS
Show)
type instance DType (Dataset a) = DType a
datasetSize :: Dataset a -> Int
datasetSize :: forall a. Dataset a -> Int
datasetSize = Vec (Sample a) -> Int
forall a. Vec a -> Int
V.size (Vec (Sample a) -> Int)
-> (Dataset a -> Vec (Sample a)) -> Dataset a -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Dataset a -> Vec (Sample a)
forall a. Dataset a -> Vec (Sample a)
unDataset
shuffleDataset :: RandomGen g => Dataset a -> g -> (Dataset a, g)
shuffleDataset :: forall g a. RandomGen g => Dataset a -> g -> (Dataset a, g)
shuffleDataset (Dataset Vec (Sample a)
dataset) g
gen
| Vec (Sample a) -> Int
forall a. Vec a -> Int
V.size Vec (Sample a)
dataset Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1 = (Vec (Sample a) -> Dataset a
forall a. Vec (Sample a) -> Dataset a
Dataset Vec (Sample a)
dataset, g
gen)
| Bool
otherwise = (forall s. ST s (Dataset a, g)) -> (Dataset a, g)
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Dataset a, g)) -> (Dataset a, g))
-> (forall s. ST s (Dataset a, g)) -> (Dataset a, g)
forall a b. (a -> b) -> a -> b
$ do
MVector s (Sample a)
mutableVector <- Vector (Sample a) -> ST s (MVector (PrimState (ST s)) (Sample a))
forall (m :: * -> *) a.
PrimMonad m =>
Vector a -> m (MVector (PrimState m) a)
thaw (Vector (Sample a) -> ST s (MVector (PrimState (ST s)) (Sample a)))
-> Vector (Sample a)
-> ST s (MVector (PrimState (ST s)) (Sample a))
forall a b. (a -> b) -> a -> b
$ Vec (Sample a) -> Vector (Sample a)
forall a. Vec a -> Vector a
V.unVec Vec (Sample a)
dataset
g
gen' <- MVector (PrimState (ST s)) (Sample a) -> Int -> g -> ST s g
forall {m :: * -> *} {t} {a}.
(RandomGen t, PrimMonad m) =>
MVector (PrimState m) a -> Int -> t -> m t
go MVector s (Sample a)
MVector (PrimState (ST s)) (Sample a)
mutableVector (Vec (Sample a) -> Int
forall a. Vec a -> Int
V.size Vec (Sample a)
dataset Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) g
gen
Vector (Sample a)
shuffledVector <- MVector (PrimState (ST s)) (Sample a) -> ST s (Vector (Sample a))
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
unsafeFreeze MVector s (Sample a)
MVector (PrimState (ST s)) (Sample a)
mutableVector
(Dataset a, g) -> ST s (Dataset a, g)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (Vec (Sample a) -> Dataset a
forall a. Vec (Sample a) -> Dataset a
Dataset (Vec (Sample a) -> Dataset a) -> Vec (Sample a) -> Dataset a
forall a b. (a -> b) -> a -> b
$ Vector (Sample a) -> Vec (Sample a)
forall a. Vector a -> Vec a
Vec Vector (Sample a)
shuffledVector, g
gen')
where
go :: MVector (PrimState m) a -> Int -> t -> m t
go MVector (PrimState m) a
_ Int
0 t
seed = t -> m t
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return t
seed
go MVector (PrimState m) a
v Int
lastIndex t
seed = let (Int
swapIndex, t
seed') = (Int, Int) -> t -> (Int, t)
forall a g. (UniformRange a, RandomGen g) => (a, a) -> g -> (a, g)
uniformR (Int
0, Int
lastIndex) t
seed
in MVector (PrimState m) a -> Int -> Int -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> Int -> m ()
swap MVector (PrimState m) a
v Int
swapIndex Int
lastIndex m () -> m t -> m t
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> MVector (PrimState m) a -> Int -> t -> m t
go MVector (PrimState m) a
v (Int
lastIndex Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) t
seed'
splitDataset :: Dataset a -> Float -> (Dataset a, Dataset a)
splitDataset :: forall a. Dataset a -> Float -> (Dataset a, Dataset a)
splitDataset (Dataset Vec (Sample a)
dataset) Float
ratio = let (Vec (Sample a)
left, Vec (Sample a)
right) = Int -> Vec (Sample a) -> (Vec (Sample a), Vec (Sample a))
forall a. Int -> Vec a -> (Vec a, Vec a)
V.splitAt (Float -> Int
forall b. Integral b => Float -> b
forall a b. (RealFrac a, Integral b) => a -> b
round (Float -> Int) -> Float -> Int
forall a b. (a -> b) -> a -> b
$ Int -> Float
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Vec (Sample a) -> Int
forall a. Vec a -> Int
V.size Vec (Sample a)
dataset) Float -> Float -> Float
forall a. Num a => a -> a -> a
* Float
ratio) Vec (Sample a)
dataset
in (Vec (Sample a) -> Dataset a
forall a. Vec (Sample a) -> Dataset a
Dataset Vec (Sample a)
left, Vec (Sample a) -> Dataset a
forall a. Vec (Sample a) -> Dataset a
Dataset Vec (Sample a)
right)
type VecDataset a = Dataset (Vec a)
type BatchedDataset a = Dataset (Mat a)
batchVectors :: Int -> VecDataset a -> BatchedDataset a
batchVectors :: forall a. Int -> VecDataset a -> BatchedDataset a
batchVectors Int
batchSize (Dataset Vec (Sample (Vec a))
dataset) = Vec (Sample (Mat a)) -> Dataset (Mat a)
forall a. Vec (Sample a) -> Dataset a
Dataset (Vec (Sample (Mat a)) -> Dataset (Mat a))
-> Vec (Sample (Mat a)) -> Dataset (Mat a)
forall a b. (a -> b) -> a -> b
$ [Sample (Mat a)] -> Vec (Sample (Mat a))
forall a. [a] -> Vec a
V.fromList ([Sample (Mat a)] -> Vec (Sample (Mat a)))
-> [Sample (Mat a)] -> Vec (Sample (Mat a))
forall a b. (a -> b) -> a -> b
$ (Vec (Sample (Vec a)) -> Sample (Mat a))
-> [Vec (Sample (Vec a))] -> [Sample (Mat a)]
forall a b. (a -> b) -> [a] -> [b]
map Vec (Sample (Vec a)) -> Sample (Mat a)
forall {a}. Vec (Sample (Vec a)) -> Sample (Mat a)
groupBatch ([Vec (Sample (Vec a))] -> [Sample (Mat a)])
-> [Vec (Sample (Vec a))] -> [Sample (Mat a)]
forall a b. (a -> b) -> a -> b
$ Vec (Sample (Vec a)) -> [Vec (Sample (Vec a))]
split Vec (Sample (Vec a))
dataset
where
split :: Vec (Sample (Vec a)) -> [Vec (Sample (Vec a))]
split Vec (Sample (Vec a))
vector
| Vec (Sample (Vec a)) -> Int
forall a. Vec a -> Int
V.size Vec (Sample (Vec a))
vector Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
batchSize = [Vec (Sample (Vec a))
vector]
| Bool
otherwise = let (Vec (Sample (Vec a))
current, Vec (Sample (Vec a))
remainder) = Int
-> Vec (Sample (Vec a))
-> (Vec (Sample (Vec a)), Vec (Sample (Vec a)))
forall a. Int -> Vec a -> (Vec a, Vec a)
V.splitAt Int
batchSize Vec (Sample (Vec a))
vector
in Vec (Sample (Vec a))
current Vec (Sample (Vec a))
-> [Vec (Sample (Vec a))] -> [Vec (Sample (Vec a))]
forall a. a -> [a] -> [a]
: Vec (Sample (Vec a)) -> [Vec (Sample (Vec a))]
split Vec (Sample (Vec a))
remainder
groupBatch :: Vec (Sample (Vec a)) -> Sample (Mat a)
groupBatch Vec (Sample (Vec a))
vector = let (Int
rows, Int
inputCols) = (Vec (Sample (Vec a)) -> Int
forall a. Vec a -> Int
V.size Vec (Sample (Vec a))
vector, Vec a -> Int
forall a. Vec a -> Int
V.size (Vec a -> Int) -> Vec a -> Int
forall a b. (a -> b) -> a -> b
$ Sample (Vec a) -> Vec a
forall a. Sample a -> a
sampleInput (Sample (Vec a) -> Vec a) -> Sample (Vec a) -> Vec a
forall a b. (a -> b) -> a -> b
$ Vec (Sample (Vec a))
-> Index (Vec (Sample (Vec a))) -> DType (Vec (Sample (Vec a)))
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex Vec (Sample (Vec a))
vector Int
Index (Vec (Sample (Vec a)))
0)
group :: (Int, Int) -> DType (Vec a)
group (Int
r, Int
c) = Vec a -> Index (Vec a) -> DType (Vec a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex ((if Int
c Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
inputCols then Sample (Vec a) -> Vec a
forall a. Sample a -> a
sampleInput else Sample (Vec a) -> Vec a
forall a. Sample a -> a
sampleOutput) (Vec (Sample (Vec a))
-> Index (Vec (Sample (Vec a))) -> DType (Vec (Sample (Vec a)))
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex Vec (Sample (Vec a))
vector Int
Index (Vec (Sample (Vec a)))
r)) (Int
c Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
inputCols)
fullBatch :: Mat a
fullBatch = (Int, Int) -> ((Int, Int) -> a) -> Mat a
forall a. (Int, Int) -> ((Int, Int) -> a) -> Mat a
M.generate (Int
rows, Int
inputCols Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Vec a -> Int
forall a. Vec a -> Int
V.size (Sample (Vec a) -> Vec a
forall a. Sample a -> a
sampleOutput (Sample (Vec a) -> Vec a) -> Sample (Vec a) -> Vec a
forall a b. (a -> b) -> a -> b
$ Vec (Sample (Vec a))
-> Index (Vec (Sample (Vec a))) -> DType (Vec (Sample (Vec a)))
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex Vec (Sample (Vec a))
vector Int
Index (Vec (Sample (Vec a)))
0)) (Int, Int) -> a
(Int, Int) -> DType (Vec a)
group
(Mat a
batchInput, Mat a
batchOutput, Mat a
_, Mat a
_) = Mat a -> (Int, Int) -> (Mat a, Mat a, Mat a, Mat a)
forall a. Mat a -> (Int, Int) -> (Mat a, Mat a, Mat a, Mat a)
M.split Mat a
fullBatch (Int
rows, Int
inputCols)
in Mat a -> Mat a -> Sample (Mat a)
forall a. a -> a -> Sample a
Sample Mat a
batchInput Mat a
batchOutput