module Language.Futhark.Interpreter.AD
  ( Op (..),
    ADVariable (..),
    ADValue (..),
    Tape (..),
    VJPValue (..),
    JVPValue (..),
    Counter (..),
    Depth (..),
    doOp,
    addFor,
    tapePrimal,
    primitive,
    varPrimal,
    deriveTape,
    unionWithM,
    unionsWithM,
  )
where

import Control.Monad (foldM, zipWithM)
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.Except (ExceptT, catchE, runExceptT, throwE)
import Control.Monad.Trans.State (State, get, modify, runState)
import Data.Either (isRight)
import Data.Foldable (find, foldlM)
import Data.Functor ((<&>))
import Data.Map qualified as M
import Data.Maybe (fromJust, fromMaybe)
import Data.Text qualified as T
import Futhark.AD.Derivatives (pdBinOp, pdBuiltin, pdUnOp)
import Futhark.Analysis.PrimExp (PrimExp (..))
import Language.Futhark.Core (VName (..), nameFromString, nameFromText)
import Language.Futhark.Primitive
  ( BinOp (Add, FAdd, FMul, LogAnd, LogOr, Mul),
    CmpOp,
    ConvOp,
    Overflow (OverflowWrap),
    PrimType (Bool, FloatType, IntType),
    PrimValue (BoolValue),
    UnOp,
    binOpType,
    blankPrimValue,
    cmpOpType,
    convOpType,
    doBinOp,
    doCmpOp,
    doConvOp,
    doUnOp,
    flipConvOp,
    primFuns,
    primValueType,
    unOpType,
  )

-- | Used to uniquely identify values.
newtype Counter = Counter Int
  deriving (Counter -> Counter -> Bool
(Counter -> Counter -> Bool)
-> (Counter -> Counter -> Bool) -> Eq Counter
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Counter -> Counter -> Bool
== :: Counter -> Counter -> Bool
$c/= :: Counter -> Counter -> Bool
/= :: Counter -> Counter -> Bool
Eq, Eq Counter
Eq Counter =>
(Counter -> Counter -> Ordering)
-> (Counter -> Counter -> Bool)
-> (Counter -> Counter -> Bool)
-> (Counter -> Counter -> Bool)
-> (Counter -> Counter -> Bool)
-> (Counter -> Counter -> Counter)
-> (Counter -> Counter -> Counter)
-> Ord Counter
Counter -> Counter -> Bool
Counter -> Counter -> Ordering
Counter -> Counter -> Counter
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: Counter -> Counter -> Ordering
compare :: Counter -> Counter -> Ordering
$c< :: Counter -> Counter -> Bool
< :: Counter -> Counter -> Bool
$c<= :: Counter -> Counter -> Bool
<= :: Counter -> Counter -> Bool
$c> :: Counter -> Counter -> Bool
> :: Counter -> Counter -> Bool
$c>= :: Counter -> Counter -> Bool
>= :: Counter -> Counter -> Bool
$cmax :: Counter -> Counter -> Counter
max :: Counter -> Counter -> Counter
$cmin :: Counter -> Counter -> Counter
min :: Counter -> Counter -> Counter
Ord, Integer -> Counter
Counter -> Counter
Counter -> Counter -> Counter
(Counter -> Counter -> Counter)
-> (Counter -> Counter -> Counter)
-> (Counter -> Counter -> Counter)
-> (Counter -> Counter)
-> (Counter -> Counter)
-> (Counter -> Counter)
-> (Integer -> Counter)
-> Num Counter
forall a.
(a -> a -> a)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (Integer -> a)
-> Num a
$c+ :: Counter -> Counter -> Counter
+ :: Counter -> Counter -> Counter
$c- :: Counter -> Counter -> Counter
- :: Counter -> Counter -> Counter
$c* :: Counter -> Counter -> Counter
* :: Counter -> Counter -> Counter
$cnegate :: Counter -> Counter
negate :: Counter -> Counter
$cabs :: Counter -> Counter
abs :: Counter -> Counter
$csignum :: Counter -> Counter
signum :: Counter -> Counter
$cfromInteger :: Integer -> Counter
fromInteger :: Integer -> Counter
Num, Int -> Counter -> ShowS
[Counter] -> ShowS
Counter -> String
(Int -> Counter -> ShowS)
-> (Counter -> String) -> ([Counter] -> ShowS) -> Show Counter
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Counter -> ShowS
showsPrec :: Int -> Counter -> ShowS
$cshow :: Counter -> String
show :: Counter -> String
$cshowList :: [Counter] -> ShowS
showList :: [Counter] -> ShowS
Show)

type ADMonad = ExceptT String (State Counter)

incCounter :: ADMonad ()
incCounter :: ADMonad ()
incCounter = State Counter () -> ADMonad ()
forall (m :: * -> *) a. Monad m => m a -> ExceptT String m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (State Counter () -> ADMonad ()) -> State Counter () -> ADMonad ()
forall a b. (a -> b) -> a -> b
$ (Counter -> Counter) -> State Counter ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((Counter -> Counter) -> State Counter ())
-> (Counter -> Counter) -> State Counter ()
forall a b. (a -> b) -> a -> b
$ \Counter
i -> Counter
i Counter -> Counter -> Counter
forall a. Num a => a -> a -> a
+ Counter
1

-- Mathematical operations subject to AD.
data Op
  = OpBin BinOp
  | OpCmp CmpOp
  | OpUn UnOp
  | OpFn T.Text
  | OpConv ConvOp
  deriving (Int -> Op -> ShowS
[Op] -> ShowS
Op -> String
(Int -> Op -> ShowS)
-> (Op -> String) -> ([Op] -> ShowS) -> Show Op
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Op -> ShowS
showsPrec :: Int -> Op -> ShowS
$cshow :: Op -> String
show :: Op -> String
$cshowList :: [Op] -> ShowS
showList :: [Op] -> ShowS
Show)

