{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE TypeApplications #-}

-----------------------------------------------------------------------------
-- |
-- Module      :  Algorithm.SRTree.Likelihoods 
-- Copyright   :  (c) Fabricio Olivetti 2021 - 2024
-- License     :  BSD3
-- Maintainer  :  fabricio.olivetti@gmail.com
-- Stability   :  experimental
-- Portability :  ConstraintKinds
--
-- Functions to calculate different likelihood functions, their gradient, and Hessian matrices.
--
-----------------------------------------------------------------------------
module Algorithm.SRTree.Likelihoods
  ( Distribution (..)
  , PVector
  , SRMatrix
  , sse
  , mse
  , rmse
  , r2
  , nll
  , predict
  , buildNLL
  , buildNLLEGraph
  , gradNLL
  , gradNLLArr
  , gradNLLGraph
  , gradNLLEGraph
  , fisherNLL
  , getSErr
  , hessianNLL
  , tree2arr
  )
    where

import Algorithm.SRTree.AD ( reverseModeArr, reverseModeGraph, reverseModeEGraph )
import Data.Massiv.Array hiding (all, map, read, replicate, tail, take, zip)
import qualified Data.Massiv.Array as M
import qualified Data.Massiv.Array.Mutable as Mut
import Data.Maybe (fromMaybe)
import Data.SRTree
import Data.SRTree.Recursion ( cata, accu )
import Data.SRTree.Derivative (deriveByParam, deriveByVar, derivative)
import Data.SRTree.Eval
import qualified Data.IntMap.Strict as IntMap
import qualified Data.Vector.Storable as VS
import GHC.IO (unsafePerformIO)
import Data.Maybe

import Debug.Trace
import Data.SRTree.Print
import Algorithm.EqSat.Egraph
import Algorithm.EqSat.Simplify
import Algorithm.EqSat.Build
import Control.Monad.State.Strict
import Control.Monad.Identity

import Data.SRTree.Print

-- | Supported distributions for negative log-likelihood
-- MSE refers to mean squared error
-- HGaussian is Gaussian with heteroscedasticity, where the error should be provided
data Distribution = MSE | Gaussian | HGaussian | Bernoulli | Poisson | ROXY
    deriving (Int -> Distribution -> ShowS
[Distribution] -> ShowS
Distribution -> String
(Int -> Distribution -> ShowS)
-> (Distribution -> String)
-> ([Distribution] -> ShowS)
-> Show Distribution
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Distribution -> ShowS
showsPrec :: Int -> Distribution -> ShowS
$cshow :: Distribution -> String
show :: Distribution -> String
$cshowList :: [Distribution] -> ShowS
showList :: [Distribution] -> ShowS
Show, ReadPrec [Distribution]
ReadPrec Distribution
Int -> ReadS Distribution
ReadS [Distribution]
(Int -> ReadS Distribution)
-> ReadS [Distribution]
-> ReadPrec Distribution
-> ReadPrec [Distribution]
-> Read Distribution
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
$creadsPrec :: Int -> ReadS Distribution
readsPrec :: Int -> ReadS Distribution
$creadList :: ReadS [Distribution]
readList :: ReadS [Distribution]
$creadPrec :: ReadPrec Distribution
readPrec :: ReadPrec Distribution
$creadListPrec :: ReadPrec [Distribution]
readListPrec :: ReadPrec [Distribution]
Read, Int -> Distribution
Distribution -> Int
Distribution -> [Distribution]
Distribution -> Distribution
Distribution -> Distribution -> [Distribution]
Distribution -> Distribution -> Distribution -> [Distribution]
(Distribution -> Distribution)
-> (Distribution -> Distribution)
-> (Int -> Distribution)
-> (Distribution -> Int)
-> (Distribution -> [Distribution])
-> (Distribution -> Distribution -> [Distribution])
-> (Distribution -> Distribution -> [Distribution])
-> (Distribution -> Distribution -> Distribution -> [Distribution])
-> Enum Distribution
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
$csucc :: Distribution -> Distribution
succ :: Distribution -> Distribution
$cpred :: Distribution -> Distribution
pred :: Distribution -> Distribution
$ctoEnum :: Int -> Distribution
toEnum :: Int -> Distribution
$cfromEnum :: Distribution -> Int
fromEnum :: Distribution -> Int
$cenumFrom :: Distribution -> [Distribution]
enumFrom :: Distribution -> [Distribution]
$cenumFromThen :: Distribution -> Distribution -> [Distribution]
enumFromThen :: Distribution -> Distribution -> [Distribution]
$cenumFromTo :: Distribution -> Distribution -> [Distribution]
enumFromTo :: Distribution -> Distribution -> [Distribution]
$cenumFromThenTo :: Distribution -> Distribution -> Distribution -> [Distribution]
enumFromThenTo :: Distribution -> Distribution -> Distribution -> [Distribution]
Enum, Distribution
Distribution -> Distribution -> Bounded Distribution
forall a. a -> a -> Bounded a
$cminBound :: Distribution
minBound :: Distribution
$cmaxBound :: Distribution
maxBound :: Distribution
Bounded, Distribution -> Distribution -> Bool
(Distribution -> Distribution -> Bool)
-> (Distribution -> Distribution -> Bool) -> Eq Distribution
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Distribution -> Distribution -> Bool
== :: Distribution -> Distribution -> Bool
$c/= :: Distribution -> Distribution -> Bool
/= :: Distribution -> Distribution -> Bool
Eq)

-- | Sum-of-square errors or Sum-of-square residues
sse :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sse :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sse SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = Double
err
  where
    (Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
    cmp :: Comp
cmp    = SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp SRMatrix
xss
    yhat :: SRVector
yhat   = SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree
    err :: Double
err    = SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (SRVector -> Double) -> SRVector -> Double
forall a b. (a -> b) -> a -> b
$ (PVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
yhat) SRVector -> Int -> SRVector
forall a b. (Num a, Integral b) => a -> b -> a
^ (Int
2 :: Int)

sseError :: SRMatrix -> PVector -> PVector -> Fix SRTree -> PVector -> Double
sseError :: SRMatrix -> PVector -> PVector -> Fix SRTree -> PVector -> Double
sseError SRMatrix
xss PVector
ys PVector
yErr Fix SRTree
tree PVector
theta = Double
err
  where
    (Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
    cmp :: Comp
cmp    = SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp SRMatrix
xss
    yhat :: SRVector
yhat   = SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree
    err :: Double
err    = SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (SRVector -> Double) -> SRVector -> Double
forall a b. (a -> b) -> a -> b
$ ((PVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
yhat) SRVector -> Int -> SRVector
forall a b. (Num a, Integral b) => a -> b -> a
^ (Int
2 :: Int) SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/ (PVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
yErr))

-- | Total Sum-of-squares
sseTot :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sseTot :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sseTot SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = Double
err
  where
    (Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
    cmp :: Comp
cmp    = SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp SRMatrix
xss
    ym :: Double
ym     = PVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum PVector
ys Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m
    err :: Double
err    = SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (SRVector -> Double) -> SRVector -> Double
forall a b. (a -> b) -> a -> b
$ ((Double -> Double) -> PVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Double -> Double -> Double
forall a. Num a => a -> a -> a
subtract Double
ym) PVector
ys) SRVector -> Int -> SRVector
forall a b. (Num a, Integral b) => a -> b -> a
^ (Int
2 :: Int)
        
-- | Mean squared errors
mse :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
mse :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
mse SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = let (Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys in SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sse SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m

-- | Root of the mean squared errors
rmse :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
rmse :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
rmse SRMatrix
xss PVector
ys Fix SRTree
tree = Double -> Double
forall a. Floating a => a -> a
sqrt (Double -> Double) -> (PVector -> Double) -> PVector -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
mse SRMatrix
xss PVector
ys Fix SRTree
tree

-- | Coefficient of determination
r2 :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
r2 :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
r2 SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sse SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sseTot  SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta

-- | logistic function
logistic :: Floating a => a -> a
logistic :: forall a. Floating a => a -> a
logistic a
x = a
1 a -> a -> a
forall a. Fractional a => a -> a -> a
/ (a
1 a -> a -> a
forall a. Num a => a -> a -> a
+ a -> a
forall a. Floating a => a -> a
exp (-a
x))
{-# inline logistic #-}

-- | get the standard error from a Maybe Double
-- if it is Nothing, estimate from the ssr, otherwise use the current value
-- For distributions other than Gaussian, it defaults to a constant 1
getSErr :: Num a => Distribution -> a -> Maybe a -> a
getSErr :: forall a. Num a => Distribution -> a -> Maybe a -> a
getSErr Distribution
Gaussian a
est = a -> Maybe a -> a
forall a. a -> Maybe a -> a
fromMaybe a
est
getSErr Distribution
_        a
_   = a -> Maybe a -> a
forall a b. a -> b -> a
const a
1
{-# inline getSErr #-}

-- negation of the sum of values in a vector
negSum :: PVector -> Double
negSum :: PVector -> Double
negSum = Double -> Double
forall a. Num a => a -> a
negate (Double -> Double) -> (PVector -> Double) -> PVector -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum
{-# inline negSum #-}

-- | Negative log-likelihood
nll :: Distribution -> Maybe PVector -> SRMatrix -> PVector -> Fix SRTree -> PVector -> Double

-- | Mean Squared error (not a distribution)
nll :: Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> Double
nll Distribution
MSE Maybe PVector
_ SRMatrix
xss PVector
ys Fix SRTree
t PVector
theta = SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
mse SRMatrix
xss PVector
ys Fix SRTree
t PVector
theta

-- | Gaussian distribution, theta must contain an additional parameter corresponding
-- to variance.
nll Distribution
Gaussian Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
t PVector
theta
  | Int
nParams Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== (Int
p'Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) = String -> Double
forall a. HasCallStack => String -> a
error String
"For Gaussian distribution theta must contain the variance as its last value."
  | Bool
otherwise     = Double
0.5Double -> Double -> Double
forall a. Num a => a -> a -> a
*(SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sse SRMatrix
xss PVector
ys Fix SRTree
t PVector
theta Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
s Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
mDouble -> Double -> Double
forall a. Num a => a -> a -> a
*Double -> Double
forall a. Floating a => a -> a
log (Double
2Double -> Double -> Double
forall a. Num a => a -> a -> a
*Double
forall a. Floating a => a
piDouble -> Double -> Double
forall a. Num a => a -> a -> a
*Double
s))
  where
    s :: Double
s       = Double -> Double
forall a. Floating a => a -> a
sqrt (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
mse SRMatrix
xss PVector
ys Fix SRTree
t PVector
theta -- theta M.! (p' - 1)
    (Sz Int
m') = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys 
    (Sz Int
p') = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
theta
    nParams :: Int
nParams = Fix SRTree -> Int
countParamsUniq Fix SRTree
t
    m :: Double
m       = Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m'
    p :: Integer
p       = Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
p'

-- | Gaussian with heteroscedasticity, it needs a valid mYerr
nll Distribution
HGaussian Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
t PVector
theta =
  case Maybe PVector
mYerr of
    Maybe PVector
Nothing   -> String -> Double
forall a. HasCallStack => String -> a
error String
"For HGaussian, you must provide the measured error for the target variable."
    Just PVector
yErr -> Double
0.5Double -> Double -> Double
forall a. Num a => a -> a -> a
*(SRMatrix -> PVector -> PVector -> Fix SRTree -> PVector -> Double
sseError SRMatrix
xss PVector
ys PVector
yErr Fix SRTree
t PVector
theta Double -> Double -> Double
forall a. Num a => a -> a -> a
+ SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum ((Double -> Double) -> PVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Double -> Double
forall a. Floating a => a -> a
log (Double -> Double) -> (Double -> Double) -> Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double
2Double -> Double -> Double
forall a. Num a => a -> a -> a
*) (Double -> Double) -> (Double -> Double) -> Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double
forall a. Floating a => a
piDouble -> Double -> Double
forall a. Num a => a -> a -> a
*)) PVector
yErr))
  where
    (Sz Int
m') = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
    (Sz Int
p') = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
theta
    m :: Integer
m       = Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m'
    p :: Integer
p       = Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
p'

-- | Bernoulli distribution of f(x; theta) is, given phi = 1 / (1 + exp (-f(x; theta))),
-- y log phi + (1-y) log (1 - phi), assuming y \in {0,1}
nll Distribution
Bernoulli Maybe PVector
_ SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta
  | PVector -> Bool
notValid PVector
ys = String -> Double
forall a. HasCallStack => String -> a
error String
"For Bernoulli distribution the output must be either 0 or 1."
  | Bool
otherwise   = SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (SRVector -> Double) -> SRVector -> Double
forall a b. (a -> b) -> a -> b
$ ((Double -> Double) -> SRVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Double
1Double -> Double -> Double
forall a. Num a => a -> a -> a
-) (PVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys)) SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
yhat SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ SRVector -> SRVector
forall a. Floating a => a -> a
log ((Double -> Double) -> SRVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Double
1Double -> Double -> Double
forall a. Num a => a -> a -> a
+) (SRVector -> SRVector) -> SRVector -> SRVector
forall a b. (a -> b) -> a -> b
$ SRVector -> SRVector
forall a. Floating a => a -> a
exp ((Double -> Double) -> SRVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map Double -> Double
forall a. Num a => a -> a
negate SRVector
yhat))
  where
    (Sz Int
m)   = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
    yhat :: SRVector
yhat     = SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree
    notValid :: PVector -> Bool
notValid = (Double -> Bool) -> PVector -> Bool
forall ix r e.
(Index ix, Source r e) =>
(e -> Bool) -> Array r ix e -> Bool
M.any (\Double
x -> Double
x Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
0 Bool -> Bool -> Bool
&& Double
x Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
1)

nll Distribution
Poisson Maybe PVector
_ SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta
  | PVector -> Bool
notValid PVector
ys = String -> Double
forall a. HasCallStack => String -> a
error String
"For Poisson distribution the output must be non-negative."
  -- | M.any isNaN yhat = error $ "NaN predictions " <> show theta
  | Bool
otherwise   = Double -> Double
forall a. Num a => a -> a
negate (Double -> Double) -> (SRVector -> Double) -> SRVector -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (SRVector -> Double) -> SRVector -> Double
forall a b. (a -> b) -> a -> b
$ SRVector
ys' SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
yhat SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
ys' SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector -> SRVector
forall a. Floating a => a -> a
log SRVector
ys' SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector -> SRVector
forall a. Floating a => a -> a
exp SRVector
yhat
  where
    ys' :: SRVector
ys'      = PVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys
    yhat :: SRVector
yhat     = SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree
    notValid :: PVector -> Bool
notValid = (Double -> Bool) -> PVector -> Bool
forall ix r e.
(Index ix, Source r e) =>
(e -> Bool) -> Array r ix e -> Bool
M.any (Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<Double
0)

nll Distribution
ROXY Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta
  | Maybe PVector -> Bool
forall a. Maybe a -> Bool
isNothing Maybe PVector
mYerr = String -> Double
forall a. HasCallStack => String -> a
error String
"Can't calculate ROXY nll without x,y-errors."
  | Int
p Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
num_params Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
3 = String -> Double
forall a. HasCallStack => String -> a
error String
"We need 3 additional parameters for ROXY."
  | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
1 Bool -> Bool -> Bool
&& Int
nInt -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/=Int
5     = String -> Double
forall a. HasCallStack => String -> a
error String
"For ROXY dataset must contain a single variable, or 1 variable + 4 cached data."
  | Bool
otherwise          = if Double -> Bool
forall a. RealFloat a => a -> Bool
isNaN Double
negLL then (Double
1.0Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
0.0) else Double
negLL
  where
    (Sz Int
p')      = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
theta
    (Sz2 Int
m Int
n)    = SRMatrix -> Sz Ix2
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size SRMatrix
xss
    p :: Int
p            = Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
p'
    num_params :: Int
num_params   = Fix SRTree -> Int
countParamsUniq Fix SRTree
tree

    x0 :: Array D (Lower Ix2) Double
x0           = SRMatrix
xss SRMatrix -> Int -> Array D (Lower Ix2) Double
forall r ix e.
(HasCallStack, Index ix, Source r e) =>
Array r ix e -> Int -> Array D (Lower ix) e
<! Int
0
    logX :: Array D (Lower Ix2) Double
logX         = SRMatrix
xss SRMatrix -> Int -> Array D (Lower Ix2) Double
forall r ix e.
(HasCallStack, Index ix, Source r e) =>
Array r ix e -> Int -> Array D (Lower ix) e
<! Int
1
    logY :: Array D (Lower Ix2) Double
logY         = SRMatrix
xss SRMatrix -> Int -> Array D (Lower Ix2) Double
forall r ix e.
(HasCallStack, Index ix, Source r e) =>
Array r ix e -> Int -> Array D (Lower ix) e
<! Int
2
    logXErr :: Array D (Lower Ix2) Double
logXErr      = SRMatrix
xss SRMatrix -> Int -> Array D (Lower Ix2) Double
forall r ix e.
(HasCallStack, Index ix, Source r e) =>
Array r ix e -> Int -> Array D (Lower ix) e
<! Int
3
    logYErr :: Array D (Lower Ix2) Double
logYErr      = SRMatrix
xss SRMatrix -> Int -> Array D (Lower Ix2) Double
forall r ix e.
(HasCallStack, Index ix, Source r e) =>
Array r ix e -> Int -> Array D (Lower ix) e
<! Int
4


    yErr :: PVector
yErr         = Maybe PVector -> PVector
forall a. HasCallStack => Maybe a -> a
fromJust Maybe PVector
mYerr
    one :: SRVector
one          = Comp -> Sz Int -> Double -> SRVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
compMode (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
1
    zero :: SRVector
zero         = Comp -> Sz Int -> Double -> SRVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
compMode (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
0

    (Double
sig, Double
mu_gauss, Double
w_gauss) = (PVector
theta PVector -> Int -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! Int
num_params, PVector
theta PVector -> Int -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! (Int
num_params Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1), PVector
theta PVector -> Int -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! (Int
num_params Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2))

    applyDer :: Op -> Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
    applyDer :: Op -> SRVector -> SRVector -> SRVector -> SRVector -> SRVector
applyDer Op
Add SRVector
l SRVector
dl SRVector
r SRVector
dr      = SRVector
dlSRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+SRVector
dr
    applyDer Op
Sub SRVector
l SRVector
dl SRVector
r SRVector
dr      = SRVector
dlSRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
-SRVector
dr
    applyDer Op
Mul SRVector
l SRVector
dl SRVector
r SRVector
dr      = SRVector
lSRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
dr SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ SRVector
rSRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
dl
    applyDer Op
Div SRVector
l SRVector
dl SRVector
r SRVector
dr      = (SRVector
dlSRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
r SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
drSRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
l) SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/ (SRVector
rSRVector -> Integer -> SRVector
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2)
    applyDer Op
Power SRVector
l SRVector
dl SRVector
r SRVector
dr    = SRVector
l SRVector -> SRVector -> SRVector
forall a. Floating a => a -> a -> a
** (SRVector
rSRVector -> Double -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
Array r ix e -> e -> Array r ix e
.-Double
1) SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* (SRVector
rSRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
dl SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ SRVector
l SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector -> SRVector
forall a. Floating a => a -> a
log SRVector
l SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
dr)
    applyDer Op
PowerAbs SRVector
l SRVector
dl SRVector
r SRVector
dr = (SRVector -> SRVector
forall a. Num a => a -> a
abs SRVector
l SRVector -> SRVector -> SRVector
forall a. Floating a => a -> a -> a
** SRVector
r) SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* (SRVector
dr SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector -> SRVector
forall a. Floating a => a -> a
log (SRVector -> SRVector
forall a. Num a => a -> a
abs SRVector
l) SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ SRVector
r SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
dl SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/ SRVector
l)
    applyDer Op
AQ SRVector
l SRVector
dl SRVector
r SRVector
dr       = ((Double
1 Double -> SRVector -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
e -> Array r ix e -> Array r ix e
+. SRVector
rSRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
r) SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
dl SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
l SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
r SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
dr) SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/ (Double -> Double) -> SRVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Double -> Double -> Double
forall a. Floating a => a -> a -> a
**Double
1.5) (Double
1 Double -> SRVector -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
e -> Array r ix e -> Array r ix e
+. SRVector
rSRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
r)

    (SRVector
yhat, SRVector
grad) = (SRTree (SRVector, SRVector) -> (SRVector, SRVector))
-> Fix SRTree -> (SRVector, SRVector)
forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata SRTree (SRVector, SRVector) -> (SRVector, SRVector)
alg Fix SRTree
tree
      where
        alg :: SRTree (SRVector, SRVector) -> (SRVector, SRVector)
alg (Var Int
ix)   = (SRVector
x0, SRVector
one)
        alg (Param Int
ix) = (Comp -> Sz Int -> Double -> SRVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
compMode (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) (PVector
theta PVector -> Int -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
M.! Int
ix), SRVector
zero)
        alg (Const Double
x)  = (Comp -> Sz Int -> Double -> SRVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
compMode (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
x, SRVector
zero)
        alg (Uni Function
f (SRVector
val, SRVector
der))  = ((Double -> Double) -> SRVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Function -> Double -> Double
forall a. Floating a => Function -> a -> a
evalFun Function
f) SRVector
val, (Double -> Double) -> SRVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Function -> Double -> Double
forall a. Floating a => Function -> a -> a
derivative Function
f) SRVector
val SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
der)
        alg (Bin Op
op (SRVector
valL, SRVector
derL) (SRVector
valR, SRVector
derR)) = ((Double -> Double -> Double) -> SRVector -> SRVector -> SRVector
forall ix r1 e1 r2 e2 e.
(Index ix, Source r1 e1, Source r2 e2) =>
(e1 -> e2 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix e
M.zipWith (Op -> Double -> Double -> Double
forall a. Floating a => Op -> a -> a -> a
evalOp Op
op) SRVector
valL SRVector
valR, Op -> SRVector -> SRVector -> SRVector -> SRVector -> SRVector
applyDer Op
op SRVector
valL SRVector
derL SRVector
valR SRVector
derR)

    f :: SRVector
f            = (Double -> Double) -> SRVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Double -> Double -> Double
forall a. Floating a => a -> a -> a
logBase Double
10) (SRVector -> SRVector
forall a. Num a => a -> a
abs SRVector
yhat)
    fprime :: SRVector
fprime       = SRVector
grad SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/ (Double -> Double
forall a. Floating a => a -> a
log Double
10 Double -> SRVector -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
e -> Array r ix e -> Array r ix e
*. SRVector
yhat) SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
x0 SRVector -> Double -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
Array r ix e -> e -> Array r ix e
.* Double -> Double
forall a. Floating a => a -> a
log Double
10

    -- nll
    w_gauss2 :: Double
w_gauss2     = Double
w_gauss Double -> Integer -> Double
forall a b. (Num a, Integral b) => a -> b -> a
^ Integer
2
    s2 :: SRVector
s2           = SRVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay (SRVector -> SRVector) -> SRVector -> SRVector
forall a b. (a -> b) -> a -> b
$ SRVector
logYErr SRVector -> Double -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
Array r ix e -> e -> Array r ix e
.+ Double
sigDouble -> Integer -> Double
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2
    den :: SRVector
den          = SRVector
fprime SRVector -> Integer -> SRVector
forall a b. (Num a, Integral b) => a -> b -> a
^ Integer
2 SRVector -> Double -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
Array r ix e -> e -> Array r ix e
.* Double
w_gauss2 SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
logXErr SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ SRVector
s2 SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* (Double
w_gauss2 Double -> SRVector -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
e -> Array r ix e -> Array r ix e
+. SRVector
logXErr)

    neglogP :: SRVector
neglogP = Double -> Double
forall a. Floating a => a -> a
log (Double
2 Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
forall a. Floating a => a
pi)
        Double -> SRVector -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
e -> Array r ix e -> Array r ix e
+. SRVector -> SRVector
forall a. Floating a => a -> a
log SRVector
den
        SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ (Double
w_gauss2 Double -> SRVector -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
e -> Array r ix e -> Array r ix e
*. (SRVector
f SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
logY) SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* (SRVector
f SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
logY)
           SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ SRVector
logXErr SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* (SRVector
fprime SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* (Double
mu_gauss Double -> SRVector -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
e -> Array r ix e -> Array r ix e
-. SRVector
logX) SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ SRVector
f SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
logY)SRVector -> Integer -> SRVector
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2
           SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ SRVector
s2 SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* (SRVector
logX SRVector -> Double -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
Array r ix e -> e -> Array r ix e
.- Double
mu_gauss)SRVector -> Integer -> SRVector
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2) SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/ SRVector
den
    negLL :: Double
negLL = Double
0.5 Double -> Double -> Double
forall a. Num a => a -> a -> a
* SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
neglogP

-- WARNING: pass tree with parameters
-- TODO: handle error similar to ROXY
buildNLL :: Distribution -> Double -> Fix SRTree -> Fix SRTree
buildNLL Distribution
MSE Double
m Fix SRTree
tree = ((Fix SRTree
tree Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
- Int -> Fix SRTree
var (-Int
1)) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a -> a
** Fix SRTree
2) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Fractional a => a -> a -> a
/ Double -> Fix SRTree
constv Double
m
buildNLL Distribution
Gaussian Double
m Fix SRTree
tree =  (Fix SRTree -> Fix SRTree
square(Fix SRTree
tree Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
- Int -> Fix SRTree
var (-Int
1)) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Fractional a => a -> a -> a
/ Fix SRTree -> Fix SRTree
square (Int -> Fix SRTree
param Int
p)) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
log ((Fix SRTree -> Fix SRTree
square (Int -> Fix SRTree
param Int
p)))
  where
    square :: Fix SRTree -> Fix SRTree
square = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> (Fix SRTree -> SRTree (Fix SRTree)) -> Fix SRTree -> Fix SRTree
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Function -> val -> SRTree val
Uni Function
Square
    p :: Int
p = Fix SRTree -> Int
countParamsUniq Fix SRTree
tree
buildNLL Distribution
HGaussian Double
m Fix SRTree
tree = (Fix SRTree
tree Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
- Int -> Fix SRTree
var (-Int
1)) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a -> a
** Fix SRTree
2 Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Fractional a => a -> a -> a
/ Int -> Fix SRTree
var (-Int
2) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ Double -> Fix SRTree
constv Double
m Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
log (Fix SRTree
2Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
*Fix SRTree
forall a. Floating a => a
piFix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* Int -> Fix SRTree
var (-Int
2))
buildNLL Distribution
Poisson Double
m Fix SRTree
tree = Int -> Fix SRTree
var (-Int
1) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
log (Int -> Fix SRTree
var (-Int
1)) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
exp Fix SRTree
tree Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
- Int -> Fix SRTree
var (-Int
1) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* Fix SRTree
tree
buildNLL Distribution
Bernoulli Double
m Fix SRTree
tree = Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
log (Fix SRTree
1 Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
exp (Fix SRTree -> Fix SRTree
forall a. Num a => a -> a
negate Fix SRTree
tree)) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ (Fix SRTree
1 Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
- Int -> Fix SRTree
var (-Int
1)) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* Fix SRTree
tree
buildNLL Distribution
ROXY Double
m Fix SRTree
tree = Fix SRTree
neglogP
  where
    p :: Int
p = Fix SRTree -> Int
countParamsUniq Fix SRTree
tree
    f :: Fix SRTree
f = Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
log (Fix SRTree -> Fix SRTree
forall a. Num a => a -> a
abs Fix SRTree
tree) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Fractional a => a -> a -> a
/ Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
log Fix SRTree
10
    fprime :: Fix SRTree
fprime = Int -> Fix SRTree -> Fix SRTree
deriveByVar Int
0 Fix SRTree
tree Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Fractional a => a -> a -> a
/ (Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
log Fix SRTree
10 Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* Fix SRTree
tree) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* Int -> Fix SRTree
var Int
0 Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
log Fix SRTree
10
    logX :: Fix SRTree
logX         = Int -> Fix SRTree
var Int
1
    logY :: Fix SRTree
logY         = Int -> Fix SRTree
var Int
2
    logXErr :: Fix SRTree
logXErr      = Int -> Fix SRTree
var Int
3
    logYErr :: Fix SRTree
logYErr      = Int -> Fix SRTree
var Int
4
    sig :: Fix SRTree
sig = Int -> Fix SRTree
param Int
p
    mu_gauss :: Fix SRTree
mu_gauss = Int -> Fix SRTree
param (Int
pInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
    w_gauss :: Fix SRTree
w_gauss = Int -> Fix SRTree
param (Int
pInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
2)
    w_gauss2 :: Fix SRTree
w_gauss2 = Fix SRTree
w_gauss Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a -> a
** Fix SRTree
2
    s2 :: Fix SRTree
s2 = Fix SRTree
logYErr Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ Fix SRTree
sig Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a -> a
** Fix SRTree
2
    den :: Fix SRTree
den = Fix SRTree
fprime Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a -> a
** Fix SRTree
2 Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* Fix SRTree
w_gauss2 Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* Fix SRTree
logXErr Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ Fix SRTree
s2 Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* (Fix SRTree
w_gauss2 Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ Fix SRTree
logXErr)
    neglogP :: Fix SRTree
neglogP = Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
log (Fix SRTree
2Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
*Fix SRTree
forall a. Floating a => a
pi)
              Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
log Fix SRTree
den
              Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ ( Fix SRTree
w_gauss2 Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* (Fix SRTree
f Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
- Fix SRTree
logY) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* (Fix SRTree
f Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
- Fix SRTree
logY)
                Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ Fix SRTree
logXErr Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* (Fix SRTree
fprime Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
*(Fix SRTree
mu_gauss Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
- Fix SRTree
logX) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ Fix SRTree
f Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
- Fix SRTree
logY)Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a -> a
**Fix SRTree
2
                Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ Fix SRTree
s2 Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* (Fix SRTree
logX Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
- Fix SRTree
mu_gauss) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a -> a
** Fix SRTree
2
                ) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Fractional a => a -> a -> a
/ Fix SRTree
den

buildNLLEGraph :: Distribution -> Double -> EGraph -> Int -> (Int, EGraph)
buildNLLEGraph Distribution
MSE Double
m EGraph
egraph Int
root = Identity (Int, EGraph) -> (Int, EGraph)
forall a. Identity a -> a
runIdentity (Identity (Int, EGraph) -> (Int, EGraph))
-> Identity (Int, EGraph) -> (Int, EGraph)
forall a b. (a -> b) -> a -> b
$ EGraphST Identity Int
addToEg  EGraphST Identity Int -> EGraph -> Identity (Int, EGraph)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
`runStateT` EGraph
egraph
  where
    addToEg :: EGraphST Identity EClassId
    addToEg :: EGraphST Identity Int
addToEg = do Int
v  <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Int -> ENode
forall val. Int -> SRTree val
Var (-Int
1))
                 Int
c1 <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Double -> ENode
forall val. Double -> SRTree val
Const Double
2)
                 Int
c2 <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Double -> ENode
forall val. Double -> SRTree val
Const Double
m)
                 Int
x <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Op -> Int -> Int -> ENode
forall val. Op -> val -> val -> SRTree val
Bin Op
Sub Int
root Int
v)
                 Int
y <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Op -> Int -> Int -> ENode
forall val. Op -> val -> val -> SRTree val
Bin Op
Power Int
x Int
c1)
                 CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Op -> Int -> Int -> ENode
forall val. Op -> val -> val -> SRTree val
Bin Op
Div Int
y Int
c2)


buildNLLEGraph Distribution
Gaussian Double
m EGraph
egraph Int
root = Identity (Int, EGraph) -> (Int, EGraph)
forall a. Identity a -> a
runIdentity (EGraphST Identity Int
addToEg EGraphST Identity Int -> EGraph -> Identity (Int, EGraph)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
`runStateT` EGraph
egraph)
  where
    p :: Int
p      = EGraph -> Int -> Int
countParamsUniqEg EGraph
egraph Int
root
    addToEg :: EGraphST Identity EClassId
    addToEg :: EGraphST Identity Int
addToEg = do Int
v <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Int -> ENode
forall val. Int -> SRTree val
Var (-Int
1))
                 Int
p <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Int -> ENode
forall val. Int -> SRTree val
Param Int
p)
                 Int
sp <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Function -> Int -> ENode
forall val. Function -> val -> SRTree val
Uni Function
Square Int
p)
                 Int
lsp <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Function -> Int -> ENode
forall val. Function -> val -> SRTree val
Uni Function
Log Int
sp)
                 Int
d <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Op -> Int -> Int -> ENode
forall val. Op -> val -> val -> SRTree val
Bin Op
Sub Int
root Int
v)
                 Int
