{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TypeFamilies #-}
module Synapse.Autograd
(
Symbolic (symbolicZero, symbolicOne, symbolicN)
, SymbolIdentifier (SymbolIdentifier, unSymbolIdentifier)
, Symbol (Symbol, symbolIdentifier, unSymbol, symbolGradients)
, SymbolVec
, SymbolMat
, symbol
, constSymbol
, renameSymbol
, symbolicUnaryOp
, symbolicBinaryOp
, Gradients (unGradients)
, allGradients
, getGradientsOf
, wrt
, nthPartialGradient
, nthGradient
) where
import Synapse.Tensors (DType, Indexable(..), ElementwiseScalarOps(..), ElementwiseScalarOps(..), SingletonOps(..), VecOps(..), MatOps(..))
import Synapse.Tensors.Vec (Vec)
import qualified Synapse.Tensors.Vec as V
import Synapse.Tensors.Mat (Mat)
import qualified Synapse.Tensors.Mat as M
import Data.Foldable (foldl')
import Data.String (IsString(..))
import Data.Hashable (Hashable(..))
import qualified Data.HashMap.Lazy as HM
class (Eq a, Num a) => Symbolic a where
symbolicZero :: a -> a
symbolicOne :: a -> a
symbolicN :: Int -> a -> a
symbolicN Int
n a
c
| Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 = a -> a
forall a. Num a => a -> a
negate (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ Int -> a -> a
forall {t} {t}. (Num t, Symbolic t, Eq t) => t -> t -> t
go (Int -> Int
forall a. Num a => a -> a
abs Int
n) a
c
| Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = a -> a
forall a. Symbolic a => a -> a
symbolicZero a
c
| Bool
otherwise = Int -> a -> a
forall {t} {t}. (Num t, Symbolic t, Eq t) => t -> t -> t
go Int
n a
c
where
go :: t -> t -> t
go t
0 t
x = t
x
go t
i t
x = t -> t -> t
go (t
i t -> t -> t
forall a. Num a => a -> a -> a
- t
1) (t
x t -> t -> t
forall a. Num a => a -> a -> a
+ t -> t
forall a. Symbolic a => a -> a
symbolicOne t
x)
instance Symbolic Int where
symbolicZero :: Int -> Int
symbolicZero = Int -> Int -> Int
forall a b. a -> b -> a
const Int
0
symbolicOne :: Int -> Int
symbolicOne = Int -> Int -> Int
forall a b. a -> b -> a
const Int
1
symbolicN :: Int -> Int -> Int
symbolicN = Int -> Int -> Int
forall a b. a -> b -> a
const
instance Symbolic Float where
symbolicZero :: Float -> Float
symbolicZero = Float -> Float -> Float
forall a b. a -> b -> a
const Float
0.0
symbolicOne :: Float -> Float
symbolicOne = Float -> Float -> Float
forall a b. a -> b -> a
const Float
1.0
symbolicN :: Int -> Float -> Float
symbolicN Int
n = Float -> Float -> Float
forall a b. a -> b -> a
const (Int -> Float
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n)
instance Symbolic Double where
symbolicZero :: Double -> Double
symbolicZero = Double -> Double -> Double
forall a b. a -> b -> a
const Double
0.0
symbolicOne :: Double -> Double
symbolicOne = Double -> Double -> Double
forall a b. a -> b -> a
const Double
1.0
symbolicN :: Int -> Double -> Double
symbolicN Int
n = Double -> Double -> Double
forall a b. a -> b -> a
const (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n)
instance Symbolic a => Symbolic (Vec a) where
symbolicZero :: Vec a -> Vec a
symbolicZero Vec a
reference = Int -> a -> Vec a
forall a. Int -> a -> Vec a
V.replicate (Vec a -> Int
forall a. Vec a -> Int
V.size Vec a
reference) a
0
symbolicOne :: Vec a -> Vec a
symbolicOne Vec a
reference = Int -> a -> Vec a
forall a. Int -> a -> Vec a
V.replicate (Vec a -> Int
forall a. Vec a -> Int
V.size Vec a
reference) a
1
symbolicN :: Int -> Vec a -> Vec a
symbolicN Int
n Vec a
reference = Int -> a -> Vec a
forall a. Int -> a -> Vec a
V.replicate (Vec a -> Int
forall a. Vec a -> Int
V.size Vec a
reference) (Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n)
instance Symbolic a => Symbolic (Mat a) where
symbolicZero :: Mat a -> Mat a
symbolicZero Mat a
reference = (Int, Int) -> a -> Mat a
forall a. (Int, Int) -> a -> Mat a
M.replicate (Mat a -> (Int, Int)
forall a. Mat a -> (Int, Int)
M.size Mat a
reference) a
0
symbolicOne :: Mat a -> Mat a
symbolicOne Mat a
reference = (Int, Int) -> a -> Mat a
forall a. (Int, Int) -> a -> Mat a
M.replicate (Mat a -> (Int, Int)
forall a. Mat a -> (Int, Int)
M.size Mat a
reference) a
1
symbolicN :: Int -> Mat a -> Mat a
symbolicN Int
n Mat a
reference = (Int, Int) -> a -> Mat a
forall a. (Int, Int) -> a -> Mat a
M.replicate (Mat a -> (Int, Int)
forall a. Mat a -> (Int, Int)
M.size Mat a
reference) (Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n)
newtype SymbolIdentifier = SymbolIdentifier
{ SymbolIdentifier -> String
unSymbolIdentifier :: String
} deriving (SymbolIdentifier -> SymbolIdentifier -> Bool
(SymbolIdentifier -> SymbolIdentifier -> Bool)
-> (SymbolIdentifier -> SymbolIdentifier -> Bool)
-> Eq SymbolIdentifier
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: SymbolIdentifier -> SymbolIdentifier -> Bool
== :: SymbolIdentifier -> SymbolIdentifier -> Bool
$c/= :: SymbolIdentifier -> SymbolIdentifier -> Bool
/= :: SymbolIdentifier -> SymbolIdentifier -> Bool
Eq, Int -> SymbolIdentifier -> ShowS
[SymbolIdentifier] -> ShowS
SymbolIdentifier -> String
(Int -> SymbolIdentifier -> ShowS)
-> (SymbolIdentifier -> String)
-> ([SymbolIdentifier] -> ShowS)
-> Show SymbolIdentifier
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> SymbolIdentifier -> ShowS
showsPrec :: Int -> SymbolIdentifier -> ShowS
$cshow :: SymbolIdentifier -> String
show :: SymbolIdentifier -> String
$cshowList :: [SymbolIdentifier] -> ShowS
showList :: [SymbolIdentifier] -> ShowS
Show)
instance IsString SymbolIdentifier where
fromString :: String -> SymbolIdentifier
fromString = String -> SymbolIdentifier
SymbolIdentifier
instance Semigroup SymbolIdentifier where
<> :: SymbolIdentifier -> SymbolIdentifier -> SymbolIdentifier
(<>) (SymbolIdentifier String
a) (SymbolIdentifier String
b) = String -> SymbolIdentifier
SymbolIdentifier (String -> SymbolIdentifier) -> String -> SymbolIdentifier
forall a b. (a -> b) -> a -> b
$ String
a String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
b
instance Monoid SymbolIdentifier where
mempty :: SymbolIdentifier
mempty = String -> SymbolIdentifier
SymbolIdentifier String
""
data Symbol a = Symbol
{ forall a. Symbol a -> SymbolIdentifier
symbolIdentifier :: SymbolIdentifier
, forall a. Symbol a -> a
unSymbol :: a
, forall a. Symbol a -> [(Symbol a, Symbol a -> Symbol a)]
symbolGradients :: [(Symbol a, Symbol a -> Symbol a)]
}
symbol :: SymbolIdentifier -> a -> Symbol a
symbol :: forall a. SymbolIdentifier -> a -> Symbol a
symbol SymbolIdentifier
name a
value = SymbolIdentifier
-> a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
forall a.
SymbolIdentifier
-> a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
Symbol SymbolIdentifier
name a
value []
constSymbol :: a -> Symbol a
constSymbol :: forall a. a -> Symbol a
constSymbol = SymbolIdentifier -> a -> Symbol a
forall a. SymbolIdentifier -> a -> Symbol a
symbol SymbolIdentifier
forall a. Monoid a => a
mempty
renameSymbol :: SymbolIdentifier -> Symbol a -> Symbol a
renameSymbol :: forall a. SymbolIdentifier -> Symbol a -> Symbol a
renameSymbol SymbolIdentifier
name (Symbol SymbolIdentifier
_ a
value [(Symbol a, Symbol a -> Symbol a)]
localGradients) = SymbolIdentifier
-> a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
forall a.
SymbolIdentifier
-> a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
Symbol SymbolIdentifier
name a
value [(Symbol a, Symbol a -> Symbol a)]
localGradients
type SymbolVec a = Symbol (Vec a)
type SymbolMat a = Symbol (Mat a)
instance Show a => Show (Symbol a) where
show :: Symbol a -> String
show (Symbol SymbolIdentifier
name a
value [(Symbol a, Symbol a -> Symbol a)]
_) = String
"Symbol " String -> ShowS
forall a. [a] -> [a] -> [a]
++ SymbolIdentifier -> String
forall a. Show a => a -> String
show SymbolIdentifier
name String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
": " String -> ShowS
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
value
instance Eq (Symbol a) where
== :: Symbol a -> Symbol a -> Bool
(==) (Symbol SymbolIdentifier
name1 a
_ [(Symbol a, Symbol a -> Symbol a)]
_) (Symbol SymbolIdentifier
name2 a
_ [(Symbol a, Symbol a -> Symbol a)]
_) = SymbolIdentifier
name1 SymbolIdentifier -> SymbolIdentifier -> Bool
forall a. Eq a => a -> a -> Bool
== SymbolIdentifier
name2
instance Hashable (Symbol a) where
hashWithSalt :: Int -> Symbol a -> Int
hashWithSalt Int
salt (Symbol (SymbolIdentifier String
name) a
_ [(Symbol a, Symbol a -> Symbol a)]
_) = Int -> String -> Int
forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
salt String
name
instance Symbolic a => Symbolic (Symbol a) where
symbolicZero :: Symbol a -> Symbol a
symbolicZero Symbol a
x = a -> Symbol a
forall a. a -> Symbol a
constSymbol (a -> Symbol a) -> a -> Symbol a
forall a b. (a -> b) -> a -> b
$ a -> a
forall a. Symbolic a => a -> a
symbolicZero (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ Symbol a -> a
forall a. Symbol a -> a
unSymbol Symbol a
x
symbolicOne :: Symbol a -> Symbol a
symbolicOne Symbol a
x = a -> Symbol a
forall a. a -> Symbol a
constSymbol (a -> Symbol a) -> a -> Symbol a
forall a b. (a -> b) -> a -> b
$ a -> a
forall a. Symbolic a => a -> a
symbolicOne (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ Symbol a -> a
forall a. Symbol a -> a
unSymbol Symbol a
x
type instance DType (SymbolVec a) = a
type instance DType (SymbolMat a) = a
symbolicUnaryOp :: (a -> a) -> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp :: forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp a -> a
op Symbol a
x = SymbolIdentifier
-> a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
forall a.
SymbolIdentifier
-> a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
Symbol SymbolIdentifier
forall a. Monoid a => a
mempty (a -> a
op (Symbol a -> a
forall a. Symbol a -> a
unSymbol Symbol a
x))
symbolicBinaryOp :: (a -> a -> a) -> Symbol a -> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicBinaryOp :: forall a.
(a -> a -> a)
-> Symbol a
-> Symbol a
-> [(Symbol a, Symbol a -> Symbol a)]
-> Symbol a
symbolicBinaryOp a -> a -> a
op Symbol a
a Symbol a
b = SymbolIdentifier
-> a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
forall a.
SymbolIdentifier
-> a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
Symbol SymbolIdentifier
forall a. Monoid a => a
mempty (a -> a -> a
op (Symbol a -> a
forall a. Symbol a -> a
unSymbol Symbol a
a) (Symbol a -> a
forall a. Symbol a -> a
unSymbol Symbol a
b))
instance Symbolic a => Num (Symbol a) where
+ :: Symbol a -> Symbol a -> Symbol a
(+) Symbol a
a Symbol a
b = (a -> a -> a)
-> Symbol a
-> Symbol a
-> [(Symbol a, Symbol a -> Symbol a)]
-> Symbol a
forall a.
(a -> a -> a)
-> Symbol a
-> Symbol a
-> [(Symbol a, Symbol a -> Symbol a)]
-> Symbol a
symbolicBinaryOp a -> a -> a
forall a. Num a => a -> a -> a
(+) Symbol a
a Symbol a
b [(Symbol a
a, Symbol a -> Symbol a
forall a. a -> a
id), (Symbol a
b, Symbol a -> Symbol a
forall a. a -> a
id)]
(-) Symbol a
a Symbol a
b = (a -> a -> a)
-> Symbol a
-> Symbol a
-> [(Symbol a, Symbol a -> Symbol a)]
-> Symbol a
forall a.
(a -> a -> a)
-> Symbol a
-> Symbol a
-> [(Symbol a, Symbol a -> Symbol a)]
-> Symbol a
symbolicBinaryOp (-) Symbol a
a Symbol a
b [(Symbol a
a, Symbol a -> Symbol a
forall a. a -> a
id), (Symbol a
b, Symbol a -> Symbol a
forall a. Num a => a -> a
negate)]
negate :: Symbol a -> Symbol a
negate Symbol a
x = (a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp a -> a
forall a. Num a => a -> a
negate Symbol a
x [(Symbol a
x, Symbol a -> Symbol a
forall a. Num a => a -> a
negate)]
* :: Symbol a -> Symbol a -> Symbol a
(*) Symbol a
a Symbol a
b = (a -> a -> a)
-> Symbol a
-> Symbol a
-> [(Symbol a, Symbol a -> Symbol a)]
-> Symbol a
forall a.
(a -> a -> a)
-> Symbol a
-> Symbol a
-> [(Symbol a, Symbol a -> Symbol a)]
-> Symbol a
symbolicBinaryOp a -> a -> a
forall a. Num a => a -> a -> a
(*) Symbol a
a Symbol a
b [(Symbol a
a, (Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a
b)), (Symbol a
b, (Symbol a
a Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
*))]
abs :: Symbol a -> Symbol a
abs Symbol a
x = (a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp a -> a
forall a. Num a => a -> a
abs Symbol a
x [(Symbol a
x, Symbol a -> Symbol a
forall a. Num a => a -> a
signum)]
signum :: Symbol a -> Symbol a
signum Symbol a
x = (a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp a -> a
forall a. Num a => a -> a
signum Symbol a
x [(Symbol a
x, Symbol a -> Symbol a
forall a. a -> a
id)]
fromInteger :: Integer -> Symbol a
fromInteger = a -> Symbol a
forall a. a -> Symbol a
constSymbol (a -> Symbol a) -> (Integer -> a) -> Integer -> Symbol a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> a
forall a. Num a => Integer -> a
fromInteger
instance (Symbolic a, Fractional a) => Fractional (Symbol a) where
/ :: Symbol a -> Symbol a -> Symbol a
(/) Symbol a
a Symbol a
b = (a -> a -> a)
-> Symbol a
-> Symbol a
-> [(Symbol a, Symbol a -> Symbol a)]
-> Symbol a
forall a.
(a -> a -> a)
-> Symbol a
-> Symbol a
-> [(Symbol a, Symbol a -> Symbol a)]
-> Symbol a
symbolicBinaryOp a -> a -> a
forall a. Fractional a => a -> a -> a
(/) Symbol a
a Symbol a
b [(Symbol a
a, (Symbol a -> Symbol a -> Symbol a
forall a. Fractional a => a -> a -> a
/ Symbol a
b)), (Symbol a
b, (Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* (Symbol a -> Symbol a
forall a. Num a => a -> a
negate Symbol a
a Symbol a -> Symbol a -> Symbol a
forall a. Fractional a => a -> a -> a
/ (Symbol a
b Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a
b))))]
recip :: Symbol a -> Symbol a
recip Symbol a
x = (a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp a -> a
forall a. Fractional a => a -> a
recip Symbol a
x [(Symbol a
x, (Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* (Symbol a -> Symbol a
forall a. Num a => a -> a
negate (Symbol a -> Symbol a
forall a. Symbolic a => a -> a
symbolicOne Symbol a
x) Symbol a -> Symbol a -> Symbol a
forall a. Fractional a => a -> a -> a
/ (Symbol a
x Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a
x))))]
fromRational :: Rational -> Symbol a
fromRational = a -> Symbol a
forall a. a -> Symbol a
constSymbol (a -> Symbol a) -> (Rational -> a) -> Rational -> Symbol a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rational -> a
forall a. Fractional a => Rational -> a
fromRational
instance (Symbolic a, Floating a) => Floating (Symbol a) where
pi :: Symbol a
pi = a -> Symbol a
forall a. a -> Symbol a
constSymbol a
forall a. Floating a => a
pi
** :: Symbol a -> Symbol a -> Symbol a
(**) Symbol a
a Symbol a
b = (a -> a -> a)
-> Symbol a
-> Symbol a
-> [(Symbol a, Symbol a -> Symbol a)]
-> Symbol a
forall a.
(a -> a -> a)
-> Symbol a
-> Symbol a
-> [(Symbol a, Symbol a -> Symbol a)]
-> Symbol a
symbolicBinaryOp a -> a -> a
forall a. Floating a => a -> a -> a
(**) Symbol a
a Symbol a
b [(Symbol a
a, (Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* (Symbol a
b Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a
a Symbol a -> Symbol a -> Symbol a
forall a. Floating a => a -> a -> a
** (Symbol a
b Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
- Symbol a -> Symbol a
forall a. Symbolic a => a -> a
symbolicOne Symbol a
b)))), (Symbol a
b, (Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* (Symbol a
a Symbol a -> Symbol a -> Symbol a
forall a. Floating a => a -> a -> a
** Symbol a
b Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a -> Symbol a
forall a. Floating a => a -> a
log Symbol a
a)))]
sqrt :: Symbol a -> Symbol a
sqrt Symbol a
x = (a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp a -> a
forall a. Floating a => a -> a
sqrt Symbol a
x [(Symbol a
x, (Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* (Symbol a -> Symbol a
forall a. Fractional a => a -> a
recip (Symbol a -> Symbol a) -> Symbol a -> Symbol a
forall a b. (a -> b) -> a -> b
$ Int -> Symbol a -> Symbol a
forall a. Symbolic a => Int -> a -> a
symbolicN Int
2 Symbol a
x Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a -> Symbol a
forall a. Floating a => a -> a
sqrt Symbol a
x)))]
exp :: Symbol a -> Symbol a
exp Symbol a
x = (a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp a -> a
forall a. Floating a => a -> a
exp Symbol a
x [(Symbol a
x, (Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a -> Symbol a
forall a. Floating a => a -> a
exp Symbol a
x))]
log :: Symbol a -> Symbol a
log Symbol a
x = (a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp a -> a
forall a. Floating a => a -> a
log Symbol a
x [(Symbol a
x, (Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a -> Symbol a
forall a. Fractional a => a -> a
recip Symbol a
x))]
sin :: Symbol a -> Symbol a
sin Symbol a
x = (a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp a -> a
forall a. Floating a => a -> a
sin Symbol a
x [(Symbol a
x, (Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a -> Symbol a
forall a. Floating a => a -> a
cos Symbol a
x))]
cos :: Symbol a -> Symbol a
cos Symbol a
x = (a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp a -> a
forall a. Floating a => a -> a
cos Symbol a
x [(Symbol a
x, (Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a -> Symbol a
forall a. Num a => a -> a
negate (Symbol a -> Symbol a
forall a. Floating a => a -> a
sin Symbol a
x)))]
asin :: Symbol a -> Symbol a
asin Symbol a
x = (a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp a -> a
forall a. Floating a => a -> a
asin Symbol a
x [(Symbol a
x, (Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a -> Symbol a
forall a. Fractional a => a -> a
recip (Symbol a -> Symbol a
forall a. Floating a => a -> a
sqrt (Symbol a -> Symbol a
forall a. Symbolic a => a -> a
symbolicOne Symbol a
x Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
- Symbol a
x Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a
x))))]
acos :: Symbol a -> Symbol a
acos Symbol a
x = (a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp a -> a
forall a. Floating a => a -> a
acos Symbol a
x [(Symbol a
x, (Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a -> Symbol a
forall a. Num a => a -> a
negate (Symbol a -> Symbol a
forall a. Fractional a => a -> a
recip (Symbol a -> Symbol a
forall a. Floating a => a -> a
sqrt (Symbol a -> Symbol a
forall a. Symbolic a => a -> a
symbolicOne Symbol a
x Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
- Symbol a
x Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a
x)))))]
atan :: Symbol a -> Symbol a
atan Symbol a
x = (a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp a -> a
forall a. Floating a => a -> a
atan Symbol a
x [(Symbol a
x, (Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a -> Symbol a
forall a. Fractional a => a -> a
recip (Symbol a -> Symbol a
forall a. Symbolic a => a -> a
symbolicOne Symbol a
x Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
+ Symbol a
x Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a
x)))]
sinh :: Symbol a -> Symbol a
sinh Symbol a
x = (a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp a -> a
forall a. Floating a => a -> a
sinh Symbol a
x [(Symbol a
x, (Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a -> Symbol a
forall a. Floating a => a -> a
cosh Symbol a
x))]
cosh :: Symbol a -> Symbol a
cosh Symbol a
x = (a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp a -> a
forall a. Floating a => a -> a
cosh Symbol a
x [(Symbol a
x, (Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a -> Symbol a
forall a. Floating a => a -> a
sinh Symbol a
x))]
asinh :: Symbol a -> Symbol a
asinh Symbol a
x = (a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp a -> a
forall a. Floating a => a -> a
asinh Symbol a
x [(Symbol a
x, (Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a -> Symbol a
forall a. Fractional a => a -> a
recip (Symbol a -> Symbol a
forall a. Floating a => a -> a
sqrt (Symbol a -> Symbol a
forall a. Symbolic a => a -> a
symbolicOne Symbol a
x Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
+ Symbol a
x Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a
x))))]
acosh :: Symbol a -> Symbol a
acosh Symbol a
x = (a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp a -> a
forall a. Floating a => a -> a
acosh Symbol a
x [(Symbol a
x, (Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a -> Symbol a
forall a. Fractional a => a -> a
recip (Symbol a -> Symbol a
forall a. Floating a => a -> a
sqrt (Symbol a
x Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a
x Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
- Symbol a -> Symbol a
forall a. Symbolic a => a -> a
symbolicOne Symbol a
x))))]
atanh :: Symbol a -> Symbol a
atanh Symbol a
x = (a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp a -> a
forall a. Floating a => a -> a
atanh Symbol a
x [(Symbol a
x, (Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a -> Symbol a
forall a. Fractional a => a -> a
recip (Symbol a -> Symbol a
forall a. Symbolic a => a -> a
symbolicOne Symbol a
x Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
- Symbol a
x Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
* Symbol a
x)))]
instance Symbolic a => ElementwiseScalarOps (Symbol (Vec a)) where
+. :: Num (DType (Symbol (Vec a))) =>
Symbol (Vec a) -> DType (Symbol (Vec a)) -> Symbol (Vec a)
(+.) Symbol (Vec a)
x DType (Symbol (Vec a))
n = Symbol (Vec a)
x Symbol (Vec a) -> Symbol (Vec a) -> Symbol (Vec a)
forall a. Num a => a -> a -> a
+ Vec a -> Symbol (Vec a)
forall a. a -> Symbol a
constSymbol (Int -> a -> Vec a
forall a. Int -> a -> Vec a
V.replicate (Vec a -> Int
forall a. Vec a -> Int
V.size (Vec a -> Int) -> Vec a -> Int
forall a b. (a -> b) -> a -> b
$ Symbol (Vec a) -> Vec a
forall a. Symbol a -> a
unSymbol Symbol (Vec a)
x) a
DType (Symbol (Vec a))
n)
-. :: Num (DType (Symbol (Vec a))) =>
Symbol (Vec a) -> DType (Symbol (Vec a)) -> Symbol (Vec a)
(-.) Symbol (Vec a)
x DType (Symbol (Vec a))
n = Symbol (Vec a)
x Symbol (Vec a) -> Symbol (Vec a) -> Symbol (Vec a)
forall a. Num a => a -> a -> a
- Vec a -> Symbol (Vec a)
forall a. a -> Symbol a
constSymbol (Int -> a -> Vec a
forall a. Int -> a -> Vec a
V.replicate (Vec a -> Int
forall a. Vec a -> Int
V.size (Vec a -> Int) -> Vec a -> Int
forall a b. (a -> b) -> a -> b
$ Symbol (Vec a) -> Vec a
forall a. Symbol a -> a
unSymbol Symbol (Vec a)
x) a
DType (Symbol (Vec a))
n)
*. :: Num (DType (Symbol (Vec a))) =>
Symbol (Vec a) -> DType (Symbol (Vec a)) -> Symbol (Vec a)
(*.) Symbol (Vec a)
x DType (Symbol (Vec a))
n = Symbol (Vec a)
x Symbol (Vec a) -> Symbol (Vec a) -> Symbol (Vec a)
forall a. Num a => a -> a -> a
* Vec a -> Symbol (Vec a)
forall a. a -> Symbol a
constSymbol (Int -> a -> Vec a
forall a. Int -> a -> Vec a
V.replicate (Vec a -> Int
forall a. Vec a -> Int
V.size (Vec a -> Int) -> Vec a -> Int
forall a b. (a -> b) -> a -> b
$ Symbol (Vec a) -> Vec a
forall a. Symbol a -> a
unSymbol Symbol (Vec a)
x) a
DType (Symbol (Vec a))
n)
/. :: Fractional (DType (Symbol (Vec a))) =>
Symbol (Vec a) -> DType (Symbol (Vec a)) -> Symbol (Vec a)
(/.) Symbol (Vec a)
x DType (Symbol (Vec a))
n = Symbol (Vec a)
x Symbol (Vec a) -> Symbol (Vec a) -> Symbol (Vec a)
forall a. Fractional a => a -> a -> a
/ Vec a -> Symbol (Vec a)
forall a. a -> Symbol a
constSymbol (Int -> a -> Vec a
forall a. Int -> a -> Vec a
V.replicate (Vec a -> Int
forall a. Vec a -> Int
V.size (Vec a -> Int) -> Vec a -> Int
forall a b. (a -> b) -> a -> b
$ Symbol (Vec a) -> Vec a
forall a. Symbol a -> a
unSymbol Symbol (Vec a)
x) a
DType (Symbol (Vec a))
n)
**. :: Floating (DType (Symbol (Vec a))) =>
Symbol (Vec a) -> DType (Symbol (Vec a)) -> Symbol (Vec a)
(**.) Symbol (Vec a)
x DType (Symbol (Vec a))
n = Symbol (Vec a)
x Symbol (Vec a) -> Symbol (Vec a) -> Symbol (Vec a)
forall a. Floating a => a -> a -> a
** Vec a -> Symbol (Vec a)
forall a. a -> Symbol a
constSymbol (Int -> a -> Vec a
forall a. Int -> a -> Vec a
V.replicate (Vec a -> Int
forall a. Vec a -> Int
V.size (Vec a -> Int) -> Vec a -> Int
forall a b. (a -> b) -> a -> b
$ Symbol (Vec a) -> Vec a
forall a. Symbol a -> a
unSymbol Symbol (Vec a)
x) a
DType (Symbol (Vec a))
n)
elementsMin :: Ord (DType (Symbol (Vec a))) =>
Symbol (Vec a) -> DType (Symbol (Vec a)) -> Symbol (Vec a)
elementsMin Symbol (Vec a)
x DType (Symbol (Vec a))
n = (Vec a -> Vec a)
-> Symbol (Vec a)
-> [(Symbol (Vec a), Symbol (Vec a) -> Symbol (Vec a))]
-> Symbol (Vec a)
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp (Vec a -> DType (Vec a) -> Vec a
forall f.
(ElementwiseScalarOps f, Ord (DType f)) =>
f -> DType f -> f
`elementsMin` DType (Vec a)
DType (Symbol (Vec a))
n) Symbol (Vec a)
x
[(Symbol (Vec a)
x, (Symbol (Vec a) -> Symbol (Vec a) -> Symbol (Vec a)
forall a. Num a => a -> a -> a
* Vec a -> Symbol (Vec a)
forall a. a -> Symbol a
constSymbol (Int -> (Int -> a) -> Vec a
forall a. Int -> (Int -> a) -> Vec a
V.generate (Vec a -> Int
forall a. Vec a -> Int
V.size (Vec a -> Int) -> Vec a -> Int
forall a b. (a -> b) -> a -> b
$ Symbol (Vec a) -> Vec a
forall a. Symbol a -> a
unSymbol Symbol (Vec a)
x) ((Int -> a) -> Vec a) -> (Int -> a) -> Vec a
forall a b. (a -> b) -> a -> b
$ \Int
i -> if Vec a -> Index (Vec a) -> DType (Vec a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex (Symbol (Vec a) -> Vec a
forall a. Symbol a -> a
unSymbol Symbol (Vec a)
x) Int
Index (Vec a)
i a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
DType (Symbol (Vec a))
n then a
1 else a
0)))]
elementsMax :: Ord (DType (Symbol (Vec a))) =>
Symbol (Vec a) -> DType (Symbol (Vec a)) -> Symbol (Vec a)
elementsMax Symbol (Vec a)
x DType (Symbol (Vec a))
n = (Vec a -> Vec a)
-> Symbol (Vec a)
-> [(Symbol (Vec a), Symbol (Vec a) -> Symbol (Vec a))]
-> Symbol (Vec a)
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp (Vec a -> DType (Vec a) -> Vec a
forall f.
(ElementwiseScalarOps f, Ord (DType f)) =>
f -> DType f -> f
`elementsMax` DType (Vec a)
DType (Symbol (Vec a))
n) Symbol (Vec a)
x
[(Symbol (Vec a)
x, (Symbol (Vec a) -> Symbol (Vec a) -> Symbol (Vec a)
forall a. Num a => a -> a -> a
* Vec a -> Symbol (Vec a)
forall a. a -> Symbol a
constSymbol (Int -> (Int -> a) -> Vec a
forall a. Int -> (Int -> a) -> Vec a
V.generate (Vec a -> Int
forall a. Vec a -> Int
V.size (Vec a -> Int) -> Vec a -> Int
forall a b. (a -> b) -> a -> b
$ Symbol (Vec a) -> Vec a
forall a. Symbol a -> a
unSymbol Symbol (Vec a)
x) ((Int -> a) -> Vec a) -> (Int -> a) -> Vec a
forall a b. (a -> b) -> a -> b
$ \Int
i -> if Vec a -> Index (Vec a) -> DType (Vec a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex (Symbol (Vec a) -> Vec a
forall a. Symbol a -> a
unSymbol Symbol (Vec a)
x) Int
Index (Vec a)
i a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
DType (Symbol (Vec a))
n then a
1 else a
0)))]
instance Symbolic a => ElementwiseScalarOps (SymbolMat a) where
+. :: Num (DType (SymbolMat a)) =>
SymbolMat a -> DType (SymbolMat a) -> SymbolMat a
(+.) SymbolMat a
x DType (SymbolMat a)
n = SymbolMat a
x SymbolMat a -> SymbolMat a -> SymbolMat a
forall a. Num a => a -> a -> a
+ Mat a -> SymbolMat a
forall a. a -> Symbol a
constSymbol ((Int, Int) -> a -> Mat a
forall a. (Int, Int) -> a -> Mat a
M.replicate (Mat a -> (Int, Int)
forall a. Mat a -> (Int, Int)
M.size (Mat a -> (Int, Int)) -> Mat a -> (Int, Int)
forall a b. (a -> b) -> a -> b
$ SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol SymbolMat a
x) a
DType (SymbolMat a)
n)
-. :: Num (DType (SymbolMat a)) =>
SymbolMat a -> DType (SymbolMat a) -> SymbolMat a
(-.) SymbolMat a
x DType (SymbolMat a)
n = SymbolMat a
x SymbolMat a -> SymbolMat a -> SymbolMat a
forall a. Num a => a -> a -> a
- Mat a -> SymbolMat a
forall a. a -> Symbol a
constSymbol ((Int, Int) -> a -> Mat a
forall a. (Int, Int) -> a -> Mat a
M.replicate (Mat a -> (Int, Int)
forall a. Mat a -> (Int, Int)
M.size (Mat a -> (Int, Int)) -> Mat a -> (Int, Int)
forall a b. (a -> b) -> a -> b
$ SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol SymbolMat a
x) a
DType (SymbolMat a)
n)
*. :: Num (DType (SymbolMat a)) =>
SymbolMat a -> DType (SymbolMat a) -> SymbolMat a
(*.) SymbolMat a
x DType (SymbolMat a)
n = SymbolMat a
x SymbolMat a -> SymbolMat a -> SymbolMat a
forall a. Num a => a -> a -> a
* Mat a -> SymbolMat a
forall a. a -> Symbol a
constSymbol ((Int, Int) -> a -> Mat a
forall a. (Int, Int) -> a -> Mat a
M.replicate (Mat a -> (Int, Int)
forall a. Mat a -> (Int, Int)
M.size (Mat a -> (Int, Int)) -> Mat a -> (Int, Int)
forall a b. (a -> b) -> a -> b
$ SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol SymbolMat a
x) a
DType (SymbolMat a)
n)
/. :: Fractional (DType (SymbolMat a)) =>
SymbolMat a -> DType (SymbolMat a) -> SymbolMat a
(/.) SymbolMat a
x DType (SymbolMat a)
n = SymbolMat a
x SymbolMat a -> SymbolMat a -> SymbolMat a
forall a. Fractional a => a -> a -> a
/ Mat a -> SymbolMat a
forall a. a -> Symbol a
constSymbol ((Int, Int) -> a -> Mat a
forall a. (Int, Int) -> a -> Mat a
M.replicate (Mat a -> (Int, Int)
forall a. Mat a -> (Int, Int)
M.size (Mat a -> (Int, Int)) -> Mat a -> (Int, Int)
forall a b. (a -> b) -> a -> b
$ SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol SymbolMat a
x) a
DType (SymbolMat a)
n)
**. :: Floating (DType (SymbolMat a)) =>
SymbolMat a -> DType (SymbolMat a) -> SymbolMat a
(**.) SymbolMat a
x DType (SymbolMat a)
n = SymbolMat a
x SymbolMat a -> SymbolMat a -> SymbolMat a
forall a. Floating a => a -> a -> a
** Mat a -> SymbolMat a
forall a. a -> Symbol a
constSymbol ((Int, Int) -> a -> Mat a
forall a. (Int, Int) -> a -> Mat a
M.replicate (Mat a -> (Int, Int)
forall a. Mat a -> (Int, Int)
M.size (Mat a -> (Int, Int)) -> Mat a -> (Int, Int)
forall a b. (a -> b) -> a -> b
$ SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol SymbolMat a
x) a
DType (SymbolMat a)
n)
elementsMin :: Ord (DType (SymbolMat a)) =>
SymbolMat a -> DType (SymbolMat a) -> SymbolMat a
elementsMin SymbolMat a
x DType (SymbolMat a)
n = (Mat a -> Mat a)
-> SymbolMat a
-> [(SymbolMat a, SymbolMat a -> SymbolMat a)]
-> SymbolMat a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp (Mat a -> DType (Mat a) -> Mat a
forall f.
(ElementwiseScalarOps f, Ord (DType f)) =>
f -> DType f -> f
`elementsMin` DType (Mat a)
DType (SymbolMat a)
n) SymbolMat a
x
[(SymbolMat a
x, (SymbolMat a -> SymbolMat a -> SymbolMat a
forall a. Num a => a -> a -> a
* Mat a -> SymbolMat a
forall a. a -> Symbol a
constSymbol ((Int, Int) -> ((Int, Int) -> a) -> Mat a
forall a. (Int, Int) -> ((Int, Int) -> a) -> Mat a
M.generate (Mat a -> (Int, Int)
forall a. Mat a -> (Int, Int)
M.size (Mat a -> (Int, Int)) -> Mat a -> (Int, Int)
forall a b. (a -> b) -> a -> b
$ SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol SymbolMat a
x) (((Int, Int) -> a) -> Mat a) -> ((Int, Int) -> a) -> Mat a
forall a b. (a -> b) -> a -> b
$ \(Int, Int)
i -> if Mat a -> Index (Mat a) -> DType (Mat a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex (SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol SymbolMat a
x) (Int, Int)
Index (Mat a)
i a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
DType (SymbolMat a)
n then a
1 else a
0)))]
elementsMax :: Ord (DType (SymbolMat a)) =>
SymbolMat a -> DType (SymbolMat a) -> SymbolMat a
elementsMax SymbolMat a
x DType (SymbolMat a)
n = (Mat a -> Mat a)
-> SymbolMat a
-> [(SymbolMat a, SymbolMat a -> SymbolMat a)]
-> SymbolMat a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp (Mat a -> DType (Mat a) -> Mat a
forall f.
(ElementwiseScalarOps f, Ord (DType f)) =>
f -> DType f -> f
`elementsMax` DType (Mat a)
DType (SymbolMat a)
n) SymbolMat a
x
[(SymbolMat a
x, (SymbolMat a -> SymbolMat a -> SymbolMat a
forall a. Num a => a -> a -> a
* Mat a -> SymbolMat a
forall a. a -> Symbol a
constSymbol ((Int, Int) -> ((Int, Int) -> a) -> Mat a
forall a. (Int, Int) -> ((Int, Int) -> a) -> Mat a
M.generate (Mat a -> (Int, Int)
forall a. Mat a -> (Int, Int)
M.size (Mat a -> (Int, Int)) -> Mat a -> (Int, Int)
forall a b. (a -> b) -> a -> b
$ SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol SymbolMat a
x) (((Int, Int) -> a) -> Mat a) -> ((Int, Int) -> a) -> Mat a
forall a b. (a -> b) -> a -> b
$ \(Int, Int)
i -> if Mat a -> Index (Mat a) -> DType (Mat a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex (SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol SymbolMat a
x) (Int, Int)
Index (Mat a)
i a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
DType (SymbolMat a)
n then a
1 else a
0)))]
instance Symbolic a => SingletonOps (SymbolVec a) where
singleton :: DType (SymbolVec a) -> SymbolVec a
singleton = Vec a -> SymbolVec a
forall a. a -> Symbol a
constSymbol (Vec a -> SymbolVec a) -> (a -> Vec a) -> a -> SymbolVec a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Vec a
DType (Vec a) -> Vec a
forall f. SingletonOps f => DType f -> f
singleton
isSingleton :: SymbolVec a -> Bool
isSingleton = Vec a -> Bool
forall f. SingletonOps f => f -> Bool
isSingleton (Vec a -> Bool) -> (SymbolVec a -> Vec a) -> SymbolVec a -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SymbolVec a -> Vec a
forall a. Symbol a -> a
unSymbol
unSingleton :: SymbolVec a -> DType (SymbolVec a)
unSingleton = Vec a -> a
Vec a -> DType (Vec a)
forall f. SingletonOps f => f -> DType f
unSingleton (Vec a -> a) -> (SymbolVec a -> Vec a) -> SymbolVec a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SymbolVec a -> Vec a
forall a. Symbol a -> a
unSymbol
extendSingleton :: SymbolVec a -> SymbolVec a -> SymbolVec a
extendSingleton SymbolVec a
vec SymbolVec a
reference = Vec a -> SymbolVec a
forall a. a -> Symbol a
constSymbol (Vec a -> SymbolVec a) -> Vec a -> SymbolVec a
forall a b. (a -> b) -> a -> b
$ Vec a -> Vec a -> Vec a
forall f. SingletonOps f => f -> f -> f
extendSingleton (SymbolVec a -> Vec a
forall a. Symbol a -> a
unSymbol SymbolVec a
vec) (SymbolVec a -> Vec a
forall a. Symbol a -> a
unSymbol SymbolVec a
reference)
elementsSum :: Num (DType (SymbolVec a)) => SymbolVec a -> SymbolVec a
elementsSum SymbolVec a
x = (Vec a -> Vec a)
-> SymbolVec a
-> [(SymbolVec a, SymbolVec a -> SymbolVec a)]
-> SymbolVec a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp Vec a -> Vec a
forall f. (SingletonOps f, Num (DType f)) => f -> f
elementsSum SymbolVec a
x [(SymbolVec a
x, (SymbolVec a -> SymbolVec a -> SymbolVec a
forall f. SingletonOps f => f -> f -> f
`extendSingleton` SymbolVec a
x))]
elementsProduct :: Fractional (DType (SymbolVec a)) => SymbolVec a -> SymbolVec a
elementsProduct SymbolVec a
x = let innerProduct :: DType (Vec a)
innerProduct = Vec a -> DType (Vec a)
forall f. SingletonOps f => f -> DType f
unSingleton (Vec a -> DType (Vec a)) -> Vec a -> DType (Vec a)
forall a b. (a -> b) -> a -> b
$ Vec a -> Vec a
forall f. (SingletonOps f, Fractional (DType f)) => f -> f
elementsProduct (Vec a -> Vec a) -> Vec a -> Vec a
forall a b. (a -> b) -> a -> b
$ SymbolVec a -> Vec a
forall a. Symbol a -> a
unSymbol SymbolVec a
x
in (Vec a -> Vec a)
-> SymbolVec a
-> [(SymbolVec a, SymbolVec a -> SymbolVec a)]
-> SymbolVec a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp Vec a -> Vec a
forall f. (SingletonOps f, Fractional (DType f)) => f -> f
elementsProduct SymbolVec a
x [(SymbolVec a
x, \SymbolVec a
path -> SymbolVec a -> SymbolVec a -> SymbolVec a
forall f. SingletonOps f => f -> f -> f
extendSingleton SymbolVec a
path SymbolVec a
x SymbolVec a -> SymbolVec a -> SymbolVec a
forall a. Num a => a -> a -> a
* Vec a -> SymbolVec a
forall a. a -> Symbol a
constSymbol (Int -> (Int -> a) -> Vec a
forall a. Int -> (Int -> a) -> Vec a
V.generate (Vec a -> Int
forall a. Vec a -> Int
V.size (Vec a -> Int) -> Vec a -> Int
forall a b. (a -> b) -> a -> b
$ SymbolVec a -> Vec a
forall a. Symbol a -> a
unSymbol SymbolVec a
x) ((Int -> a) -> Vec a) -> (Int -> a) -> Vec a
forall a b. (a -> b) -> a -> b
$ \Int
i -> a
DType (Vec a)
innerProduct a -> a -> a
forall a. Fractional a => a -> a -> a
/ Vec a -> Index (Vec a) -> DType (Vec a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex (SymbolVec a -> Vec a
forall a. Symbol a -> a
unSymbol SymbolVec a
x) Int
Index (Vec a)
i))]
mean :: Fractional (DType (SymbolVec a)) => SymbolVec a -> SymbolVec a
mean SymbolVec a
x = (Vec a -> Vec a)
-> SymbolVec a
-> [(SymbolVec a, SymbolVec a -> SymbolVec a)]
-> SymbolVec a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp Vec a -> Vec a
forall f. (SingletonOps f, Fractional (DType f)) => f -> f
mean SymbolVec a
x [(SymbolVec a
x, \SymbolVec a
path -> SymbolVec a -> SymbolVec a -> SymbolVec a
forall f. SingletonOps f => f -> f -> f
extendSingleton SymbolVec a
path SymbolVec a
x SymbolVec a -> DType (SymbolVec a) -> SymbolVec a
forall f.
(ElementwiseScalarOps f, Fractional (DType f)) =>
f -> DType f -> f
/. Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Vec a -> Int
forall a. Vec a -> Int
V.size (SymbolVec a -> Vec a
forall a. Symbol a -> a
unSymbol SymbolVec a
x)))]
norm :: Floating (DType (SymbolVec a)) => SymbolVec a -> SymbolVec a
norm SymbolVec a
x = (Vec a -> Vec a)
-> SymbolVec a
-> [(SymbolVec a, SymbolVec a -> SymbolVec a)]
-> SymbolVec a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp Vec a -> Vec a
forall f. (SingletonOps f, Floating (DType f)) => f -> f
norm SymbolVec a
x [(SymbolVec a
x, \SymbolVec a
path -> SymbolVec a -> SymbolVec a -> SymbolVec a
forall f. SingletonOps f => f -> f -> f
extendSingleton SymbolVec a
path SymbolVec a
x SymbolVec a -> SymbolVec a -> SymbolVec a
forall a. Num a => a -> a -> a
* (SymbolVec a
x SymbolVec a -> DType (SymbolVec a) -> SymbolVec a
forall f.
(ElementwiseScalarOps f, Fractional (DType f)) =>
f -> DType f -> f
/. Vec a -> DType (Vec a)
forall f. SingletonOps f => f -> DType f
unSingleton (Vec a -> Vec a
forall f. (SingletonOps f, Floating (DType f)) => f -> f
norm (Vec a -> Vec a) -> Vec a -> Vec a
forall a b. (a -> b) -> a -> b
$ SymbolVec a -> Vec a
forall a. Symbol a -> a
unSymbol SymbolVec a
x)))]
instance Symbolic a => SingletonOps (SymbolMat a) where
singleton :: DType (SymbolMat a) -> SymbolMat a
singleton = Mat a -> SymbolMat a
forall a. a -> Symbol a
constSymbol (Mat a -> SymbolMat a) -> (a -> Mat a) -> a -> SymbolMat a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Mat a
DType (Mat a) -> Mat a
forall f. SingletonOps f => DType f -> f
singleton
isSingleton :: SymbolMat a -> Bool
isSingleton = Mat a -> Bool
forall f. SingletonOps f => f -> Bool
isSingleton (Mat a -> Bool) -> (SymbolMat a -> Mat a) -> SymbolMat a -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol
unSingleton :: SymbolMat a -> DType (SymbolMat a)
unSingleton = Mat a -> a
Mat a -> DType (Mat a)
forall f. SingletonOps f => f -> DType f
unSingleton (Mat a -> a) -> (SymbolMat a -> Mat a) -> SymbolMat a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol
extendSingleton :: SymbolMat a -> SymbolMat a -> SymbolMat a
extendSingleton SymbolMat a
mat SymbolMat a
reference = Mat a -> SymbolMat a
forall a. a -> Symbol a
constSymbol (Mat a -> SymbolMat a) -> Mat a -> SymbolMat a
forall a b. (a -> b) -> a -> b
$ Mat a -> Mat a -> Mat a
forall f. SingletonOps f => f -> f -> f
extendSingleton (SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol SymbolMat a
mat) (SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol SymbolMat a
reference)
elementsSum :: Num (DType (SymbolMat a)) => SymbolMat a -> SymbolMat a
elementsSum SymbolMat a
x = (Mat a -> Mat a)
-> SymbolMat a
-> [(SymbolMat a, SymbolMat a -> SymbolMat a)]
-> SymbolMat a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp Mat a -> Mat a
forall f. (SingletonOps f, Num (DType f)) => f -> f
elementsSum SymbolMat a
x [(SymbolMat a
x, (SymbolMat a -> SymbolMat a -> SymbolMat a
forall f. SingletonOps f => f -> f -> f
`extendSingleton` SymbolMat a
x))]
elementsProduct :: Fractional (DType (SymbolMat a)) => SymbolMat a -> SymbolMat a
elementsProduct SymbolMat a
x = let innerProduct :: DType (Mat a)
innerProduct = Mat a -> DType (Mat a)
forall f. SingletonOps f => f -> DType f
unSingleton (Mat a -> DType (Mat a)) -> Mat a -> DType (Mat a)
forall a b. (a -> b) -> a -> b
$ Mat a -> Mat a
forall f. (SingletonOps f, Fractional (DType f)) => f -> f
elementsProduct (Mat a -> Mat a) -> Mat a -> Mat a
forall a b. (a -> b) -> a -> b
$ SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol SymbolMat a
x
in (Mat a -> Mat a)
-> SymbolMat a
-> [(SymbolMat a, SymbolMat a -> SymbolMat a)]
-> SymbolMat a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp Mat a -> Mat a
forall f. (SingletonOps f, Fractional (DType f)) => f -> f
elementsProduct SymbolMat a
x [(SymbolMat a
x, \SymbolMat a
path -> SymbolMat a -> SymbolMat a -> SymbolMat a
forall f. SingletonOps f => f -> f -> f
extendSingleton SymbolMat a
path SymbolMat a
x SymbolMat a -> SymbolMat a -> SymbolMat a
forall a. Num a => a -> a -> a
* Mat a -> SymbolMat a
forall a. a -> Symbol a
constSymbol ((Int, Int) -> ((Int, Int) -> a) -> Mat a
forall a. (Int, Int) -> ((Int, Int) -> a) -> Mat a
M.generate (Mat a -> (Int, Int)
forall a. Mat a -> (Int, Int)
M.size (Mat a -> (Int, Int)) -> Mat a -> (Int, Int)
forall a b. (a -> b) -> a -> b
$ SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol SymbolMat a
x) (((Int, Int) -> a) -> Mat a) -> ((Int, Int) -> a) -> Mat a
forall a b. (a -> b) -> a -> b
$ \(Int, Int)
i -> a
DType (Mat a)
innerProduct a -> a -> a
forall a. Fractional a => a -> a -> a
/ Mat a -> Index (Mat a) -> DType (Mat a)
forall f. Indexable f => f -> Index f -> DType f
unsafeIndex (SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol SymbolMat a
x) (Int, Int)
Index (Mat a)
i))]
mean :: Fractional (DType (SymbolMat a)) => SymbolMat a -> SymbolMat a
mean SymbolMat a
x = (Mat a -> Mat a)
-> SymbolMat a
-> [(SymbolMat a, SymbolMat a -> SymbolMat a)]
-> SymbolMat a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp Mat a -> Mat a
forall f. (SingletonOps f, Fractional (DType f)) => f -> f
mean SymbolMat a
x [(SymbolMat a
x, \SymbolMat a
path -> SymbolMat a -> SymbolMat a -> SymbolMat a
forall f. SingletonOps f => f -> f -> f
extendSingleton SymbolMat a
path SymbolMat a
x SymbolMat a -> DType (SymbolMat a) -> SymbolMat a
forall f.
(ElementwiseScalarOps f, Fractional (DType f)) =>
f -> DType f -> f
/. Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Mat a -> Int
forall a. Mat a -> Int
M.nElements (Mat a -> Int) -> Mat a -> Int
forall a b. (a -> b) -> a -> b
$ SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol SymbolMat a
x))]
norm :: Floating (DType (SymbolMat a)) => SymbolMat a -> SymbolMat a
norm SymbolMat a
x = (Mat a -> Mat a)
-> SymbolMat a
-> [(SymbolMat a, SymbolMat a -> SymbolMat a)]
-> SymbolMat a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp Mat a -> Mat a
forall f. (SingletonOps f, Floating (DType f)) => f -> f
norm SymbolMat a
x [(SymbolMat a
x, \SymbolMat a
path -> SymbolMat a -> SymbolMat a -> SymbolMat a
forall f. SingletonOps f => f -> f -> f
extendSingleton SymbolMat a
path SymbolMat a
x SymbolMat a -> SymbolMat a -> SymbolMat a
forall a. Num a => a -> a -> a
* (SymbolMat a
x SymbolMat a -> DType (SymbolMat a) -> SymbolMat a
forall f.
(ElementwiseScalarOps f, Fractional (DType f)) =>
f -> DType f -> f
/. Mat a -> DType (Mat a)
forall f. SingletonOps f => f -> DType f
unSingleton (Mat a -> Mat a
forall f. (SingletonOps f, Floating (DType f)) => f -> f
norm (Mat a -> Mat a) -> Mat a -> Mat a
forall a b. (a -> b) -> a -> b
$ SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol SymbolMat a
x)))]
instance Symbolic a => VecOps (SymbolVec a) where
dot :: Num (DType (SymbolVec a)) =>
SymbolVec a -> SymbolVec a -> SymbolVec a
dot SymbolVec a
a SymbolVec a
b = SymbolVec a -> SymbolVec a
forall f. (SingletonOps f, Num (DType f)) => f -> f
elementsSum (SymbolVec a -> SymbolVec a) -> SymbolVec a -> SymbolVec a
forall a b. (a -> b) -> a -> b
$ SymbolVec a
a SymbolVec a -> SymbolVec a -> SymbolVec a
forall a. Num a => a -> a -> a
* SymbolVec a
b
instance Symbolic a => MatOps (SymbolMat a) where
transpose :: SymbolMat a -> SymbolMat a
transpose SymbolMat a
x = (Mat a -> Mat a)
-> SymbolMat a
-> [(SymbolMat a, SymbolMat a -> SymbolMat a)]
-> SymbolMat a
forall a.
(a -> a)
-> Symbol a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
symbolicUnaryOp Mat a -> Mat a
forall f. MatOps f => f -> f
M.transpose SymbolMat a
x [(SymbolMat a
x, (SymbolMat a -> SymbolMat a -> SymbolMat a
forall a. Num a => a -> a -> a
* SymbolMat a -> SymbolMat a
forall f. MatOps f => f -> f
transpose SymbolMat a
x))]
addMatRow :: Num (DType (SymbolMat a)) =>
SymbolMat a -> SymbolMat a -> SymbolMat a
addMatRow SymbolMat a
mat SymbolMat a
row = (Mat a -> Mat a -> Mat a)
-> SymbolMat a
-> SymbolMat a
-> [(SymbolMat a, SymbolMat a -> SymbolMat a)]
-> SymbolMat a
forall a.
(a -> a -> a)
-> Symbol a
-> Symbol a
-> [(Symbol a, Symbol a -> Symbol a)]
-> Symbol a
symbolicBinaryOp Mat a -> Mat a -> Mat a
forall f. (MatOps f, Num (DType f)) => f -> f -> f
addMatRow SymbolMat a
mat SymbolMat a
row [(SymbolMat a
mat, SymbolMat a -> SymbolMat a
forall a. a -> a
id), (SymbolMat a
row, Mat a -> SymbolMat a
forall a. a -> Symbol a
constSymbol (Mat a -> SymbolMat a)
-> (SymbolMat a -> Mat a) -> SymbolMat a -> SymbolMat a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vec a -> Mat a
forall a. Vec a -> Mat a
M.rowVec (Vec a -> Mat a) -> (SymbolMat a -> Vec a) -> SymbolMat a -> Mat a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Mat a -> Int -> Vec a) -> Int -> Mat a -> Vec a
forall a b c. (a -> b -> c) -> b -> a -> c
flip Mat a -> Int -> Vec a
forall a. Mat a -> Int -> Vec a
M.indexRow Int
0 (Mat a -> Vec a) -> (SymbolMat a -> Mat a) -> SymbolMat a -> Vec a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SymbolMat a -> Mat a
forall a. Symbol a -> a
unSymbol)]
matMul :: Num (DType (SymbolMat a)) =>
SymbolMat a -> SymbolMat a -> SymbolMat a
matMul SymbolMat a
a SymbolMat a
b = (Mat a -> Mat a -> Mat a)
-> SymbolMat a
-> SymbolMat a
-> [(SymbolMat a, SymbolMat a -> SymbolMat a)]
-> SymbolMat a
forall a.
(a -> a -> a)
-> Symbol a
-> Symbol a
-> [(Symbol a, Symbol a -> Symbol a)]
-> Symbol a
symbolicBinaryOp Mat a -> Mat a -> Mat a
forall f. (MatOps f, Num (DType f)) => f -> f -> f
M.matMul SymbolMat a
a SymbolMat a
b [(SymbolMat a
a, (SymbolMat a -> SymbolMat a -> SymbolMat a
forall f. (MatOps f, Num (DType f)) => f -> f -> f
`matMul` SymbolMat a -> SymbolMat a
forall f. MatOps f => f -> f
transpose SymbolMat a
b)), (SymbolMat a
b, (SymbolMat a -> SymbolMat a
forall f. MatOps f => f -> f
transpose SymbolMat a
a SymbolMat a -> SymbolMat a -> SymbolMat a
forall f. (MatOps f, Num (DType f)) => f -> f -> f
`matMul`))]
newtype Gradients a = Gradients
{ forall a. Gradients a -> HashMap (Symbol a) (Symbol a)
unGradients :: HM.HashMap (Symbol a) (Symbol a)
}
allGradients :: Gradients a -> [(Symbol a, Symbol a)]
allGradients :: forall a. Gradients a -> [(Symbol a, Symbol a)]
allGradients = HashMap (Symbol a) (Symbol a) -> [(Symbol a, Symbol a)]
forall k v. HashMap k v -> [(k, v)]
HM.toList (HashMap (Symbol a) (Symbol a) -> [(Symbol a, Symbol a)])
-> (Gradients a -> HashMap (Symbol a) (Symbol a))
-> Gradients a
-> [(Symbol a, Symbol a)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Gradients a -> HashMap (Symbol a) (Symbol a)
forall a. Gradients a -> HashMap (Symbol a) (Symbol a)
unGradients
instance Show a => Show (Gradients a) where
show :: Gradients a -> String
show Gradients a
gradients = [(Symbol a, Symbol a)] -> String
forall a. Show a => a -> String
show ([(Symbol a, Symbol a)] -> String)
-> [(Symbol a, Symbol a)] -> String
forall a b. (a -> b) -> a -> b
$ Gradients a -> [(Symbol a, Symbol a)]
forall a. Gradients a -> [(Symbol a, Symbol a)]
allGradients Gradients a
gradients
getGradientsOf :: Symbolic a => Symbol a -> Gradients a
getGradientsOf :: forall a. Symbolic a => Symbol a -> Gradients a
getGradientsOf Symbol a
differentiatedSymbol = HashMap (Symbol a) (Symbol a) -> Gradients a
forall a. HashMap (Symbol a) (Symbol a) -> Gradients a
Gradients (HashMap (Symbol a) (Symbol a) -> Gradients a)
-> HashMap (Symbol a) (Symbol a) -> Gradients a
forall a b. (a -> b) -> a -> b
$ Symbol a
-> Symbol a
-> HashMap (Symbol a) (Symbol a)
-> HashMap (Symbol a) (Symbol a)
forall k v.
(Eq k, Hashable k) =>
k -> v -> HashMap k v -> HashMap k v
HM.insert Symbol a
differentiatedSymbol Symbol a
wrtItself (HashMap (Symbol a) (Symbol a) -> HashMap (Symbol a) (Symbol a))
-> HashMap (Symbol a) (Symbol a) -> HashMap (Symbol a) (Symbol a)
forall a b. (a -> b) -> a -> b
$
Symbol a
-> HashMap (Symbol a) (Symbol a) -> HashMap (Symbol a) (Symbol a)
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> HashMap k v
HM.delete (SymbolIdentifier
-> a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
forall a.
SymbolIdentifier
-> a -> [(Symbol a, Symbol a -> Symbol a)] -> Symbol a
Symbol SymbolIdentifier
forall a. Monoid a => a
mempty a
forall a. HasCallStack => a
undefined []) (HashMap (Symbol a) (Symbol a) -> HashMap (Symbol a) (Symbol a))
-> HashMap (Symbol a) (Symbol a) -> HashMap (Symbol a) (Symbol a)
forall a b. (a -> b) -> a -> b
$
HashMap (Symbol a) (Symbol a)
-> Symbol a -> Symbol a -> HashMap (Symbol a) (Symbol a)
forall {a}.
Symbolic a =>
HashMap (Symbol a) (Symbol a)
-> Symbol a -> Symbol a -> HashMap (Symbol a) (Symbol a)
go HashMap (Symbol a) (Symbol a)
forall k v. HashMap k v
HM.empty Symbol a
differentiatedSymbol Symbol a
wrtItself
where
wrtItself :: Symbol a
wrtItself = Symbol a -> Symbol a
forall a. Symbolic a => a -> a
symbolicOne Symbol a
differentiatedSymbol
go :: HashMap (Symbol a) (Symbol a)
-> Symbol a -> Symbol a -> HashMap (Symbol a) (Symbol a)
go HashMap (Symbol a) (Symbol a)
grads Symbol a
s Symbol a
pathValue =
((Symbol a, Symbol a -> Symbol a)
-> HashMap (Symbol a) (Symbol a) -> HashMap (Symbol a) (Symbol a))
-> HashMap (Symbol a) (Symbol a)
-> [(Symbol a, Symbol a -> Symbol a)]
-> HashMap (Symbol a) (Symbol a)
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\(Symbol a
child, Symbol a -> Symbol a
mulPath) HashMap (Symbol a) (Symbol a)
grad -> let pathValue' :: Symbol a
pathValue' = Symbol a -> Symbol a
mulPath Symbol a
pathValue
grad' :: HashMap (Symbol a) (Symbol a)
grad' = (Maybe (Symbol a) -> Maybe (Symbol a))
-> Symbol a
-> HashMap (Symbol a) (Symbol a)
-> HashMap (Symbol a) (Symbol a)
forall k v.
(Eq k, Hashable k) =>
(Maybe v -> Maybe v) -> k -> HashMap k v -> HashMap k v
HM.alter (\Maybe (Symbol a)
e -> Symbol a -> Maybe (Symbol a)
forall a. a -> Maybe a
Just (Symbol a -> Maybe (Symbol a)) -> Symbol a -> Maybe (Symbol a)
forall a b. (a -> b) -> a -> b
$ case Maybe (Symbol a)
e of
Maybe (Symbol a)
Nothing -> Symbol a
pathValue'
Just Symbol a
x -> Symbol a
x Symbol a -> Symbol a -> Symbol a
forall a. Num a => a -> a -> a
+ Symbol a
pathValue'
) Symbol a
child HashMap (Symbol a) (Symbol a)
grad
in HashMap (Symbol a) (Symbol a)
-> Symbol a -> Symbol a -> HashMap (Symbol a) (Symbol a)
go HashMap (Symbol a) (Symbol a)
grad' Symbol a
child Symbol a
pathValue') HashMap (Symbol a) (Symbol a)
grads (Symbol a -> [(Symbol a, Symbol a -> Symbol a)]
forall a. Symbol a -> [(Symbol a, Symbol a -> Symbol a)]
symbolGradients Symbol a
s)
wrt :: Symbolic a => Gradients a -> Symbol a -> Symbol a
wrt :: forall a. Symbolic a => Gradients a -> Symbol a -> Symbol a
wrt Gradients a
gradients Symbol a
x = Symbol a -> Symbol a -> HashMap (Symbol a) (Symbol a) -> Symbol a
forall k v. (Eq k, Hashable k) => v -> k -> HashMap k v -> v
HM.findWithDefault (Symbol a -> Symbol a
forall a. Symbolic a => a -> a
symbolicZero Symbol a
x) Symbol a
x (HashMap (Symbol a) (Symbol a) -> Symbol a)
-> HashMap (Symbol a) (Symbol a) -> Symbol a
forall a b. (a -> b) -> a -> b
$ Gradients a -> HashMap (Symbol a) (Symbol a)
forall a. Gradients a -> HashMap (Symbol a) (Symbol a)
unGradients Gradients a
gradients
nthPartialGradient :: Symbolic a => Symbol a -> [Symbol a] -> Symbol a
nthPartialGradient :: forall a. Symbolic a => Symbol a -> [Symbol a] -> Symbol a
nthPartialGradient = (Symbol a -> Symbol a -> Symbol a)
-> Symbol a -> [Symbol a] -> Symbol a
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((Symbol a -> Symbol a -> Symbol a)
-> Symbol a -> [Symbol a] -> Symbol a)
-> (Symbol a -> Symbol a -> Symbol a)
-> Symbol a
-> [Symbol a]
-> Symbol a
forall a b. (a -> b) -> a -> b
$ \Symbol a
y Symbol a
x -> Symbol a -> Gradients a
forall a. Symbolic a => Symbol a -> Gradients a
getGradientsOf Symbol a
y Gradients a -> Symbol a -> Symbol a
forall a. Symbolic a => Gradients a -> Symbol a -> Symbol a
`wrt` Symbol a
x
nthGradient :: Symbolic a => Int -> Symbol a -> Symbol a -> Symbol a
nthGradient :: forall a. Symbolic a => Int -> Symbol a -> Symbol a -> Symbol a
nthGradient Int
n Symbol a
y Symbol a
x
| Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 = String -> Symbol a
forall a. HasCallStack => String -> a
error String
"Cannot take negative order gradient"
| Bool
otherwise = Symbol a -> [Symbol a] -> Symbol a
forall a. Symbolic a => Symbol a -> [Symbol a] -> Symbol a
nthPartialGradient Symbol a
y (Int -> Symbol a -> [Symbol a]
forall a. Int -> a -> [a]
replicate Int
n Symbol a
x)