-- Checks if an operation matches the types of its operands
opTypeMatch :: Op -> [PrimType] -> Bool
opTypeMatch :: Op -> [PrimType] -> Bool
opTypeMatch (OpBin BinOp
op) [PrimType]
p = (PrimType -> Bool) -> [PrimType] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\PrimType
x -> BinOp -> PrimType
binOpType BinOp
op PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
x) [PrimType]
p
opTypeMatch (OpCmp CmpOp
op) [PrimType]
p = (PrimType -> Bool) -> [PrimType] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\PrimType
x -> CmpOp -> PrimType
cmpOpType CmpOp
op PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
x) [PrimType]
p
opTypeMatch (OpUn UnOp
op) [PrimType]
p = (PrimType -> Bool) -> [PrimType] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\PrimType
x -> UnOp -> PrimType
unOpType UnOp
op PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
x) [PrimType]
p
opTypeMatch (OpConv ConvOp
op) [PrimType]
p = (PrimType -> Bool) -> [PrimType] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\PrimType
x -> (PrimType, PrimType) -> PrimType
forall a b. (a, b) -> a
fst (ConvOp -> (PrimType, PrimType)
convOpType ConvOp
op) PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
x) [PrimType]
p
opTypeMatch (OpFn Text
fn) [PrimType]
p = case Text
-> Map Text ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
-> Maybe ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Text
fn Map Text ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
primFuns of
  Just ([PrimType]
t, PrimType
_, [PrimValue] -> Maybe PrimValue
_) -> [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ (PrimType -> PrimType -> Bool)
-> [PrimType] -> [PrimType] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
(==) [PrimType]
t [PrimType]
p
  Maybe ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
Nothing -> String -> Bool
forall a. HasCallStack => String -> a
error String
"opTypeMatch" -- It is assumed that the function exists

-- Gets the return type of an operation
opReturnType :: Op -> PrimType
opReturnType :: Op -> PrimType
opReturnType (OpBin BinOp
op) = BinOp -> PrimType
binOpType BinOp
op
opReturnType (OpCmp CmpOp
op) = CmpOp -> PrimType
cmpOpType CmpOp
op
opReturnType (OpUn UnOp
op) = UnOp -> PrimType
unOpType UnOp
op
opReturnType (OpConv ConvOp
op) = (PrimType, PrimType) -> PrimType
forall a b. (a, b) -> b
snd ((PrimType, PrimType) -> PrimType)
-> (PrimType, PrimType) -> PrimType
forall a b. (a -> b) -> a -> b
$ ConvOp -> (PrimType, PrimType)
convOpType ConvOp
op
opReturnType (OpFn Text
fn) = case Text
-> Map Text ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
-> Maybe ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Text
fn Map Text ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
primFuns of
  Just ([PrimType]
_, PrimType
t, [PrimValue] -> Maybe PrimValue
_) -> PrimType
t
  Maybe ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
Nothing -> String -> PrimType
forall a. HasCallStack => String -> a
error String
"opReturnType" -- It is assumed that the function exists

-- Returns the operation which performs addition (or an
-- equivalent operation) on the given type
addFor :: PrimType -> BinOp
addFor :: PrimType -> BinOp
addFor (IntType IntType
t) = IntType -> Overflow -> BinOp
Add IntType
t Overflow
OverflowWrap
addFor (FloatType FloatType
t) = FloatType -> BinOp
FAdd FloatType
t
addFor PrimType
Bool = BinOp
LogOr
addFor PrimType
t = String -> BinOp
forall a. HasCallStack => String -> a
error (String -> BinOp) -> String -> BinOp
forall a b. (a -> b) -> a -> b
$ String
"addFor: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ PrimType -> String
forall a. Show a => a -> String
show PrimType
t

-- Returns the function which performs multiplication
-- (or an equivalent operation) on the given type
mulFor :: PrimType -> BinOp
mulFor :: PrimType -> BinOp
mulFor (IntType IntType
t) = IntType -> Overflow -> BinOp
Mul IntType
t Overflow
OverflowWrap
mulFor (FloatType FloatType
t) = FloatType -> BinOp
FMul FloatType
t
mulFor PrimType
Bool = BinOp
LogAnd
mulFor PrimType
t = String -> BinOp
forall a. HasCallStack => String -> a
error (String -> BinOp) -> String -> BinOp
forall a b. (a -> b) -> a -> b
$ String
"mulFor: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ PrimType -> String
forall a. Show a => a -> String
show PrimType
t

-- | An indication of the nesting depth of AD. This is used to avoid
-- pertubation confusion.
newtype Depth = Depth Int
  deriving (Eq Depth
Eq Depth =>
(Depth -> Depth -> Ordering)
-> (Depth -> Depth -> Bool)
-> (Depth -> Depth -> Bool)
-> (Depth -> Depth -> Bool)
-> (Depth -> Depth -> Bool)
-> (Depth -> Depth -> Depth)
-> (Depth -> Depth -> Depth)
-> Ord Depth
Depth -> Depth -> Bool
Depth -> Depth -> Ordering
Depth -> Depth -> Depth
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: Depth -> Depth -> Ordering
compare :: Depth -> Depth -> Ordering
$c< :: Depth -> Depth -> Bool
< :: Depth -> Depth -> Bool
$c<= :: Depth -> Depth -> Bool
<= :: Depth -> Depth -> Bool
$c> :: Depth -> Depth -> Bool
> :: Depth -> Depth -> Bool
$c>= :: Depth -> Depth -> Bool
>= :: Depth -> Depth -> Bool
$cmax :: Depth -> Depth -> Depth
max :: Depth -> Depth -> Depth
$cmin :: Depth -> Depth -> Depth
min :: Depth -> Depth -> Depth
Ord, Depth -> Depth -> Bool
(Depth -> Depth -> Bool) -> (Depth -> Depth -> Bool) -> Eq Depth
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Depth -> Depth -> Bool
== :: Depth -> Depth -> Bool
$c/= :: Depth -> Depth -> Bool
/= :: Depth -> Depth -> Bool
Eq, Int -> Depth -> ShowS
[Depth] -> ShowS
Depth -> String
(Int -> Depth -> ShowS)
-> (Depth -> String) -> ([Depth] -> ShowS) -> Show Depth
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Depth -> ShowS
showsPrec :: Int -> Depth -> ShowS
$cshow :: Depth -> String
show :: Depth -> String
$cshowList :: [Depth] -> ShowS
showList :: [Depth] -> ShowS
Show)

-- Types and utility functions--
-- When taking the partial derivative of a function, we
-- must differentiate between the values which are kept
-- constant, and those which are not
data ADValue
  = Variable Depth ADVariable
  | Constant PrimValue
  deriving (Int -> ADValue -> ShowS
[ADValue] -> ShowS
ADValue -> String
(Int -> ADValue -> ShowS)
-> (ADValue -> String) -> ([ADValue] -> ShowS) -> Show ADValue
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ADValue -> ShowS
showsPrec :: Int -> ADValue -> ShowS
$cshow :: ADValue -> String
show :: ADValue -> String
$cshowList :: [ADValue] -> ShowS
showList :: [ADValue] -> ShowS
Show)

-- When performing automatic differentiation, each derived
-- variable must be augmented with additional data. This
-- value holds the primitive value of the variable, as well
-- as its data
data ADVariable
  = VJP VJPValue
  | JVP JVPValue
  deriving (Int -> ADVariable -> ShowS
[ADVariable] -> ShowS
ADVariable -> String
(Int -> ADVariable -> ShowS)
-> (ADVariable -> String)
-> ([ADVariable] -> ShowS)
-> Show ADVariable
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ADVariable -> ShowS
showsPrec :: Int -> ADVariable -> ShowS
$cshow :: ADVariable -> String
show :: ADVariable -> String
$cshowList :: [ADVariable] -> ShowS
showList :: [ADVariable] -> ShowS
Show)

depth :: ADValue -> Depth
depth :: ADValue -> Depth
depth (Variable Depth
d ADVariable
_) = Depth
d
depth (Constant PrimValue
_) = Int -> Depth
Depth Int
0

primal :: ADValue -> ADValue
primal :: ADValue -> ADValue
primal (Variable Depth
_ (VJP (VJPValue Tape
t))) = Tape -> ADValue
tapePrimal Tape
t
primal (Variable Depth
_ (JVP (JVPValue ADValue
v ADValue
_))) = ADValue -> ADValue
primal ADValue
v
primal (Constant PrimValue
v) = PrimValue -> ADValue
Constant PrimValue
v

primalFor :: Depth -> ADValue -> ADValue
primalFor :: Depth -> ADValue -> ADValue
primalFor Depth
cur v :: ADValue
v@(Variable Depth
tag ADVariable
_) | Depth
cur Depth -> Depth -> Bool
forall a. Eq a => a -> a -> Bool
/= Depth
tag = ADValue
v
primalFor Depth
_ (Variable Depth
_ (VJP (VJPValue Tape
t))) = Tape -> ADValue
tapePrimal Tape
t
primalFor Depth
cur (Variable Depth
_ (JVP (JVPValue ADValue
v ADValue
_))) = Depth -> ADValue -> ADValue
primalFor Depth
cur ADValue
v
primalFor Depth
_ (Constant PrimValue
v) = PrimValue -> ADValue
Constant PrimValue
v

primitive :: ADValue -> PrimValue
primitive :: ADValue -> PrimValue
primitive (Variable Depth
_ ADVariable
v) = ADVariable -> PrimValue
varPrimal ADVariable
v
primitive (Constant PrimValue
v) = PrimValue
v

varPrimal :: ADVariable -> PrimValue
varPrimal :: ADVariable -> PrimValue
varPrimal (VJP (VJPValue Tape
t)) = ADValue -> PrimValue
primitive (ADValue -> PrimValue) -> ADValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ Tape -> ADValue
tapePrimal Tape
t
varPrimal (JVP (JVPValue ADValue
v ADValue
_)) = ADValue -> PrimValue
primitive (ADValue -> PrimValue) -> ADValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ ADValue -> ADValue
primal ADValue
v

-- Evaluates a PrimExp using doOp'
evalPrimExp :: M.Map VName ADValue -> PrimExp VName -> ADMonad ADValue
evalPrimExp :: Map VName ADValue -> PrimExp VName -> ADMonad ADValue
evalPrimExp Map VName ADValue
m (LeafExp VName
n PrimType
_) =
  ADMonad ADValue
-> (ADValue -> ADMonad ADValue) -> Maybe ADValue -> ADMonad ADValue
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (String -> ADMonad ADValue
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (String -> ADMonad ADValue) -> String -> ADMonad ADValue
forall a b. (a -> b) -> a -> b
$ String
"Unknown variable " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> VName -> String
forall a. Show a => a -> String
show VName
n) ADValue -> ADMonad ADValue
forall a. a -> ExceptT String (State Counter) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe ADValue -> ADMonad ADValue)
-> Maybe ADValue -> ADMonad ADValue
forall a b. (a -> b) -> a -> b
$ VName -> Map VName ADValue -> Maybe ADValue
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
n Map VName ADValue
m
evalPrimExp Map VName ADValue
_ (ValueExp PrimValue
pv) =
  ADValue -> ADMonad ADValue