sd <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Function -> Int -> ENode
forall val. Function -> val -> SRTree val
Uni Function
Square Int
d)
                 Int
x <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Op -> Int -> Int -> ENode
forall val. Op -> val -> val -> SRTree val
Bin Op
Div Int
sd Int
sp)
                 CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Op -> Int -> Int -> ENode
forall val. Op -> val -> val -> SRTree val
Bin Op
Add Int
x Int
lsp)

buildNLLEGraph Distribution
HGaussian Double
m EGraph
egraph Int
root = Identity (Int, EGraph) -> (Int, EGraph)
forall a. Identity a -> a
runIdentity (Identity (Int, EGraph) -> (Int, EGraph))
-> Identity (Int, EGraph) -> (Int, EGraph)
forall a b. (a -> b) -> a -> b
$ EGraphST Identity Int
addToEg EGraphST Identity Int -> EGraph -> Identity (Int, EGraph)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
`runStateT` EGraph
egraph
  where
    addToEg :: EGraphST Identity EClassId
    addToEg :: EGraphST Identity Int
addToEg = do Int
v1 <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Int -> ENode
forall val. Int -> SRTree val
Var (-Int
1))
                 Int
v2 <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Int -> ENode
forall val. Int -> SRTree val
Var (-Int
2))
                 Int
