{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_HADDOCK not-home #-}
module Numeric.Backprop.Internal (
BVar (..),
W (..),
backpropWithN,
evalBPN,
constVar,
liftOp,
liftOp1,
liftOp2,
liftOp3,
viewVar,
setVar,
sequenceVar,
collectVar,
previewVar,
toListOfVar,
coerceVar,
ZeroFunc (..),
zfNum,
zeroFunc,
AddFunc (..),
afNum,
addFunc,
OneFunc (..),
ofNum,
oneFunc,
debugSTN,
debugIR,
TapeNode (..),
SomeTapeNode (..),
BRef (..),
Runner (..),
InpRef (..),
initWengert,
insertNode,
bvConst,
forceBVar,
forceInpRef,
forceSomeTapeNode,
forceTapeNode,
fillWengert,
bumpMaybe,
initRunner,
gradRunner,
) where
import Control.DeepSeq
import Control.Exception
import Control.Monad
import Control.Monad.ST
import Control.Monad.Trans.State
import Data.Bifunctor
import Data.Coerce
import Data.Foldable
import Data.Function
import Data.Functor.Identity
import Data.IORef
import Data.Kind
import Data.Maybe
import Data.Monoid hiding (Any (..))
import Data.Proxy
import Data.Reflection
import Data.Type.Util
import Data.Typeable
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as MV
import Data.Vinyl.Core
import qualified Data.Vinyl.Recursive as VR
import qualified Data.Vinyl.XRec as X
import GHC.Exts (Any)
import GHC.Generics as G
import Lens.Micro
import Lens.Micro.Extras
import Numeric.Backprop.Class
import Numeric.Backprop.Op
import System.IO.Unsafe
import Unsafe.Coerce
newtype ZeroFunc a = ZF {forall a. ZeroFunc a -> a -> a
runZF :: a -> a}
newtype AddFunc a = AF {forall a. AddFunc a -> a -> a -> a
runAF :: a -> a -> a}
newtype OneFunc a = OF {forall a. OneFunc a -> a -> a
runOF :: a -> a}
zfNum :: Num a => ZeroFunc a
zfNum :: forall a. Num a => ZeroFunc a
zfNum = (a -> a) -> ZeroFunc a
forall a. (a -> a) -> ZeroFunc a
ZF (a -> a -> a
forall a b. a -> b -> a
const a
0)
{-# INLINE zfNum #-}
afNum :: Num a => AddFunc a
afNum :: forall a. Num a => AddFunc a
afNum = (a -> a -> a) -> AddFunc a
forall a. (a -> a -> a) -> AddFunc a
AF a -> a -> a
forall a. Num a => a -> a -> a
(+)
{-# INLINE afNum #-}
ofNum :: Num a => OneFunc a
ofNum :: forall a. Num a => OneFunc a
ofNum = (a -> a) -> OneFunc a
forall a. (a -> a) -> OneFunc a
OF (a -> a -> a
forall a b. a -> b -> a
const a
1)
{-# INLINE ofNum #-}
data BVar s a
=
BV
{ forall s a. BVar s a -> BRef s
_bvRef :: !(BRef s)
, forall s a. BVar s a -> a
_bvVal :: !a
}
deriving instance Typeable (BVar s a)
instance X.IsoHKD (BVar s) a
data BRef (s :: Type)
= BRInp !Int
| BRIx !Int
| BRC
deriving ((forall x. BRef s -> Rep (BRef s) x)
-> (forall x. Rep (BRef s) x -> BRef s) -> Generic (BRef s)
forall x. Rep (BRef s) x -> BRef s
forall x. BRef s -> Rep (BRef s) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall s x. Rep (BRef s) x -> BRef s
forall s x. BRef s -> Rep (BRef s) x
$cfrom :: forall s x. BRef s -> Rep (BRef s) x
from :: forall x. BRef s -> Rep (BRef s) x
$cto :: forall s x. Rep (BRef s) x -> BRef s
to :: forall x. Rep (BRef s) x -> BRef s
Generic, Int -> BRef s -> ShowS
[BRef s] -> ShowS
BRef s -> String
(Int -> BRef s -> ShowS)
-> (BRef s -> String) -> ([BRef s] -> ShowS) -> Show (BRef s)
forall s. Int -> BRef s -> ShowS
forall s. [BRef s] -> ShowS
forall s. BRef s -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall s. Int -> BRef s -> ShowS
showsPrec :: Int -> BRef s -> ShowS
$cshow :: forall s. BRef s -> String
show :: BRef s -> String
$cshowList :: forall s. [BRef s] -> ShowS
showList :: [BRef s] -> ShowS
Show)
instance NFData (BRef s)
instance NFData a => NFData (BVar s a) where
rnf :: BVar s a -> ()
rnf (BV BRef s
r a
v) = BRef s -> BRef s
forall a. NFData a => a -> a
force BRef s
r BRef s -> () -> ()
forall a b. a -> b -> b
`seq` a -> a
forall a. NFData a => a -> a
force a
v a -> () -> ()
forall a b. a -> b -> b
`seq` ()
bvConst :: BVar s a -> Maybe a
bvConst :: forall s a. BVar s a -> Maybe a
bvConst (BV BRef s
BRC !a
x) = a -> Maybe a
forall a. a -> Maybe a
Just a
x
bvConst BVar s a
_ = Maybe a
forall a. Maybe a
Nothing
{-# INLINE bvConst #-}
forceBVar :: BVar s a -> ()
forceBVar :: forall s a. BVar s a -> ()
forceBVar (BV BRef s
r !a
_) = BRef s -> BRef s
forall a. NFData a => a -> a
force BRef s
r BRef s -> () -> ()
forall a b. a -> b -> b
`seq` ()
{-# INLINE forceBVar #-}
data InpRef :: Type -> Type where
IR ::
{ ()
_irIx :: !(BVar s b)
, ()
_irAdd :: !(a -> b -> b)
, ()
_irEmbed :: !(a -> b)
} ->
InpRef a
forceInpRef :: InpRef a -> ()
forceInpRef :: forall a. InpRef a -> ()
forceInpRef (IR BVar s b
v !a -> b -> b
_ !a -> b
_) = BVar s b -> ()
forall s a. BVar s a -> ()
forceBVar BVar s b
v () -> () -> ()
forall a b. a -> b -> b
`seq` ()
{-# INLINE forceInpRef #-}
debugIR :: InpRef a -> String
debugIR :: forall a. InpRef a -> String
debugIR IR{BVar s b
a -> b
a -> b -> b
_irIx :: ()
_irAdd :: ()
_irEmbed :: ()
_irIx :: BVar s b
_irAdd :: a -> b -> b
_irEmbed :: a -> b
..} = BRef s -> String
forall a. Show a => a -> String
show (BVar s b -> BRef s
forall s a. BVar s a -> BRef s
_bvRef BVar s b
_irIx)
data TapeNode :: Type -> Type where
TN ::
{ ()
_tnInputs :: !(Rec InpRef as)
, ()
_tnGrad :: !(a -> Rec Identity as)
} ->
TapeNode a
forceTapeNode :: TapeNode a -> ()
forceTapeNode :: forall a. TapeNode a -> ()
forceTapeNode (TN Rec InpRef as
inps !a -> Rec Identity as
_) = (forall a. InpRef a -> ()) -> Rec InpRef as -> ()
forall {u} (f :: u -> *) m (rs :: [u]).
Monoid m =>
(forall (x :: u). f x -> m) -> Rec f rs -> m
VR.rfoldMap InpRef x -> ()
forall a. InpRef a -> ()
forceInpRef Rec InpRef as
inps () -> () -> ()
forall a b. a -> b -> b
`seq` ()
{-# INLINE forceTapeNode #-}
data SomeTapeNode :: Type where
STN ::
{ ()
_stnNode :: !(TapeNode a)
} ->
SomeTapeNode
forceSomeTapeNode :: SomeTapeNode -> ()
forceSomeTapeNode :: SomeTapeNode -> ()
forceSomeTapeNode (STN TapeNode a
n) = TapeNode a -> ()
forall a. TapeNode a -> ()
forceTapeNode TapeNode a
n
debugSTN :: SomeTapeNode -> String
debugSTN :: SomeTapeNode -> String
debugSTN (STN TN{Rec InpRef as
a -> Rec Identity as
_tnInputs :: ()
_tnGrad :: ()
_tnInputs :: Rec InpRef as
_tnGrad :: a -> Rec Identity as
..}) = [String] -> String
forall a. Show a => a -> String
show ([String] -> String)
-> (Rec InpRef as -> [String]) -> Rec InpRef as -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall x. InpRef x -> [String]) -> Rec InpRef as -> [String]
forall {u} (f :: u -> *) m (rs :: [u]).
Monoid m =>
(forall (x :: u). f x -> m) -> Rec f rs -> m
VR.rfoldMap ((String -> [String] -> [String]
forall a. a -> [a] -> [a]
: []) (String -> [String])
-> (InpRef x -> String) -> InpRef x -> [String]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. InpRef x -> String
forall a. InpRef a -> String
debugIR) (Rec InpRef as -> String) -> Rec InpRef as -> String
forall a b. (a -> b) -> a -> b
$ Rec InpRef as
_tnInputs
newtype W
=
W {W -> IORef (Int, [SomeTapeNode])
wRef :: IORef (Int, [SomeTapeNode])}
initWengert :: IO W
initWengert :: IO W
initWengert = IORef (Int, [SomeTapeNode]) -> W
W (IORef (Int, [SomeTapeNode]) -> W)
-> IO (IORef (Int, [SomeTapeNode])) -> IO W
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Int, [SomeTapeNode]) -> IO (IORef (Int, [SomeTapeNode]))
forall a. a -> IO (IORef a)
newIORef (Int
0, [])
{-# INLINE initWengert #-}
insertNode ::
TapeNode a ->
a ->
W ->
IO (BVar s a)
insertNode :: forall a s. TapeNode a -> a -> W -> IO (BVar s a)
insertNode TapeNode a
tn !a
x !W
w = (Int -> BVar s a) -> IO Int -> IO (BVar s a)
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((BRef s -> a -> BVar s a
forall s a. BRef s -> a -> BVar s a
`BV` a
x) (BRef s -> BVar s a) -> (Int -> BRef s) -> Int -> BVar s a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> BRef s
forall s. Int -> BRef s
BRIx) (IO Int -> IO (BVar s a))
-> (((Int, [SomeTapeNode]) -> ((Int, [SomeTapeNode]), Int))
-> IO Int)
-> ((Int, [SomeTapeNode]) -> ((Int, [SomeTapeNode]), Int))
-> IO (BVar s a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IORef (Int, [SomeTapeNode])
-> ((Int, [SomeTapeNode]) -> ((Int, [SomeTapeNode]), Int))
-> IO Int
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' (W -> IORef (Int, [SomeTapeNode])
wRef W
w) (((Int, [SomeTapeNode]) -> ((Int, [SomeTapeNode]), Int))
-> IO (BVar s a))
-> ((Int, [SomeTapeNode]) -> ((Int, [SomeTapeNode]), Int))
-> IO (BVar s a)
forall a b. (a -> b) -> a -> b
$ \(!Int
n, ![SomeTapeNode]
t) ->
let n' :: Int
n' = Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
t' :: [SomeTapeNode]
t' = TapeNode a -> SomeTapeNode
forall s. TapeNode s -> SomeTapeNode
STN TapeNode a
tn SomeTapeNode -> [SomeTapeNode] -> [SomeTapeNode]
forall a. a -> [a] -> [a]
: [SomeTapeNode]
t
in TapeNode a -> ()
forall a. TapeNode a -> ()
forceTapeNode TapeNode a
tn () -> ((Int, [SomeTapeNode]), Int) -> ((Int, [SomeTapeNode]), Int)
forall a b. a -> b -> b
`seq` Int
n' Int -> ((Int, [SomeTapeNode]), Int) -> ((Int, [SomeTapeNode]), Int)
forall a b. a -> b -> b
`seq` [SomeTapeNode]
t' [SomeTapeNode]
-> ((Int, [SomeTapeNode]), Int) -> ((Int, [SomeTapeNode]), Int)
forall a b. a -> b -> b
`seq` ((Int
n', [SomeTapeNode]
t'), Int
n)
{-# INLINE insertNode #-}
constVar :: a -> BVar s a
constVar :: forall a s. a -> BVar s a
constVar = BRef s -> a -> BVar s a
forall s a. BRef s -> a -> BVar s a
BV BRef s
forall s. BRef s
BRC
{-# INLINE constVar #-}
liftOp_ ::
forall s as b.
Reifies s W =>
Rec AddFunc as ->
Op as b ->
Rec (BVar s) as ->
IO (BVar s b)
liftOp_ :: forall s (as :: [*]) b.
Reifies s W =>
Rec AddFunc as -> Op as b -> Rec (BVar s) as -> IO (BVar s b)
liftOp_ Rec AddFunc as
afs Op as b
o !Rec (BVar s) as
vs = case (forall x. BVar s x -> Maybe (Identity x))
-> Rec (BVar s) as -> Maybe (Rec Identity as)
forall {u} (h :: * -> *) (f :: u -> *) (g :: u -> *) (rs :: [u]).
Applicative h =>
(forall (x :: u). f x -> h (g x)) -> Rec f rs -> h (Rec g rs)
rtraverse ((x -> Identity x) -> Maybe x -> Maybe (Identity x)
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap x -> Identity x
forall a. a -> Identity a
Identity (Maybe x -> Maybe (Identity x))
-> (BVar s x -> Maybe x) -> BVar s x -> Maybe (Identity x)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BVar s x -> Maybe x
forall s a. BVar s a -> Maybe a
bvConst) Rec (BVar s) as
vs of
Just Rec Identity as
xs -> BVar s b -> IO (BVar s b)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (BVar s b -> IO (BVar s b)) -> BVar s b -> IO (BVar s b)
forall a b. (a -> b) -> a -> b
$ b -> BVar s b
forall a s. a -> BVar s a
constVar (Op as b -> Rec Identity as -> b
forall (as :: [*]) a. Op as a -> Rec Identity as -> a
evalOp Op as b
o Rec Identity as
xs)
Maybe (Rec Identity as)
Nothing -> TapeNode b -> b -> W -> IO (BVar s b)
forall a s. TapeNode a -> a -> W -> IO (BVar s a)
insertNode TapeNode b
tn b
y (Proxy s -> W
forall {k} (s :: k) a (proxy :: k -> *).
Reifies s a =>
proxy s -> a
forall (proxy :: * -> *). proxy s -> W
reflect (forall t. Proxy t
forall {k} (t :: k). Proxy t
Proxy @s))
where
(b
y, b -> Rec Identity as
g) = Op as b -> Rec Identity as -> (b, b -> Rec Identity as)
forall (as :: [*]) a.
Op as a -> Rec Identity as -> (a, a -> Rec Identity as)
runOpWith Op as b
o ((forall x. BVar s x -> Identity x)
-> Rec (BVar s) as -> Rec Identity as
forall {u} (f :: u -> *) (g :: u -> *) (rs :: [u]).
(forall (x :: u). f x -> g x) -> Rec f rs -> Rec g rs
VR.rmap (x -> Identity x
forall a. a -> Identity a
Identity (x -> Identity x) -> (BVar s x -> x) -> BVar s x -> Identity x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BVar s x -> x
forall s a. BVar s a -> a
_bvVal) Rec (BVar s) as
vs)
tn :: TapeNode b
tn =
TN
{ _tnInputs :: Rec InpRef as
_tnInputs = (forall x. AddFunc x -> BVar s x -> InpRef x)
-> forall (xs :: [*]).
Rec AddFunc xs -> Rec (BVar s) xs -> Rec InpRef xs
forall {k} (f :: k -> *) (g :: k -> *) (h :: k -> *).
(forall (x :: k). f x -> g x -> h x)
-> forall (xs :: [k]). Rec f xs -> Rec g xs -> Rec h xs
VR.rzipWith AddFunc x -> BVar s x -> InpRef x
forall x. AddFunc x -> BVar s x -> InpRef x
go Rec AddFunc as
afs Rec (BVar s) as
vs
, _tnGrad :: b -> Rec Identity as
_tnGrad = b -> Rec Identity as
g
}
go :: forall a. AddFunc a -> BVar s a -> InpRef a
go :: forall x. AddFunc x -> BVar s x -> InpRef x
go AddFunc a
af !BVar s a
v = BVar s a -> ()
forall s a. BVar s a -> ()
forceBVar BVar s a
v () -> InpRef a -> InpRef a
forall a b. a -> b -> b
`seq` BVar s a -> (a -> a -> a) -> (a -> a) -> InpRef a
forall s b a. BVar s b -> (a -> b -> b) -> (a -> b) -> InpRef a
IR BVar s a
v (AddFunc a -> a -> a -> a
forall a. AddFunc a -> a -> a -> a
runAF AddFunc a
af) a -> a
forall a. a -> a
id
{-# INLINE go #-}
{-# INLINE liftOp_ #-}
liftOp ::
forall as b s.
Reifies s W =>
Rec AddFunc as ->
Op as b ->
Rec (BVar s) as ->
BVar s b
liftOp :: forall (as :: [*]) b s.
Reifies s W =>
Rec AddFunc as -> Op as b -> Rec (BVar s) as -> BVar s b
liftOp Rec AddFunc as
afs Op as b
o !Rec (BVar s) as
vs = IO (BVar s b) -> BVar s b
forall a. IO a -> a
unsafePerformIO (IO (BVar s b) -> BVar s b) -> IO (BVar s b) -> BVar s b
forall a b. (a -> b) -> a -> b
$ Rec AddFunc as -> Op as b -> Rec (BVar s) as -> IO (BVar s b)
forall s (as :: [*]) b.
Reifies s W =>
Rec AddFunc as -> Op as b -> Rec (BVar s) as -> IO (BVar s b)
liftOp_ Rec AddFunc as
afs Op as b
o Rec (BVar s) as
vs
{-# INLINE liftOp #-}
liftOp1_ ::
forall a b s.
Reifies s W =>
AddFunc a ->
Op '[a] b ->
BVar s a ->
IO (BVar s b)
liftOp1_ :: forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> IO (BVar s b)
liftOp1_ AddFunc a
_ Op '[a] b
o (BVar s a -> Maybe a
forall s a. BVar s a -> Maybe a
bvConst -> Just a
x) = BVar s b -> IO (BVar s b)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (BVar s b -> IO (BVar s b))
-> (Rec Identity '[a] -> BVar s b)
-> Rec Identity '[a]
-> IO (BVar s b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> BVar s b
forall a s. a -> BVar s a
constVar (b -> BVar s b)
-> (Rec Identity '[a] -> b) -> Rec Identity '[a] -> BVar s b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op '[a] b -> Rec Identity '[a] -> b
forall (as :: [*]) a. Op as a -> Rec Identity as -> a
evalOp Op '[a] b
o (Rec Identity '[a] -> IO (BVar s b))
-> Rec Identity '[a] -> IO (BVar s b)
forall a b. (a -> b) -> a -> b
$ (a -> Identity a
forall a. a -> Identity a
Identity a
x Identity a -> Rec Identity '[] -> Rec Identity '[a]
forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& Rec Identity '[]
forall {u} (a :: u -> *). Rec a '[]
RNil)
liftOp1_ AddFunc a
af Op '[a] b
o BVar s a
v = BVar s a -> ()
forall s a. BVar s a -> ()
forceBVar BVar s a
v () -> IO (BVar s b) -> IO (BVar s b)
forall a b. a -> b -> b
`seq` TapeNode b -> b -> W -> IO (BVar s b)
forall a s. TapeNode a -> a -> W -> IO (BVar s a)
insertNode TapeNode b
tn b
y (Proxy s -> W
forall {k} (s :: k) a (proxy :: k -> *).
Reifies s a =>
proxy s -> a
forall (proxy :: * -> *). proxy s -> W
reflect (forall t. Proxy t
forall {k} (t :: k). Proxy t
Proxy @s))
where
(b
y, b -> Rec Identity '[a]
g) = Op '[a] b -> Rec Identity '[a] -> (b, b -> Rec Identity '[a])
forall (as :: [*]) a.
Op as a -> Rec Identity as -> (a, a -> Rec Identity as)
runOpWith Op '[a] b
o (a -> Identity a
forall a. a -> Identity a
Identity (BVar s a -> a
forall s a. BVar s a -> a
_bvVal BVar s a
v) Identity a -> Rec Identity '[] -> Rec Identity '[a]
forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& Rec Identity '[]
forall {u} (a :: u -> *). Rec a '[]
RNil)
tn :: TapeNode b
tn =
TN
{ _tnInputs :: Rec InpRef '[a]
_tnInputs = BVar s a -> (a -> a -> a) -> (a -> a) -> InpRef a
forall s b a. BVar s b -> (a -> b -> b) -> (a -> b) -> InpRef a
IR BVar s a
v (AddFunc a -> a -> a -> a
forall a. AddFunc a -> a -> a -> a
runAF AddFunc a
af) a -> a
forall a. a -> a
id InpRef a -> Rec InpRef '[] -> Rec InpRef '[a]
forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& Rec InpRef '[]
forall {u} (a :: u -> *). Rec a '[]
RNil
, _tnGrad :: b -> Rec Identity '[a]
_tnGrad = b -> Rec Identity '[a]
g
}
{-# INLINE liftOp1_ #-}
liftOp1 ::
forall a b s.
Reifies s W =>
AddFunc a ->
Op '[a] b ->
BVar s a ->
BVar s b
liftOp1 :: forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> BVar s b
liftOp1 AddFunc a
af Op '[a] b
o !BVar s a
v = IO (BVar s b) -> BVar s b
forall a. IO a -> a
unsafePerformIO (IO (BVar s b) -> BVar s b) -> IO (BVar s b) -> BVar s b
forall a b. (a -> b) -> a -> b
$ AddFunc a -> Op '[a] b -> BVar s a -> IO (BVar s b)
forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> IO (BVar s b)
liftOp1_ AddFunc a
af Op '[a] b
o BVar s a
v
{-# INLINE liftOp1 #-}
liftOp2_ ::
forall a b c s.
Reifies s W =>
AddFunc a ->
AddFunc b ->
Op '[a, b] c ->
BVar s a ->
BVar s b ->
IO (BVar s c)
liftOp2_ :: forall a b c s.
Reifies s W =>
AddFunc a
-> AddFunc b
-> Op '[a, b] c
-> BVar s a
-> BVar s b
-> IO (BVar s c)
liftOp2_ AddFunc a
_ AddFunc b
_ Op '[a, b] c
o (BVar s a -> Maybe a
forall s a. BVar s a -> Maybe a
bvConst -> Just a
x) (BVar s b -> Maybe b
forall s a. BVar s a -> Maybe a
bvConst -> Just b
y) =
BVar s c -> IO (BVar s c)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (BVar s c -> IO (BVar s c))
-> (Rec Identity '[a, b] -> BVar s c)
-> Rec Identity '[a, b]
-> IO (BVar s c)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. c -> BVar s c
forall a s. a -> BVar s a
constVar (c -> BVar s c)
-> (Rec Identity '[a, b] -> c) -> Rec Identity '[a, b] -> BVar s c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op '[a, b] c -> Rec Identity '[a, b] -> c
forall (as :: [*]) a. Op as a -> Rec Identity as -> a
evalOp Op '[a, b] c
o (Rec Identity '[a, b] -> IO (BVar s c))
-> Rec Identity '[a, b] -> IO (BVar s c)
forall a b. (a -> b) -> a -> b
$ a -> Identity a
forall a. a -> Identity a
Identity a
x Identity a -> Rec Identity '[b] -> Rec Identity '[a, b]
forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& b -> Identity b
forall a. a -> Identity a
Identity b
y Identity b -> Rec Identity '[] -> Rec Identity '[b]
forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& Rec Identity '[]
forall {u} (a :: u -> *). Rec a '[]
RNil
liftOp2_ AddFunc a
afa AddFunc b
afb Op '[a, b] c
o BVar s a
v BVar s b
u =
BVar s a -> ()
forall s a. BVar s a -> ()
forceBVar BVar s a
v () -> IO (BVar s c) -> IO (BVar s c)
forall a b. a -> b -> b
`seq`
BVar s b -> ()
forall s a. BVar s a -> ()
forceBVar BVar s b
u () -> IO (BVar s c) -> IO (BVar s c)
forall a b. a -> b -> b
`seq`
TapeNode c -> c -> W -> IO (BVar s c)
forall a s. TapeNode a -> a -> W -> IO (BVar s a)
insertNode TapeNode c
tn c
y (Proxy s -> W
forall {k} (s :: k) a (proxy :: k -> *).
Reifies s a =>
proxy s -> a
forall (proxy :: * -> *). proxy s -> W
reflect (forall t. Proxy t
forall {k} (t :: k). Proxy t
Proxy @s))
where
(c
y, c -> Rec Identity '[a, b]
g) =
Op '[a, b] c
-> Rec Identity '[a, b] -> (c, c -> Rec Identity '[a, b])
forall (as :: [*]) a.
Op as a -> Rec Identity as -> (a, a -> Rec Identity as)
runOpWith Op '[a, b] c
o (Rec Identity '[a, b] -> (c, c -> Rec Identity '[a, b]))
-> Rec Identity '[a, b] -> (c, c -> Rec Identity '[a, b])
forall a b. (a -> b) -> a -> b
$
a -> Identity a
forall a. a -> Identity a
Identity (BVar s a -> a
forall s a. BVar s a -> a
_bvVal BVar s a
v)
Identity a -> Rec Identity '[b] -> Rec Identity '[a, b]
forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& b -> Identity b
forall a. a -> Identity a
Identity (BVar s b -> b
forall s a. BVar s a -> a
_bvVal BVar s b
u)
Identity b -> Rec Identity '[] -> Rec Identity '[b]
forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& Rec Identity '[]
forall {u} (a :: u -> *). Rec a '[]
RNil
tn :: TapeNode c
tn =
TN
{ _tnInputs :: Rec InpRef '[a, b]
_tnInputs = BVar s a -> (a -> a -> a) -> (a -> a) -> InpRef a
forall s b a. BVar s b -> (a -> b -> b) -> (a -> b) -> InpRef a
IR BVar s a
v (AddFunc a -> a -> a -> a
forall a. AddFunc a -> a -> a -> a
runAF AddFunc a
afa) a -> a
forall a. a -> a
id InpRef a -> Rec InpRef '[b] -> Rec InpRef '[a, b]
forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& BVar s b -> (b -> b -> b) -> (b -> b) -> InpRef b
forall s b a. BVar s b -> (a -> b -> b) -> (a -> b) -> InpRef a
IR BVar s b
u (AddFunc b -> b -> b -> b
forall a. AddFunc a -> a -> a -> a
runAF AddFunc b
afb) b -> b
forall a. a -> a
id InpRef b -> Rec InpRef '[] -> Rec InpRef '[b]
forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& Rec InpRef '[]
forall {u} (a :: u -> *). Rec a '[]
RNil
, _tnGrad :: c -> Rec Identity '[a, b]
_tnGrad = c -> Rec Identity '[a, b]
g
}
{-# INLINE liftOp2_ #-}
liftOp2 ::
forall a b c s.
Reifies s W =>
AddFunc a ->
AddFunc b ->
Op '[a, b] c ->
BVar s a ->
BVar s b ->
BVar s c
liftOp2 :: forall a b c s.
Reifies s W =>
AddFunc a
-> AddFunc b -> Op '[a, b] c -> BVar s a -> BVar s b -> BVar s c
liftOp2 AddFunc a
afa AddFunc b
afb Op '[a, b] c
o !BVar s a
v !BVar s b
u = IO (BVar s c) -> BVar s c
forall a. IO a -> a
unsafePerformIO (IO (BVar s c) -> BVar s c) -> IO (BVar s c) -> BVar s c
forall a b. (a -> b) -> a -> b
$ AddFunc a
-> AddFunc b
-> Op '[a, b] c
-> BVar s a
-> BVar s b
-> IO (BVar s c)
forall a b c s.
Reifies s W =>
AddFunc a
-> AddFunc b
-> Op '[a, b] c
-> BVar s a
-> BVar s b
-> IO (BVar s c)
liftOp2_ AddFunc a
afa AddFunc b
afb Op '[a, b] c
o BVar s a
v BVar s b
u
{-# INLINE liftOp2 #-}
liftOp3_ ::
forall a b c d s.
Reifies s W =>
AddFunc a ->
AddFunc b ->
AddFunc c ->
Op '[a, b, c] d ->
BVar s a ->
BVar s b ->
BVar s c ->
IO (BVar s d)
liftOp3_ :: forall a b c d s.
Reifies s W =>
AddFunc a
-> AddFunc b
-> AddFunc c
-> Op '[a, b, c] d
-> BVar s a
-> BVar s b
-> BVar s c
-> IO (BVar s d)
liftOp3_ AddFunc a
_ AddFunc b
_ AddFunc c
_ Op '[a, b, c] d
o (BVar s a -> Maybe a
forall s a. BVar s a -> Maybe a
bvConst -> Just a
x) (BVar s b -> Maybe b
forall s a. BVar s a -> Maybe a
bvConst -> Just b
y) (BVar s c -> Maybe c
forall s a. BVar s a -> Maybe a
bvConst -> Just c
z) =
BVar s d -> IO (BVar s d)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (BVar s d -> IO (BVar s d))
-> (Rec Identity '[a, b, c] -> BVar s d)
-> Rec Identity '[a, b, c]
-> IO (BVar s d)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. d -> BVar s d
forall a s. a -> BVar s a
constVar (d -> BVar s d)
-> (Rec Identity '[a, b, c] -> d)
-> Rec Identity '[a, b, c]
-> BVar s d
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op '[a, b, c] d -> Rec Identity '[a, b, c] -> d
forall (as :: [*]) a. Op as a -> Rec Identity as -> a
evalOp Op '[a, b, c] d
o (Rec Identity '[a, b, c] -> IO (BVar s d))
-> Rec Identity '[a, b, c] -> IO (BVar s d)
forall a b. (a -> b) -> a -> b
$
a -> Identity a
forall a. a -> Identity a
Identity a
x
Identity a -> Rec Identity '[b, c] -> Rec Identity '[a, b, c]
forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& b -> Identity b
forall a. a -> Identity a
Identity b
y
Identity b -> Rec Identity '[c] -> Rec Identity '[b, c]
forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& c -> Identity c
forall a. a -> Identity a
Identity c
z
Identity c -> Rec Identity '[] -> Rec Identity '[c]
forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& Rec Identity '[]
forall {u} (a :: u -> *). Rec a '[]
RNil
liftOp3_ AddFunc a
afa AddFunc b
afb AddFunc c
afc Op '[a, b, c] d
o BVar s a
v BVar s b
u BVar s c
w =
BVar s a -> ()
forall s a. BVar s a -> ()
forceBVar BVar s a
v () -> IO (BVar s d) -> IO (BVar s d)
forall a b. a -> b -> b
`seq`
BVar s b -> ()
forall s a. BVar s a -> ()
forceBVar BVar s b
u () -> IO (BVar s d) -> IO (BVar s d)
forall a b. a -> b -> b
`seq`
BVar s c -> ()
forall s a. BVar s a -> ()
forceBVar BVar s c
w () -> IO (BVar s d) -> IO (BVar s d)
forall a b. a -> b -> b
`seq`
TapeNode d -> d -> W -> IO (BVar s d)
forall a s. TapeNode a -> a -> W -> IO (BVar s a)
insertNode TapeNode d
tn d
y (Proxy s -> W
forall {k} (s :: k) a (proxy :: k -> *).
Reifies s a =>
proxy s -> a
forall (proxy :: * -> *). proxy s -> W
reflect (forall t. Proxy t
forall {k} (t :: k). Proxy t
Proxy @s))
where
(d
y, d -> Rec Identity '[a, b, c]
g) =
Op '[a, b, c] d
-> Rec Identity '[a, b, c] -> (d, d -> Rec Identity '[a, b, c])
forall (as :: [*]) a.
Op as a -> Rec Identity as -> (a, a -> Rec Identity as)
runOpWith Op '[a, b, c] d
o (Rec Identity '[a, b, c] -> (d, d -> Rec Identity '[a, b, c]))
-> Rec Identity '[a, b, c] -> (d, d -> Rec Identity '[a, b, c])
forall a b. (a -> b) -> a -> b
$
a -> Identity a
forall a. a -> Identity a
Identity (BVar s a -> a
forall s a. BVar s a -> a
_bvVal BVar s a
v)
Identity a -> Rec Identity '[b, c] -> Rec Identity '[a, b, c]
forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& b -> Identity b
forall a. a -> Identity a
Identity (BVar s b -> b
forall s a. BVar s a -> a
_bvVal BVar s b
u)
Identity b -> Rec Identity '[c] -> Rec Identity '[b, c]
forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& c -> Identity c
forall a. a -> Identity a
Identity (BVar s c -> c
forall s a. BVar s a -> a
_bvVal BVar s c
w)
Identity c -> Rec Identity '[] -> Rec Identity '[c]
forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& Rec Identity '[]
forall {u} (a :: u -> *). Rec a '[]
RNil
tn :: TapeNode d
tn =
TN
{ _tnInputs :: Rec InpRef '[a, b, c]
_tnInputs =
BVar s a -> (a -> a -> a) -> (a -> a) -> InpRef a
forall s b a. BVar s b -> (a -> b -> b) -> (a -> b) -> InpRef a
IR BVar s a
v (AddFunc a -> a -> a -> a
forall a. AddFunc a -> a -> a -> a
runAF AddFunc a
afa) a -> a
forall a. a -> a
id
InpRef a -> Rec InpRef '[b, c] -> Rec InpRef '[a, b, c]
forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& BVar s b -> (b -> b -> b) -> (b -> b) -> InpRef b
forall s b a. BVar s b -> (a -> b -> b) -> (a -> b) -> InpRef a
IR BVar s b
u (AddFunc b -> b -> b -> b
forall a. AddFunc a -> a -> a -> a
runAF AddFunc b
afb) b -> b
forall a. a -> a
id
InpRef b -> Rec InpRef '[c] -> Rec InpRef '[b, c]
forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& BVar s c -> (c -> c -> c) -> (c -> c) -> InpRef c
forall s b a. BVar s b -> (a -> b -> b) -> (a -> b) -> InpRef a
IR BVar s c
w (AddFunc c -> c -> c -> c
forall a. AddFunc a -> a -> a -> a
runAF AddFunc c
afc) c -> c
forall a. a -> a
id
InpRef c -> Rec InpRef '[] -> Rec InpRef '[c]
forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& Rec InpRef '[]
forall {u} (a :: u -> *). Rec a '[]
RNil
, _tnGrad :: d -> Rec Identity '[a, b, c]
_tnGrad = d -> Rec Identity '[a, b, c]
g
}
{-# INLINE liftOp3_ #-}
liftOp3 ::
forall a b c d s.
Reifies s W =>
AddFunc a ->
AddFunc b ->
AddFunc c ->
Op '[a, b, c] d ->
BVar s a ->
BVar s b ->
BVar s c ->
BVar s d
liftOp3 :: forall a b c d s.
Reifies s W =>
AddFunc a
-> AddFunc b
-> AddFunc c
-> Op '[a, b, c] d
-> BVar s a
-> BVar s b
-> BVar s c
-> BVar s d
liftOp3 AddFunc a
afa AddFunc b
afb AddFunc c
afc Op '[a, b, c] d
o !BVar s a
v !BVar s b
u !BVar s c
w = IO (BVar s d) -> BVar s d
forall a. IO a -> a
unsafePerformIO (IO (BVar s d) -> BVar s d) -> IO (BVar s d) -> BVar s d
forall a b. (a -> b) -> a -> b
$ AddFunc a
-> AddFunc b
-> AddFunc c
-> Op '[a, b, c] d
-> BVar s a
-> BVar s b
-> BVar s c
-> IO (BVar s d)
forall a b c d s.
Reifies s W =>
AddFunc a
-> AddFunc b
-> AddFunc c
-> Op '[a, b, c] d
-> BVar s a
-> BVar s b
-> BVar s c
-> IO (BVar s d)
liftOp3_ AddFunc a
afa AddFunc b
afb AddFunc c
afc Op '[a, b, c] d
o BVar s a
v BVar s b
u BVar s c
w
{-# INLINE liftOp3 #-}
viewVar_ ::
forall a b s.
Reifies s W =>
AddFunc a ->
ZeroFunc b ->
Lens' b a ->
BVar s b ->
IO (BVar s a)
viewVar_ :: forall a b s.
Reifies s W =>
AddFunc a -> ZeroFunc b -> Lens' b a -> BVar s b -> IO (BVar s a)
viewVar_ AddFunc a
af ZeroFunc b
z Lens' b a
l BVar s b
v = BVar s b -> ()
forall s a. BVar s a -> ()
forceBVar BVar s b
v () -> IO (BVar s a) -> IO (BVar s a)
forall a b. a -> b -> b
`seq` TapeNode a -> a -> W -> IO (BVar s a)
forall a s. TapeNode a -> a -> W -> IO (BVar s a)
insertNode TapeNode a
tn a
y (Proxy s -> W
forall {k} (s :: k) a (proxy :: k -> *).
Reifies s a =>
proxy s -> a
forall (proxy :: * -> *). proxy s -> W
reflect (forall t. Proxy t
forall {k} (t :: k). Proxy t
Proxy @s))
where
x :: b
x = BVar s b -> b
forall s a. BVar s a -> a
_bvVal BVar s b
v
y :: a
y = b
x b -> Getting a b a -> a
forall s a. s -> Getting a s a -> a
^. Getting a b a
Lens' b a
l
tn :: TapeNode a
tn =
TN
{ _tnInputs :: Rec InpRef '[a]
_tnInputs =
BVar s b -> (a -> b -> b) -> (a -> b) -> InpRef a
forall s b a. BVar s b -> (a -> b -> b) -> (a -> b) -> InpRef a
IR BVar s b
v (ASetter b b a a -> (a -> a) -> b -> b
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter b b a a
Lens' b a
l ((a -> a) -> b -> b) -> (a -> a -> a) -> a -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AddFunc a -> a -> a -> a
forall a. AddFunc a -> a -> a -> a
runAF AddFunc a
af) (\a
g -> ASetter b b a a -> a -> b -> b
forall s t a b. ASetter s t a b -> b -> s -> t
set ASetter b b a a
Lens' b a
l a
g (ZeroFunc b -> b -> b
forall a. ZeroFunc a -> a -> a
runZF ZeroFunc b
z b
x))
InpRef a -> Rec InpRef '[] -> Rec InpRef '[a]
forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& Rec InpRef '[]
forall {u} (a :: u -> *). Rec a '[]
RNil
, _tnGrad :: a -> Rec Identity '[a]
_tnGrad = (Identity a -> Rec Identity '[] -> Rec Identity '[a]
forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& Rec Identity '[]
forall {u} (a :: u -> *). Rec a '[]
RNil) (Identity a -> Rec Identity '[a])
-> (a -> Identity a) -> a -> Rec Identity '[a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Identity a
forall a. a -> Identity a
Identity
}
{-# INLINE viewVar_ #-}
viewVar ::
forall a b s.
Reifies s W =>
AddFunc a ->
ZeroFunc b ->
Lens' b a ->
BVar s b ->
BVar s a
viewVar :: forall a b s.
Reifies s W =>
AddFunc a -> ZeroFunc b -> Lens' b a -> BVar s b -> BVar s a
viewVar AddFunc a
af ZeroFunc b
z Lens' b a
l !BVar s b
v = IO (BVar s a) -> BVar s a
forall a. IO a -> a
unsafePerformIO (IO (BVar s a) -> BVar s a) -> IO (BVar s a) -> BVar s a
forall a b. (a -> b) -> a -> b
$ AddFunc a -> ZeroFunc b -> Lens' b a -> BVar s b -> IO (BVar s a)
forall a b s.
Reifies s W =>
AddFunc a -> ZeroFunc b -> Lens' b a -> BVar s b -> IO (BVar s a)
viewVar_ AddFunc a
af ZeroFunc b
z (a -> f a) -> b -> f b
Lens' b a
l BVar s b
v
{-# INLINE viewVar #-}
setVar_ ::
forall a b s.
Reifies s W =>
AddFunc a ->
AddFunc b ->
ZeroFunc a ->
Lens' b a ->
BVar s a ->
BVar s b ->
IO (BVar s b)
setVar_ :: forall a b s.
Reifies s W =>
AddFunc a
-> AddFunc b
-> ZeroFunc a
-> Lens' b a
-> BVar s a
-> BVar s b
-> IO (BVar s b)
setVar_ AddFunc a
afa AddFunc b
afb ZeroFunc a
za Lens' b a
l BVar s a
w BVar s b
v =
BVar s b -> ()
forall s a. BVar s a -> ()
forceBVar BVar s b
v () -> IO (BVar s b) -> IO (BVar s b)
forall a b. a -> b -> b
`seq`
BVar s a -> ()
forall s a. BVar s a -> ()
forceBVar BVar s a
w () -> IO (BVar s b) -> IO (BVar s b)
forall a b. a -> b -> b
`seq`
TapeNode b -> b -> W -> IO (BVar s b)
forall a s. TapeNode a -> a -> W -> IO (BVar s a)
insertNode TapeNode b
tn b
y (Proxy s -> W
forall {k} (s :: k) a (proxy :: k -> *).
Reifies s a =>
proxy s -> a
forall (proxy :: * -> *). proxy s -> W
reflect (forall t. Proxy t
forall {k} (t :: k). Proxy t
Proxy @s))
where
y :: b
y = BVar s b -> b
forall s a. BVar s a -> a
_bvVal BVar s b
v b -> (b -> b) -> b
forall a b. a -> (a -> b) -> b
& (a -> Identity a) -> b -> Identity b
Lens' b a
l ((a -> Identity a) -> b -> Identity b) -> a -> b -> b
forall s t a b. ASetter s t a b -> b -> s -> t
.~ BVar s a -> a
forall s a. BVar s a -> a
_bvVal BVar s a
w
tn :: TapeNode b
tn =
TN
{ _tnInputs :: Rec InpRef '[a, b]
_tnInputs =
BVar s a -> (a -> a -> a) -> (a -> a) -> InpRef a
forall s b a. BVar s b -> (a -> b -> b) -> (a -> b) -> InpRef a
IR BVar s a
w (AddFunc a -> a -> a -> a
forall a. AddFunc a -> a -> a -> a
runAF AddFunc a
afa) a -> a
forall a. a -> a
id
InpRef a -> Rec InpRef '[b] -> Rec InpRef '[a, b]
forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& BVar s b -> (b -> b -> b) -> (b -> b) -> InpRef b
forall s b a. BVar s b -> (a -> b -> b) -> (a -> b) -> InpRef a
IR BVar s b
v (AddFunc b -> b -> b -> b
forall a. AddFunc a -> a -> a -> a
runAF AddFunc b
afb) b -> b
forall a. a -> a
id
InpRef b -> Rec InpRef '[] -> Rec InpRef '[b]
forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& Rec InpRef '[]
forall {u} (a :: u -> *). Rec a '[]
RNil
, _tnGrad :: b -> Rec Identity '[a, b]
_tnGrad = \b
d ->
let (a
dw, b
dv) = (a -> (a, a)) -> b -> (a, b)
Lens' b a
l (\a
x -> (a
x, ZeroFunc a -> a -> a
forall a. ZeroFunc a -> a -> a
runZF ZeroFunc a
za a
x)) b
d
in a -> Identity a
forall a. a -> Identity a
Identity a
dw Identity a -> Rec Identity '[b] -> Rec Identity '[a, b]
forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& b -> Identity b
forall a. a -> Identity a
Identity b
dv Identity b -> Rec Identity '[] -> Rec Identity '[b]
forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& Rec Identity '[]
forall {u} (a :: u -> *). Rec a '[]
RNil
}
{-# INLINE setVar_ #-}
setVar ::
forall a b s.
Reifies s W =>
AddFunc a ->
AddFunc b ->
ZeroFunc a ->
Lens' b a ->
BVar s a ->
BVar s b ->
BVar s b
setVar :: forall a b s.
Reifies s W =>
AddFunc a
-> AddFunc b
-> ZeroFunc a
-> Lens' b a
-> BVar s a
-> BVar s b
-> BVar s b
setVar AddFunc a
afa AddFunc b
afb ZeroFunc a
za Lens' b a
l !BVar s a
w !BVar s b
v = IO (BVar s b) -> BVar s b
forall a. IO a -> a
unsafePerformIO (IO (BVar s b) -> BVar s b) -> IO (BVar s b) -> BVar s b
forall a b. (a -> b) -> a -> b
$ AddFunc a
-> AddFunc b
-> ZeroFunc a
-> Lens' b a
-> BVar s a
-> BVar s b
-> IO (BVar s b)
forall a b s.
Reifies s W =>
AddFunc a
-> AddFunc b
-> ZeroFunc a
-> Lens' b a
-> BVar s a
-> BVar s b
-> IO (BVar s b)
setVar_ AddFunc a
afa AddFunc b
afb ZeroFunc a
za (a -> f a) -> b -> f b
Lens' b a
l BVar s a
w BVar s b
v
{-# INLINE setVar #-}
sequenceVar ::
forall t a s.
(Reifies s W, Traversable t) =>
AddFunc a ->
ZeroFunc a ->
BVar s (t a) ->
t (BVar s a)
sequenceVar :: forall (t :: * -> *) a s.
(Reifies s W, Traversable t) =>
AddFunc a -> ZeroFunc a -> BVar s (t a) -> t (BVar s a)
sequenceVar AddFunc a
af ZeroFunc a
z !BVar s (t a)
v =
IO (t (BVar s a)) -> t (BVar s a)
forall a. IO a -> a
unsafePerformIO (IO (t (BVar s a)) -> t (BVar s a))
-> IO (t (BVar s a)) -> t (BVar s a)
forall a b. (a -> b) -> a -> b
$
AddFunc a
-> ZeroFunc (t a)
-> (t a -> t a)
-> Traversal' (t a) a
-> BVar s (t a)
-> IO (t (BVar s a))
forall b a (f :: * -> *) s.
(Reifies s W, Traversable f) =>
AddFunc a
-> ZeroFunc b
-> (b -> f a)
-> Traversal' b a
-> BVar s b
-> IO (f (BVar s a))
traverseVar' AddFunc a
af ((t a -> t a) -> ZeroFunc (t a)
forall a. (a -> a) -> ZeroFunc a
ZF ((a -> a) -> t a -> t a
forall a b. (a -> b) -> t a -> t b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (ZeroFunc a -> a -> a
forall a. ZeroFunc a -> a -> a
runZF ZeroFunc a
z))) t a -> t a
forall a. a -> a
id (a -> f a) -> t a -> f (t a)
Traversal' (t a) a
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> t a -> f (t b)
traverse BVar s (t a)
v
{-# INLINE sequenceVar #-}
collectVar_ ::
forall t a s.
(Reifies s W, Foldable t, Functor t) =>
AddFunc a ->
ZeroFunc a ->
t (BVar s a) ->
IO (BVar s (t a))
collectVar_ :: forall (t :: * -> *) a s.
(Reifies s W, Foldable t, Functor t) =>
AddFunc a -> ZeroFunc a -> t (BVar s a) -> IO (BVar s (t a))
collectVar_ AddFunc a
af ZeroFunc a
z !t (BVar s a)
vs = [BVar s a]
-> (forall {n :: Nat}. VecT n (BVar s) a -> IO (BVar s (t a)))
-> IO (BVar s (t a))
forall {k} (f :: k -> *) (a :: k) r.
[f a] -> (forall (n :: Nat). VecT n f a -> r) -> r
withVec (t (BVar s a) -> [BVar s a]
forall a. t a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList t (BVar s a)
vs) ((forall {n :: Nat}. VecT n (BVar s) a -> IO (BVar s (t a)))
-> IO (BVar s (t a)))
-> (forall {n :: Nat}. VecT n (BVar s) a -> IO (BVar s (t a)))
-> IO (BVar s (t a))
forall a b. (a -> b) -> a -> b
$ \(VecT n (BVar s) a
vVec :: VecT n (BVar s) a) -> do
let tn :: TapeNode (t a)
tn :: TapeNode (t a)
tn =
TN
{ _tnInputs :: Rec InpRef (Replicate n a)
_tnInputs = VecT n InpRef a -> Rec InpRef (Replicate n a)
forall {k} (n :: Nat) (f :: k -> *) (a :: k).
VecT n f a -> Rec f (Replicate n a)
vecToRec ((BVar s a -> InpRef a) -> VecT n (BVar s) a -> VecT n InpRef a
forall {k} (n :: Nat) (f :: k -> *) (g :: k -> *) (a :: k).
(f a -> g a) -> VecT n f a -> VecT n g a
vmap (\BVar s a
v -> BVar s a -> (a -> a -> a) -> (a -> a) -> InpRef a
forall s b a. BVar s b -> (a -> b -> b) -> (a -> b) -> InpRef a
IR BVar s a
v (AddFunc a -> a -> a -> a
forall a. AddFunc a -> a -> a -> a
runAF AddFunc a
af) a -> a
forall a. a -> a
id) VecT n (BVar s) a
vVec)
, _tnGrad :: t a -> Rec Identity (Replicate n a)
_tnGrad =
VecT n Identity a -> Rec Identity (Replicate n a)
forall {k} (n :: Nat) (f :: k -> *) (a :: k).
VecT n f a -> Rec f (Replicate n a)
vecToRec
(VecT n Identity a -> Rec Identity (Replicate n a))
-> (t a -> VecT n Identity a)
-> t a
-> Rec Identity (Replicate n a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (BVar s a -> Maybe a -> Identity a)
-> VecT n (BVar s) a -> [a] -> VecT n Identity a
forall {k1} {k2} (a :: k1) b (c :: k2) (f :: k1 -> *)
(g :: k2 -> *) (n :: Nat).
(f a -> Maybe b -> g c) -> VecT n f a -> [b] -> VecT n g c
zipVecList (\BVar s a
v -> a -> Identity a
forall a. a -> Identity a
Identity (a -> Identity a) -> (Maybe a -> a) -> Maybe a -> Identity a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Maybe a -> a
forall a. a -> Maybe a -> a
fromMaybe (ZeroFunc a -> a -> a
forall a. ZeroFunc a -> a -> a
runZF ZeroFunc a
z (BVar s a -> a
forall s a. BVar s a -> a
_bvVal BVar s a
v))) VecT n (BVar s) a
vVec
([a] -> VecT n Identity a)
-> (t a -> [a]) -> t a -> VecT n Identity a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. t a -> [a]
forall a. t a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList
}
(BVar s a -> IO ()) -> t (BVar s a) -> IO ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ (() -> IO ()
forall a. a -> IO a
evaluate (() -> IO ()) -> (BVar s a -> ()) -> BVar s a -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BVar s a -> ()
forall s a. BVar s a -> ()
forceBVar) t (BVar s a)
vs
TapeNode (t a) -> t a -> W -> IO (BVar s (t a))
forall a s. TapeNode a -> a -> W -> IO (BVar s a)
insertNode TapeNode (t a)
tn (BVar s a -> a
forall s a. BVar s a -> a
_bvVal (BVar s a -> a) -> t (BVar s a) -> t a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> t (BVar s a)
vs) (Proxy s -> W
forall {k} (s :: k) a (proxy :: k -> *).
Reifies s a =>
proxy s -> a
forall (proxy :: * -> *). proxy s -> W
reflect (forall t. Proxy t
forall {k} (t :: k). Proxy t
Proxy @s))
{-# INLINE collectVar_ #-}
collectVar ::
forall t a s.
(Reifies s W, Foldable t, Functor t) =>
AddFunc a ->
ZeroFunc a ->
t (BVar s a) ->
BVar s (t a)
collectVar :: forall (t :: * -> *) a s.
(Reifies s W, Foldable t, Functor t) =>
AddFunc a -> ZeroFunc a -> t (BVar s a) -> BVar s (t a)
collectVar AddFunc a
af ZeroFunc a
z !t (BVar s a)
vs = IO (BVar s (t a)) -> BVar s (t a)
forall a. IO a -> a
unsafePerformIO (IO (BVar s (t a)) -> BVar s (t a))
-> IO (BVar s (t a)) -> BVar s (t a)
forall a b. (a -> b) -> a -> b
$ AddFunc a -> ZeroFunc a -> t (BVar s a) -> IO (BVar s (t a))
forall (t :: * -> *) a s.
(Reifies s W, Foldable t, Functor t) =>
AddFunc a -> ZeroFunc a -> t (BVar s a) -> IO (BVar s (t a))
collectVar_ AddFunc a
af ZeroFunc a
z t (BVar s a)
vs
{-# INLINE collectVar #-}
traverseVar' ::
forall b a f s.
(Reifies s W, Traversable f) =>
AddFunc a ->
ZeroFunc b ->
(b -> f a) ->
Traversal' b a ->
BVar s b ->
IO (f (BVar s a))
traverseVar' :: forall b a (f :: * -> *) s.
(Reifies s W, Traversable f) =>
AddFunc a
-> ZeroFunc b
-> (b -> f a)
-> Traversal' b a
-> BVar s b
-> IO (f (BVar s a))
traverseVar' AddFunc a
af ZeroFunc b
z b -> f a
f Traversal' b a
t BVar s b
v =
BVar s b -> ()
forall s a. BVar s a -> ()
forceBVar BVar s b
v () -> IO (f (BVar s a)) -> IO (f (BVar s a))
forall a b. a -> b -> b
`seq`
(Int -> a -> IO (BVar s a)) -> f a -> IO (f (BVar s a))
forall (t :: * -> *) a b (f :: * -> *).
(Traversable t, Monad f) =>
(Int -> a -> f b) -> t a -> f (t b)
itraverse Int -> a -> IO (BVar s a)
go (b -> f a
f b
x)
where
x :: b
x = BVar s b -> b
forall s a. BVar s a -> a
_bvVal BVar s b
v
go :: Int -> a -> IO (BVar s a)
go :: Int -> a -> IO (BVar s a)
go Int
i a
y = TapeNode a -> a -> W -> IO (BVar s a)
forall a s. TapeNode a -> a -> W -> IO (BVar s a)
insertNode TapeNode a
tn a
y (Proxy s -> W
forall {k} (s :: k) a (proxy :: k -> *).
Reifies s a =>
proxy s -> a
forall (proxy :: * -> *). proxy s -> W
reflect (forall t. Proxy t
forall {k} (t :: k). Proxy t
Proxy @s))
where
tn :: TapeNode a
tn =
TN
{ _tnInputs :: Rec InpRef '[a]
_tnInputs =
BVar s b -> (a -> b -> b) -> (a -> b) -> InpRef a
forall s b a. BVar s b -> (a -> b -> b) -> (a -> b) -> InpRef a
IR
BVar s b
v
(ASetter b b a a -> (a -> a) -> b -> b
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over (Traversal' b a -> Int -> Lens' b a
forall b a. Traversal' b a -> Int -> Lens' b a
ixt (a -> f a) -> b -> f b
Traversal' b a
t Int
i) ((a -> a) -> b -> b) -> (a -> a -> a) -> a -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AddFunc a -> a -> a -> a
forall a. AddFunc a -> a -> a -> a
runAF AddFunc a
af)
(\a
g -> ASetter b b a a -> a -> b -> b
forall s t a b. ASetter s t a b -> b -> s -> t
set (Traversal' b a -> Int -> Lens' b a
forall b a. Traversal' b a -> Int -> Lens' b a
ixt (a -> f a) -> b -> f b
Traversal' b a
t Int
i) a
g (ZeroFunc b -> b -> b
forall a. ZeroFunc a -> a -> a
runZF ZeroFunc b
z b
x))
InpRef a -> Rec InpRef '[] -> Rec InpRef '[a]
forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& Rec InpRef '[]
forall {u} (a :: u -> *). Rec a '[]
RNil
, _tnGrad :: a -> Rec Identity '[a]
_tnGrad = (Identity a -> Rec Identity '[] -> Rec Identity '[a]
forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& Rec Identity '[]
forall {u} (a :: u -> *). Rec a '[]
RNil) (Identity a -> Rec Identity '[a])
-> (a -> Identity a) -> a -> Rec Identity '[a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Identity a
forall a. a -> Identity a
Identity
}
{-# INLINE go #-}
{-# INLINE traverseVar' #-}
previewVar ::
forall b a s.
Reifies s W =>
AddFunc a ->
ZeroFunc b ->
Traversal' b a ->
BVar s b ->
Maybe (BVar s a)
previewVar :: forall b a s.
Reifies s W =>
AddFunc a
-> ZeroFunc b -> Traversal' b a -> BVar s b -> Maybe (BVar s a)
previewVar AddFunc a
af ZeroFunc b
z Traversal' b a
t !BVar s b
v =
IO (Maybe (BVar s a)) -> Maybe (BVar s a)
forall a. IO a -> a
unsafePerformIO (IO (Maybe (BVar s a)) -> Maybe (BVar s a))
-> IO (Maybe (BVar s a)) -> Maybe (BVar s a)
forall a b. (a -> b) -> a -> b
$
AddFunc a
-> ZeroFunc b
-> (b -> Maybe a)
-> Traversal' b a
-> BVar s b
-> IO (Maybe (BVar s a))
forall b a (f :: * -> *) s.
(Reifies s W, Traversable f) =>
AddFunc a
-> ZeroFunc b
-> (b -> f a)
-> Traversal' b a
-> BVar s b
-> IO (f (BVar s a))
traverseVar' AddFunc a
af ZeroFunc b
z (Getting (First a) b a -> b -> Maybe a
forall a s. Getting (First a) s a -> s -> Maybe a
preview Getting (First a) b a
Traversal' b a
t) (a -> f a) -> b -> f b
Traversal' b a
t BVar s b
v
{-# INLINE previewVar #-}
toListOfVar ::
forall b a s.
Reifies s W =>
AddFunc a ->
ZeroFunc b ->
Traversal' b a ->
BVar s b ->
[BVar s a]
toListOfVar :: forall b a s.
Reifies s W =>
AddFunc a -> ZeroFunc b -> Traversal' b a -> BVar s b -> [BVar s a]
toListOfVar AddFunc a
af ZeroFunc b
z Traversal' b a
t !BVar s b
v =
IO [BVar s a] -> [BVar s a]
forall a. IO a -> a
unsafePerformIO (IO [BVar s a] -> [BVar s a]) -> IO [BVar s a] -> [BVar s a]
forall a b. (a -> b) -> a -> b
$
AddFunc a
-> ZeroFunc b
-> (b -> [a])
-> Traversal' b a
-> BVar s b
-> IO [BVar s a]
forall b a (f :: * -> *) s.
(Reifies s W, Traversable f) =>
AddFunc a
-> ZeroFunc b
-> (b -> f a)
-> Traversal' b a
-> BVar s b
-> IO (f (BVar s a))
traverseVar' AddFunc a
af ZeroFunc b
z (Getting (Endo [a]) b a -> b -> [a]
forall a s. Getting (Endo [a]) s a -> s -> [a]
toListOf Getting (Endo [a]) b a
Traversal' b a
t) (a -> f a) -> b -> f b
Traversal' b a
t BVar s b
v
{-# INLINE toListOfVar #-}
coerceVar ::
Coercible a b =>
BVar s a ->
BVar s b
coerceVar :: forall a b s. Coercible a b => BVar s a -> BVar s b
coerceVar v :: BVar s a
v@(BV BRef s
r a
x) = BVar s a -> ()
forall s a. BVar s a -> ()
forceBVar BVar s a
v () -> BVar s b -> BVar s b
forall a b. a -> b -> b
`seq` BRef s -> b -> BVar s b
forall s a. BRef s -> a -> BVar s a
BV BRef s
r (a -> b
forall a b. Coercible a b => a -> b
coerce a
x)
data Runner s = R
{ forall s. Runner s -> MVector s (Maybe Any)
_rDelta :: !(MV.MVector s (Maybe Any))
, forall s. Runner s -> MVector s (Maybe Any)
_rInputs :: !(MV.MVector s (Maybe Any))
}
initRunner ::
(Int, [SomeTapeNode]) ->
(Int, [Maybe Any]) ->
ST s (Runner s)
initRunner :: forall s.
(Int, [SomeTapeNode]) -> (Int, [Maybe Any]) -> ST s (Runner s)
initRunner (Int
n, [SomeTapeNode]
stns) (Int
nx, [Maybe Any]
xs) = do
MVector s (Maybe Any)
delts <- Int -> ST s (MVector (PrimState (ST s)) (Maybe Any))
forall (m :: * -> *) a.
PrimMonad m =>
Int -> m (MVector (PrimState m) a)
MV.new Int
n
[(Int, SomeTapeNode)]
-> ((Int, SomeTapeNode) -> ST s ()) -> ST s ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ ([Int] -> [SomeTapeNode] -> [(Int, SomeTapeNode)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1, Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2 ..] [SomeTapeNode]
stns) (((Int, SomeTapeNode) -> ST s ()) -> ST s ())
-> ((Int, SomeTapeNode) -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \(Int
i, STN (TN{} :: TapeNode c)) ->
MVector (PrimState (ST s)) (Maybe Any)
-> Int -> Maybe Any -> ST s ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write MVector s (Maybe Any)
MVector (PrimState (ST s)) (Maybe Any)
delts Int
i (Maybe Any -> ST s ()) -> Maybe Any -> ST s ()
forall a b. (a -> b) -> a -> b
$ Maybe a -> Maybe Any
forall a b. a -> b
unsafeCoerce (forall a. Maybe a
Nothing @c)
MVector s (Maybe Any)
inps <- Int -> ST s (MVector (PrimState (ST s)) (Maybe Any))
forall (m :: * -> *) a.
PrimMonad m =>
Int -> m (MVector (PrimState m) a)
MV.new Int
nx
(Int -> Maybe Any -> ST s ()) -> [Maybe Any] -> ST s ()
forall (t :: * -> *) a b (f :: * -> *).
(Foldable t, Monad f) =>
(Int -> a -> f b) -> t a -> f ()
itraverse_ (MVector (PrimState (ST s)) (Maybe Any)
-> Int -> Maybe Any -> ST s ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write MVector s (Maybe Any)
MVector (PrimState (ST s)) (Maybe Any)
inps) [Maybe Any]
xs
Runner s -> ST s (Runner s)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (Runner s -> ST s (Runner s)) -> Runner s -> ST s (Runner s)
forall a b. (a -> b) -> a -> b
$ MVector s (Maybe Any) -> MVector s (Maybe Any) -> Runner s
forall s.
MVector s (Maybe Any) -> MVector s (Maybe Any) -> Runner s
R MVector s (Maybe Any)
delts MVector s (Maybe Any)
inps
{-# INLINE initRunner #-}
gradRunner ::
forall b s.
() =>
b ->
Runner s ->
(Int, [SomeTapeNode]) ->
ST s ()
gradRunner :: forall b s. b -> Runner s -> (Int, [SomeTapeNode]) -> ST s ()
gradRunner b
o R{MVector s (Maybe Any)
_rDelta :: forall s. Runner s -> MVector s (Maybe Any)
_rInputs :: forall s. Runner s -> MVector s (Maybe Any)
_rDelta :: MVector s (Maybe Any)
_rInputs :: MVector s (Maybe Any)
..} (Int
n, [SomeTapeNode]
stns) = do
Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$
MVector (PrimState (ST s)) (Maybe Any)
-> Int -> Maybe Any -> ST s ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write MVector s (Maybe Any)
MVector (PrimState (ST s)) (Maybe Any)
_rDelta (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (Maybe b -> Maybe Any
forall a b. a -> b
unsafeCoerce (b -> Maybe b
forall a. a -> Maybe a
Just b
o))
(Int -> SomeTapeNode -> ST s ())
-> [Int] -> [SomeTapeNode] -> ST s ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Int -> SomeTapeNode -> ST s ()
go [Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1, Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2 ..] [SomeTapeNode]
stns
where
go :: Int -> SomeTapeNode -> ST s ()
go :: Int -> SomeTapeNode -> ST s ()
go Int
i (STN (TN{Rec InpRef as
a -> Rec Identity as
_tnInputs :: ()
_tnGrad :: ()
_tnInputs :: Rec InpRef as
_tnGrad :: a -> Rec Identity as
..} :: TapeNode c)) = do
Maybe Any
delt <- MVector (PrimState (ST s)) (Maybe Any) -> Int -> ST s (Maybe Any)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m a
MV.read MVector s (Maybe Any)
MVector (PrimState (ST s)) (Maybe Any)
_rDelta Int
i
Maybe Any -> (Any -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ Maybe Any
delt ((Any -> ST s ()) -> ST s ()) -> (Any -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Any
d -> do
let gs :: Rec Identity as
gs = a -> Rec Identity as
_tnGrad (Any -> a
forall a b. a -> b
unsafeCoerce Any
d)
(forall a. InpRef a -> Identity a -> ST s ())
-> Rec InpRef as -> Rec Identity as -> ST s ()
forall {u} (h :: * -> *) (f :: u -> *) (g :: u -> *) (as :: [u]).
Applicative h =>
(forall (a :: u). f a -> g a -> h ())
-> Rec f as -> Rec g as -> h ()
rzipWithM_ InpRef a -> Identity a -> ST s ()
forall a. InpRef a -> Identity a -> ST s ()
propagate Rec InpRef as
_tnInputs Rec Identity as
gs
{-# INLINE go #-}
propagate :: forall x. InpRef x -> Identity x -> ST s ()
propagate :: forall a. InpRef a -> Identity a -> ST s ()
propagate (IR BVar s b
v x -> b -> b
(+*) x -> b
e) (Identity x
d) = case BVar s b -> BRef s
forall s a. BVar s a -> BRef s
_bvRef BVar s b
v of
BRInp Int
i ->
((Maybe Any -> Maybe Any) -> Int -> ST s ())
-> Int -> (Maybe Any -> Maybe Any) -> ST s ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip (MVector (PrimState (ST s)) (Maybe Any)
-> (Maybe Any -> Maybe Any) -> Int -> ST s ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
MV.modify MVector s (Maybe Any)
MVector (PrimState (ST s)) (Maybe Any)
_rInputs) Int
i ((Maybe Any -> Maybe Any) -> ST s ())
-> (Maybe Any -> Maybe Any) -> ST s ()
forall a b. (a -> b) -> a -> b
$
Maybe b -> Maybe Any
forall a b. a -> b
unsafeCoerce (Maybe b -> Maybe Any)
-> (Maybe Any -> Maybe b) -> Maybe Any -> Maybe Any
forall b c a. (b -> c) -> (a -> b) -> a -> c
. x -> (x -> b -> b) -> (x -> b) -> Maybe b -> Maybe b
forall a b. a -> (a -> b -> b) -> (a -> b) -> Maybe b -> Maybe b
bumpMaybe x
d x -> b -> b
(+*) x -> b
e (Maybe b -> Maybe b)
-> (Maybe Any -> Maybe b) -> Maybe Any -> Maybe b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe Any -> Maybe b
forall a b. a -> b
unsafeCoerce
BRIx Int
i ->
((Maybe Any -> Maybe Any) -> Int -> ST s ())
-> Int -> (Maybe Any -> Maybe Any) -> ST s ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip (MVector (PrimState (ST s)) (Maybe Any)
-> (Maybe Any -> Maybe Any) -> Int -> ST s ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
MV.modify MVector s (Maybe Any)
MVector (PrimState (ST s)) (Maybe Any)
_rDelta) Int
i ((Maybe Any -> Maybe Any) -> ST s ())
-> (Maybe Any -> Maybe Any) -> ST s ()
forall a b. (a -> b) -> a -> b
$
Maybe b -> Maybe Any
forall a b. a -> b
unsafeCoerce (Maybe b -> Maybe Any)
-> (Maybe Any -> Maybe b) -> Maybe Any -> Maybe Any
forall b c a. (b -> c) -> (a -> b) -> a -> c
. x -> (x -> b -> b) -> (x -> b) -> Maybe b -> Maybe b
forall a b. a -> (a -> b -> b) -> (a -> b) -> Maybe b -> Maybe b
bumpMaybe x
d x -> b -> b
(+*) x -> b
e (Maybe b -> Maybe b)
-> (Maybe Any -> Maybe b) -> Maybe Any -> Maybe b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe Any -> Maybe b
forall a b. a -> b
unsafeCoerce
BRef s
BRC -> () -> ST s ()
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
{-# INLINE propagate #-}
{-# INLINE gradRunner #-}
bumpMaybe ::
a ->
(a -> b -> b) ->
(a -> b) ->
Maybe b ->
Maybe b
bumpMaybe :: forall a b. a -> (a -> b -> b) -> (a -> b) -> Maybe b -> Maybe b
bumpMaybe a
x a -> b -> b
(+*) a -> b
e = \case
Maybe b
Nothing -> b -> Maybe b
forall a. a -> Maybe a
Just (a -> b
e a
x)
Just b
y -> b -> Maybe b
forall a. a -> Maybe a
Just (a
x a -> b -> b
+* b
y)
{-# INLINE bumpMaybe #-}
seqEither :: Either a (b, [SomeTapeNode]) -> Either a (b, [SomeTapeNode])
seqEither :: forall a b.
Either a (b, [SomeTapeNode]) -> Either a (b, [SomeTapeNode])
seqEither e :: Either a (b, [SomeTapeNode])
e@(Left !a
_) = Either a (b, [SomeTapeNode])
e
seqEither e :: Either a (b, [SomeTapeNode])
e@(Right (!b
_, (SomeTapeNode -> ()) -> [SomeTapeNode] -> ()
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap SomeTapeNode -> ()
forceSomeTapeNode -> (!()
_))) = Either a (b, [SomeTapeNode])
e
{-# INLINE seqEither #-}
backpropWithN ::
forall as b.
() =>
Rec ZeroFunc as ->
(forall s. Reifies s W => Rec (BVar s) as -> BVar s b) ->
Rec Identity as ->
(b, b -> Rec Identity as)
backpropWithN :: forall (as :: [*]) b.
Rec ZeroFunc as
-> (forall s. Reifies s W => Rec (BVar s) as -> BVar s b)
-> Rec Identity as
-> (b, b -> Rec Identity as)
backpropWithN Rec ZeroFunc as
zfs forall s. Reifies s W => Rec (BVar s) as -> BVar s b
f !Rec Identity as
xs = (b
y, b -> Rec Identity as
g')
where
!(Either Int (Int, [SomeTapeNode])
-> Either Int (Int, [SomeTapeNode])
forall a b.
Either a (b, [SomeTapeNode]) -> Either a (b, [SomeTapeNode])
seqEither -> (!Either Int (Int, [SomeTapeNode])
tp0), !b
y) = IO (Either Int (Int, [SomeTapeNode]), b)
-> (Either Int (Int, [SomeTapeNode]), b)
forall a. IO a -> a
unsafePerformIO (IO (Either Int (Int, [SomeTapeNode]), b)
-> (Either Int (Int, [SomeTapeNode]), b))
-> IO (Either Int (Int, [SomeTapeNode]), b)
-> (Either Int (Int, [SomeTapeNode]), b)
forall a b. (a -> b) -> a -> b
$ (forall s. Reifies s W => Rec (BVar s) as -> BVar s b)
-> Rec Identity as -> IO (Either Int (Int, [SomeTapeNode]), b)
forall (as :: [*]) b.
(forall s. Reifies s W => Rec (BVar s) as -> BVar s b)
-> Rec Identity as -> IO (Either Int (Int, [SomeTapeNode]), b)
fillWengert Rec (BVar s) as -> BVar s b
forall s. Reifies s W => Rec (BVar s) as -> BVar s b
f Rec Identity as
xs
g' :: b -> Rec Identity as
g' :: b -> Rec Identity as
g' = case Either Int (Int, [SomeTapeNode])
tp0 of
Left Int
i -> Int -> b -> Rec Identity as
setInput Int
i
Right (Int, [SomeTapeNode])
tp -> (Int, [SomeTapeNode]) -> b -> Rec Identity as
g (Int, [SomeTapeNode])
tp
{-# INLINE g' #-}
g :: (Int, [SomeTapeNode]) -> b -> Rec Identity as
g :: (Int, [SomeTapeNode]) -> b -> Rec Identity as
g (Int, [SomeTapeNode])
tp b
o = (forall s. ST s (Rec Identity as)) -> Rec Identity as
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Rec Identity as)) -> Rec Identity as)
-> (forall s. ST s (Rec Identity as)) -> Rec Identity as
forall a b. (a -> b) -> a -> b
$ do
Runner s
r <-
(Int, [SomeTapeNode]) -> (Int, [Maybe Any]) -> ST s (Runner s)
forall s.
(Int, [SomeTapeNode]) -> (Int, [Maybe Any]) -> ST s (Runner s)
initRunner (Int, [SomeTapeNode])
tp
((Int, [Maybe Any]) -> ST s (Runner s))
-> (Rec Identity as -> (Int, [Maybe Any]))
-> Rec Identity as
-> ST s (Runner s)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Sum Int -> Int)
-> (Endo [Maybe Any] -> [Maybe Any])
-> (Sum Int, Endo [Maybe Any])
-> (Int, [Maybe Any])
forall a b c d. (a -> b) -> (c -> d) -> (a, c) -> (b, d)
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap Sum Int -> Int
forall a. Sum a -> a
getSum (Endo [Maybe Any] -> [Maybe Any] -> [Maybe Any]
forall a. Endo a -> a -> a
`appEndo` [])
((Sum Int, Endo [Maybe Any]) -> (Int, [Maybe Any]))
-> (Rec Identity as -> (Sum Int, Endo [Maybe Any]))
-> Rec Identity as
-> (Int, [Maybe Any])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall x. Identity x -> (Sum Int, Endo [Maybe Any]))
-> Rec Identity as -> (Sum Int, Endo [Maybe Any])
forall {u} (f :: u -> *) m (rs :: [u]).
Monoid m =>
(forall (x :: u). f x -> m) -> Rec f rs -> m
VR.rfoldMap Identity x -> (Sum Int, Endo [Maybe Any])
forall x. Identity x -> (Sum Int, Endo [Maybe Any])
go
(Rec Identity as -> ST s (Runner s))
-> Rec Identity as -> ST s (Runner s)
forall a b. (a -> b) -> a -> b
$ Rec Identity as
xs
b -> Runner s -> (Int, [SomeTapeNode]) -> ST s ()
forall b s. b -> Runner s -> (Int, [SomeTapeNode]) -> ST s ()
gradRunner b
o Runner s
r (Int, [SomeTapeNode])
tp
[Maybe Any]
delts <- Vector (Maybe Any) -> [Maybe Any]
forall a. Vector a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList (Vector (Maybe Any) -> [Maybe Any])
-> ST s (Vector (Maybe Any)) -> ST s [Maybe Any]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState (ST s)) (Maybe Any) -> ST s (Vector (Maybe Any))
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
V.freeze (Runner s -> MVector s (Maybe Any)
forall s. Runner s -> MVector s (Maybe Any)
_rInputs Runner s
r)
Rec Identity as -> ST s (Rec Identity as)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (Rec Identity as -> ST s (Rec Identity as))
-> (Maybe (Rec Identity as) -> Rec Identity as)
-> Maybe (Rec Identity as)
-> ST s (Rec Identity as)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rec Identity as -> Maybe (Rec Identity as) -> Rec Identity as
forall a. a -> Maybe a -> a
fromMaybe (String -> Rec Identity as
forall a. String -> a
internalError String
"backpropN") (Maybe (Rec Identity as) -> ST s (Rec Identity as))
-> Maybe (Rec Identity as) -> ST s (Rec Identity as)
forall a b. (a -> b) -> a -> b
$
(forall a. Identity a -> Maybe Any -> Identity a)
-> Rec Identity as -> [Maybe Any] -> Maybe (Rec Identity as)
forall {u} (f :: u -> *) (g :: u -> *) (as :: [u]) c.
(forall (a :: u). f a -> c -> g a)
-> Rec f as -> [c] -> Maybe (Rec g as)
fillRec
(\Identity a
z -> Identity a -> (Any -> Identity a) -> Maybe Any -> Identity a
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Identity a
z (a -> Identity a
forall a. a -> Identity a
Identity (a -> Identity a) -> (Any -> a) -> Any -> Identity a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Any -> a
forall a b. a -> b
unsafeCoerce))
((forall x. ZeroFunc x -> Identity x -> Identity x)
-> forall (xs :: [*]).
Rec ZeroFunc xs -> Rec Identity xs -> Rec Identity xs
forall {k} (f :: k -> *) (g :: k -> *) (h :: k -> *).
(forall (x :: k). f x -> g x -> h x)
-> forall (xs :: [k]). Rec f xs -> Rec g xs -> Rec h xs
VR.rzipWith ((x -> x) -> Identity x -> Identity x
forall a b. (a -> b) -> Identity a -> Identity b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((x -> x) -> Identity x -> Identity x)
-> (ZeroFunc x -> x -> x) -> ZeroFunc x -> Identity x -> Identity x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ZeroFunc x -> x -> x
forall a. ZeroFunc a -> a -> a
runZF) Rec ZeroFunc as
zfs Rec Identity as
xs)
[Maybe Any]
delts
where
go :: forall a. Identity a -> (Sum Int, Endo [Maybe Any])
go :: forall x. Identity x -> (Sum Int, Endo [Maybe Any])
go Identity a
_ = (Sum Int
1, ([Maybe Any] -> [Maybe Any]) -> Endo [Maybe Any]
forall a. (a -> a) -> Endo a
Endo (Maybe a -> Maybe Any
forall a b. a -> b
unsafeCoerce (forall a. Maybe a
Nothing @a) Maybe Any -> [Maybe Any] -> [Maybe Any]
forall a. a -> [a] -> [a]
:))
{-# INLINE go #-}
setInput :: Int -> b -> Rec Identity as
setInput :: Int -> b -> Rec Identity as
setInput !Int
i !b
x = Rec ZeroFunc as -> Rec Identity as -> Int -> Rec Identity as
forall (bs :: [*]).
Rec ZeroFunc bs -> Rec Identity bs -> Int -> Rec Identity bs
go Rec ZeroFunc as
zfs Rec Identity as
xs Int
0
where
go :: Rec ZeroFunc bs -> Rec Identity bs -> Int -> Rec Identity bs
go :: forall (bs :: [*]).
Rec ZeroFunc bs -> Rec Identity bs -> Int -> Rec Identity bs
go = \case
Rec ZeroFunc bs
RNil -> \Rec Identity bs
_ Int
_ -> Rec Identity bs
Rec Identity '[]
forall {u} (a :: u -> *). Rec a '[]
RNil
ZeroFunc r
z :& Rec ZeroFunc rs
zs -> \case
Identity r
q :& Rec Identity rs
qs -> \(!Int
j) ->
if Int
j Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
i
then r -> Identity r
forall a. a -> Identity a
Identity (b -> r
forall a b. a -> b
unsafeCoerce b
x) Identity r -> Rec Identity rs -> Rec Identity (r : rs)
forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& (forall x. ZeroFunc x -> Identity x -> Identity x)
-> forall (xs :: [*]).
Rec ZeroFunc xs -> Rec Identity xs -> Rec Identity xs
forall {k} (f :: k -> *) (g :: k -> *) (h :: k -> *).
(forall (x :: k). f x -> g x -> h x)
-> forall (xs :: [k]). Rec f xs -> Rec g xs -> Rec h xs
VR.rzipWith ZeroFunc x -> Identity x -> Identity x
forall x. ZeroFunc x -> Identity x -> Identity x
forall a b. Coercible a b => a -> b
coerce Rec ZeroFunc rs
zs Rec Identity rs
Rec Identity rs
qs
else ZeroFunc r -> Identity r -> Identity r
forall a b. Coercible a b => a -> b
coerce ZeroFunc r
z Identity r
q Identity r -> Rec Identity rs -> Rec Identity (r : rs)
forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& Rec ZeroFunc rs -> Rec Identity rs -> Int -> Rec Identity rs
forall (bs :: [*]).
Rec ZeroFunc bs -> Rec Identity bs -> Int -> Rec Identity bs
go Rec ZeroFunc rs
zs Rec Identity rs
Rec Identity rs
qs (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
{-# INLINE setInput #-}
{-# INLINE backpropWithN #-}
evalBPN ::
forall as b.
() =>
(forall s. Reifies s W => Rec (BVar s) as -> BVar s b) ->
Rec Identity as ->
b
evalBPN :: forall (as :: [*]) b.
(forall s. Reifies s W => Rec (BVar s) as -> BVar s b)
-> Rec Identity as -> b
evalBPN forall s. Reifies s W => Rec (BVar s) as -> BVar s b
f = (Either Int (Int, [SomeTapeNode]), b) -> b
forall a b. (a, b) -> b
snd ((Either Int (Int, [SomeTapeNode]), b) -> b)
-> (Rec Identity as -> (Either Int (Int, [SomeTapeNode]), b))
-> Rec Identity as
-> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO (Either Int (Int, [SomeTapeNode]), b)
-> (Either Int (Int, [SomeTapeNode]), b)
forall a. IO a -> a
unsafePerformIO (IO (Either Int (Int, [SomeTapeNode]), b)
-> (Either Int (Int, [SomeTapeNode]), b))
-> (Rec Identity as -> IO (Either Int (Int, [SomeTapeNode]), b))
-> Rec Identity as
-> (Either Int (Int, [SomeTapeNode]), b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall s. Reifies s W => Rec (BVar s) as -> BVar s b)
-> Rec Identity as -> IO (Either Int (Int, [SomeTapeNode]), b)
forall (as :: [*]) b.
(forall s. Reifies s W => Rec (BVar s) as -> BVar s b)
-> Rec Identity as -> IO (Either Int (Int, [SomeTapeNode]), b)
fillWengert Rec (BVar s) as -> BVar s b
forall s. Reifies s W => Rec (BVar s) as -> BVar s b
f
{-# INLINE evalBPN #-}
fillWengert ::
forall as b.
() =>
(forall s. Reifies s W => Rec (BVar s) as -> BVar s b) ->
Rec Identity as ->
IO (Either Int (Int, [SomeTapeNode]), b)
fillWengert :: forall (as :: [*]) b.
(forall s. Reifies s W => Rec (BVar s) as -> BVar s b)
-> Rec Identity as -> IO (Either Int (Int, [SomeTapeNode]), b)
fillWengert forall s. Reifies s W => Rec (BVar s) as -> BVar s b
f Rec Identity as
xs = do
W
w <- IO W
initWengert
(Maybe Int
i, b
o) <- W
-> (forall {s}. Reifies s W => Proxy s -> IO (Maybe Int, b))
-> IO (Maybe Int, b)
forall a r. a -> (forall s. Reifies s a => Proxy s -> r) -> r
reify W
w ((forall {s}. Reifies s W => Proxy s -> IO (Maybe Int, b))
-> IO (Maybe Int, b))
-> (forall {s}. Reifies s W => Proxy s -> IO (Maybe Int, b))
-> IO (Maybe Int, b)
forall a b. (a -> b) -> a -> b
$ \(Proxy s
Proxy :: Proxy s) -> do
let oVar :: BVar s b
oVar = Rec (BVar s) as -> BVar s b
forall s. Reifies s W => Rec (BVar s) as -> BVar s b
f (forall s. Rec (BVar s) as
inpRec @s)
() -> IO ()
forall a. a -> IO a
evaluate (BVar s b -> ()
forall s a. BVar s a -> ()
forceBVar BVar s b
oVar)
let isInput :: Maybe Int
isInput = case BVar s b -> BRef s
forall s a. BVar s a -> BRef s
_bvRef BVar s b
oVar of
BRInp Int
i -> Int -> Maybe Int
forall a. a -> Maybe a
Just Int
i
BRef s
_ -> Maybe Int
forall a. Maybe a
Nothing
(Maybe Int, b) -> IO (Maybe Int, b)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Int
isInput, BVar s b -> b
forall s a. BVar s a -> a
_bvVal BVar s b
oVar)
Either Int (Int, [SomeTapeNode])
t <- case Maybe Int
i of
Maybe Int
Nothing -> (Int, [SomeTapeNode]) -> Either Int (Int, [SomeTapeNode])
forall a b. b -> Either a b
Right ((Int, [SomeTapeNode]) -> Either Int (Int, [SomeTapeNode]))
-> IO (Int, [SomeTapeNode])
-> IO (Either Int (Int, [SomeTapeNode]))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IORef (Int, [SomeTapeNode]) -> IO (Int, [SomeTapeNode])
forall a. IORef a -> IO a
readIORef (W -> IORef (Int, [SomeTapeNode])
wRef W
w)
Just Int
i' -> Either Int (Int, [SomeTapeNode])
-> IO (Either Int (Int, [SomeTapeNode]))
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either Int (Int, [SomeTapeNode])
-> IO (Either Int (Int, [SomeTapeNode])))
-> Either Int (Int, [SomeTapeNode])
-> IO (Either Int (Int, [SomeTapeNode]))
forall a b. (a -> b) -> a -> b
$ Int -> Either Int (Int, [SomeTapeNode])
forall a b. a -> Either a b
Left Int
i'
(Either Int (Int, [SomeTapeNode]), b)
-> IO (Either Int (Int, [SomeTapeNode]), b)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either Int (Int, [SomeTapeNode])
t, b
o)
where
inpRec :: forall s. Rec (BVar s) as
inpRec :: forall s. Rec (BVar s) as
inpRec = State Int (Rec (BVar s) as) -> Int -> Rec (BVar s) as
forall s a. State s a -> s -> a
evalState ((forall x. Identity x -> StateT Int Identity (BVar s x))
-> Rec Identity as -> State Int (Rec (BVar s) as)
forall {u} (h :: * -> *) (f :: u -> *) (g :: u -> *) (rs :: [u]).
Applicative h =>
(forall (x :: u). f x -> h (g x)) -> Rec f rs -> h (Rec g rs)
rtraverse ((Int -> (BVar s x, Int)) -> StateT Int Identity (BVar s x)
forall (m :: * -> *) s a. Monad m => (s -> (a, s)) -> StateT s m a
state ((Int -> (BVar s x, Int)) -> StateT Int Identity (BVar s x))
-> (Identity x -> Int -> (BVar s x, Int))
-> Identity x
-> StateT Int Identity (BVar s x)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. x -> Int -> (BVar s x, Int)
forall a. a -> Int -> (BVar s a, Int)
go (x -> Int -> (BVar s x, Int))
-> (Identity x -> x) -> Identity x -> Int -> (BVar s x, Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Identity x -> x
forall a. Identity a -> a
runIdentity) Rec Identity as
xs) Int
0
where
go :: a -> Int -> (BVar s a, Int)
go :: forall a. a -> Int -> (BVar s a, Int)
go a
x Int
i = (BRef s -> a -> BVar s a
forall s a. BRef s -> a -> BVar s a
BV (Int -> BRef s
forall s. Int -> BRef s
BRInp Int
i) a
x, Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
{-# INLINE go #-}
{-# INLINE inpRec #-}
{-# INLINE fillWengert #-}
instance (Num a, Reifies s W) => Num (BVar s a) where
+ :: BVar s a -> BVar s a -> BVar s a
(+) = AddFunc a
-> AddFunc a -> Op '[a, a] a -> BVar s a -> BVar s a -> BVar s a
forall a b c s.
Reifies s W =>
AddFunc a
-> AddFunc b -> Op '[a, b] c -> BVar s a -> BVar s b -> BVar s c
liftOp2 AddFunc a
forall a. Num a => AddFunc a
afNum AddFunc a
forall a. Num a => AddFunc a
afNum Op '[a, a] a
forall a. Num a => Op '[a, a] a
(+.)
{-# INLINE (+) #-}
(-) = AddFunc a
-> AddFunc a -> Op '[a, a] a -> BVar s a -> BVar s a -> BVar s a
forall a b c s.
Reifies s W =>
AddFunc a
-> AddFunc b -> Op '[a, b] c -> BVar s a -> BVar s b -> BVar s c
liftOp2 AddFunc a
forall a. Num a => AddFunc a
afNum AddFunc a
forall a. Num a => AddFunc a
afNum Op '[a, a] a
forall a. Num a => Op '[a, a] a
(-.)
{-# INLINE (-) #-}
* :: BVar s a -> BVar s a -> BVar s a
(*) = AddFunc a
-> AddFunc a -> Op '[a, a] a -> BVar s a -> BVar s a -> BVar s a
forall a b c s.
Reifies s W =>
AddFunc a
-> AddFunc b -> Op '[a, b] c -> BVar s a -> BVar s b -> BVar s c
liftOp2 AddFunc a
forall a. Num a => AddFunc a
afNum AddFunc a
forall a. Num a => AddFunc a
afNum Op '[a, a] a
forall a. Num a => Op '[a, a] a
(*.)
{-# INLINE (*) #-}
negate :: BVar s a -> BVar s a
negate = AddFunc a -> Op '[a] a -> BVar s a -> BVar s a
forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> BVar s b
liftOp1 AddFunc a
forall a. Num a => AddFunc a
afNum Op '[a] a
forall a. Num a => Op '[a] a
negateOp
{-# INLINE negate #-}
signum :: BVar s a -> BVar s a
signum = AddFunc a -> Op '[a] a -> BVar s a -> BVar s a
forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> BVar s b
liftOp1 AddFunc a
forall a. Num a => AddFunc a
afNum Op '[a] a
forall a. Num a => Op '[a] a
signumOp
{-# INLINE signum #-}
abs :: BVar s a -> BVar s a
abs = AddFunc a -> Op '[a] a -> BVar s a -> BVar s a
forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> BVar s b
liftOp1 AddFunc a
forall a. Num a => AddFunc a
afNum Op '[a] a
forall a. Num a => Op '[a] a
absOp
{-# INLINE abs #-}
fromInteger :: Integer -> BVar s a
fromInteger = a -> BVar s a
forall a s. a -> BVar s a
constVar (a -> BVar s a) -> (Integer -> a) -> Integer -> BVar s a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> a
forall a. Num a => Integer -> a
fromInteger
{-# INLINE fromInteger #-}
instance (Fractional a, Reifies s W) => Fractional (BVar s a) where
/ :: BVar s a -> BVar s a -> BVar s a
(/) = AddFunc a
-> AddFunc a -> Op '[a, a] a -> BVar s a -> BVar s a -> BVar s a
forall a b c s.
Reifies s W =>
AddFunc a
-> AddFunc b -> Op '[a, b] c -> BVar s a -> BVar s b -> BVar s c
liftOp2 AddFunc a
forall a. Num a => AddFunc a
afNum AddFunc a
forall a. Num a => AddFunc a
afNum Op '[a, a] a
forall a. Fractional a => Op '[a, a] a
(/.)
{-# INLINE (/) #-}
recip :: BVar s a -> BVar s a
recip = AddFunc a -> Op '[a] a -> BVar s a -> BVar s a
forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> BVar s b
liftOp1 AddFunc a
forall a. Num a => AddFunc a
afNum Op '[a] a
forall a. Fractional a => Op '[a] a
recipOp
{-# INLINE recip #-}
fromRational :: Rational -> BVar s a
fromRational = a -> BVar s a
forall a s. a -> BVar s a
constVar (a -> BVar s a) -> (Rational -> a) -> Rational -> BVar s a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rational -> a
forall a. Fractional a => Rational -> a
fromRational
{-# INLINE fromRational #-}
instance (Floating a, Reifies s W) => Floating (BVar s a) where
pi :: BVar s a
pi = a -> BVar s a
forall a s. a -> BVar s a
constVar a
forall a. Floating a => a
pi
{-# INLINE pi #-}
exp :: BVar s a -> BVar s a
exp = AddFunc a -> Op '[a] a -> BVar s a -> BVar s a
forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> BVar s b
liftOp1 AddFunc a
forall a. Num a => AddFunc a
afNum Op '[a] a
forall a. Floating a => Op '[a] a
expOp
{-# INLINE exp #-}
log :: BVar s a -> BVar s a
log = AddFunc a -> Op '[a] a -> BVar s a -> BVar s a
forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> BVar s b
liftOp1 AddFunc a
forall a. Num a => AddFunc a
afNum Op '[a] a
forall a. Floating a => Op '[a] a
logOp
{-# INLINE log #-}
sqrt :: BVar s a -> BVar s a
sqrt = AddFunc a -> Op '[a] a -> BVar s a -> BVar s a
forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> BVar s b
liftOp1 AddFunc a
forall a. Num a => AddFunc a
afNum Op '[a] a
forall a. Floating a => Op '[a] a
sqrtOp
{-# INLINE sqrt #-}
** :: BVar s a -> BVar s a -> BVar s a
(**) = AddFunc a
-> AddFunc a -> Op '[a, a] a -> BVar s a -> BVar s a -> BVar s a
forall a b c s.
Reifies s W =>
AddFunc a
-> AddFunc b -> Op '[a, b] c -> BVar s a -> BVar s b -> BVar s c
liftOp2 AddFunc a
forall a. Num a => AddFunc a
afNum AddFunc a
forall a. Num a => AddFunc a
afNum Op '[a, a] a
forall a. Floating a => Op '[a, a] a
(**.)
{-# INLINE (**) #-}
logBase :: BVar s a -> BVar s a -> BVar s a
logBase = AddFunc a
-> AddFunc a -> Op '[a, a] a -> BVar s a -> BVar s a -> BVar s a
forall a b c s.
Reifies s W =>
AddFunc a
-> AddFunc b -> Op '[a, b] c -> BVar s a -> BVar s b -> BVar s c
liftOp2 AddFunc a
forall a. Num a => AddFunc a
afNum AddFunc a
forall a. Num a => AddFunc a
afNum Op '[a, a] a
forall a. Floating a => Op '[a, a] a
logBaseOp
{-# INLINE logBase #-}
sin :: BVar s a -> BVar s a
sin = AddFunc a -> Op '[a] a -> BVar s a -> BVar s a
forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> BVar s b
liftOp1 AddFunc a
forall a. Num a => AddFunc a
afNum Op '[a] a
forall a. Floating a => Op '[a] a
sinOp
{-# INLINE sin #-}
cos :: BVar s a -> BVar s a
cos = AddFunc a -> Op '[a] a -> BVar s a -> BVar s a
forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> BVar s b
liftOp1 AddFunc a
forall a. Num a => AddFunc a
afNum Op '[a] a
forall a. Floating a => Op '[a] a
cosOp
{-# INLINE cos #-}
tan :: BVar s a -> BVar s a
tan = AddFunc a -> Op '[a] a -> BVar s a -> BVar s a
forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> BVar s b
liftOp1 AddFunc a
forall a. Num a => AddFunc a
afNum Op '[a] a
forall a. Floating a => Op '[a] a
tanOp
{-# INLINE tan #-}
asin :: BVar s a -> BVar s a
asin = AddFunc a -> Op '[a] a -> BVar s a -> BVar s a
forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> BVar s b
liftOp1 AddFunc a
forall a. Num a => AddFunc a
afNum Op '[a] a
forall a. Floating a => Op '[a] a
asinOp
{-# INLINE asin #-}
acos :: BVar s a -> BVar s a
acos = AddFunc a -> Op '[a] a -> BVar s a -> BVar s a
forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> BVar s b
liftOp1 AddFunc a
forall a. Num a => AddFunc a
afNum Op '[a] a
forall a. Floating a => Op '[a] a
acosOp
{-# INLINE acos #-}
atan :: BVar s a -> BVar s a
atan = AddFunc a -> Op '[a] a -> BVar s a -> BVar s a
forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> BVar s b
liftOp1 AddFunc a
forall a. Num a => AddFunc a
afNum Op '[a] a
forall a. Floating a => Op '[a] a
atanOp
{-# INLINE atan #-}
sinh :: BVar s a -> BVar s a
sinh = AddFunc a -> Op '[a] a -> BVar s a -> BVar s a
forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> BVar s b
liftOp1 AddFunc a
forall a. Num a => AddFunc a
afNum Op '[a] a
forall a. Floating a => Op '[a] a
sinhOp
{-# INLINE sinh #-}
cosh :: BVar s a -> BVar s a
cosh = AddFunc a -> Op '[a] a -> BVar s a -> BVar s a
forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> BVar s b
liftOp1 AddFunc a
forall a. Num a => AddFunc a
afNum Op '[a] a
forall a. Floating a => Op '[a] a
coshOp
{-# INLINE cosh #-}
tanh :: BVar s a -> BVar s a
tanh = AddFunc a -> Op '[a] a -> BVar s a -> BVar s a
forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> BVar s b
liftOp1 AddFunc a
forall a. Num a => AddFunc a
afNum Op '[a] a
forall a. Floating a => Op '[a] a
tanhOp
{-# INLINE tanh #-}
asinh :: BVar s a -> BVar s a
asinh = AddFunc a -> Op '[a] a -> BVar s a -> BVar s a
forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> BVar s b
liftOp1 AddFunc a
forall a. Num a => AddFunc a
afNum Op '[a] a
forall a. Floating a => Op '[a] a
asinhOp
{-# INLINE asinh #-}
acosh :: BVar s a -> BVar s a
acosh = AddFunc a -> Op '[a] a -> BVar s a -> BVar s a
forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> BVar s b
liftOp1 AddFunc a
forall a. Num a => AddFunc a
afNum Op '[a] a
forall a. Floating a => Op '[a] a
acoshOp
{-# INLINE acosh #-}
atanh :: BVar s a -> BVar s a
atanh = AddFunc a -> Op '[a] a -> BVar s a -> BVar s a
forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> BVar s b
liftOp1 AddFunc a
forall a. Num a => AddFunc a
afNum Op '[a] a
forall a. Floating a => Op '[a] a
atanhOp
{-# INLINE atanh #-}
instance Eq a => Eq (BVar s a) where
== :: BVar s a -> BVar s a -> Bool
(==) = a -> a -> Bool
forall a. Eq a => a -> a -> Bool
(==) (a -> a -> Bool) -> (BVar s a -> a) -> BVar s a -> BVar s a -> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` BVar s a -> a
forall s a. BVar s a -> a
_bvVal
/= :: BVar s a -> BVar s a -> Bool
(/=) = a -> a -> Bool
forall a. Eq a => a -> a -> Bool
(/=) (a -> a -> Bool) -> (BVar s a -> a) -> BVar s a -> BVar s a -> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` BVar s a -> a
forall s a. BVar s a -> a
_bvVal
instance Ord a => Ord (BVar s a) where
compare :: BVar s a -> BVar s a -> Ordering
compare = a -> a -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (a -> a -> Ordering)
-> (BVar s a -> a) -> BVar s a -> BVar s a -> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` BVar s a -> a
forall s a. BVar s a -> a
_bvVal
< :: BVar s a -> BVar s a -> Bool
(<) = a -> a -> Bool
forall a. Ord a => a -> a -> Bool
(<) (a -> a -> Bool) -> (BVar s a -> a) -> BVar s a -> BVar s a -> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` BVar s a -> a
forall s a. BVar s a -> a
_bvVal
<= :: BVar s a -> BVar s a -> Bool
(<=) = a -> a -> Bool
forall a. Ord a => a -> a -> Bool
(<=) (a -> a -> Bool) -> (BVar s a -> a) -> BVar s a -> BVar s a -> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` BVar s a -> a
forall s a. BVar s a -> a
_bvVal
> :: BVar s a -> BVar s a -> Bool
(>) = a -> a -> Bool
forall a. Ord a => a -> a -> Bool
(>) (a -> a -> Bool) -> (BVar s a -> a) -> BVar s a -> BVar s a -> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` BVar s a -> a
forall s a. BVar s a -> a
_bvVal
>= :: BVar s a -> BVar s a -> Bool
(>=) = a -> a -> Bool
forall a. Ord a => a -> a -> Bool
(>=) (a -> a -> Bool) -> (BVar s a -> a) -> BVar s a -> BVar s a -> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` BVar s a -> a
forall s a. BVar s a -> a
_bvVal
itraverse ::
forall t a b f.
(Traversable t, Monad f) =>
(Int -> a -> f b) -> t a -> f (t b)
itraverse :: forall (t :: * -> *) a b (f :: * -> *).
(Traversable t, Monad f) =>
(Int -> a -> f b) -> t a -> f (t b)
itraverse Int -> a -> f b
f t a
xs = StateT Int f (t b) -> Int -> f (t b)
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT ((a -> StateT Int f b) -> t a -> StateT Int f (t b)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> t a -> f (t b)
traverse ((Int -> f (b, Int)) -> StateT Int f b
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT ((Int -> f (b, Int)) -> StateT Int f b)
-> (a -> Int -> f (b, Int)) -> a -> StateT Int f b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Int -> f (b, Int)
go) t a
xs) Int
0
where
go :: a -> Int -> f (b, Int)
go :: a -> Int -> f (b, Int)
go a
x Int
i = (,Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (b -> (b, Int)) -> f b -> f (b, Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> a -> f b
f Int
i a
x
{-# INLINE itraverse #-}
itraverse_ ::
forall t a b f.
(Foldable t, Monad f) =>
(Int -> a -> f b) -> t a -> f ()
itraverse_ :: forall (t :: * -> *) a b (f :: * -> *).
(Foldable t, Monad f) =>
(Int -> a -> f b) -> t a -> f ()
itraverse_ Int -> a -> f b
f t a
xs = ((Int, a) -> f b) -> [(Int, a)] -> f ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ ((Int -> a -> f b) -> (Int, a) -> f b
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Int -> a -> f b
f) ([Int] -> [a] -> [(Int, a)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 ..] (t a -> [a]
forall a. t a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList t a
xs))
{-# INLINE itraverse_ #-}
ixi :: Int -> Lens' [a] a
ixi :: forall a. Int -> Lens' [a] a
ixi Int
_ a -> f a
_ [] = String -> f [a]
forall a. String -> a
internalError String
"ixi"
ixi Int
0 a -> f a
f (a
x : [a]
xs) = (a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a]
xs) (a -> [a]) -> f a -> f [a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> f a
f a
x
ixi Int
n a -> f a
f (a
x : [a]
xs) = (a
x a -> [a] -> [a]
forall a. a -> [a] -> [a]
:) ([a] -> [a]) -> f [a] -> f [a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Lens' [a] a
forall a. Int -> Lens' [a] a
ixi (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) a -> f a
f [a]
xs
{-# INLINE ixi #-}
ixt :: forall b a. Traversal' b a -> Int -> Lens' b a
ixt :: forall b a. Traversal' b a -> Int -> Lens' b a
ixt Traversal' b a
t Int
i a -> f a
f b
xs = [a] -> b
stuff ([a] -> b) -> f [a] -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Lens' [a] a
forall a. Int -> Lens' [a] a
ixi Int
i a -> f a
f [a]
contents
where
contents :: [a]
contents = b
xs b -> Getting (Endo [a]) b a -> [a]
forall s a. s -> Getting (Endo [a]) s a -> [a]
^.. Getting (Endo [a]) b a
Traversal' b a
t
stuff :: [a] -> b
stuff = State [a] b -> [a] -> b
forall s a. State s a -> s -> a
evalState (LensLike (StateT [a] Identity) b b a a
-> LensLike (StateT [a] Identity) b b a a
forall (f :: * -> *) s t a b.
LensLike f s t a b -> LensLike f s t a b
traverseOf LensLike (StateT [a] Identity) b b a a
Traversal' b a
t (([a] -> (a, [a])) -> StateT [a] Identity a
forall (m :: * -> *) s a. Monad m => (s -> (a, s)) -> StateT s m a
state (([a] -> (a, [a])) -> StateT [a] Identity a)
-> (a -> [a] -> (a, [a])) -> a -> StateT [a] Identity a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([a] -> (a, [a])) -> a -> [a] -> (a, [a])
forall a b. a -> b -> a
const [a] -> (a, [a])
go) b
xs)
where
go :: [a] -> (a, [a])
go :: [a] -> (a, [a])
go [] = String -> (a, [a])
forall a. String -> a
internalError String
"ixt"
go (a
y : [a]
ys) = (a
y, [a]
ys)
{-# INLINE ixt #-}
instance (Backprop a, Reifies s W) => Backprop (BVar s a) where
zero :: BVar s a -> BVar s a
zero = AddFunc a -> Op '[a] a -> BVar s a -> BVar s a
forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> BVar s b
liftOp1 AddFunc a
forall a. Backprop a => AddFunc a
addFunc (Op '[a] a -> BVar s a -> BVar s a)
-> ((a -> (a, a -> a)) -> Op '[a] a)
-> (a -> (a, a -> a))
-> BVar s a
-> BVar s a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> (a, a -> a)) -> Op '[a] a
forall a b. (a -> (b, b -> a)) -> Op '[a] b
op1 ((a -> (a, a -> a)) -> BVar s a -> BVar s a)
-> (a -> (a, a -> a)) -> BVar s a -> BVar s a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Backprop a => a -> a
zero a
x, a -> a
forall a. Backprop a => a -> a
zero)
{-# INLINE zero #-}
add :: BVar s a -> BVar s a -> BVar s a
add = AddFunc a
-> AddFunc a -> Op '[a, a] a -> BVar s a -> BVar s a -> BVar s a
forall a b c s.
Reifies s W =>
AddFunc a
-> AddFunc b -> Op '[a, b] c -> BVar s a -> BVar s b -> BVar s c
liftOp2 AddFunc a
forall a. Backprop a => AddFunc a
addFunc AddFunc a
forall a. Backprop a => AddFunc a
addFunc (Op '[a, a] a -> BVar s a -> BVar s a -> BVar s a)
-> ((a -> a -> (a, a -> (a, a))) -> Op '[a, a] a)
-> (a -> a -> (a, a -> (a, a)))
-> BVar s a
-> BVar s a
-> BVar s a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> a -> (a, a -> (a, a))) -> Op '[a, a] a
forall a b c. (a -> b -> (c, c -> (a, b))) -> Op '[a, b] c
op2 ((a -> a -> (a, a -> (a, a))) -> BVar s a -> BVar s a -> BVar s a)
-> (a -> a -> (a, a -> (a, a))) -> BVar s a -> BVar s a -> BVar s a
forall a b. (a -> b) -> a -> b
$ \a
x a
y ->
( a -> a -> a
forall a. Backprop a => a -> a -> a
add a
x a
y
, \a
d -> (a
d, a
d)
)
{-# INLINE add #-}
one :: BVar s a -> BVar s a
one = AddFunc a -> Op '[a] a -> BVar s a -> BVar s a
forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> BVar s b
liftOp1 AddFunc a
forall a. Backprop a => AddFunc a
addFunc (Op '[a] a -> BVar s a -> BVar s a)
-> ((a -> (a, a -> a)) -> Op '[a] a)
-> (a -> (a, a -> a))
-> BVar s a
-> BVar s a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> (a, a -> a)) -> Op '[a] a
forall a b. (a -> (b, b -> a)) -> Op '[a] b
op1 ((a -> (a, a -> a)) -> BVar s a -> BVar s a)
-> (a -> (a, a -> a)) -> BVar s a -> BVar s a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Backprop a => a -> a
one a
x, a -> a
forall a. Backprop a => a -> a
zero)
{-# INLINE one #-}
zeroFunc :: Backprop a => ZeroFunc a
zeroFunc :: forall a. Backprop a => ZeroFunc a
zeroFunc = (a -> a) -> ZeroFunc a
forall a. (a -> a) -> ZeroFunc a
ZF a -> a
forall a. Backprop a => a -> a
zero
{-# INLINE zeroFunc #-}
addFunc :: Backprop a => AddFunc a
addFunc :: forall a. Backprop a => AddFunc a
addFunc = (a -> a -> a) -> AddFunc a
forall a. (a -> a -> a) -> AddFunc a
AF a -> a -> a
forall a. Backprop a => a -> a -> a
add
{-# INLINE addFunc #-}
oneFunc :: Backprop a => OneFunc a
oneFunc :: forall a. Backprop a => OneFunc a
oneFunc = (a -> a) -> OneFunc a
forall a. (a -> a) -> OneFunc a
OF a -> a
forall a. Backprop a => a -> a
one
{-# INLINE oneFunc #-}
internalError :: String -> a
internalError :: forall a. String -> a
internalError String
m =
String -> a
forall a. String -> a
errorWithoutStackTrace (String -> a) -> String -> a
forall a b. (a -> b) -> a -> b
$
String
"Numeric.Backprop.Internal." String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
m String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
": unexpected shape involved in gradient computation"