forall a. a -> ExceptT String (State Counter) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADValue -> ADMonad ADValue) -> ADValue -> ADMonad ADValue
forall a b. (a -> b) -> a -> b
$ PrimValue -> ADValue
Constant PrimValue
pv
evalPrimExp Map VName ADValue
m (BinOpExp BinOp
op PrimExp VName
x PrimExp VName
y) = do
  ADValue
x' <- Map VName ADValue -> PrimExp VName -> ADMonad ADValue
evalPrimExp Map VName ADValue
m PrimExp VName
x
  ADValue
y' <- Map VName ADValue -> PrimExp VName -> ADMonad ADValue
evalPrimExp Map VName ADValue
m PrimExp VName
y
  Op -> [ADValue] -> ADMonad ADValue
doOp' (BinOp -> Op
OpBin BinOp
op) [ADValue
x', ADValue
y']
evalPrimExp Map VName ADValue
m (CmpOpExp CmpOp
op PrimExp VName
x PrimExp VName
y) = do
  ADValue
x' <- Map VName ADValue -> PrimExp VName -> ADMonad ADValue
evalPrimExp Map VName ADValue
m PrimExp VName
x
  ADValue
y' <- Map VName ADValue -> PrimExp VName -> ADMonad ADValue
evalPrimExp Map VName ADValue
m PrimExp VName
y
  Op -> [ADValue] -> ADMonad ADValue
doOp' (CmpOp -> Op
OpCmp CmpOp
op) [ADValue
x', ADValue
y']
evalPrimExp Map VName ADValue
m (UnOpExp UnOp
op PrimExp VName
x) = do
  ADValue
x' <- Map VName ADValue -> PrimExp VName -> ADMonad ADValue
evalPrimExp Map VName ADValue
m PrimExp VName
x
  Op -> [ADValue] -> ADMonad ADValue
doOp' (UnOp -> Op
OpUn UnOp
op) [ADValue
x']
evalPrimExp Map VName ADValue
m (ConvOpExp ConvOp
op PrimExp VName
x) = do
  ADValue
x' <- Map VName ADValue -> PrimExp VName -> ADMonad ADValue
evalPrimExp Map VName ADValue
m PrimExp VName
x
  Op -> [ADValue] -> ADMonad ADValue
doOp' (ConvOp -> Op
OpConv ConvOp
op) [ADValue
x']
evalPrimExp Map VName ADValue
m (FunExp Text
fn [PrimExp VName]
p PrimType
_) = do
  [ADValue]
p' <- (PrimExp VName -> ADMonad ADValue)
-> [PrimExp VName] -> ExceptT String (State Counter) [ADValue]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Map VName ADValue -> PrimExp VName -> ADMonad ADValue
evalPrimExp Map VName ADValue
m) [PrimExp VName]
p
  Op -> [ADValue] -> ADMonad ADValue
doOp' (Text -> Op
OpFn Text
fn) [ADValue]
p'

-- Returns a list of PrimExps calculating the partial
-- derivative of each operands of a given operation
lookupPDs :: Op -> [PrimExp VName] -> Maybe [PrimExp VName]
lookupPDs :: Op -> [PrimExp VName] -> Maybe [PrimExp VName]
lookupPDs (OpBin BinOp
op) [PrimExp VName
x, PrimExp VName
y] = [PrimExp VName] -> Maybe [PrimExp VName]
forall a. a -> Maybe a
Just ([PrimExp VName] -> Maybe [PrimExp VName])
-> [PrimExp VName] -> Maybe [PrimExp VName]
forall a b. (a -> b) -> a -> b
$ do
  let (PrimExp VName
a, PrimExp VName
b) = BinOp
-> PrimExp VName -> PrimExp VName -> (PrimExp VName, PrimExp VName)
pdBinOp BinOp
op PrimExp VName
x PrimExp VName
y
  [PrimExp VName
a, PrimExp VName
b]
lookupPDs (OpUn UnOp
op) [PrimExp VName
x] = [PrimExp VName] -> Maybe [PrimExp VName]
forall a. a -> Maybe a
Just [UnOp -> PrimExp VName -> PrimExp VName
pdUnOp UnOp
op PrimExp VName
x]
lookupPDs (OpFn Text
fn) [PrimExp VName]
p = Name -> [PrimExp VName] -> Maybe [PrimExp VName]
pdBuiltin (Text -> Name
nameFromText Text
fn) [PrimExp VName]
p
lookupPDs Op
_ [PrimExp VName]
_ = Maybe [PrimExp VName]
forall a. Maybe a
Nothing

