{- | Provides interface for creating and using neural network models.
-}


-- 'TypeFamilies' are needed to instantiate 'Synapse.Tensors.DType'.

{-# LANGUAGE TypeFamilies #-}


module Synapse.NN.Models
    ( -- * Common for models

      InputSize (InputSize)

      -- * 'SequentialModel' datatype

      
    , SequentialModel (SequentialModel, unSequentialModel)
    , buildSequentialModel

    , layerPrefix
    ) where


import Synapse.Tensors (DType, SingletonOps(singleton))

import Synapse.Autograd (SymbolIdentifier (SymbolIdentifier))

import Synapse.NN.Layers.Layer(AbstractLayer(..), Layer, LayerConfiguration)

import Data.Maybe (fromMaybe)
import Data.Foldable (foldl')


-- | 'InputSize' newtype wraps 'Int' - amount of features of input that the model should support (@InputSize 3@ means that model supports any matrix with size (x, 3)).

newtype InputSize = InputSize Int


-- | 'SequentialModel' datatype represents any model grouping layers linearly.

newtype SequentialModel a = SequentialModel
    { forall a. SequentialModel a -> [Layer a]
unSequentialModel :: [Layer a]  -- ^ Returns layers of 'SequentialModel'.

    }

instance Show a => Show (SequentialModel a) where
    show :: SequentialModel a -> String
show (SequentialModel [Layer a]
layers) = Int -> [Layer a] -> String
forall {a} {l :: * -> *}.
(Show a, AbstractLayer l) =>
Int -> [l a] -> String
go Int
1 [Layer a]
layers
      where
        go :: Int -> [l a] -> String
go Int
_ [] = String
""
        go Int
i (l a
x:[l a]
xs) = String
"Layer " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" parameters: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ [SymbolMat a] -> String
forall a. Show a => a -> String
show (SymbolIdentifier -> l a -> [SymbolMat a]
forall a. SymbolIdentifier -> l a -> [SymbolMat a]
forall (l :: * -> *) a.
AbstractLayer l =>
SymbolIdentifier -> l a -> [SymbolMat a]
getParameters (SymbolIdentifier -> Int -> SymbolIdentifier
layerPrefix SymbolIdentifier
forall a. Monoid a => a
mempty Int
i) l a
x) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
";\n" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> [l a] -> String
go (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [l a]
xs

-- | Builds sequential model using input size and layer configurations to ensure that layers are compatible with each other.

buildSequentialModel :: InputSize -> [LayerConfiguration (Layer a)] -> SequentialModel a
buildSequentialModel :: forall a.
InputSize -> [LayerConfiguration (Layer a)] -> SequentialModel a
buildSequentialModel (InputSize Int
i) [LayerConfiguration (Layer a)]
layerConfigs = [Layer a] -> SequentialModel a
forall a. [Layer a] -> SequentialModel a
SequentialModel ([Layer a] -> SequentialModel a) -> [Layer a] -> SequentialModel a
forall a b. (a -> b) -> a -> b
$ Int -> [LayerConfiguration (Layer a)] -> [Layer a]
forall {l :: * -> *} {a}.
AbstractLayer l =>
Int -> [Int -> l a] -> [l a]
go Int
i [LayerConfiguration (Layer a)]
layerConfigs
  where
    go :: Int -> [Int -> l a] -> [l a]
go Int
_ [] = []
    go Int
prevSize (Int -> l a
l:[Int -> l a]
ls) = let layer :: l a
layer = Int -> l a
l Int
prevSize
                             outputMaybe :: Maybe Int
outputMaybe = l a -> Maybe Int
forall a. l a -> Maybe Int
forall (l :: * -> *) a. AbstractLayer l => l a -> Maybe Int
outputSize l a
layer
                             output :: Int
output = Int -> Maybe Int -> Int
forall a. a -> Maybe a -> a
fromMaybe Int
prevSize Maybe Int
outputMaybe
                         in l a
layer l a -> [l a] -> [l a]
forall a. a -> [a] -> [a]
: Int -> [Int -> l a] -> [l a]
go Int
output [Int -> l a]
ls

type instance DType (SequentialModel a) = a


-- | Forms prefix for layers according to 'Synapse.NN.Layers.Layer.AbstractLayer' requirements.

layerPrefix :: SymbolIdentifier -> Int -> SymbolIdentifier
layerPrefix :: SymbolIdentifier -> Int -> SymbolIdentifier
layerPrefix SymbolIdentifier
prefix Int
i = [SymbolIdentifier] -> SymbolIdentifier
forall a. Monoid a => [a] -> a
mconcat [SymbolIdentifier
prefix, String -> SymbolIdentifier
SymbolIdentifier String
"l", String -> SymbolIdentifier
SymbolIdentifier (Int -> String
forall a. Show a => a -> String
show Int
i), String -> SymbolIdentifier
SymbolIdentifier String
"w"]

instance AbstractLayer SequentialModel where
    inputSize :: forall a. SequentialModel a -> Maybe Int
inputSize = Layer a -> Maybe Int
forall a. Layer a -> Maybe Int
forall (l :: * -> *) a. AbstractLayer l => l a -> Maybe Int
inputSize (Layer a -> Maybe Int)
-> (SequentialModel a -> Layer a) -> SequentialModel a -> Maybe Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Layer a] -> Layer a
forall a. HasCallStack => [a] -> a
head ([Layer a] -> Layer a)
-> (SequentialModel a -> [Layer a]) -> SequentialModel a -> Layer a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SequentialModel a -> [Layer a]
forall a. SequentialModel a -> [Layer a]
unSequentialModel
    outputSize :: forall a. SequentialModel a -> Maybe Int
outputSize = Layer a -> Maybe Int
forall a. Layer a -> Maybe Int
forall (l :: * -> *) a. AbstractLayer l => l a -> Maybe Int
outputSize (Layer a -> Maybe Int)
-> (SequentialModel a -> Layer a) -> SequentialModel a -> Maybe Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Layer a] -> Layer a
forall a. HasCallStack => [a] -> a
head ([Layer a] -> Layer a)
-> (SequentialModel a -> [Layer a]) -> SequentialModel a -> Layer a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SequentialModel a -> [Layer a]
forall a. SequentialModel a -> [Layer a]
unSequentialModel

    nParameters :: forall a. SequentialModel a -> Int
nParameters = (Int -> Layer a -> Int) -> Int -> [Layer a] -> Int
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\Int
parameters Layer a
layer -> Int
parameters Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Layer a -> Int
forall a. Layer a -> Int
forall (l :: * -> *) a. AbstractLayer l => l a -> Int
nParameters Layer a
layer) Int
0 ([Layer a] -> Int)
-> (SequentialModel a -> [Layer a]) -> SequentialModel a -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SequentialModel a -> [Layer a]
forall a. SequentialModel a -> [Layer a]
unSequentialModel
    getParameters :: forall a. SymbolIdentifier -> SequentialModel a -> [SymbolMat a]