c1 <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Double -> ENode
forall val. Double -> SRTree val
Const (Double
2Double -> Double -> Double
forall a. Num a => a -> a -> a
*Double
forall a. Floating a => a
pi))
                 Int
c2 <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Double -> ENode
forall val. Double -> SRTree val
Const Double
m)
                 Int
x <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Op -> Int -> Int -> ENode
forall val. Op -> val -> val -> SRTree val
Bin Op
Sub Int
root Int
v1)
                 Int
y <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Function -> Int -> ENode
forall val. Function -> val -> SRTree val
Uni Function
Square Int
x)
                 Int
z <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Op -> Int -> Int -> ENode
forall val. Op -> val -> val -> SRTree val
Bin Op
Div Int
y Int
v2)
                 Int
w <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Op -> Int -> Int -> ENode
forall val. Op -> val -> val -> SRTree val
Bin Op
Mul Int
c1 Int
v2)
                 Int
lw <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Function -> Int -> ENode
forall val. Function -> val -> SRTree val
Uni Function
Log Int
w)
                 Int
p <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Op -> Int -> Int -> ENode
forall val. Op -> val -> val -> SRTree val
Bin Op
Mul Int
c2 Int
lw)
                 CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Op -> Int -> Int -> ENode
forall val. Op -> val -> val -> SRTree val
Bin Op
Add Int
z Int
p)


