module Language.Futhark.Interpreter.AD
( Op (..),
ADVariable (..),
ADValue (..),
Tape (..),
VJPValue (..),
JVPValue (..),
Counter (..),
Depth (..),
doOp,
addFor,
tapePrimal,
primitive,
varPrimal,
deriveTape,
unionWithM,
unionsWithM,
)
where
import Control.Monad (foldM, zipWithM)
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.Except (ExceptT, catchE, runExceptT, throwE)
import Control.Monad.Trans.State (State, get, modify, runState)
import Data.Either (isRight)
import Data.Foldable (find, foldlM)
import Data.Functor ((<&>))
import Data.Map qualified as M
import Data.Maybe (fromJust, fromMaybe)
import Data.Text qualified as T
import Futhark.AD.Derivatives (pdBinOp, pdBuiltin, pdUnOp)
import Futhark.Analysis.PrimExp (PrimExp (..))
import Language.Futhark.Core (VName (..), nameFromString, nameFromText)
import Language.Futhark.Primitive
( BinOp (Add, FAdd, FMul, LogAnd, LogOr, Mul),
CmpOp,
ConvOp,
Overflow (OverflowWrap),
PrimType (Bool, FloatType, IntType),
PrimValue (BoolValue),
UnOp,
binOpType,
blankPrimValue,
cmpOpType,
convOpType,
doBinOp,
doCmpOp,
doConvOp,
doUnOp,
flipConvOp,
primFuns,
primValueType,
unOpType,
)
newtype Counter = Counter Int
deriving (Counter -> Counter -> Bool
(Counter -> Counter -> Bool)
-> (Counter -> Counter -> Bool) -> Eq Counter
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Counter -> Counter -> Bool
== :: Counter -> Counter -> Bool
$c/= :: Counter -> Counter -> Bool
/= :: Counter -> Counter -> Bool
Eq, Eq Counter
Eq Counter =>
(Counter -> Counter -> Ordering)
-> (Counter -> Counter -> Bool)
-> (Counter -> Counter -> Bool)
-> (Counter -> Counter -> Bool)
-> (Counter -> Counter -> Bool)
-> (Counter -> Counter -> Counter)
-> (Counter -> Counter -> Counter)
-> Ord Counter
Counter -> Counter -> Bool
Counter -> Counter -> Ordering
Counter -> Counter -> Counter
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: Counter -> Counter -> Ordering
compare :: Counter -> Counter -> Ordering
$c< :: Counter -> Counter -> Bool
< :: Counter -> Counter -> Bool
$c<= :: Counter -> Counter -> Bool
<= :: Counter -> Counter -> Bool
$c> :: Counter -> Counter -> Bool
> :: Counter -> Counter -> Bool
$c>= :: Counter -> Counter -> Bool
>= :: Counter -> Counter -> Bool
$cmax :: Counter -> Counter -> Counter
max :: Counter -> Counter -> Counter
$cmin :: Counter -> Counter -> Counter
min :: Counter -> Counter -> Counter
Ord, Integer -> Counter
Counter -> Counter
Counter -> Counter -> Counter
(Counter -> Counter -> Counter)
-> (Counter -> Counter -> Counter)
-> (Counter -> Counter -> Counter)
-> (Counter -> Counter)
-> (Counter -> Counter)
-> (Counter -> Counter)
-> (Integer -> Counter)
-> Num Counter
forall a.
(a -> a -> a)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (Integer -> a)
-> Num a
$c+ :: Counter -> Counter -> Counter
+ :: Counter -> Counter -> Counter
$c- :: Counter -> Counter -> Counter
- :: Counter -> Counter -> Counter
$c* :: Counter -> Counter -> Counter
* :: Counter -> Counter -> Counter
$cnegate :: Counter -> Counter
negate :: Counter -> Counter
$cabs :: Counter -> Counter
abs :: Counter -> Counter
$csignum :: Counter -> Counter
signum :: Counter -> Counter
$cfromInteger :: Integer -> Counter
fromInteger :: Integer -> Counter
Num, Int -> Counter -> ShowS
[Counter] -> ShowS
Counter -> String
(Int -> Counter -> ShowS)
-> (Counter -> String) -> ([Counter] -> ShowS) -> Show Counter
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Counter -> ShowS
showsPrec :: Int -> Counter -> ShowS
$cshow :: Counter -> String
show :: Counter -> String
$cshowList :: [Counter] -> ShowS
showList :: [Counter] -> ShowS
Show)
type ADMonad = ExceptT String (State Counter)
incCounter :: ADMonad ()
incCounter :: ADMonad ()
incCounter = State Counter () -> ADMonad ()
forall (m :: * -> *) a. Monad m => m a -> ExceptT String m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (State Counter () -> ADMonad ()) -> State Counter () -> ADMonad ()
forall a b. (a -> b) -> a -> b
$ (Counter -> Counter) -> State Counter ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((Counter -> Counter) -> State Counter ())
-> (Counter -> Counter) -> State Counter ()
forall a b. (a -> b) -> a -> b
$ \Counter
i -> Counter
i Counter -> Counter -> Counter
forall a. Num a => a -> a -> a
+ Counter
1
data Op
= OpBin BinOp
| OpCmp CmpOp
| OpUn UnOp
| OpFn T.Text
| OpConv ConvOp
deriving (Int -> Op -> ShowS
[Op] -> ShowS
Op -> String
(Int -> Op -> ShowS)
-> (Op -> String) -> ([Op] -> ShowS) -> Show Op
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Op -> ShowS
showsPrec :: Int -> Op -> ShowS
$cshow :: Op -> String
show :: Op -> String
$cshowList :: [Op] -> ShowS
showList :: [Op] -> ShowS
Show)
opTypeMatch :: Op -> [PrimType] -> Bool
opTypeMatch :: Op -> [PrimType] -> Bool
opTypeMatch (OpBin BinOp
op) [PrimType]
p = (PrimType -> Bool) -> [PrimType] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\PrimType
x -> BinOp -> PrimType
binOpType BinOp
op PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
x) [PrimType]
p
opTypeMatch (OpCmp CmpOp
op) [PrimType]
p = (PrimType -> Bool) -> [PrimType] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\PrimType
x -> CmpOp -> PrimType
cmpOpType CmpOp
op PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
x) [PrimType]
p
opTypeMatch (OpUn UnOp
op) [PrimType]
p = (PrimType -> Bool) -> [PrimType] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\PrimType
x -> UnOp -> PrimType
unOpType UnOp
op PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
x) [PrimType]
p
opTypeMatch (OpConv ConvOp
op) [PrimType]
p = (PrimType -> Bool) -> [PrimType] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\PrimType
x -> (PrimType, PrimType) -> PrimType
forall a b. (a, b) -> a
fst (ConvOp -> (PrimType, PrimType)
convOpType ConvOp
op) PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
x) [PrimType]
p
opTypeMatch (OpFn Text
fn) [PrimType]
p = case Text
-> Map Text ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
-> Maybe ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Text
fn Map Text ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
primFuns of
Just ([PrimType]
t, PrimType
_, [PrimValue] -> Maybe PrimValue
_) -> [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ (PrimType -> PrimType -> Bool)
-> [PrimType] -> [PrimType] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
(==) [PrimType]
t [PrimType]
p
Maybe ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
Nothing -> String -> Bool
forall a. HasCallStack => String -> a
error String
"opTypeMatch"
opReturnType :: Op -> PrimType
opReturnType :: Op -> PrimType
opReturnType (OpBin BinOp
op) = BinOp -> PrimType
binOpType BinOp
op
opReturnType (OpCmp CmpOp
op) = CmpOp -> PrimType
cmpOpType CmpOp
op
opReturnType (OpUn UnOp
op) = UnOp -> PrimType
unOpType UnOp
op
opReturnType (OpConv ConvOp
op) = (PrimType, PrimType) -> PrimType
forall a b. (a, b) -> b
snd ((PrimType, PrimType) -> PrimType)
-> (PrimType, PrimType) -> PrimType
forall a b. (a -> b) -> a -> b
$ ConvOp -> (PrimType, PrimType)
convOpType ConvOp
op
opReturnType (OpFn Text
fn) = case Text
-> Map Text ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
-> Maybe ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Text
fn Map Text ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
primFuns of
Just ([PrimType]
_, PrimType
t, [PrimValue] -> Maybe PrimValue
_) -> PrimType
t
Maybe ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
Nothing -> String -> PrimType
forall a. HasCallStack => String -> a
error String
"opReturnType"
addFor :: PrimType -> BinOp
addFor :: PrimType -> BinOp
addFor (IntType IntType
t) = IntType -> Overflow -> BinOp
Add IntType
t Overflow
OverflowWrap
addFor (FloatType FloatType
t) = FloatType -> BinOp
FAdd FloatType
t
addFor PrimType
Bool = BinOp
LogOr
addFor PrimType
t = String -> BinOp
forall a. HasCallStack => String -> a
error (String -> BinOp) -> String -> BinOp
forall a b. (a -> b) -> a -> b
$ String
"addFor: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ PrimType -> String
forall a. Show a => a -> String
show PrimType
t
mulFor :: PrimType -> BinOp
mulFor :: PrimType -> BinOp
mulFor (IntType IntType
t) = IntType -> Overflow -> BinOp
Mul IntType
t Overflow
OverflowWrap
mulFor (FloatType FloatType
t) = FloatType -> BinOp
FMul FloatType
t
mulFor PrimType
Bool = BinOp
LogAnd
mulFor PrimType
t = String -> BinOp
forall a. HasCallStack => String -> a
error (String -> BinOp) -> String -> BinOp
forall a b. (a -> b) -> a -> b
$ String
"mulFor: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ PrimType -> String
forall a. Show a => a -> String
show PrimType
t
newtype Depth = Depth Int
deriving (Eq Depth
Eq Depth =>
(Depth -> Depth -> Ordering)
-> (Depth -> Depth -> Bool)
-> (Depth -> Depth -> Bool)
-> (Depth -> Depth -> Bool)
-> (Depth -> Depth -> Bool)
-> (Depth -> Depth -> Depth)
-> (Depth -> Depth -> Depth)
-> Ord Depth
Depth -> Depth -> Bool
Depth -> Depth -> Ordering
Depth -> Depth -> Depth
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: Depth -> Depth -> Ordering
compare :: Depth -> Depth -> Ordering
$c< :: Depth -> Depth -> Bool
< :: Depth -> Depth -> Bool
$c<= :: Depth -> Depth -> Bool
<= :: Depth -> Depth -> Bool
$c> :: Depth -> Depth -> Bool
> :: Depth -> Depth -> Bool
$c>= :: Depth -> Depth -> Bool
>= :: Depth -> Depth -> Bool
$cmax :: Depth -> Depth -> Depth
max :: Depth -> Depth -> Depth
$cmin :: Depth -> Depth -> Depth
min :: Depth -> Depth -> Depth
Ord, Depth -> Depth -> Bool
(Depth -> Depth -> Bool) -> (Depth -> Depth -> Bool) -> Eq Depth
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Depth -> Depth -> Bool
== :: Depth -> Depth -> Bool
$c/= :: Depth -> Depth -> Bool
/= :: Depth -> Depth -> Bool
Eq, Int -> Depth -> ShowS
[Depth] -> ShowS
Depth -> String
(Int -> Depth -> ShowS)
-> (Depth -> String) -> ([Depth] -> ShowS) -> Show Depth
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Depth -> ShowS
showsPrec :: Int -> Depth -> ShowS
$cshow :: Depth -> String
show :: Depth -> String
$cshowList :: [Depth] -> ShowS
showList :: [Depth] -> ShowS
Show)
data ADValue
= Variable Depth ADVariable
| Constant PrimValue
deriving (Int -> ADValue -> ShowS
[ADValue] -> ShowS
ADValue -> String
(Int -> ADValue -> ShowS)
-> (ADValue -> String) -> ([ADValue] -> ShowS) -> Show ADValue
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ADValue -> ShowS
showsPrec :: Int -> ADValue -> ShowS
$cshow :: ADValue -> String
show :: ADValue -> String
$cshowList :: [ADValue] -> ShowS
showList :: [ADValue] -> ShowS
Show)
data ADVariable
= VJP VJPValue
| JVP JVPValue
deriving (Int -> ADVariable -> ShowS
[ADVariable] -> ShowS
ADVariable -> String
(Int -> ADVariable -> ShowS)
-> (ADVariable -> String)
-> ([ADVariable] -> ShowS)
-> Show ADVariable
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ADVariable -> ShowS
showsPrec :: Int -> ADVariable -> ShowS
$cshow :: ADVariable -> String
show :: ADVariable -> String
$cshowList :: [ADVariable] -> ShowS
showList :: [ADVariable] -> ShowS
Show)
depth :: ADValue -> Depth
depth :: ADValue -> Depth
depth (Variable Depth
d ADVariable
_) = Depth
d
depth (Constant PrimValue
_) = Int -> Depth
Depth Int
0
primal :: ADValue -> ADValue
primal :: ADValue -> ADValue
primal (Variable Depth
_ (VJP (VJPValue Tape
t))) = Tape -> ADValue
tapePrimal Tape
t
primal (Variable Depth
_ (JVP (JVPValue ADValue
v ADValue
_))) = ADValue -> ADValue
primal ADValue
v
primal (Constant PrimValue
v) = PrimValue -> ADValue
Constant PrimValue
v
primalFor :: Depth -> ADValue -> ADValue
primalFor :: Depth -> ADValue -> ADValue
primalFor Depth
cur v :: ADValue
v@(Variable Depth
tag ADVariable
_) | Depth
cur Depth -> Depth -> Bool
forall a. Eq a => a -> a -> Bool
/= Depth
tag = ADValue
v
primalFor Depth
_ (Variable Depth
_ (VJP (VJPValue Tape
t))) = Tape -> ADValue
tapePrimal Tape
t
primalFor Depth
cur (Variable Depth
_ (JVP (JVPValue ADValue
v ADValue
_))) = Depth -> ADValue -> ADValue
primalFor Depth
cur ADValue
v
primalFor Depth
_ (Constant PrimValue
v) = PrimValue -> ADValue
Constant PrimValue
v
primitive :: ADValue -> PrimValue
primitive :: ADValue -> PrimValue
primitive (Variable Depth
_ ADVariable
v) = ADVariable -> PrimValue
varPrimal ADVariable
v
primitive (Constant PrimValue
v) = PrimValue
v
varPrimal :: ADVariable -> PrimValue
varPrimal :: ADVariable -> PrimValue
varPrimal (VJP (VJPValue Tape
t)) = ADValue -> PrimValue
primitive (ADValue -> PrimValue) -> ADValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ Tape -> ADValue
tapePrimal Tape
t
varPrimal (JVP (JVPValue ADValue
v ADValue
_)) = ADValue -> PrimValue
primitive (ADValue -> PrimValue) -> ADValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ ADValue -> ADValue
primal ADValue
v
evalPrimExp :: M.Map VName ADValue -> PrimExp VName -> ADMonad ADValue
evalPrimExp :: Map VName ADValue -> PrimExp VName -> ADMonad ADValue
evalPrimExp Map VName ADValue
m (LeafExp VName
n PrimType
_) =
ADMonad ADValue
-> (ADValue -> ADMonad ADValue) -> Maybe ADValue -> ADMonad ADValue
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (String -> ADMonad ADValue
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (String -> ADMonad ADValue) -> String -> ADMonad ADValue
forall a b. (a -> b) -> a -> b
$ String
"Unknown variable " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> VName -> String
forall a. Show a => a -> String
show VName
n) ADValue -> ADMonad ADValue
forall a. a -> ExceptT String (State Counter) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe ADValue -> ADMonad ADValue)
-> Maybe ADValue -> ADMonad ADValue
forall a b. (a -> b) -> a -> b
$ VName -> Map VName ADValue -> Maybe ADValue
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
n Map VName ADValue
m
evalPrimExp Map VName ADValue
_ (ValueExp PrimValue
pv) =
ADValue -> ADMonad ADValue
forall a. a -> ExceptT String (State Counter) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADValue -> ADMonad ADValue) -> ADValue -> ADMonad ADValue
forall a b. (a -> b) -> a -> b
$ PrimValue -> ADValue
Constant PrimValue
pv
evalPrimExp Map VName ADValue
m (BinOpExp BinOp
op PrimExp VName
x PrimExp VName
y) = do
ADValue
x' <- Map VName ADValue -> PrimExp VName -> ADMonad ADValue
evalPrimExp Map VName ADValue
m PrimExp VName
x
ADValue
y' <- Map VName ADValue -> PrimExp VName -> ADMonad ADValue
evalPrimExp Map VName ADValue
m PrimExp VName
y
Op -> [ADValue] -> ADMonad ADValue
doOp' (BinOp -> Op
OpBin BinOp
op) [ADValue
x', ADValue
y']
evalPrimExp Map VName ADValue
m (CmpOpExp CmpOp
op PrimExp VName
x PrimExp VName
y) = do
ADValue
x' <- Map VName ADValue -> PrimExp VName -> ADMonad ADValue
evalPrimExp Map VName ADValue
m PrimExp VName
x
ADValue
y' <- Map VName ADValue -> PrimExp VName -> ADMonad ADValue
evalPrimExp Map VName ADValue
m PrimExp VName
y
Op -> [ADValue] -> ADMonad ADValue
doOp' (CmpOp -> Op
OpCmp CmpOp
op) [ADValue
x', ADValue
y']
evalPrimExp Map VName ADValue
m (UnOpExp UnOp
op PrimExp VName
x) = do
ADValue
x' <- Map VName ADValue -> PrimExp VName -> ADMonad ADValue
evalPrimExp Map VName ADValue
m PrimExp VName
x
Op -> [ADValue] -> ADMonad ADValue
doOp' (UnOp -> Op
OpUn UnOp
op) [ADValue
x']
evalPrimExp Map VName ADValue
m (ConvOpExp ConvOp
op PrimExp VName
x) = do
ADValue
x' <- Map VName ADValue -> PrimExp VName -> ADMonad ADValue
evalPrimExp Map VName ADValue
m PrimExp VName
x
Op -> [ADValue] -> ADMonad ADValue
doOp' (ConvOp -> Op
OpConv ConvOp
op) [ADValue
x']
evalPrimExp Map VName ADValue
m (FunExp Text
fn [PrimExp VName]
p PrimType
_) = do
[ADValue]
p' <- (PrimExp VName -> ADMonad ADValue)
-> [PrimExp VName] -> ExceptT String (State Counter) [ADValue]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Map VName ADValue -> PrimExp VName -> ADMonad ADValue
evalPrimExp Map VName ADValue
m) [PrimExp VName]
p
Op -> [ADValue] -> ADMonad ADValue
doOp' (Text -> Op
OpFn Text
fn) [ADValue]
p'
lookupPDs :: Op -> [PrimExp VName] -> Maybe [PrimExp VName]
lookupPDs :: Op -> [PrimExp VName] -> Maybe [PrimExp VName]
lookupPDs (OpBin BinOp
op) [PrimExp VName
x, PrimExp VName
y] = [PrimExp VName] -> Maybe [PrimExp VName]
forall a. a -> Maybe a
Just ([PrimExp VName] -> Maybe [PrimExp VName])
-> [PrimExp VName] -> Maybe [PrimExp VName]
forall a b. (a -> b) -> a -> b
$ do
let (PrimExp VName
a, PrimExp VName
b) = BinOp
-> PrimExp VName -> PrimExp VName -> (PrimExp VName, PrimExp VName)
pdBinOp BinOp
op PrimExp VName
x PrimExp VName
y
[PrimExp VName
a, PrimExp VName
b]
lookupPDs (OpUn UnOp
op) [PrimExp VName
x] = [PrimExp VName] -> Maybe [PrimExp VName]
forall a. a -> Maybe a
Just [UnOp -> PrimExp VName -> PrimExp VName
pdUnOp UnOp
op PrimExp VName
x]
lookupPDs (OpFn Text
fn) [PrimExp VName]
p = Name -> [PrimExp VName] -> Maybe [PrimExp VName]
pdBuiltin (Text -> Name
nameFromText Text
fn) [PrimExp VName]
p
lookupPDs Op
_ [PrimExp VName]
_ = Maybe [PrimExp VName]
forall a. Maybe a
Nothing
doOp :: Op -> [ADValue] -> Counter -> Either String (ADValue, Counter)
doOp :: Op -> [ADValue] -> Counter -> Either String (ADValue, Counter)
doOp Op
op [ADValue]
o Counter
uid = case State Counter (Either String ADValue)
-> Counter -> (Either String ADValue, Counter)
forall s a. State s a -> s -> (a, s)
runState (ADMonad ADValue -> State Counter (Either String ADValue)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ADMonad ADValue -> State Counter (Either String ADValue))
-> ADMonad ADValue -> State Counter (Either String ADValue)
forall a b. (a -> b) -> a -> b
$ Op -> [ADValue] -> ADMonad ADValue
doOp' Op
op [ADValue]
o) Counter
uid of
(Left String
s, Counter
_) -> String -> Either String (ADValue, Counter)
forall a b. a -> Either a b
Left String
s
(Right ADValue
v, Counter
uid') -> (ADValue, Counter) -> Either String (ADValue, Counter)
forall a b. b -> Either a b
Right (ADValue
v, Counter
uid')
doOp' :: Op -> [ADValue] -> ADMonad ADValue
doOp' :: Op -> [ADValue] -> ADMonad ADValue
doOp' Op
op [ADValue]
o
| Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Op -> [PrimType] -> Bool
opTypeMatch Op
op ((PrimValue -> PrimType) -> [PrimValue] -> [PrimType]
forall a b. (a -> b) -> [a] -> [b]
map PrimValue -> PrimType
primValueType [PrimValue]
pv) =
String -> ADMonad ADValue
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (String -> ADMonad ADValue) -> String -> ADMonad ADValue
forall a b. (a -> b) -> a -> b
$ [String] -> String
unwords [String
"invalid types for op", Op -> String
forall a. Show a => a -> String
show Op
op, String
"and operands", [ADValue] -> String
forall a. Show a => a -> String
show [ADValue]
o]
| Bool
otherwise = do
let dep :: Depth
dep = case Op
op of
OpCmp CmpOp
_ -> Int -> Depth
Depth Int
0
Op
_ -> [Depth] -> Depth
forall a. Ord a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum ((ADValue -> Depth) -> [ADValue] -> [Depth]
forall a b. (a -> b) -> [a] -> [b]
map ADValue -> Depth
depth [ADValue]
o)
if Depth
dep Depth -> Depth -> Bool
forall a. Eq a => a -> a -> Bool
== Int -> Depth
Depth Int
0
then ADMonad ADValue
-> (ADValue -> ADMonad ADValue) -> Maybe ADValue -> ADMonad ADValue
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (String -> ADMonad ADValue
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE String
"failed to evaluate const") ADValue -> ADMonad ADValue
forall a. a -> ExceptT String (State Counter) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe ADValue
constCase ADMonad ADValue -> ADMonad () -> ADMonad ADValue
forall a b.
ExceptT String (State Counter) a
-> ExceptT String (State Counter) b
-> ExceptT String (State Counter) a
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* ADMonad ()
incCounter
else Depth -> ADMonad ADValue
nonconstCase Depth
dep
where
pv :: [PrimValue]
pv = (ADValue -> PrimValue) -> [ADValue] -> [PrimValue]
forall a b. (a -> b) -> [a] -> [b]
map ADValue -> PrimValue
primitive [ADValue]
o
divideDepths :: Depth -> ADValue -> Either ADValue ADVariable
divideDepths :: Depth -> ADValue -> Either ADValue ADVariable
divideDepths Depth
_ v :: ADValue
v@(Constant {}) = ADValue -> Either ADValue ADVariable
forall a b. a -> Either a b
Left ADValue
v
divideDepths Depth
d v :: ADValue
v@(Variable Depth
d' ADVariable
v') = if Depth
d' Depth -> Depth -> Bool
forall a. Ord a => a -> a -> Bool
< Depth
d then ADValue -> Either ADValue ADVariable
forall a b. a -> Either a b
Left ADValue
v else ADVariable -> Either ADValue ADVariable
forall a b. b -> Either a b
Right ADVariable
v'
extractVJP :: Either ADValue ADVariable -> Either ADValue VJPValue
extractVJP :: Either ADValue ADVariable -> Either ADValue VJPValue
extractVJP (Right (VJP VJPValue
v)) = VJPValue -> Either ADValue VJPValue
forall a b. b -> Either a b
Right VJPValue
v
extractVJP (Left ADValue
v) = ADValue -> Either ADValue VJPValue
forall a b. a -> Either a b
Left ADValue
v
extractVJP Either ADValue ADVariable
_ =
String -> Either ADValue VJPValue
forall a. HasCallStack => String -> a
error String
"extractVJP"
extractJVP :: Either ADValue ADVariable -> Either ADValue JVPValue
extractJVP :: Either ADValue ADVariable -> Either ADValue JVPValue
extractJVP (Right (JVP JVPValue
v)) = JVPValue -> Either ADValue JVPValue
forall a b. b -> Either a b
Right JVPValue
v
extractJVP (Left ADValue
v) = ADValue -> Either ADValue JVPValue
forall a b. a -> Either a b
Left ADValue
v
extractJVP Either ADValue ADVariable
_ =
String -> Either ADValue JVPValue
forall a. HasCallStack => String -> a
error String
"extractJVP"
constCase :: Maybe ADValue
constCase =
PrimValue -> ADValue
Constant (PrimValue -> ADValue) -> Maybe PrimValue -> Maybe ADValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> case (Op
op, [PrimValue]
pv) of
(OpBin BinOp
op', [PrimValue
x, PrimValue
y]) -> BinOp -> PrimValue -> PrimValue -> Maybe PrimValue
doBinOp BinOp
op' PrimValue
x PrimValue
y
(OpCmp CmpOp
op', [PrimValue
x, PrimValue
y]) -> Bool -> PrimValue
BoolValue (Bool -> PrimValue) -> Maybe Bool -> Maybe PrimValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CmpOp -> PrimValue -> PrimValue -> Maybe Bool
doCmpOp CmpOp
op' PrimValue
x PrimValue
y
(OpUn UnOp
op', [PrimValue
x]) -> UnOp -> PrimValue -> Maybe PrimValue
doUnOp UnOp
op' PrimValue
x
(OpConv ConvOp
op', [PrimValue
x]) -> ConvOp -> PrimValue -> Maybe PrimValue
doConvOp ConvOp
op' PrimValue
x
(OpFn Text
fn, [PrimValue]
_) -> do
([PrimType]
_, PrimType
_, [PrimValue] -> Maybe PrimValue
f) <- Text
-> Map Text ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
-> Maybe ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Text
fn Map Text ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
primFuns
[PrimValue] -> Maybe PrimValue
f [PrimValue]
pv
(Op, [PrimValue])
_ -> String -> Maybe PrimValue
forall a. HasCallStack => String -> a
error String
"doOp': opTypeMatch"
nonconstCase :: Depth -> ADMonad ADValue
nonconstCase Depth
dep = do
let oprev :: [ADValue]
oprev = (ADValue -> ADValue) -> [ADValue] -> [ADValue]
forall a b. (a -> b) -> [a] -> [b]
map (Depth -> ADValue -> ADValue
primalFor Depth
dep) [ADValue]
o
ADValue
vprev <- Op -> [ADValue] -> ADMonad ADValue
doOp' Op
op [ADValue]
oprev
let o' :: [Either ADValue ADVariable]
o' = (ADValue -> Either ADValue ADVariable)
-> [ADValue] -> [Either ADValue ADVariable]
forall a b. (a -> b) -> [a] -> [b]
map (Depth -> ADValue -> Either ADValue ADVariable
divideDepths Depth
dep) [ADValue]
o
case (Either ADValue ADVariable -> Bool)
-> [Either ADValue ADVariable] -> Maybe (Either ADValue ADVariable)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find Either ADValue ADVariable -> Bool
forall a b. Either a b -> Bool
isRight [Either ADValue ADVariable]
o' of
Just (Right (VJP {})) ->
Depth -> ADVariable -> ADValue
Variable Depth
dep (ADVariable -> ADValue) -> (Tape -> ADVariable) -> Tape -> ADValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VJPValue -> ADVariable
VJP (VJPValue -> ADVariable)
-> (Tape -> VJPValue) -> Tape -> ADVariable
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tape -> VJPValue
VJPValue
(Tape -> ADValue)
-> ExceptT String (State Counter) Tape -> ADMonad ADValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Op
-> [Either ADValue VJPValue]
-> ADValue
-> ExceptT String (State Counter) Tape
vjpHandleOp Op
op ((Either ADValue ADVariable -> Either ADValue VJPValue)
-> [Either ADValue ADVariable] -> [Either ADValue VJPValue]
forall a b. (a -> b) -> [a] -> [b]
map Either ADValue ADVariable -> Either ADValue VJPValue
extractVJP [Either ADValue ADVariable]
o') ADValue
vprev
Just (Right (JVP {})) ->
Depth -> ADVariable -> ADValue
Variable Depth
dep (ADVariable -> ADValue)
-> (ADValue -> ADVariable) -> ADValue -> ADValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. JVPValue -> ADVariable
JVP (JVPValue -> ADVariable)
-> (ADValue -> JVPValue) -> ADValue -> ADVariable
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ADValue -> ADValue -> JVPValue
JVPValue ADValue
vprev
(ADValue -> ADValue) -> ADMonad ADValue -> ADMonad ADValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Op -> [Either ADValue JVPValue] -> ADMonad ADValue
jvpHandleOp Op
op ((Either ADValue ADVariable -> Either ADValue JVPValue)
-> [Either ADValue ADVariable] -> [Either ADValue JVPValue]
forall a b. (a -> b) -> [a] -> [b]
map Either ADValue ADVariable -> Either ADValue JVPValue
extractJVP [Either ADValue ADVariable]
o')
Maybe (Either ADValue ADVariable)
_ ->
String -> ADMonad ADValue
forall a. HasCallStack => String -> a
error String
"find isRight"
calculatePDs :: Op -> [ADValue] -> ADMonad [ADValue]
calculatePDs :: Op -> [ADValue] -> ExceptT String (State Counter) [ADValue]
calculatePDs Op
op [ADValue]
args =
let n :: [VName]
n = (Int -> VName) -> [Int] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (\Int
i -> Name -> Int -> VName
VName (String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
"x" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i) Int
i) [Int
1 .. [ADValue] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ADValue]
args]
m :: Map VName ADValue
m = [(VName, ADValue)] -> Map VName ADValue
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, ADValue)] -> Map VName ADValue)
-> [(VName, ADValue)] -> Map VName ADValue
forall a b. (a -> b) -> a -> b
$ [VName] -> [ADValue] -> [(VName, ADValue)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
n [ADValue]
args
pde :: [PrimExp VName]
pde =
[PrimExp VName] -> Maybe [PrimExp VName] -> [PrimExp VName]
forall a. a -> Maybe a -> a
fromMaybe (String -> [PrimExp VName]
forall a. HasCallStack => String -> a
error String
"lookupPDs failed") (Maybe [PrimExp VName] -> [PrimExp VName])
-> Maybe [PrimExp VName] -> [PrimExp VName]
forall a b. (a -> b) -> a -> b
$
Op -> [PrimExp VName] -> Maybe [PrimExp VName]
lookupPDs Op
op ([PrimExp VName] -> Maybe [PrimExp VName])
-> [PrimExp VName] -> Maybe [PrimExp VName]
forall a b. (a -> b) -> a -> b
$
(VName -> ADValue -> PrimExp VName)
-> [VName] -> [ADValue] -> [PrimExp VName]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\VName
v ADValue
val -> VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
v (PrimType -> PrimExp VName) -> PrimType -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType (PrimValue -> PrimType) -> PrimValue -> PrimType
forall a b. (a -> b) -> a -> b
$ ADValue -> PrimValue
primitive ADValue
val) [VName]
n [ADValue]
args
res :: ExceptT e' (State Counter) [ADValue]
res = (PrimExp VName -> ExceptT e' (State Counter) ADValue)
-> [PrimExp VName] -> ExceptT e' (State Counter) [ADValue]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\PrimExp VName
x -> ADMonad ADValue
-> (String -> ExceptT e' (State Counter) ADValue)
-> ExceptT e' (State Counter) ADValue
forall (m :: * -> *) e a e'.
Monad m =>
ExceptT e m a -> (e -> ExceptT e' m a) -> ExceptT e' m a
catchE (Map VName ADValue -> PrimExp VName -> ADMonad ADValue
evalPrimExp Map VName ADValue
m PrimExp VName
x) ((String -> ExceptT e' (State Counter) ADValue)
-> ExceptT e' (State Counter) ADValue)
-> (String -> ExceptT e' (State Counter) ADValue)
-> ExceptT e' (State Counter) ADValue
forall a b. (a -> b) -> a -> b
$ String -> ExceptT e' (State Counter) ADValue
forall a. HasCallStack => String -> a
error (String -> ExceptT e' (State Counter) ADValue)
-> ShowS -> String -> ExceptT e' (State Counter) ADValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String
"evalPrimExp failed: " <>)) [PrimExp VName]
pde
in ExceptT String (State Counter) [ADValue]
forall {e'}. ExceptT e' (State Counter) [ADValue]
res
newtype VJPValue = VJPValue Tape
deriving (Int -> VJPValue -> ShowS
[VJPValue] -> ShowS
VJPValue -> String
(Int -> VJPValue -> ShowS)
-> (VJPValue -> String) -> ([VJPValue] -> ShowS) -> Show VJPValue
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> VJPValue -> ShowS
showsPrec :: Int -> VJPValue -> ShowS
$cshow :: VJPValue -> String
show :: VJPValue -> String
$cshowList :: [VJPValue] -> ShowS
showList :: [VJPValue] -> ShowS
Show)
data Tape
=
TapeID Counter ADValue
|
TapeConst ADValue
|
TapeOp Op [Tape] Counter ADValue
deriving (Int -> Tape -> ShowS
[Tape] -> ShowS
Tape -> String
(Int -> Tape -> ShowS)
-> (Tape -> String) -> ([Tape] -> ShowS) -> Show Tape
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Tape -> ShowS
showsPrec :: Int -> Tape -> ShowS
$cshow :: Tape -> String
show :: Tape -> String
$cshowList :: [Tape] -> ShowS
showList :: [Tape] -> ShowS
Show)
tapePrimal :: Tape -> ADValue
tapePrimal :: Tape -> ADValue
tapePrimal (TapeID Counter
_ ADValue
v) = ADValue
v
tapePrimal (TapeConst ADValue
v) = ADValue
v
tapePrimal (TapeOp Op
_ [Tape]
_ Counter
_ ADValue
v) = ADValue
v
vjpHandleOp :: Op -> [Either ADValue VJPValue] -> ADValue -> ADMonad Tape
vjpHandleOp :: Op
-> [Either ADValue VJPValue]
-> ADValue
-> ExceptT String (State Counter) Tape
vjpHandleOp Op
op [Either ADValue VJPValue]
p ADValue
v = do
Counter
i <- State Counter Counter -> ExceptT String (State Counter) Counter
forall (m :: * -> *) a. Monad m => m a -> ExceptT String m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift State Counter Counter
forall (m :: * -> *) s. Monad m => StateT s m s
get
Tape -> ExceptT String (State Counter) Tape
forall a. a -> ExceptT String (State Counter) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tape -> ExceptT String (State Counter) Tape)
-> Tape -> ExceptT String (State Counter) Tape
forall a b. (a -> b) -> a -> b
$ Op -> [Tape] -> Counter -> ADValue -> Tape
TapeOp Op
op ((Either ADValue VJPValue -> Tape)
-> [Either ADValue VJPValue] -> [Tape]
forall a b. (a -> b) -> [a] -> [b]
map Either ADValue VJPValue -> Tape
toTape [Either ADValue VJPValue]
p) Counter
i ADValue
v
where
toTape :: Either ADValue VJPValue -> Tape
toTape (Left ADValue
v') = ADValue -> Tape
TapeConst ADValue
v'
toTape (Right (VJPValue Tape
t)) = Tape
t
unionWithM :: (Monad m, Ord k) => (a -> a -> m a) -> M.Map k a -> M.Map k a -> m (M.Map k a)
unionWithM :: forall (m :: * -> *) k a.
(Monad m, Ord k) =>
(a -> a -> m a) -> Map k a -> Map k a -> m (Map k a)
unionWithM a -> a -> m a
f Map k a
m1 Map k a
m2 = do
let m :: Map k a
m = Map k a -> Map k a -> Map k a
forall k a. Ord k => Map k a -> Map k a -> Map k a
M.union (Map k a -> Map k a -> Map k a
forall k a b. Ord k => Map k a -> Map k b -> Map k a
M.difference Map k a
m1 Map k a
m2) (Map k a -> Map k a -> Map k a
forall k a b. Ord k => Map k a -> Map k b -> Map k a
M.difference Map k a
m2 Map k a
m1)
let k :: [k]
k = Map k a -> [k]
forall k a. Map k a -> [k]
M.keys (Map k a -> [k]) -> Map k a -> [k]
forall a b. (a -> b) -> a -> b
$ Map k a -> Map k a -> Map k a
forall k a b. Ord k => Map k a -> Map k b -> Map k a
M.intersection Map k a
m1 Map k a
m2
[a]
v <- (k -> m a) -> [k] -> m [a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\k
k' -> a -> a -> m a
f (Maybe a -> a
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe a -> a) -> Maybe a -> a
forall a b. (a -> b) -> a -> b
$ k -> Map k a -> Maybe a
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup k
k' Map k a
m1) (Maybe a -> a
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe a -> a) -> Maybe a -> a
forall a b. (a -> b) -> a -> b
$ k -> Map k a -> Maybe a
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup k
k' Map k a
m2)) [k]
k
Map k a -> m (Map k a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map k a -> m (Map k a)) -> Map k a -> m (Map k a)
forall a b. (a -> b) -> a -> b
$ (Map k a -> (k, a) -> Map k a) -> Map k a -> [(k, a)] -> Map k a
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\Map k a
m' (k
k', a
v') -> k -> a -> Map k a -> Map k a
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert k
k' a
v' Map k a
m') Map k a
m ([k] -> [a] -> [(k, a)]
forall a b. [a] -> [b] -> [(a, b)]
zip [k]
k [a]
v)
unionsWithM :: (Foldable f, Monad m, Ord k) => (a -> a -> m a) -> f (M.Map k a) -> m (M.Map k a)
unionsWithM :: forall (f :: * -> *) (m :: * -> *) k a.
(Foldable f, Monad m, Ord k) =>
(a -> a -> m a) -> f (Map k a) -> m (Map k a)
unionsWithM a -> a -> m a
f = (Map k a -> Map k a -> m (Map k a))
-> Map k a -> f (Map k a) -> m (Map k a)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ((a -> a -> m a) -> Map k a -> Map k a -> m (Map k a)
forall (m :: * -> *) k a.
(Monad m, Ord k) =>
(a -> a -> m a) -> Map k a -> Map k a -> m (Map k a)
unionWithM a -> a -> m a
f) Map k a
forall k a. Map k a
M.empty
deriveTape :: Tape -> ADValue -> Counter -> Either String (M.Map Counter ADValue, Counter)
deriveTape :: Tape
-> ADValue
-> Counter
-> Either String (Map Counter ADValue, Counter)
deriveTape Tape
tp ADValue
s Counter
uid = case State Counter (Either String (Map Counter ADValue))
-> Counter -> (Either String (Map Counter ADValue), Counter)
forall s a. State s a -> s -> (a, s)
runState (ExceptT String (State Counter) (Map Counter ADValue)
-> State Counter (Either String (Map Counter ADValue))
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT String (State Counter) (Map Counter ADValue)
-> State Counter (Either String (Map Counter ADValue)))
-> ExceptT String (State Counter) (Map Counter ADValue)
-> State Counter (Either String (Map Counter ADValue))
forall a b. (a -> b) -> a -> b
$ Tape
-> ADValue -> ExceptT String (State Counter) (Map Counter ADValue)
deriveTape' Tape
tp ADValue
s) Counter
uid of
(Left String
e, Counter
_) -> String -> Either String (Map Counter ADValue, Counter)
forall a b. a -> Either a b
Left String
e
(Right Map Counter ADValue
v, Counter
uid') -> (Map Counter ADValue, Counter)
-> Either String (Map Counter ADValue, Counter)
forall a b. b -> Either a b
Right (Map Counter ADValue
v, Counter
uid')
deriveTape' :: Tape -> ADValue -> ADMonad (M.Map Counter ADValue)
deriveTape' :: Tape
-> ADValue -> ExceptT String (State Counter) (Map Counter ADValue)
deriveTape' (TapeID Counter
i ADValue
_) ADValue
s = Map Counter ADValue
-> ExceptT String (State Counter) (Map Counter ADValue)
forall a. a -> ExceptT String (State Counter) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map Counter ADValue
-> ExceptT String (State Counter) (Map Counter ADValue))
-> Map Counter ADValue
-> ExceptT String (State Counter) (Map Counter ADValue)
forall a b. (a -> b) -> a -> b
$ Counter -> ADValue -> Map Counter ADValue
forall k a. k -> a -> Map k a
M.singleton Counter
i ADValue
s
deriveTape' (TapeConst ADValue
_) ADValue
_ = Map Counter ADValue
-> ExceptT String (State Counter) (Map Counter ADValue)
forall a. a -> ExceptT String (State Counter) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Map Counter ADValue
forall k a. Map k a
M.empty
deriveTape' tp :: Tape
tp@(TapeOp Op
op [Tape]
p Counter
uid ADValue
_) ADValue
s =
(Map Counter ADValue, Map Counter Int) -> Map Counter ADValue
forall a b. (a, b) -> a
fst ((Map Counter ADValue, Map Counter Int) -> Map Counter ADValue)
-> ExceptT
String (State Counter) (Map Counter ADValue, Map Counter Int)
-> ExceptT String (State Counter) (Map Counter ADValue)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Tape
-> ADValue
-> Map Counter ADValue
-> Map Counter Int
-> ExceptT
String (State Counter) (Map Counter ADValue, Map Counter Int)
derive Tape
tp ADValue
s Map Counter ADValue
forall k a. Map k a
M.empty ([Tape] -> Map Counter Int -> Map Counter Int
countReferences [Tape]
p (Map Counter Int -> Map Counter Int)
-> Map Counter Int -> Map Counter Int
forall a b. (a -> b) -> a -> b
$ Counter -> Int -> Map Counter Int
forall k a. k -> a -> Map k a
M.singleton (-Counter
uid Counter -> Counter -> Counter
forall a. Num a => a -> a -> a
- Counter
1) Int
1)
where
add :: ADValue -> ADValue -> ADMonad ADValue
add ADValue
x ADValue
y = Op -> [ADValue] -> ADMonad ADValue
doOp' (BinOp -> Op
OpBin (BinOp -> Op) -> BinOp -> Op
forall a b. (a -> b) -> a -> b
$ PrimType -> BinOp
addFor (PrimType -> BinOp) -> PrimType -> BinOp
forall a b. (a -> b) -> a -> b
$ Op -> PrimType
opReturnType Op
op) [ADValue
x, ADValue
y]
mul :: ADValue -> ADValue -> ADMonad ADValue
mul ADValue
x ADValue
y = Op -> [ADValue] -> ADMonad ADValue
doOp' (BinOp -> Op
OpBin (BinOp -> Op) -> BinOp -> Op
forall a b. (a -> b) -> a -> b
$ PrimType -> BinOp
mulFor (PrimType -> BinOp) -> PrimType -> BinOp
forall a b. (a -> b) -> a -> b
$ Op -> PrimType
opReturnType Op
op) [ADValue
x, ADValue
y]
madd :: Counter -> ADValue -> M.Map Counter ADValue -> ADMonad (M.Map Counter ADValue)
madd :: Counter
-> ADValue
-> Map Counter ADValue
-> ExceptT String (State Counter) (Map Counter ADValue)
madd Counter
i ADValue
a Map Counter ADValue
m = case Counter -> Map Counter ADValue -> Maybe ADValue
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Counter
i Map Counter ADValue
m of
Just ADValue
b -> ADValue -> ADValue -> ADMonad ADValue
add ADValue
a ADValue
b ADMonad ADValue
-> (ADValue -> Map Counter ADValue)
-> ExceptT String (State Counter) (Map Counter ADValue)
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (\ADValue
x -> Counter -> ADValue -> Map Counter ADValue -> Map Counter ADValue
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Counter
i ADValue
x Map Counter ADValue
m)
Maybe ADValue
Nothing -> Map Counter ADValue
-> ExceptT String (State Counter) (Map Counter ADValue)
forall a. a -> ExceptT String (State Counter) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map Counter ADValue
-> ExceptT String (State Counter) (Map Counter ADValue))
-> Map Counter ADValue
-> ExceptT String (State Counter) (Map Counter ADValue)
forall a b. (a -> b) -> a -> b
$ Counter -> ADValue -> Map Counter ADValue -> Map Counter ADValue
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Counter
i ADValue
a Map Counter ADValue
m
derive ::
Tape ->
ADValue ->
M.Map Counter ADValue ->
M.Map Counter Int ->
ADMonad (M.Map Counter ADValue, M.Map Counter Int)
derive :: Tape
-> ADValue
-> Map Counter ADValue
-> Map Counter Int
-> ExceptT
String (State Counter) (Map Counter ADValue, Map Counter Int)
derive (TapeID Counter
i ADValue
_) ADValue
s' Map Counter ADValue
ss Map Counter Int
rs = Counter
-> ADValue
-> Map Counter ADValue
-> ExceptT String (State Counter) (Map Counter ADValue)
madd Counter
i ADValue
s' Map Counter ADValue
ss ExceptT String (State Counter) (Map Counter ADValue)
-> (Map Counter ADValue -> (Map Counter ADValue, Map Counter Int))
-> ExceptT
String (State Counter) (Map Counter ADValue, Map Counter Int)
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (,Map Counter Int
rs)
derive (TapeConst ADValue
_) ADValue
_ Map Counter ADValue
ss Map Counter Int
rs = (Map Counter ADValue, Map Counter Int)
-> ExceptT
String (State Counter) (Map Counter ADValue, Map Counter Int)
forall a. a -> ExceptT String (State Counter) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map Counter ADValue
ss, Map Counter Int
rs)
derive (TapeOp Op
op' [Tape]
p' Counter
uid' ADValue
_) ADValue
s' Map Counter ADValue
ss Map Counter Int
rs = do
let r :: Int
r = Maybe Int -> Int
forall a. HasCallStack => Maybe a -> a
fromJust (Counter -> Map Counter Int -> Maybe Int
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (-Counter
uid' Counter -> Counter -> Counter
forall a. Num a => a -> a -> a
- Counter
1) Map Counter Int
rs) Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
rs' :: Map Counter Int
rs' = Counter -> Int -> Map Counter Int -> Map Counter Int
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (-Counter
uid' Counter -> Counter -> Counter
forall a. Num a => a -> a -> a
- Counter
1) Int
r Map Counter Int
rs
Map Counter ADValue
ss' <- Counter
-> ADValue
-> Map Counter ADValue
-> ExceptT String (State Counter) (Map Counter ADValue)
madd (-Counter
uid' Counter -> Counter -> Counter
forall a. Num a => a -> a -> a
- Counter
1) ADValue
s' Map Counter ADValue
ss
if Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0
then (Map Counter ADValue, Map Counter Int)
-> ExceptT
String (State Counter) (Map Counter ADValue, Map Counter Int)
forall a. a -> ExceptT String (State Counter) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map Counter ADValue
ss', Map Counter Int
rs')
else
if Int
r Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
then do
let s'' :: ADValue
s'' = Maybe ADValue -> ADValue
forall a. HasCallStack => Maybe a -> a
fromJust (Counter -> Map Counter ADValue -> Maybe ADValue
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (-Counter
uid' Counter -> Counter -> Counter
forall a. Num a => a -> a -> a
- Counter
1) Map Counter ADValue
ss')
[ADValue]
s''' <- case Op
op' of
OpConv ConvOp
op'' ->
[ADMonad ADValue] -> ExceptT String (State Counter) [ADValue]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence [Op -> [ADValue] -> ADMonad ADValue
doOp' (ConvOp -> Op
OpConv (ConvOp -> Op) -> ConvOp -> Op
forall a b. (a -> b) -> a -> b
$ ConvOp -> ConvOp
flipConvOp ConvOp
op'') [ADValue
s'']]
Op
_ -> Op -> [ADValue] -> ExceptT String (State Counter) [ADValue]
calculatePDs Op
op' ((Tape -> ADValue) -> [Tape] -> [ADValue]
forall a b. (a -> b) -> [a] -> [b]
map Tape -> ADValue
tapePrimal [Tape]
p') ExceptT String (State Counter) [ADValue]
-> ([ADValue] -> ExceptT String (State Counter) [ADValue])
-> ExceptT String (State Counter) [ADValue]
forall a b.
ExceptT String (State Counter) a
-> (a -> ExceptT String (State Counter) b)
-> ExceptT String (State Counter) b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (ADValue -> ADMonad ADValue)
-> [ADValue] -> ExceptT String (State Counter) [ADValue]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (ADValue -> ADValue -> ADMonad ADValue
mul ADValue
s'')
((Map Counter ADValue, Map Counter Int)
-> (Tape, ADValue)
-> ExceptT
String (State Counter) (Map Counter ADValue, Map Counter Int))
-> (Map Counter ADValue, Map Counter Int)
-> [(Tape, ADValue)]
-> ExceptT
String (State Counter) (Map Counter ADValue, Map Counter Int)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldlM (\(Map Counter ADValue
ss'', Map Counter Int
rs'') (Tape
p'', ADValue
s'''') -> Tape
-> ADValue
-> Map Counter ADValue
-> Map Counter Int
-> ExceptT
String (State Counter) (Map Counter ADValue, Map Counter Int)
derive Tape
p'' ADValue
s'''' Map Counter ADValue
ss'' Map Counter Int
rs'') (Map Counter ADValue
ss', Map Counter Int
rs') ([(Tape, ADValue)]
-> ExceptT
String (State Counter) (Map Counter ADValue, Map Counter Int))
-> [(Tape, ADValue)]
-> ExceptT
String (State Counter) (Map Counter ADValue, Map Counter Int)
forall a b. (a -> b) -> a -> b
$ [Tape] -> [ADValue] -> [(Tape, ADValue)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Tape]
p' [ADValue]
s'''
else String
-> ExceptT
String (State Counter) (Map Counter ADValue, Map Counter Int)
forall a. HasCallStack => String -> a
error String
"TODO: This branch is unreachable unless `countReferences` undercounts"
countReferences :: [Tape] -> M.Map Counter Int -> M.Map Counter Int
countReferences :: [Tape] -> Map Counter Int -> Map Counter Int
countReferences [Tape]
p' Map Counter Int
d' = (Map Counter Int -> Tape -> Map Counter Int)
-> Map Counter Int -> [Tape] -> Map Counter Int
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Map Counter Int -> Tape -> Map Counter Int
f Map Counter Int
d' [Tape]
p'
f :: Map Counter Int -> Tape -> Map Counter Int
f Map Counter Int
d'' Tape
x =
case Tape
x of
(TapeOp Op
_ [Tape]
p'' Counter
uid'' ADValue
_) -> case Counter -> Map Counter Int -> Maybe Int
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (-Counter
uid'' Counter -> Counter -> Counter
forall a. Num a => a -> a -> a
- Counter
1) Map Counter Int
d'' of
Just Int
v -> Counter -> Int -> Map Counter Int -> Map Counter Int
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (-Counter
uid'' Counter -> Counter -> Counter
forall a. Num a => a -> a -> a
- Counter
1) (Int
v Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Map Counter Int
d''
Maybe Int
Nothing -> [Tape] -> Map Counter Int -> Map Counter Int
countReferences [Tape]
p'' (Map Counter Int -> Map Counter Int)
-> Map Counter Int -> Map Counter Int
forall a b. (a -> b) -> a -> b
$ Counter -> Int -> Map Counter Int -> Map Counter Int
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (-Counter
uid'' Counter -> Counter -> Counter
forall a. Num a => a -> a -> a
- Counter
1) Int
1 Map Counter Int
d''
Tape
_ -> Map Counter Int
d''
data JVPValue = JVPValue ADValue ADValue
deriving (Int -> JVPValue -> ShowS
[JVPValue] -> ShowS
JVPValue -> String
(Int -> JVPValue -> ShowS)
-> (JVPValue -> String) -> ([JVPValue] -> ShowS) -> Show JVPValue
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> JVPValue -> ShowS
showsPrec :: Int -> JVPValue -> ShowS
$cshow :: JVPValue -> String
show :: JVPValue -> String
$cshowList :: [JVPValue] -> ShowS
showList :: [JVPValue] -> ShowS
Show)
jvpHandleOp :: Op -> [Either ADValue JVPValue] -> ADMonad ADValue
jvpHandleOp :: Op -> [Either ADValue JVPValue] -> ADMonad ADValue
jvpHandleOp Op
op [Either ADValue JVPValue]
p = do
case Op
op of
OpConv ConvOp
_ ->
Op -> [ADValue] -> ADMonad ADValue
doOp' Op
op [Either ADValue JVPValue -> ADValue
forall {a}. Either a JVPValue -> ADValue
tangent (Either ADValue JVPValue -> ADValue)
-> Either ADValue JVPValue -> ADValue
forall a b. (a -> b) -> a -> b
$ [Either ADValue JVPValue] -> Either ADValue JVPValue
forall a. HasCallStack => [a] -> a
head [Either ADValue JVPValue]
p]
Op
_ -> do
[ADValue]
pds <- Op -> [ADValue] -> ExceptT String (State Counter) [ADValue]
calculatePDs Op
op ([ADValue] -> ExceptT String (State Counter) [ADValue])
-> [ADValue] -> ExceptT String (State Counter) [ADValue]
forall a b. (a -> b) -> a -> b
$ (Either ADValue JVPValue -> ADValue)
-> [Either ADValue JVPValue] -> [ADValue]
forall a b. (a -> b) -> [a] -> [b]
map Either ADValue JVPValue -> ADValue
primal' [Either ADValue JVPValue]
p
[ADValue]
vs <- (ADValue -> ADValue -> ADMonad ADValue)
-> [ADValue]
-> [ADValue]
-> ExceptT String (State Counter) [ADValue]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM ADValue -> ADValue -> ADMonad ADValue
mul [ADValue]
pds ([ADValue] -> ExceptT String (State Counter) [ADValue])
-> [ADValue] -> ExceptT String (State Counter) [ADValue]
forall a b. (a -> b) -> a -> b
$ (Either ADValue JVPValue -> ADValue)
-> [Either ADValue JVPValue] -> [ADValue]
forall a b. (a -> b) -> [a] -> [b]
map Either ADValue JVPValue -> ADValue
forall {a}. Either a JVPValue -> ADValue
tangent [Either ADValue JVPValue]
p
(ADValue -> ADValue -> ADMonad ADValue)
-> ADValue -> [ADValue] -> ADMonad ADValue
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ADValue -> ADValue -> ADMonad ADValue
add (PrimValue -> ADValue
Constant (PrimValue -> ADValue) -> PrimValue -> ADValue
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
op_t) [ADValue]
vs
where
op_t :: PrimType
op_t = Op -> PrimType
opReturnType Op
op
primal' :: Either ADValue JVPValue -> ADValue
primal' (Left ADValue
v) = ADValue
v
primal' (Right (JVPValue ADValue
v ADValue
_)) = ADValue
v
tangent :: Either a JVPValue -> ADValue
tangent (Left a
_) = PrimValue -> ADValue
Constant (PrimValue -> ADValue) -> PrimValue -> ADValue
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue (PrimType -> PrimValue) -> PrimType -> PrimValue
forall a b. (a -> b) -> a -> b
$ Op -> PrimType
opReturnType Op
op
tangent (Right (JVPValue ADValue
_ ADValue
d)) = ADValue
d
add :: ADValue -> ADValue -> ADMonad ADValue
add ADValue
x ADValue
y = Op -> [ADValue] -> ADMonad ADValue
doOp' (BinOp -> Op
OpBin (BinOp -> Op) -> BinOp -> Op
forall a b. (a -> b) -> a -> b
$ PrimType -> BinOp
addFor (PrimType -> BinOp) -> PrimType -> BinOp
forall a b. (a -> b) -> a -> b
$ Op -> PrimType
opReturnType Op
op) [ADValue
x, ADValue
y]
mul :: ADValue -> ADValue -> ADMonad ADValue
mul ADValue
x ADValue
y = Op -> [ADValue] -> ADMonad ADValue
doOp' (BinOp -> Op
OpBin (BinOp -> Op) -> BinOp -> Op
forall a b. (a -> b) -> a -> b
$ PrimType -> BinOp
mulFor (PrimType -> BinOp) -> PrimType -> BinOp
forall a b. (a -> b) -> a -> b
$ Op -> PrimType
opReturnType Op
op) [ADValue
x, ADValue
y]