{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE TypeApplications #-}
module Algorithm.SRTree.Likelihoods
( Distribution (..)
, PVector
, SRMatrix
, sse
, mse
, rmse
, r2
, nll
, predict
, buildNLL
, buildNLLEGraph
, gradNLL
, gradNLLArr
, gradNLLGraph
, gradNLLEGraph
, fisherNLL
, getSErr
, hessianNLL
, tree2arr
)
where
import Algorithm.SRTree.AD ( reverseModeArr, reverseModeGraph, reverseModeEGraph )
import Data.Massiv.Array hiding (all, map, read, replicate, tail, take, zip)
import qualified Data.Massiv.Array as M
import qualified Data.Massiv.Array.Mutable as Mut
import Data.Maybe (fromMaybe)
import Data.SRTree
import Data.SRTree.Recursion ( cata, accu )
import Data.SRTree.Derivative (deriveByParam, deriveByVar, derivative)
import Data.SRTree.Eval
import qualified Data.IntMap.Strict as IntMap
import qualified Data.Vector.Storable as VS
import GHC.IO (unsafePerformIO)
import Data.Maybe
import Debug.Trace
import Data.SRTree.Print
import Algorithm.EqSat.Egraph
import Algorithm.EqSat.Simplify
import Algorithm.EqSat.Build
import Control.Monad.State.Strict
import Control.Monad.Identity
import Data.SRTree.Print
data Distribution = MSE | Gaussian | HGaussian | Bernoulli | Poisson | ROXY
deriving (Int -> Distribution -> ShowS
[Distribution] -> ShowS
Distribution -> String
(Int -> Distribution -> ShowS)
-> (Distribution -> String)
-> ([Distribution] -> ShowS)
-> Show Distribution
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Distribution -> ShowS
showsPrec :: Int -> Distribution -> ShowS
$cshow :: Distribution -> String
show :: Distribution -> String
$cshowList :: [Distribution] -> ShowS
showList :: [Distribution] -> ShowS
Show, ReadPrec [Distribution]
ReadPrec Distribution
Int -> ReadS Distribution
ReadS [Distribution]
(Int -> ReadS Distribution)
-> ReadS [Distribution]
-> ReadPrec Distribution
-> ReadPrec [Distribution]
-> Read Distribution
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
$creadsPrec :: Int -> ReadS Distribution
readsPrec :: Int -> ReadS Distribution
$creadList :: ReadS [Distribution]
readList :: ReadS [Distribution]
$creadPrec :: ReadPrec Distribution
readPrec :: ReadPrec Distribution
$creadListPrec :: ReadPrec [Distribution]
readListPrec :: ReadPrec [Distribution]
Read, Int -> Distribution
Distribution -> Int
Distribution -> [Distribution]
Distribution -> Distribution
Distribution -> Distribution -> [Distribution]
Distribution -> Distribution -> Distribution -> [Distribution]
(Distribution -> Distribution)
-> (Distribution -> Distribution)
-> (Int -> Distribution)
-> (Distribution -> Int)
-> (Distribution -> [Distribution])
-> (Distribution -> Distribution -> [Distribution])
-> (Distribution -> Distribution -> [Distribution])
-> (Distribution -> Distribution -> Distribution -> [Distribution])
-> Enum Distribution
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
$csucc :: Distribution -> Distribution
succ :: Distribution -> Distribution
$cpred :: Distribution -> Distribution
pred :: Distribution -> Distribution
$ctoEnum :: Int -> Distribution
toEnum :: Int -> Distribution
$cfromEnum :: Distribution -> Int
fromEnum :: Distribution -> Int
$cenumFrom :: Distribution -> [Distribution]
enumFrom :: Distribution -> [Distribution]
$cenumFromThen :: Distribution -> Distribution -> [Distribution]
enumFromThen :: Distribution -> Distribution -> [Distribution]
$cenumFromTo :: Distribution -> Distribution -> [Distribution]
enumFromTo :: Distribution -> Distribution -> [Distribution]
$cenumFromThenTo :: Distribution -> Distribution -> Distribution -> [Distribution]
enumFromThenTo :: Distribution -> Distribution -> Distribution -> [Distribution]
Enum, Distribution
Distribution -> Distribution -> Bounded Distribution
forall a. a -> a -> Bounded a
$cminBound :: Distribution
minBound :: Distribution
$cmaxBound :: Distribution
maxBound :: Distribution
Bounded, Distribution -> Distribution -> Bool
(Distribution -> Distribution -> Bool)
-> (Distribution -> Distribution -> Bool) -> Eq Distribution
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Distribution -> Distribution -> Bool
== :: Distribution -> Distribution -> Bool
$c/= :: Distribution -> Distribution -> Bool
/= :: Distribution -> Distribution -> Bool
Eq)
sse :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sse :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sse SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = Double
err
where
(Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
cmp :: Comp
cmp = SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp SRMatrix
xss
yhat :: SRVector
yhat = SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree
err :: Double
err = SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (SRVector -> Double) -> SRVector -> Double
forall a b. (a -> b) -> a -> b
$ (PVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
yhat) SRVector -> Int -> SRVector
forall a b. (Num a, Integral b) => a -> b -> a
^ (Int
2 :: Int)
sseError :: SRMatrix -> PVector -> PVector -> Fix SRTree -> PVector -> Double
sseError :: SRMatrix -> PVector -> PVector -> Fix SRTree -> PVector -> Double
sseError SRMatrix
xss PVector
ys PVector
yErr Fix SRTree
tree PVector
theta = Double
err
where
(Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
cmp :: Comp
cmp = SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp SRMatrix
xss
yhat :: SRVector
yhat = SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree
err :: Double
err = SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (SRVector -> Double) -> SRVector -> Double
forall a b. (a -> b) -> a -> b
$ ((PVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
yhat) SRVector -> Int -> SRVector
forall a b. (Num a, Integral b) => a -> b -> a
^ (Int
2 :: Int) SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/ (PVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
yErr))
sseTot :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sseTot :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sseTot SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = Double
err
where
(Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
cmp :: Comp
cmp = SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp SRMatrix
xss
ym :: Double
ym = PVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum PVector
ys Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m
err :: Double
err = SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (SRVector -> Double) -> SRVector -> Double
forall a b. (a -> b) -> a -> b
$ ((Double -> Double) -> PVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Double -> Double -> Double
forall a. Num a => a -> a -> a
subtract Double
ym) PVector
ys) SRVector -> Int -> SRVector
forall a b. (Num a, Integral b) => a -> b -> a
^ (Int
2 :: Int)
mse :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
mse :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
mse SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = let (Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys in SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sse SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m
rmse :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
rmse :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
rmse SRMatrix
xss PVector
ys Fix SRTree
tree = Double -> Double
forall a. Floating a => a -> a
sqrt (Double -> Double) -> (PVector -> Double) -> PVector -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
mse SRMatrix
xss PVector
ys Fix SRTree
tree
r2 :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
r2 :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
r2 SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sse SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sseTot SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta
logistic :: Floating a => a -> a
logistic :: forall a. Floating a => a -> a
logistic a
x = a
1 a -> a -> a
forall a. Fractional a => a -> a -> a
/ (a
1 a -> a -> a
forall a. Num a => a -> a -> a
+ a -> a
forall a. Floating a => a -> a
exp (-a
x))
{-# inline logistic #-}
getSErr :: Num a => Distribution -> a -> Maybe a -> a
getSErr :: forall a. Num a => Distribution -> a -> Maybe a -> a
getSErr Distribution
Gaussian a
est = a -> Maybe a -> a
forall a. a -> Maybe a -> a
fromMaybe a
est
getSErr Distribution
_ a
_ = a -> Maybe a -> a
forall a b. a -> b -> a
const a
1
{-# inline getSErr #-}
negSum :: PVector -> Double
negSum :: PVector -> Double
negSum = Double -> Double
forall a. Num a => a -> a
negate (Double -> Double) -> (PVector -> Double) -> PVector -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum
{-# inline negSum #-}
nll :: Distribution -> Maybe PVector -> SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
nll :: Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> Double
nll Distribution
MSE Maybe PVector
_ SRMatrix
xss PVector
ys Fix SRTree
t PVector
theta = SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
mse SRMatrix
xss PVector
ys Fix SRTree
t PVector
theta
nll Distribution
Gaussian Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
t PVector
theta
| Int
nParams Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== (Int
p'Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) = String -> Double
forall a. HasCallStack => String -> a
error String
"For Gaussian distribution theta must contain the variance as its last value."
| Bool
otherwise = Double
0.5Double -> Double -> Double
forall a. Num a => a -> a -> a
*(SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sse SRMatrix
xss PVector
ys Fix SRTree
t PVector
theta Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
s Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
mDouble -> Double -> Double
forall a. Num a => a -> a -> a
*Double -> Double
forall a. Floating a => a -> a
log (Double
2Double -> Double -> Double
forall a. Num a => a -> a -> a
*Double
forall a. Floating a => a
piDouble -> Double -> Double
forall a. Num a => a -> a -> a
*Double
s))
where
s :: Double
s = Double -> Double
forall a. Floating a => a -> a
sqrt (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
mse SRMatrix
xss PVector
ys Fix SRTree
t PVector
theta
(Sz Int
m') = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
(Sz Int
p') = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
theta
nParams :: Int
nParams = Fix SRTree -> Int
countParamsUniq Fix SRTree
t
m :: Double
m = Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m'
p :: Integer
p = Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
p'
nll Distribution
HGaussian Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
t PVector
theta =
case Maybe PVector
mYerr of
Maybe PVector
Nothing -> String -> Double
forall a. HasCallStack => String -> a
error String
"For HGaussian, you must provide the measured error for the target variable."
Just PVector
yErr -> Double
0.5Double -> Double -> Double
forall a. Num a => a -> a -> a
*(SRMatrix -> PVector -> PVector -> Fix SRTree -> PVector -> Double
sseError SRMatrix
xss PVector
ys PVector
yErr Fix SRTree
t PVector
theta Double -> Double -> Double
forall a. Num a => a -> a -> a
+ SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum ((Double -> Double) -> PVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Double -> Double
forall a. Floating a => a -> a
log (Double -> Double) -> (Double -> Double) -> Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double
2Double -> Double -> Double
forall a. Num a => a -> a -> a
*) (Double -> Double) -> (Double -> Double) -> Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double
forall a. Floating a => a
piDouble -> Double -> Double
forall a. Num a => a -> a -> a
*)) PVector
yErr))
where
(Sz Int
m') = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
(Sz Int
p') = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
theta
m :: Integer
m = Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m'
p :: Integer
p = Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
p'
nll Distribution
Bernoulli Maybe PVector
_ SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta
| PVector -> Bool
notValid PVector
ys = String -> Double
forall a. HasCallStack => String -> a
error String
"For Bernoulli distribution the output must be either 0 or 1."
| Bool
otherwise = SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (SRVector -> Double) -> SRVector -> Double
forall a b. (a -> b) -> a -> b
$ ((Double -> Double) -> SRVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Double
1Double -> Double -> Double
forall a. Num a => a -> a -> a
-) (PVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys)) SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
yhat SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ SRVector -> SRVector
forall a. Floating a => a -> a
log ((Double -> Double) -> SRVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Double
1Double -> Double -> Double
forall a. Num a => a -> a -> a
+) (SRVector -> SRVector) -> SRVector -> SRVector
forall a b. (a -> b) -> a -> b
$ SRVector -> SRVector
forall a. Floating a => a -> a
exp ((Double -> Double) -> SRVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map Double -> Double
forall a. Num a => a -> a
negate SRVector
yhat))
where
(Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
yhat :: SRVector
yhat = SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree
notValid :: PVector -> Bool
notValid = (Double -> Bool) -> PVector -> Bool
forall ix r e.
(Index ix, Source r e) =>
(e -> Bool) -> Array r ix e -> Bool
M.any (\Double
x -> Double
x Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
0 Bool -> Bool -> Bool
&& Double
x Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
1)
nll Distribution
Poisson Maybe PVector
_ SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta
| PVector -> Bool
notValid PVector
ys = String -> Double
forall a. HasCallStack => String -> a
error String
"For Poisson distribution the output must be non-negative."
| Bool
otherwise = Double -> Double
forall a. Num a => a -> a
negate (Double -> Double) -> (SRVector -> Double) -> SRVector -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (SRVector -> Double) -> SRVector -> Double
forall a b. (a -> b) -> a -> b
$ SRVector
ys' SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
yhat SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
ys' SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector -> SRVector
forall a. Floating a => a -> a
log SRVector
ys' SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector -> SRVector
forall a. Floating a => a -> a
exp SRVector
yhat
where
ys' :: SRVector
ys' = PVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys
yhat :: SRVector
yhat = SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree
notValid :: PVector -> Bool
notValid = (Double -> Bool) -> PVector -> Bool
forall ix r e.
(Index ix, Source r e) =>
(e -> Bool) -> Array r ix e -> Bool
M.any (Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<Double
0)
nll Distribution
ROXY Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta
| Maybe PVector -> Bool
forall a. Maybe a -> Bool
isNothing Maybe PVector
mYerr = String -> Double
forall a. HasCallStack => String -> a
error String
"Can't calculate ROXY nll without x,y-errors."
| Int
p Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
num_params Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
3 = String -> Double
forall a. HasCallStack => String -> a
error String
"We need 3 additional parameters for ROXY."
| Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
1 Bool -> Bool -> Bool
&& Int
nInt -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/=Int
5 = String -> Double
forall a. HasCallStack => String -> a
error String
"For ROXY dataset must contain a single variable, or 1 variable + 4 cached data."
| Bool
otherwise = if Double -> Bool
forall a. RealFloat a => a -> Bool
isNaN Double
negLL then (Double
1.0Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
0.0) else Double
negLL
where
(Sz Int
p') = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
theta
(Sz2 Int
m Int
n) = SRMatrix -> Sz Ix2
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size SRMatrix
xss
p :: Int
p = Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
p'
num_params :: Int
num_params = Fix SRTree -> Int
countParamsUniq Fix SRTree
tree
x0 :: Array D (Lower Ix2) Double
x0 = SRMatrix
xss SRMatrix -> Int -> Array D (Lower Ix2) Double
forall r ix e.
(HasCallStack, Index ix, Source r e) =>
Array r ix e -> Int -> Array D (Lower ix) e
<! Int
0
logX :: Array D (Lower Ix2) Double
logX = SRMatrix
xss SRMatrix -> Int -> Array D (Lower Ix2) Double
forall r ix e.
(HasCallStack, Index ix, Source r e) =>
Array r ix e -> Int -> Array D (Lower ix) e
<! Int
1
logY :: Array D (Lower Ix2) Double
logY = SRMatrix
xss SRMatrix -> Int -> Array D (Lower Ix2) Double
forall r ix e.
(HasCallStack, Index ix, Source r e) =>
Array r ix e -> Int -> Array D (Lower ix) e
<! Int
2
logXErr :: Array D (Lower Ix2) Double
logXErr = SRMatrix
xss SRMatrix -> Int -> Array D (Lower Ix2) Double
forall r ix e.
(HasCallStack, Index ix, Source r e) =>
Array r ix e -> Int -> Array D (Lower ix) e
<! Int
3
logYErr :: Array D (Lower Ix2) Double
logYErr = SRMatrix
xss SRMatrix -> Int -> Array D (Lower Ix2) Double
forall r ix e.
(HasCallStack, Index ix, Source r e) =>
Array r ix e -> Int -> Array D (Lower ix) e
<! Int
4
yErr :: PVector
yErr = Maybe PVector -> PVector
forall a. HasCallStack => Maybe a -> a
fromJust Maybe PVector
mYerr
one :: SRVector
one = Comp -> Sz Int -> Double -> SRVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
compMode (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
1
zero :: SRVector
zero = Comp -> Sz Int -> Double -> SRVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
compMode (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
0
(Double
sig, Double
mu_gauss, Double
w_gauss) = (PVector
theta PVector -> Int -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! Int
num_params, PVector
theta PVector -> Int -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! (Int
num_params Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1), PVector
theta PVector -> Int -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! (Int
num_params Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2))
applyDer :: Op -> Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
applyDer :: Op -> SRVector -> SRVector -> SRVector -> SRVector -> SRVector
applyDer Op
Add SRVector
l SRVector
dl SRVector
r SRVector
dr = SRVector
dlSRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+SRVector
dr
applyDer Op
Sub SRVector
l SRVector
dl SRVector
r SRVector
dr = SRVector
dlSRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
-SRVector
dr
applyDer Op
Mul SRVector
l SRVector
dl SRVector
r SRVector
dr = SRVector
lSRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
dr SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ SRVector
rSRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
dl
applyDer Op
Div SRVector
l SRVector
dl SRVector
r SRVector
dr = (SRVector
dlSRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
r SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
drSRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
l) SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/ (SRVector
rSRVector -> Integer -> SRVector
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2)
applyDer Op
Power SRVector
l SRVector
dl SRVector
r SRVector
dr = SRVector
l SRVector -> SRVector -> SRVector
forall a. Floating a => a -> a -> a
** (SRVector
rSRVector -> Double -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
Array r ix e -> e -> Array r ix e
.-Double
1) SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* (SRVector
rSRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
dl SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ SRVector
l SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector -> SRVector
forall a. Floating a => a -> a
log SRVector
l SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
dr)
applyDer Op
PowerAbs SRVector
l SRVector
dl SRVector
r SRVector
dr = (SRVector -> SRVector
forall a. Num a => a -> a
abs SRVector
l SRVector -> SRVector -> SRVector
forall a. Floating a => a -> a -> a
** SRVector
r) SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* (SRVector
dr SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector -> SRVector
forall a. Floating a => a -> a
log (SRVector -> SRVector
forall a. Num a => a -> a
abs SRVector
l) SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ SRVector
r SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
dl SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/ SRVector
l)
applyDer Op
AQ SRVector
l SRVector
dl SRVector
r SRVector
dr = ((Double
1 Double -> SRVector -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
e -> Array r ix e -> Array r ix e
+. SRVector
rSRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
r) SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
dl SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
l SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
r SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
dr) SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/ (Double -> Double) -> SRVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Double -> Double -> Double
forall a. Floating a => a -> a -> a
**Double
1.5) (Double
1 Double -> SRVector -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
e -> Array r ix e -> Array r ix e
+. SRVector
rSRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
r)
(SRVector
yhat, SRVector
grad) = (SRTree (SRVector, SRVector) -> (SRVector, SRVector))
-> Fix SRTree -> (SRVector, SRVector)
forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata SRTree (SRVector, SRVector) -> (SRVector, SRVector)
alg Fix SRTree
tree
where
alg :: SRTree (SRVector, SRVector) -> (SRVector, SRVector)
alg (Var Int
ix) = (SRVector
x0, SRVector
one)
alg (Param Int
ix) = (Comp -> Sz Int -> Double -> SRVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
compMode (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) (PVector
theta PVector -> Int -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
M.! Int
ix), SRVector
zero)
alg (Const Double
x) = (Comp -> Sz Int -> Double -> SRVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
compMode (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
x, SRVector
zero)
alg (Uni Function
f (SRVector
val, SRVector
der)) = ((Double -> Double) -> SRVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Function -> Double -> Double
forall a. Floating a => Function -> a -> a
evalFun Function
f) SRVector
val, (Double -> Double) -> SRVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Function -> Double -> Double
forall a. Floating a => Function -> a -> a
derivative Function
f) SRVector
val SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
der)
alg (Bin Op
op (SRVector
valL, SRVector
derL) (SRVector
valR, SRVector
derR)) = ((Double -> Double -> Double) -> SRVector -> SRVector -> SRVector
forall ix r1 e1 r2 e2 e.
(Index ix, Source r1 e1, Source r2 e2) =>
(e1 -> e2 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix e
M.zipWith (Op -> Double -> Double -> Double
forall a. Floating a => Op -> a -> a -> a
evalOp Op
op) SRVector
valL SRVector
valR, Op -> SRVector -> SRVector -> SRVector -> SRVector -> SRVector
applyDer Op
op SRVector
valL SRVector
derL SRVector
valR SRVector
derR)
f :: SRVector
f = (Double -> Double) -> SRVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Double -> Double -> Double
forall a. Floating a => a -> a -> a
logBase Double
10) (SRVector -> SRVector
forall a. Num a => a -> a
abs SRVector
yhat)
fprime :: SRVector
fprime = SRVector
grad SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/ (Double -> Double
forall a. Floating a => a -> a
log Double
10 Double -> SRVector -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
e -> Array r ix e -> Array r ix e
*. SRVector
yhat) SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
x0 SRVector -> Double -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
Array r ix e -> e -> Array r ix e
.* Double -> Double
forall a. Floating a => a -> a
log Double
10
w_gauss2 :: Double
w_gauss2 = Double
w_gauss Double -> Integer -> Double
forall a b. (Num a, Integral b) => a -> b -> a
^ Integer
2
s2 :: SRVector
s2 = SRVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay (SRVector -> SRVector) -> SRVector -> SRVector
forall a b. (a -> b) -> a -> b
$ SRVector
logYErr SRVector -> Double -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
Array r ix e -> e -> Array r ix e
.+ Double
sigDouble -> Integer -> Double
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2
den :: SRVector
den = SRVector
fprime SRVector -> Integer -> SRVector
forall a b. (Num a, Integral b) => a -> b -> a
^ Integer
2 SRVector -> Double -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
Array r ix e -> e -> Array r ix e
.* Double
w_gauss2 SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
logXErr SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ SRVector
s2 SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* (Double
w_gauss2 Double -> SRVector -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
e -> Array r ix e -> Array r ix e
+. SRVector
logXErr)
neglogP :: SRVector
neglogP = Double -> Double
forall a. Floating a => a -> a
log (Double
2 Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
forall a. Floating a => a
pi)
Double -> SRVector -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
e -> Array r ix e -> Array r ix e
+. SRVector -> SRVector
forall a. Floating a => a -> a
log SRVector
den
SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ (Double
w_gauss2 Double -> SRVector -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
e -> Array r ix e -> Array r ix e
*. (SRVector
f SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
logY) SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* (SRVector
f SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
logY)
SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ SRVector
logXErr SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* (SRVector
fprime SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* (Double
mu_gauss Double -> SRVector -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
e -> Array r ix e -> Array r ix e
-. SRVector
logX) SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ SRVector
f SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
logY)SRVector -> Integer -> SRVector
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2
SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ SRVector
s2 SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* (SRVector
logX SRVector -> Double -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
Array r ix e -> e -> Array r ix e
.- Double
mu_gauss)SRVector -> Integer -> SRVector
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2) SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/ SRVector
den
negLL :: Double
negLL = Double
0.5 Double -> Double -> Double
forall a. Num a => a -> a -> a
* SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
neglogP
buildNLL :: Distribution -> Double -> Fix SRTree -> Fix SRTree
buildNLL Distribution
MSE Double
m Fix SRTree
tree = ((Fix SRTree
tree Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
- Int -> Fix SRTree
var (-Int
1)) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a -> a
** Fix SRTree
2) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Fractional a => a -> a -> a
/ Double -> Fix SRTree
constv Double
m
buildNLL Distribution
Gaussian Double
m Fix SRTree
tree = (Fix SRTree -> Fix SRTree
square(Fix SRTree
tree Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
- Int -> Fix SRTree
var (-Int
1)) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Fractional a => a -> a -> a
/ Fix SRTree -> Fix SRTree
square (Int -> Fix SRTree
param Int
p)) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
log ((Fix SRTree -> Fix SRTree
square (Int -> Fix SRTree
param Int
p)))
where
square :: Fix SRTree -> Fix SRTree
square = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> (Fix SRTree -> SRTree (Fix SRTree)) -> Fix SRTree -> Fix SRTree
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Function -> val -> SRTree val
Uni Function
Square
p :: Int
p = Fix SRTree -> Int
countParamsUniq Fix SRTree
tree
buildNLL Distribution
HGaussian Double
m Fix SRTree
tree = (Fix SRTree
tree Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
- Int -> Fix SRTree
var (-Int
1)) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a -> a
** Fix SRTree
2 Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Fractional a => a -> a -> a
/ Int -> Fix SRTree
var (-Int
2) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ Double -> Fix SRTree
constv Double
m Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
log (Fix SRTree
2Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
*Fix SRTree
forall a. Floating a => a
piFix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* Int -> Fix SRTree
var (-Int
2))
buildNLL Distribution
Poisson Double
m Fix SRTree
tree = Int -> Fix SRTree
var (-Int
1) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
log (Int -> Fix SRTree
var (-Int
1)) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
exp Fix SRTree
tree Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
- Int -> Fix SRTree
var (-Int
1) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* Fix SRTree
tree
buildNLL Distribution
Bernoulli Double
m Fix SRTree
tree = Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
log (Fix SRTree
1 Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
exp (Fix SRTree -> Fix SRTree
forall a. Num a => a -> a
negate Fix SRTree
tree)) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ (Fix SRTree
1 Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
- Int -> Fix SRTree
var (-Int
1)) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* Fix SRTree
tree
buildNLL Distribution
ROXY Double
m Fix SRTree
tree = Fix SRTree
neglogP
where
p :: Int
p = Fix SRTree -> Int
countParamsUniq Fix SRTree
tree
f :: Fix SRTree
f = Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
log (Fix SRTree -> Fix SRTree
forall a. Num a => a -> a
abs Fix SRTree
tree) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Fractional a => a -> a -> a
/ Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
log Fix SRTree
10
fprime :: Fix SRTree
fprime = Int -> Fix SRTree -> Fix SRTree
deriveByVar Int
0 Fix SRTree
tree Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Fractional a => a -> a -> a
/ (Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
log Fix SRTree
10 Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* Fix SRTree
tree) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* Int -> Fix SRTree
var Int
0 Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
log Fix SRTree
10
logX :: Fix SRTree
logX = Int -> Fix SRTree
var Int
1
logY :: Fix SRTree
logY = Int -> Fix SRTree
var Int
2
logXErr :: Fix SRTree
logXErr = Int -> Fix SRTree
var Int
3
logYErr :: Fix SRTree
logYErr = Int -> Fix SRTree
var Int
4
sig :: Fix SRTree
sig = Int -> Fix SRTree
param Int
p
mu_gauss :: Fix SRTree
mu_gauss = Int -> Fix SRTree
param (Int
pInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
w_gauss :: Fix SRTree
w_gauss = Int -> Fix SRTree
param (Int
pInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
2)
w_gauss2 :: Fix SRTree
w_gauss2 = Fix SRTree
w_gauss Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a -> a
** Fix SRTree
2
s2 :: Fix SRTree
s2 = Fix SRTree
logYErr Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ Fix SRTree
sig Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a -> a
** Fix SRTree
2
den :: Fix SRTree
den = Fix SRTree
fprime Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a -> a
** Fix SRTree
2 Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* Fix SRTree
w_gauss2 Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* Fix SRTree
logXErr Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ Fix SRTree
s2 Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* (Fix SRTree
w_gauss2 Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ Fix SRTree
logXErr)
neglogP :: Fix SRTree
neglogP = Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
log (Fix SRTree
2Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
*Fix SRTree
forall a. Floating a => a
pi)
Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
log Fix SRTree
den
Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ ( Fix SRTree
w_gauss2 Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* (Fix SRTree
f Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
- Fix SRTree
logY) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* (Fix SRTree
f Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
- Fix SRTree
logY)
Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ Fix SRTree
logXErr Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* (Fix SRTree
fprime Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
*(Fix SRTree
mu_gauss Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
- Fix SRTree
logX) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ Fix SRTree
f Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
- Fix SRTree
logY)Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a -> a
**Fix SRTree
2
Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ Fix SRTree
s2 Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* (Fix SRTree
logX Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
- Fix SRTree
mu_gauss) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a -> a
** Fix SRTree
2
) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Fractional a => a -> a -> a
/ Fix SRTree
den
buildNLLEGraph :: Distribution -> Double -> EGraph -> Int -> (Int, EGraph)
buildNLLEGraph Distribution
MSE Double
m EGraph
egraph Int
root = Identity (Int, EGraph) -> (Int, EGraph)
forall a. Identity a -> a
runIdentity (Identity (Int, EGraph) -> (Int, EGraph))
-> Identity (Int, EGraph) -> (Int, EGraph)
forall a b. (a -> b) -> a -> b
$ EGraphST Identity Int
addToEg EGraphST Identity Int -> EGraph -> Identity (Int, EGraph)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
`runStateT` EGraph
egraph
where
addToEg :: EGraphST Identity EClassId
addToEg :: EGraphST Identity Int
addToEg = do Int
v <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Int -> ENode
forall val. Int -> SRTree val
Var (-Int
1))
Int
c1 <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Double -> ENode
forall val. Double -> SRTree val
Const Double
2)
Int
c2 <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Double -> ENode
forall val. Double -> SRTree val
Const Double
m)
Int
x <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Op -> Int -> Int -> ENode
forall val. Op -> val -> val -> SRTree val
Bin Op
Sub Int
root Int
v)
Int
y <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Op -> Int -> Int -> ENode
forall val. Op -> val -> val -> SRTree val
Bin Op
Power Int
x Int
c1)
CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Op -> Int -> Int -> ENode
forall val. Op -> val -> val -> SRTree val
Bin Op
Div Int
y Int
c2)
buildNLLEGraph Distribution
Gaussian Double
m EGraph
egraph Int
root = Identity (Int, EGraph) -> (Int, EGraph)
forall a. Identity a -> a
runIdentity (EGraphST Identity Int
addToEg EGraphST Identity Int -> EGraph -> Identity (Int, EGraph)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
`runStateT` EGraph
egraph)
where
p :: Int
p = EGraph -> Int -> Int
countParamsUniqEg EGraph
egraph Int
root
addToEg :: EGraphST Identity EClassId
addToEg :: EGraphST Identity Int
addToEg = do Int
v <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Int -> ENode
forall val. Int -> SRTree val
Var (-Int
1))
Int
p <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Int -> ENode
forall val. Int -> SRTree val
Param Int
p)
Int
sp <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Function -> Int -> ENode
forall val. Function -> val -> SRTree val
Uni Function
Square Int
p)
Int
lsp <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Function -> Int -> ENode
forall val. Function -> val -> SRTree val
Uni Function
Log Int
sp)
Int
d <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Op -> Int -> Int -> ENode
forall val. Op -> val -> val -> SRTree val
Bin Op
Sub Int
root Int
v)
Int
sd <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Function -> Int -> ENode
forall val. Function -> val -> SRTree val
Uni Function
Square Int
d)
Int
x <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Op -> Int -> Int -> ENode
forall val. Op -> val -> val -> SRTree val
Bin Op
Div Int
sd Int
sp)
CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Op -> Int -> Int -> ENode
forall val. Op -> val -> val -> SRTree val
Bin Op
Add Int
x Int
lsp)
buildNLLEGraph Distribution
HGaussian Double
m EGraph
egraph Int
root = Identity (Int, EGraph) -> (Int, EGraph)
forall a. Identity a -> a
runIdentity (Identity (Int, EGraph) -> (Int, EGraph))
-> Identity (Int, EGraph) -> (Int, EGraph)
forall a b. (a -> b) -> a -> b
$ EGraphST Identity Int
addToEg EGraphST Identity Int -> EGraph -> Identity (Int, EGraph)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
`runStateT` EGraph
egraph
where
addToEg :: EGraphST Identity EClassId
addToEg :: EGraphST Identity Int
addToEg = do Int
v1 <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Int -> ENode
forall val. Int -> SRTree val
Var (-Int
1))
Int
v2 <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Int -> ENode
forall val. Int -> SRTree val
Var (-Int
2))
Int
c1 <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Double -> ENode
forall val. Double -> SRTree val
Const (Double
2Double -> Double -> Double
forall a. Num a => a -> a -> a
*Double
forall a. Floating a => a
pi))
Int
c2 <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Double -> ENode
forall val. Double -> SRTree val
Const Double
m)
Int
x <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Op -> Int -> Int -> ENode
forall val. Op -> val -> val -> SRTree val
Bin Op
Sub Int
root Int
v1)
Int
y <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Function -> Int -> ENode
forall val. Function -> val -> SRTree val
Uni Function
Square Int
x)
Int
z <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Op -> Int -> Int -> ENode
forall val. Op -> val -> val -> SRTree val
Bin Op
Div Int
y Int
v2)
Int
w <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Op -> Int -> Int -> ENode
forall val. Op -> val -> val -> SRTree val
Bin Op
Mul Int
c1 Int
v2)
Int
lw <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Function -> Int -> ENode
forall val. Function -> val -> SRTree val
Uni Function
Log Int
w)
Int
p <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Op -> Int -> Int -> ENode
forall val. Op -> val -> val -> SRTree val
Bin Op
Mul Int
c2 Int
lw)
CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Op -> Int -> Int -> ENode
forall val. Op -> val -> val -> SRTree val
Bin Op
Add Int
z Int
p)
buildNLLEGraph Distribution
Poisson Double
m EGraph
egraph Int
root = Identity (Int, EGraph) -> (Int, EGraph)
forall a. Identity a -> a
runIdentity (Identity (Int, EGraph) -> (Int, EGraph))
-> Identity (Int, EGraph) -> (Int, EGraph)
forall a b. (a -> b) -> a -> b
$ EGraphST Identity Int
addToEg EGraphST Identity Int -> EGraph -> Identity (Int, EGraph)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
`runStateT` EGraph
egraph
where
addToEg :: EGraphST Identity EClassId
addToEg :: EGraphST Identity Int
addToEg = do Int
v1 <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Int -> ENode
forall val. Int -> SRTree val
Var (-Int
1))
Int
lv <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Function -> Int -> ENode
forall val. Function -> val -> SRTree val
Uni Function
Log Int
v1)
Int
x <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Op -> Int -> Int -> ENode
forall val. Op -> val -> val -> SRTree val
Bin Op
Mul Int
v1 Int
lv)
Int
y <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Function -> Int -> ENode
forall val. Function -> val -> SRTree val
Uni Function
Exp Int
root)
Int
z <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Op -> Int -> Int -> ENode
forall val. Op -> val -> val -> SRTree val
Bin Op
Add Int
x Int
y)
Int
vt <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Op -> Int -> Int -> ENode
forall val. Op -> val -> val -> SRTree val
Bin Op
Mul Int
v1 Int
root)
CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Op -> Int -> Int -> ENode
forall val. Op -> val -> val -> SRTree val
Bin Op
Sub Int
z Int
vt)
buildNLLEGraph Distribution
Bernoulli Double
m EGraph
egraph Int
root = Identity (Int, EGraph) -> (Int, EGraph)
forall a. Identity a -> a
runIdentity (Identity (Int, EGraph) -> (Int, EGraph))
-> Identity (Int, EGraph) -> (Int, EGraph)
forall a b. (a -> b) -> a -> b
$ EGraphST Identity Int
addToEg EGraphST Identity Int -> EGraph -> Identity (Int, EGraph)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
`runStateT` EGraph
egraph
where
addToEg :: EGraphST Identity EClassId
addToEg :: EGraphST Identity Int
addToEg = do Int
v <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Int -> ENode
forall val. Int -> SRTree val
Var (-Int
1))
Int
c1 <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Double -> ENode
forall val. Double -> SRTree val
Const Double
1)
Int
c2 <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Double -> ENode
forall val. Double -> SRTree val
Const (-Double
1))
Int
mr <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Op -> Int -> Int -> ENode
forall val. Op -> val -> val -> SRTree val
Bin Op
Mul Int
c2 Int
root)
Int
er <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Function -> Int -> ENode
forall val. Function -> val -> SRTree val
Uni Function
Exp Int
mr)
Int
er1 <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Op -> Int -> Int -> ENode
forall val. Op -> val -> val -> SRTree val
Bin Op
Add Int
c1 Int
er)
Int
ler1 <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Function -> Int -> ENode
forall val. Function -> val -> SRTree val
Uni Function
Log Int
er1)
Int
v1 <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Op -> Int -> Int -> ENode
forall val. Op -> val -> val -> SRTree val
Bin Op
Sub Int
c1 Int
v)
Int
v1r <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Op -> Int -> Int -> ENode
forall val. Op -> val -> val -> SRTree val
Bin Op
Mul Int
v1 Int
root)
CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Op -> Int -> Int -> ENode
forall val. Op -> val -> val -> SRTree val
Bin Op
Add Int
ler1 Int
v1r)
buildNLLEGraph Distribution
ROXY Double
m EGraph
egraph Int
root = String -> (Int, EGraph)
forall a. HasCallStack => String -> a
error String
"ROXY not supported with cache"
predict :: Distribution -> Fix SRTree -> PVector -> SRMatrix -> SRVector
predict :: Distribution -> Fix SRTree -> PVector -> SRMatrix -> SRVector
predict Distribution
MSE Fix SRTree
tree PVector
theta SRMatrix
xss = SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree
predict Distribution
Gaussian Fix SRTree
tree PVector
theta SRMatrix
xss = SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree
predict Distribution
Bernoulli Fix SRTree
tree PVector
theta SRMatrix
xss = SRVector -> SRVector
forall a. Floating a => a -> a
logistic (SRVector -> SRVector) -> SRVector -> SRVector
forall a b. (a -> b) -> a -> b
$ SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree
predict Distribution
Poisson Fix SRTree
tree PVector
theta SRMatrix
xss = SRVector -> SRVector
forall a. Floating a => a -> a
exp (SRVector -> SRVector) -> SRVector -> SRVector
forall a b. (a -> b) -> a -> b
$ SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree
predict Distribution
ROXY Fix SRTree
tree PVector
theta SRMatrix
xss = SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree
gradNLL :: Distribution -> Maybe PVector -> SRMatrix -> PVector -> Fix SRTree -> PVector -> (Double, SRVector)
gradNLL :: Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (Double, SRVector)
gradNLL Distribution
dist Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = (Double
f, PVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
grad)
where
grad :: PVector
grad :: PVector
grad = Comp -> [Double] -> PVector
forall r e. Manifest r e => Comp -> [e] -> Vector r e
M.fromList Comp
M.Seq [Int -> Double
finitediff Int
ix | Int
ix <- [Int
0..Int
pInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]]
(Sz Int
p) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
theta
disturb :: Int -> PVector
disturb :: Int -> PVector
disturb Int
ix = Comp -> [Double] -> PVector
forall r e. Manifest r e => Comp -> [e] -> Vector r e
M.fromList Comp
M.Seq ([Double] -> PVector) -> [Double] -> PVector
forall a b. (a -> b) -> a -> b
$ (Int -> Double -> Double) -> [Int] -> [Double] -> [Double]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
Prelude.zipWith (\Int
iy Double
v -> if Int
iyInt -> Int -> Bool
forall a. Eq a => a -> a -> Bool
==Int
ix then (Double
vDouble -> Double -> Double
forall a. Num a => a -> a -> a
+Double
eps) else Double
v) [Int
0..] (PVector -> [Double]
forall ix r e. (Index ix, Source r e) => Array r ix e -> [e]
M.toList PVector
theta)
eps :: Double
eps :: Double
eps = Double
1e-8
f :: Double
f = (Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m) (Double -> Double) -> (SRVector -> Double) -> SRVector -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (SRVector -> Double)
-> (SRVector -> SRVector) -> SRVector -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double -> Double) -> SRVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Double -> Integer -> Double
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2) (SRVector -> Double) -> SRVector -> Double
forall a b. (a -> b) -> a -> b
$ (Distribution -> Fix SRTree -> PVector -> SRMatrix -> SRVector
predict Distribution
MSE Fix SRTree
tree PVector
theta SRMatrix
xss) SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- PVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys
finitediff :: Int -> Double
finitediff Int
ix = let t1 :: PVector
t1 = Int -> PVector
disturb Int
ix
f' :: Double
f' = (Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m) (Double -> Double) -> (SRVector -> Double) -> SRVector -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (SRVector -> Double)
-> (SRVector -> SRVector) -> SRVector -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double -> Double) -> SRVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Double -> Integer -> Double
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2) (SRVector -> Double) -> SRVector -> Double
forall a b. (a -> b) -> a -> b
$ (Distribution -> Fix SRTree -> PVector -> SRMatrix -> SRVector
predict Distribution
MSE Fix SRTree
tree PVector
t1 SRMatrix
xss) SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- PVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys
in (Double
f' Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
f)Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
eps
(Sz2 Int
m Int
_) = SRMatrix -> Sz Ix2
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size SRMatrix
xss
tree' :: Fix SRTree
tree' = Distribution -> Double -> Fix SRTree -> Fix SRTree
buildNLL Distribution
dist (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m) Fix SRTree
tree
treeArr :: [(Int, (Int, Int, Int, Double))]
treeArr = IntMap (Int, Int, Int, Double) -> [(Int, (Int, Int, Int, Double))]
forall a. IntMap a -> [(Int, a)]
IntMap.toAscList (IntMap (Int, Int, Int, Double)
-> [(Int, (Int, Int, Int, Double))])
-> IntMap (Int, Int, Int, Double)
-> [(Int, (Int, Int, Int, Double))]
forall a b. (a -> b) -> a -> b
$ Fix SRTree -> IntMap (Int, Int, Int, Double)
tree2arr Fix SRTree
tree'
j2ix :: IntMap Integer
j2ix = [(Int, Integer)] -> IntMap Integer
forall a. [(Int, a)] -> IntMap a
IntMap.fromList ([(Int, Integer)] -> IntMap Integer)
-> [(Int, Integer)] -> IntMap Integer
forall a b. (a -> b) -> a -> b
$ [Int] -> [Integer] -> [(Int, Integer)]
forall a b. [a] -> [b] -> [(a, b)]
Prelude.zip (((Int, (Int, Int, Int, Double)) -> Int)
-> [(Int, (Int, Int, Int, Double))] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
Prelude.map (Int, (Int, Int, Int, Double)) -> Int
forall a b. (a, b) -> a
fst [(Int, (Int, Int, Int, Double))]
treeArr) [Integer
0..]
nanTo0 :: p -> p
nanTo0 p
x = p
x
{-# INLINE nanTo0 #-}
gradNLLArr :: Distribution
-> SRMatrix
-> PVector
-> Maybe PVector
-> [(Int, (Int, Int, Int, Double))]
-> IntMap Int
-> Vector Double
-> (Double, SRVector)
gradNLLArr Distribution
MSE SRMatrix
xss PVector
ys Maybe PVector
mYerr [(Int, (Int, Int, Int, Double))]
tree IntMap Int
j2ix Vector Double
theta =
(SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, SRVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay SRVector
grad')
where
(SRVector
yhat, PVector
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> Vector Double
-> [(Int, (Int, Int, Int, Double))]
-> IntMap Int
-> (SRVector, PVector)
reverseModeArr SRMatrix
xss PVector
ys Maybe PVector
mYerr Vector Double
theta [(Int, (Int, Int, Int, Double))]
tree IntMap Int
j2ix
grad' :: SRVector
grad' = (Double -> Double) -> PVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map Double -> Double
forall {p}. p -> p
nanTo0 PVector
grad
gradNLLArr Distribution
Gaussian SRMatrix
xss PVector
ys Maybe PVector
mYerr [(Int, (Int, Int, Int, Double))]
tree IntMap Int
j2ix Vector Double
theta =
(SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, SRVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay SRVector
grad')
where
(SRVector
yhat, PVector
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> Vector Double
-> [(Int, (Int, Int, Int, Double))]
-> IntMap Int
-> (SRVector, PVector)
reverseModeArr SRMatrix
xss PVector
ys Maybe PVector
mYerr Vector Double
theta [(Int, (Int, Int, Int, Double))]
tree IntMap Int
j2ix
grad' :: SRVector
grad' = (Double -> Double) -> PVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map Double -> Double
forall {p}. p -> p
nanTo0 PVector
grad
gradNLLArr Distribution
Bernoulli SRMatrix
xss PVector
ys Maybe PVector
mYerr [(Int, (Int, Int, Int, Double))]
tree IntMap Int
j2ix Vector Double
theta
| (Double -> Bool) -> PVector -> Bool
forall ix r e.
(Index ix, Source r e) =>
(e -> Bool) -> Array r ix e -> Bool
M.any (\Double
x -> Double
x Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
0 Bool -> Bool -> Bool
&& Double
x Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
1) PVector
ys = String -> (Double, SRVector)
forall a. HasCallStack => String -> a
error String
"For Bernoulli distribution the output must be either 0 or 1."
| Bool
otherwise = (SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, SRVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay SRVector
grad')
where
(SRVector
yhat, PVector
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> Vector Double
-> [(Int, (Int, Int, Int, Double))]
-> IntMap Int
-> (SRVector, PVector)
reverseModeArr SRMatrix
xss PVector
ys Maybe PVector
mYerr Vector Double
theta [(Int, (Int, Int, Int, Double))]
tree IntMap Int
j2ix
grad' :: SRVector
grad' = (Double -> Double) -> PVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map Double -> Double
forall {p}. p -> p
nanTo0 PVector
grad
gradNLLArr Distribution
Poisson SRMatrix
xss PVector
ys Maybe PVector
mYerr [(Int, (Int, Int, Int, Double))]
tree IntMap Int
j2ix Vector Double
theta
| (Double -> Bool) -> PVector -> Bool
forall ix r e.
(Index ix, Source r e) =>
(e -> Bool) -> Array r ix e -> Bool
M.any (Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<Double
0) PVector
ys = String -> (Double, SRVector)
forall a. HasCallStack => String -> a
error String
"For Poisson distribution the output must be non-negative."
| Bool
otherwise = (SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, SRVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay SRVector
grad')
where
(SRVector
yhat, PVector
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> Vector Double
-> [(Int, (Int, Int, Int, Double))]
-> IntMap Int
-> (SRVector, PVector)
reverseModeArr SRMatrix
xss PVector
ys Maybe PVector
mYerr Vector Double
theta [(Int, (Int, Int, Int, Double))]
tree IntMap Int
j2ix
grad' :: SRVector
grad' = (Double -> Double) -> PVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map Double -> Double
forall {p}. p -> p
nanTo0 PVector
grad
gradNLLArr Distribution
ROXY SRMatrix
xss PVector
ys Maybe PVector
mYerr [(Int, (Int, Int, Int, Double))]
tree IntMap Int
j2ix Vector Double
theta =
((Double -> Double -> Double
forall a. Num a => a -> a -> a
*Double
0.5) (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, (Double -> Double) -> SRVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Double -> Double -> Double
forall a. Num a => a -> a -> a
*(Double
0.5)) (SRVector -> SRVector) -> SRVector -> SRVector
forall a b. (a -> b) -> a -> b
$ SRVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay SRVector
grad')
where
(SRVector
yhat, PVector
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> Vector Double
-> [(Int, (Int, Int, Int, Double))]
-> IntMap Int
-> (SRVector, PVector)
reverseModeArr SRMatrix
xss PVector
ys Maybe PVector
mYerr Vector Double
theta [(Int, (Int, Int, Int, Double))]
tree IntMap Int
j2ix
grad' :: SRVector
grad' = (Double -> Double) -> PVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map Double -> Double
forall {p}. p -> p
nanTo0 PVector
grad
gradNLLGraph :: Distribution
-> SRMatrix
-> PVector
-> Maybe PVector
-> Fix SRTree
-> Vector Double
-> (Double, Vector Double)
gradNLLGraph Distribution
MSE SRMatrix
xss PVector
ys Maybe PVector
mYerr Fix SRTree
tree Vector Double
theta =
(SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, Vector Double
grad')
where
(SRVector
yhat, Vector Double
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> Vector Double
-> Fix SRTree
-> (SRVector, Vector Double)
reverseModeGraph SRMatrix
xss PVector
ys Maybe PVector
mYerr Vector Double
theta Fix SRTree
tree
grad' :: Vector Double
grad' = (Double -> Double) -> Vector Double -> Vector Double
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
VS.map Double -> Double
forall {p}. p -> p
nanTo0 Vector Double
grad
gradNLLGraph Distribution
Gaussian SRMatrix
xss PVector
ys Maybe PVector
mYerr Fix SRTree
tree Vector Double
theta =
(SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, Vector Double
grad')
where
(SRVector
yhat, Vector Double
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> Vector Double
-> Fix SRTree
-> (SRVector, Vector Double)
reverseModeGraph SRMatrix
xss PVector
ys Maybe PVector
mYerr Vector Double
theta Fix SRTree
tree
grad' :: Vector Double
grad' = (Double -> Double) -> Vector Double -> Vector Double
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
VS.map Double -> Double
forall {p}. p -> p
nanTo0 Vector Double
grad
gradNLLGraph Distribution
Bernoulli SRMatrix
xss PVector
ys Maybe PVector
mYerr Fix SRTree
tree Vector Double
theta
| (Double -> Bool) -> PVector -> Bool
forall ix r e.
(Index ix, Source r e) =>
(e -> Bool) -> Array r ix e -> Bool
M.any (\Double
x -> Double
x Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
0 Bool -> Bool -> Bool
&& Double
x Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
1) PVector
ys = String -> (Double, Vector Double)
forall a. HasCallStack => String -> a
error String
"For Bernoulli distribution the output must be either 0 or 1."
| Bool
otherwise = (SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, Vector Double
grad')
where
(SRVector
yhat, Vector Double
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> Vector Double
-> Fix SRTree
-> (SRVector, Vector Double)
reverseModeGraph SRMatrix
xss PVector
ys Maybe PVector
mYerr Vector Double
theta Fix SRTree
tree
grad' :: Vector Double
grad' = (Double -> Double) -> Vector Double -> Vector Double
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
VS.map Double -> Double
forall {p}. p -> p
nanTo0 Vector Double
grad
gradNLLGraph Distribution
Poisson SRMatrix
xss PVector
ys Maybe PVector
mYerr Fix SRTree
tree Vector Double
theta
| (Double -> Bool) -> PVector -> Bool
forall ix r e.
(Index ix, Source r e) =>
(e -> Bool) -> Array r ix e -> Bool
M.any (Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<Double
0) PVector
ys = String -> (Double, Vector Double)
forall a. HasCallStack => String -> a
error String
"For Poisson distribution the output must be non-negative."
| Bool
otherwise = (SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, Vector Double
grad')
where
(SRVector
yhat, Vector Double
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> Vector Double
-> Fix SRTree
-> (SRVector, Vector Double)
reverseModeGraph SRMatrix
xss PVector
ys Maybe PVector
mYerr Vector Double
theta Fix SRTree
tree
grad' :: Vector Double
grad' = (Double -> Double) -> Vector Double -> Vector Double
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
VS.map Double -> Double
forall {p}. p -> p
nanTo0 Vector Double
grad
gradNLLGraph Distribution
ROXY SRMatrix
xss PVector
ys Maybe PVector
mYerr Fix SRTree
tree Vector Double
theta =
((Double -> Double -> Double
forall a. Num a => a -> a -> a
*Double
0.5) (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, (Double -> Double) -> Vector Double -> Vector Double
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
VS.map (Double -> Double -> Double
forall a. Num a => a -> a -> a
*(Double
0.5)) (Vector Double -> Vector Double) -> Vector Double -> Vector Double
forall a b. (a -> b) -> a -> b
$ Vector Double
grad')
where
(SRVector
yhat, Vector Double
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> Vector Double
-> Fix SRTree
-> (SRVector, Vector Double)
reverseModeGraph SRMatrix
xss PVector
ys Maybe PVector
mYerr Vector Double
theta Fix SRTree
tree
grad' :: Vector Double
grad' = (Double -> Double) -> Vector Double -> Vector Double
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
VS.map Double -> Double
forall {p}. p -> p
nanTo0 Vector Double
grad
gradNLLEGraph :: Distribution
-> SRMatrix
-> PVector
-> Maybe PVector
-> EGraph
-> ECache
-> Int
-> Vector Double
-> (Double, Vector Double)
gradNLLEGraph Distribution
MSE SRMatrix
xss PVector
ys Maybe PVector
mYerr EGraph
egraph ECache
cache Int
root Vector Double
theta =
(SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, Vector Double
grad')
where
(SRVector
yhat, Vector Double
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> EGraph
-> ECache
-> Int
-> Vector Double
-> (SRVector, Vector Double)
reverseModeEGraph SRMatrix
xss PVector
ys Maybe PVector
mYerr EGraph
egraph ECache
cache Int
root Vector Double
theta
grad' :: Vector Double
grad' = (Double -> Double) -> Vector Double -> Vector Double
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
VS.map Double -> Double
forall {p}. p -> p
nanTo0 Vector Double
grad
gradNLLEGraph Distribution
Gaussian SRMatrix
xss PVector
ys Maybe PVector
mYerr EGraph
egraph ECache
cache Int
root Vector Double
theta =
(SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, Vector Double
grad')
where
(SRVector
yhat, Vector Double
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> EGraph
-> ECache
-> Int
-> Vector Double
-> (SRVector, Vector Double)
reverseModeEGraph SRMatrix
xss PVector
ys Maybe PVector
mYerr EGraph
egraph ECache
cache Int
root Vector Double
theta
grad' :: Vector Double
grad' = (Double -> Double) -> Vector Double -> Vector Double
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
VS.map Double -> Double
forall {p}. p -> p
nanTo0 Vector Double
grad
gradNLLEGraph Distribution
Bernoulli SRMatrix
xss PVector
ys Maybe PVector
mYerr EGraph
egraph ECache
cache Int
root Vector Double
theta
| (Double -> Bool) -> PVector -> Bool
forall ix r e.
(Index ix, Source r e) =>
(e -> Bool) -> Array r ix e -> Bool
M.any (\Double
x -> Double
x Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
0 Bool -> Bool -> Bool
&& Double
x Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
1) PVector
ys = String -> (Double, Vector Double)
forall a. HasCallStack => String -> a
error String
"For Bernoulli distribution the output must be either 0 or 1."
| Bool
otherwise = (SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, Vector Double
grad')
where
(SRVector
yhat, Vector Double
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> EGraph
-> ECache
-> Int
-> Vector Double
-> (SRVector, Vector Double)
reverseModeEGraph SRMatrix
xss PVector
ys Maybe PVector
mYerr EGraph
egraph ECache
cache Int
root Vector Double
theta
grad' :: Vector Double
grad' = (Double -> Double) -> Vector Double -> Vector Double
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
VS.map Double -> Double
forall {p}. p -> p
nanTo0 Vector Double
grad
gradNLLEGraph Distribution
Poisson SRMatrix
xss PVector
ys Maybe PVector
mYerr EGraph
egraph ECache
cache Int
root Vector Double
theta
| (Double -> Bool) -> PVector -> Bool
forall ix r e.
(Index ix, Source r e) =>
(e -> Bool) -> Array r ix e -> Bool
M.any (Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<Double
0) PVector
ys = String -> (Double, Vector Double)
forall a. HasCallStack => String -> a
error String
"For Poisson distribution the output must be non-negative."
| Bool
otherwise = (SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, Vector Double
grad')
where
(SRVector
yhat, Vector Double
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> EGraph
-> ECache
-> Int
-> Vector Double
-> (SRVector, Vector Double)
reverseModeEGraph SRMatrix
xss PVector
ys Maybe PVector
mYerr EGraph
egraph ECache
cache Int
root Vector Double
theta
grad' :: Vector Double
grad' = (Double -> Double) -> Vector Double -> Vector Double
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
VS.map Double -> Double
forall {p}. p -> p
nanTo0 Vector Double
grad
gradNLLEGraph Distribution
ROXY SRMatrix
xss PVector
ys Maybe PVector
mYerr EGraph
egraph ECache
cache Int
root Vector Double
theta =
((Double -> Double -> Double
forall a. Num a => a -> a -> a
*Double
0.5) (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, (Double -> Double) -> Vector Double -> Vector Double
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
VS.map (Double -> Double -> Double
forall a. Num a => a -> a -> a
*(Double
0.5)) (Vector Double -> Vector Double) -> Vector Double -> Vector Double
forall a b. (a -> b) -> a -> b
$ Vector Double
grad')
where
(SRVector
yhat, Vector Double
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> EGraph
-> ECache
-> Int
-> Vector Double
-> (SRVector, Vector Double)
reverseModeEGraph SRMatrix
xss PVector
ys Maybe PVector
mYerr EGraph
egraph ECache
cache Int
root Vector Double
theta
grad' :: Vector Double
grad' = (Double -> Double) -> Vector Double -> Vector Double
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
VS.map Double -> Double
forall {p}. p -> p
nanTo0 Vector Double
grad
fisherNLL :: Distribution -> Maybe PVector -> SRMatrix -> PVector -> Fix SRTree -> PVector -> SRVector
fisherNLL :: Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> SRVector
fisherNLL Distribution
ROXY Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = Comp -> Sz Int -> (Int -> Double) -> SRVector
forall r ix e.
Load r ix e =>
Comp -> Sz ix -> (ix -> e) -> Array r ix e
makeArray Comp
cmp (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
p) Int -> Double
finiteDiff
where
cmp :: Comp
cmp = SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp SRMatrix
xss
(Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
(Sz Int
p) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
theta
f :: Double
f = Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> Double
nll Distribution
ROXY Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta
eps :: Double
eps = Double
1e-6
finiteDiff :: Int -> Double
finiteDiff Int
ix = IO Double -> Double
forall a. IO a -> a
unsafePerformIO (IO Double -> Double) -> IO Double -> Double
forall a b. (a -> b) -> a -> b
$ do
MArray RealWorld S Int Double
theta' <- PVector -> IO (MArray RealWorld S Int Double)
forall r ix e (m :: * -> *).
(Manifest r e, Index ix, MonadIO m) =>
Array r ix e -> m (MArray RealWorld r ix e)
Mut.thaw PVector
theta
Double
v <- MArray (PrimState IO) S Int Double -> Int -> IO Double
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> ix -> m e
Mut.readM MArray RealWorld S Int Double
MArray (PrimState IO) S Int Double
theta' Int
ix
MArray (PrimState IO) S Int Double -> Int -> Double -> IO ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
Mut.writeM MArray RealWorld S Int Double
MArray (PrimState IO) S Int Double
theta' Int
ix (Double
v Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
eps)
PVector
thetaPlus <- MArray (PrimState IO) S Int Double -> IO PVector
forall r ix e (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> m (Array r ix e)
Mut.freezeS MArray RealWorld S Int Double
MArray (PrimState IO) S Int Double
theta'
MArray (PrimState IO) S Int Double -> Int -> Double -> IO ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
Mut.writeM MArray RealWorld S Int Double
MArray (PrimState IO) S Int Double
theta' Int
ix (Double
v Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
eps)
PVector
thetaMinus <- MArray (PrimState IO) S Int Double -> IO PVector
forall r ix e (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> m (Array r ix e)
Mut.freezeS MArray RealWorld S Int Double
MArray (PrimState IO) S Int Double
theta'
let fPlus :: Double
fPlus = Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> Double
nll Distribution
ROXY Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
thetaPlus
fMinus :: Double
fMinus = Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> Double
nll Distribution
ROXY Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
thetaMinus
Double -> IO Double
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Double -> IO Double) -> Double -> IO Double
forall a b. (a -> b) -> a -> b
$ (Double
fPlus Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
fMinus Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
2Double -> Double -> Double
forall a. Num a => a -> a -> a
*Double
f)Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/(Double
epsDouble -> Double -> Double
forall a. Num a => a -> a -> a
*Double
eps)
fisherNLL Distribution
Gaussian Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = Comp -> Sz Int -> (Int -> Double) -> SRVector
forall r ix e.
Load r ix e =>
Comp -> Sz ix -> (ix -> e) -> Array r ix e
makeArray Comp
cmp (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
p) Int -> Double
finiteDiff
where
cmp :: Comp
cmp = SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp SRMatrix
xss
(Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
(Sz Int
p) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
theta
f :: Double
f = Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> Double
nll Distribution
Gaussian Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta
eps :: Double
eps = Double
1e-6
finiteDiff :: Int -> Double
finiteDiff Int
ix = IO Double -> Double
forall a. IO a -> a
unsafePerformIO (IO Double -> Double) -> IO Double -> Double
forall a b. (a -> b) -> a -> b
$ do
MArray RealWorld S Int Double
theta' <- PVector -> IO (MArray RealWorld S Int Double)
forall r ix e (m :: * -> *).
(Manifest r e, Index ix, MonadIO m) =>
Array r ix e -> m (MArray RealWorld r ix e)
Mut.thaw PVector
theta
Double
v <- MArray (PrimState IO) S Int Double -> Int -> IO Double
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> ix -> m e
Mut.readM MArray RealWorld S Int Double
MArray (PrimState IO) S Int Double
theta' Int
ix
MArray (PrimState IO) S Int Double -> Int -> Double -> IO ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
Mut.writeM MArray RealWorld S Int Double
MArray (PrimState IO) S Int Double
theta' Int
ix (Double
v Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
eps)
PVector
thetaPlus <- MArray (PrimState IO) S Int Double -> IO PVector
forall r ix e (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> m (Array r ix e)
Mut.freezeS MArray RealWorld S Int Double
MArray (PrimState IO) S Int Double
theta'
MArray (PrimState IO) S Int Double -> Int -> Double -> IO ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
Mut.writeM MArray RealWorld S Int Double
MArray (PrimState IO) S Int Double
theta' Int
ix (Double
v Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
eps)
PVector
thetaMinus <- MArray (PrimState IO) S Int Double -> IO PVector
forall r ix e (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> m (Array r ix e)
Mut.freezeS MArray RealWorld S Int Double
MArray (PrimState IO) S Int Double
theta'
let fPlus :: Double
fPlus = Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> Double
nll Distribution
Gaussian Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
thetaPlus
fMinus :: Double
fMinus = Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> Double
nll Distribution
Gaussian Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
thetaMinus
Double -> IO Double
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Double -> IO Double) -> Double -> IO Double
forall a b. (a -> b) -> a -> b
$ (Double
fPlus Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
fMinus Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
2Double -> Double -> Double
forall a. Num a => a -> a -> a
*Double
f)Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/(Double
epsDouble -> Double -> Double
forall a. Num a => a -> a -> a
*Double
eps)
fisherNLL Distribution
dist Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = Comp -> Sz Int -> (Int -> Double) -> SRVector
forall r ix e.
Load r ix e =>
Comp -> Sz ix -> (ix -> e) -> Array r ix e
makeArray Comp
cmp (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
p) Int -> Double
build
where
build :: Int -> Double
build Int
ix = let dtdix :: Fix SRTree
dtdix = Int -> Fix SRTree -> Fix SRTree
deriveByParam Int
ix Fix SRTree
t'
d2tdix2 :: Fix SRTree
d2tdix2 = Int -> Fix SRTree -> Fix SRTree
deriveByParam Int
ix Fix SRTree
dtdix
f' :: SRVector
f' = Fix SRTree -> SRVector
eval Fix SRTree
dtdix
f'' :: SRVector
f'' = Fix SRTree -> SRVector
eval Fix SRTree
d2tdix2
in SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (SRVector -> Double) -> SRVector -> Double
forall a b. (a -> b) -> a -> b
$ SRVector
phi' SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
f'SRVector -> Integer -> SRVector
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2 SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
res SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
f''
cmp :: Comp
cmp = SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp SRMatrix
xss
(Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
(Sz Int
p) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
theta
t' :: Fix SRTree
t' = (Fix SRTree, [Double]) -> Fix SRTree
forall a b. (a, b) -> a
fst ((Fix SRTree, [Double]) -> Fix SRTree)
-> (Fix SRTree, [Double]) -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Fix SRTree -> (Fix SRTree, [Double])
floatConstsToParam Fix SRTree
tree
eval :: Fix SRTree -> SRVector
eval = SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree SRMatrix
xss PVector
theta
yhat :: SRVector
yhat = Fix SRTree -> SRVector
eval Fix SRTree
t'
res :: SRVector
res = PVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
phi
yErr :: PVector
yErr = case Maybe PVector
mYerr of
Maybe PVector
Nothing -> Comp -> Sz Int -> Double -> PVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate (SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp SRMatrix
xss) (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
est
Just PVector
e -> PVector
e
est :: Double
est = Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
p)
(SRVector
phi, SRVector
phi') = case Distribution
dist of
Distribution
MSE -> (SRVector
yhat, Comp -> Sz Int -> Double -> SRVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
compMode (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
1)
Distribution
Gaussian -> (SRVector
yhat, Comp -> Sz Int -> Double -> SRVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
compMode (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
1)
Distribution
Bernoulli -> (SRVector -> SRVector
forall a. Floating a => a -> a
logistic SRVector
yhat, SRVector
phiSRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*(Comp -> Sz Int -> Double -> SRVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
compMode (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
1 SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
phi))
Distribution
Poisson -> (SRVector -> SRVector
forall a. Floating a => a -> a
exp SRVector
yhat, SRVector
phi)
hessianNLL :: Distribution -> Maybe PVector -> SRMatrix -> PVector -> Fix SRTree -> PVector -> SRMatrix
hessianNLL :: Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> SRMatrix
hessianNLL Distribution
ROXY Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = SRMatrix
forall a. HasCallStack => a
undefined
hessianNLL Distribution
dist Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = Comp -> Sz Ix2 -> (Ix2 -> Double) -> SRMatrix
forall r ix e.
Load r ix e =>
Comp -> Sz ix -> (ix -> e) -> Array r ix e
makeArray Comp
cmp (Ix2 -> Sz Ix2
forall ix. Index ix => ix -> Sz ix
Sz (Int
p Int -> Int -> Ix2
:. Int
p)) Ix2 -> Double
build
where
build :: Ix2 -> Double
build (Int
ix :. Int
iy) = let dtdix :: Fix SRTree
dtdix = Int -> Fix SRTree -> Fix SRTree
deriveByParam Int
ix Fix SRTree
t'
dtdiy :: Fix SRTree
dtdiy = Int -> Fix SRTree -> Fix SRTree
deriveByParam Int
iy Fix SRTree
t'
d2tdixy :: Fix SRTree
d2tdixy = Int -> Fix SRTree -> Fix SRTree
deriveByParam Int
iy Fix SRTree
dtdix
fx :: SRVector
fx = Fix SRTree -> SRVector
eval Fix SRTree
dtdix
fy :: SRVector
fy = Fix SRTree -> SRVector
eval Fix SRTree
dtdiy
fxy :: SRVector
fxy = Fix SRTree -> SRVector
eval Fix SRTree
d2tdixy
in case Distribution
dist of
Distribution
Gaussian -> SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (SRVector -> Double)
-> (SRVector -> SRVector) -> SRVector -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/PVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
yErr) (SRVector -> Double) -> SRVector -> Double
forall a b. (a -> b) -> a -> b
$ SRVector
phi' SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
fx SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
fy SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
res SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
fxy
Distribution
_ -> SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (SRVector -> Double) -> SRVector -> Double
forall a b. (a -> b) -> a -> b
$ SRVector
phi' SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
fx SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
fy SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
res SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
fxy
cmp :: Comp
cmp = SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp SRMatrix
xss
(Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
(Sz Int
p) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
theta
t' :: Fix SRTree
t' = Fix SRTree
tree
eval :: Fix SRTree -> SRVector
eval = SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree SRMatrix
xss PVector
theta
yErr :: PVector
yErr = case Maybe PVector
mYerr of
Maybe PVector
Nothing -> Comp -> Sz Int -> Double -> PVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
compMode (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
est
Just PVector
e -> PVector
e
est :: Double
est = Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
p)
yhat :: SRVector
yhat = Fix SRTree -> SRVector
eval Fix SRTree
t'
res :: SRVector
res = PVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
phi
(SRVector
phi, SRVector
phi') = case Distribution
dist of
Distribution
MSE -> (SRVector
yhat, Comp -> Sz Int -> Double -> SRVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
cmp (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
1)
Distribution
Gaussian -> (SRVector
yhat, Comp -> Sz Int -> Double -> SRVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
cmp (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
1)
Distribution
Bernoulli -> (SRVector -> SRVector
forall a. Floating a => a -> a
logistic SRVector
yhat, SRVector
phiSRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*(Comp -> Sz Int -> Double -> SRVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
cmp (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
1 SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
phi))
Distribution
Poisson -> (SRVector -> SRVector
forall a. Floating a => a -> a
exp SRVector
yhat, SRVector
phi)
tree2arr :: Fix SRTree -> IntMap.IntMap (Int, Int, Int, Double)
tree2arr :: Fix SRTree -> IntMap (Int, Int, Int, Double)
tree2arr Fix SRTree
tree = [(Int, (Int, Int, Int, Double))] -> IntMap (Int, Int, Int, Double)
forall a. [(Int, a)] -> IntMap a
IntMap.fromList [(Int, (Int, Int, Int, Double))]
listTree
where
height :: Fix SRTree -> Integer
height = (SRTree Integer -> Integer) -> Fix SRTree -> Integer
forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata SRTree Integer -> Integer
forall {a}. (Num a, Ord a) => SRTree a -> a
alg
where
alg :: SRTree a -> a
alg (Var Int
ix) = a
1
alg (Const Double
x) = a
1
alg (Param Int
ix) = a
1
alg (Uni Function
_ a
t) = a
1 a -> a -> a
forall a. Num a => a -> a -> a
+ a
t
alg (Bin Op
_ a
l a
r) = a
1 a -> a -> a
forall a. Num a => a -> a -> a
+ a -> a -> a
forall a. Ord a => a -> a -> a
max a
l a
r
listTree :: [(Int, (Int, Int, Int, Double))]
listTree = (forall x. SRTree x -> Int -> SRTree (x, Int))
-> (SRTree [(Int, (Int, Int, Int, Double))]
-> Int -> [(Int, (Int, Int, Int, Double))])
-> Fix SRTree
-> Int
-> [(Int, (Int, Int, Int, Double))]
forall (f :: * -> *) p a.
Functor f =>
(forall x. f x -> p -> f (x, p))
-> (f a -> p -> a) -> Fix f -> p -> a
accu SRTree x -> Int -> SRTree (x, Int)
forall x. SRTree x -> Int -> SRTree (x, Int)
forall {b} {a}. Num b => SRTree a -> b -> SRTree (a, b)
indexer SRTree [(Int, (Int, Int, Int, Double))]
-> Int -> [(Int, (Int, Int, Int, Double))]
forall {a} {a}.
Num a =>
SRTree [(a, (a, Int, Int, Double))]
-> a -> [(a, (a, Int, Int, Double))]
convert Fix SRTree
tree Int
0
indexer :: SRTree a -> b -> SRTree (a, b)
indexer (Var Int
ix) b
iy = Int -> SRTree (a, b)
forall val. Int -> SRTree val
Var Int
ix
indexer (Const Double
x) b
iy = Double -> SRTree (a, b)
forall val. Double -> SRTree val
Const Double
x
indexer (Param Int
ix) b
iy = Int -> SRTree (a, b)
forall val. Int -> SRTree val
Param Int
ix
indexer (Bin Op
op a
l a
r) b
iy = Op -> (a, b) -> (a, b) -> SRTree (a, b)
forall val. Op -> val -> val -> SRTree val
Bin Op
op (a
l, b
2b -> b -> b
forall a. Num a => a -> a -> a
*b
iyb -> b -> b
forall a. Num a => a -> a -> a
+b
1) (a
r, b
2b -> b -> b
forall a. Num a => a -> a -> a
*b
iyb -> b -> b
forall a. Num a => a -> a -> a
+b
2)
indexer (Uni Function
f a
t) b
iy = Function -> (a, b) -> SRTree (a, b)
forall val. Function -> val -> SRTree val
Uni Function
f (a
t, b
2b -> b -> b
forall a. Num a => a -> a -> a
*b
iyb -> b -> b
forall a. Num a => a -> a -> a
+b
1)
convert :: SRTree [(a, (a, Int, Int, Double))]
-> a -> [(a, (a, Int, Int, Double))]
convert (Var Int
ix) a
iy = [(a
iy, (a
0, Int
0, Int
ix, -Double
1))]
convert (Const Double
x) a
iy = [(a
iy, (a
0, Int
2, -Int
1, Double
x))]
convert (Param Int
ix) a
iy = [(a
iy, (a
0, Int
1, Int
ix, -Double
1))]
convert (Uni Function
f [(a, (a, Int, Int, Double))]
t) a
iy = (a
iy, (a
1, Function -> Int
forall a. Enum a => a -> Int
fromEnum Function
f, -Int
1, -Double
1)) (a, (a, Int, Int, Double))
-> [(a, (a, Int, Int, Double))] -> [(a, (a, Int, Int, Double))]
forall a. a -> [a] -> [a]
: [(a, (a, Int, Int, Double))]
t
convert (Bin Op
op [(a, (a, Int, Int, Double))]
l [(a, (a, Int, Int, Double))]
r) a
iy = (a
iy, (a
2, Op -> Int
forall a. Enum a => a -> Int
fromEnum Op
op, -Int
1, -Double
1)) (a, (a, Int, Int, Double))
-> [(a, (a, Int, Int, Double))] -> [(a, (a, Int, Int, Double))]
forall a. a -> [a] -> [a]
: ([(a, (a, Int, Int, Double))]
l [(a, (a, Int, Int, Double))]
-> [(a, (a, Int, Int, Double))] -> [(a, (a, Int, Int, Double))]
forall a. Semigroup a => a -> a -> a
<> [(a, (a, Int, Int, Double))]
r)
{-# INLINE tree2arr #-}