-- Shared AD logic--
-- This function performs a mathematical operation on a
-- list of operands, performing automatic differentiation
-- if one or more operands is a Variable (of depth > 0)
doOp :: Op -> [ADValue] -> Counter -> Either String (ADValue, Counter)
doOp :: Op -> [ADValue] -> Counter -> Either String (ADValue, Counter)
doOp Op
op [ADValue]
o Counter
uid = case State Counter (Either String ADValue)
-> Counter -> (Either String ADValue, Counter)
forall s a. State s a -> s -> (a, s)
runState (ADMonad ADValue -> State Counter (Either String ADValue)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ADMonad ADValue -> State Counter (Either String ADValue))
-> ADMonad ADValue -> State Counter (Either String ADValue)
forall a b. (a -> b) -> a -> b
$ Op -> [ADValue] -> ADMonad ADValue
doOp' Op
op [ADValue]
o) Counter
uid of
  (Left String
s, Counter
_) -> String -> Either String (ADValue, Counter)
forall a b. a -> Either a b
Left String
s
  (Right ADValue
v, Counter
uid') -> (ADValue, Counter) -> Either String (ADValue, Counter)
forall a b. b -> Either a b
Right (ADValue
v, Counter
uid')

doOp' :: Op -> [ADValue] -> ADMonad ADValue
doOp' :: Op -> [ADValue] -> ADMonad ADValue
doOp' Op
op [ADValue]
o
  | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Op -> [PrimType] -> Bool
opTypeMatch Op
op ((PrimValue -> PrimType) -> [PrimValue] -> [PrimType]
forall a b. (a -> b) -> [a] -> [b]
map PrimValue -> PrimType
primValueType [PrimValue]
pv) =
      -- This function may be called with arguments of invalid types,
      -- because it is used as part of an overloaded operator.
      String -> ADMonad ADValue
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (String -> ADMonad ADValue) -> String -> ADMonad ADValue
forall a b. (a -> b) -> a -> b
$ [String] -> String
unwords [String
"invalid types for op", Op -> String
forall a. Show a => a -> String
show Op
op, String
"and operands", [ADValue] -> String
forall a. Show a => a -> String
show [ADValue]
o]
  | Bool
otherwise = do
      let dep :: Depth
dep = case Op
op of
            OpCmp CmpOp
_ -> Int -> Depth
Depth Int
0 -- AD is not well-defined for comparason operations
            -- There are no derivatives for those written in
            -- PrimExp (check lookupPDs)
            Op
_ -> [Depth] -> Depth
forall a. Ord a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum ((ADValue -> Depth) -> [ADValue] -> [Depth]
forall a b. (a -> b) -> [a] -> [b]
map ADValue -> Depth
depth [ADValue]
o)
      if Depth
dep Depth -> Depth -> Bool
forall a. Eq a => a -> a -> Bool
== Int -> Depth
Depth Int
0
        then ADMonad ADValue
-> (ADValue -> ADMonad ADValue) -> Maybe ADValue -> ADMonad ADValue
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (String -> ADMonad ADValue
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE String
"failed to evaluate const") ADValue -> ADMonad ADValue
forall a. a -> ExceptT String (State Counter) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe ADValue
constCase ADMonad ADValue -> ADMonad () -> ADMonad ADValue
forall a b.
ExceptT String (State Counter) a
-> ExceptT String (State Counter) b
-> ExceptT String (State Counter) a
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* ADMonad ()
incCounter
        else Depth -> ADMonad ADValue
nonconstCase Depth
dep
  where
    pv :: [PrimValue]
pv = (ADValue -> PrimValue) -> [ADValue] -> [PrimValue]
forall a b. (a -> b) -> [a] -> [b]
map ADValue -> PrimValue
primitive [ADValue]
o

    divideDepths :: Depth -> ADValue -> Either ADValue ADVariable
    divideDepths :: Depth -> ADValue -> Either ADValue ADVariable
divideDepths Depth
_ v :: ADValue
v@(Constant {}) = ADValue -> Either ADValue ADVariable
forall a b. a -> Either a b
Left ADValue
v
    divideDepths Depth
d v :: ADValue
v@(Variable Depth
d' ADVariable
v') = if Depth
d' Depth -> Depth -> Bool
forall a. Ord a => a -> a -> Bool
< Depth
d then ADValue -> Either ADValue ADVariable
forall a b. a -> Either a b
Left ADValue
v else ADVariable -> Either ADValue ADVariable
forall a b. b -> Either a b
Right ADVariable
v'

    -- TODO: There may be a more graceful way of
    -- doing this
    extractVJP :: Either ADValue ADVariable -> Either ADValue VJPValue
    extractVJP :: Either ADValue ADVariable -> Either ADValue VJPValue
extractVJP (Right (VJP VJPValue
v)) = VJPValue -> Either ADValue VJPValue
forall a b. b -> Either a b
Right VJPValue
v
    extractVJP (Left ADValue
v) = ADValue -> Either ADValue VJPValue
forall a b. a -> Either a b
Left ADValue
v
    extractVJP Either ADValue ADVariable
_ =
      -- This will never be called when the maximum depth layer is JVP
      String -> Either ADValue VJPValue
forall a. HasCallStack => String -> a
error String
"extractVJP"

    -- TODO: There may be a more graceful way of
    -- doing this
    extractJVP :: Either ADValue ADVariable -> Either ADValue JVPValue
    extractJVP :: Either ADValue ADVariable -> Either ADValue JVPValue
extractJVP (Right (JVP JVPValue
v)) = JVPValue -> Either ADValue JVPValue
forall a b. b -> Either a b
Right JVPValue
v
    extractJVP (Left ADValue
v) = ADValue -> Either ADValue JVPValue
forall a b. a -> Either a b
Left ADValue
v
    extractJVP Either ADValue ADVariable
_ =
      -- This will never be called when the maximum depth layer is VJP
      String -> Either ADValue JVPValue
forall a. HasCallStack => String -> a
error String
"extractJVP"

    -- In this case, every operand is a constant, and the
    -- mathematical operation can be applied as it would be
    -- otherwise
    constCase :: Maybe ADValue
constCase =
      PrimValue -> ADValue
Constant (PrimValue -> ADValue) -> Maybe PrimValue -> Maybe ADValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> case (Op
op, [PrimValue]
pv) of
        (OpBin BinOp
op', [PrimValue
x, PrimValue
y]) -> BinOp -> PrimValue -> PrimValue -> Maybe PrimValue
doBinOp BinOp
op' PrimValue
x PrimValue
y
        (OpCmp CmpOp
op', [PrimValue
x, PrimValue
y]) -> Bool -> PrimValue
BoolValue (Bool -> PrimValue) -> Maybe Bool -> Maybe PrimValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CmpOp -> PrimValue -> PrimValue -> Maybe Bool
doCmpOp CmpOp
op' PrimValue
x PrimValue
y
        (OpUn UnOp
op', [PrimValue
x]) -> UnOp -> PrimValue -> Maybe PrimValue
doUnOp UnOp
op' PrimValue
x
        (OpConv ConvOp
op', [PrimValue
x]) -> ConvOp -> PrimValue -> Maybe PrimValue
doConvOp ConvOp
op' PrimValue
x
        (OpFn Text
fn, [PrimValue]
_) -> do
          ([PrimType]
_, PrimType
_, [PrimValue] -> Maybe PrimValue
f) <- Text
-> Map Text ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
-> Maybe ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Text
fn Map Text ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
primFuns
          [PrimValue] -> Maybe PrimValue
f [PrimValue]
pv
        (Op, [PrimValue])
_ -> String -> Maybe PrimValue
forall a. HasCallStack => String -> a
error String
"doOp': opTypeMatch"

    nonconstCase :: Depth -> ADMonad ADValue
nonconstCase Depth
dep = do
      -- In this case, some values are variables. We therefore
      -- have to perform the necessary steps for AD

      -- First, we calculate the value for the previous depth
      let oprev :: [ADValue]
oprev = (ADValue -> ADValue) -> [ADValue] -> [ADValue]
forall a b. (a -> b) -> [a] -> [b]
map (Depth -> ADValue -> ADValue
primalFor Depth
dep) [ADValue]
o
      ADValue
vprev <- Op -> [ADValue] -> ADMonad ADValue
doOp' Op
op [ADValue]
oprev

      -- Then we separate the values of the maximum depth from
      -- those of a lower depth
      let o' :: [Either ADValue ADVariable]
o' = (ADValue -> Either ADValue ADVariable)
-> [ADValue] -> [Either ADValue ADVariable]
forall a b. (a -> b) -> [a] -> [b]
map (Depth -> ADValue -> Either ADValue ADVariable
divideDepths Depth
dep) [ADValue]
o
      -- Then we find out what type of AD is being performed
      case (Either ADValue ADVariable -> Bool)
-> [Either ADValue ADVariable] -> Maybe (Either ADValue ADVariable)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find Either ADValue ADVariable -> Bool
forall a b. Either a b -> Bool
isRight [Either ADValue ADVariable]
o' of
        -- Finally, we perform the necessary steps for the given
        -- type of AD
        Just (Right (VJP {})) ->
          Depth -> ADVariable -> ADValue
Variable Depth
dep (ADVariable -> ADValue) -> (Tape -> ADVariable) -> Tape -> ADValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VJPValue -> ADVariable
VJP (VJPValue -> ADVariable)
-> (Tape -> VJPValue) -> Tape -> ADVariable
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tape -> VJPValue
VJPValue
            (Tape -> ADValue)
