{- | Implements batching - technology that allows packing and processing multiple samples at once.
-}


-- 'TypeFamilies' are needed to instantiate 'DType'.

{-# LANGUAGE TypeFamilies #-}


module Synapse.NN.Batching
    ( -- * 'Sample' datatype

    
      Sample (Sample, sampleInput, sampleOutput)

      -- * 'Dataset' datatype


    , Dataset (Dataset, unDataset)

    , datasetSize

    , shuffleDataset
    , splitDataset

    , VecDataset
    , BatchedDataset
    , batchVectors
    ) where


import Synapse.Tensors (DType, Indexable(unsafeIndex))

import Synapse.Tensors.Vec (Vec(Vec))
import qualified Synapse.Tensors.Vec as V

import Synapse.Tensors.Mat (Mat)
import qualified Synapse.Tensors.Mat as M

import Control.Monad.ST (runST)

import System.Random (RandomGen, uniformR)

import Data.Vector (thaw, unsafeFreeze)
import Data.Vector.Mutable (swap)


-- | 'Sample' datatype represents known pair of inputs and outputs of function that is unknown.

data Sample a = Sample
    { forall a. Sample a -> a
sampleInput  :: a  -- ^ Sample input.

    , forall a. Sample a -> a
sampleOutput :: a  -- ^ Sample output.

    } deriving (Sample a -> Sample a -> Bool
(Sample a -> Sample a -> Bool)
-> (Sample a -> Sample a -> Bool) -> Eq (Sample a)
forall a. Eq a => Sample a -> Sample a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall a. Eq a => Sample a -> Sample a -> Bool
== :: Sample a -> Sample a -> Bool
$c/= :: forall a. Eq a => Sample a -> Sample a -> Bool
/= :: Sample a -> Sample a -> Bool
Eq, Int -> Sample a -> ShowS
[Sample a] -> ShowS
Sample a -> String
(Int -> Sample a -> ShowS)
-> (Sample a -> String) -> ([Sample a] -> ShowS) -> Show (Sample a)
forall a. Show a => Int -> Sample a -> ShowS
forall a. Show a => [Sample a] -> ShowS
forall a. Show a => Sample a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall a. Show a => Int -> Sample a -> ShowS
showsPrec :: Int -> Sample a -> ShowS
$cshow :: forall a. Show a => Sample a -> String
show :: Sample a -> String
$cshowList :: forall a. Show a => [Sample a] -> ShowS
showList :: [Sample a] -> ShowS
Show)

type instance DType (Sample a) = DType a


-- | 'Dataset' newtype wraps vector of 'Sample's - it represents known information about unknown function.

newtype Dataset a = Dataset 
    { forall a. Dataset a -> Vec (Sample a)
unDataset :: Vec (Sample a)  -- ^ Unwraps 'Dataset' newtype.

    } deriving (Dataset a -> Dataset a -> Bool
(Dataset a -> Dataset a -> Bool)
-> (Dataset a -> Dataset a -> Bool) -> Eq (Dataset a)
forall a. Eq a => Dataset a -> Dataset a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall a. Eq a => Dataset a -> Dataset a -> Bool
== :: Dataset a -> Dataset a -> Bool
$c/= :: forall a. Eq a => Dataset a -> Dataset a -> Bool
/= :: Dataset a -> Dataset a -> Bool
Eq, Int -> Dataset a -> ShowS
[Dataset a] -> ShowS
Dataset a -> String
(Int -> Dataset a -> ShowS)
-> (Dataset a -> String)
-> ([Dataset a] -> ShowS)
-> Show (Dataset a)
forall a. Show a => Int -> Dataset a -> ShowS
forall a. Show a => [Dataset a] -> ShowS
forall a. Show a => Dataset a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall a. Show a => Int -> Dataset a -> ShowS
showsPrec :: Int -> Dataset a -> ShowS
$cshow :: forall a. Show a => Dataset a -> String
show :: Dataset a -> String
$cshowList :: forall a. Show a => [Dataset a] -> ShowS
showList :: [Dataset a] -> ShowS
Show)

type instance DType (Dataset a) = DType a

-- | Returns size of dataset.

datasetSize :: Dataset a -> Int
datasetSize :: forall a. Dataset a -> Int
datasetSize = Vec (Sample a) -> Int
forall a. Vec a -> Int
V.size (Vec (Sample a) -> Int)
-> (Dataset a -> Vec (Sample a)) -> Dataset a -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Dataset a -> Vec (Sample a)
forall a. Dataset a -> Vec (Sample a)
unDataset


-- | Shuffles any 'Dataset' using Fisher-Yates algorithm.

shuffleDataset :: RandomGen g => Dataset a -> g -> (Dataset a, g)
shuffleDataset :: forall g a. RandomGen g => Dataset a -> g -> (Dataset a, g)
shuffleDataset (Dataset Vec (Sample a)
dataset) g
gen
    | Vec (Sample a) -> Int
forall a. Vec a -> Int
V.size Vec (Sample a)
dataset Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1 = (Vec (Sample a) -> Dataset a
forall a. Vec (Sample a) -> Dataset a
Dataset Vec (Sample a)
dataset, g
gen)
    | Bool
otherwise           = (forall s. ST s (Dataset a, g)) -> (Dataset a, g)
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Dataset a, g)) -> (Dataset a, g))
-> (forall s. ST s (Dataset a, g)) -> (Dataset a, g)
forall a b. (a -> b) -> a -> b
$ do
        MVector s (Sample a)