buildNLLEGraph Distribution
Poisson Double
m EGraph
egraph Int
root = Identity (Int, EGraph) -> (Int, EGraph)
forall a. Identity a -> a
runIdentity (Identity (Int, EGraph) -> (Int, EGraph))
-> Identity (Int, EGraph) -> (Int, EGraph)
forall a b. (a -> b) -> a -> b
$ EGraphST Identity Int
addToEg EGraphST Identity Int -> EGraph -> Identity (Int, EGraph)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
`runStateT` EGraph
egraph
  where
    addToEg :: EGraphST Identity EClassId
    addToEg :: EGraphST Identity Int
addToEg = do Int
v1 <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Int -> ENode
forall val. Int -> SRTree val
Var (-Int
1))
                 Int
lv <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Function -> Int -> ENode
forall val. Function -> val -> SRTree val
Uni Function
Log Int
v1)
                 Int
x  <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Op -> Int -> Int -> ENode
forall val. Op -> val -> val -> SRTree val
Bin Op
Mul Int
v1 Int
lv)
                 Int
y  <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Function -> Int -> ENode
forall val. Function -> val -> SRTree val
Uni Function
Exp Int
root)
                 Int
z  <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Op -> Int -> Int -> ENode
forall val. Op -> val -> val -> SRTree val
Bin Op
Add Int
x Int
y)
                 Int
vt <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Op -> Int -> Int -> ENode
forall val. Op -> val -> val -> SRTree val
Bin Op
Mul Int
v1 Int
root)
                 CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Op -> Int -> Int -> ENode
forall val. Op -> val -> val -> SRTree val
Bin Op
Sub Int
z Int
vt)

buildNLLEGraph Distribution
Bernoulli Double
m EGraph
egraph Int
root = Identity (Int, EGraph) -> (Int, EGraph)
forall a. Identity a -> a
runIdentity (Identity (Int, EGraph) -> (Int, EGraph))
-> Identity (Int, EGraph) -> (Int, EGraph)
forall a b. (a -> b) -> a -> b
$ EGraphST Identity Int
addToEg EGraphST Identity Int -> EGraph -> Identity (Int, EGraph)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
`runStateT` EGraph
egraph
  where
    addToEg :: EGraphST Identity EClassId
    addToEg :: EGraphST Identity Int
addToEg = do Int
v <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Int -> ENode
forall val. Int -> SRTree val
Var (-Int
1))
                 Int
c1 <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Double -> ENode
forall val. Double -> SRTree val
Const Double
1)
                 Int
c2 <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Double -> ENode
forall val. Double -> SRTree val
Const (-Double
1))
                 Int
mr <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Op -> Int -> Int -> ENode
forall val. Op -> val -> val -> SRTree val
Bin Op
Mul Int
c2 Int
root)
                 Int
er <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Function -> Int -> ENode
forall val. Function -> val -> SRTree val
Uni Function
Exp Int
mr)
                 Int
er1 <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Op -> Int -> Int -> ENode
forall val. Op -> val -> val -> SRTree val
Bin Op
Add Int
c1 Int
er)
                 Int
ler1 <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Function -> Int -> ENode
forall val. Function -> val -> SRTree val
Uni Function
Log Int
er1)
                 Int
v1 <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Op -> Int -> Int -> ENode
forall val. Op -> val -> val -> SRTree val
Bin Op
Sub Int
c1 Int
v)
                 Int
v1r <- CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Op -> Int -> Int -> ENode
forall val. Op -> val -> val -> SRTree val
Bin Op
Mul Int
v1 Int
root)
                 CostFun -> ENode -> EGraphST Identity Int