-> ExceptT String (State Counter) Tape -> ADMonad ADValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Op
-> [Either ADValue VJPValue]
-> ADValue
-> ExceptT String (State Counter) Tape
vjpHandleOp Op
op ((Either ADValue ADVariable -> Either ADValue VJPValue)
-> [Either ADValue ADVariable] -> [Either ADValue VJPValue]
forall a b. (a -> b) -> [a] -> [b]
map Either ADValue ADVariable -> Either ADValue VJPValue
extractVJP [Either ADValue ADVariable]
o') ADValue
vprev
        Just (Right (JVP {})) ->
          Depth -> ADVariable -> ADValue
Variable Depth
dep (ADVariable -> ADValue)
-> (ADValue -> ADVariable) -> ADValue -> ADValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. JVPValue -> ADVariable
JVP (JVPValue -> ADVariable)
-> (ADValue -> JVPValue) -> ADValue -> ADVariable
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ADValue -> ADValue -> JVPValue
JVPValue ADValue
vprev
            (ADValue -> ADValue) -> ADMonad ADValue -> ADMonad ADValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Op -> [Either ADValue JVPValue] -> ADMonad ADValue
jvpHandleOp Op
op ((Either ADValue ADVariable -> Either ADValue JVPValue)
-> [Either ADValue ADVariable] -> [Either ADValue JVPValue]
forall a b. (a -> b) -> [a] -> [b]
map Either ADValue ADVariable -> Either ADValue JVPValue
extractJVP [Either ADValue ADVariable]
o')
        Maybe (Either ADValue ADVariable)
_ ->
          -- Since the maximum depth is non-zero, there must be at
          -- least one variable of depth > 0
          String -> ADMonad ADValue
forall a. HasCallStack => String -> a
error String
"find isRight"

calculatePDs :: Op -> [ADValue] -> ADMonad [ADValue]
calculatePDs :: Op -> [ADValue] -> ExceptT String (State Counter) [ADValue]
calculatePDs Op
op [ADValue]
args =
  -- Create a unique VName for each operand
  let n :: [VName]
n = (Int -> VName) -> [Int] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (\Int
i -> Name -> Int -> VName
VName (String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
"x" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i) Int
i) [Int
1 .. [ADValue] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ADValue]
args]
      -- Put the operands in the environment
      m :: Map VName ADValue
m = [(VName, ADValue)] -> Map VName ADValue
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, ADValue)] -> Map VName ADValue)
-> [(VName, ADValue)] -> Map VName ADValue
forall a b. (a -> b) -> a -> b
$ [VName] -> [ADValue] -> [(VName, ADValue)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
n [ADValue]
args

      -- Look up, and calculate the partial derivative
      -- of the operation with respect to each operand
      pde :: [PrimExp VName]
pde =
        [PrimExp VName] -> Maybe [PrimExp VName] -> [PrimExp VName]
forall a. a -> Maybe a -> a
fromMaybe (String -> [PrimExp VName]
forall a. HasCallStack => String -> a
error String
"lookupPDs failed") (Maybe [PrimExp VName] -> [PrimExp VName])
-> Maybe [PrimExp VName] -> [PrimExp VName]
forall a b. (a -> b) -> a -> b
$
          Op -> [PrimExp VName] -> Maybe [PrimExp VName]
lookupPDs Op
op ([PrimExp VName] -> Maybe [PrimExp VName])
-> [PrimExp VName] -> Maybe [PrimExp VName]
forall a b. (a -> b) -> a -> b
$
            (VName -> ADValue -> PrimExp VName)
-> [VName] -> [ADValue] -> [PrimExp VName]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\VName
v ADValue
val -> VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
v (PrimType -> PrimExp VName) -> PrimType -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType (PrimValue -> PrimType) -> PrimValue -> PrimType
forall a b. (a -> b) -> a -> b
$ ADValue -> PrimValue
primitive ADValue
val) [VName]
n [ADValue]
args
      res :: ExceptT e' (State Counter) [ADValue]
