{- | This module implements reverse-mode automatic differentiation.

Machine learning and training of models are based on calculating gradients of operations.
This can be done symbolically by dynamically creating a graph of all operations,
which is then traversed to obtain the gradient.

"Synapse" provides several operations that support automatic differentiation,
but you could easily extend list of those: you just need to define function
that returns 'Symbol' with correct local gradients.
You can check out implementations in the source to give yourself a reference
and read more about it in 'Symbol' datatype docs.
-}


-- 'FlexibleInstances' and 'TypeFamilies' are needed to instantiate 'Indexable', 'ElementwiseScalarOps', 'SingletonOps', 'VecOps', 'MatOps' typeclasses.

{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TypeFamilies      #-}


module Synapse.Autograd
    ( -- * 'Symbolic' and 'Symbol'


      Symbolic (symbolicZero, symbolicOne, symbolicN)

    , SymbolIdentifier (SymbolIdentifier, unSymbolIdentifier)

    , Symbol (Symbol, symbolIdentifier, unSymbol, symbolGradients)
    , SymbolVec
    , SymbolMat

    , symbol
    , constSymbol
    , renameSymbol

    , symbolicUnaryOp
    , symbolicBinaryOp

      -- * 'Gradients' calculation


    , Gradients (unGradients)
    , allGradients
    , getGradientsOf
    , wrt
    , nthPartialGradient
    , nthGradient
    ) where


import Synapse.Tensors (DType, Indexable(..), ElementwiseScalarOps(..), ElementwiseScalarOps(..), SingletonOps(..), VecOps(..), MatOps(..))

import Synapse.Tensors.Vec (Vec)
import qualified Synapse.Tensors.Vec as V

import Synapse.Tensors.Mat (Mat)
import qualified Synapse.Tensors.Mat as M

import Data.Foldable (foldl')
import Data.String (IsString(..))

import Data.Hashable (Hashable(..))

import qualified Data.HashMap.Lazy as HM


{- | 'Symbolic' typeclass describes types with few properties that are needed for autogradient.

Members of this typeclass could have default implementation due to 'Num', but such implementation is not always correct.
'Synapse.Tensors.Vec.Vec's and 'Synapse.Tensors.Mat.Mat's do not have only one zero or identity element, and so numerical literal is not enough.
'symbolicZero' and 'symbolicOne' function additionally take reference value to consider dimensions.
Absence of default implementations forces to manually ensure correctness of those functions.

"Synapse" provides implementations for primitive types ('Int', 'Float', 'Double'),
and for containers types ('Synapse.Tensors.Vec.Vec', 'Synapse.Tensors.Vec.Vec').

Detailed laws of 'Symbolic' properties are in the docs for associated functions.
-}
class (Eq a, Num a) => Symbolic a where
    -- | Returns additive and multiplicative (elementwise) zero element. Argument is passed for the reference of the dimension.

    symbolicZero :: a -> a

    -- | Returns multiplicative (elementwise) identity element. Argument is passed for the reference of the dimension.

    symbolicOne :: a -> a

    -- | Returns what could be considered @N@ constant (sum of @N@ 'symbolicOne's). Argument is passed for the reference of the dimension.

    symbolicN :: Int -> a -> a
    symbolicN Int
n a
c
        | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0     = a -> a
forall a. Num a => a -> a
negate (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ Int -> a -> a
forall {t} {t}. (Num t, Symbolic t, Eq t) => t -> t -> t
go (Int -> Int
forall a. Num a => a -> a
abs Int
n) a
c
        | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0    = a -> a
forall a. Symbolic a => a -> a
symbolicZero a
c
        | Bool
otherwise = Int -> a -> a
forall {t} {t}. (Num t, Symbolic t, Eq t) => t -> t -> t
go Int
n a
c
      where
        go :: t -> t -> t
go t
0 t
x = t
x
        go t
i t
x = t -> t -> t
go (t
i t -> t -> t
forall a. Num a => a -> a -> a
- t
1) (t
x t -> t -> t
forall a. Num a => a -> a -> a
+ t -> t
forall a. Symbolic a => a -> a
symbolicOne t
x)

-- Instances


instance Symbolic Int where
    symbolicZero :: Int -> Int
symbolicZero = Int -> Int -> Int
forall a b. a -> b -> a
const Int
0
    symbolicOne :: Int -> Int
symbolicOne = Int -> Int -> Int
forall a b. a -> b -> a
const Int
1
    symbolicN :: Int -> Int -> Int
symbolicN = Int -> Int -> Int
forall a b. a -> b -> a
const

instance Symbolic Float where
    symbolicZero :: Float -> Float
symbolicZero = Float -> Float -> Float
forall a b. a -> b -> a
const Float
0.0
    symbolicOne :: Float -> Float
symbolicOne = Float -> Float -> Float
forall a b. a -> b -> a
const Float
1.0
    symbolicN :: Int -> Float -> Float
symbolicN Int
n = Float -> Float -> Float
forall a b. a -> b -> a
const (Int -> Float
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n)

instance Symbolic Double where
    symbolicZero :: Double -> Double
symbolicZero = Double -> Double -> Double
forall a b. a -> b -> a
const Double
0.0
    symbolicOne :: Double -> Double
symbolicOne = Double -> Double -> Double
forall a b. a -> b -> a
const Double
1.0
    symbolicN :: Int -> Double -> Double
symbolicN Int
n = Double -> Double -> Double
forall a b. a -> b -> a
const (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n)


instance Symbolic a => Symbolic (Vec a) where
    symbolicZero :: Vec a -> Vec a
symbolicZero Vec a
reference = Int -> a -> Vec a
forall a. Int -> a -> Vec a
V.replicate (Vec a -> Int
forall a. Vec a -> Int
V.size Vec a
reference) a
0
    symbolicOne :: Vec a -> Vec a
symbolicOne Vec a
reference = Int -> a -> Vec a
forall a. Int -> a -> Vec a
V.replicate (Vec a -> Int
forall a. Vec a -> Int
V.size Vec a
reference) a
1
    symbolicN :: Int -> Vec a -> Vec a
symbolicN Int
n Vec a
reference = Int -> a -> Vec a
forall a. Int -> a -> Vec a
V.replicate (Vec a -> Int
forall a. Vec a -> Int
V.size Vec a
reference) (Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n)

instance Symbolic a => Symbolic (Mat a) where
    symbolicZero :: Mat a -> Mat a
symbolicZero Mat a
reference = (Int, Int) -> a -> Mat a
forall a. (Int, Int) -> a -> Mat a
M.replicate (Mat a -> (Int, Int)
forall a. Mat a -> (Int, Int)
M.size Mat a
reference) a
0
    symbolicOne :: Mat a -> Mat a
symbolicOne Mat a
reference = (Int, Int) -> a -> Mat a
forall a. (Int, Int) -> a -> Mat a
M.replicate (Mat a -> (Int, Int)
forall a. Mat a -> (Int, Int)
M.size Mat a
reference) a
1
    symbolicN :: Int -> Mat a -> Mat a
symbolicN Int
n Mat a
reference = (Int, Int) -> a -> Mat a
forall a. (Int, Int) -> a -> Mat a
M.replicate (Mat a -> (Int, Int)
forall a. Mat a -> (Int, Int)
M.size Mat a
reference) (Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n)


-- | 'SymbolIdentifier' is a newtype that wraps string, which needs to uniquely represent symbol.

newtype SymbolIdentifier = SymbolIdentifier 
    { SymbolIdentifier -> String
unSymbolIdentifier :: String  -- ^ Identifier of a symbol.

    } deriving (SymbolIdentifier -> SymbolIdentifier -> Bool
(SymbolIdentifier -> SymbolIdentifier -> Bool)
-> (SymbolIdentifier -> SymbolIdentifier -> Bool)
-> Eq SymbolIdentifier
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: SymbolIdentifier -> SymbolIdentifier -> Bool
== :: SymbolIdentifier -> SymbolIdentifier -> Bool
$c/= :: SymbolIdentifier -> SymbolIdentifier -> Bool
/= :: SymbolIdentifier -> SymbolIdentifier -> Bool
Eq, Int -> SymbolIdentifier -> ShowS
[SymbolIdentifier] -> ShowS
SymbolIdentifier -> String
(Int -> SymbolIdentifier -> ShowS)
-> (SymbolIdentifier -> String)
-> ([SymbolIdentifier] -> ShowS)
-> Show SymbolIdentifier
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> SymbolIdentifier -> ShowS
showsPrec :: Int -> SymbolIdentifier -> ShowS
$cshow :: SymbolIdentifier -> String
show :: SymbolIdentifier -> String
$cshowList :: [SymbolIdentifier] -> ShowS
showList :: [SymbolIdentifier] -> ShowS
Show)

instance IsString SymbolIdentifier where
    fromString :: String -> SymbolIdentifier
fromString = String -> SymbolIdentifier
SymbolIdentifier

instance Semigroup SymbolIdentifier where
    <> :: SymbolIdentifier -> SymbolIdentifier -> SymbolIdentifier
(<>) (SymbolIdentifier String
a) (SymbolIdentifier String
b) = String -> SymbolIdentifier
SymbolIdentifier (String -> SymbolIdentifier) -> String -> SymbolIdentifier
forall a b. (a -> b) -> a -> b
$ String
a String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
b

instance Monoid SymbolIdentifier where
    mempty :: SymbolIdentifier
mempty = String -> SymbolIdentifier
SymbolIdentifier String
""

{- | Datatype that represents symbol variable (variable which operations are recorded to symbolically obtain derivatives).

Any operation returning @Symbol a@ where @a@ is 'Symbolic' could be autogradiented - returned 'Symbol' has 'symbolGradients' list,
which allows "Synapse" to build a graph of computation and obtain needed gradients.
'symbolGradients' list contains pairs: first element in that pair is symbol wrt which you can take gradient and
the second element is closure that represents chain rule - it takes incoming local gradient of said symbol and multiplies it by local derivative.
You can check out implementations of those operations in the source to give yourself a reference.
-}
data Symbol a = Symbol
    { forall a. Symbol a -> SymbolIdentifier
symbolIdentifier :: SymbolIdentifier                    -- ^ Name of a symbol (identifier for differentiation).

    , forall a. Symbol a -> a
unSymbol         :: a                                   -- ^ Value of a symbol.

    , forall a. Symbol a -> [(Symbol a, Symbol a -> Symbol a)]
symbolGradients  :: [(Symbol a, Symbol a -> Symbol a)]  -- ^ List of gradients (wrt to what Symbol and closure to calculate gradient). 

    }

-- | Creates new symbol that refers to a variable (so it must have a name to be able to be differentiated wrt).

symbol :: SymbolIdentifier -> a -> Symbol a
symbol :: forall a. SymbolIdentifier -> a -> Symbol a
symbol SymbolIdentifier
name a
value = SymbolIdentifier
-> a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
forall a.
SymbolIdentifier
-> a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
Symbol SymbolIdentifier
name a
value []

-- | Creates new symbol that refers to constant (so it does not have name and thus its gradients are not saved).

constSymbol :: a -> Symbol a
constSymbol :: forall a. a -> Symbol a
constSymbol = SymbolIdentifier -> a -> Symbol a
forall a. SymbolIdentifier -> a -> Symbol a
symbol SymbolIdentifier
forall a. Monoid a => a
mempty

-- | Renames symbol which allows differentiating wrt it. Note: renaming practically creates new symbol for the gradient calculation.

renameSymbol :: SymbolIdentifier -> Symbol a -> Symbol a
renameSymbol :: forall a. SymbolIdentifier -> Symbol a -> Symbol a
renameSymbol SymbolIdentifier
name (Symbol SymbolIdentifier
_ a
value [(Symbol a, Symbol a -> Symbol a)]
localGradients) = SymbolIdentifier
-> a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
forall a.
SymbolIdentifier
-> a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
Symbol SymbolIdentifier
name a
value [(Symbol a, Symbol a -> Symbol a)]
localGradients


-- | @SymbolVec a@ type alias stands for @Symbol (Vec a)@.

type SymbolVec a = Symbol (Vec a)

-- | @SymbolMat a@ type alias stands for @Symbol (Mat a)@.

type SymbolMat a = Symbol (Mat a)


-- Typeclasses


instance Show a => Show (Symbol a) where
    show :: Symbol a -> String
show (Symbol SymbolIdentifier
name a
value [(Symbol a, Symbol a -> Symbol a)]
_) = String
"Symbol " String -> ShowS
forall a. [a] -> [a] -> [a]
++ SymbolIdentifier -> String
forall a. Show a => a -> String
show SymbolIdentifier
name String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
": " String -> ShowS
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
value


instance Eq (Symbol a) where
    == :: Symbol a -> Symbol a -> Bool
(==) (Symbol SymbolIdentifier
name1 a
_ [(Symbol a, Symbol a -> Symbol a)]
_) (Symbol SymbolIdentifier
name2 a
_ [(Symbol a, Symbol a -> Symbol a)]
_) = SymbolIdentifier
name1 SymbolIdentifier -> SymbolIdentifier -> Bool
forall a. Eq a => a -> a -> Bool
== SymbolIdentifier
name2


instance Hashable (Symbol a) where
    hashWithSalt :: Int -> Symbol a -> Int
hashWithSalt Int
salt (Symbol (SymbolIdentifier String
name) a
_ [(Symbol a, Symbol a -> Symbol a)]
_) = Int -> String -> Int
forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
salt String
name


-- Symbolic symbols


instance Symbolic a => Symbolic (Symbol a) where
    symbolicZero :: Symbol a -> Symbol a
symbolicZero Symbol a
x = a -> Symbol a
forall a. a -> Symbol a
constSymbol (a -> Symbol a) -> a -> Symbol a
forall a b. (a -> b) -> a -> b
$ a -> a
forall a. Symbolic a => a -> a
symbolicZero (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ Symbol a -> a
forall a. Symbol a -> a
unSymbol Symbol a
x
    symbolicOne :: Symbol a -> Symbol a
symbolicOne Symbol a
x = a -> Symbol a
forall a. a -> Symbol a
constSymbol (a -> Symbol a) -> a -> Symbol a
forall a b. (a -> b) -> a -> b
$ a -> a
forall a. Symbolic a => a -> a
symbolicOne (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ Symbol a -> a
forall a. Symbol a -> a
unSymbol Symbol a
x


type instance DType (SymbolVec a) = a

type instance DType (SymbolMat a) = a


-- | Converts unary operation into symbolic one.

symbolicUnaryOp :: (a -> a) -> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp :: forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp a -> a
op Symbol a
x = SymbolIdentifier
-> a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
forall a.
SymbolIdentifier
-> a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
Symbol SymbolIdentifier
forall a. Monoid a => a
mempty (a -> a
op (Symbol a -> a
forall a. Symbol a -> a
unSymbol Symbol a
x))

-- | Converts binary operation into symbolic one.

symbolicBinaryOp :: (a -> a -> a) -> Symbol a -> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicBinaryOp :: forall a.
(a -> a -> a)
-> Symbol a
-> Symbol a
-> [(Symbol a, Symbol a -> Symbol a)]
-> Symbol a
symbolicBinaryOp a -> a -> a
op Symbol a
a Symbol a
b = SymbolIdentifier
-> a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
forall a.
SymbolIdentifier
-> a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
Symbol SymbolIdentifier
forall a. Monoid a => a
mempty (a -> a -> a
op (Symbol a -> a
forall a. Symbol a -> a
unSymbol Symbol a
a) (Symbol a -> a
forall a. Symbol a -> a
unSymbol Symbol a
b))

instance Symbolic a => Num (Symbol a) where
    + :: Symbol a -> Symbol a -> Symbol a
(+) Symbol a
a Symbol a
b = (a -> a -> a)
-> Symbol a
-> Symbol a
-> [(Symbol a, Symbol a -> Symbol a)]
-> Symbol a
forall a.
(a -> a -> a)
-> Symbol a
-> Symbol a
-> [(Symbol a, Symbol a -> Symbol a)]
-> Symbol a
symbolicBinaryOp a -> a -> a
forall a. Num a => a -> a -> a
(+) Symbol a
a Symbol a
b [(Symbol a
a, Symbol a -> Symbol a
forall a. a -> a
id), (Symbol a
b, Symbol a -> Symbol a
forall a. a -> a
id)]
    (-) Symbol a
a Symbol a
b = (a -> a -> a)
-> Symbol a
-> Symbol a
-> [(Symbol a, Symbol a -> Symbol a)]
-> Symbol a
forall a.
(a -> a -> a)
-> Symbol a
-> Symbol a
-> [(Symbol a, Symbol a -> Symbol a)]
-> Symbol a
symbolicBinaryOp (-) Symbol a
a Symbol a
b [(Symbol a
a, Symbol a -> Symbol a
forall a. a -> a
id), (Symbol a
b, Symbol a -> Symbol a
forall a. Num a => a -> a
negate)]
    negate :: Symbol a -> Symbol a
negate Symbol a
x = (a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp a -> a
forall a. Num a => a -> a
negate Symbol a
x [(Symbol a
x, Symbol a -> Symbol a
forall a. Num a => a -> a
negate)]
    * :: Symbol a -> Symbol a -> Symbol a
(*) Symbol a
a Symbol a
b = (a -> a -> a)
-> Symbol a
-> Symbol a
-> [(Symbol a, Symbol a -> Symbol a)]
-> Symbol a
forall a.
(a -> a -> a)
-> Symbol a
-> Symbol a
-> [(Symbol a, Symbol a -> Symbol a)]
-> Symbol a
symbolicBinaryOp a -> a -> a
forall a. Num a => a -> a -> a
(*) Symbol a
a Symbol a
b [(Symbol a
a, (Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a
b)), (Symbol a
b, (Symbol a
a Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
*))]
    abs :: Symbol a -> Symbol a
abs Symbol a
x = (a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp a -> a
forall a. Num a => a -> a
abs Symbol a
x [(Symbol a
x, Symbol a -> Symbol a
forall a. Num a => a -> a
signum)]
    signum :: Symbol a -> Symbol a
signum Symbol a
x = (a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp a -> a
forall a. Num a => a -> a
signum Symbol a
x [(Symbol a
x, Symbol a -> Symbol a
forall a. a -> a
id)]
    fromInteger :: Integer -> Symbol a
fromInteger = a -> Symbol a
forall a. a -> Symbol a
constSymbol (a -> Symbol a) -> (Integer -> a) -> Integer -> Symbol a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> a
forall a. Num a => Integer -> a
fromInteger

instance (Symbolic a, Fractional a) => Fractional (Symbol a) where
    / :: Symbol a -> Symbol a -> Symbol a
(/) Symbol a
a Symbol a
b = (a -> a -> a)
-> Symbol a
-> Symbol a
-> [(Symbol a, Symbol a -> Symbol a)]
-> Symbol a
forall a.
(a -> a -> a)
-> Symbol a
-> Symbol a
-> [(Symbol a, Symbol a -> Symbol a)]
-> Symbol a
symbolicBinaryOp a -> a -> a
forall a. Fractional a => a -> a -> a
(/) Symbol a
a Symbol a
b [(Symbol a
a, (Symbol a -> Symbol a -> Symbol a
forall a. Fractional a => a -> a -> a
/ Symbol a
b)), (Symbol a
b, (Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* (Symbol a -> Symbol a
forall a. Num a => a -> a
negate Symbol a
a Symbol a -> Symbol a -> Symbol a
forall a. Fractional a => a -> a -> a
/ (Symbol a
b Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a
b))))]
    recip :: Symbol a -> Symbol a
recip Symbol a
x = (a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp a -> a
forall a. Fractional a => a -> a
recip Symbol a
x [(Symbol a
x, (Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* (Symbol a -> Symbol a
forall a. Num a => a -> a
negate (Symbol a -> Symbol a
forall a. Symbolic a => a -> a
symbolicOne Symbol a
x) Symbol a -> Symbol a -> Symbol a
forall a. Fractional a => a -> a -> a
/ (Symbol a
x Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a
x))))]
    fromRational :: Rational -> Symbol a
fromRational = a -> Symbol a
forall a. a -> Symbol a
constSymbol (a -> Symbol a) -> (Rational -> a) -> Rational -> Symbol a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rational -> a
forall a. Fractional a => Rational -> a
fromRational

instance (Symbolic a, Floating a) => Floating (Symbol a) where
    pi :: Symbol a
pi = a -> Symbol a
forall a. a -> Symbol a
constSymbol a
forall a. Floating a => a
pi
    ** :: Symbol a -> Symbol a -> Symbol a
(**) Symbol a
a Symbol a
b = (a -> a -> a)
-> Symbol a
-> Symbol a
-> [(Symbol a, Symbol a -> Symbol a)]
-> Symbol a
forall a.
(a -> a -> a)
-> Symbol a
-> Symbol a
-> [(Symbol a, Symbol a -> Symbol a)]
-> Symbol a
symbolicBinaryOp a -> a -> a
forall a. Floating a => a -> a -> a
(**) Symbol a
a Symbol a
b [(Symbol a
a, (Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* (Symbol a
b Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a
a Symbol a -> Symbol a -> Symbol a
forall a. Floating a => a -> a -> a
** (Symbol a
b Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
- Symbol a -> Symbol a
forall a. Symbolic a => a -> a
symbolicOne Symbol a
b)))), (Symbol a
b, (Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* (Symbol a
a Symbol a -> Symbol a -> Symbol a
forall a. Floating a => a -> a -> a
** Symbol a
b Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a -> Symbol a
forall a. Floating a => a -> a
log Symbol a
a)))]
    sqrt :: Symbol a -> Symbol a
sqrt Symbol a
x = (a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp a -> a
forall a. Floating a => a -> a
sqrt Symbol a
x [(Symbol a
x, (Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* (Symbol a -> Symbol a
forall a. Fractional a => a -> a
recip (Symbol a -> Symbol a) -> Symbol a -> Symbol a
forall a b. (a -> b) -> a -> b
$ Int -> Symbol a -> Symbol a
forall a. Symbolic a => Int -> a -> a
symbolicN Int
2 Symbol a
x Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a -> Symbol a
forall a. Floating a => a -> a
sqrt Symbol a
x)))]
    exp :: Symbol a -> Symbol a
exp Symbol a
x = (a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp a -> a
forall a. Floating a => a -> a
exp Symbol a
x [(Symbol a
x, (Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a -> Symbol a
forall a. Floating a => a -> a
exp Symbol a
x))]
    log :: Symbol a -> Symbol a
log Symbol a
x = (a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp a -> a
forall a. Floating a => a -> a
log Symbol a
x [(Symbol a
x, (Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a -> Symbol a
forall a. Fractional a => a -> a
recip Symbol a
x))]
    sin :: Symbol a -> Symbol a
sin Symbol a
x = (a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp a -> a
forall a. Floating a => a -> a
sin Symbol a
x [(Symbol a
x, (Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a -> Symbol a
forall a. Floating a => a -> a
cos Symbol a
x))]
    cos :: Symbol a -> Symbol a
cos Symbol a
x = (a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp a -> a
forall a. Floating a => a -> a
cos Symbol a
x [(Symbol a
x, (Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a -> Symbol a
forall a. Num a => a -> a
negate (Symbol a -> Symbol a
forall a. Floating a => a -> a
sin Symbol a
x)))]
    asin :: Symbol a -> Symbol a
asin Symbol a
x = (a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp a -> a
forall a. Floating a => a -> a
asin Symbol a
x [(Symbol a
x, (Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a -> Symbol a
forall a. Fractional a => a -> a
recip (Symbol a -> Symbol a
forall a. Floating a => a -> a
sqrt (Symbol a -> Symbol a
forall a. Symbolic a => a -> a
symbolicOne Symbol a
x Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
- Symbol a
x Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a
x))))]
    acos :: Symbol a -> Symbol a
acos Symbol a
x = (a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp a -> a
forall a. Floating a => a -> a
acos Symbol a
x [(Symbol a
x, (Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a -> Symbol a
forall a. Num a => a -> a
negate (Symbol a -> Symbol a
forall a. Fractional a => a -> a
recip (Symbol a -> Symbol a
forall a. Floating a => a -> a
sqrt (Symbol a -> Symbol a
forall a. Symbolic a => a -> a
symbolicOne Symbol a
x Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
- Symbol a
x Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a
x)))))]
    atan :: Symbol a -> Symbol a
atan Symbol a
x = (a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp a -> a
forall a. Floating a => a -> a
atan Symbol a
x [(Symbol a
x, (Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a -> Symbol a
forall a. Fractional a => a -> a
recip (Symbol a -> Symbol a
forall a. Symbolic a => a -> a
symbolicOne Symbol a
x Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
+ Symbol a
x Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a
x)))]
    sinh :: Symbol a -> Symbol a
sinh Symbol a
x = (a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp a -> a
forall a. Floating a => a -> a
sinh Symbol a
x [(Symbol a
x, (Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a -> Symbol a
forall a. Floating a => a -> a
cosh Symbol a
x))]
    cosh :: Symbol a -> Symbol a
cosh Symbol a
x = (a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp a -> a
forall a. Floating a => a -> a
cosh Symbol a
x [(Symbol a
x, (Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a -> Symbol a
forall a. Floating a => a -> a
sinh Symbol a
x))]
    asinh :: Symbol a -> Symbol a
asinh Symbol a
x = (a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp a -> a
forall a. Floating a => a -> a
asinh Symbol a
x [(Symbol a
x, (Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a -> Symbol a
forall a. Fractional a => a -> a
recip (Symbol a -> Symbol a
forall a. Floating a => a -> a
sqrt (Symbol a -> Symbol a
forall a. Symbolic a => a -> a
symbolicOne Symbol a
x Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
+ Symbol a
x Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a
x))))]
    acosh :: Symbol a -> Symbol a
acosh Symbol a
x = (a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp a -> a
forall a. Floating a => a -> a
acosh Symbol a
x [(Symbol a
x, (Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a -> Symbol a
forall a. Fractional a => a -> a
recip (Symbol a -> Symbol a
forall a. Floating a => a -> a
sqrt (Symbol a
x Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a
x Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
- Symbol a -> Symbol a
forall a. Symbolic a => a -> a
symbolicOne Symbol a
x))))]
    atanh :: Symbol a -> Symbol a
atanh Symbol a
x = (a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp a -> a
forall a. Floating a => a -> a
atanh Symbol a
x [(Symbol a
x, (Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a -> Symbol a
forall a. Fractional a => a -> a
recip (Symbol a -> Symbol a
forall a. Symbolic a => a -> a
symbolicOne Symbol a
x Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
- Symbol a
x Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a
x)))]

instance Symbolic a => ElementwiseScalarOps (Symbol (Vec a)) where
    +. :: Num (DType (Symbol (Vec a))) =>
Symbol (Vec a) -> DType (Symbol (Vec a)) -> Symbol (Vec a)
(+.) Symbol (Vec a)
x DType (Symbol (Vec a))
n = Symbol (Vec a)
x Symbol (Vec a) -> Symbol (Vec a) -> Symbol (Vec a)
forall a. Num a => a -> a -> a
+ Vec a -> Symbol (Vec a)
forall a. a -> Symbol a
constSymbol (Int -> a -> Vec a
forall a. Int -> a -> Vec a
V.replicate (Vec a -> Int
forall a. Vec a -> Int
V.size (Vec a -> Int) -> Vec a -> Int
forall a b. (a -> b) -> a -> b
$ Symbol (Vec a) -> Vec a
forall a. Symbol a -> a
unSymbol Symbol (Vec a)
x) a
DType (Symbol (Vec a))
n)
    -. :: Num (DType (Symbol (Vec a))) =>
Symbol (Vec a) -> DType (Symbol (Vec a)) -> Symbol (Vec a)
(-.) Symbol (Vec a)
x DType (Symbol (Vec a))
n = Symbol (Vec a)
x Symbol (Vec a) -> Symbol (Vec a) -> Symbol (Vec a)
forall a. Num a => a -> a -> a
- Vec a -> Symbol (Vec a)
forall a. a -> Symbol a
constSymbol (Int -> a -> Vec a
forall a. Int -> a -> Vec a
V.replicate (Vec a -> Int
forall a. Vec a -> Int
V.size (Vec a -> Int) -> Vec a -> Int
forall a b. (a -> b) -> a -> b
$ Symbol (Vec a) -> Vec a
forall a. Symbol a -> a
unSymbol Symbol (Vec a)
x) a
DType (Symbol (Vec a))
n)
    *. :: Num (DType (Symbol (Vec a))) =>
Symbol (Vec a) -> DType (Symbol (Vec a)) -> Symbol (Vec a)
(*.) Symbol (Vec a)
x DType (Symbol (Vec a))
n = Symbol (Vec a)
x Symbol (Vec a) -> Symbol (Vec a) -> Symbol (Vec a)
forall a. Num a => a -> a -> a
* Vec a -> Symbol (Vec a)
forall a. a -> Symbol a
constSymbol (Int -> a -> Vec a
forall a. Int -> a -> Vec a
V.replicate (Vec a -> Int
forall a. Vec a -> Int
V.size (Vec a -> Int) -> Vec a -> Int
forall a b. (a -> b) -> a -> b
$ Symbol (Vec a) -> Vec a
forall a. Symbol a -> a
unSymbol Symbol (Vec a)
x) a
DType (Symbol (Vec a))
n)
    /. :: Fractional (DType (Symbol (Vec a))) =>
Symbol (Vec a) -> DType (Symbol (Vec a)) -> Symbol (Vec a)
(/.) Symbol (Vec a)
x DType (Symbol (Vec a))
n = Symbol (Vec a)
x Symbol (Vec a) -> Symbol (Vec a) -> Symbol (Vec a)
forall a. Fractional a => a -> a -> a
/ Vec a -> Symbol (Vec a)
forall a. a -> Symbol a
constSymbol (Int -> a -> Vec a
forall a. Int -> a -> Vec a
V.replicate (Vec a -> Int
forall a. Vec a -> Int
V.size (Vec a -> Int) -> Vec a -> Int
forall a b. (a -> b) -> a -> b
$ Symbol (Vec a) -> Vec a
forall a. Symbol a -> a
unSymbol Symbol (Vec a)
x) a
DType (Symbol (Vec a))
n)
    **. :: Floating (DType (Symbol (Vec a))) =>
Symbol (Vec a) -> DType (Symbol (Vec a)) -> Symbol (Vec a)
(**.) Symbol (Vec a)
x DType (Symbol (Vec a))
n = Symbol (Vec a)
x Symbol (Vec a) -> Symbol (Vec a) -> Symbol (Vec a)
forall a. Floating a => a -> a -> a
** Vec a -> Symbol (Vec a)
forall a. a -> Symbol a
constSymbol (Int -> a -> Vec a
forall a. Int -> a -> Vec a
V.replicate (Vec a -> Int
forall a. Vec a -> Int
V.size (Vec a -> Int) -> Vec a -> Int
forall a b. (a -> b) -> a -> b
$ Symbol (Vec a) -> Vec a
forall a. Symbol a -> a
unSymbol Symbol (Vec a)
x) a
DType (Symbol (Vec a))
n)

    elementsMin :: Ord (DType (Symbol (Vec a))) =>
Symbol (Vec a) -> DType (Symbol (Vec a)) -> Symbol (Vec a)
elementsMin Symbol (Vec a)
x DType (Symbol (Vec a))
n = (Vec a -> Vec a)
-> Symbol (Vec a)
-> [(Symbol (Vec a), Symbol (Vec a) -> Symbol (Vec a))]
-> Symbol (Vec a)
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp (Vec a -> DType (Vec a) -> Vec a
forall f.
(ElementwiseScalarOps f, Ord (DType f)) =>
f -> DType f -> f
`elementsMin` DType (Vec a)
DType (Symbol (Vec a))
n) Symbol (Vec a)
x
                      [(Symbol (Vec a)
x, (Symbol (Vec a) -> Symbol (Vec a) -> Symbol (Vec a)
forall a. Num a => a -> a -> a
* Vec a -> Symbol (Vec a)
forall a. a -> Symbol a
constSymbol (Int -> (Int -> a) -> Vec a
forall a. Int -> (Int -> a) -> Vec a
V.generate (Vec a -> Int
forall a. Vec a -> Int
V.size (Vec a -> Int) -> Vec a -> Int
forall a b. (a -> b) -> a -> b
$ Symbol (Vec a) -> Vec a
forall a. Symbol a -> a
unSymbol Symbol (Vec a)
x) ((Int -> a) -> Vec a) -> (Int -> a) -> Vec a
forall a b. (a -> b) -> a -> b
$ \Int
i -> if Vec a -> Index (Vec a) -> DType (Vec a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex (Symbol (Vec a) -> Vec a
forall a. Symbol a -> a
unSymbol Symbol (Vec a)
x) Int
Index (Vec a)
i a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
DType (Symbol (Vec a))
n then a
1 else a
0)))]
    elementsMax :: Ord (DType (Symbol (Vec a))) =>
Symbol (Vec a) -> DType (Symbol (Vec a)) -> Symbol (Vec a)
elementsMax Symbol (Vec a)
x DType (Symbol (Vec a))
n = (Vec a -> Vec a)
-> Symbol (Vec a)
-> [(Symbol (Vec a), Symbol (Vec a) -> Symbol (Vec a))]
-> Symbol (Vec a)
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp (Vec a -> DType (Vec a) -> Vec a
forall f.
(ElementwiseScalarOps f, Ord (DType f)) =>
f -> DType f -> f
`elementsMax` DType (Vec a)
DType (Symbol (Vec a))
n) Symbol (Vec a)
x
                      [(Symbol (Vec a)
x, (Symbol (Vec a) -> Symbol (Vec a) -> Symbol (Vec a)
forall a. Num a => a -> a -> a
* Vec a -> Symbol (Vec a)
forall a. a -> Symbol a
constSymbol (Int -> (Int -> a) -> Vec a
forall a. Int -> (Int -> a) -> Vec a
V.generate (Vec a -> Int
forall a. Vec a -> Int
V.size (Vec a -> Int) -> Vec a -> Int
forall a b. (a -> b) -> a -> b
$ Symbol (Vec a) -> Vec a
forall a. Symbol a -> a
unSymbol Symbol (Vec a)
x) ((Int -> a) -> Vec a) -> (Int -> a) -> Vec a
forall a b. (a -> b) -> a -> b
$ \Int
i -> if Vec a -> Index (Vec a) -> DType (Vec a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex (Symbol (Vec a) -> Vec a
forall a. Symbol a -> a
unSymbol Symbol (Vec a)
x) Int
Index (Vec a)
i a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
DType (Symbol (Vec a))
n then a
1 else a
0)))]

instance Symbolic a => ElementwiseScalarOps (SymbolMat a) where
    +. :: Num (DType (SymbolMat a)) =>
SymbolMat a -> DType (SymbolMat a) -> SymbolMat a
(+.) SymbolMat a
x DType (SymbolMat a)
n = SymbolMat a
x SymbolMat a -> SymbolMat a -> SymbolMat a
forall a. Num a => a -> a -> a
+ Mat a -> SymbolMat a
forall a. a -> Symbol a
constSymbol ((Int, Int) -> a -> Mat a
forall a. (Int, Int) -> a -> Mat a
M.replicate (Mat a -> (Int, Int)
forall a. Mat a -> (Int, Int)
M.size (Mat a -> (Int, Int)) -> Mat a -> (Int, Int)
forall a b. (a -> b) -> a -> b
$ SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol SymbolMat a
x) a
DType (SymbolMat a)
n)
    -. :: Num (DType (SymbolMat a)) =>
SymbolMat a -> DType (SymbolMat a) -> SymbolMat a
(-.) SymbolMat a
x DType (SymbolMat a)
n = SymbolMat a
x SymbolMat a -> SymbolMat a -> SymbolMat a
forall a. Num a => a -> a -> a
- Mat a -> SymbolMat a
forall a. a -> Symbol a
constSymbol ((Int, Int) -> a -> Mat a
forall a. (Int, Int) -> a -> Mat a
M.replicate (Mat a -> (Int, Int)
forall a. Mat a -> (Int, Int)
M.size (Mat a -> (Int, Int)) -> Mat a -> (Int, Int)
forall a b. (a -> b) -> a -> b
$ SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol SymbolMat a
x) a
DType (SymbolMat a)
n)
    *. :: Num (DType (SymbolMat a)) =>
SymbolMat a -> DType (SymbolMat a) -> SymbolMat a
(*.) SymbolMat a
x DType (SymbolMat a)
n = SymbolMat a
x SymbolMat a -> SymbolMat a -> SymbolMat a
forall a. Num a => a -> a -> a
* Mat a -> SymbolMat a
forall a. a -> Symbol a
constSymbol ((Int, Int) -> a -> Mat a
forall a. (Int, Int) -> a -> Mat a
M.replicate (Mat a -> (Int, Int)
forall a. Mat a -> (Int, Int)
M.size (Mat a -> (Int, Int)) -> Mat a -> (Int, Int)
forall a b. (a -> b) -> a -> b
$ SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol SymbolMat a
x) a
DType (SymbolMat a)
n)
    /. :: Fractional (DType (SymbolMat a)) =>
SymbolMat a -> DType (SymbolMat a) -> SymbolMat a
(/.) SymbolMat a
x DType (SymbolMat a)
n = SymbolMat a
x SymbolMat a -> SymbolMat a -> SymbolMat a
forall a. Fractional a => a -> a -> a
/ Mat a -> SymbolMat a
forall a. a -> Symbol a
constSymbol ((Int, Int) -> a -> Mat a
forall a. (Int, Int) -> a -> Mat a
M.replicate (Mat a -> (Int, Int)
forall a. Mat a -> (Int, Int)
M.size (Mat a -> (Int, Int)) -> Mat a -> (Int, Int)
forall a b. (a -> b) -> a -> b
$ SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol SymbolMat a
x) a
DType (SymbolMat a)
n)
    **. :: Floating (DType (SymbolMat a)) =>
SymbolMat a -> DType (SymbolMat a) -> SymbolMat a
(**.) SymbolMat a
x DType (SymbolMat a)
n = SymbolMat a
x SymbolMat a -> SymbolMat a -> SymbolMat a
forall a. Floating a => a -> a -> a
** Mat a -> SymbolMat a
forall a. a -> Symbol a
constSymbol ((Int, Int) -> a -> Mat a
forall a. (Int, Int) -> a -> Mat a
M.replicate (Mat a -> (Int, Int)
forall a. Mat a -> (Int, Int)
M.size (Mat a -> (Int, Int)) -> Mat a -> (Int, Int)
forall a b. (a -> b) -> a -> b
$ SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol SymbolMat a
x) a
DType (SymbolMat a)
n)

    elementsMin :: Ord (DType (SymbolMat a)) =>
SymbolMat a -> DType (SymbolMat a) -> SymbolMat a
elementsMin SymbolMat a
x DType (SymbolMat a)
n = (Mat a -> Mat a)
-> SymbolMat a
-> [(SymbolMat a, SymbolMat a -> SymbolMat a)]
-> SymbolMat a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp (Mat a -> DType (Mat a) -> Mat a
forall f.
(ElementwiseScalarOps f, Ord (DType f)) =>
f -> DType f -> f
`elementsMin` DType (Mat a)
DType (SymbolMat a)
n) SymbolMat a
x
                      [(SymbolMat a
x, (SymbolMat a -> SymbolMat a -> SymbolMat a
forall a. Num a => a -> a -> a
* Mat a -> SymbolMat a
forall a. a -> Symbol a
constSymbol ((Int, Int) -> ((Int, Int) -> a) -> Mat a
forall a. (Int, Int) -> ((Int, Int) -> a) -> Mat a
M.generate (Mat a -> (Int, Int)
forall a. Mat a -> (Int, Int)
M.size (Mat a -> (Int, Int)) -> Mat a -> (Int, Int)
forall a b. (a -> b) -> a -> b
$ SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol SymbolMat a
x) (((Int, Int) -> a) -> Mat a) -> ((Int, Int) -> a) -> Mat a
forall a b. (a -> b) -> a -> b
$ \(Int, Int)
i -> if Mat a -> Index (Mat a) -> DType (Mat a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex (SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol SymbolMat a
x) (Int, Int)
Index (Mat a)
i a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
DType (SymbolMat a)
n then a
1 else a
0)))]
    elementsMax :: Ord (DType (SymbolMat a)) =>
SymbolMat a -> DType (SymbolMat a) -> SymbolMat a
elementsMax SymbolMat a
x DType (SymbolMat a)
n = (Mat a -> Mat a)
-> SymbolMat a
-> [(SymbolMat a, SymbolMat a -> SymbolMat a)]
-> SymbolMat a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp (Mat a -> DType (Mat a) -> Mat a
forall f.
(ElementwiseScalarOps f, Ord (DType f)) =>
f -> DType f -> f
`elementsMax` DType (Mat a)
DType (SymbolMat a)
n) SymbolMat a
x
                      [(SymbolMat a
x, (SymbolMat a -> SymbolMat a -> SymbolMat a
forall a. Num a => a -> a -> a
* Mat a -> SymbolMat a
forall a. a -> Symbol a
constSymbol ((Int, Int) -> ((Int, Int) -> a) -> Mat a
forall a. (Int, Int) -> ((Int, Int) -> a) -> Mat a
M.generate (Mat a -> (Int, Int)
forall a. Mat a -> (Int, Int)
M.size (Mat a -> (Int, Int)) -> Mat a -> (Int, Int)
forall a b. (a -> b) -> a -> b
$ SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol SymbolMat a
x) (((Int, Int) -> a) -> Mat a) -> ((Int, Int) -> a) -> Mat a
forall a b. (a -> b) -> a -> b
$ \(Int, Int)
i -> if Mat a -> Index (Mat a) -> DType (Mat a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex (SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol SymbolMat a
x) (Int, Int)
Index (Mat a)
i a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
DType (SymbolMat a)
n then a
1 else a
0)))]

instance Symbolic a => SingletonOps (SymbolVec a) where
    singleton :: DType (SymbolVec a) -> SymbolVec a
singleton = Vec a -> SymbolVec a
forall a. a -> Symbol a
constSymbol (Vec a -> SymbolVec a) -> (a -> Vec a) -> a -> SymbolVec a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Vec a
DType (Vec a) -> Vec a
forall f. SingletonOps f => DType f -> f
singleton
    isSingleton :: SymbolVec a -> Bool
isSingleton = Vec a -> Bool
forall f. SingletonOps f => f -> Bool
isSingleton (Vec a -> Bool) -> (SymbolVec a -> Vec a) -> SymbolVec a -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SymbolVec a -> Vec a
forall a. Symbol a -> a
unSymbol
    unSingleton :: SymbolVec a -> DType (SymbolVec a)
unSingleton = Vec a -> a
Vec a -> DType (Vec a)
forall f. SingletonOps f => f -> DType f
unSingleton (Vec a -> a) -> (SymbolVec a -> Vec a) -> SymbolVec a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SymbolVec a -> Vec a
forall a. Symbol a -> a
unSymbol

    extendSingleton :: SymbolVec a -> SymbolVec a -> SymbolVec a
extendSingleton SymbolVec a
vec SymbolVec a
reference = Vec a -> SymbolVec a
forall a. a -> Symbol a
constSymbol (Vec a -> SymbolVec a) -> Vec a -> SymbolVec a
forall a b. (a -> b) -> a -> b
$ Vec a -> Vec a -> Vec a
forall f. SingletonOps f => f -> f -> f
extendSingleton (SymbolVec a -> Vec a
forall a. Symbol a -> a
unSymbol SymbolVec a
vec) (SymbolVec a -> Vec a
forall a. Symbol a -> a
unSymbol SymbolVec a
reference)

    elementsSum :: Num (DType (SymbolVec a)) => SymbolVec a -> SymbolVec a
elementsSum SymbolVec a
x = (Vec a -> Vec a)
-> SymbolVec a
-> [(SymbolVec a, SymbolVec a -> SymbolVec a)]
-> SymbolVec a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp Vec a -> Vec a
forall f. (SingletonOps f, Num (DType f)) => f -> f
elementsSum SymbolVec a
x [(SymbolVec a
x, (SymbolVec a -> SymbolVec a -> SymbolVec a
forall f. SingletonOps f => f -> f -> f
`extendSingleton` SymbolVec a
x))]
    elementsProduct :: Fractional (DType (SymbolVec a)) => SymbolVec a -> SymbolVec a
elementsProduct SymbolVec a
x = let innerProduct :: DType (Vec a)
innerProduct = Vec a -> DType (Vec a)
forall f. SingletonOps f => f -> DType f
unSingleton (Vec a -> DType (Vec a)) -> Vec a -> DType (Vec a)
forall a b. (a -> b) -> a -> b
$ Vec a -> Vec a
forall f. (SingletonOps f, Fractional (DType f)) => f -> f
elementsProduct (Vec a -> Vec a) -> Vec a -> Vec a
forall a b. (a -> b) -> a -> b
$ SymbolVec a -> Vec a
forall a. Symbol a -> a
unSymbol SymbolVec a
x
                        in (Vec a -> Vec a)
-> SymbolVec a
-> [(SymbolVec a, SymbolVec a -> SymbolVec a)]
-> SymbolVec a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp Vec a -> Vec a
forall f. (SingletonOps f, Fractional (DType f)) => f -> f
elementsProduct SymbolVec a
x [(SymbolVec a
x, \SymbolVec a
path -> SymbolVec a -> SymbolVec a -> SymbolVec a
forall f. SingletonOps f => f -> f -> f
extendSingleton SymbolVec a
path SymbolVec a
x SymbolVec a -> SymbolVec a -> SymbolVec a
forall a. Num a => a -> a -> a
* Vec a -> SymbolVec a
forall a. a -> Symbol a
constSymbol (Int -> (Int -> a) -> Vec a
forall a. Int -> (Int -> a) -> Vec a
V.generate (Vec a -> Int
forall a. Vec a -> Int
V.size (Vec a -> Int) -> Vec a -> Int
forall a b. (a -> b) -> a -> b
$ SymbolVec a -> Vec a
forall a. Symbol a -> a
unSymbol SymbolVec a
x) ((Int -> a) -> Vec a) -> (Int -> a) -> Vec a
forall a b. (a -> b) -> a -> b
$ \Int
i -> a
DType (Vec a)
innerProduct a -> a -> a
forall a. Fractional a => a -> a -> a
/ Vec a -> Index (Vec a) -> DType (Vec a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex (SymbolVec a -> Vec a
forall a. Symbol a -> a
unSymbol SymbolVec a
x) Int
Index (Vec a)
i))]

    mean :: Fractional (DType (SymbolVec a)) => SymbolVec a -> SymbolVec a
mean SymbolVec a
x = (Vec a -> Vec a)
-> SymbolVec a
-> [(SymbolVec a, SymbolVec a -> SymbolVec a)]
-> SymbolVec a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp Vec a -> Vec a
forall f. (SingletonOps f, Fractional (DType f)) => f -> f
mean SymbolVec a
x [(SymbolVec a
x, \SymbolVec a
path -> SymbolVec a -> SymbolVec a -> SymbolVec a
forall f. SingletonOps f => f -> f -> f
extendSingleton SymbolVec a
path SymbolVec a
x SymbolVec a -> DType (SymbolVec a) -> SymbolVec a
forall f.
(ElementwiseScalarOps f, Fractional (DType f)) =>
f -> DType f -> f
/. Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Vec a -> Int
forall a. Vec a -> Int
V.size (SymbolVec a -> Vec a
forall a. Symbol a -> a
unSymbol SymbolVec a
x)))]

    norm :: Floating (DType (SymbolVec a)) => SymbolVec a -> SymbolVec a
norm SymbolVec a
x = (Vec a -> Vec a)
-> SymbolVec a
-> [(SymbolVec a, SymbolVec a -> SymbolVec a)]
-> SymbolVec a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp Vec a -> Vec a
forall f. (SingletonOps f, Floating (DType f)) => f -> f
norm SymbolVec a
x [(SymbolVec a
x, \SymbolVec a
path -> SymbolVec a -> SymbolVec a -> SymbolVec a
forall f. SingletonOps f => f -> f -> f
extendSingleton SymbolVec a
path SymbolVec a
x SymbolVec a -> SymbolVec a -> SymbolVec a
forall a. Num a => a -> a -> a
* (SymbolVec a
x SymbolVec a -> DType (SymbolVec a) -> SymbolVec a
forall f.
(ElementwiseScalarOps f, Fractional (DType f)) =>
f -> DType f -> f
/. Vec a -> DType (Vec a)
forall f. SingletonOps f => f -> DType f
unSingleton (Vec a -> Vec a
forall f. (SingletonOps f, Floating (DType f)) => f -> f
norm (Vec a -> Vec a) -> Vec a -> Vec a
forall a b. (a -> b) -> a -> b
$ SymbolVec a -> Vec a
forall a. Symbol a -> a
unSymbol SymbolVec a
x)))]

instance Symbolic a => SingletonOps (SymbolMat a) where
    singleton :: DType (SymbolMat a) -> SymbolMat a
singleton = Mat a -> SymbolMat a
forall a. a -> Symbol a
constSymbol (Mat a -> SymbolMat a) -> (a -> Mat a) -> a -> SymbolMat a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Mat a
DType (Mat a) -> Mat a
forall f. SingletonOps f => DType f -> f
singleton
    isSingleton :: SymbolMat a -> Bool
isSingleton = Mat a -> Bool
forall f. SingletonOps f => f -> Bool
isSingleton (Mat a -> Bool) -> (SymbolMat a -> Mat a) -> SymbolMat a -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol
    unSingleton :: SymbolMat a -> DType (SymbolMat a)
unSingleton = Mat a -> a
Mat a -> DType (Mat a)
forall f. SingletonOps f => f -> DType f
unSingleton (Mat a -> a) -> (SymbolMat a -> Mat a) -> SymbolMat a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol

    extendSingleton :: SymbolMat a -> SymbolMat a -> SymbolMat a
extendSingleton SymbolMat a
mat SymbolMat a
reference = Mat a -> SymbolMat a
forall a. a -> Symbol a
constSymbol (Mat a -> SymbolMat a) -> Mat a -> SymbolMat a
forall a b. (a -> b) -> a -> b
$ Mat a -> Mat a -> Mat a
forall f. SingletonOps f => f -> f -> f
extendSingleton (SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol SymbolMat a
mat) (SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol SymbolMat a
reference)

    elementsSum :: Num (DType (SymbolMat a)) => SymbolMat a -> SymbolMat a
elementsSum SymbolMat a
x = (Mat a -> Mat a)
-> SymbolMat a
-> [(SymbolMat a, SymbolMat a -> SymbolMat a)]
-> SymbolMat a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp Mat a -> Mat a
forall f. (SingletonOps f, Num (DType f)) => f -> f
elementsSum SymbolMat a
x [(SymbolMat a
x, (SymbolMat a -> SymbolMat a -> SymbolMat a
forall f. SingletonOps f => f -> f -> f
`extendSingleton` SymbolMat a
x))]
    elementsProduct :: Fractional (DType (SymbolMat a)) => SymbolMat a -> SymbolMat a
elementsProduct SymbolMat a
x = let innerProduct :: DType (Mat a)
innerProduct = Mat a -> DType (Mat a)
forall f. SingletonOps f => f -> DType f
unSingleton (Mat a -> DType (Mat a)) -> Mat a -> DType (Mat a)
forall a b. (a -> b) -> a -> b
$ Mat a -> Mat a
forall f. (SingletonOps f, Fractional (DType f)) => f -> f
elementsProduct (Mat a -> Mat a) -> Mat a -> Mat a
forall a b. (a -> b) -> a -> b
$ SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol SymbolMat a
x
                        in (Mat a -> Mat a)
-> SymbolMat a
-> [(SymbolMat a, SymbolMat a -> SymbolMat a)]
-> SymbolMat a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp Mat a -> Mat a
forall f. (SingletonOps f, Fractional (DType f)) => f -> f
elementsProduct SymbolMat a
x [(SymbolMat a
x, \SymbolMat a
path -> SymbolMat a -> SymbolMat a -> SymbolMat a
forall f. SingletonOps f => f -> f -> f
extendSingleton SymbolMat a
path SymbolMat a
x SymbolMat a -> SymbolMat a -> SymbolMat a
forall a. Num a => a -> a -> a
* Mat a -> SymbolMat a
forall a. a -> Symbol a
constSymbol ((Int, Int) -> ((Int, Int) -> a) -> Mat a
forall a. (Int, Int) -> ((Int, Int) -> a) -> Mat a
M.generate (Mat a -> (Int, Int)
forall a. Mat a -> (Int, Int)
M.size (Mat a -> (Int, Int)) -> Mat a -> (Int, Int)
forall a b. (a -> b) -> a -> b
$ SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol SymbolMat a
x) (((Int, Int) -> a) -> Mat a) -> ((Int, Int) -> a) -> Mat a
forall a b. (a -> b) -> a -> b
$ \(Int, Int)
i -> a
DType (Mat a)
innerProduct a -> a -> a
forall a. Fractional a => a -> a -> a
/ Mat a -> Index (Mat a) -> DType (Mat a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex (SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol SymbolMat a
x) (Int, Int)
Index (Mat a)
i))]

    mean :: Fractional (DType (SymbolMat a)) => SymbolMat a -> SymbolMat a
mean SymbolMat a
x = (Mat a -> Mat a)
-> SymbolMat a
-> [(SymbolMat a, SymbolMat a -> SymbolMat a)]
-> SymbolMat a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp Mat a -> Mat a
forall f. (SingletonOps f, Fractional (DType f)) => f -> f
mean SymbolMat a
x [(SymbolMat a
x, \SymbolMat a
path -> SymbolMat a -> SymbolMat a -> SymbolMat a
forall f. SingletonOps f => f -> f -> f
extendSingleton SymbolMat a
path SymbolMat a
x SymbolMat a -> DType (SymbolMat a) -> SymbolMat a
forall f.
(ElementwiseScalarOps f, Fractional (DType f)) =>
f -> DType f -> f
/. Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Mat a -> Int
forall a. Mat a -> Int
M.nElements (Mat a -> Int) -> Mat a -> Int
forall a b. (a -> b) -> a -> b
$ SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol SymbolMat a
x))]

    norm :: Floating (DType (SymbolMat a)) => SymbolMat a -> SymbolMat a
norm SymbolMat a
x = (Mat a -> Mat a)
-> SymbolMat a
-> [(SymbolMat a, SymbolMat a -> SymbolMat a)]
-> SymbolMat a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp Mat a -> Mat a
forall f. (SingletonOps f, Floating (DType f)) => f -> f
norm SymbolMat a
x [(SymbolMat a
x, \SymbolMat a
path -> SymbolMat a -> SymbolMat a -> SymbolMat a
forall f. SingletonOps f => f -> f -> f
extendSingleton SymbolMat a
path SymbolMat a
x SymbolMat a -> SymbolMat a -> SymbolMat a
forall a. Num a => a -> a -> a
* (SymbolMat a
x SymbolMat a -> DType (SymbolMat a) -> SymbolMat a
forall f.
(ElementwiseScalarOps f, Fractional (DType f)) =>
f -> DType f -> f
/. Mat a -> DType (Mat a)
forall f. SingletonOps f => f -> DType f
unSingleton (Mat a -> Mat a
forall f. (SingletonOps f, Floating (DType f)) => f -> f
norm (Mat a -> Mat a) -> Mat a -> Mat a
forall a b. (a -> b) -> a -> b
$ SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol SymbolMat a
x)))]

instance Symbolic a => VecOps (SymbolVec a) where
    dot :: Num (DType (SymbolVec a)) =>
SymbolVec a -> SymbolVec a -> SymbolVec a
dot SymbolVec a
a SymbolVec a
b = SymbolVec a -> SymbolVec a
forall f. (SingletonOps f, Num (DType f)) => f -> f
elementsSum (SymbolVec a -> SymbolVec a) -> SymbolVec a -> SymbolVec a
forall a b. (a -> b) -> a -> b
$ SymbolVec a
a SymbolVec a -> SymbolVec a -> SymbolVec a
forall a. Num a => a -> a -> a
* SymbolVec a
b

instance Symbolic a => MatOps (SymbolMat a) where
    transpose :: SymbolMat a -> SymbolMat a
transpose SymbolMat a
x = (Mat a -> Mat a)
-> SymbolMat a
-> [(SymbolMat a, SymbolMat a -> SymbolMat a)]
-> SymbolMat a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp Mat a -> Mat a
forall f. MatOps f => f -> f
M.transpose SymbolMat a
x [(SymbolMat a
x, (SymbolMat a -> SymbolMat a -> SymbolMat a
forall a. Num a => a -> a -> a
* SymbolMat a -> SymbolMat a
forall f. MatOps f => f -> f
transpose SymbolMat a
x))]

    addMatRow :: Num (DType (SymbolMat a)) =>
SymbolMat a -> SymbolMat a -> SymbolMat a
addMatRow SymbolMat a
mat SymbolMat a
row = (Mat a -> Mat a -> Mat a)
-> SymbolMat a
-> SymbolMat a
-> [(SymbolMat a, SymbolMat a -> SymbolMat a)]
-> SymbolMat a
forall a.
(a -> a -> a)
-> Symbol a
-> Symbol a
-> [(Symbol a, Symbol a -> Symbol a)]
-> Symbol a
symbolicBinaryOp Mat a -> Mat a -> Mat a
forall f. (MatOps f, Num (DType f)) => f -> f -> f
addMatRow SymbolMat a
mat SymbolMat a
row [(SymbolMat a
mat, SymbolMat a -> SymbolMat a
forall a. a -> a
id), (SymbolMat a
row, Mat a -> SymbolMat a
forall a. a -> Symbol a
constSymbol (Mat a -> SymbolMat a)
-> (SymbolMat a -> Mat a) -> SymbolMat a -> SymbolMat a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vec a -> Mat a
forall a. Vec a -> Mat a
M.rowVec (Vec a -> Mat a) -> (SymbolMat a -> Vec a) -> SymbolMat a -> Mat a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Mat a -> Int -> Vec a) -> Int -> Mat a -> Vec a
forall a b c. (a -> b -> c) -> b -> a -> c
flip Mat a -> Int -> Vec a
forall a. Mat a -> Int -> Vec a
M.indexRow Int
0 (Mat a -> Vec a) -> (SymbolMat a -> Mat a) -> SymbolMat a -> Vec a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol)]

    matMul :: Num (DType (SymbolMat a)) =>
SymbolMat a -> SymbolMat a -> SymbolMat a
matMul SymbolMat a
a SymbolMat a
b = (Mat a -> Mat a -> Mat a)
-> SymbolMat a
-> SymbolMat a
-> [(SymbolMat a, SymbolMat a -> SymbolMat a)]
-> SymbolMat a
forall a.
(a -> a -> a)
-> Symbol a
-> Symbol a
-> [(Symbol a, Symbol a -> Symbol a)]
-> Symbol a
symbolicBinaryOp Mat a -> Mat a -> Mat a
forall f. (MatOps f, Num (DType f)) => f -> f -> f
M.matMul SymbolMat a
a SymbolMat a
b [(SymbolMat a
a, (SymbolMat a -> SymbolMat a -> SymbolMat a
forall f. (MatOps f, Num (DType f)) => f -> f -> f
`matMul` SymbolMat a -> SymbolMat a
forall f. MatOps f => f -> f
transpose SymbolMat a
b)), (SymbolMat a
b, (SymbolMat a -> SymbolMat a
forall f. MatOps f => f -> f
transpose SymbolMat a
a SymbolMat a -> SymbolMat a -> SymbolMat a
forall f. (MatOps f, Num (DType f)) => f -> f -> f
`matMul`))]


-- Gradients calculation


-- | 'Gradients' datatype holds all gradients of one symbol with respect to other symbols.

newtype Gradients a = Gradients
    { forall a. Gradients a -> HashMap (Symbol a) (Symbol a)
unGradients :: HM.HashMap (Symbol a) (Symbol a)  -- ^ Map of gradients.

    }

-- | Returns key-value pairs of all gradients of symbol.

allGradients :: Gradients a -> [(Symbol a, Symbol a)]
allGradients :: forall a. Gradients a -> [(Symbol a, Symbol a)]
allGradients = HashMap (Symbol a) (Symbol a) -> [(Symbol a, Symbol a)]
forall k v. HashMap k v -> [(k, v)]
HM.toList (HashMap (Symbol a) (Symbol a) -> [(Symbol a, Symbol a)])
-> (Gradients a -> HashMap (Symbol a) (Symbol a))
-> Gradients a
-> [(Symbol a, Symbol a)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Gradients a -> HashMap (Symbol a) (Symbol a)
forall a. Gradients a -> HashMap (Symbol a) (Symbol a)
unGradients


-- Typeclasses


instance Show a => Show (Gradients a) where
    show :: Gradients a -> String
show Gradients a
gradients = [(Symbol a, Symbol a)] -> String
forall a. Show a => a -> String
show ([(Symbol a, Symbol a)] -> String)
-> [(Symbol a, Symbol a)] -> String
forall a b. (a -> b) -> a -> b
$ Gradients a -> [(Symbol a, Symbol a)]
forall a. Gradients a -> [(Symbol a, Symbol a)]
allGradients Gradients a
gradients


-- | Generates 'Gradients' for given symbol.

getGradientsOf :: Symbolic a => Symbol a -> Gradients a
getGradientsOf :: forall a. Symbolic a => Symbol a -> Gradients a
getGradientsOf Symbol a
differentiatedSymbol = HashMap (Symbol a) (Symbol a) -> Gradients a
forall a. HashMap (Symbol a) (Symbol a) -> Gradients a
Gradients (HashMap (Symbol a) (Symbol a) -> Gradients a)
-> HashMap (Symbol a) (Symbol a) -> Gradients a
forall a b. (a -> b) -> a -> b
$ Symbol a
-> Symbol a
-> HashMap (Symbol a) (Symbol a)
-> HashMap (Symbol a) (Symbol a)
forall k v.
(Eq k, Hashable k) =>
k -> v -> HashMap k v -> HashMap k v
HM.insert Symbol a
differentiatedSymbol Symbol a
wrtItself (HashMap (Symbol a) (Symbol a) -> HashMap (Symbol a) (Symbol a))
-> HashMap (Symbol a) (Symbol a) -> HashMap (Symbol a) (Symbol a)
forall a b. (a -> b) -> a -> b
$
                                                  Symbol a
-> HashMap (Symbol a) (Symbol a) -> HashMap (Symbol a) (Symbol a)
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> HashMap k v
HM.delete (SymbolIdentifier
-> a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
forall a.
SymbolIdentifier
-> a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
Symbol SymbolIdentifier
forall a. Monoid a => a
mempty a
forall a. HasCallStack => a
undefined []) (HashMap (Symbol a) (Symbol a) -> HashMap (Symbol a) (Symbol a))
-> HashMap (Symbol a) (Symbol a) -> HashMap (Symbol a) (Symbol a)
forall a b. (a -> b) -> a -> b
$
                                                  HashMap (Symbol a) (Symbol a)
-> Symbol a -> Symbol a -> HashMap (Symbol a) (Symbol a)
forall {a}.
Symbolic a =>
HashMap (Symbol a) (Symbol a)
-> Symbol a -> Symbol a -> HashMap (Symbol a) (Symbol a)
go HashMap (Symbol a) (Symbol a)
forall k v. HashMap k v
HM.empty Symbol a
differentiatedSymbol Symbol a
wrtItself
  where
    wrtItself :: Symbol a
wrtItself = Symbol a -> Symbol a
forall a. Symbolic a => a -> a
symbolicOne Symbol a
differentiatedSymbol

    go :: HashMap (Symbol a) (Symbol a)
-> Symbol a -> Symbol a -> HashMap (Symbol a) (Symbol a)
go HashMap (Symbol a) (Symbol a)
grads Symbol a
s Symbol a
pathValue =
        ((Symbol a, Symbol a -> Symbol a)
 -> HashMap (Symbol a) (Symbol a) -> HashMap (Symbol a) (Symbol a))
-> HashMap (Symbol a) (Symbol a)
-> [(Symbol a, Symbol a -> Symbol a)]
-> HashMap (Symbol a) (Symbol a)
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\(Symbol a
child, Symbol a -> Symbol a
mulPath) HashMap (Symbol a) (Symbol a)
grad -> let pathValue' :: Symbol a
pathValue' = Symbol a -> Symbol a
mulPath Symbol a
pathValue
                                             grad' :: HashMap (Symbol a) (Symbol a)
grad' = (Maybe (Symbol a) -> Maybe (Symbol a))
-> Symbol a
-> HashMap (Symbol a) (Symbol a)
-> HashMap (Symbol a) (Symbol a)
forall k v.
(Eq k, Hashable k) =>
(Maybe v -> Maybe v) -> k -> HashMap k v -> HashMap k v
HM.alter (\Maybe (Symbol a)
e -> Symbol a -> Maybe (Symbol a)
forall a. a -> Maybe a
Just (Symbol a -> Maybe (Symbol a)) -> Symbol a -> Maybe (Symbol a)
forall a b. (a -> b) -> a -> b
$ case Maybe (Symbol a)
e of
                                                                                Maybe (Symbol a)
Nothing -> Symbol a
pathValue'
                                                                                Just Symbol a
x  -> Symbol a
x Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
+ Symbol a
pathValue'
                                                              ) Symbol a
child HashMap (Symbol a) (Symbol a)
grad
                                         in HashMap (Symbol a) (Symbol a)
-> Symbol a -> Symbol a -> HashMap (Symbol a) (Symbol a)
go HashMap (Symbol a) (Symbol a)
grad' Symbol a
child Symbol a
pathValue') HashMap (Symbol a) (Symbol a)
grads (Symbol a -> [(Symbol a, Symbol a -> Symbol a)]
forall a. Symbol a -> [(Symbol a, Symbol a -> Symbol a)]
symbolGradients Symbol a
s)

-- | Chooses gradient with respect to given symbol.

wrt :: Symbolic a => Gradients a -> Symbol a -> Symbol a
wrt :: forall a. Symbolic a => Gradients a -> Symbol a -> Symbol a
wrt Gradients a
gradients Symbol a
x = Symbol a -> Symbol a -> HashMap (Symbol a) (Symbol a) -> Symbol a
forall k v. (Eq k, Hashable k) => v -> k -> HashMap k v -> v
HM.findWithDefault (Symbol a -> Symbol a
forall a. Symbolic a => a -> a
symbolicZero Symbol a
x) Symbol a
x (HashMap (Symbol a) (Symbol a) -> Symbol a)
-> HashMap (Symbol a) (Symbol a) -> Symbol a
forall a b. (a -> b) -> a -> b
$ Gradients a -> HashMap (Symbol a) (Symbol a)
forall a. Gradients a -> HashMap (Symbol a) (Symbol a)
unGradients Gradients a
gradients

-- | Takes partial gradients wrt to all symbols in a list sequentially, returning last result.

nthPartialGradient :: Symbolic a => Symbol a -> [Symbol a] -> Symbol a
nthPartialGradient :: forall a. Symbolic a => Symbol a -> [Symbol a] -> Symbol a
nthPartialGradient = (Symbol a -> Symbol a -> Symbol a)
-> Symbol a -> [Symbol a] -> Symbol a
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((Symbol a -> Symbol a -> Symbol a)
 -> Symbol a -> [Symbol a] -> Symbol a)
-> (Symbol a -> Symbol a -> Symbol a)
-> Symbol a
-> [Symbol a]
-> Symbol a
forall a b. (a -> b) -> a -> b
$ \Symbol a
y Symbol a
x -> Symbol a -> Gradients a
forall a. Symbolic a => Symbol a -> Gradients a
getGradientsOf Symbol a
y Gradients a -> Symbol a -> Symbol a
forall a. Symbolic a => Gradients a -> Symbol a -> Symbol a
`wrt` Symbol a
x

-- | Takes nth order gradient of one symbol wrt other symbol. If n is negative number, an error is returned.

nthGradient :: Symbolic a => Int -> Symbol a -> Symbol a -> Symbol a
nthGradient :: forall a. Symbolic a => Int -> Symbol a -> Symbol a -> Symbol a
nthGradient Int
n Symbol a
y Symbol a
x
    | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 = String -> Symbol a
forall a. HasCallStack => String -> a
error String
"Cannot take negative order gradient"
    | Bool
otherwise = Symbol a -> [Symbol a] -> Symbol a
forall a. Symbolic a => Symbol a -> [Symbol a] -> Symbol a
nthPartialGradient Symbol a
y (Int -> Symbol a -> [Symbol a]
forall a. Int -> a -> [a]
replicate Int
n Symbol a
x)