forall (m :: * -> *). Monad m => CostFun -> ENode -> EGraphST m Int
add CostFun
myCost (Op -> Int -> Int -> ENode
forall val. Op -> val -> val -> SRTree val
Bin Op
Add Int
ler1 Int
v1r)

buildNLLEGraph Distribution
ROXY Double
m EGraph
egraph Int
root = String -> (Int, EGraph)
forall a. HasCallStack => String -> a
error String
"ROXY not supported with cache"

-- | Prediction for different distributions
predict :: Distribution -> Fix SRTree -> PVector -> SRMatrix -> SRVector
predict :: Distribution -> Fix SRTree -> PVector -> SRMatrix -> SRVector
predict Distribution
MSE       Fix SRTree
tree PVector
theta SRMatrix
xss = SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree
predict Distribution
Gaussian  Fix SRTree
tree PVector
theta SRMatrix
xss = SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree
predict Distribution
Bernoulli Fix SRTree
tree PVector
theta SRMatrix
xss = SRVector -> SRVector
forall a. Floating a => a -> a
logistic (SRVector -> SRVector) -> SRVector -> SRVector
forall a b. (a -> b) -> a -> b
$ SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree
predict Distribution
Poisson   Fix SRTree
tree PVector
theta SRMatrix
xss = SRVector -> SRVector
forall a. Floating a => a -> a
exp (SRVector -> SRVector) -> SRVector -> SRVector
forall a b. (a -> b) -> a -> b
$ SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree
predict Distribution
ROXY      Fix SRTree
tree PVector
theta SRMatrix
xss = SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree

-- | Gradient of the negative log-likelihood
gradNLL :: Distribution -> Maybe PVector -> SRMatrix -> PVector -> Fix SRTree -> PVector -> (Double, SRVector)
gradNLL :: Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (Double, SRVector)
gradNLL Distribution
dist Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = (Double
f, PVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
grad) -- gradNLLArr dist xss ys mYerr treeArr j2ix (toStorableVector theta)
  where
    grad :: PVector
    grad :: PVector
grad = Comp -> [Double] -> PVector
forall r e. Manifest r e => Comp -> [e] -> Vector r e
M.fromList Comp
M.Seq [Int -> Double
finitediff Int
ix | Int
ix <- [Int
0..Int
pInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]]
    (Sz Int
p) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
theta

    disturb :: Int -> PVector
    disturb :: Int -> PVector
disturb Int
ix = Comp -> [Double] -> PVector
forall r e. Manifest r e => Comp -> [e] -> Vector r e
M.fromList Comp
M.Seq ([Double] -> PVector) -> [Double] -> PVector
forall a b. (a -> b) -> a -> b
$ (Int -> Double -> Double) -> [Int] -> [Double] -> [Double]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
Prelude.zipWith (\Int
iy Double
v -> if Int
iyInt -> Int -> Bool
forall a. Eq a => a -> a -> Bool
==Int
ix  then (Double
vDouble -> Double -> Double
forall a. Num a => a -> a -> a
+Double
eps) else Double
v) [Int
0..] (PVector -> [Double]
forall ix r e. (Index ix, Source r e) => Array r ix e -> [e]
M.toList PVector
theta)
    eps :: Double
    eps :: Double
eps = Double
1e-8
    f :: Double
f = (Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m) (Double -> Double) -> (SRVector -> Double) -> SRVector -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (SRVector -> Double)
-> (SRVector -> SRVector) -> SRVector -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double -> Double) -> SRVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Double -> Integer -> Double
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2) (SRVector -> Double) -> SRVector -> Double
forall a b. (a -> b) -> a -> b
$ (Distribution -> Fix SRTree -> PVector -> SRMatrix -> SRVector
predict Distribution
MSE Fix SRTree
tree PVector
theta SRMatrix
xss) SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- PVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys
    finitediff :: Int -> Double
finitediff Int
ix = let t1 :: PVector
t1 = Int -> PVector
disturb Int
ix
                        f' :: Double
f' = (Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m) (Double -> Double) -> (SRVector -> Double) -> SRVector -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (SRVector -> Double)
-> (SRVector -> SRVector) -> SRVector -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double -> Double) -> SRVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Double -> Integer -> Double
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2) (SRVector -> Double) -> SRVector -> Double
forall a b. (a -> b) -> a -> b
$ (Distribution -> Fix SRTree -> PVector -> SRMatrix -> SRVector
predict Distribution
MSE Fix SRTree
tree PVector
t1 SRMatrix
xss) SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- PVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys
                     in (Double
f' Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
f)Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
eps
    (Sz2 Int
m Int
_) = SRMatrix -> Sz Ix2
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size SRMatrix
xss
    tree' :: Fix SRTree
tree'     = Distribution -> Double -> Fix SRTree -> Fix SRTree
buildNLL Distribution
dist (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m) Fix SRTree
tree
    treeArr :: [(Int, (Int, Int, Int, Double))]
treeArr   = IntMap (Int, Int, Int, Double) -> [(Int, (Int, Int, Int, Double))]
forall a. IntMap a -> [(Int, a)]
IntMap.toAscList (IntMap (Int, Int, Int, Double)
 -> [(Int, (Int, Int, Int, Double))])
-> IntMap (Int, Int, Int, Double)
-> [(Int, (Int, Int, Int, Double))]
forall a b. (a -> b) -> a -> b
$ Fix SRTree -> IntMap (Int, Int, Int, Double)
tree2arr Fix SRTree
tree'
    j2ix :: IntMap Integer
j2ix      = [(Int, Integer)] -> IntMap Integer
forall a. [(Int, a)] -> IntMap a
IntMap.fromList ([(Int, Integer)] -> IntMap Integer)
-> [(Int, Integer)] -> IntMap Integer
forall a b. (a -> b) -> a -> b
$ [Int] -> [Integer] -> [(Int, Integer)]
forall a b. [a] -> [b] -> [(a, b)]
Prelude.zip (((Int, (Int, Int, Int, Double)) -> Int)
-> [(Int, (Int, Int, Int, Double))] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
Prelude.map (Int, (Int, Int, Int, Double)) -> Int
forall a b. (a, b) -> a
fst [(Int, (Int, Int, Int, Double))]
treeArr) [Integer
0..]