res = (PrimExp VName -> ExceptT e' (State Counter) ADValue)
-> [PrimExp VName] -> ExceptT e' (State Counter) [ADValue]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\PrimExp VName
x -> ADMonad ADValue
-> (String -> ExceptT e' (State Counter) ADValue)
-> ExceptT e' (State Counter) ADValue
forall (m :: * -> *) e a e'.
Monad m =>
ExceptT e m a -> (e -> ExceptT e' m a) -> ExceptT e' m a
catchE (Map VName ADValue -> PrimExp VName -> ADMonad ADValue
evalPrimExp Map VName ADValue
m PrimExp VName
x) ((String -> ExceptT e' (State Counter) ADValue)
 -> ExceptT e' (State Counter) ADValue)
-> (String -> ExceptT e' (State Counter) ADValue)
-> ExceptT e' (State Counter) ADValue
forall a b. (a -> b) -> a -> b
$ String -> ExceptT e' (State Counter) ADValue
forall a. HasCallStack => String -> a
error (String -> ExceptT e' (State Counter) ADValue)
-> ShowS -> String -> ExceptT e' (State Counter) ADValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String
"evalPrimExp failed: " <>)) [PrimExp VName]
pde
   in ExceptT String (State Counter) [ADValue]
forall {e'}. ExceptT e' (State Counter) [ADValue]
res

-- VJP / Reverse mode automatic differentiation--
-- In reverse mode AD, the entire computation
-- leading up to a variable must be saved
-- This is represented as a Tape
newtype VJPValue = VJPValue Tape
  deriving (Int -> VJPValue -> ShowS
[VJPValue] -> ShowS
VJPValue -> String
(Int -> VJPValue -> ShowS)
-> (VJPValue -> String) -> ([VJPValue] -> ShowS) -> Show VJPValue
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> VJPValue -> ShowS
showsPrec :: Int -> VJPValue -> ShowS
$cshow :: VJPValue -> String
show :: VJPValue -> String
$cshowList :: [VJPValue] -> ShowS
showList :: [VJPValue] -> ShowS
Show)

-- | Represents a computation tree, as well as every intermediate
-- value in its evaluation.
data Tape
  = -- | This represents a variable. Each variable is given a unique ID,
    -- and has an initial value
    TapeID Counter ADValue
  | -- | This represents a constant.
    TapeConst ADValue
  | -- | This represents the application of a mathematical operation.
    -- Each parameter is given by its Tape, and the return value of
    -- the operation is saved
    TapeOp Op [Tape] Counter ADValue
  deriving (Int -> Tape -> ShowS
[Tape] -> ShowS
Tape -> String
(Int -> Tape -> ShowS)
-> (Tape -> String) -> ([Tape] -> ShowS) -> Show Tape
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Tape -> ShowS
showsPrec :: Int -> Tape -> ShowS
$cshow :: Tape -> String
show :: Tape -> String
$cshowList :: [Tape] -> ShowS
showList :: [Tape] -> ShowS
Show)

-- | Returns the primal value of a Tape.
tapePrimal :: Tape -> ADValue
tapePrimal :: Tape -> ADValue
tapePrimal (TapeID Counter
_ ADValue
v) = ADValue
v
tapePrimal (TapeConst ADValue
v) = ADValue
v
tapePrimal (TapeOp Op
_ [Tape]
_ Counter
_ ADValue
v) = ADValue
v

-- This updates Tape of a VJPValue with a new operation,
-- treating all operands of a lower depth as constants
vjpHandleOp :: Op -> [Either ADValue VJPValue] -> ADValue -> ADMonad Tape
vjpHandleOp :: Op
-> [Either ADValue VJPValue]
-> ADValue
-> ExceptT String (State Counter) Tape
vjpHandleOp Op
op [Either ADValue VJPValue]
p ADValue
v = do
  Counter
i <- State Counter Counter -> ExceptT String (State Counter) Counter
forall (m :: * -> *) a. Monad m => m a -> ExceptT String m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift State Counter Counter
forall (m :: * -> *) s. Monad m => StateT s m s
get
  Tape -> ExceptT String (State Counter) Tape
forall a. a -> ExceptT String (State Counter) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tape -> ExceptT String (State Counter) Tape)
-> Tape -> ExceptT String (State Counter) Tape
forall a b. (a -> b) -> a -> b
$ Op -> [Tape] -> Counter -> ADValue -> Tape
TapeOp Op
op ((Either ADValue VJPValue -> Tape)
-> [Either ADValue VJPValue] -> [Tape]
forall a b. (a -> b) -> [a] -> [b]
map Either ADValue VJPValue -> Tape
toTape [Either ADValue VJPValue]
p) Counter
i ADValue
v
  where
    toTape :: Either ADValue VJPValue -> Tape
toTape (Left ADValue
v') = ADValue -> Tape
TapeConst ADValue
v'
    toTape (Right (VJPValue Tape
t)) = Tape
t

unionWithM :: (Monad m, Ord k) => (a -> a -> m a) -> M.Map k a -> M.Map k a -> m (M.Map k a)
unionWithM :: forall (m :: * -> *) k a.
(Monad m, Ord k) =>
(a -> a -> m a) -> Map k a -> Map k a -> m (Map k a)
unionWithM a -> a -> m a
f Map k a
m1 Map k a
m2 = do
  let m :: Map k a
m = Map k a -> Map k a -> Map k a
forall k a. Ord k => Map k a -> Map k a -> Map k a
M.union (Map k a -> Map k a -> Map k a
forall k a b. Ord k => Map k a -> Map k b -> Map k a
M.difference Map k a
m1 Map k a
m2) (Map k a -> Map k a -> Map k a
forall k a b. Ord k => Map k a -> Map k b -> Map k a
M.difference Map k a
m2 Map k a
m1)
  let k :: [k]
k = Map k a -> [k]
forall k a. Map k a -> [k]
M.keys (Map k a -> [k]) -> Map k a -> [k]
forall a b. (a -> b) -> a -> b
$ Map k a -> Map k a -> Map k a
forall k a b. Ord k => Map k a -> Map k b -> Map k a
M.intersection Map k a
m1 Map k a
m2
  [a]
v <- (k -> m a) -> [k] -> m [a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\k
k' -> a -> a -> m a
f (Maybe a -> a
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe a -> a) -> Maybe a -> a
forall a b. (a -> b) -> a -> b
$ k -> Map k a -> Maybe a
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup k
k' Map k a
m1) (Maybe a -> a
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe a -> a) -> Maybe a -> a
forall a b. (a -> b) -> a -> b
$ k -> Map k a -> Maybe a
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup k
k' Map k a
m2)) [k]
k
  Map k a -> m (Map k a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map k a -> m (Map k a)) -> Map k a -> m (Map k a)
forall a b. (a -> b) -> a -> b
$ (Map k a -> (k, a) -> Map k a) -> Map k a -> [(k, a)] -> Map k a
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\Map k a
m' (k
k', a
v') -> k -> a -> Map k a -> Map k a
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert k
k' a
v' Map k a
m') Map k a
m ([k] -> [a] -> [(k, a)]
forall a b. [a] -> [b] -> [(a, b)]
zip [k]
k [a]
v)

unionsWithM :: (Foldable f, Monad m, Ord k) => (a -> a -> m a) -> f (M.Map k a) -> m (M.Map k a)
unionsWithM :: forall (f :: * -> *) (m :: * -> *) k a.
(Foldable f, Monad m, Ord k) =>
(a -> a -> m a) -> f (Map k a) -> m (Map k a)
unionsWithM a -> a -> m a
f = (Map k a -> Map k a -> m (Map k a))
-> Map k a -> f (Map k a) -> m (Map k a)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ((a -> a -> m a) -> Map k a -> Map k a -> m (Map k a)
forall (m :: * -> *) k a.
(Monad m, Ord k) =>
(a -> a -> m a) -> Map k a -> Map k a -> m (Map k a)
unionWithM a -> a -> m a
f) Map k a
forall k a. Map k a
M.empty

-- | This calculates every partial derivative of a 'Tape'. The result
-- is a map of the partial derivatives, each key corresponding to the
-- ID of a free variable (see TapeID).
deriveTape :: Tape -> ADValue -> Counter -> Either String (M.Map Counter ADValue, Counter)
deriveTape :: Tape
-> ADValue
-> Counter
-> Either String (Map Counter ADValue, Counter)
deriveTape Tape
tp ADValue
s Counter
uid = case State Counter (Either String (Map Counter ADValue))
-> Counter -> (Either String (Map Counter ADValue), Counter)
forall s a. State s a -> s -> (a, s)
runState (ExceptT String (State Counter) (Map Counter ADValue)
-> State Counter (Either String (Map Counter ADValue))
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT String (State Counter) (Map Counter ADValue)
 -> State Counter (Either String (Map Counter ADValue)))
-> ExceptT String (State Counter) (Map Counter ADValue)
-> State Counter (Either String (Map Counter ADValue))
forall a b. (a -> b) -> a -> b
$ Tape
-> ADValue -> ExceptT String (State Counter) (Map Counter ADValue)
deriveTape' Tape
tp ADValue
s) Counter
uid of
  (Left String
e, Counter
_) -> String -> Either String (Map Counter ADValue, Counter)
forall a b. a -> Either a b
Left String
e
  (Right Map Counter ADValue
v, Counter
uid') -> (Map Counter ADValue, Counter)
-> Either String (Map Counter ADValue, Counter)
forall a b. b -> Either a b
Right (Map Counter ADValue
v, Counter
uid')

deriveTape' :: Tape -> ADValue -> ADMonad (M.Map Counter ADValue)
deriveTape' :: Tape
-> ADValue -> ExceptT String (State Counter) (Map Counter ADValue)
deriveTape' (TapeID Counter
i ADValue
_) ADValue
s = Map Counter ADValue
-> ExceptT String (State Counter) (Map Counter ADValue)
forall a. a -> ExceptT String (State Counter) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map Counter ADValue
 -> ExceptT String (State Counter) (Map Counter ADValue))
-> Map Counter ADValue
-> ExceptT String (State Counter) (Map Counter ADValue)
forall a b. (a -> b) -> a -> b
$ Counter -> ADValue -> Map Counter ADValue
forall k a. k -> a -> Map k a
M.singleton Counter
i ADValue
s
deriveTape' (TapeConst ADValue
_) ADValue
_ = Map Counter ADValue
-> ExceptT String (State Counter) (Map Counter ADValue)
forall a. a -> ExceptT String (State Counter) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Map Counter ADValue
forall k a. Map k a
M.empty
deriveTape' tp :: Tape
tp@(TapeOp Op
op [Tape]
p Counter
uid ADValue
_) ADValue
s =
  (Map Counter ADValue, Map Counter Int) -> Map Counter ADValue
forall a b. (a, b) -> a
fst ((Map Counter ADValue, Map Counter Int) -> Map Counter ADValue)
-> ExceptT
     String (State Counter) (Map Counter ADValue, Map Counter Int)
-> ExceptT String (State Counter) (Map Counter ADValue)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Tape
-> ADValue
-> Map Counter ADValue
-> Map Counter Int
-> ExceptT
     String (State Counter) (Map Counter ADValue, Map Counter Int)
derive Tape
tp ADValue
s Map Counter ADValue
forall k a. Map k a
M.empty ([Tape] -> Map Counter Int -> Map Counter Int
countReferences [Tape]
p (Map Counter Int -> Map Counter Int)
-> Map Counter Int -> Map Counter Int
forall a b. (a -> b) -> a -> b
$ Counter -> Int -> Map Counter Int
forall k a. k -> a -> Map k a
M.singleton (-Counter
uid Counter -> Counter -> Counter
forall a. Num a => a -> a -> a
- Counter
1) Int
1)
  where
    add :: ADValue -> ADValue -> ADMonad ADValue
add ADValue
x ADValue
y = Op -> [ADValue] -> ADMonad ADValue
doOp' (BinOp -> Op
OpBin (BinOp -> Op) -> BinOp -> Op
forall a b. (a -> b) -> a -> b
$ PrimType -> BinOp
addFor (PrimType -> BinOp) -> PrimType -> BinOp
forall a b. (a -> b) -> a -> b
$ Op -> PrimType
opReturnType Op
op) [ADValue
x, ADValue
y]
    mul :: ADValue -> ADValue -> ADMonad ADValue
mul ADValue
x ADValue
y = Op -> [ADValue] -> ADMonad ADValue
doOp' (BinOp -> Op
OpBin (BinOp -> Op) -> BinOp -> Op
forall a b. (a -> b) -> a -> b
$ PrimType -> BinOp
mulFor (PrimType -> BinOp) -> PrimType -> BinOp
forall a b. (a -> b) -> a -> b
$ Op -> PrimType
opReturnType Op
op) [ADValue
x, ADValue
y]
    madd :: Counter -> ADValue -> M.Map Counter ADValue -> ADMonad (M.Map Counter ADValue)
    madd :: Counter
