{-# language FlexibleInstances, DeriveFunctor #-}
{-# language ScopedTypeVariables #-}
{-# language RankNTypes #-}
{-# language ViewPatterns #-}
{-# language FlexibleContexts #-}
{-# language BangPatterns #-}
{-# language TypeApplications #-}
{-# language MultiWayIf #-}
module Algorithm.SRTree.AD
( reverseModeArr
, reverseModeEGraph
, reverseModeGraph
, forwardModeUniqueJac
, evalCache
) where
import Control.Monad (forM_, foldM, when)
import Control.Monad.ST ( runST )
import Data.Bifunctor (bimap, first, second)
import qualified Data.DList as DL
import Data.Massiv.Array hiding (forM_, map, replicate, zipWith)
import qualified Data.Massiv.Array as M
import qualified Data.Massiv.Array.Unsafe as UMA
import Data.Massiv.Core.Operations (unsafeLiftArray)
import Data.SRTree.Derivative ( derivative )
import Data.SRTree.Eval
( SRVector, evalFun, evalOp, SRMatrix, PVector, replicateAs )
import Data.SRTree.Internal
import Data.SRTree.Print (showExpr)
import Data.SRTree.Recursion ( cataM, cata, accu )
import qualified Data.Vector as V
import Debug.Trace (trace, traceShow)
import GHC.IO (unsafePerformIO)
import qualified Data.IntMap.Strict as IntMap
import Data.List ( foldl' )
import qualified Data.Vector.Storable as VS
import Control.Scheduler
import Data.Maybe ( fromJust, isJust )
import Algorithm.EqSat.Egraph
import Control.Monad.State.Strict
import Control.Monad.Identity
import qualified Data.Map.Strict as Map
evalCache :: SRMatrix -> EGraph -> ECache -> EClassId -> VS.Vector Double -> ECache
evalCache :: SRMatrix
-> EGraph
-> IntMap (Array S Ix1 Double)
-> Ix1
-> Vector Double
-> IntMap (Array S Ix1 Double)
evalCache SRMatrix
xss EGraph
egraph IntMap (Array S Ix1 Double)
cache Ix1
root' Vector Double
theta = IntMap (Array S Ix1 Double)
cache'
where
(Sz2 Ix1
_ Ix1
m') = SRMatrix -> Sz Ix2
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size SRMatrix
xss
m :: Sz Ix1
m = Ix1 -> Sz Ix1
Sz1 Ix1
m'
root :: Ix1
root = Ix1 -> Ix1
canon Ix1
root'
p :: Ix1
p = Vector Double -> Ix1
forall a. Storable a => Vector a -> Ix1
VS.length Vector Double
theta
comp :: Comp
comp = SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
M.getComp SRMatrix
xss
one :: Array S Ix1 Double
one :: Array S Ix1 Double
one = Comp -> Sz Ix1 -> Double -> Array S Ix1 Double
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
comp Sz Ix1
m Double
1
canon :: Ix1 -> Ix1
canon Ix1
rt = case EGraph -> ClassIdMap Ix1
_canonicalMap EGraph
egraph ClassIdMap Ix1 -> Ix1 -> Maybe Ix1
forall a. IntMap a -> Ix1 -> Maybe a
IntMap.!? Ix1
rt of
Maybe Ix1
Nothing -> [Char] -> Ix1
forall a. HasCallStack => [Char] -> a
error [Char]
"wrong canon"
Just Ix1
rt' -> if Ix1
rt Ix1 -> Ix1 -> Bool
forall a. Eq a => a -> a -> Bool
== Ix1
rt' then Ix1
rt else Ix1 -> Ix1
canon Ix1
rt'
getNode :: Ix1 -> SRTree Ix1
getNode Ix1
rt' = let rt :: Ix1
rt = Ix1 -> Ix1
canon Ix1
rt'
cls :: EClass
cls = EGraph -> ClassIdMap EClass
_eClass EGraph
egraph ClassIdMap EClass -> Ix1 -> EClass
forall a. IntMap a -> Ix1 -> a
IntMap.! Ix1
rt
in (EClassData -> SRTree Ix1
_best (EClassData -> SRTree Ix1)
-> (EClass -> EClassData) -> EClass -> SRTree Ix1
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info) EClass
cls
getId :: SRTree Ix1 -> Ix1
getId SRTree Ix1
n' = let n :: SRTree Ix1
n = Identity (SRTree Ix1) -> SRTree Ix1
forall a. Identity a -> a
runIdentity (Identity (SRTree Ix1) -> SRTree Ix1)
-> Identity (SRTree Ix1) -> SRTree Ix1
forall a b. (a -> b) -> a -> b
$ SRTree Ix1 -> EGraphST Identity (SRTree Ix1)
forall (m :: * -> *).
Monad m =>
SRTree Ix1 -> EGraphST m (SRTree Ix1)
canonize SRTree Ix1
n' EGraphST Identity (SRTree Ix1) -> EGraph -> Identity (SRTree Ix1)
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
`evalStateT` EGraph
egraph
in if SRTree Ix1
n SRTree Ix1 -> Map (SRTree Ix1) Ix1 -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`Map.member` EGraph -> Map (SRTree Ix1) Ix1
_eNodeToEClass EGraph
egraph then EGraph -> Map (SRTree Ix1) Ix1
_eNodeToEClass EGraph
egraph Map (SRTree Ix1) Ix1 -> SRTree Ix1 -> Ix1
forall k a. Ord k => Map k a -> k -> a
Map.! SRTree Ix1
n else EGraph -> Map (SRTree Ix1) Ix1
_eNodeToEClass EGraph
egraph Map (SRTree Ix1) Ix1 -> SRTree Ix1 -> Ix1
forall k a. Ord k => Map k a -> k -> a
Map.! SRTree Ix1
n'
((IntMap (Array S Ix1 Double)
cache', IntMap (Array S Ix1 Double)
localcache), Map (SRTree Ix1) (Array S Ix1 Double)
_) = Ix1
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
evalCached Ix1
root State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
-> ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
forall s a. State s a -> s -> s
`execState` ((IntMap (Array S Ix1 Double)
cache, IntMap (Array S Ix1 Double)
forall a. IntMap a
IntMap.empty), Map (SRTree Ix1) (Array S Ix1 Double)
forall k a. Map k a
Map.empty)
where
evalCached :: EClassId -> State ((ECache, ECache), Map.Map ENode PVector) (PVector, Bool)
evalCached :: Ix1
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
evalCached Ix1
rt = Ix1
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
insertKey Ix1
rt
insertKey :: EClassId -> State ((ECache, ECache), Map.Map ENode PVector) (PVector, Bool)
insertKey :: Ix1
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
insertKey Ix1
key' = do
let key :: Ix1
key = Ix1 -> Ix1
canon Ix1
key'
Bool
isCachedGlobal <- (((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> Bool)
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((Ix1
key Ix1 -> IntMap (Array S Ix1 Double) -> Bool
forall a. Ix1 -> IntMap a -> Bool
`IntMap.member`) (IntMap (Array S Ix1 Double) -> Bool)
-> (((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> IntMap (Array S Ix1 Double))
-> ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double))
-> IntMap (Array S Ix1 Double)
forall a b. (a, b) -> a
fst ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double))
-> IntMap (Array S Ix1 Double))
-> (((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> (IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)))
-> ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> IntMap (Array S Ix1 Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> (IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double))
forall a b. (a, b) -> a
fst)
Bool
isCachedLocal <- (((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> Bool)
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((Ix1
key Ix1 -> IntMap (Array S Ix1 Double) -> Bool
forall a. Ix1 -> IntMap a -> Bool
`IntMap.member`) (IntMap (Array S Ix1 Double) -> Bool)
-> (((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> IntMap (Array S Ix1 Double))
-> ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double))
-> IntMap (Array S Ix1 Double)
forall a b. (a, b) -> b
snd ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double))
-> IntMap (Array S Ix1 Double))
-> (((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> (IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)))
-> ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> IntMap (Array S Ix1 Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> (IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double))
forall a b. (a, b) -> a
fst)
Bool
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
()
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool -> Bool
not Bool
isCachedLocal Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
isCachedGlobal) (StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
()
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
())
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
()
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
()
forall a b. (a -> b) -> a -> b
$ do
let node :: SRTree Ix1
node = Ix1 -> SRTree Ix1
getNode Ix1
key
(Array S Ix1 Double
ev, Bool
toLocal) <- SRTree Ix1
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
evalKey SRTree Ix1
node
(((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double)))
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' (SRTree Ix1
-> Array S Ix1 Double
-> Bool
-> ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
forall {a} {b}.
SRTree Ix1
-> a
-> Bool
-> ((IntMap a, IntMap a), b)
-> ((IntMap a, IntMap a), b)
insKey SRTree Ix1
node Array S Ix1 Double
ev Bool
toLocal)
Ix1
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
getVal Ix1
key
evalKey :: ENode -> State ((ECache, ECache), Map.Map ENode PVector) (PVector, Bool)
evalKey :: SRTree Ix1
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
evalKey (Var Ix1
ix) = (Array S Ix1 Double, Bool)
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
forall a.
a
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Array S Ix1 Double, Bool)
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool))
-> (Array S Ix1 Double, Bool)
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
forall a b. (a -> b) -> a -> b
$ (S -> Array D Ix1 Double -> Array S Ix1 Double
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D Ix1 Double -> Array S Ix1 Double)
-> Array D Ix1 Double -> Array S Ix1 Double
forall a b. (a -> b) -> a -> b
$ SRMatrix
xss SRMatrix -> Ix1 -> Array D (Lower Ix2) Double
forall r ix e.
(HasCallStack, Index ix, Source r e) =>
Array r ix e -> Ix1 -> Array D (Lower ix) e
<! Ix1
ix, Bool
False)
evalKey (Const Double
v) = (Array S Ix1 Double, Bool)
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
forall a.
a
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Array S Ix1 Double, Bool)
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool))
-> (Array S Ix1 Double, Bool)
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
forall a b. (a -> b) -> a -> b
$ (Comp -> Sz Ix1 -> Double -> Array S Ix1 Double
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
comp Sz Ix1
m Double
v, Bool
False)
evalKey (Param Ix1
ix) = (Array S Ix1 Double, Bool)
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
forall a.
a
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Array S Ix1 Double, Bool)
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool))
-> (Array S Ix1 Double, Bool)
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
forall a b. (a -> b) -> a -> b
$ (Comp -> Sz Ix1 -> Double -> Array S Ix1 Double
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
comp Sz Ix1
m (Vector Double
theta Vector Double -> Ix1 -> Double
forall a. Storable a => Vector a -> Ix1 -> a
VS.! Ix1
ix), Bool
True)
evalKey (Uni Function
f Ix1
t) = do (Array S Ix1 Double
v, Bool
b) <- Ix1
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
getVal Ix1
t
(Array S Ix1 Double, Bool)
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
forall a.
a
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Array S Ix1 Double, Bool)
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool))
-> (Array S Ix1 Double, Bool)
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
forall a b. (a -> b) -> a -> b
$ (S -> Array D Ix1 Double -> Array S Ix1 Double
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D Ix1 Double -> Array S Ix1 Double)
-> (Array S Ix1 Double -> Array D Ix1 Double)
-> Array S Ix1 Double
-> Array S Ix1 Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double -> Double) -> Array S Ix1 Double -> Array D Ix1 Double
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Function -> Double -> Double
forall a. Floating a => Function -> a -> a
evalFun Function
f) (Array S Ix1 Double -> Array S Ix1 Double)
-> Array S Ix1 Double -> Array S Ix1 Double
forall a b. (a -> b) -> a -> b
$ Array S Ix1 Double
v, Bool
b)
evalKey (Bin Op
op Ix1
l Ix1
r) = do (Array S Ix1 Double
vl, Bool
bl) <- Ix1
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
getVal Ix1
l
(Array S Ix1 Double
vr, Bool
br) <- Ix1
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
getVal Ix1
r
(Array S Ix1 Double, Bool)
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
forall a.
a
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Array S Ix1 Double, Bool)
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool))
-> (Array S Ix1 Double, Bool)
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
forall a b. (a -> b) -> a -> b
$ (S -> Array D Ix1 Double -> Array S Ix1 Double
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D Ix1 Double -> Array S Ix1 Double)
-> Array D Ix1 Double -> Array S Ix1 Double
forall a b. (a -> b) -> a -> b
$ (Double -> Double -> Double)
-> Array S Ix1 Double -> Array S Ix1 Double -> Array D Ix1 Double
forall ix r1 e1 r2 e2 e.
(Index ix, Source r1 e1, Source r2 e2) =>
(e1 -> e2 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix e
M.zipWith (Op -> Double -> Double -> Double
forall a. Floating a => Op -> a -> a -> a
evalOp Op
op) Array S Ix1 Double
vl Array S Ix1 Double
vr, Bool
bl Bool -> Bool -> Bool
|| Bool
br)
insKey :: SRTree Ix1
-> a
-> Bool
-> ((IntMap a, IntMap a), b)
-> ((IntMap a, IntMap a), b)
insKey (Var Ix1
_) a
_ Bool
_ ((IntMap a, IntMap a), b)
s = ((IntMap a, IntMap a), b)
s
insKey (Const Double
_) a
_ Bool
_ ((IntMap a, IntMap a), b)
s = ((IntMap a, IntMap a), b)
s
insKey (Param Ix1
_) a
_ Bool
_ ((IntMap a, IntMap a), b)
s = ((IntMap a, IntMap a), b)
s
insKey SRTree Ix1
node a
v Bool
toLocal ((IntMap a
global,IntMap a
local), b
s) =
let k :: Ix1
k = SRTree Ix1 -> Ix1
getId SRTree Ix1
node
in if Bool
toLocal
then ((IntMap a
global, Ix1 -> a -> IntMap a -> IntMap a
forall a. Ix1 -> a -> IntMap a -> IntMap a
IntMap.insert Ix1
k a
v IntMap a
local), b
s)
else ((Ix1 -> a -> IntMap a -> IntMap a
forall a. Ix1 -> a -> IntMap a -> IntMap a
IntMap.insert Ix1
k a
v IntMap a
global, IntMap a
local), b
s)
insertLocal :: Ix1 -> a -> m ()
insertLocal Ix1
k a
v = do (a
c1, IntMap a
c2) <- m (a, IntMap a)
forall s (m :: * -> *). MonadState s m => m s
get
(a, IntMap a) -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (a
c1, Ix1 -> a -> IntMap a -> IntMap a
forall a. Ix1 -> a -> IntMap a -> IntMap a
IntMap.insert Ix1
k a
v IntMap a
c2)
insertGlobal :: Ix1 -> a -> m ()
insertGlobal Ix1
k a
v = do (IntMap a
c1, b
c2) <- m (IntMap a, b)
forall s (m :: * -> *). MonadState s m => m s
get
(IntMap a, b) -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (Ix1 -> a -> IntMap a -> IntMap a
forall a. Ix1 -> a -> IntMap a -> IntMap a
IntMap.insert Ix1
k a
v IntMap a
c1, b
c2)
getVal :: Ix1
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
getVal Ix1
rt' = do let rt :: Ix1
rt = Ix1 -> Ix1
canon Ix1
rt'
n :: SRTree Ix1
n = Ix1 -> SRTree Ix1
getNode Ix1
rt
case SRTree Ix1
n of
Var Ix1
ix -> SRTree Ix1
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
evalKey SRTree Ix1
n
Const Double
v -> SRTree Ix1
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
evalKey SRTree Ix1
n
Param Ix1
ix -> SRTree Ix1
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
evalKey SRTree Ix1
n
SRTree Ix1
_ -> Ix1
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
getFromCache Ix1
rt
getFromCache :: Ix1
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
getFromCache Ix1
rt = do
Maybe (Array S Ix1 Double)
global <- (((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> Maybe (Array S Ix1 Double))
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
(Maybe (Array S Ix1 Double))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((IntMap (Array S Ix1 Double) -> Ix1 -> Maybe (Array S Ix1 Double)
forall a. IntMap a -> Ix1 -> Maybe a
IntMap.!? Ix1
rt) (IntMap (Array S Ix1 Double) -> Maybe (Array S Ix1 Double))
-> (((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> IntMap (Array S Ix1 Double))
-> ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> Maybe (Array S Ix1 Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double))
-> IntMap (Array S Ix1 Double)
forall a b. (a, b) -> a
fst ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double))
-> IntMap (Array S Ix1 Double))
-> (((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> (IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)))
-> ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> IntMap (Array S Ix1 Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> (IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double))
forall a b. (a, b) -> a
fst)
Maybe (Array S Ix1 Double)
local <- (((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> Maybe (Array S Ix1 Double))
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
(Maybe (Array S Ix1 Double))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((IntMap (Array S Ix1 Double) -> Ix1 -> Maybe (Array S Ix1 Double)
forall a. IntMap a -> Ix1 -> Maybe a
IntMap.!? Ix1
rt) (IntMap (Array S Ix1 Double) -> Maybe (Array S Ix1 Double))
-> (((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> IntMap (Array S Ix1 Double))
-> ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> Maybe (Array S Ix1 Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double))
-> IntMap (Array S Ix1 Double)
forall a b. (a, b) -> b
snd ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double))
-> IntMap (Array S Ix1 Double))
-> (((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> (IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)))
-> ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> IntMap (Array S Ix1 Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> (IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double))
forall a b. (a, b) -> a
fst)
if | Maybe (Array S Ix1 Double) -> Bool
forall a. Maybe a -> Bool
isJust Maybe (Array S Ix1 Double)
global -> (Array S Ix1 Double, Bool)
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
forall a.
a
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (Array S Ix1 Double) -> Array S Ix1 Double
forall a. HasCallStack => Maybe a -> a
fromJust Maybe (Array S Ix1 Double)
global, Bool
False)
| Maybe (Array S Ix1 Double) -> Bool
forall a. Maybe a -> Bool
isJust Maybe (Array S Ix1 Double)
local -> (Array S Ix1 Double, Bool)
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
forall a.
a
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (Array S Ix1 Double) -> Array S Ix1 Double
forall a. HasCallStack => Maybe a -> a
fromJust Maybe (Array S Ix1 Double)
local, Bool
True)
| Bool
otherwise -> Ix1
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
insertKey Ix1
rt
reverseModeEGraph :: SRMatrix -> PVector -> Maybe PVector -> EGraph -> ECache -> EClassId -> VS.Vector Double -> (Array D Ix1 Double, VS.Vector Double)
reverseModeEGraph :: SRMatrix
-> Array S Ix1 Double
-> Maybe (Array S Ix1 Double)
-> EGraph
-> IntMap (Array S Ix1 Double)
-> Ix1
-> Vector Double
-> (Array D Ix1 Double, Vector Double)
reverseModeEGraph SRMatrix
xss Array S Ix1 Double
ys Maybe (Array S Ix1 Double)
mYErr EGraph
egraph IntMap (Array S Ix1 Double)
cache Ix1
root' Vector Double
theta =
(Array S Ix1 Double -> Array D Ix1 Double
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay (Array S Ix1 Double -> Array D Ix1 Double)
-> Array S Ix1 Double -> Array D Ix1 Double
forall a b. (a -> b) -> a -> b
$ Array S Ix1 Double
rootVal
, [Double] -> Vector Double
forall a. Storable a => [a] -> Vector a
VS.fromList [Array S Ix1 Double -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (Array S Ix1 Double -> Double) -> Array S Ix1 Double -> Double
forall a b. (a -> b) -> a -> b
$ Map (SRTree Ix1) (Array S Ix1 Double)
cachedGrad Map (SRTree Ix1) (Array S Ix1 Double)
-> SRTree Ix1 -> Array S Ix1 Double
forall k a. Ord k => Map k a -> k -> a
Map.! (Ix1 -> SRTree Ix1
forall val. Ix1 -> SRTree val
Param Ix1
ix) | Ix1
ix <- [Ix1
0..Ix1
pIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
-Ix1
1]]
)
where
rootVal :: Array S Ix1 Double
rootVal = (Maybe (Array S Ix1 Double), Maybe (Array S Ix1 Double))
-> Array S Ix1 Double
forall {a}. (Maybe a, Maybe a) -> a
extractCache (IntMap (Array S Ix1 Double)
cache'' IntMap (Array S Ix1 Double) -> Ix1 -> Maybe (Array S Ix1 Double)
forall a. IntMap a -> Ix1 -> Maybe a
IntMap.!? Ix1
root', IntMap (Array S Ix1 Double)
localcache' IntMap (Array S Ix1 Double) -> Ix1 -> Maybe (Array S Ix1 Double)
forall a. IntMap a -> Ix1 -> Maybe a
IntMap.!? Ix1
root')
root :: Ix1
root = Ix1 -> Ix1
canon Ix1
root'
yErr :: Array S Ix1 Double
yErr = Maybe (Array S Ix1 Double) -> Array S Ix1 Double
forall a. HasCallStack => Maybe a -> a
fromJust Maybe (Array S Ix1 Double)
mYErr
m :: Sz Ix1
m = Array S Ix1 Double -> Sz Ix1
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size Array S Ix1 Double
ys
p :: Ix1
p = Vector Double -> Ix1
forall a. Storable a => Vector a -> Ix1
VS.length Vector Double
theta
comp :: Comp
comp = SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
M.getComp SRMatrix
xss
one :: Array S Ix1 Double
one :: Array S Ix1 Double
one = Comp -> Sz Ix1 -> Double -> Array S Ix1 Double
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
comp Sz Ix1
m Double
1
canon :: Ix1 -> Ix1
canon Ix1
rt = case EGraph -> ClassIdMap Ix1
_canonicalMap EGraph
egraph ClassIdMap Ix1 -> Ix1 -> Maybe Ix1
forall a. IntMap a -> Ix1 -> Maybe a
IntMap.!? Ix1
rt of
Maybe Ix1
Nothing -> [Char] -> Ix1
forall a. HasCallStack => [Char] -> a
error [Char]
"wrong canon"
Just Ix1
rt' -> if Ix1
rt Ix1 -> Ix1 -> Bool
forall a. Eq a => a -> a -> Bool
== Ix1
rt' then Ix1
rt else Ix1 -> Ix1
canon Ix1
rt'
getNode :: Ix1 -> SRTree Ix1
getNode Ix1
rt' = let rt :: Ix1
rt = Ix1 -> Ix1
canon Ix1
rt'
cls :: EClass
cls = EGraph -> ClassIdMap EClass
_eClass EGraph
egraph ClassIdMap EClass -> Ix1 -> EClass
forall a. IntMap a -> Ix1 -> a
IntMap.! Ix1
rt
in (EClassData -> SRTree Ix1
_best (EClassData -> SRTree Ix1)
-> (EClass -> EClassData) -> EClass -> SRTree Ix1
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info) EClass
cls
getId :: SRTree Ix1 -> Ix1
getId SRTree Ix1
n' = let n :: SRTree Ix1
n = Identity (SRTree Ix1) -> SRTree Ix1
forall a. Identity a -> a
runIdentity (Identity (SRTree Ix1) -> SRTree Ix1)
-> Identity (SRTree Ix1) -> SRTree Ix1
forall a b. (a -> b) -> a -> b
$ SRTree Ix1 -> EGraphST Identity (SRTree Ix1)
forall (m :: * -> *).
Monad m =>
SRTree Ix1 -> EGraphST m (SRTree Ix1)
canonize SRTree Ix1
n' EGraphST Identity (SRTree Ix1) -> EGraph -> Identity (SRTree Ix1)
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
`evalStateT` EGraph
egraph
in if SRTree Ix1
n SRTree Ix1 -> Map (SRTree Ix1) Ix1 -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`Map.member` EGraph -> Map (SRTree Ix1) Ix1
_eNodeToEClass EGraph
egraph then EGraph -> Map (SRTree Ix1) Ix1
_eNodeToEClass EGraph
egraph Map (SRTree Ix1) Ix1 -> SRTree Ix1 -> Ix1
forall k a. Ord k => Map k a -> k -> a
Map.! SRTree Ix1
n else EGraph -> Map (SRTree Ix1) Ix1
_eNodeToEClass EGraph
egraph Map (SRTree Ix1) Ix1 -> SRTree Ix1 -> Ix1
forall k a. Ord k => Map k a -> k -> a
Map.! SRTree Ix1
n'
((IntMap (Array S Ix1 Double)
cache', IntMap (Array S Ix1 Double)
localcache), Map (SRTree Ix1) (Array S Ix1 Double)
_) = Ix1
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
evalCached Ix1
root State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
-> ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
forall s a. State s a -> s -> s
`execState` ((IntMap (Array S Ix1 Double)
cache, IntMap (Array S Ix1 Double)
forall a. IntMap a
IntMap.empty), Map (SRTree Ix1) (Array S Ix1 Double)
forall k a. Map k a
Map.empty)
where
evalCached :: EClassId -> State ((ECache, ECache), Map.Map ENode PVector) (PVector, Bool)
evalCached :: Ix1
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
evalCached Ix1
rt = Ix1
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
insertKey Ix1
rt
insertKey :: EClassId -> State ((ECache, ECache), Map.Map ENode PVector) (PVector, Bool)
insertKey :: Ix1
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
insertKey Ix1
key' = do
let key :: Ix1
key = Ix1 -> Ix1
canon Ix1
key'
Bool
isCachedGlobal <- (((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> Bool)
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((Ix1
key Ix1 -> IntMap (Array S Ix1 Double) -> Bool
forall a. Ix1 -> IntMap a -> Bool
`IntMap.member`) (IntMap (Array S Ix1 Double) -> Bool)
-> (((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> IntMap (Array S Ix1 Double))
-> ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double))
-> IntMap (Array S Ix1 Double)
forall a b. (a, b) -> a
fst ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double))
-> IntMap (Array S Ix1 Double))
-> (((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> (IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)))
-> ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> IntMap (Array S Ix1 Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> (IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double))
forall a b. (a, b) -> a
fst)
Bool
isCachedLocal <- (((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> Bool)
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((Ix1
key Ix1 -> IntMap (Array S Ix1 Double) -> Bool
forall a. Ix1 -> IntMap a -> Bool
`IntMap.member`) (IntMap (Array S Ix1 Double) -> Bool)
-> (((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> IntMap (Array S Ix1 Double))
-> ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double))
-> IntMap (Array S Ix1 Double)
forall a b. (a, b) -> b
snd ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double))
-> IntMap (Array S Ix1 Double))
-> (((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> (IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)))
-> ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> IntMap (Array S Ix1 Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> (IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double))
forall a b. (a, b) -> a
fst)
Bool
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
()
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool -> Bool
not Bool
isCachedLocal Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
isCachedGlobal) (StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
()
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
())
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
()
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
()
forall a b. (a -> b) -> a -> b
$ do
let node :: SRTree Ix1
node = Ix1 -> SRTree Ix1
getNode Ix1
key
(Array S Ix1 Double
ev, Bool
toLocal) <- SRTree Ix1
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
evalKey SRTree Ix1
node
(((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double)))
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' (SRTree Ix1
-> Array S Ix1 Double
-> Bool
-> ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
forall {a} {b}.
SRTree Ix1
-> a
-> Bool
-> ((IntMap a, IntMap a), b)
-> ((IntMap a, IntMap a), b)
insKey SRTree Ix1
node Array S Ix1 Double
ev Bool
toLocal)
Ix1
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
getVal Ix1
key
evalKey :: ENode -> State ((ECache, ECache), Map.Map ENode PVector) (PVector, Bool)
evalKey :: SRTree Ix1
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
evalKey (Var Ix1
ix) = (Array S Ix1 Double, Bool)
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
forall a.
a
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Array S Ix1 Double, Bool)
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool))
-> (Array S Ix1 Double, Bool)
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
forall a b. (a -> b) -> a -> b
$ if | Ix1
ix Ix1 -> Ix1 -> Bool
forall a. Eq a => a -> a -> Bool
== -Ix1
1 -> (Array S Ix1 Double
ys, Bool
False)
| Ix1
ix Ix1 -> Ix1 -> Bool
forall a. Eq a => a -> a -> Bool
== -Ix1
2 -> (Array S Ix1 Double
yErr, Bool
False)
| Bool
otherwise -> (S -> Array D Ix1 Double -> Array S Ix1 Double
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D Ix1 Double -> Array S Ix1 Double)
-> Array D Ix1 Double -> Array S Ix1 Double
forall a b. (a -> b) -> a -> b
$ SRMatrix
xss SRMatrix -> Ix1 -> Array D (Lower Ix2) Double
forall r ix e.
(HasCallStack, Index ix, Source r e) =>
Array r ix e -> Ix1 -> Array D (Lower ix) e
<! Ix1
ix, Bool
False)
evalKey (Const Double
v) = (Array S Ix1 Double, Bool)
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
forall a.
a
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Array S Ix1 Double, Bool)
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool))
-> (Array S Ix1 Double, Bool)
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
forall a b. (a -> b) -> a -> b
$ (Comp -> Sz Ix1 -> Double -> Array S Ix1 Double
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
comp Sz Ix1
m Double
v, Bool
False)
evalKey (Param Ix1
ix) = (Array S Ix1 Double, Bool)
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
forall a.
a
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Array S Ix1 Double, Bool)
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool))
-> (Array S Ix1 Double, Bool)
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
forall a b. (a -> b) -> a -> b
$ (Comp -> Sz Ix1 -> Double -> Array S Ix1 Double
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
comp Sz Ix1
m (Vector Double
theta Vector Double -> Ix1 -> Double
forall a. Storable a => Vector a -> Ix1 -> a
VS.! Ix1
ix), Bool
True)
evalKey (Uni Function
f Ix1
t) = do (Array S Ix1 Double
v, Bool
b) <- Ix1
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
getVal Ix1
t
(Array S Ix1 Double, Bool)
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
forall a.
a
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Array S Ix1 Double, Bool)
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool))
-> (Array S Ix1 Double, Bool)
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
forall a b. (a -> b) -> a -> b
$ (S -> Array D Ix1 Double -> Array S Ix1 Double
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D Ix1 Double -> Array S Ix1 Double)
-> (Array S Ix1 Double -> Array D Ix1 Double)
-> Array S Ix1 Double
-> Array S Ix1 Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double -> Double) -> Array S Ix1 Double -> Array D Ix1 Double
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Function -> Double -> Double
forall a. Floating a => Function -> a -> a
evalFun Function
f) (Array S Ix1 Double -> Array S Ix1 Double)
-> Array S Ix1 Double -> Array S Ix1 Double
forall a b. (a -> b) -> a -> b
$ Array S Ix1 Double
v, Bool
b)
evalKey (Bin Op
op Ix1
l Ix1
r) = do (Array S Ix1 Double
vl, Bool
bl) <- Ix1
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
getVal Ix1
l
(Array S Ix1 Double
vr, Bool
br) <- Ix1
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
getVal Ix1
r
(Array S Ix1 Double, Bool)
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
forall a.
a
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Array S Ix1 Double, Bool)
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool))
-> (Array S Ix1 Double, Bool)
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
forall a b. (a -> b) -> a -> b
$ (S -> Array D Ix1 Double -> Array S Ix1 Double
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D Ix1 Double -> Array S Ix1 Double)
-> Array D Ix1 Double -> Array S Ix1 Double
forall a b. (a -> b) -> a -> b
$ (Double -> Double -> Double)
-> Array S Ix1 Double -> Array S Ix1 Double -> Array D Ix1 Double
forall ix r1 e1 r2 e2 e.
(Index ix, Source r1 e1, Source r2 e2) =>
(e1 -> e2 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix e
M.zipWith (Op -> Double -> Double -> Double
forall a. Floating a => Op -> a -> a -> a
evalOp Op
op) Array S Ix1 Double
vl Array S Ix1 Double
vr, Bool
bl Bool -> Bool -> Bool
|| Bool
br)
insKey :: SRTree Ix1
-> a
-> Bool
-> ((IntMap a, IntMap a), b)
-> ((IntMap a, IntMap a), b)
insKey (Var Ix1
_) a
_ Bool
_ ((IntMap a, IntMap a), b)
s = ((IntMap a, IntMap a), b)
s
insKey (Const Double
_) a
_ Bool
_ ((IntMap a, IntMap a), b)
s = ((IntMap a, IntMap a), b)
s
insKey (Param Ix1
_) a
_ Bool
_ ((IntMap a, IntMap a), b)
s = ((IntMap a, IntMap a), b)
s
insKey SRTree Ix1
node a
v Bool
toLocal ((IntMap a
global,IntMap a
local), b
s) =
let k :: Ix1
k = SRTree Ix1 -> Ix1
getId SRTree Ix1
node
in if Bool
toLocal
then ((IntMap a
global, Ix1 -> a -> IntMap a -> IntMap a
forall a. Ix1 -> a -> IntMap a -> IntMap a
IntMap.insert Ix1
k a
v IntMap a
local), b
s)
else ((Ix1 -> a -> IntMap a -> IntMap a
forall a. Ix1 -> a -> IntMap a -> IntMap a
IntMap.insert Ix1
k a
v IntMap a
global, IntMap a
local), b
s)
insertLocal :: Ix1 -> a -> m ()
insertLocal Ix1
k a
v = do (a
c1, IntMap a
c2) <- m (a, IntMap a)
forall s (m :: * -> *). MonadState s m => m s
get
(a, IntMap a) -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (a
c1, Ix1 -> a -> IntMap a -> IntMap a
forall a. Ix1 -> a -> IntMap a -> IntMap a
IntMap.insert Ix1
k a
v IntMap a
c2)
insertGlobal :: Ix1 -> a -> m ()
insertGlobal Ix1
k a
v = do (IntMap a
c1, b
c2) <- m (IntMap a, b)
forall s (m :: * -> *). MonadState s m => m s
get
(IntMap a, b) -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (Ix1 -> a -> IntMap a -> IntMap a
forall a. Ix1 -> a -> IntMap a -> IntMap a
IntMap.insert Ix1
k a
v IntMap a
c1, b
c2)
getVal :: Ix1
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
getVal Ix1
rt' = do let rt :: Ix1
rt = Ix1 -> Ix1
canon Ix1
rt'
n :: SRTree Ix1
n = Ix1 -> SRTree Ix1
getNode Ix1
rt
case SRTree Ix1
n of
Var Ix1
ix -> SRTree Ix1
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
evalKey SRTree Ix1
n
Const Double
v -> SRTree Ix1
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
evalKey SRTree Ix1
n
Param Ix1
ix -> SRTree Ix1
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
evalKey SRTree Ix1
n
SRTree Ix1
_ -> Ix1
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
getFromCache Ix1
rt
getFromCache :: Ix1
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
getFromCache Ix1
rt = do
Maybe (Array S Ix1 Double)
global <- (((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> Maybe (Array S Ix1 Double))
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
(Maybe (Array S Ix1 Double))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((IntMap (Array S Ix1 Double) -> Ix1 -> Maybe (Array S Ix1 Double)
forall a. IntMap a -> Ix1 -> Maybe a
IntMap.!? Ix1
rt) (IntMap (Array S Ix1 Double) -> Maybe (Array S Ix1 Double))
-> (((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> IntMap (Array S Ix1 Double))
-> ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> Maybe (Array S Ix1 Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double))
-> IntMap (Array S Ix1 Double)
forall a b. (a, b) -> a
fst ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double))
-> IntMap (Array S Ix1 Double))
-> (((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> (IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)))
-> ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> IntMap (Array S Ix1 Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> (IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double))
forall a b. (a, b) -> a
fst)
Maybe (Array S Ix1 Double)
local <- (((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> Maybe (Array S Ix1 Double))
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
(Maybe (Array S Ix1 Double))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((IntMap (Array S Ix1 Double) -> Ix1 -> Maybe (Array S Ix1 Double)
forall a. IntMap a -> Ix1 -> Maybe a
IntMap.!? Ix1
rt) (IntMap (Array S Ix1 Double) -> Maybe (Array S Ix1 Double))
-> (((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> IntMap (Array S Ix1 Double))
-> ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> Maybe (Array S Ix1 Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double))
-> IntMap (Array S Ix1 Double)
forall a b. (a, b) -> b
snd ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double))
-> IntMap (Array S Ix1 Double))
-> (((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> (IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)))
-> ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> IntMap (Array S Ix1 Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> (IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double))
forall a b. (a, b) -> a
fst)
if | Maybe (Array S Ix1 Double) -> Bool
forall a. Maybe a -> Bool
isJust Maybe (Array S Ix1 Double)
global -> (Array S Ix1 Double, Bool)
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
forall a.
a
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (Array S Ix1 Double) -> Array S Ix1 Double
forall a. HasCallStack => Maybe a -> a
fromJust Maybe (Array S Ix1 Double)
global, Bool
False)
| Maybe (Array S Ix1 Double) -> Bool
forall a. Maybe a -> Bool
isJust Maybe (Array S Ix1 Double)
local -> (Array S Ix1 Double, Bool)
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
forall a.
a
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (Array S Ix1 Double) -> Array S Ix1 Double
forall a. HasCallStack => Maybe a -> a
fromJust Maybe (Array S Ix1 Double)
local, Bool
True)
| Bool
otherwise -> Ix1
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
insertKey Ix1
rt
extractCache :: (Maybe a, Maybe a) -> a
extractCache (Maybe a
Nothing, Maybe a
Nothing) = [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"no root info"
extractCache (Just a
r, Maybe a
_) = a
r
extractCache (Maybe a
_, Just a
r) = a
r
((IntMap (Array S Ix1 Double)
cache'', IntMap (Array S Ix1 Double)
localcache'), Map (SRTree Ix1) (Array S Ix1 Double)
cachedGrad) = Ix1
-> Array S Ix1 Double
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
()
calcGrad Ix1
root Array S Ix1 Double
one StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
()
-> ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
forall s a. State s a -> s -> s
`execState` ((IntMap (Array S Ix1 Double)
cache', IntMap (Array S Ix1 Double)
localcache), Map (SRTree Ix1) (Array S Ix1 Double)
forall k a. Map k a
Map.empty)
calcGrad :: Int -> Array S Ix1 Double -> State ((IntMap.IntMap (Array S Ix1 Double), IntMap.IntMap (Array S Ix1 Double)), Map.Map (SRTree Int) (Array S Ix1 Double)) ()
calcGrad :: Ix1
-> Array S Ix1 Double
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
()
calcGrad Ix1
rt Array S Ix1 Double
v = do let node :: SRTree Ix1
node = Ix1 -> SRTree Ix1
getNode Ix1
rt
case SRTree Ix1
node of
Bin Op
op Ix1
l Ix1
r -> do Array S Ix1 Double
xl <- (Array S Ix1 Double, Bool) -> Array S Ix1 Double
forall a b. (a, b) -> a
fst ((Array S Ix1 Double, Bool) -> Array S Ix1 Double)
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
(Array S Ix1 Double)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ix1
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
getVal Ix1
l
Array S Ix1 Double
xr <- (Array S Ix1 Double, Bool) -> Array S Ix1 Double
forall a b. (a, b) -> a
fst ((Array S Ix1 Double, Bool) -> Array S Ix1 Double)
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
(Array S Ix1 Double)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ix1
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
getVal Ix1
r
(Array S Ix1 Double
dl, Array S Ix1 Double
dr) <- Op
-> Array S Ix1 Double
-> Array S Ix1 Double
-> Array S Ix1 Double
-> Ix1
-> Ix1
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
(Array S Ix1 Double, Array S Ix1 Double)
forall {r2} {r3}.
(Source r2 Double, Source r3 Double) =>
Op
-> Array S Ix1 Double
-> Array r2 Ix1 Double
-> Array r3 Ix1 Double
-> Ix1
-> Ix1
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
(Array S Ix1 Double, Array S Ix1 Double)
diff Op
op Array S Ix1 Double
v Array S Ix1 Double
xl Array S Ix1 Double
xr Ix1
l Ix1
r
Ix1
-> Array S Ix1 Double
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
()
calcGrad Ix1
l Array S Ix1 Double
dl
Ix1
-> Array S Ix1 Double
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
()
calcGrad Ix1
r Array S Ix1 Double
dr
Uni Function
f Ix1
t -> do Array S Ix1 Double
x <- (Array S Ix1 Double, Bool) -> Array S Ix1 Double
forall a b. (a, b) -> a
fst ((Array S Ix1 Double, Bool) -> Array S Ix1 Double)
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
(Array S Ix1 Double)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ix1
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
getVal Ix1
t
Ix1
-> Array S Ix1 Double
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
()
calcGrad Ix1
t (S -> Array D Ix1 Double -> Array S Ix1 Double
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D Ix1 Double -> Array S Ix1 Double)
-> Array D Ix1 Double -> Array S Ix1 Double
forall a b. (a -> b) -> a -> b
$ (Double -> Double -> Double)
-> Array S Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall ix r1 e1 r2 e2 e.
(Index ix, Source r1 e1, Source r2 e2) =>
(e1 -> e2 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix e
M.zipWith Double -> Double -> Double
forall a. Num a => a -> a -> a
(*) Array S Ix1 Double
v ((Double -> Double) -> Array S Ix1 Double -> Array D Ix1 Double
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Function -> Double -> Double
forall a. Floating a => Function -> a -> a
derivative Function
f) Array S Ix1 Double
x))
Param Ix1
ix -> (((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double)))
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' (Array S Ix1 Double
-> SRTree Ix1
-> ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
-> ((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
forall {e} {ix} {k} {a} {b}.
(Storable e, Index ix, Num e, Ord k) =>
Array S ix e
-> k
-> ((a, b), Map k (Array S ix e))
-> ((a, b), Map k (Array S ix e))
insertGrad Array S Ix1 Double
v (Ix1 -> SRTree Ix1
forall val. Ix1 -> SRTree val
Param Ix1
ix))
SRTree Ix1
_ -> ()
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
()
forall a.
a
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
where
insertGrad :: Array S ix e
-> k
-> ((a, b), Map k (Array S ix e))
-> ((a, b), Map k (Array S ix e))
insertGrad Array S ix e
v k
k ((a
a, b
b), Map k (Array S ix e)
g) = ((a
a, b
b), (Array S ix e -> Array S ix e -> Array S ix e)
-> k
-> Array S ix e
-> Map k (Array S ix e)
-> Map k (Array S ix e)
forall k a. Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
Map.insertWith (\Array S ix e
v1 Array S ix e
v2 -> S -> Array D ix e -> Array S ix e
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D ix e -> Array S ix e) -> Array D ix e -> Array S ix e
forall a b. (a -> b) -> a -> b
$ (e -> e -> e) -> Array S ix e -> Array S ix e -> Array D ix e
forall ix r1 e1 r2 e2 e.
(Index ix, Source r1 e1, Source r2 e2) =>
(e1 -> e2 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix e
M.zipWith e -> e -> e
forall a. Num a => a -> a -> a
(+) Array S ix e
v1 Array S ix e
v2) k
k Array S ix e
v Map k (Array S ix e)
g)
diff :: Op
-> Array S Ix1 Double
-> Array r2 Ix1 Double
-> Array r3 Ix1 Double
-> Ix1
-> Ix1
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
(Array S Ix1 Double, Array S Ix1 Double)
diff Op
Add Array S Ix1 Double
dx Array r2 Ix1 Double
fx Array r3 Ix1 Double
gy Ix1
l Ix1
r = (Array S Ix1 Double, Array S Ix1 Double)
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
(Array S Ix1 Double, Array S Ix1 Double)
forall a.
a
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Array S Ix1 Double
dx, Array S Ix1 Double
dx)
diff Op
Sub Array S Ix1 Double
dx Array r2 Ix1 Double
fx Array r3 Ix1 Double
gy Ix1
l Ix1
r = (Array S Ix1 Double, Array S Ix1 Double)
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
(Array S Ix1 Double, Array S Ix1 Double)
forall a.
a
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Array S Ix1 Double
dx, S -> Array D Ix1 Double -> Array S Ix1 Double
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D Ix1 Double -> Array S Ix1 Double)
-> Array D Ix1 Double -> Array S Ix1 Double
forall a b. (a -> b) -> a -> b
$ (Double -> Double) -> Array S Ix1 Double -> Array D Ix1 Double
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map Double -> Double
forall a. Num a => a -> a
negate Array S Ix1 Double
dx)
diff Op
Mul Array S Ix1 Double
dx Array r2 Ix1 Double
fx Array r3 Ix1 Double
gy Ix1
l Ix1
r = (Array S Ix1 Double, Array S Ix1 Double)
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
(Array S Ix1 Double, Array S Ix1 Double)
forall a.
a
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (S -> Array D Ix1 Double -> Array S Ix1 Double
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D Ix1 Double -> Array S Ix1 Double)
-> Array D Ix1 Double -> Array S Ix1 Double
forall a b. (a -> b) -> a -> b
$ (Double -> Double -> Double)
-> Array S Ix1 Double -> Array r3 Ix1 Double -> Array D Ix1 Double
forall ix r1 e1 r2 e2 e.
(Index ix, Source r1 e1, Source r2 e2) =>
(e1 -> e2 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix e
M.zipWith Double -> Double -> Double
forall a. Num a => a -> a -> a
(*) Array S Ix1 Double
dx Array r3 Ix1 Double
gy, S -> Array D Ix1 Double -> Array S Ix1 Double
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D Ix1 Double -> Array S Ix1 Double)
-> Array D Ix1 Double -> Array S Ix1 Double
forall a b. (a -> b) -> a -> b
$ (Double -> Double -> Double)
-> Array S Ix1 Double -> Array r2 Ix1 Double -> Array D Ix1 Double
forall ix r1 e1 r2 e2 e.
(Index ix, Source r1 e1, Source r2 e2) =>
(e1 -> e2 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix e
M.zipWith Double -> Double -> Double
forall a. Num a => a -> a -> a
(*) Array S Ix1 Double
dx Array r2 Ix1 Double
fx)
diff Op
Div Array S Ix1 Double
dx Array r2 Ix1 Double
fx Array r3 Ix1 Double
gy Ix1
l Ix1
r = do
let k :: Ix1
k = SRTree Ix1 -> Ix1
getId (Op -> Ix1 -> Ix1 -> SRTree Ix1
forall val. Op -> val -> val -> SRTree val
Bin Op
Div Ix1
l Ix1
r)
Array S Ix1 Double
v <- (Array S Ix1 Double, Bool) -> Array S Ix1 Double
forall a b. (a, b) -> a
fst ((Array S Ix1 Double, Bool) -> Array S Ix1 Double)
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
(Array S Ix1 Double)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ix1
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
getVal Ix1
k
(Array S Ix1 Double, Array S Ix1 Double)
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
(Array S Ix1 Double, Array S Ix1 Double)
forall a.
a
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (S -> Array D Ix1 Double -> Array S Ix1 Double
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D Ix1 Double -> Array S Ix1 Double)
-> Array D Ix1 Double -> Array S Ix1 Double
forall a b. (a -> b) -> a -> b
$ (Double -> Double -> Double)
-> Array S Ix1 Double -> Array r3 Ix1 Double -> Array D Ix1 Double
forall ix r1 e1 r2 e2 e.
(Index ix, Source r1 e1, Source r2 e2) =>
(e1 -> e2 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix e
M.zipWith Double -> Double -> Double
forall a. Fractional a => a -> a -> a
(/) Array S Ix1 Double
dx Array r3 Ix1 Double
gy
, S -> Array D Ix1 Double -> Array S Ix1 Double
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D Ix1 Double -> Array S Ix1 Double)
-> Array D Ix1 Double -> Array S Ix1 Double
forall a b. (a -> b) -> a -> b
$ (Double -> Double -> Double)
-> Array S Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall ix r1 e1 r2 e2 e.
(Index ix, Source r1 e1, Source r2 e2) =>
(e1 -> e2 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix e
M.zipWith Double -> Double -> Double
forall a. Num a => a -> a -> a
(*) Array S Ix1 Double
dx ((Double -> Double -> Double)
-> Array S Ix1 Double -> Array r3 Ix1 Double -> Array D Ix1 Double
forall ix r1 e1 r2 e2 e.
(Index ix, Source r1 e1, Source r2 e2) =>
(e1 -> e2 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix e
M.zipWith (\Double
l Double
r -> Double -> Double
forall a. Num a => a -> a
negate Double
lDouble -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
r) Array S Ix1 Double
v Array r3 Ix1 Double
gy))
diff Op
Power Array S Ix1 Double
dx Array r2 Ix1 Double
fx Array r3 Ix1 Double
gy Ix1
l Ix1
r = do
let k :: Ix1
k = SRTree Ix1 -> Ix1
getId (Op -> Ix1 -> Ix1 -> SRTree Ix1
forall val. Op -> val -> val -> SRTree val
Bin Op
Power Ix1
l Ix1
r)
Array S Ix1 Double
v <- (Array S Ix1 Double, Bool) -> Array S Ix1 Double
forall a b. (a, b) -> a
fst ((Array S Ix1 Double, Bool) -> Array S Ix1 Double)
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
(Array S Ix1 Double)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ix1
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
getVal Ix1
k
(Array S Ix1 Double, Array S Ix1 Double)
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
(Array S Ix1 Double, Array S Ix1 Double)
forall a.
a
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ( S -> Array D Ix1 Double -> Array S Ix1 Double
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D Ix1 Double -> Array S Ix1 Double)
-> Array D Ix1 Double -> Array S Ix1 Double
forall a b. (a -> b) -> a -> b
$ (Double -> Double -> Double -> Double -> Double)
-> Array S Ix1 Double
-> Array r2 Ix1 Double
-> Array r3 Ix1 Double
-> Array S Ix1 Double
-> Array D Ix1 Double
forall ix r1 e1 r2 e2 r3 e3 r4 e4 e.
(Index ix, Source r1 e1, Source r2 e2, Source r3 e3,
Source r4 e4) =>
(e1 -> e2 -> e3 -> e4 -> e)
-> Array r1 ix e1
-> Array r2 ix e2
-> Array r3 ix e3
-> Array r4 ix e4
-> Array D ix e
M.zipWith4 (\Double
d Double
f Double
g Double
vi -> Double -> Double
forall {a}. RealFloat a => a -> a
fixNaN (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Double
d Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
g Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
vi Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
f) Array S Ix1 Double
dx Array r2 Ix1 Double
fx Array r3 Ix1 Double
gy Array S Ix1 Double
v
, S -> Array D Ix1 Double -> Array S Ix1 Double
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D Ix1 Double -> Array S Ix1 Double)
-> Array D Ix1 Double -> Array S Ix1 Double
forall a b. (a -> b) -> a -> b
$ (Double -> Double -> Double -> Double)
-> Array S Ix1 Double
-> Array r2 Ix1 Double
-> Array S Ix1 Double
-> Array D Ix1 Double
forall ix r1 e1 r2 e2 r3 e3 e.
(Index ix, Source r1 e1, Source r2 e2, Source r3 e3) =>
(e1 -> e2 -> e3 -> e)
-> Array r1 ix e1
-> Array r2 ix e2
-> Array r3 ix e3
-> Array D ix e
M.zipWith3 (\Double
d Double
f Double
vi -> Double -> Double
forall {a}. RealFloat a => a -> a
fixNaN (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Double
d Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
vi Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Floating a => a -> a
log Double
f) Array S Ix1 Double
dx Array r2 Ix1 Double
fx Array S Ix1 Double
v)
diff Op
PowerAbs Array S Ix1 Double
dx Array r2 Ix1 Double
fx Array r3 Ix1 Double
gy Ix1
l Ix1
r = do
let k :: Ix1
k = SRTree Ix1 -> Ix1
getId (Op -> Ix1 -> Ix1 -> SRTree Ix1
forall val. Op -> val -> val -> SRTree val
Bin Op
PowerAbs Ix1
l Ix1
r)
Array S Ix1 Double
v <- (Array S Ix1 Double, Bool) -> Array S Ix1 Double
forall a b. (a, b) -> a
fst ((Array S Ix1 Double, Bool) -> Array S Ix1 Double)
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
(Array S Ix1 Double)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ix1
-> State
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
(Array S Ix1 Double, Bool)
getVal Ix1
k
let v2 :: Array D Ix1 Double
v2 = (Double -> Double) -> Array r2 Ix1 Double -> Array D Ix1 Double
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map Double -> Double
forall a. Num a => a -> a
abs Array r2 Ix1 Double
fx
v3 :: Array S Ix1 Double
v3 = S -> Array D Ix1 Double -> Array S Ix1 Double
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D Ix1 Double -> Array S Ix1 Double)
-> Array D Ix1 Double -> Array S Ix1 Double
forall a b. (a -> b) -> a -> b
$ (Double -> Double -> Double)
-> Array r2 Ix1 Double -> Array r3 Ix1 Double -> Array D Ix1 Double
forall ix r1 e1 r2 e2 e.
(Index ix, Source r1 e1, Source r2 e2) =>
(e1 -> e2 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix e
M.zipWith Double -> Double -> Double
forall a. Num a => a -> a -> a
(*) Array r2 Ix1 Double
fx Array r3 Ix1 Double
gy
(Array S Ix1 Double, Array S Ix1 Double)
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
(Array S Ix1 Double, Array S Ix1 Double)
forall a.
a
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ( S -> Array D Ix1 Double -> Array S Ix1 Double
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D Ix1 Double -> Array S Ix1 Double)
-> Array D Ix1 Double -> Array S Ix1 Double
forall a b. (a -> b) -> a -> b
$ (Double -> Double -> Double -> Double -> Double)
-> Array S Ix1 Double
-> Array S Ix1 Double
-> Array S Ix1 Double
-> Array D Ix1 Double
-> Array D Ix1 Double
forall ix r1 e1 r2 e2 r3 e3 r4 e4 e.
(Index ix, Source r1 e1, Source r2 e2, Source r3 e3,
Source r4 e4) =>
(e1 -> e2 -> e3 -> e4 -> e)
-> Array r1 ix e1
-> Array r2 ix e2
-> Array r3 ix e3
-> Array r4 ix e4
-> Array D ix e
M.zipWith4 (\Double
d Double
v3i Double
vi Double
v2i -> Double -> Double
forall {a}. RealFloat a => a -> a
fixNaN (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Double
d Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
v3i Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
vi Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ (Double
v2iDouble -> Integer -> Double
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2)) Array S Ix1 Double
dx Array S Ix1 Double
v3 Array S Ix1 Double
v Array D Ix1 Double
v2
, S -> Array D Ix1 Double -> Array S Ix1 Double
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D Ix1 Double -> Array S Ix1 Double)
-> Array D Ix1 Double -> Array S Ix1 Double
forall a b. (a -> b) -> a -> b
$ (Double -> Double -> Double -> Double)
-> Array S Ix1 Double
-> Array D Ix1 Double
-> Array S Ix1 Double
-> Array D Ix1 Double
forall ix r1 e1 r2 e2 r3 e3 e.
(Index ix, Source r1 e1, Source r2 e2, Source r3 e3) =>
(e1 -> e2 -> e3 -> e)
-> Array r1 ix e1
-> Array r2 ix e2
-> Array r3 ix e3
-> Array D ix e
M.zipWith3 (\Double
d Double
f Double
vi -> Double -> Double
forall {a}. RealFloat a => a -> a
fixNaN (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Double
d Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
vi Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Floating a => a -> a
log Double
f) Array S Ix1 Double
dx Array D Ix1 Double
v2 Array S Ix1 Double
v)
diff Op
AQ Array S Ix1 Double
dx Array r2 Ix1 Double
fx Array r3 Ix1 Double
gy Ix1
l Ix1
r = let dxl :: Array D Ix1 Double
dxl = (Double -> Double -> Double)
-> Array r3 Ix1 Double -> Array S Ix1 Double -> Array D Ix1 Double
forall ix r1 e1 r2 e2 e.
(Index ix, Source r1 e1, Source r2 e2) =>
(e1 -> e2 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix e
M.zipWith (\Double
g Double
d -> Double
d Double -> Double -> Double
forall a. Num a => a -> a -> a
* (Double -> Double
forall a. Fractional a => a -> a
recip (Double -> Double) -> (Double -> Double) -> Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> Double
forall a. Floating a => a -> a
sqrt (Double -> Double) -> (Double -> Double) -> Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double -> Double -> Double
forall a. Num a => a -> a -> a
+Double
1) (Double -> Double) -> (Double -> Double) -> Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double -> Integer -> Double
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2)) Double
g) Array r3 Ix1 Double
gy Array S Ix1 Double
dx
dxy :: Array D Ix1 Double
dxy = (Double -> Double -> Double -> Double)
-> Array r2 Ix1 Double
-> Array r3 Ix1 Double
-> Array D Ix1 Double
-> Array D Ix1 Double
forall ix r1 e1 r2 e2 r3 e3 e.
(Index ix, Source r1 e1, Source r2 e2, Source r3 e3) =>
(e1 -> e2 -> e3 -> e)
-> Array r1 ix e1
-> Array r2 ix e2
-> Array r3 ix e3
-> Array D ix e
M.zipWith3 (\Double
f Double
g Double
dl -> Double
f Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
g Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
dlDouble -> Integer -> Double
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
3) Array r2 Ix1 Double
fx Array r3 Ix1 Double
gy Array D Ix1 Double
dxl
in (Array S Ix1 Double, Array S Ix1 Double)
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
(Array S Ix1 Double, Array S Ix1 Double)
forall a.
a
-> StateT
((IntMap (Array S Ix1 Double), IntMap (Array S Ix1 Double)),
Map (SRTree Ix1) (Array S Ix1 Double))
Identity
a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (S -> Array D Ix1 Double -> Array S Ix1 Double
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D Ix1 Double -> Array S Ix1 Double)
-> Array D Ix1 Double -> Array S Ix1 Double
forall a b. (a -> b) -> a -> b
$ Array D Ix1 Double
dxl, S -> Array D Ix1 Double -> Array S Ix1 Double
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D Ix1 Double -> Array S Ix1 Double)
-> Array D Ix1 Double -> Array S Ix1 Double
forall a b. (a -> b) -> a -> b
$ Array D Ix1 Double
dxy)
fixNaN :: a -> a
fixNaN a
x = if a -> Bool
forall a. RealFloat a => a -> Bool
isNaN a
x then a
0 else a
x
reverseModeGraph :: SRMatrix -> PVector -> Maybe PVector -> VS.Vector Double -> Fix SRTree -> (Array D Ix1 Double, VS.Vector Double)
reverseModeGraph :: SRMatrix
-> Array S Ix1 Double
-> Maybe (Array S Ix1 Double)
-> Vector Double
-> Fix SRTree
-> (Array D Ix1 Double, Vector Double)
reverseModeGraph SRMatrix
xss Array S Ix1 Double
ys Maybe (Array S Ix1 Double)
mYErr Vector Double
theta Fix SRTree
tree = (Array S Ix1 Double -> Array D Ix1 Double
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay (Array S Ix1 Double -> Array D Ix1 Double)
-> Array S Ix1 Double -> Array D Ix1 Double
forall a b. (a -> b) -> a -> b
$ IntMap (Array S Ix1 Double)
cachedVal' IntMap (Array S Ix1 Double) -> Ix1 -> Array S Ix1 Double
forall a. IntMap a -> Ix1 -> a
IntMap.! Ix1
root
, [Double] -> Vector Double
forall a. Storable a => [a] -> Vector a
VS.fromList [Array S Ix1 Double -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (Array S Ix1 Double -> Double) -> Array S Ix1 Double -> Double
forall a b. (a -> b) -> a -> b
$ Map (SRTree Ix1) (Array S Ix1 Double)
cachedGrad Map (SRTree Ix1) (Array S Ix1 Double)
-> SRTree Ix1 -> Array S Ix1 Double
forall k a. Ord k => Map k a -> k -> a
Map.! (Ix1 -> SRTree Ix1
forall val. Ix1 -> SRTree val
Param Ix1
ix) | Ix1
ix <- [Ix1
0..Ix1
pIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
-Ix1
1]])
where
yErr :: Array S Ix1 Double
yErr = Maybe (Array S Ix1 Double) -> Array S Ix1 Double
forall a. HasCallStack => Maybe a -> a
fromJust Maybe (Array S Ix1 Double)
mYErr
m :: Sz Ix1
m = Array S Ix1 Double -> Sz Ix1
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size Array S Ix1 Double
ys
p :: Ix1
p = Vector Double -> Ix1
forall a. Storable a => Vector a -> Ix1
VS.length Vector Double
theta
comp :: Comp
comp = SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
M.getComp SRMatrix
xss
one :: Array S Ix1 Double
one :: Array S Ix1 Double
one = Comp -> Sz Ix1 -> Double -> Array S Ix1 Double
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
comp Sz Ix1
m Double
1
(Map (SRTree Ix1) Ix1
key2int, IntMap (SRTree Ix1)
int2key, IntMap (Array S Ix1 Double)
cachedVal, (Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
subtract Ix1
1) -> Ix1
root) = (forall x.
SRTree
(StateT
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
Identity
x)
-> StateT
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
Identity
(SRTree x))
-> (SRTree Ix1
-> StateT
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
Identity
Ix1)
-> Fix SRTree
-> StateT
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
Identity
Ix1
forall (f :: * -> *) (m :: * -> *) a.
(Functor f, Monad m) =>
(forall x. f (m x) -> m (f x)) -> (f a -> m a) -> Fix f -> m a
cataM SRTree
(StateT
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
Identity
x)
-> StateT
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
Identity
(SRTree x)
forall x.
SRTree
(StateT
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
Identity
x)
-> StateT
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
Identity
(SRTree x)
forall {f :: * -> *} {a}.
Applicative f =>
SRTree (f a) -> f (SRTree a)
leftToRight SRTree Ix1
-> StateT
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
Identity
Ix1
forall {m :: * -> *}.
MonadState
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
m =>
SRTree Ix1 -> m Ix1
alg Fix SRTree
tree StateT
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
Identity
Ix1
-> (Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
-> (Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
forall s a. State s a -> s -> s
`execState` (Map (SRTree Ix1) Ix1
forall k a. Map k a
Map.empty, IntMap (SRTree Ix1)
forall a. IntMap a
IntMap.empty, IntMap (Array S Ix1 Double)
forall a. IntMap a
IntMap.empty, Ix1
0)
(Map (SRTree Ix1) Ix1
key2int', IntMap (SRTree Ix1)
int2key', IntMap (Array S Ix1 Double)
cachedVal', Map (SRTree Ix1) (Array S Ix1 Double)
cachedGrad) = Ix1
-> Array S Ix1 Double
-> State
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
()
calcGrad Ix1
root Array S Ix1 Double
one State
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
()
-> (Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
-> (Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
forall s a. State s a -> s -> s
`execState` (Map (SRTree Ix1) Ix1
key2int, IntMap (SRTree Ix1)
int2key, IntMap (Array S Ix1 Double)
cachedVal, Map (SRTree Ix1) (Array S Ix1 Double)
forall k a. Map k a
Map.empty)
calcGrad :: Int -> Array S Ix1 Double -> State (Map.Map (SRTree Int) Int, IntMap.IntMap (SRTree Int), IntMap.IntMap (Array S Ix1 Double), Map.Map (SRTree Int) (Array S Ix1 Double)) ()
calcGrad :: Ix1
-> Array S Ix1 Double
-> State
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
()
calcGrad Ix1
key Array S Ix1 Double
v = do SRTree Ix1
node <- ((Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
-> SRTree Ix1)
-> StateT
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
Identity
(SRTree Ix1)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((IntMap (SRTree Ix1) -> Ix1 -> SRTree Ix1
forall a. IntMap a -> Ix1 -> a
IntMap.! Ix1
key) (IntMap (SRTree Ix1) -> SRTree Ix1)
-> ((Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
-> IntMap (SRTree Ix1))
-> (Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
-> SRTree Ix1
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
-> IntMap (SRTree Ix1)
forall {a} {b} {c} {d}. (a, b, c, d) -> b
_int2key)
case SRTree Ix1
node of
Bin Op
op Ix1
l Ix1
r -> do Array S Ix1 Double
xl <- ((Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
-> Array S Ix1 Double)
-> StateT
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
Identity
(Array S Ix1 Double)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (Ix1
-> (Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
-> Array S Ix1 Double
forall {a} {b} {a} {d}. Ix1 -> (a, b, IntMap a, d) -> a
getVal Ix1
l)
Array S Ix1 Double
xr <- ((Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
-> Array S Ix1 Double)
-> StateT
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
Identity
(Array S Ix1 Double)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (Ix1
-> (Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
-> Array S Ix1 Double
forall {a} {b} {a} {d}. Ix1 -> (a, b, IntMap a, d) -> a
getVal Ix1
r)
(Array S Ix1 Double
dl, Array S Ix1 Double
dr) <- Op
-> Array S Ix1 Double
-> Array S Ix1 Double
-> Array S Ix1 Double
-> Ix1
-> Ix1
-> StateT
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
Identity
(Array S Ix1 Double, Array S Ix1 Double)
forall {f :: * -> *} {e} {val} {b} {r3} {ix} {d} {r2} {r3}.
(RealFloat e,
MonadState (Map (SRTree val) Ix1, b, IntMap (Array r3 ix e), d) f,
Storable e, Index ix, Source r3 e, Source r2 e, Source r3 e,
Ord val) =>
Op
-> Array S ix e
-> Array r2 ix e
-> Array r3 ix e
-> val
-> val
-> f (Array S ix e, Array S ix e)
diff Op
op Array S Ix1 Double
v Array S Ix1 Double
xl Array S Ix1 Double
xr Ix1
l Ix1
r
Ix1
-> Array S Ix1 Double
-> State
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
()
calcGrad Ix1
l Array S Ix1 Double
dl
Ix1
-> Array S Ix1 Double
-> State
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
()
calcGrad Ix1
r Array S Ix1 Double
dr
Uni Function
f Ix1
t -> do Array S Ix1 Double
x <- ((Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
-> Array S Ix1 Double)
-> StateT
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
Identity
(Array S Ix1 Double)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (Ix1
-> (Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
-> Array S Ix1 Double
forall {a} {b} {a} {d}. Ix1 -> (a, b, IntMap a, d) -> a
getVal Ix1
t)
Ix1
-> Array S Ix1 Double
-> State
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
()
calcGrad Ix1
t (S -> Array D Ix1 Double -> Array S Ix1 Double
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D Ix1 Double -> Array S Ix1 Double)
-> Array D Ix1 Double -> Array S Ix1 Double
forall a b. (a -> b) -> a -> b
$ (Double -> Double -> Double)
-> Array S Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall ix r1 e1 r2 e2 e.
(Index ix, Source r1 e1, Source r2 e2) =>
(e1 -> e2 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix e
M.zipWith Double -> Double -> Double
forall a. Num a => a -> a -> a
(*) Array S Ix1 Double
v ((Double -> Double) -> Array S Ix1 Double -> Array D Ix1 Double
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Function -> Double -> Double
forall a. Floating a => Function -> a -> a
derivative Function
f) Array S Ix1 Double
x))
Param Ix1
ix -> ((Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
-> (Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double),
Map (SRTree Ix1) (Array S Ix1 Double)))
-> State
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' (Array S Ix1 Double
-> SRTree Ix1
-> (Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
-> (Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
forall {e} {ix} {k} {a} {b} {c}.
(Storable e, Index ix, Num e, Ord k) =>
Array S ix e
-> k
-> (a, b, c, Map k (Array S ix e))
-> (a, b, c, Map k (Array S ix e))
insertGrad Array S Ix1 Double
v (Ix1 -> SRTree Ix1
forall val. Ix1 -> SRTree val
Param Ix1
ix))
SRTree Ix1
_ -> ()
-> State
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
()
forall a.
a
-> StateT
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
Identity
a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
where
_int2key :: (a, b, c, d) -> b
_int2key (a
_, b
b, c
_, d
_) = b
b
insertGrad :: Array S ix e
-> k
-> (a, b, c, Map k (Array S ix e))
-> (a, b, c, Map k (Array S ix e))
insertGrad Array S ix e
v k
k (a
a, b
b, c
c, Map k (Array S ix e)
g) = (a
a, b
b, c
c, (Array S ix e -> Array S ix e -> Array S ix e)
-> k
-> Array S ix e
-> Map k (Array S ix e)
-> Map k (Array S ix e)
forall k a. Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
Map.insertWith (\Array S ix e
v1 Array S ix e
v2 -> S -> Array D ix e -> Array S ix e
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D ix e -> Array S ix e) -> Array D ix e -> Array S ix e
forall a b. (a -> b) -> a -> b
$ (e -> e -> e) -> Array S ix e -> Array S ix e -> Array D ix e
forall ix r1 e1 r2 e2 e.
(Index ix, Source r1 e1, Source r2 e2) =>
(e1 -> e2 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix e
M.zipWith e -> e -> e
forall a. Num a => a -> a -> a
(+) Array S ix e
v1 Array S ix e
v2) k
k Array S ix e
v Map k (Array S ix e)
g)
graph :: (a, b, c, d) -> a
graph (a
a, b
_, c
_, d
_) = a
a
insKey :: a
-> a
-> (Map a Ix1, IntMap a, IntMap a, Ix1)
-> (Map a Ix1, IntMap a, IntMap a, Ix1)
insKey a
key a
ev (Map a Ix1
a, IntMap a
b, IntMap a
c, Ix1
d) = (a -> Ix1 -> Map a Ix1 -> Map a Ix1
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert a
key Ix1
d Map a Ix1
a, Ix1 -> a -> IntMap a -> IntMap a
forall a. Ix1 -> a -> IntMap a -> IntMap a
IntMap.insert Ix1
d a
key IntMap a
b, Ix1 -> a -> IntMap a -> IntMap a
forall a. Ix1 -> a -> IntMap a -> IntMap a
IntMap.insert Ix1
d a
ev IntMap a
c, Ix1
dIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
+Ix1
1)
getVal :: Ix1 -> (a, b, IntMap a, d) -> a
getVal Ix1
key (a
a, b
b, IntMap a
c, d
d) = IntMap a
c IntMap a -> Ix1 -> a
forall a. IntMap a -> Ix1 -> a
IntMap.! Ix1
key
getKey :: k -> (Map k a, b, c, d) -> a
getKey k
key (Map k a
a, b
b, c
c, d
d) = Map k a
a Map k a -> k -> a
forall k a. Ord k => Map k a -> k -> a
Map.! k
key
leftToRight :: SRTree (f a) -> f (SRTree a)
leftToRight (Uni Function
f f a
mt) = Function -> a -> SRTree a
forall val. Function -> val -> SRTree val
Uni Function
f (a -> SRTree a) -> f a -> f (SRTree a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f a
mt;
leftToRight (Bin Op
f f a
ml f a
mr) = Op -> a -> a -> SRTree a
forall val. Op -> val -> val -> SRTree val
Bin Op
f (a -> a -> SRTree a) -> f a -> f (a -> SRTree a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f a
ml f (a -> SRTree a) -> f a -> f (SRTree a)
forall a b. f (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> f a
mr
leftToRight (Var Ix1
ix) = SRTree a -> f (SRTree a)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Ix1 -> SRTree a
forall val. Ix1 -> SRTree val
Var Ix1
ix)
leftToRight (Param Ix1
ix) = SRTree a -> f (SRTree a)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Ix1 -> SRTree a
forall val. Ix1 -> SRTree val
Param Ix1
ix)
leftToRight (Const Double
c) = SRTree a -> f (SRTree a)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Double -> SRTree a
forall val. Double -> SRTree val
Const Double
c)
evalKey :: SRTree Ix1 -> f (Array S Ix1 Double)
evalKey (Var Ix1
ix) = Array S Ix1 Double -> f (Array S Ix1 Double)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Array S Ix1 Double -> f (Array S Ix1 Double))
-> Array S Ix1 Double -> f (Array S Ix1 Double)
forall a b. (a -> b) -> a -> b
$ if Ix1
ix Ix1 -> Ix1 -> Bool
forall a. Eq a => a -> a -> Bool
== -Ix1
1
then Array S Ix1 Double
ys
else if Ix1
ix Ix1 -> Ix1 -> Bool
forall a. Eq a => a -> a -> Bool
== -Ix1
2
then Array S Ix1 Double
yErr
else S -> Array D Ix1 Double -> Array S Ix1 Double
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D Ix1 Double -> Array S Ix1 Double)
-> Array D Ix1 Double -> Array S Ix1 Double
forall a b. (a -> b) -> a -> b
$ SRMatrix
xss SRMatrix -> Ix1 -> Array D (Lower Ix2) Double
forall r ix e.
(HasCallStack, Index ix, Source r e) =>
Array r ix e -> Ix1 -> Array D (Lower ix) e
<! Ix1
ix
evalKey (Const Double
v) = Array S Ix1 Double -> f (Array S Ix1 Double)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Array S Ix1 Double -> f (Array S Ix1 Double))
-> Array S Ix1 Double -> f (Array S Ix1 Double)
forall a b. (a -> b) -> a -> b
$ Comp -> Sz Ix1 -> Double -> Array S Ix1 Double
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
comp Sz Ix1
m Double
v
evalKey (Param Ix1
ix) = Array S Ix1 Double -> f (Array S Ix1 Double)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Array S Ix1 Double -> f (Array S Ix1 Double))
-> Array S Ix1 Double -> f (Array S Ix1 Double)
forall a b. (a -> b) -> a -> b
$ Comp -> Sz Ix1 -> Double -> Array S Ix1 Double
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
comp Sz Ix1
m (Vector Double
theta Vector Double -> Ix1 -> Double
forall a. Storable a => Vector a -> Ix1 -> a
VS.! Ix1
ix)
evalKey (Uni Function
f Ix1
t) = S -> Array D Ix1 Double -> Array S Ix1 Double
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D Ix1 Double -> Array S Ix1 Double)
-> (Array r2 Ix1 Double -> Array D Ix1 Double)
-> Array r2 Ix1 Double
-> Array S Ix1 Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double -> Double) -> Array r2 Ix1 Double -> Array D Ix1 Double
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Function -> Double -> Double
forall a. Floating a => Function -> a -> a
evalFun Function
f) (Array r2 Ix1 Double -> Array S Ix1 Double)
-> f (Array r2 Ix1 Double) -> f (Array S Ix1 Double)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((a, b, IntMap (Array r2 Ix1 Double), d) -> Array r2 Ix1 Double)
-> f (Array r2 Ix1 Double)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (Ix1
-> (a, b, IntMap (Array r2 Ix1 Double), d) -> Array r2 Ix1 Double
forall {a} {b} {a} {d}. Ix1 -> (a, b, IntMap a, d) -> a
getVal Ix1
t)
evalKey (Bin Op
op Ix1
l Ix1
r) = S -> Array D Ix1 Double -> Array S Ix1 Double
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D Ix1 Double -> Array S Ix1 Double)
-> f (Array D Ix1 Double) -> f (Array S Ix1 Double)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Double -> Double -> Double)
-> Array r2 Ix1 Double -> Array r2 Ix1 Double -> Array D Ix1 Double
forall ix r1 e1 r2 e2 e.
(Index ix, Source r1 e1, Source r2 e2) =>
(e1 -> e2 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix e
M.zipWith (Op -> Double -> Double -> Double
forall a. Floating a => Op -> a -> a -> a
evalOp Op
op) (Array r2 Ix1 Double -> Array r2 Ix1 Double -> Array D Ix1 Double)
-> f (Array r2 Ix1 Double)
-> f (Array r2 Ix1 Double -> Array D Ix1 Double)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((a, b, IntMap (Array r2 Ix1 Double), d) -> Array r2 Ix1 Double)
-> f (Array r2 Ix1 Double)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (Ix1
-> (a, b, IntMap (Array r2 Ix1 Double), d) -> Array r2 Ix1 Double
forall {a} {b} {a} {d}. Ix1 -> (a, b, IntMap a, d) -> a
getVal Ix1
l) f (Array r2 Ix1 Double -> Array D Ix1 Double)
-> f (Array r2 Ix1 Double) -> f (Array D Ix1 Double)
forall a b. f (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ((a, b, IntMap (Array r2 Ix1 Double), d) -> Array r2 Ix1 Double)
-> f (Array r2 Ix1 Double)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (Ix1
-> (a, b, IntMap (Array r2 Ix1 Double), d) -> Array r2 Ix1 Double
forall {a} {b} {a} {d}. Ix1 -> (a, b, IntMap a, d) -> a
getVal Ix1
r))
alg :: SRTree Ix1 -> m Ix1
alg (Var Ix1
ix) = SRTree Ix1 -> m Ix1
forall {m :: * -> *}.
MonadState
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
m =>
SRTree Ix1 -> m Ix1
insertKey (Ix1 -> SRTree Ix1
forall val. Ix1 -> SRTree val
Var Ix1
ix)
alg (Param Ix1
ix) = SRTree Ix1 -> m Ix1
forall {m :: * -> *}.
MonadState
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
m =>
SRTree Ix1 -> m Ix1
insertKey (Ix1 -> SRTree Ix1
forall val. Ix1 -> SRTree val
Param Ix1
ix)
alg (Const Double
v) = SRTree Ix1 -> m Ix1
forall {m :: * -> *}.
MonadState
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
m =>
SRTree Ix1 -> m Ix1
insertKey (Double -> SRTree Ix1
forall val. Double -> SRTree val
Const Double
v)
alg (Uni Function
f Ix1
t) = SRTree Ix1 -> m Ix1
forall {m :: * -> *}.
MonadState
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
m =>
SRTree Ix1 -> m Ix1
insertKey (Function -> Ix1 -> SRTree Ix1
forall val. Function -> val -> SRTree val
Uni Function
f Ix1
t)
alg (Bin Op
op Ix1
l Ix1
r) = SRTree Ix1 -> m Ix1
forall {m :: * -> *}.
MonadState
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
m =>
SRTree Ix1 -> m Ix1
insertKey (Op -> Ix1 -> Ix1 -> SRTree Ix1
forall val. Op -> val -> val -> SRTree val
Bin Op
op Ix1
l Ix1
r)
diff :: Op
-> Array S ix e
-> Array r2 ix e
-> Array r3 ix e
-> val
-> val
-> f (Array S ix e, Array S ix e)
diff Op
Add Array S ix e
dx Array r2 ix e
fx Array r3 ix e
gy val
l val
r = (Array S ix e, Array S ix e) -> f (Array S ix e, Array S ix e)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Array S ix e
dx, Array S ix e
dx)
diff Op
Sub Array S ix e
dx Array r2 ix e
fx Array r3 ix e
gy val
l val
r = (Array S ix e, Array S ix e) -> f (Array S ix e, Array S ix e)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Array S ix e
dx, S -> Array D ix e -> Array S ix e
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D ix e -> Array S ix e) -> Array D ix e -> Array S ix e
forall a b. (a -> b) -> a -> b
$ (e -> e) -> Array S ix e -> Array D ix e
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map e -> e
forall a. Num a => a -> a
negate Array S ix e
dx)
diff Op
Mul Array S ix e
dx Array r2 ix e
fx Array r3 ix e
gy val
l val
r = (Array S ix e, Array S ix e) -> f (Array S ix e, Array S ix e)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (S -> Array D ix e -> Array S ix e
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D ix e -> Array S ix e) -> Array D ix e -> Array S ix e
forall a b. (a -> b) -> a -> b
$ (e -> e -> e) -> Array S ix e -> Array r3 ix e -> Array D ix e
forall ix r1 e1 r2 e2 e.
(Index ix, Source r1 e1, Source r2 e2) =>
(e1 -> e2 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix e
M.zipWith e -> e -> e
forall a. Num a => a -> a -> a
(*) Array S ix e
dx Array r3 ix e
gy, S -> Array D ix e -> Array S ix e
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D ix e -> Array S ix e) -> Array D ix e -> Array S ix e
forall a b. (a -> b) -> a -> b
$ (e -> e -> e) -> Array S ix e -> Array r2 ix e -> Array D ix e
forall ix r1 e1 r2 e2 e.
(Index ix, Source r1 e1, Source r2 e2) =>
(e1 -> e2 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix e
M.zipWith e -> e -> e
forall a. Num a => a -> a -> a
(*) Array S ix e
dx Array r2 ix e
fx)
diff Op
Div Array S ix e
dx Array r2 ix e
fx Array r3 ix e
gy val
l val
r = do
Ix1
k <- ((Map (SRTree val) Ix1, b, IntMap (Array r3 ix e), d) -> Ix1)
-> f Ix1
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (SRTree val
-> (Map (SRTree val) Ix1, b, IntMap (Array r3 ix e), d) -> Ix1
forall {k} {a} {b} {c} {d}. Ord k => k -> (Map k a, b, c, d) -> a
getKey (Op -> val -> val -> SRTree val
forall val. Op -> val -> val -> SRTree val
Bin Op
Div val
l val
r))
Array r3 ix e
v <- ((Map (SRTree val) Ix1, b, IntMap (Array r3 ix e), d)
-> Array r3 ix e)
-> f (Array r3 ix e)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (Ix1
-> (Map (SRTree val) Ix1, b, IntMap (Array r3 ix e), d)
-> Array r3 ix e
forall {a} {b} {a} {d}. Ix1 -> (a, b, IntMap a, d) -> a
getVal Ix1
k)
(Array S ix e, Array S ix e) -> f (Array S ix e, Array S ix e)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (S -> Array D ix e -> Array S ix e
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D ix e -> Array S ix e) -> Array D ix e -> Array S ix e
forall a b. (a -> b) -> a -> b
$ (e -> e -> e) -> Array S ix e -> Array r3 ix e -> Array D ix e
forall ix r1 e1 r2 e2 e.
(Index ix, Source r1 e1, Source r2 e2) =>
(e1 -> e2 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix e
M.zipWith e -> e -> e
forall a. Fractional a => a -> a -> a
(/) Array S ix e
dx Array r3 ix e
gy
, S -> Array D ix e -> Array S ix e
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D ix e -> Array S ix e) -> Array D ix e -> Array S ix e
forall a b. (a -> b) -> a -> b
$ (e -> e -> e) -> Array S ix e -> Array D ix e -> Array D ix e
forall ix r1 e1 r2 e2 e.
(Index ix, Source r1 e1, Source r2 e2) =>
(e1 -> e2 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix e
M.zipWith e -> e -> e
forall a. Num a => a -> a -> a
(*) Array S ix e
dx ((e -> e -> e) -> Array r3 ix e -> Array r3 ix e -> Array D ix e
forall ix r1 e1 r2 e2 e.
(Index ix, Source r1 e1, Source r2 e2) =>
(e1 -> e2 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix e
M.zipWith (\e
l e
r -> e -> e
forall a. Num a => a -> a
negate e
le -> e -> e
forall a. Fractional a => a -> a -> a
/e
r) Array r3 ix e
v Array r3 ix e
gy))
diff Op
Power Array S ix e
dx Array r2 ix e
fx Array r3 ix e
gy val
l val
r = do
Ix1
k <- ((Map (SRTree val) Ix1, b, IntMap (Array r3 ix e), d) -> Ix1)
-> f Ix1
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (SRTree val
-> (Map (SRTree val) Ix1, b, IntMap (Array r3 ix e), d) -> Ix1
forall {k} {a} {b} {c} {d}. Ord k => k -> (Map k a, b, c, d) -> a
getKey (Op -> val -> val -> SRTree val
forall val. Op -> val -> val -> SRTree val
Bin Op
Power val
l val
r))
Array r3 ix e
v <- ((Map (SRTree val) Ix1, b, IntMap (Array r3 ix e), d)
-> Array r3 ix e)
-> f (Array r3 ix e)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (Ix1
-> (Map (SRTree val) Ix1, b, IntMap (Array r3 ix e), d)
-> Array r3 ix e
forall {a} {b} {a} {d}. Ix1 -> (a, b, IntMap a, d) -> a
getVal Ix1
k)
(Array S ix e, Array S ix e) -> f (Array S ix e, Array S ix e)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ( S -> Array D ix e -> Array S ix e
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D ix e -> Array S ix e) -> Array D ix e -> Array S ix e
forall a b. (a -> b) -> a -> b
$ (e -> e -> e -> e -> e)
-> Array S ix e
-> Array r2 ix e
-> Array r3 ix e
-> Array r3 ix e
-> Array D ix e
forall ix r1 e1 r2 e2 r3 e3 r4 e4 e.
(Index ix, Source r1 e1, Source r2 e2, Source r3 e3,
Source r4 e4) =>
(e1 -> e2 -> e3 -> e4 -> e)
-> Array r1 ix e1
-> Array r2 ix e2
-> Array r3 ix e3
-> Array r4 ix e4
-> Array D ix e
M.zipWith4 (\e
d e
f e
g e
vi -> e -> e
forall {a}. RealFloat a => a -> a
fixNaN (e -> e) -> e -> e
forall a b. (a -> b) -> a -> b
$ e
d e -> e -> e
forall a. Num a => a -> a -> a
* e
g e -> e -> e
forall a. Num a => a -> a -> a
* e
vi e -> e -> e
forall a. Fractional a => a -> a -> a
/ e
f) Array S ix e
dx Array r2 ix e
fx Array r3 ix e
gy Array r3 ix e
v
, S -> Array D ix e -> Array S ix e
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D ix e -> Array S ix e) -> Array D ix e -> Array S ix e
forall a b. (a -> b) -> a -> b
$ (e -> e -> e -> e)
-> Array S ix e -> Array r2 ix e -> Array r3 ix e -> Array D ix e
forall ix r1 e1 r2 e2 r3 e3 e.
(Index ix, Source r1 e1, Source r2 e2, Source r3 e3) =>
(e1 -> e2 -> e3 -> e)
-> Array r1 ix e1
-> Array r2 ix e2
-> Array r3 ix e3
-> Array D ix e
M.zipWith3 (\e
d e
f e
vi -> e -> e
forall {a}. RealFloat a => a -> a
fixNaN (e -> e) -> e -> e
forall a b. (a -> b) -> a -> b
$ e
d e -> e -> e
forall a. Num a => a -> a -> a
* e
vi e -> e -> e
forall a. Num a => a -> a -> a
* e -> e
forall a. Floating a => a -> a
log e
f) Array S ix e
dx Array r2 ix e
fx Array r3 ix e
v)
diff Op
PowerAbs Array S ix e
dx Array r2 ix e
fx Array r3 ix e
gy val
l val
r = do
Ix1
k <- ((Map (SRTree val) Ix1, b, IntMap (Array r3 ix e), d) -> Ix1)
-> f Ix1
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (SRTree val
-> (Map (SRTree val) Ix1, b, IntMap (Array r3 ix e), d) -> Ix1
forall {k} {a} {b} {c} {d}. Ord k => k -> (Map k a, b, c, d) -> a
getKey (Op -> val -> val -> SRTree val
forall val. Op -> val -> val -> SRTree val
Bin Op
PowerAbs val
l val
r))
Array r3 ix e
v <- ((Map (SRTree val) Ix1, b, IntMap (Array r3 ix e), d)
-> Array r3 ix e)
-> f (Array r3 ix e)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (Ix1
-> (Map (SRTree val) Ix1, b, IntMap (Array r3 ix e), d)
-> Array r3 ix e
forall {a} {b} {a} {d}. Ix1 -> (a, b, IntMap a, d) -> a
getVal Ix1
k)
let v2 :: Array D ix e
v2 = (e -> e) -> Array r2 ix e -> Array D ix e
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map e -> e
forall a. Num a => a -> a
abs Array r2 ix e
fx
v3 :: Array S ix e
v3 = S -> Array D ix e -> Array S ix e
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D ix e -> Array S ix e) -> Array D ix e -> Array S ix e
forall a b. (a -> b) -> a -> b
$ (e -> e -> e) -> Array r2 ix e -> Array r3 ix e -> Array D ix e
forall ix r1 e1 r2 e2 e.
(Index ix, Source r1 e1, Source r2 e2) =>
(e1 -> e2 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix e
M.zipWith e -> e -> e
forall a. Num a => a -> a -> a
(*) Array r2 ix e
fx Array r3 ix e
gy
(Array S ix e, Array S ix e) -> f (Array S ix e, Array S ix e)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ( S -> Array D ix e -> Array S ix e
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D ix e -> Array S ix e) -> Array D ix e -> Array S ix e
forall a b. (a -> b) -> a -> b
$ (e -> e -> e -> e -> e)
-> Array S ix e
-> Array S ix e
-> Array r3 ix e
-> Array D ix e
-> Array D ix e
forall ix r1 e1 r2 e2 r3 e3 r4 e4 e.
(Index ix, Source r1 e1, Source r2 e2, Source r3 e3,
Source r4 e4) =>
(e1 -> e2 -> e3 -> e4 -> e)
-> Array r1 ix e1
-> Array r2 ix e2
-> Array r3 ix e3
-> Array r4 ix e4
-> Array D ix e
M.zipWith4 (\e
d e
v3i e
vi e
v2i -> e -> e
forall {a}. RealFloat a => a -> a
fixNaN (e -> e) -> e -> e
forall a b. (a -> b) -> a -> b
$ e
d e -> e -> e
forall a. Num a => a -> a -> a
* e
v3i e -> e -> e
forall a. Num a => a -> a -> a
* e
vi e -> e -> e
forall a. Fractional a => a -> a -> a
/ (e
v2ie -> Integer -> e
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2)) Array S ix e
dx Array S ix e
v3 Array r3 ix e
v Array D ix e
v2
, S -> Array D ix e -> Array S ix e
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D ix e -> Array S ix e) -> Array D ix e -> Array S ix e
forall a b. (a -> b) -> a -> b
$ (e -> e -> e -> e)
-> Array S ix e -> Array D ix e -> Array r3 ix e -> Array D ix e
forall ix r1 e1 r2 e2 r3 e3 e.
(Index ix, Source r1 e1, Source r2 e2, Source r3 e3) =>
(e1 -> e2 -> e3 -> e)
-> Array r1 ix e1
-> Array r2 ix e2
-> Array r3 ix e3
-> Array D ix e
M.zipWith3 (\e
d e
f e
vi -> e -> e
forall {a}. RealFloat a => a -> a
fixNaN (e -> e) -> e -> e
forall a b. (a -> b) -> a -> b
$ e
d e -> e -> e
forall a. Num a => a -> a -> a
* e
vi e -> e -> e
forall a. Num a => a -> a -> a
* e -> e
forall a. Floating a => a -> a
log e
f) Array S ix e
dx Array D ix e
v2 Array r3 ix e
v)
diff Op
AQ Array S ix e
dx Array r2 ix e
fx Array r3 ix e
gy val
l val
r = let dxl :: Array D ix e
dxl = (e -> e -> e) -> Array r3 ix e -> Array S ix e -> Array D ix e
forall ix r1 e1 r2 e2 e.
(Index ix, Source r1 e1, Source r2 e2) =>
(e1 -> e2 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix e
M.zipWith (\e
g e
d -> e
d e -> e -> e
forall a. Num a => a -> a -> a
* (e -> e
forall a. Fractional a => a -> a
recip (e -> e) -> (e -> e) -> e -> e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> e
forall a. Floating a => a -> a
sqrt (e -> e) -> (e -> e) -> e -> e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (e -> e -> e
forall a. Num a => a -> a -> a
+e
1) (e -> e) -> (e -> e) -> e -> e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (e -> Integer -> e
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2)) e
g) Array r3 ix e
gy Array S ix e
dx
dxy :: Array D ix e
dxy = (e -> e -> e -> e)
-> Array r2 ix e -> Array r3 ix e -> Array D ix e -> Array D ix e
forall ix r1 e1 r2 e2 r3 e3 e.
(Index ix, Source r1 e1, Source r2 e2, Source r3 e3) =>
(e1 -> e2 -> e3 -> e)
-> Array r1 ix e1
-> Array r2 ix e2
-> Array r3 ix e3
-> Array D ix e
M.zipWith3 (\e
f e
g e
dl -> e
f e -> e -> e
forall a. Num a => a -> a -> a
* e
g e -> e -> e
forall a. Num a => a -> a -> a
* e
dle -> Integer -> e
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
3) Array r2 ix e
fx Array r3 ix e
gy Array D ix e
dxl
in (Array S ix e, Array S ix e) -> f (Array S ix e, Array S ix e)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (S -> Array D ix e -> Array S ix e
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D ix e -> Array S ix e) -> Array D ix e -> Array S ix e
forall a b. (a -> b) -> a -> b
$ Array D ix e
dxl, S -> Array D ix e -> Array S ix e
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D ix e -> Array S ix e) -> Array D ix e -> Array S ix e
forall a b. (a -> b) -> a -> b
$ Array D ix e
dxy)
fixNaN :: a -> a
fixNaN a
x = if a -> Bool
forall a. RealFloat a => a -> Bool
isNaN a
x then a
0 else a
x
insertKey :: SRTree Ix1 -> m Ix1
insertKey SRTree Ix1
key = do
Bool
isCached <- ((Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
-> Bool)
-> m Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((SRTree Ix1
key SRTree Ix1 -> Map (SRTree Ix1) Ix1 -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`Map.member`) (Map (SRTree Ix1) Ix1 -> Bool)
-> ((Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
-> Map (SRTree Ix1) Ix1)
-> (Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
-> Map (SRTree Ix1) Ix1
forall {a} {b} {c} {d}. (a, b, c, d) -> a
graph)
Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool -> Bool
not Bool
isCached) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
Array S Ix1 Double
ev <- SRTree Ix1 -> m (Array S Ix1 Double)
forall {f :: * -> *} {a} {b} {r2} {d}.
(MonadState (a, b, IntMap (Array r2 Ix1 Double), d) f,
Source r2 Double) =>
SRTree Ix1 -> f (Array S Ix1 Double)
evalKey SRTree Ix1
key
((Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
-> (Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1))
-> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' (SRTree Ix1
-> Array S Ix1 Double
-> (Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
-> (Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
forall {a} {a}.
Ord a =>
a
-> a
-> (Map a Ix1, IntMap a, IntMap a, Ix1)
-> (Map a Ix1, IntMap a, IntMap a, Ix1)
insKey SRTree Ix1
key Array S Ix1 Double
ev)
((Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
-> Ix1)
-> m Ix1
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (SRTree Ix1
-> (Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
-> Ix1
forall {k} {a} {b} {c} {d}. Ord k => k -> (Map k a, b, c, d) -> a
getKey SRTree Ix1
key)
reverseModeArr :: SRMatrix
-> PVector
-> Maybe PVector
-> VS.Vector Double
-> [(Int, (Int, Int, Int, Double))]
-> IntMap.IntMap Int
-> (Array D Ix1 Double, Array S Ix1 Double)
reverseModeArr :: SRMatrix
-> Array S Ix1 Double
-> Maybe (Array S Ix1 Double)
-> Vector Double
-> [(Ix1, (Ix1, Ix1, Ix1, Double))]
-> ClassIdMap Ix1
-> (Array D Ix1 Double, Array S Ix1 Double)
reverseModeArr SRMatrix
xss Array S Ix1 Double
ys Maybe (Array S Ix1 Double)
mYErr Vector Double
theta [(Ix1, (Ix1, Ix1, Ix1, Double))]
t ClassIdMap Ix1
j2ix =
IO (Array D Ix1 Double, Array S Ix1 Double)
-> (Array D Ix1 Double, Array S Ix1 Double)
forall a. IO a -> a
unsafePerformIO (IO (Array D Ix1 Double, Array S Ix1 Double)
-> (Array D Ix1 Double, Array S Ix1 Double))
-> IO (Array D Ix1 Double, Array S Ix1 Double)
-> (Array D Ix1 Double, Array S Ix1 Double)
forall a b. (a -> b) -> a -> b
$ do
MArray RealWorld S Ix2 Double
fwd <- Sz Ix2 -> Double -> IO (MArray (PrimState IO) S Ix2 Double)
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
Sz ix -> e -> m (MArray (PrimState m) r ix e)
forall ix (m :: * -> *).
(Index ix, PrimMonad m) =>
Sz ix -> Double -> m (MArray (PrimState m) S ix Double)
M.newMArray (Ix1 -> Ix1 -> Sz Ix2
Sz2 Ix1
n Ix1
m) Double
0
MArray RealWorld S Ix2 Double
partial <- Sz Ix2 -> Double -> IO (MArray (PrimState IO) S Ix2 Double)
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
Sz ix -> e -> m (MArray (PrimState m) r ix e)
forall ix (m :: * -> *).
(Index ix, PrimMonad m) =>
Sz ix -> Double -> m (MArray (PrimState m) S ix Double)
M.newMArray (Ix1 -> Ix1 -> Sz Ix2
Sz2 Ix1
n Ix1
m) Double
0
MArray RealWorld S Ix1 Double
jacob <- Sz Ix1 -> Double -> IO (MArray (PrimState IO) S Ix1 Double)
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
Sz ix -> e -> m (MArray (PrimState m) r ix e)
forall ix (m :: * -> *).
(Index ix, PrimMonad m) =>
Sz ix -> Double -> m (MArray (PrimState m) S ix Double)
M.newMArray (Ix1 -> Sz Ix1
forall ix. Index ix => ix -> Sz ix
Sz Ix1
p) Double
0
MArray RealWorld S Ix1 Double
val <- Sz Ix1 -> Double -> IO (MArray (PrimState IO) S Ix1 Double)
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
Sz ix -> e -> m (MArray (PrimState m) r ix e)
forall ix (m :: * -> *).
(Index ix, PrimMonad m) =>
Sz ix -> Double -> m (MArray (PrimState m) S ix Double)
M.newMArray (Ix1 -> Sz Ix1
forall ix. Index ix => ix -> Sz ix
Sz Ix1
m) Double
0
let
stps :: Integer
stps = Integer
2
(Ix1
a, Ix1
b) = (Ix1
0, Ix1
m)
(Ix1, Ix1) -> MArray (PrimState IO) S Ix2 Double -> IO ()
forward (Ix1
a, Ix1
b) MArray RealWorld S Ix2 Double
MArray (PrimState IO) S Ix2 Double
fwd
(Ix1, Ix1)
-> MArray (PrimState IO) S Ix2 Double
-> MArray (PrimState IO) S Ix1 Double
-> IO ()
calculateYHat (Ix1
a, Ix1
b) MArray RealWorld S Ix2 Double
MArray (PrimState IO) S Ix2 Double
fwd MArray RealWorld S Ix1 Double
MArray (PrimState IO) S Ix1 Double
val
(Ix1, Ix1)
-> MArray (PrimState IO) S Ix2 Double
-> MArray (PrimState IO) S Ix2 Double
-> IO ()
reverseMode (Ix1
a, Ix1
b) MArray RealWorld S Ix2 Double
MArray (PrimState IO) S Ix2 Double
fwd MArray RealWorld S Ix2 Double
MArray (PrimState IO) S Ix2 Double
partial
(Ix1, Ix1)
-> MArray (PrimState IO) S Ix2 Double
-> MArray (PrimState IO) S Ix1 Double
-> IO ()
combine (Ix1
a, Ix1
b) MArray RealWorld S Ix2 Double
MArray (PrimState IO) S Ix2 Double
partial MArray RealWorld S Ix1 Double
MArray (PrimState IO) S Ix1 Double
jacob
Array S Ix1 Double
j <- Comp
-> MArray (PrimState IO) S Ix1 Double -> IO (Array S Ix1 Double)
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
Comp -> MArray (PrimState m) r ix e -> m (Array r ix e)
forall ix (m :: * -> *).
(Index ix, PrimMonad m) =>
Comp -> MArray (PrimState m) S ix Double -> m (Array S ix Double)
UMA.unsafeFreeze (SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp SRMatrix
xss) MArray RealWorld S Ix1 Double
MArray (PrimState IO) S Ix1 Double
jacob
Array S Ix1 Double
v <- Comp
-> MArray (PrimState IO) S Ix1 Double -> IO (Array S Ix1 Double)
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
Comp -> MArray (PrimState m) r ix e -> m (Array r ix e)
forall ix (m :: * -> *).
(Index ix, PrimMonad m) =>
Comp -> MArray (PrimState m) S ix Double -> m (Array S ix Double)
UMA.unsafeFreeze (SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp SRMatrix
xss) MArray RealWorld S Ix1 Double
MArray (PrimState IO) S Ix1 Double
val
(Array D Ix1 Double, Array S Ix1 Double)
-> IO (Array D Ix1 Double, Array S Ix1 Double)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Array S Ix1 Double -> Array D Ix1 Double
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay Array S Ix1 Double
v, Array S Ix1 Double
j)
where
(Sz2 Ix1
m Ix1
_) = SRMatrix -> Sz Ix2
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size SRMatrix
xss
p :: Ix1
p = Vector Double -> Ix1
forall a. Storable a => Vector a -> Ix1
VS.length Vector Double
theta
n :: Ix1
n = [(Ix1, (Ix1, Ix1, Ix1, Double))] -> Ix1
forall a. [a] -> Ix1
forall (t :: * -> *) a. Foldable t => t a -> Ix1
length [(Ix1, (Ix1, Ix1, Ix1, Double))]
t
toLin :: Ix1 -> Ix1 -> Ix1
toLin Ix1
i Ix1
j = Ix1
iIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
*Ix1
m Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
+ Ix1
j
yErr :: Array S Ix1 Double
yErr = Maybe (Array S Ix1 Double) -> Array S Ix1 Double
forall a. HasCallStack => Maybe a -> a
fromJust Maybe (Array S Ix1 Double)
mYErr
eps :: Double
eps = Double
1e-8
myForM_ :: [t] -> (t -> f a) -> f ()
myForM_ [] t -> f a
_ = () -> f ()
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
myForM_ (!t
x:[t]
xs) t -> f a
f = do t -> f a
f t
x
[t] -> (t -> f a) -> f ()
myForM_ [t]
xs t -> f a
f
{-# INLINE myForM_ #-}
calculateYHat :: (Int, Int) -> MArray (PrimState IO) S Ix2 Double -> MArray (PrimState IO) S Ix1 Double -> IO ()
calculateYHat :: (Ix1, Ix1)
-> MArray (PrimState IO) S Ix2 Double
-> MArray (PrimState IO) S Ix1 Double
-> IO ()
calculateYHat (Ix1
a, Ix1
b) MArray (PrimState IO) S Ix2 Double
fwd MArray (PrimState IO) S Ix1 Double
yhat = [Ix1] -> (Ix1 -> IO ()) -> IO ()
forall {f :: * -> *} {t} {a}. Monad f => [t] -> (t -> f a) -> f ()
myForM_ [Ix1
a..Ix1
bIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
-Ix1
1] ((Ix1 -> IO ()) -> IO ()) -> (Ix1 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ix1
i -> do
Double
vi <- MArray (PrimState IO) S Ix2 Double -> Ix2 -> IO Double
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m e
UMA.unsafeRead MArray (PrimState IO) S Ix2 Double
fwd (Ix1
0 Ix1 -> Ix1 -> Ix2
:. Ix1
i)
MArray (PrimState IO) S Ix1 Double -> Ix1 -> Double -> IO ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
UMA.unsafeWrite MArray (PrimState IO) S Ix1 Double
yhat Ix1
i Double
vi
{-# INLINE calculateYHat #-}
forward :: (Int, Int) -> MArray (PrimState IO) S Ix2 Double -> IO ()
forward :: (Ix1, Ix1) -> MArray (PrimState IO) S Ix2 Double -> IO ()
forward (Ix1
a, Ix1
b) MArray (PrimState IO) S Ix2 Double
fwd = do
let t' :: [(Ix1, (Ix1, Ix1, Ix1, Double))]
t' = [(Ix1, (Ix1, Ix1, Ix1, Double))]
-> [(Ix1, (Ix1, Ix1, Ix1, Double))]
forall a. [a] -> [a]
Prelude.reverse [(Ix1, (Ix1, Ix1, Ix1, Double))]
t
[(Ix1, (Ix1, Ix1, Ix1, Double))]
-> ((Ix1, (Ix1, Ix1, Ix1, Double)) -> IO ()) -> IO ()
forall {f :: * -> *} {t} {a}. Monad f => [t] -> (t -> f a) -> f ()
myForM_ [(Ix1, (Ix1, Ix1, Ix1, Double))]
t' (Ix1, (Ix1, Ix1, Ix1, Double)) -> IO ()
forall {f :: * -> *} {a}.
(PrimState f ~ RealWorld, Eq a, Num a, PrimMonad f) =>
(Ix1, (a, Ix1, Ix1, Double)) -> f ()
makeFwd
where
makeFwd :: (Ix1, (a, Ix1, Ix1, Double)) -> f ()
makeFwd (Ix1
j, (a
0, Ix1
0, Ix1
ix, Double
_)) =
do let j' :: Ix1
j' = ClassIdMap Ix1
j2ix ClassIdMap Ix1 -> Ix1 -> Ix1
forall a. IntMap a -> Ix1 -> a
IntMap.! Ix1
j
[Ix1] -> (Ix1 -> f ()) -> f ()
forall {f :: * -> *} {t} {a}. Monad f => [t] -> (t -> f a) -> f ()
myForM_ [Ix1
a..Ix1
bIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
-Ix1
1] ((Ix1 -> f ()) -> f ()) -> (Ix1 -> f ()) -> f ()
forall a b. (a -> b) -> a -> b
$ \Ix1
i -> do
MArray (PrimState f) S Ix2 Double -> Ix2 -> Double -> f ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
UMA.unsafeWrite MArray (PrimState f) S Ix2 Double
MArray (PrimState IO) S Ix2 Double
fwd (Ix1
j' Ix1 -> Ix1 -> Ix2
:. Ix1
i) (Double -> f ()) -> Double -> f ()
forall a b. (a -> b) -> a -> b
$ case Ix1
ix of
(-1) -> Array S Ix1 Double
ys Array S Ix1 Double -> Ix1 -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
M.! Ix1
i
(-2) -> Array S Ix1 Double
yErr Array S Ix1 Double -> Ix1 -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
M.! Ix1
i
Ix1
_ -> SRMatrix
xss SRMatrix -> Ix2 -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
M.! (Ix1
i Ix1 -> Ix1 -> Ix2
:. Ix1
ix)
makeFwd (Ix1
j, (a
0, Ix1
1, Ix1
ix, Double
_)) = do let j' :: Ix1
j' = ClassIdMap Ix1
j2ix ClassIdMap Ix1 -> Ix1 -> Ix1
forall a. IntMap a -> Ix1 -> a
IntMap.! Ix1
j
v :: Double
v = Vector Double
theta Vector Double -> Ix1 -> Double
forall a. Storable a => Vector a -> Ix1 -> a
VS.! Ix1
ix
[Ix1] -> (Ix1 -> f ()) -> f ()
forall {f :: * -> *} {t} {a}. Monad f => [t] -> (t -> f a) -> f ()
myForM_ [Ix1
a..Ix1
bIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
-Ix1
1] ((Ix1 -> f ()) -> f ()) -> (Ix1 -> f ()) -> f ()
forall a b. (a -> b) -> a -> b
$ \Ix1
i -> do
MArray (PrimState f) S Ix2 Double -> Ix2 -> Double -> f ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
UMA.unsafeWrite MArray (PrimState f) S Ix2 Double
MArray (PrimState IO) S Ix2 Double
fwd (Ix1
j' Ix1 -> Ix1 -> Ix2
:. Ix1
i) Double
v
makeFwd (Ix1
j, (a
0, Ix1
2, Ix1
_, Double
x)) = do let j' :: Ix1
j' = ClassIdMap Ix1
j2ix ClassIdMap Ix1 -> Ix1 -> Ix1
forall a. IntMap a -> Ix1 -> a
IntMap.! Ix1
j
[Ix1] -> (Ix1 -> f ()) -> f ()
forall {f :: * -> *} {t} {a}. Monad f => [t] -> (t -> f a) -> f ()
myForM_ [Ix1
a..Ix1
bIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
-Ix1
1] ((Ix1 -> f ()) -> f ()) -> (Ix1 -> f ()) -> f ()
forall a b. (a -> b) -> a -> b
$ \Ix1
i -> do
MArray (PrimState f) S Ix2 Double -> Ix2 -> Double -> f ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
UMA.unsafeWrite MArray (PrimState f) S Ix2 Double
MArray (PrimState IO) S Ix2 Double
fwd (Ix1
j' Ix1 -> Ix1 -> Ix2
:. Ix1
i) Double
x
makeFwd (Ix1
j, (a
1, Ix1
f, Ix1
_, Double
_)) = do let j' :: Ix1
j' = ClassIdMap Ix1
j2ix ClassIdMap Ix1 -> Ix1 -> Ix1
forall a. IntMap a -> Ix1 -> a
IntMap.! Ix1
j
j2 :: Ix1
j2 = ClassIdMap Ix1
j2ix ClassIdMap Ix1 -> Ix1 -> Ix1
forall a. IntMap a -> Ix1 -> a
IntMap.! (Ix1
2Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
*Ix1
j Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
+ Ix1
1)
[Ix1] -> (Ix1 -> f ()) -> f ()
forall {f :: * -> *} {t} {a}. Monad f => [t] -> (t -> f a) -> f ()
myForM_ [Ix1
a..Ix1
bIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
-Ix1
1] ((Ix1 -> f ()) -> f ()) -> (Ix1 -> f ()) -> f ()
forall a b. (a -> b) -> a -> b
$ \Ix1
i -> do
Double
v <- MArray (PrimState f) S Ix2 Double -> Ix2 -> f Double
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m e
UMA.unsafeRead MArray (PrimState f) S Ix2 Double
MArray (PrimState IO) S Ix2 Double
fwd (Ix1
j2 Ix1 -> Ix1 -> Ix2
:. Ix1
i)
MArray (PrimState f) S Ix2 Double -> Ix2 -> Double -> f ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
UMA.unsafeWrite MArray (PrimState f) S Ix2 Double
MArray (PrimState IO) S Ix2 Double
fwd (Ix1
j' Ix1 -> Ix1 -> Ix2
:. Ix1
i) (Function -> Double -> Double
forall a. Floating a => Function -> a -> a
evalFun (Ix1 -> Function
forall a. Enum a => Ix1 -> a
toEnum Ix1
f) Double
v)
makeFwd (Ix1
j, (a
2, Ix1
op, Ix1
_, Double
_)) = do let j' :: Ix1
j' = ClassIdMap Ix1
j2ix ClassIdMap Ix1 -> Ix1 -> Ix1
forall a. IntMap a -> Ix1 -> a
IntMap.! Ix1
j
j2 :: Ix1
j2 = ClassIdMap Ix1
j2ix ClassIdMap Ix1 -> Ix1 -> Ix1
forall a. IntMap a -> Ix1 -> a
IntMap.! (Ix1
2Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
*Ix1
j Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
+ Ix1
1)
j3 :: Ix1
j3 = ClassIdMap Ix1
j2ix ClassIdMap Ix1 -> Ix1 -> Ix1
forall a. IntMap a -> Ix1 -> a
IntMap.! (Ix1
2Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
*Ix1
j Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
+ Ix1
2)
[Ix1] -> (Ix1 -> f ()) -> f ()
forall {f :: * -> *} {t} {a}. Monad f => [t] -> (t -> f a) -> f ()
myForM_ [Ix1
a..Ix1
bIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
-Ix1
1] ((Ix1 -> f ()) -> f ()) -> (Ix1 -> f ()) -> f ()
forall a b. (a -> b) -> a -> b
$ \Ix1
i -> do
Double
l <- MArray (PrimState f) S Ix2 Double -> Ix2 -> f Double
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m e
UMA.unsafeRead MArray (PrimState f) S Ix2 Double
MArray (PrimState IO) S Ix2 Double
fwd (Ix1
j2 Ix1 -> Ix1 -> Ix2
:. Ix1
i)
Double
r <- MArray (PrimState f) S Ix2 Double -> Ix2 -> f Double
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m e
UMA.unsafeRead MArray (PrimState f) S Ix2 Double
MArray (PrimState IO) S Ix2 Double
fwd (Ix1
j3 Ix1 -> Ix1 -> Ix2
:. Ix1
i)
MArray (PrimState f) S Ix2 Double -> Ix2 -> Double -> f ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
UMA.unsafeWrite MArray (PrimState f) S Ix2 Double
MArray (PrimState IO) S Ix2 Double
fwd (Ix1
j' Ix1 -> Ix1 -> Ix2
:. Ix1
i) (Op -> Double -> Double -> Double
forall a. Floating a => Op -> a -> a -> a
evalOp (Ix1 -> Op
forall a. Enum a => Ix1 -> a
toEnum Ix1
op) Double
l Double
r)
makeFwd (Ix1, (a, Ix1, Ix1, Double))
_ = () -> f ()
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
{-# INLINE makeFwd #-}
{-# INLINE forward #-}
reverseMode :: (Int, Int) -> MArray (PrimState IO) S Ix2 Double -> MArray (PrimState IO) S Ix2 Double -> IO ()
reverseMode :: (Ix1, Ix1)
-> MArray (PrimState IO) S Ix2 Double
-> MArray (PrimState IO) S Ix2 Double
-> IO ()
reverseMode (Ix1
a, Ix1
b) MArray (PrimState IO) S Ix2 Double
fwd MArray (PrimState IO) S Ix2 Double
partial =
do [Ix1] -> (Ix1 -> IO ()) -> IO ()
forall {f :: * -> *} {t} {a}. Monad f => [t] -> (t -> f a) -> f ()
myForM_ [Ix1
a..Ix1
bIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
-Ix1
1] ((Ix1 -> IO ()) -> IO ()) -> (Ix1 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ix1
i -> MArray (PrimState IO) S Ix2 Double -> Ix2 -> Double -> IO ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
UMA.unsafeWrite MArray (PrimState IO) S Ix2 Double
partial (Ix1
0 Ix1 -> Ix1 -> Ix2
:. Ix1
i) Double
1
[(Ix1, (Ix1, Ix1, Ix1, Double))]
-> ((Ix1, (Ix1, Ix1, Ix1, Double)) -> IO ()) -> IO ()
forall {f :: * -> *} {t} {a}. Monad f => [t] -> (t -> f a) -> f ()
myForM_ [(Ix1, (Ix1, Ix1, Ix1, Double))]
t (Ix1, (Ix1, Ix1, Ix1, Double)) -> IO ()
forall {f :: * -> *} {a} {c} {d}.
(PrimState f ~ RealWorld, Eq a, Num a, PrimMonad f) =>
(Ix1, (a, Ix1, c, d)) -> f ()
makeRev
where
makeRev :: (Ix1, (a, Ix1, c, d)) -> f ()
makeRev (Ix1
j, (a
1, Ix1
f, c
_, d
_)) = do let dxj :: Ix1
dxj = ClassIdMap Ix1
j2ix ClassIdMap Ix1 -> Ix1 -> Ix1
forall a. IntMap a -> Ix1 -> a
IntMap.! Ix1
j
vj :: Ix1
vj = ClassIdMap Ix1
j2ix ClassIdMap Ix1 -> Ix1 -> Ix1
forall a. IntMap a -> Ix1 -> a
IntMap.! (Ix1
2Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
*Ix1
j Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
+ Ix1
1)
[Ix1] -> (Ix1 -> f ()) -> f ()
forall {f :: * -> *} {t} {a}. Monad f => [t] -> (t -> f a) -> f ()
myForM_ [Ix1
a..Ix1
bIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
-Ix1
1] ((Ix1 -> f ()) -> f ()) -> (Ix1 -> f ()) -> f ()
forall a b. (a -> b) -> a -> b
$ \Ix1
i -> do
Double
v <- MArray (PrimState f) S Ix2 Double -> Ix2 -> f Double
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m e
UMA.unsafeRead MArray (PrimState f) S Ix2 Double
MArray (PrimState IO) S Ix2 Double
fwd (Ix1
vj Ix1 -> Ix1 -> Ix2
:. Ix1
i)
Double
dx <- MArray (PrimState f) S Ix2 Double -> Ix2 -> f Double
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m e
UMA.unsafeRead MArray (PrimState f) S Ix2 Double
MArray (PrimState IO) S Ix2 Double
partial (Ix1
dxj Ix1 -> Ix1 -> Ix2
:. Ix1
i)
MArray (PrimState f) S Ix2 Double -> Ix2 -> Double -> f ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
UMA.unsafeWrite MArray (PrimState f) S Ix2 Double
MArray (PrimState IO) S Ix2 Double
partial (Ix1
vj Ix1 -> Ix1 -> Ix2
:. Ix1
i) (Double
dx Double -> Double -> Double
forall a. Num a => a -> a -> a
* Function -> Double -> Double
forall a. Floating a => Function -> a -> a
derivative (Ix1 -> Function
forall a. Enum a => Ix1 -> a
toEnum Ix1
f) Double
v)
makeRev (Ix1
j, (a
2, Ix1
op, c
_, d
_)) = do let dxj :: Ix1
dxj = ClassIdMap Ix1
j2ix ClassIdMap Ix1 -> Ix1 -> Ix1
forall a. IntMap a -> Ix1 -> a
IntMap.! Ix1
j
lj :: Ix1
lj = ClassIdMap Ix1
j2ix ClassIdMap Ix1 -> Ix1 -> Ix1
forall a. IntMap a -> Ix1 -> a
IntMap.! (Ix1
2Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
*Ix1
j Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
+ Ix1
1)
rj :: Ix1
rj = ClassIdMap Ix1
j2ix ClassIdMap Ix1 -> Ix1 -> Ix1
forall a. IntMap a -> Ix1 -> a
IntMap.! (Ix1
2Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
*Ix1
j Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
+ Ix1
2)
[Ix1] -> (Ix1 -> f ()) -> f ()
forall {f :: * -> *} {t} {a}. Monad f => [t] -> (t -> f a) -> f ()
myForM_ [Ix1
a..Ix1
bIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
-Ix1
1] ((Ix1 -> f ()) -> f ()) -> (Ix1 -> f ()) -> f ()
forall a b. (a -> b) -> a -> b
$ \Ix1
i -> do
Double
l <- MArray (PrimState f) S Ix2 Double -> Ix2 -> f Double
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m e
UMA.unsafeRead MArray (PrimState f) S Ix2 Double
MArray (PrimState IO) S Ix2 Double
fwd (Ix1
lj Ix1 -> Ix1 -> Ix2
:. Ix1
i)
Double
r <- MArray (PrimState f) S Ix2 Double -> Ix2 -> f Double
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m e
UMA.unsafeRead MArray (PrimState f) S Ix2 Double
MArray (PrimState IO) S Ix2 Double
fwd (Ix1
rj Ix1 -> Ix1 -> Ix2
:. Ix1
i)
Double
dx <- MArray (PrimState f) S Ix2 Double -> Ix2 -> f Double
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m e
UMA.unsafeRead MArray (PrimState f) S Ix2 Double
MArray (PrimState IO) S Ix2 Double
partial (Ix1
dxj Ix1 -> Ix1 -> Ix2
:. Ix1
i)
let (Double
dxl, Double
dxr) = Op -> Double -> Double -> Double -> (Double, Double)
diff (Ix1 -> Op
forall a. Enum a => Ix1 -> a
toEnum Ix1
op) Double
dx Double
l Double
r
MArray (PrimState f) S Ix2 Double -> Ix2 -> Double -> f ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
UMA.unsafeWrite MArray (PrimState f) S Ix2 Double
MArray (PrimState IO) S Ix2 Double
partial (Ix1
lj Ix1 -> Ix1 -> Ix2
:. Ix1
i) Double
dxl
MArray (PrimState f) S Ix2 Double -> Ix2 -> Double -> f ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
UMA.unsafeWrite MArray (PrimState f) S Ix2 Double
MArray (PrimState IO) S Ix2 Double
partial (Ix1
rj Ix1 -> Ix1 -> Ix2
:. Ix1
i) Double
dxr
makeRev (Ix1, (a, Ix1, c, d))
_ = () -> f ()
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
{-# INLINE makeRev #-}
{-# INLINE reverseMode #-}
fixNaN :: a -> a
fixNaN a
x | a -> Bool
forall a. RealFloat a => a -> Bool
isNaN a
x = a
0
| Bool
otherwise = a
x
diff :: Op -> Double -> Double -> Double -> (Double, Double)
diff :: Op -> Double -> Double -> Double -> (Double, Double)
diff Op
Add Double
dx Double
fx Double
gy = (Double
dx, Double
dx)
diff Op
Sub Double
dx Double
fx Double
gy = (Double
dx, Double -> Double
forall a. Num a => a -> a
negate Double
dx)
diff Op
Mul Double
dx Double
fx Double
gy = (Double
dx Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
gy, Double
dx Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
fx)
diff Op
Div Double
dx Double
fx Double
gy = (Double
dx Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
gy, Double
dx Double -> Double -> Double
forall a. Num a => a -> a -> a
* (Double -> Double
forall a. Num a => a -> a
negate Double
fx Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ (Double
gy Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
gy)))
diff Op
Power Double
0 Double
_ Double
_ = (Double
0, Double
0)
diff Op
Power Double
dx Double
0 Double
0 = (Double
0, Double
0)
diff Op
Power Double
dx Double
fx Double
0 = (Double
0, Double -> Double
forall {a}. RealFloat a => a -> a
fixNaN (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Double
dx Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Floating a => a -> a
log Double
fx)
diff Op
Power Double
dx Double
0 Double
gy = (Double -> Double
forall {a}. RealFloat a => a -> a
fixNaN (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Double
dx Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
gy Double -> Double -> Double
forall a. Num a => a -> a -> a
* if Double
gy Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
1 then Double
eps Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (Double
gy Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1) else Double
0
, Double
0)
diff Op
Power Double
dx Double
fx Double
gy = (Double -> Double
forall {a}. RealFloat a => a -> a
fixNaN (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Double
dx Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
gy Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
fx Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (Double
gy Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1), Double -> Double
forall {a}. RealFloat a => a -> a
fixNaN (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Double
dx Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
fx Double -> Double -> Double
forall a. Floating a => a -> a -> a
** Double
gy Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Floating a => a -> a
log Double
fx)
diff Op
PowerAbs Double
0 Double
fx Double
gy = (Double
0, Double
0)
diff Op
PowerAbs Double
0 Double
0 Double
0 = (Double
0, Double
0)
diff Op
PowerAbs Double
dx Double
fx Double
0 = (Double
0, Double -> Double
forall {a}. RealFloat a => a -> a
fixNaN (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Double
dx Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Floating a => a -> a
log (Double -> Double
forall a. Num a => a -> a
abs Double
fx))
diff Op
PowerAbs Double
dx Double
0 Double
gy = (Double
0, Double -> Double
forall {a}. RealFloat a => a -> a
fixNaN (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Double
dx Double -> Double -> Double
forall a. Num a => a -> a -> a
* if Double
gy Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
0 then Double
eps Double -> Double -> Double
forall a. Floating a => a -> a -> a
** Double
gy else Double
0)
diff Op
PowerAbs Double
dx Double
fx Double
gy = (Double -> Double
forall {a}. RealFloat a => a -> a
fixNaN (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Double
dx Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
gy Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
fx Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Num a => a -> a
abs Double
fx Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (Double
gy Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
2), Double -> Double
forall {a}. RealFloat a => a -> a
fixNaN (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Double
dx Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Num a => a -> a
abs Double
fx Double -> Double -> Double
forall a. Floating a => a -> a -> a
** Double
gy Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Floating a => a -> a
log (Double -> Double
forall a. Num a => a -> a
abs Double
fx))
diff Op
AQ Double
dx Double
fx Double
gy = let dxl :: Double
dxl = Double -> Double
forall a. Fractional a => a -> a
recip ((Double -> Double
forall a. Floating a => a -> a
sqrt (Double -> Double) -> (Double -> Double) -> Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double -> Double -> Double
forall a. Num a => a -> a -> a
+Double
1)) (Double
gy Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
gy))
dxy :: Double
dxy = Double
fx Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
gy Double -> Double -> Double
forall a. Num a => a -> a -> a
* (Double
dxlDouble -> Integer -> Double
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
3)
in (Double
dxl Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
dx, Double
dxy Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
dx)
{-# INLINE diff #-}
combine :: (Int, Int) -> MArray (PrimState IO) S Ix2 Double -> MArray (PrimState IO) S Ix1 Double -> IO ()
combine :: (Ix1, Ix1)
-> MArray (PrimState IO) S Ix2 Double
-> MArray (PrimState IO) S Ix1 Double
-> IO ()
combine (Ix1
lo, Ix1
hi) MArray (PrimState IO) S Ix2 Double
partial MArray (PrimState IO) S Ix1 Double
jacob = [(Ix1, (Ix1, Ix1, Ix1, Double))]
-> ((Ix1, (Ix1, Ix1, Ix1, Double)) -> IO ()) -> IO ()
forall {f :: * -> *} {t} {a}. Monad f => [t] -> (t -> f a) -> f ()
myForM_ [(Ix1, (Ix1, Ix1, Ix1, Double))]
t (Ix1, (Ix1, Ix1, Ix1, Double)) -> IO ()
forall {m :: * -> *} {a} {a} {d}.
(PrimState m ~ RealWorld, Eq a, Eq a, Num a, Num a, PrimMonad m) =>
(Ix1, (a, a, Ix1, d)) -> m ()
makeJacob
where
makeJacob :: (Ix1, (a, a, Ix1, d)) -> m ()
makeJacob (Ix1
j, (a
0, a
1, Ix1
ix, d
_)) = do Double
val <- MArray (PrimState m) S Ix1 Double -> Ix1 -> m Double
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m e
UMA.unsafeRead MArray (PrimState m) S Ix1 Double
MArray (PrimState IO) S Ix1 Double
jacob Ix1
ix
let j' :: Ix1
j' = ClassIdMap Ix1
j2ix ClassIdMap Ix1 -> Ix1 -> Ix1
forall a. IntMap a -> Ix1 -> a
IntMap.! Ix1
j
addI :: Ix1 -> Ix1 -> Double -> m Double
addI Ix1
a Ix1
b Double
acc = do Double
v2 <- MArray (PrimState m) S Ix2 Double -> Ix2 -> m Double
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m e
UMA.unsafeRead MArray (PrimState m) S Ix2 Double
MArray (PrimState IO) S Ix2 Double
partial (Ix1
b Ix1 -> Ix1 -> Ix2
:. Ix1
a)
Double -> m Double
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Double
v2 Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
acc)
Double
acc <- (Double -> Ix1 -> m Double) -> Double -> [Ix1] -> m Double
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (\Double
a Ix1
i -> Ix1 -> Ix1 -> Double -> m Double
forall {m :: * -> *}.
(PrimState m ~ RealWorld, PrimMonad m) =>
Ix1 -> Ix1 -> Double -> m Double
addI Ix1
i Ix1
j' Double
a) Double
val [Ix1
lo..Ix1
hiIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
-Ix1
1]
MArray (PrimState m) S Ix1 Double -> Ix1 -> Double -> m ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
UMA.unsafeWrite MArray (PrimState m) S Ix1 Double
MArray (PrimState IO) S Ix1 Double
jacob Ix1
ix Double
acc
makeJacob (Ix1, (a, a, Ix1, d))
_ = () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
{-# INLINE combine #-}
forwardModeUniqueJac :: SRMatrix -> PVector -> Fix SRTree -> [PVector]
forwardModeUniqueJac :: SRMatrix
-> Array S Ix1 Double -> Fix SRTree -> [Array S Ix1 Double]
forwardModeUniqueJac SRMatrix
xss Array S Ix1 Double
theta = (Array D Ix1 Double, [Array S Ix1 Double]) -> [Array S Ix1 Double]
forall a b. (a, b) -> b
snd ((Array D Ix1 Double, [Array S Ix1 Double])
-> [Array S Ix1 Double])
-> (Fix SRTree -> (Array D Ix1 Double, [Array S Ix1 Double]))
-> Fix SRTree
-> [Array S Ix1 Double]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (DList (Array D Ix1 Double) -> [Array S Ix1 Double])
-> (Array D Ix1 Double, DList (Array D Ix1 Double))
-> (Array D Ix1 Double, [Array S Ix1 Double])
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second ((Array D Ix1 Double -> Array S Ix1 Double)
-> [Array D Ix1 Double] -> [Array S Ix1 Double]
forall a b. (a -> b) -> [a] -> [b]
map (S -> Array D Ix1 Double -> Array S Ix1 Double
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
M.S) ([Array D Ix1 Double] -> [Array S Ix1 Double])
-> (DList (Array D Ix1 Double) -> [Array D Ix1 Double])
-> DList (Array D Ix1 Double)
-> [Array S Ix1 Double]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DList (Array D Ix1 Double) -> [Array D Ix1 Double]
forall a. DList a -> [a]
DL.toList) ((Array D Ix1 Double, DList (Array D Ix1 Double))
-> (Array D Ix1 Double, [Array S Ix1 Double]))
-> (Fix SRTree -> (Array D Ix1 Double, DList (Array D Ix1 Double)))
-> Fix SRTree
-> (Array D Ix1 Double, [Array S Ix1 Double])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SRTree (Array D Ix1 Double, DList (Array D Ix1 Double))
-> (Array D Ix1 Double, DList (Array D Ix1 Double)))
-> Fix SRTree -> (Array D Ix1 Double, DList (Array D Ix1 Double))
forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata SRTree (Array D Ix1 Double, DList (Array D Ix1 Double))
-> (Array D Ix1 Double, DList (Array D Ix1 Double))
alg
where
(Sz Ix1
n) = Array S Ix1 Double -> Sz Ix1
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size Array S Ix1 Double
theta
one :: Array D Ix1 Double
one = SRMatrix -> Double -> Array D Ix1 Double
replicateAs SRMatrix
xss Double
1
alg :: SRTree (Array D Ix1 Double, DList (Array D Ix1 Double))
-> (Array D Ix1 Double, DList (Array D Ix1 Double))
alg (Var Ix1
ix) = (SRMatrix
xss SRMatrix -> Ix1 -> Array D (Lower Ix2) Double
forall r ix e.
(HasCallStack, Index ix, Source r e) =>
Array r ix e -> Ix1 -> Array D (Lower ix) e
<! Ix1
ix, DList (Array D Ix1 Double)
forall a. DList a
DL.empty)
alg (Param Ix1
ix) = (SRMatrix -> Double -> Array D Ix1 Double
replicateAs SRMatrix
xss (Double -> Array D Ix1 Double) -> Double -> Array D Ix1 Double
forall a b. (a -> b) -> a -> b
$ Array S Ix1 Double
theta Array S Ix1 Double -> Ix1 -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! Ix1
ix, Array D Ix1 Double -> DList (Array D Ix1 Double)
forall a. a -> DList a
DL.singleton Array D Ix1 Double
one)
alg (Const Double
c) = (SRMatrix -> Double -> Array D Ix1 Double
replicateAs SRMatrix
xss Double
c, DList (Array D Ix1 Double)
forall a. DList a
DL.empty)
alg (Uni Function
f (Array D Ix1 Double
v, DList (Array D Ix1 Double)
gs)) = let v' :: Array D Ix1 Double
v' = Function -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Floating a => Function -> a -> a
evalFun Function
f Array D Ix1 Double
v
dv :: Array D Ix1 Double
dv = Function -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Floating a => Function -> a -> a
derivative Function
f Array D Ix1 Double
v
in (Array D Ix1 Double
v', (Array D Ix1 Double -> Array D Ix1 Double)
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a b. (a -> b) -> DList a -> DList b
DL.map (Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
*Array D Ix1 Double
dv) DList (Array D Ix1 Double)
gs)
alg (Bin Op
Add (Array D Ix1 Double
v1, DList (Array D Ix1 Double)
l) (Array D Ix1 Double
v2, DList (Array D Ix1 Double)
r)) = (Array D Ix1 Double
v1Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
+Array D Ix1 Double
v2, DList (Array D Ix1 Double)
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a. DList a -> DList a -> DList a
DL.append DList (Array D Ix1 Double)
l DList (Array D Ix1 Double)
r)
alg (Bin Op
Sub (Array D Ix1 Double
v1, DList (Array D Ix1 Double)
l) (Array D Ix1 Double
v2, DList (Array D Ix1 Double)
r)) = (Array D Ix1 Double
v1Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
-Array D Ix1 Double
v2, DList (Array D Ix1 Double)
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a. DList a -> DList a -> DList a
DL.append DList (Array D Ix1 Double)
l ((Array D Ix1 Double -> Array D Ix1 Double)
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a b. (a -> b) -> DList a -> DList b
DL.map Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a
negate DList (Array D Ix1 Double)
r))
alg (Bin Op
Mul (Array D Ix1 Double
v1, DList (Array D Ix1 Double)
l) (Array D Ix1 Double
v2, DList (Array D Ix1 Double)
r)) = (Array D Ix1 Double
v1Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
*Array D Ix1 Double
v2, DList (Array D Ix1 Double)
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a. DList a -> DList a -> DList a
DL.append ((Array D Ix1 Double -> Array D Ix1 Double)
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a b. (a -> b) -> DList a -> DList b
DL.map (Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
*Array D Ix1 Double
v2) DList (Array D Ix1 Double)
l) ((Array D Ix1 Double -> Array D Ix1 Double)
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a b. (a -> b) -> DList a -> DList b
DL.map (Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
*Array D Ix1 Double
v1) DList (Array D Ix1 Double)
r))
alg (Bin Op
Div (Array D Ix1 Double
v1, DList (Array D Ix1 Double)
l) (Array D Ix1 Double
v2, DList (Array D Ix1 Double)
r)) = let dv :: Array D Ix1 Double
dv = ((-Array D Ix1 Double
v1)Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Fractional a => a -> a -> a
/(Array D Ix1 Double
v2Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
*Array D Ix1 Double
v2))
in (Array D Ix1 Double
v1Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Fractional a => a -> a -> a
/Array D Ix1 Double
v2, DList (Array D Ix1 Double)
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a. DList a -> DList a -> DList a
DL.append ((Array D Ix1 Double -> Array D Ix1 Double)
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a b. (a -> b) -> DList a -> DList b
DL.map (Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Fractional a => a -> a -> a
/Array D Ix1 Double
v2) DList (Array D Ix1 Double)
l) ((Array D Ix1 Double -> Array D Ix1 Double)
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a b. (a -> b) -> DList a -> DList b
DL.map (Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
*Array D Ix1 Double
dv) DList (Array D Ix1 Double)
r))
alg (Bin Op
Power (Array D Ix1 Double
v1, DList (Array D Ix1 Double)
l) (Array D Ix1 Double
v2, DList (Array D Ix1 Double)
r)) = let dv1 :: Array D Ix1 Double
dv1 = Array D Ix1 Double
v1 Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Floating a => a -> a -> a
** (Array D Ix1 Double
v2 Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
- Array D Ix1 Double
one)
dv2 :: Array D Ix1 Double
dv2 = Array D Ix1 Double
v1 Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
* Array D Ix1 Double -> Array D Ix1 Double
forall a. Floating a => a -> a
log Array D Ix1 Double
v1
in (Array D Ix1 Double
v1 Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Floating a => a -> a -> a
** Array D Ix1 Double
v2, (Array D Ix1 Double -> Array D Ix1 Double)
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a b. (a -> b) -> DList a -> DList b
DL.map (Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
*Array D Ix1 Double
dv1) (DList (Array D Ix1 Double)
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a. DList a -> DList a -> DList a
DL.append ((Array D Ix1 Double -> Array D Ix1 Double)
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a b. (a -> b) -> DList a -> DList b
DL.map (Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
*Array D Ix1 Double
v2) DList (Array D Ix1 Double)
l) ((Array D Ix1 Double -> Array D Ix1 Double)
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a b. (a -> b) -> DList a -> DList b
DL.map (Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
*Array D Ix1 Double
dv2) DList (Array D Ix1 Double)
r)))
alg (Bin Op
PowerAbs (Array D Ix1 Double
v1, DList (Array D Ix1 Double)
l) (Array D Ix1 Double
v2, DList (Array D Ix1 Double)
r)) = let dv1 :: Array D Ix1 Double
dv1 = Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a
abs Array D Ix1 Double
v1 Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Floating a => a -> a -> a
** Array D Ix1 Double
v2
dv2 :: DList (Array D Ix1 Double)
dv2 = (Array D Ix1 Double -> Array D Ix1 Double)
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a b. (a -> b) -> DList a -> DList b
DL.map (Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
* (Array D Ix1 Double -> Array D Ix1 Double
forall a. Floating a => a -> a
log (Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a
abs Array D Ix1 Double
v1))) DList (Array D Ix1 Double)
r
dv3 :: DList (Array D Ix1 Double)
dv3 = (Array D Ix1 Double -> Array D Ix1 Double)
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a b. (a -> b) -> DList a -> DList b
DL.map (Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
*(Array D Ix1 Double
v2 Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Fractional a => a -> a -> a
/ Array D Ix1 Double
v1)) DList (Array D Ix1 Double)
l
in (Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a
abs Array D Ix1 Double
v1 Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Floating a => a -> a -> a
** Array D Ix1 Double
v2, (Array D Ix1 Double -> Array D Ix1 Double)
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a b. (a -> b) -> DList a -> DList b
DL.map (Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
*Array D Ix1 Double
dv1) (DList (Array D Ix1 Double)
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a. DList a -> DList a -> DList a
DL.append DList (Array D Ix1 Double)
dv2 DList (Array D Ix1 Double)
dv3))
alg (Bin Op
AQ (Array D Ix1 Double
v1, DList (Array D Ix1 Double)
l) (Array D Ix1 Double
v2, DList (Array D Ix1 Double)
r)) = let dv1 :: DList (Array D Ix1 Double)
dv1 = (Array D Ix1 Double -> Array D Ix1 Double)
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a b. (a -> b) -> DList a -> DList b
DL.map (Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
*(Array D Ix1 Double
1 Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
+ Array D Ix1 Double
v2Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
*Array D Ix1 Double
v2)) DList (Array D Ix1 Double)
l
dv2 :: DList (Array D Ix1 Double)
dv2 = (Array D Ix1 Double -> Array D Ix1 Double)
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a b. (a -> b) -> DList a -> DList b
DL.map (Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
*(-Array D Ix1 Double
v1Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
*Array D Ix1 Double
v2)) DList (Array D Ix1 Double)
r
in (Array D Ix1 Double
v1Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Fractional a => a -> a -> a
/Array D Ix1 Double -> Array D Ix1 Double
forall a. Floating a => a -> a
sqrt(Array D Ix1 Double
1 Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
+ Array D Ix1 Double
v2Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
*Array D Ix1 Double
v2), (Array D Ix1 Double -> Array D Ix1 Double)
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a b. (a -> b) -> DList a -> DList b
DL.map (Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Fractional a => a -> a -> a
/(Array D Ix1 Double
1 Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
+ Array D Ix1 Double
v2Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
*Array D Ix1 Double
v2)Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Floating a => a -> a -> a
**Array D Ix1 Double
1.5) (DList (Array D Ix1 Double) -> DList (Array D Ix1 Double))
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a b. (a -> b) -> a -> b
$ DList (Array D Ix1 Double)
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a. DList a -> DList a -> DList a
DL.append DList (Array D Ix1 Double)
dv1 DList (Array D Ix1 Double)
dv2)