nanTo0 :: p -> p
nanTo0 p
x = p
x -- if isNaN x || isInfinite x then 0 else x
{-# INLINE nanTo0 #-}

-- | Gradient of the negative log-likelihood
gradNLLArr :: Distribution
-> SRMatrix
-> PVector
-> Maybe PVector
-> [(Int, (Int, Int, Int, Double))]
-> IntMap Int
-> Vector Double
-> (Double, SRVector)
gradNLLArr Distribution
MSE SRMatrix
xss PVector
ys Maybe PVector
mYerr [(Int, (Int, Int, Int, Double))]
tree IntMap Int
j2ix Vector Double
theta =
  (SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, SRVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay SRVector
grad')
  where
    (SRVector
yhat, PVector
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> Vector Double
-> [(Int, (Int, Int, Int, Double))]
-> IntMap Int
-> (SRVector, PVector)
reverseModeArr SRMatrix
xss PVector
ys Maybe PVector
mYerr Vector Double
theta [(Int, (Int, Int, Int, Double))]
tree IntMap Int
j2ix
    grad' :: SRVector
grad'        = (Double -> Double) -> PVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map Double -> Double
forall {p}. p -> p
nanTo0 PVector
grad
gradNLLArr Distribution
Gaussian SRMatrix
xss PVector
ys Maybe PVector
mYerr [(Int, (Int, Int, Int, Double))]
tree IntMap Int
j2ix Vector Double
theta =
  (SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, SRVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay SRVector
grad')
  where
    (SRVector
yhat, PVector
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> Vector Double
-> [(Int, (Int, Int, Int, Double))]
-> IntMap Int
-> (SRVector, PVector)
reverseModeArr SRMatrix
xss PVector
ys Maybe PVector
mYerr Vector Double
theta [(Int, (Int, Int, Int, Double))]
tree IntMap Int
j2ix
    grad' :: SRVector
grad'        = (Double -> Double) -> PVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map Double -> Double
forall {p}. p -> p
nanTo0 PVector
grad
gradNLLArr Distribution
Bernoulli SRMatrix
xss PVector
ys Maybe PVector
mYerr [(Int, (Int, Int, Int, Double))]
tree IntMap Int
j2ix Vector Double
theta
  | (Double -> Bool) -> PVector -> Bool
forall ix r e.
(Index ix, Source r e) =>
(e -> Bool) -> Array r ix e -> Bool
M.any (\Double
x -> Double
x Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
0 Bool -> Bool -> Bool
&& Double
x Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
1) PVector
ys = String -> (Double, SRVector)
forall a. HasCallStack => String -> a
error String
"For Bernoulli distribution the output must be either 0 or 1."
  | Bool
otherwise                         = (SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, SRVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay SRVector
grad')
  where
    (SRVector
yhat, PVector
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> Vector Double
-> [(Int, (Int, Int, Int, Double))]
-> IntMap Int
-> (SRVector, PVector)
reverseModeArr SRMatrix
xss PVector
ys Maybe PVector
mYerr Vector Double
theta [(Int, (Int, Int, Int, Double))]
tree IntMap Int
j2ix
    grad' :: SRVector
grad'        = (Double -> Double) -> PVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map Double -> Double
forall {p}. p -> p
nanTo0 PVector
grad
gradNLLArr Distribution
Poisson SRMatrix
xss PVector
ys Maybe PVector
mYerr [(Int, (Int, Int, Int, Double))]
tree IntMap Int
j2ix Vector Double
theta
  | (Double -> Bool) -> PVector -> Bool
forall ix r e.
(Index ix, Source r e) =>
(e -> Bool) -> Array r ix e -> Bool
M.any (Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<Double
0) PVector
ys    = String -> (Double, SRVector)
forall a. HasCallStack => String -> a
error String
"For Poisson distribution the output must be non-negative."
  | Bool
otherwise        = (SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, SRVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay SRVector
grad')
  where
    (SRVector
yhat, PVector
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> Vector Double
-> [(Int, (Int, Int, Int, Double))]
-> IntMap Int
-> (SRVector, PVector)
reverseModeArr SRMatrix
xss PVector
ys Maybe PVector
mYerr Vector Double
theta [(Int, (Int, Int, Int, Double))]
tree IntMap Int
j2ix
    grad' :: SRVector
grad'        = (Double -> Double) -> PVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map Double -> Double
forall {p}. p -> p
nanTo0 PVector
grad
gradNLLArr Distribution
ROXY SRMatrix
xss PVector
ys Maybe PVector
mYerr [(Int, (Int, Int, Int, Double))]
tree IntMap Int
j2ix Vector Double
theta =
  ((Double -> Double -> Double
forall a. Num a => a -> a -> a
*Double
0.5) (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, (Double -> Double) -> SRVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Double -> Double -> Double
forall a. Num a => a -> a -> a
*(Double
0.5)) (SRVector -> SRVector) -> SRVector -> SRVector
forall a b. (a -> b) -> a -> b
$ SRVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay SRVector
grad')
  where
    (SRVector
yhat, PVector
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> Vector Double
-> [(Int, (Int, Int, Int, Double))]
-> IntMap Int
-> (SRVector, PVector)
reverseModeArr SRMatrix
xss PVector
ys Maybe PVector
mYerr Vector Double
theta [(Int, (Int, Int, Int, Double))]
tree IntMap Int
j2ix
    grad' :: SRVector
grad'        = (Double -> Double) -> PVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map Double -> Double
forall {p}. p -> p
nanTo0 PVector
grad

-- | Gradient of the negative log-likelihood
gradNLLGraph :: Distribution
-> SRMatrix
-> PVector
-> Maybe PVector
-> Fix SRTree
-> Vector Double
-> (Double, Vector Double)
gradNLLGraph Distribution
MSE SRMatrix
xss PVector
ys Maybe PVector
mYerr Fix SRTree
tree Vector Double
theta =
  (SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, Vector Double
grad')
  where
    (SRVector
yhat, Vector Double
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> Vector Double
-> Fix SRTree
-> (SRVector, Vector Double)
reverseModeGraph SRMatrix
xss PVector
ys Maybe PVector
mYerr Vector Double
theta Fix SRTree
tree
    grad' :: Vector Double
grad'        = (Double -> Double) -> Vector Double -> Vector Double
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
VS.map Double -> Double
forall {p}. p -> p
nanTo0 Vector Double
grad
gradNLLGraph Distribution
Gaussian SRMatrix
xss PVector
ys Maybe PVector
mYerr Fix SRTree
tree Vector Double
theta =
  (SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, Vector Double
grad')
  where
    (SRVector
yhat, Vector Double
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> Vector Double
-> Fix SRTree
-> (SRVector, Vector Double)
reverseModeGraph SRMatrix
xss PVector
ys Maybe PVector
mYerr Vector Double
theta Fix SRTree
tree
    grad' :: Vector Double
grad'        = (Double -> Double) -> Vector Double -> Vector Double
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
VS.map Double -> Double
forall {p}. p -> p
nanTo0 Vector Double
grad
gradNLLGraph Distribution
Bernoulli SRMatrix
xss PVector
ys Maybe PVector
mYerr Fix SRTree
tree Vector Double
theta
  | (Double -> Bool) -> PVector -> Bool
forall ix r e.
(Index ix, Source r e) =>
(e -> Bool) -> Array r ix e -> Bool
M.any (\Double
x -> Double
x Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
0 Bool -> Bool -> Bool
&& Double
x Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
1) PVector
ys = String -> (Double, Vector Double)
forall a. HasCallStack => String -> a
error String
"For Bernoulli distribution the output must be either 0 or 1."
  | Bool
otherwise                         = (SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, Vector Double
grad')
  where
    (SRVector
yhat, Vector Double
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> Vector Double
-> Fix SRTree
-> (SRVector, Vector Double)
reverseModeGraph SRMatrix
xss PVector
ys Maybe PVector
mYerr Vector Double
theta Fix SRTree
tree
    grad' :: Vector Double
grad'        = (Double -> Double) -> Vector Double -> Vector Double
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
VS.map Double -> Double
forall {p}. p -> p
nanTo0 Vector Double
grad
gradNLLGraph Distribution
Poisson SRMatrix
xss PVector
ys Maybe PVector
mYerr Fix SRTree
tree Vector Double
theta
  | (Double -> Bool) -> PVector -> Bool
forall ix r e.
(Index ix, Source r e) =>
(e -> Bool) -> Array r ix e -> Bool
M.any (Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<Double
0) PVector
ys    = String -> (Double, Vector Double)
forall a. HasCallStack => String -> a
error String
"For Poisson distribution the output must be non-negative."
  | Bool
otherwise        = (SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, Vector Double
grad')
  where
    (SRVector
yhat, Vector Double
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> Vector Double
-> Fix SRTree
-> (SRVector, Vector Double)
reverseModeGraph SRMatrix
xss PVector
ys Maybe PVector
mYerr Vector Double
theta Fix SRTree
tree
    grad' :: Vector Double
grad'        = (Double -> Double) -> Vector Double -> Vector Double
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
VS.map Double -> Double
forall {p}. p -> p
nanTo0 Vector Double
grad
gradNLLGraph Distribution
ROXY SRMatrix
xss PVector
ys Maybe PVector
mYerr Fix SRTree
tree Vector Double
theta =
  ((Double -> Double -> Double
forall a. Num a => a -> a -> a
*Double
0.5) (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, (Double -> Double) -> Vector Double -> Vector Double
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
VS.map (Double -> Double -> Double
forall a. Num a => a -> a -> a
*(Double
0.5)) (Vector Double -> Vector Double) -> Vector Double -> Vector Double
forall a b. (a -> b) -> a -> b
$ Vector Double
grad')
  where
    (SRVector
yhat, Vector Double
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> Vector Double
-> Fix SRTree
-> (SRVector, Vector Double)
reverseModeGraph SRMatrix
xss PVector
ys Maybe PVector
mYerr Vector Double
theta Fix SRTree
tree
    grad' :: Vector Double
grad'        = (Double -> Double) -> Vector Double -> Vector Double
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
VS.map Double -> Double
forall {p}. p -> p
nanTo0 Vector Double
grad

-- | e-graph support
gradNLLEGraph :: Distribution
-> SRMatrix
-> PVector
-> Maybe PVector
-> EGraph
-> ECache
-> Int
-> Vector Double
-> (Double, Vector Double)
gradNLLEGraph Distribution
MSE SRMatrix
xss PVector
ys Maybe PVector
mYerr EGraph
egraph ECache
cache Int
root Vector Double
theta =
  (SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, Vector Double
grad')
  where
    (SRVector
yhat, Vector Double
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> EGraph
-> ECache
-> Int
-> Vector Double
-> (SRVector, Vector Double)
reverseModeEGraph SRMatrix
xss PVector
ys Maybe PVector
mYerr EGraph
egraph ECache
cache Int
root Vector Double
theta
    grad' :: Vector Double
grad'                = (Double -> Double) -> Vector Double -> Vector Double
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
VS.map Double -> Double
forall {p}. p -> p
nanTo0 Vector Double
grad
gradNLLEGraph Distribution
Gaussian SRMatrix
xss PVector
ys Maybe PVector
mYerr EGraph
egraph ECache
cache Int
root Vector Double
theta =
  (SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, Vector Double
grad')
  where
    (SRVector
yhat, Vector Double
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> EGraph
-> ECache
-> Int
-> Vector Double
-> (SRVector, Vector Double)
reverseModeEGraph SRMatrix
xss PVector
ys Maybe PVector
mYerr EGraph
egraph ECache
cache Int
root Vector Double
theta
    grad' :: Vector Double
grad'                = (Double -> Double) -> Vector Double -> Vector Double
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
VS.map Double -> Double
forall {p}. p -> p
nanTo0 Vector Double
grad
gradNLLEGraph Distribution
Bernoulli SRMatrix
xss PVector
ys Maybe PVector
mYerr EGraph
egraph ECache
cache Int
root Vector Double
theta
  | (Double -> Bool) -> PVector -> Bool
forall ix r e.
(Index ix, Source r e) =>
(e -> Bool) -> Array r ix e -> Bool
M.any (\Double
x -> Double
x Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
0 Bool -> Bool -> Bool
&& Double
x Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
1) PVector
ys = String -> (Double, Vector Double)
forall a. HasCallStack => String -> a
error String
"For Bernoulli distribution the output must be either 0 or 1."
  | Bool
otherwise                         = (SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, Vector Double
grad')
  where
    (SRVector
yhat, Vector Double
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> EGraph
-> ECache
-> Int
-> Vector Double
-> (SRVector, Vector Double)
reverseModeEGraph SRMatrix
xss PVector
ys Maybe PVector
mYerr EGraph
egraph ECache
cache Int
root Vector Double
theta
    grad' :: Vector Double
grad'        = (Double -> Double) -> Vector Double -> Vector Double
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
VS.map Double -> Double
forall {p}. p -> p
nanTo0 Vector Double
grad
gradNLLEGraph Distribution
Poisson SRMatrix
xss PVector
ys Maybe PVector
mYerr EGraph
egraph ECache
cache Int
root Vector Double
theta
  | (Double -> Bool) -> PVector -> Bool
forall ix r e.
(Index ix, Source r e) =>
(e -> Bool) -> Array r ix e -> Bool
M.any (Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<Double
0) PVector
ys    = String -> (Double, Vector Double)
forall a. HasCallStack => String -> a
error String
"For Poisson distribution the output must be non-negative."
  | Bool
otherwise        = (SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, Vector Double
grad')
  where
    (SRVector
yhat, Vector Double
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> EGraph
-> ECache
-> Int
-> Vector Double
-> (SRVector, Vector Double)
reverseModeEGraph SRMatrix
xss PVector
ys Maybe PVector
mYerr EGraph
egraph ECache
cache Int
root Vector Double
theta
    grad' :: Vector Double
grad'                = (Double -> Double) -> Vector Double -> Vector Double
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
VS.map Double -> Double
forall {p}. p -> p
nanTo0 Vector Double
grad
gradNLLEGraph Distribution
ROXY SRMatrix
xss PVector
ys Maybe PVector
mYerr EGraph
egraph ECache
cache Int
root Vector Double
theta =
  ((Double -> Double -> Double
forall a. Num a => a -> a -> a
*Double
0.5) (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, (Double -> Double) -> Vector Double -> Vector Double
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
VS.map (Double -> Double -> Double
forall a. Num a => a -> a -> a
*(Double
0.5)) (Vector Double -> Vector Double) -> Vector Double -> Vector Double
forall a b. (a -> b) -> a -> b
$ Vector Double
grad')
  where
    (SRVector
yhat, Vector Double
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> EGraph
-> ECache
-> Int
-> Vector Double
-> (SRVector, Vector Double)
reverseModeEGraph SRMatrix
xss PVector
ys Maybe PVector
mYerr EGraph
egraph ECache
cache Int
root Vector Double
theta
    grad' :: Vector Double
grad'                = (Double -> Double) -> Vector Double -> Vector Double
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
VS.map Double -> Double
forall {p}. p -> p
nanTo0 Vector Double
grad

-- | Fisher information of negative log-likelihood
fisherNLL :: Distribution -> Maybe PVector -> SRMatrix -> PVector -> Fix SRTree -> PVector -> SRVector
fisherNLL :: Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> SRVector
fisherNLL Distribution
ROXY Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = Comp -> Sz Int -> (Int -> Double) -> SRVector
forall r ix e.
Load r ix e =>
Comp -> Sz ix -> (ix -> e) -> Array r ix e
makeArray Comp
cmp (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
p) Int -> Double
finiteDiff
  where
    cmp :: Comp
cmp    = SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp SRMatrix
xss
    (Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
    (Sz Int
p) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
theta
    f :: Double
f      = Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> Double
nll Distribution
ROXY Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta
    eps :: Double
eps = Double
1e-6
    finiteDiff :: Int -> Double
finiteDiff Int
ix = IO Double -> Double
forall a. IO a -> a
unsafePerformIO (IO Double -> Double) -> IO Double -> Double
forall a b. (a -> b) -> a -> b
$ do
                      MArray RealWorld S Int Double
theta' <- PVector -> IO (MArray RealWorld S Int Double)
forall r ix e (m :: * -> *).
(Manifest r e, Index ix, MonadIO m) =>
Array r ix e -> m (MArray RealWorld r ix e)
Mut.thaw PVector
theta
                      Double
v <- MArray (PrimState IO) S Int Double -> Int -> IO Double
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> ix -> m e
Mut.readM MArray RealWorld S Int Double
MArray (PrimState IO) S Int Double
theta' Int
ix
                      MArray (PrimState IO) S Int Double -> Int -> Double -> IO ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
Mut.writeM MArray RealWorld S Int Double
MArray (PrimState IO) S Int Double
theta' Int
ix (Double
v Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
eps)
                      PVector
thetaPlus <- MArray (PrimState IO) S Int Double -> IO PVector
forall r ix e (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> m (Array r ix e)
Mut.freezeS MArray RealWorld S Int Double
MArray (PrimState IO) S Int Double
theta'
                      MArray (PrimState IO) S Int Double -> Int -> Double -> IO ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
Mut.writeM MArray RealWorld S Int Double
MArray (PrimState IO) S Int Double
theta' Int
ix (Double
v Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
eps)
                      PVector
thetaMinus <- MArray (PrimState IO) S Int Double -> IO PVector
forall r ix e (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> m (Array r ix e)
Mut.freezeS MArray RealWorld S Int Double
MArray (PrimState IO) S Int Double
theta'
                      let fPlus :: Double
fPlus     = Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> Double
nll Distribution
ROXY Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
thetaPlus
                          fMinus :: Double
fMinus    = Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> Double
nll Distribution
ROXY Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
thetaMinus
                      Double -> IO Double
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Double -> IO Double) -> Double -> IO Double
forall a b. (a -> b) -> a -> b
$ (Double
fPlus Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
fMinus Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
2Double -> Double -> Double
forall a. Num a => a -> a -> a
*Double
f)Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/(Double
epsDouble -> Double -> Double
forall a. Num a => a -> a -> a
*Double
eps)
fisherNLL Distribution
Gaussian Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = Comp -> Sz Int -> (Int -> Double) -> SRVector
forall r ix e.
Load r ix e =>
Comp -> Sz ix -> (ix -> e) -> Array r ix e
makeArray Comp
cmp (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
p) Int -> Double
finiteDiff
  where
    cmp :: Comp
cmp    = SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp SRMatrix
xss
    (Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
    (Sz Int
p) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
theta
    f :: Double
f      = Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> Double
nll Distribution
Gaussian Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta
    eps :: Double
eps = Double
1e-6
    finiteDiff :: Int -> Double
finiteDiff Int
ix = IO Double -> Double
forall a. IO a -> a
unsafePerformIO (IO Double -> Double) -> IO Double -> Double
forall a b. (a -> b) -> a -> b
$ do
                      MArray RealWorld S Int Double
theta' <- PVector -> IO (MArray RealWorld S Int Double)
forall r ix e (m :: * -> *).
(Manifest r e, Index ix, MonadIO m) =>
Array r ix e -> m (MArray RealWorld r ix e)
Mut.thaw PVector
theta
                      Double
v <- MArray (PrimState IO) S Int Double -> Int -> IO Double
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> ix -> m e
Mut.readM MArray RealWorld S Int Double
MArray (PrimState IO) S Int Double
theta' Int
ix
                      MArray (PrimState IO) S Int Double -> Int -> Double -> IO ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
Mut.writeM MArray RealWorld S Int Double
MArray (PrimState IO) S Int Double
theta' Int
ix (Double
v Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
eps)
                      PVector
thetaPlus <- MArray (PrimState IO) S Int Double -> IO PVector
forall r ix e (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> m (Array r ix e)
Mut.freezeS MArray RealWorld S Int Double
MArray (PrimState IO) S Int Double
theta'
                      MArray (PrimState IO) S Int Double -> Int -> Double -> IO ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
Mut.writeM MArray RealWorld S Int Double
MArray (PrimState IO) S Int Double
theta' Int
ix (Double
v Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
eps)
                      PVector
thetaMinus <- MArray (PrimState IO) S Int Double -> IO PVector
forall r ix e (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> m (Array r ix e)
Mut.freezeS MArray RealWorld S Int Double
MArray (PrimState IO) S Int Double
theta'
                      let fPlus :: Double
fPlus     = Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> Double
nll Distribution
Gaussian Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
thetaPlus
                          fMinus :: Double
fMinus    = Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> Double
nll Distribution
Gaussian Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
thetaMinus
                      Double -> IO Double
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Double -> IO Double) -> Double -> IO Double
forall a b. (a -> b) -> a -> b
$ (Double
fPlus Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
fMinus Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
2Double -> Double -> Double
forall a. Num a => a -> a -> a
*Double
f)Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/(Double
epsDouble -> Double -> Double
forall a. Num a => a -> a -> a
*Double
eps)
fisherNLL Distribution
dist Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = Comp -> Sz Int -> (Int -> Double) -> SRVector
forall r ix e.
Load r ix e =>
Comp -> Sz ix -> (ix -> e) -> Array r ix e
makeArray Comp
cmp (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
p) Int -> Double
build
  where
    build :: Int -> Double
build Int
ix = let dtdix :: Fix SRTree
dtdix   = Int -> Fix SRTree -> Fix SRTree
deriveByParam Int
ix Fix SRTree
t'
                   d2tdix2 :: Fix SRTree
d2tdix2 = Int -> Fix SRTree -> Fix SRTree
deriveByParam Int
ix Fix SRTree
dtdix 
                   f' :: SRVector
f'      = Fix SRTree -> SRVector
eval Fix SRTree
dtdix 
                   f'' :: SRVector
f''     = Fix SRTree -> SRVector
eval Fix SRTree
d2tdix2 
               in SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (SRVector -> Double) -> SRVector -> Double
forall a b. (a -> b) -> a -> b
$ SRVector
phi' SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
f'SRVector -> Integer -> SRVector
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2 SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
res SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
f''
               --case dist of
               --     Gaussian -> M.sum . (/delay (theta M.! (p-1))) $ phi' * f'^2 - res * f''
               --     _        -> M.sum $ phi' * f'^2 - res * f''
    cmp :: Comp
cmp    = SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp SRMatrix
xss 
    (Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
    (Sz Int
p) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
theta
    t' :: Fix SRTree
t'     = (Fix SRTree, [Double]) -> Fix SRTree
forall a b. (a, b) -> a
fst ((Fix SRTree, [Double]) -> Fix SRTree)
-> (Fix SRTree, [Double]) -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Fix SRTree -> (Fix SRTree, [Double])
floatConstsToParam Fix SRTree
tree
    eval :: Fix SRTree -> SRVector
eval   = SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree SRMatrix
xss PVector
theta
    yhat :: SRVector
yhat   = Fix SRTree -> SRVector
eval Fix SRTree
t'
    res :: SRVector
res    = PVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
phi
    yErr :: PVector
yErr   = case Maybe PVector
mYerr of
               Maybe PVector
Nothing -> Comp -> Sz Int -> Double -> PVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate (SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp SRMatrix
xss) (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
est
               Just PVector
e  -> PVector
e
    est :: Double
est    = Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
p)

    (SRVector
phi, SRVector
phi') = case Distribution
dist of
                    Distribution
MSE       -> (SRVector
yhat, Comp -> Sz Int -> Double -> SRVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
compMode (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
1)
                    Distribution
Gaussian  -> (SRVector
yhat, Comp -> Sz Int -> Double -> SRVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
compMode (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
1)
                    Distribution
Bernoulli -> (SRVector -> SRVector
forall a. Floating a => a -> a
logistic SRVector
yhat, SRVector
phiSRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*(Comp -> Sz Int -> Double -> SRVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
compMode (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
1 SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
phi))
                    Distribution
Poisson   -> (SRVector -> SRVector
forall a. Floating a => a -> a
exp SRVector
yhat, SRVector
phi)

-- | Hessian of negative log-likelihood
--
-- Note, though the Fisher is just the diagonal of the return of this function
-- it is better to keep them as different functions for efficiency
hessianNLL :: Distribution -> Maybe PVector -> SRMatrix -> PVector -> Fix SRTree -> PVector -> SRMatrix
hessianNLL :: Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> SRMatrix
hessianNLL Distribution
ROXY Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = SRMatrix
forall a. HasCallStack => a
undefined
hessianNLL Distribution
dist Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = Comp -> Sz Ix2 -> (Ix2 -> Double) -> SRMatrix
forall r ix e.
Load r ix e =>
Comp -> Sz ix -> (ix -> e) -> Array r ix e
makeArray Comp
cmp (Ix2 -> Sz Ix2
forall ix. Index ix => ix -> Sz ix
Sz (Int
p Int -> Int -> Ix2
:. Int
p)) Ix2 -> Double
build
  where
    build :: Ix2 -> Double
build (Int
ix :. Int
iy) = let dtdix :: Fix SRTree
dtdix   = Int -> Fix SRTree -> Fix SRTree
deriveByParam Int
ix Fix SRTree
t' 
                           dtdiy :: Fix SRTree
dtdiy   = Int -> Fix SRTree -> Fix SRTree
deriveByParam Int
iy Fix SRTree
t' 
                           d2tdixy :: Fix SRTree
d2tdixy = Int -> Fix SRTree -> Fix SRTree
deriveByParam Int
iy Fix SRTree
dtdix
                           fx :: SRVector
fx      = Fix SRTree -> SRVector
eval Fix SRTree
dtdix 
                           fy :: SRVector
fy      = Fix SRTree -> SRVector
eval Fix SRTree
dtdiy 
                           fxy :: SRVector
fxy     = Fix SRTree -> SRVector
eval Fix SRTree
d2tdixy 
                        in case Distribution
dist of
                            Distribution
Gaussian -> SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (SRVector -> Double)
-> (SRVector -> SRVector) -> SRVector -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/PVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
yErr) (SRVector -> Double) -> SRVector -> Double
forall a b. (a -> b) -> a -> b
$ SRVector
phi' SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
fx SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
fy SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
res SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
fxy
                            Distribution
_        -> SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (SRVector -> Double) -> SRVector -> Double
forall a b. (a -> b) -> a -> b
$ SRVector
phi' SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
fx SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
fy SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
res SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
fxy

    cmp :: Comp
cmp    = SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp SRMatrix
xss
    (Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
    (Sz Int
p) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
theta
    t' :: Fix SRTree
t'     = Fix SRTree
tree -- relabelParams tree -- $ floatConstsToParam tree
    eval :: Fix SRTree -> SRVector
eval   = SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree SRMatrix
xss PVector
theta
    yErr :: PVector
yErr   = case Maybe PVector
mYerr of
               Maybe PVector
Nothing -> Comp -> Sz Int -> Double -> PVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
compMode (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
est
               Just PVector
e  -> PVector
e
    est :: Double
est    = Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
p)
    yhat :: SRVector
yhat   = Fix SRTree -> SRVector
eval Fix SRTree
t'
    res :: SRVector
res    = PVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
phi

    (SRVector
phi, SRVector
phi') = case Distribution
dist of
                    Distribution
MSE       -> (SRVector
yhat, Comp -> Sz Int -> Double -> SRVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
cmp (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
1)
                    Distribution
Gaussian  -> (SRVector
yhat, Comp -> Sz Int -> Double -> SRVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
cmp (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
1)
                    Distribution
Bernoulli -> (SRVector -> SRVector
forall a. Floating a => a -> a
logistic SRVector
yhat, SRVector
phiSRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*(Comp -> Sz Int -> Double -> SRVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
cmp (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
1 SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
phi))
                    Distribution
Poisson   -> (SRVector -> SRVector
forall a. Floating a => a -> a
exp SRVector
yhat, SRVector
phi)

tree2arr :: Fix SRTree -> IntMap.IntMap (Int, Int, Int, Double)
tree2arr :: Fix SRTree -> IntMap (Int, Int, Int, Double)
tree2arr Fix SRTree
tree = [(Int, (Int, Int, Int, Double))] -> IntMap (Int, Int, Int, Double)
forall a. [(Int, a)] -> IntMap a
IntMap.fromList [(Int, (Int, Int, Int, Double))]
listTree
  where
    height :: Fix SRTree -> Integer
height = (SRTree Integer -> Integer) -> Fix SRTree -> Integer
forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata SRTree Integer -> Integer
forall {a}. (Num a, Ord a) => SRTree a -> a
alg
      where
        alg :: SRTree a -> a
alg (Var Int
ix) = a
1
        alg (Const Double
x) = a
1
        alg (Param Int
ix) = a
1
        alg (Uni Function
_ a
t) = a
1 a -> a -> a
forall a. Num a => a -> a -> a
+ a
t
        alg (Bin Op
_ a
l a
r) = a
1 a -> a -> a
forall a. Num a => a -> a -> a
+ a -> a -> a
forall a. Ord a => a -> a -> a
max a
l a
r
    listTree :: [(Int, (Int, Int, Int, Double))]
listTree = (forall x. SRTree x -> Int -> SRTree (x, Int))
-> (SRTree [(Int, (Int, Int, Int, Double))]
    -> Int -> [(Int, (Int, Int, Int, Double))])
-> Fix SRTree
-> Int
-> [(Int, (Int, Int, Int, Double))]
forall (f :: * -> *) p a.
Functor f =>
(forall x. f x -> p -> f (x, p))
-> (f a -> p -> a) -> Fix f -> p -> a
accu SRTree x -> Int -> SRTree (x, Int)
forall x. SRTree x -> Int -> SRTree (x, Int)
forall {b} {a}. Num b => SRTree a -> b -> SRTree (a, b)
indexer SRTree [(Int, (Int, Int, Int, Double))]
-> Int -> [(Int, (Int, Int, Int, Double))]
forall {a} {a}.
Num a =>
SRTree [(a, (a, Int, Int, Double))]
-> a -> [(a, (a, Int, Int, Double))]
convert Fix SRTree
tree Int
0

    indexer :: SRTree a -> b -> SRTree (a, b)
indexer (Var Int
ix) b
iy   = Int -> SRTree (a, b)
forall val. Int -> SRTree val
Var Int
ix
    indexer (Const Double
x) b
iy  = Double -> SRTree (a, b)
forall val. Double -> SRTree val
Const Double
x
    indexer (Param Int
ix) b
iy = Int -> SRTree (a, b)
forall val. Int -> SRTree val
Param Int
ix
    indexer (Bin Op
op a
l a
r) b
iy = Op -> (a, b) -> (a, b) -> SRTree (a, b)
forall val. Op -> val -> val -> SRTree val
Bin Op
op (a
l, b
2b -> b -> b
forall a. Num a => a -> a -> a
*b
iyb -> b -> b
forall a. Num a => a -> a -> a
+b
1) (a
r, b
2b -> b -> b
forall a. Num a => a -> a -> a
*b
iyb -> b -> b
forall a. Num a => a -> a -> a
+b
2)
    indexer (Uni Function
f a
t) b
iy = Function -> (a, b) -> SRTree (a, b)
forall val. Function -> val -> SRTree val
Uni Function
f (a
t, b
2b -> b -> b
forall a. Num a => a -> a -> a
*b
iyb -> b -> b
forall a. Num a => a -> a -> a
+b
1)

    convert :: SRTree [(a, (a, Int, Int, Double))]
-> a -> [(a, (a, Int, Int, Double))]
convert (Var Int
ix) a
iy = [(a
iy, (a
0, Int
0, Int
ix, -Double
1))]
    convert (Const Double
x) a
iy = [(a
iy, (a
0, Int
2, -Int
1, Double
x))]
    convert (Param Int
ix) a
iy = [(a
iy, (a
0, Int
1, Int
ix, -Double
1))]
    convert (Uni Function
f [(a, (a, Int, Int, Double))]
t) a
iy = (a
iy, (a
1, Function -> Int
forall a. Enum a => a -> Int
fromEnum Function
f, -Int
1, -Double
1)) (a, (a, Int, Int, Double))
-> [(a, (a, Int, Int, Double))] -> [(a, (a, Int, Int, Double))]
forall a. a -> [a] -> [a]
: [(a, (a, Int, Int, Double))]
t
    convert (Bin Op
op [(a, (a, Int, Int, Double))]
l [(a, (a, Int, Int, Double))]
r) a
iy = (a
iy, (a
2, Op -> Int
forall a. Enum a => a -> Int
fromEnum Op
op, -Int
1, -Double
1)) (a, (a, Int, Int, Double))
-> [(a, (a, Int, Int, Double))] -> [(a, (a, Int, Int, Double))]
forall a. a -> [a] -> [a]
: ([(a, (a, Int, Int, Double))]
l [(a, (a, Int, Int, Double))]
-> [(a, (a, Int, Int, Double))] -> [(a, (a, Int, Int, Double))]
forall a. Semigroup a => a -> a -> a
<> [(a, (a, Int, Int, Double))]
r)
{-# INLINE tree2arr #-}