-> ADValue
-> Map Counter ADValue
-> ExceptT String (State Counter) (Map Counter ADValue)
madd Counter
i ADValue
a Map Counter ADValue
m = case Counter -> Map Counter ADValue -> Maybe ADValue
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Counter
i Map Counter ADValue
m of
      Just ADValue
b -> ADValue -> ADValue -> ADMonad ADValue
add ADValue
a ADValue
b ADMonad ADValue
-> (ADValue -> Map Counter ADValue)
-> ExceptT String (State Counter) (Map Counter ADValue)
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (\ADValue
x -> Counter -> ADValue -> Map Counter ADValue -> Map Counter ADValue
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Counter
i ADValue
x Map Counter ADValue
m)
      Maybe ADValue
Nothing -> Map Counter ADValue
-> ExceptT String (State Counter) (Map Counter ADValue)
forall a. a -> ExceptT String (State Counter) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map Counter ADValue
 -> ExceptT String (State Counter) (Map Counter ADValue))
-> Map Counter ADValue
-> ExceptT String (State Counter) (Map Counter ADValue)
forall a b. (a -> b) -> a -> b
$ Counter -> ADValue -> Map Counter ADValue -> Map Counter ADValue
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Counter
i ADValue
a Map Counter ADValue
m
    derive ::
      Tape ->
      ADValue ->
      M.Map Counter ADValue ->
      M.Map Counter Int ->
      ADMonad (M.Map Counter ADValue, M.Map Counter Int)
    derive :: Tape
-> ADValue
-> Map Counter ADValue
-> Map Counter Int
-> ExceptT
     String (State Counter) (Map Counter ADValue, Map Counter Int)
derive (TapeID Counter
i ADValue
_) ADValue
s' Map Counter ADValue
ss Map Counter Int
rs = Counter
-> ADValue
-> Map Counter ADValue
-> ExceptT String (State Counter) (Map Counter ADValue)
madd Counter
i ADValue
s' Map Counter ADValue
ss ExceptT String (State Counter) (Map Counter ADValue)
-> (Map Counter ADValue -> (Map Counter ADValue, Map Counter Int))
-> ExceptT
     String (State Counter) (Map Counter ADValue, Map Counter Int)
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (,Map Counter Int
rs)
    derive (TapeConst ADValue
_) ADValue
_ Map Counter ADValue
ss Map Counter Int
rs = (Map Counter ADValue, Map Counter Int)
-> ExceptT
     String (State Counter) (Map Counter ADValue, Map Counter Int)
forall a. a -> ExceptT String (State Counter) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map Counter ADValue
ss, Map Counter Int
rs)
    derive (TapeOp Op
op' [Tape]
p' Counter
uid' ADValue
_) ADValue
s' Map Counter ADValue
ss Map Counter Int
rs = do
      -- Decrease the reference counter
      let r :: Int
r = Maybe Int -> Int
forall a. HasCallStack => Maybe a -> a
fromJust (Counter -> Map Counter Int -> Maybe Int
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (-Counter
uid' Counter -> Counter -> Counter
forall a. Num a => a -> a -> a
- Counter
1) Map Counter Int
rs) Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
          rs' :: Map Counter Int
rs' = Counter -> Int -> Map Counter Int -> Map Counter Int
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (-Counter
uid' Counter -> Counter -> Counter
forall a. Num a => a -> a -> a
- Counter
1) Int
r Map Counter Int
rs
      -- Add the sensitivity
      Map Counter ADValue
ss' <- Counter
-> ADValue
-> Map Counter ADValue
-> ExceptT String (State Counter) (Map Counter ADValue)
madd (-Counter
uid' Counter -> Counter -> Counter
forall a. Num a => a -> a -> a
- Counter
1) ADValue
s' Map Counter ADValue
ss
      -- If there are still more references left, do nothing
      if Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0
        then (Map Counter ADValue, Map Counter Int)
-> ExceptT
     String (State Counter) (Map Counter ADValue, Map Counter Int)
forall a. a -> ExceptT String (State Counter) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map Counter ADValue
ss', Map Counter Int
rs')
        else -- Otherwise, derive the tape

          if Int
r Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
            then do
              let s'' :: ADValue