getParameters SymbolIdentifier
prefix =
        (Int, [SymbolMat a]) -> [SymbolMat a]
forall a b. (a, b) -> b
snd ((Int, [SymbolMat a]) -> [SymbolMat a])
-> (SequentialModel a -> (Int, [SymbolMat a]))
-> SequentialModel a
-> [SymbolMat a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Int, [SymbolMat a]) -> Layer a -> (Int, [SymbolMat a]))
-> (Int, [SymbolMat a]) -> [Layer a] -> (Int, [SymbolMat a])
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\(Int
i, [SymbolMat a]
acc) Layer a
layer -> (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, [SymbolMat a]
acc [SymbolMat a] -> [SymbolMat a] -> [SymbolMat a]
forall a. [a] -> [a] -> [a]
++ SymbolIdentifier -> Layer a -> [SymbolMat a]
forall a. SymbolIdentifier -> Layer a -> [SymbolMat a]
forall (l :: * -> *) a.
AbstractLayer l =>
SymbolIdentifier -> l a -> [SymbolMat a]
getParameters (SymbolIdentifier -> Int -> SymbolIdentifier
layerPrefix SymbolIdentifier
prefix Int
i) Layer a
layer)) (Int
1, []) ([Layer a] -> (Int, [SymbolMat a]))
-> (SequentialModel a -> [Layer a])
-> SequentialModel a
-> (Int, [SymbolMat a])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SequentialModel a -> [Layer a]
forall a. SequentialModel a -> [Layer a]
unSequentialModel
    updateParameters :: forall a. SequentialModel a -> [Mat a] -> SequentialModel a
