{-# LANGUAGE Arrows              #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies        #-}
{-# OPTIONS_GHC -Wall #-}
{-# OPTIONS -fplugin=Overloaded -fplugin-opt=Overloaded:Categories #-}
module Main where

import Control.Monad      (when)
import Data.Word          (Word64)
import Numeric            (showFFloat)
import System.Environment (getArgs)
import Data.List (intercalate)

import qualified Control.Category
import qualified Numeric.LinearAlgebra  as LA
import qualified System.Random.SplitMix as SM

import Overloaded.Categories
import VectorSpace

-- | A Function which computes value and derivative at the point.
newtype AD a b = AD (a -> (b, L a b))

instance Category AD where
    id = AD (\x -> (x, L id))

    AD g . AD f = AD $ \a ->
        let (b, L f') = f a
            (c, L g') = g b
        in (c, L (g' . f'))

instance CategoryWith1 AD where
    type Terminal AD = ()

    terminal = AD (const ((), terminal))

instance CartesianCategory AD where
    type Product AD = (,)

    proj1 = AD (\x -> (fst x, proj1))
    proj2 = AD (\x -> (snd x, proj2))

    fanout (AD f) (AD g) = AD $ \a ->
        let (b, f') = f a
            (c, g') = g a
        in ((b, c), fanout f' g')

instance GeneralizedElement AD where
    type Object AD a = a

    konst x = AD (\_ -> (x, L $ \_ -> LZ))

ladd :: LinMap r (a, a) -> LinMap r a
ladd (LH f g) = LA f g
ladd (LV f g) = LV (ladd f) (ladd g)
ladd (LA a b) = LA (ladd a) (ladd b)
ladd (LK k f) = LK k (ladd f)
ladd LZ       = LZ
ladd LI       = LV LI LI

lmult :: Double -> Double -> LinMap r (a, a) -> LinMap r a
lmult x y (LH f g) = LA (LK y f) (LK x g)
lmult x y (LV f g) = LV (lmult x y f) (lmult x y g)
lmult x y (LA f g) = LA (lmult x y f) (lmult x y g)
lmult x y (LK k f) = LK k (lmult x y f)
lmult _ _ LZ       = LZ
lmult x y LI       = LV (LK y LI) (LK x LI)

plus :: AD (Double, Double) Double
plus = AD $ \(x,y) -> (x + y, L ladd)

minus :: AD (Double, Double) Double
minus = AD $ \(x,y) -> (x - y, L $ lmult (-1) 1)

mult :: AD (Double, Double) Double
mult = AD $ \(x,y) -> (x * y, L $ lmult x y)

scale :: Double -> AD Double Double
scale k = AD $ \x -> (k * x, linear k)

evaluateAD :: (HasDim a, HasDim b) => AD a b -> a -> (b, LA.Matrix Double)
evaluateAD (AD f) x = let (y, f') = f x in (y, evalL f')

-------------------------------------------------------------------------------
-- Simple examples
-------------------------------------------------------------------------------

ex1 :: AD Double Double
ex1 = plus %% fanout identity identity

ex2 :: AD Double Double
ex2 = mult %% fanout identity identity

-------------------------------------------------------------------------------
-- Quadratic function
-------------------------------------------------------------------------------

quad :: AD (Double, Double) Double
quad = proc (x, y) -> do
    x2  <- mult    -< (x, x)
    y2  <- mult    -< (y, y)
    tmp <- plus    -< (x2, y2)
    z   <- konst 5 -< ()
    plus -< (tmp, z)

-------------------------------------------------------------------------------
-- Newton
-------------------------------------------------------------------------------

findZero :: AD Double Double -> Double -> [Double]
findZero f x0 = take 10 results
  where
    results = iterate go x0

    go :: Double -> Double
    go x =
        let (y, m) = evaluateAD f x
            [[y']] = LA.toLists m
        in x - gamma * (y / y')

    gamma = 0.1

-------------------------------------------------------------------------------
-- Gradient descent
-------------------------------------------------------------------------------

gradDesc :: forall a. VectorSpace a => AD a Double -> a -> [a]
gradDesc f = iterate go where
    go :: a -> a
    go x =
        let (_, m) = evaluateAD f x
            [grad] = LA.toLists $ LA.tr $ LA.scale gamma m

        in fromVector $ zipWith (-) (toVector x) grad

    gamma = 0.1

-------------------------------------------------------------------------------
-- Random
-------------------------------------------------------------------------------

randomDoubles :: Word64 -> [Double]
randomDoubles seed = go (SM.mkSMGen seed) where
    go g = let (d, g') = SM.nextDouble g in d : go g'

-------------------------------------------------------------------------------
-- Dot
-------------------------------------------------------------------------------

class VectorSpace' a where
    sumN :: AD a Double
    multN :: AD (a, a) a

instance (VectorSpace' a, VectorSpace' b) => VectorSpace' (a, b) where
    sumN = proc (x, y) -> do
        x' <- sumN -< x
        y' <- sumN -< y
        plus -< (x', y')

    multN = proc ((x1, x2), (y1, y2)) -> do
        z1 <- multN -< (x1, y1)
        z2 <- multN -< (x2, y2)
        identity -< (z1, z2)

instance VectorSpace' Double where
    sumN  = identity
    multN = mult

dot :: VectorSpace' a => AD (a, a) Double
dot = sumN %% multN

-------------------------------------------------------------------------------
-- ML stuff
-------------------------------------------------------------------------------

tanhAD :: AD Double Double
tanhAD = AD $ \x ->
    let y = tanh x
    in (y, linear (1 - y * y))

sigmoidAD :: AD Double Double
sigmoidAD = AD $ \x ->
    let y = 1 / (1 + exp (- x))
    in (x, linear (y * (1 - y)))

-- | weights for 2x1 connection. Two weights and bias.
type Weights' = ((Double, Double), Double)

-- | Two internal neurons, and final output
type Weights = ((Weights', Weights'), Weights')

startWeights :: Weights
startWeights = fromVector $ randomDoubles 1337

neuron :: AD (Weights', (Double, Double)) Double
neuron = proc ((ws, bias), i) -> do
    o <- dot -< (ws, i)
    tanhAD %% plus -< (o, bias)

network :: AD (Weights, (Double, Double)) Double
network = proc (((w1, w2), w3), xy) -> do
    u <- neuron  -< (w1, xy)
    v <- neuron  -< (w2, xy)
    neuron -< (w3, (u, v))

networkError :: AD Weights Double
networkError = proc ws -> do
    -- xor!
    s1 <- ex 1 1 0 -< ws
    s2 <- ex 0 0 0 -< ws
    s3 <- ex 1 0 1 -< ws
    s4 <- ex 0 1 1 -< ws

    sumN -< ((s1,s2), (s3, s4))
  where
    ex :: Double -> Double -> Double -> AD Weights Double
    ex x y z = proc ws -> do
         x1 <- konst x -< ()
         y1 <- konst y -< ()
         e1 <- konst z -< ()
         a1 <- network -< (ws, (x1, y1))
         r1 <- minus   -< (e1, a1)
         mult -< (r1, r1)

train :: Weights
train = gradDesc networkError startWeights !! 500

-------------------------------------------------------------------------------
-- Main
-------------------------------------------------------------------------------

main :: IO ()
main = do
    putStrLn $ "quad (2,3) = " ++ show (evaluateAD quad (2,3))
    putStrLn $ "gradDesc quad (2,3) = " ++ show (gradDesc quad (2,3) !! 30)

    print $ evaluateAD tanhAD 1
    print $ evaluateAD sigmoidAD 1

    putStrLn "Training the net (for xor)"
    let ws = train
    putStrLn $ "Parameters = " ++ show (toVector ws)
    putStrLn $ "Error = " ++ show (fst $ evaluateAD networkError ws)
    let example xy =
          putStrLn $ "eval " ++ show xy ++ " = " ++ showFFloat (Just 2) (fst $ evaluateAD network (ws, xy)) ""

    example (0, 0)
    example (0, 1)
    example (1, 0)
    example (1, 1)

    args <- getArgs
    when ("plot" `elem` args) $ do
        putStrLn "Outputting plot data: datafile.dat"

        let n = 20 :: Int
        let points = [ fromIntegral x / fromIntegral n | x <- [0..n] ] :: [Double]

        let output :: String
            output = unlines
                [ intercalate "\t"
                    [ show x
                    , show y
                    , show (fst (evaluateAD network (ws, (x, y))))
                    ]
                | x <- points
                , y <- points
                ]

        writeFile "datafile.dat" output