s'' = Maybe ADValue -> ADValue
forall a. HasCallStack => Maybe a -> a
fromJust (Counter -> Map Counter ADValue -> Maybe ADValue
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (-Counter
uid' Counter -> Counter -> Counter
forall a. Num a => a -> a -> a
- Counter
1) Map Counter ADValue
ss')

              -- Calculate the new sensitivities
              [ADValue]
s''' <- case Op
op' of
                OpConv ConvOp
op'' ->
                  -- In case of type conversion, simply convert the sensitivity
                  [ADMonad ADValue] -> ExceptT String (State Counter) [ADValue]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence [Op -> [ADValue] -> ADMonad ADValue
doOp' (ConvOp -> Op
OpConv (ConvOp -> Op) -> ConvOp -> Op
forall a b. (a -> b) -> a -> b
$ ConvOp -> ConvOp
flipConvOp ConvOp
op'') [ADValue
s'']]
                Op
_ -> Op -> [ADValue] -> ExceptT String (State Counter) [ADValue]
calculatePDs Op
op' ((Tape -> ADValue) -> [Tape] -> [ADValue]
forall a b. (a -> b) -> [a] -> [b]
map Tape -> ADValue
tapePrimal [Tape]
p') ExceptT String (State Counter) [ADValue]
-> ([ADValue] -> ExceptT String (State Counter) [ADValue])
-> ExceptT String (State Counter) [ADValue]
forall a b.
ExceptT String (State Counter) a
-> (a -> ExceptT String (State Counter) b)
-> ExceptT String (State Counter) b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (ADValue -> ADMonad ADValue)
-> [ADValue] -> ExceptT String (State Counter) [ADValue]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (ADValue -> ADValue -> ADMonad ADValue
mul ADValue
s'')

              -- Propagate the new sensitivities
              ((Map Counter ADValue, Map Counter Int)
 -> (Tape, ADValue)
 -> ExceptT
      String (State Counter) (Map Counter ADValue, Map Counter Int))
-> (Map Counter ADValue, Map Counter Int)
-> [(Tape, ADValue)]
-> ExceptT
     String (State Counter) (Map Counter ADValue, Map Counter Int)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldlM (\(Map Counter ADValue
ss'', Map Counter Int
rs'') (Tape
p'', ADValue
s'''') -> Tape
-> ADValue
-> Map Counter ADValue
-> Map Counter Int
-> ExceptT
     String (State Counter) (Map Counter ADValue, Map Counter Int)
derive Tape
p'' ADValue
s'''' Map Counter ADValue
ss'' Map Counter Int
rs'') (Map Counter ADValue
ss', Map Counter Int
rs') ([(Tape, ADValue)]
 -> ExceptT
      String (State Counter) (Map Counter ADValue, Map Counter Int))
-> [(Tape, ADValue)]
-> ExceptT
     String (State Counter) (Map Counter ADValue, Map Counter Int)
forall a b. (a -> b) -> a -> b
$ [Tape] -> [ADValue] -> [(Tape, ADValue)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Tape]
p' [ADValue]
s'''
            else String
-> ExceptT
     String (State Counter) (Map Counter ADValue, Map Counter Int)
forall a. HasCallStack => String -> a
error String
"TODO: This branch is unreachable unless `countReferences` undercounts"
    countReferences :: [Tape] -> M.Map Counter Int -> M.Map Counter Int
    countReferences :: [Tape] -> Map Counter Int -> Map Counter Int
countReferences [Tape]
p' Map Counter Int
d' = (Map Counter Int -> Tape -> Map Counter Int)
-> Map Counter Int -> [Tape] -> Map Counter Int
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Map Counter Int -> Tape -> Map Counter Int
f Map Counter Int
d' [Tape]
p'
    f :: Map Counter Int -> Tape -> Map Counter Int
f Map Counter Int
d'' Tape
x =
      case Tape
x of
        (TapeOp Op
_ [Tape]
p'' Counter
uid'' ADValue
_) -> case Counter -> Map Counter Int -> Maybe Int
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (-Counter
uid'' Counter -> Counter -> Counter
forall a. Num a => a -> a -> a
- Counter
1) Map Counter Int
d'' of
          Just Int
v -> Counter -> Int -> Map Counter Int -> Map Counter Int
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (-Counter
uid'' Counter -> Counter -> Counter
forall a. Num a => a -> a -> a
- Counter
1) (Int
v Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Map Counter Int
d''
          Maybe Int
Nothing -> [Tape] -> Map Counter Int -> Map Counter Int
countReferences [Tape]
p'' (Map Counter Int -> Map Counter Int)
-> Map Counter Int -> Map Counter Int
forall a b. (a -> b) -> a -> b
$ Counter -> Int -> Map Counter Int -> Map Counter Int
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (-Counter
uid'' Counter -> Counter -> Counter
forall a. Num a => a -> a -> a
- Counter
1) Int
1 Map Counter Int
d''
        Tape
_ -> Map Counter Int
d''

-- JVP / Forward mode automatic differentiation--

-- | In JVP, the derivative of the variable must be saved. This is
-- represented as a second value.
data JVPValue = JVPValue ADValue ADValue
  deriving (Int -> JVPValue -> ShowS
[JVPValue] -> ShowS
JVPValue -> String
(Int -> JVPValue -> ShowS)
-> (JVPValue -> String) -> ([JVPValue] -> ShowS) -> Show JVPValue
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> JVPValue -> ShowS
showsPrec :: Int -> JVPValue -> ShowS
$cshow :: JVPValue -> String
show :: JVPValue -> String
$cshowList :: [JVPValue] -> ShowS
showList :: [JVPValue] -> ShowS
Show)

-- | This calculates the tangent part of the JVPValue resulting
-- from the application of a mathematical operation on one or more
-- JVPValues.
jvpHandleOp :: Op -> [Either ADValue JVPValue] -> ADMonad ADValue
jvpHandleOp :: Op -> [Either ADValue JVPValue] -> ADMonad ADValue
jvpHandleOp Op
op [Either ADValue JVPValue]
p = do
  case Op
op of
    OpConv ConvOp
_ ->
      -- In case of type conversion, simply convert
      -- the old tangent
      Op -> [ADValue] -> ADMonad ADValue
doOp' Op
op [Either ADValue JVPValue -> ADValue
forall {a}. Either a JVPValue -> ADValue
tangent (Either ADValue JVPValue -> ADValue)
-> Either ADValue JVPValue -> ADValue
forall a b. (a -> b) -> a -> b
$ [Either ADValue JVPValue] -> Either ADValue JVPValue
forall a. HasCallStack => [a] -> a
head [Either ADValue JVPValue]
p]
    Op
_ -> do
      -- Calculate the new tangent using the chain rule
      [ADValue]
pds <- Op -> [ADValue] -> ExceptT String (State Counter) [ADValue]
calculatePDs Op
op ([ADValue] -> ExceptT String (State Counter) [ADValue])
-> [ADValue] -> ExceptT String (State Counter) [ADValue]
forall a b. (a -> b) -> a -> b
$ (Either ADValue JVPValue -> ADValue)
-> [Either ADValue JVPValue] -> [ADValue]
forall a b. (a -> b) -> [a] -> [b]
map Either ADValue JVPValue -> ADValue
primal' [Either ADValue JVPValue]
p
      [ADValue]
vs <- (ADValue -> ADValue -> ADMonad ADValue)
-> [ADValue]
-> [ADValue]
-> ExceptT String (State Counter) [ADValue]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM ADValue -> ADValue -> ADMonad ADValue
mul [ADValue]
pds ([ADValue] -> ExceptT String (State Counter) [ADValue])
-> [ADValue] -> ExceptT String (State Counter) [ADValue]
forall a b. (a -> b) -> a -> b
$ (Either ADValue JVPValue -> ADValue)
-> [Either ADValue JVPValue] -> [ADValue]
forall a b. (a -> b) -> [a] -> [b]
map Either ADValue JVPValue -> ADValue
forall {a}. Either a JVPValue -> ADValue
tangent [Either ADValue JVPValue]
p
      (ADValue -> ADValue -> ADMonad ADValue)
-> ADValue -> [ADValue] -> ADMonad ADValue
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ADValue -> ADValue -> ADMonad ADValue
add (PrimValue -> ADValue
Constant (PrimValue -> ADValue) -> PrimValue -> ADValue
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
op_t) [ADValue]
vs
  where
    op_t :: PrimType
op_t = Op -> PrimType
opReturnType Op
op
    primal' :: Either ADValue JVPValue -> ADValue
primal' (Left ADValue
v) = ADValue
v
    primal' (Right (JVPValue ADValue
v ADValue
_)) = ADValue
v
    tangent :: Either a JVPValue -> ADValue
tangent (Left a
_) = PrimValue -> ADValue
Constant (PrimValue -> ADValue) -> PrimValue -> ADValue
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue (PrimType -> PrimValue) -> PrimType -> PrimValue
forall a b. (a -> b) -> a -> b
$ Op -> PrimType
opReturnType Op
op
    tangent (Right (JVPValue ADValue
_ ADValue
d)) = ADValue
d
    add :: ADValue -> ADValue -> ADMonad ADValue
add ADValue
x ADValue
y = Op -> [ADValue] -> ADMonad ADValue
doOp' (BinOp -> Op
OpBin (BinOp -> Op) -> BinOp -> Op
forall a b. (a -> b) -> a -> b
$ PrimType -> BinOp
addFor (PrimType -> BinOp) -> PrimType -> BinOp
forall a b. (a -> b) -> a -> b
$ Op -> PrimType
opReturnType Op
op) [ADValue
x, ADValue
y]
    mul :: ADValue -> ADValue -> ADMonad ADValue
mul ADValue
x ADValue
y = Op -> [ADValue] -> ADMonad ADValue
doOp' (BinOp -> Op
OpBin (BinOp -> Op) -> BinOp -> Op
forall a b. (a -> b) -> a -> b
$ PrimType -> BinOp
mulFor (PrimType -> BinOp) -> PrimType -> BinOp
forall a b. (a -> b) -> a -> b
$ Op -> PrimType
opReturnType Op
op) [ADValue
x, ADValue
y]