mutableVector <- Vector (Sample a) -> ST s (MVector (PrimState (ST s)) (Sample a))
forall (m :: * -> *) a.
PrimMonad m =>
Vector a -> m (MVector (PrimState m) a)
thaw (Vector (Sample a) -> ST s (MVector (PrimState (ST s)) (Sample a)))
-> Vector (Sample a)
-> ST s (MVector (PrimState (ST s)) (Sample a))
forall a b. (a -> b) -> a -> b
$ Vec (Sample a) -> Vector (Sample a)
forall a. Vec a -> Vector a
V.unVec Vec (Sample a)
dataset
        g
gen' <- MVector (PrimState (ST s)) (Sample a) -> Int -> g -> ST s g
forall {m :: * -> *} {t} {a}.
(RandomGen t, PrimMonad m) =>
MVector (PrimState m) a -> Int -> t -> m t
go MVector s (Sample a)
MVector (PrimState (ST s)) (Sample a)
mutableVector (Vec (Sample a) -> Int
forall a. Vec a -> Int
V.size Vec (Sample a)
dataset Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) g
gen
        Vector (Sample a)
shuffledVector <- MVector (PrimState (ST s)) (Sample a) -> ST s (Vector (Sample a))
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
unsafeFreeze MVector s (Sample a)
MVector (PrimState (ST s)) (Sample a)
mutableVector
        (Dataset a, g) -> ST s (Dataset a, g)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (Vec (Sample a) -> Dataset a
forall a. Vec (Sample a) -> Dataset a
Dataset (Vec (Sample a) -> Dataset a) -> Vec (Sample a) -> Dataset a
forall a b. (a -> b) -> a -> b
$ Vector (Sample a) -> Vec (Sample a)
forall a. Vector a -> Vec a
Vec Vector (Sample a)
shuffledVector, g
gen')
  where
    go :: MVector (PrimState m) a -> Int -> t -> m t
go MVector (PrimState m) a
_ Int
0 t
seed = t -> m t
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return t
seed
    go MVector (PrimState m) a
v Int
lastIndex t
seed = let (Int
swapIndex, t
seed') = (Int, Int) -> t -> (Int, t)
forall a g. (UniformRange a, RandomGen g) => (a, a) -> g -> (a, g)
uniformR (Int
0, Int
lastIndex) t
seed
                          in MVector (PrimState m) a -> Int -> Int -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> Int -> m ()
swap MVector (PrimState m) a
v Int
swapIndex Int
lastIndex m () -> m t -> m t
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> MVector (PrimState m) a -> Int -> t -> m t
go MVector (PrimState m) a
v (Int
lastIndex Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) t
seed'

-- | Splits dataset such that size of left dataset divided on size of right dataset will be equal to given ratio.

splitDataset :: Dataset a -> Float -> (Dataset a, Dataset a)
splitDataset :: forall a. Dataset a -> Float -> (Dataset a, Dataset a)
splitDataset (Dataset Vec (Sample a)
dataset) Float
ratio = let (Vec (Sample a)
left, Vec (Sample a)
right) = Int -> Vec (Sample a) -> (Vec (Sample a), Vec (Sample a))
forall a. Int -> Vec a -> (Vec a, Vec a)
V.splitAt (Float -> Int
forall b. Integral b => Float -> b
forall a b. (RealFrac a, Integral b) => a -> b
round (Float -> Int) -> Float -> Int
forall a b. (a -> b) -> a -> b
$ Int -> Float
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Vec (Sample a) -> Int
forall a. Vec a -> Int
V.size Vec (Sample a)
dataset) Float -> Float -> Float
forall a. Num a => a -> a -> a
* Float
ratio) Vec (Sample a)
dataset
                                       in (Vec (Sample a) -> Dataset a
forall a. Vec (Sample a) -> Dataset a
Dataset Vec (Sample a)
left, Vec (Sample a) -> Dataset a
forall a. Vec (Sample a) -> Dataset a
Dataset Vec (Sample a)
right)


-- | 'VecDataset' type alias represents 'Dataset's with samples of vector functions.

type VecDataset a = Dataset (Vec a)
-- | 'BatchedDataset' type alias represents 'Dataset's with samples of vector functions where multiple samples were batched together.

type BatchedDataset a = Dataset (Mat a)

-- | Batches 'VecDataset' by grouping a given amount of samples into batches.

batchVectors :: Int -> VecDataset a -> BatchedDataset a
batchVectors :: forall a. Int -> VecDataset a -> BatchedDataset a
batchVectors Int
batchSize (Dataset Vec (Sample (Vec a))
dataset) = Vec (Sample (Mat a)) -> Dataset (Mat a)
forall a. Vec (Sample a) -> Dataset a
Dataset (Vec (Sample (Mat a)) -> Dataset (Mat a))
-> Vec (Sample (Mat a)) -> Dataset (Mat a)
forall a b. (a -> b) -> a -> b
$ [Sample (Mat a)] -> Vec (Sample (Mat a))
forall a. [a] -> Vec a
V.fromList ([Sample (Mat a)] -> Vec (Sample (Mat a)))
-> [Sample (Mat a)] -> Vec (Sample (Mat a))
forall a b. (a -> b) -> a -> b
$ (Vec (Sample (Vec a)) -> Sample (Mat a))
-> [Vec (Sample (Vec a))] -> [Sample (Mat a)]
forall a b. (a -> b) -> [a] -> [b]
map Vec (Sample (Vec a)) -> Sample (Mat a)
forall {a}. Vec (Sample (Vec a)) -> Sample (Mat a)
groupBatch ([Vec (Sample (Vec a))] -> [Sample (Mat a)])
-> [Vec (Sample (Vec a))] -> [Sample (Mat a)]
forall a b. (a -> b) -> a -> b
$ Vec (Sample (Vec a)) -> [Vec (Sample (Vec a))]
split Vec (Sample (Vec a))
dataset
  where
    split :: Vec (Sample (Vec a)) -> [Vec (Sample (Vec a))]
split Vec (Sample (Vec a))
vector
        | Vec (Sample (Vec a)) -> Int
forall a. Vec a -> Int
V.size Vec (Sample (Vec a))
vector Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
batchSize = [Vec (Sample (Vec a))
vector]
        | Bool
otherwise                  = let (Vec (Sample (Vec a))
current, Vec (Sample (Vec a))
remainder) = Int
-> Vec (Sample (Vec a))
-> (Vec (Sample (Vec a)), Vec (Sample (Vec a)))
forall a. Int -> Vec a -> (Vec a, Vec a)
V.splitAt Int
batchSize Vec (Sample (Vec a))
vector
                                       in Vec (Sample (Vec a))
current Vec (Sample (Vec a))
-> [Vec (Sample (Vec a))] -> [Vec (Sample (Vec a))]
forall a. a -> [a] -> [a]
: Vec (Sample (Vec a)) -> [Vec (Sample (Vec a))]
split Vec (Sample (Vec a))
remainder
    
    groupBatch :: Vec (Sample (Vec a)) -> Sample (Mat a)
groupBatch Vec (Sample (Vec a))
vector = let (Int
rows, Int
inputCols) = (Vec (Sample (Vec a)) -> Int
forall a. Vec a -> Int
V.size Vec (Sample (Vec a))
vector, Vec a -> Int
forall a. Vec a -> Int
V.size (Vec a -> Int) -> Vec a -> Int
forall a b. (a -> b) -> a -> b
$ Sample (Vec a) -> Vec a
forall a. Sample a -> a
sampleInput (Sample (Vec a) -> Vec a) -> Sample (Vec a) -> Vec a
forall a b. (a -> b) -> a -> b
$ Vec (Sample (Vec a))
-> Index (Vec (Sample (Vec a))) -> DType (Vec (Sample (Vec a)))
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex Vec (Sample (Vec a))
vector Int
Index (Vec (Sample (Vec a)))
0)
                            group :: (Int, Int) -> DType (Vec a)
group (Int
r, Int
c) = Vec a -> Index (Vec a) -> DType (Vec a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex ((if Int
c Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
inputCols then Sample (Vec a) -> Vec a
forall a. Sample a -> a
sampleInput else Sample (Vec a) -> Vec a
forall a. Sample a -> a
sampleOutput) (Vec (Sample (Vec a))
-> Index (Vec (Sample (Vec a))) -> DType (Vec (Sample (Vec a)))
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex Vec (Sample (Vec a))
vector Int
Index (Vec (Sample (Vec a)))
r)) (Int
c Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
inputCols)
                            fullBatch :: Mat a
fullBatch = (Int, Int) -> ((Int, Int) -> a) -> Mat a
forall a. (Int, Int) -> ((Int, Int) -> a) -> Mat a
M.generate (Int
rows, Int
inputCols Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Vec a -> Int
forall a. Vec a -> Int
V.size (Sample (Vec a) -> Vec a
forall a. Sample a -> a
sampleOutput (Sample (Vec a) -> Vec a) -> Sample (Vec a) -> Vec a
forall a b. (a -> b) -> a -> b
$ Vec (Sample (Vec a))
-> Index (Vec (Sample (Vec a))) -> DType (Vec (Sample (Vec a)))
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex Vec (Sample (Vec a))
vector Int
Index (Vec (Sample (Vec a)))
0)) (Int, Int) -> a
(Int, Int) -> DType (Vec a)
group
                            (Mat a
batchInput, Mat a
batchOutput, Mat a
_, Mat a
_) = Mat a -> (Int, Int) -> (Mat a, Mat a, Mat a, Mat a)
forall a. Mat a -> (Int, Int) -> (Mat a, Mat a, Mat a, Mat a)
M.split Mat a
fullBatch (Int
rows, Int
inputCols)
                        in Mat a -> Mat a -> Sample (Mat a)
forall a. a -> a -> Sample a
Sample Mat a
batchInput Mat a
batchOutput