updateParameters (SequentialModel [Layer a]
model) = [Layer a] -> SequentialModel a
forall a. [Layer a] -> SequentialModel a
SequentialModel ([Layer a] -> SequentialModel a)
-> ([Mat a] -> [Layer a]) -> [Mat a] -> SequentialModel a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Layer a] -> [Mat a] -> [Layer a]
forall {l :: * -> *} {a}.
AbstractLayer l =>
[l a] -> [Mat a] -> [l a]
go [Layer a]
model
      where
        go :: [l a] -> [Mat a] -> [l a]
go [] [Mat a]
_ = []
        go (l a
layer:[l a]
layers) [Mat a]
parameters = let ([Mat a]
x, [Mat a]
parameters') = Int -> [Mat a] -> ([Mat a], [Mat a])
forall a. Int -> [a] -> ([a], [a])
splitAt (l a -> Int
forall a. l a -> Int
forall (l :: * -> *) a. AbstractLayer l => l a -> Int
nParameters l a
layer) [Mat a]
parameters
                                       in l a -> [Mat a] -> l a
forall a. l a -> [Mat a] -> l a
forall (l :: * -> *) a. AbstractLayer l => l a -> [Mat a] -> l a
updateParameters l a
layer [Mat a]
x l a -> [l a] -> [l a]
forall a. a -> [a] -> [a]
: [l a] -> [Mat a] -> [l a]
go [l a]
layers [Mat a]
parameters'

    symbolicForward :: forall a.
(Symbolic a, Floating a, Ord a) =>
SymbolIdentifier
-> SymbolMat a -> SequentialModel a -> (SymbolMat a, SymbolMat a)
symbolicForward SymbolIdentifier
prefix SymbolMat a
input =
        (Int, (SymbolMat a, SymbolMat a)) -> (SymbolMat a, SymbolMat a)
forall a b. (a, b) -> b
snd ((Int, (SymbolMat a, SymbolMat a)) -> (SymbolMat a, SymbolMat a))
-> (SequentialModel a -> (Int, (SymbolMat a, SymbolMat a)))
-> SequentialModel a
-> (SymbolMat a, SymbolMat a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Int, (SymbolMat a, SymbolMat a))
 -> Layer a -> (Int, (SymbolMat a, SymbolMat a)))
-> (Int, (SymbolMat a, SymbolMat a))
-> [Layer a]
-> (Int, (SymbolMat a, SymbolMat a))
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\(Int
i, (SymbolMat a
mat, SymbolMat a
loss)) Layer a
layer -> let (SymbolMat a
mat', SymbolMat a
newLoss) = SymbolIdentifier
-> SymbolMat a -> Layer a -> (SymbolMat a, SymbolMat a)
forall a.
(Symbolic a, Floating a, Ord a) =>
SymbolIdentifier
-> SymbolMat a -> Layer a -> (SymbolMat a, SymbolMat a)
forall (l :: * -> *) a.
(AbstractLayer l, Symbolic a, Floating a, Ord a) =>
SymbolIdentifier
-> SymbolMat a -> l a -> (SymbolMat a, SymbolMat a)
symbolicForward (SymbolIdentifier -> Int -> SymbolIdentifier
layerPrefix SymbolIdentifier
prefix Int
i) SymbolMat a
mat Layer a
layer
                                                 in (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, (SymbolMat a
mat', SymbolMat a
loss SymbolMat a -> SymbolMat a -> SymbolMat a
forall a. Num a => a -> a -> a
+ SymbolMat a
newLoss)))
              (Int
1, (SymbolMat a
input, DType (SymbolMat a) -> SymbolMat a
forall f. SingletonOps f => DType f -> f
singleton a
DType (SymbolMat a)
0)) ([Layer a] -> (Int, (SymbolMat a, SymbolMat a)))
-> (SequentialModel a -> [Layer a])
-> SequentialModel a
-> (Int, (SymbolMat a, SymbolMat a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SequentialModel a -> [Layer a]
forall a. SequentialModel a -> [Layer a]
unSequentialModel