module Language.Futhark.Interpreter.AD
  ( Op (..),
    ADVariable (..),
    ADValue (..),
    Tape (..),
    VJPValue (..),
    JVPValue (..),
    doOp,
    addFor,
    tapePrimal,
    primitive,
    varPrimal,
    deriveTape,
  )
where

import Control.Monad (foldM, zipWithM)
import Data.Either (isRight)
import Data.List (find, foldl')
import Data.Map qualified as M
import Data.Maybe (fromMaybe)
import Data.Text qualified as T
import Futhark.AD.Derivatives (pdBinOp, pdBuiltin, pdUnOp)
import Futhark.Analysis.PrimExp (PrimExp (..))
import Language.Futhark.Core (VName (..), nameFromString, nameFromText)
import Language.Futhark.Primitive

-- Mathematical operations subject to AD.
data Op
  = OpBin BinOp
  | OpCmp CmpOp
  | OpUn UnOp
  | OpFn T.Text
  | OpConv ConvOp
  deriving (Int -> Op -> ShowS
[Op] -> ShowS
Op -> [Char]
(Int -> Op -> ShowS)
-> (Op -> [Char]) -> ([Op] -> ShowS) -> Show Op
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Op -> ShowS
showsPrec :: Int -> Op -> ShowS
$cshow :: Op -> [Char]
show :: Op -> [Char]
$cshowList :: [Op] -> ShowS
showList :: [Op] -> ShowS
Show)

-- Checks if an operation matches the types of its operands
opTypeMatch :: Op -> [PrimType] -> Bool
opTypeMatch :: Op -> [PrimType] -> Bool
opTypeMatch (OpBin BinOp
op) [PrimType]
p = (PrimType -> Bool) -> [PrimType] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\PrimType
x -> BinOp -> PrimType
binOpType BinOp
op PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
x) [PrimType]
p
opTypeMatch (OpCmp CmpOp
op) [PrimType]
p = (PrimType -> Bool) -> [PrimType] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\PrimType
x -> CmpOp -> PrimType
cmpOpType CmpOp
op PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
x) [PrimType]
p
opTypeMatch (OpUn UnOp
op) [PrimType]
p = (PrimType -> Bool) -> [PrimType] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\PrimType
x -> UnOp -> PrimType
unOpType UnOp
op PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
x) [PrimType]
p
opTypeMatch (OpConv ConvOp
op) [PrimType]
p = (PrimType -> Bool) -> [PrimType] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\PrimType
x -> (PrimType, PrimType) -> PrimType
forall a b. (a, b) -> a
fst (ConvOp -> (PrimType, PrimType)
convOpType ConvOp
op) PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
x) [PrimType]
p
opTypeMatch (OpFn Text
fn) [PrimType]
p = case Text
-> Map Text ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
-> Maybe ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Text
fn Map Text ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
primFuns of
  Just ([PrimType]
t, PrimType
_, [PrimValue] -> Maybe PrimValue
_) -> [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ (PrimType -> PrimType -> Bool)
-> [PrimType] -> [PrimType] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
(==) [PrimType]
t [PrimType]
p
  Maybe ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
Nothing -> [Char] -> Bool
forall a. HasCallStack => [Char] -> a
error [Char]
"opTypeMatch" -- It is assumed that the function exists

-- Gets the return type of an operation
opReturnType :: Op -> PrimType
opReturnType :: Op -> PrimType
opReturnType (OpBin BinOp
op) = BinOp -> PrimType
binOpType BinOp
op
opReturnType (OpCmp CmpOp
op) = CmpOp -> PrimType
cmpOpType CmpOp
op
opReturnType (OpUn UnOp
op) = UnOp -> PrimType
unOpType UnOp
op
opReturnType (OpConv ConvOp
op) = (PrimType, PrimType) -> PrimType
forall a b. (a, b) -> b
snd ((PrimType, PrimType) -> PrimType)
-> (PrimType, PrimType) -> PrimType
forall a b. (a -> b) -> a -> b
$ ConvOp -> (PrimType, PrimType)
convOpType ConvOp
op
opReturnType (OpFn Text
fn) = case Text
-> Map Text ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
-> Maybe ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Text
fn Map Text ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
primFuns of
  Just ([PrimType]
_, PrimType
t, [PrimValue] -> Maybe PrimValue
_) -> PrimType
t
  Maybe ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
Nothing -> [Char] -> PrimType
forall a. HasCallStack => [Char] -> a
error [Char]
"opReturnType" -- It is assumed that the function exists

-- Returns the operation which performs addition (or an
-- equivalent operation) on the given type
addFor :: PrimType -> BinOp
addFor :: PrimType -> BinOp
addFor (IntType IntType
t) = IntType -> Overflow -> BinOp
Add IntType
t Overflow
OverflowWrap
addFor (FloatType FloatType
t) = FloatType -> BinOp
FAdd FloatType
t
addFor PrimType
Bool = BinOp
LogOr
addFor PrimType
t = [Char] -> BinOp
forall a. HasCallStack => [Char] -> a
error ([Char] -> BinOp) -> [Char] -> BinOp
forall a b. (a -> b) -> a -> b
$ [Char]
"addFor: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ PrimType -> [Char]
forall a. Show a => a -> [Char]
show PrimType
t

-- Returns the function which performs multiplication
-- (or an equivalent operation) on the given type
mulFor :: PrimType -> BinOp
mulFor :: PrimType -> BinOp
mulFor (IntType IntType
t) = IntType -> Overflow -> BinOp
Mul IntType
t Overflow
OverflowWrap
mulFor (FloatType FloatType
t) = FloatType -> BinOp
FMul FloatType
t
mulFor PrimType
Bool = BinOp
LogAnd
mulFor PrimType
t = [Char] -> BinOp
forall a. HasCallStack => [Char] -> a
error ([Char] -> BinOp) -> [Char] -> BinOp
forall a b. (a -> b) -> a -> b
$ [Char]
"mulFor: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ PrimType -> [Char]
forall a. Show a => a -> [Char]
show PrimType
t

-- Types and utility functions--
-- When taking the partial derivative of a function, we
-- must differentiate between the values which are kept
-- constant, and those which are not
data ADValue
  = Variable Int ADVariable
  | Constant PrimValue
  deriving (Int -> ADValue -> ShowS
[ADValue] -> ShowS
ADValue -> [Char]
(Int -> ADValue -> ShowS)
-> (ADValue -> [Char]) -> ([ADValue] -> ShowS) -> Show ADValue
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ADValue -> ShowS
showsPrec :: Int -> ADValue -> ShowS
$cshow :: ADValue -> [Char]
show :: ADValue -> [Char]
$cshowList :: [ADValue] -> ShowS
showList :: [ADValue] -> ShowS
Show)

-- When performing automatic differentiation, each derived
-- variable must be augmented with additional data. This
-- value holds the primitive value of the variable, as well
-- as its data
data ADVariable
  = VJP VJPValue
  | JVP JVPValue
  deriving (Int -> ADVariable -> ShowS
[ADVariable] -> ShowS
ADVariable -> [Char]
(Int -> ADVariable -> ShowS)
-> (ADVariable -> [Char])
-> ([ADVariable] -> ShowS)
-> Show ADVariable
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ADVariable -> ShowS
showsPrec :: Int -> ADVariable -> ShowS
$cshow :: ADVariable -> [Char]
show :: ADVariable -> [Char]
$cshowList :: [ADVariable] -> ShowS
showList :: [ADVariable] -> ShowS
Show)

depth :: ADValue -> Int
depth :: ADValue -> Int
depth (Variable Int
d ADVariable
_) = Int
d
depth (Constant PrimValue
_) = Int
0

primal :: ADValue -> ADValue
primal :: ADValue -> ADValue
primal (Variable Int
_ (VJP (VJPValue Tape
t))) = Tape -> ADValue
tapePrimal Tape
t
primal (Variable Int
_ (JVP (JVPValue ADValue
v ADValue
_))) = ADValue -> ADValue
primal ADValue
v
primal (Constant PrimValue
v) = PrimValue -> ADValue
Constant PrimValue
v

primitive :: ADValue -> PrimValue
primitive :: ADValue -> PrimValue
primitive (Variable Int
_ ADVariable
v) = ADVariable -> PrimValue
varPrimal ADVariable
v
primitive (Constant PrimValue
v) = PrimValue
v

varPrimal :: ADVariable -> PrimValue
varPrimal :: ADVariable -> PrimValue
varPrimal (VJP (VJPValue Tape
t)) = ADValue -> PrimValue
primitive (ADValue -> PrimValue) -> ADValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ Tape -> ADValue
tapePrimal Tape
t
varPrimal (JVP (JVPValue ADValue
v ADValue
_)) = ADValue -> PrimValue
primitive (ADValue -> PrimValue) -> ADValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ ADValue -> ADValue
primal ADValue
v

-- Evaluates a PrimExp using doOp
evalPrimExp :: M.Map VName ADValue -> PrimExp VName -> Maybe ADValue
evalPrimExp :: Map VName ADValue -> PrimExp VName -> Maybe ADValue
evalPrimExp Map VName ADValue
m (LeafExp VName
n PrimType
_) = VName -> Map VName ADValue -> Maybe ADValue
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
n Map VName ADValue
m
evalPrimExp Map VName ADValue
_ (ValueExp PrimValue
pv) = ADValue -> Maybe ADValue
forall a. a -> Maybe a
Just (ADValue -> Maybe ADValue) -> ADValue -> Maybe ADValue
forall a b. (a -> b) -> a -> b
$ PrimValue -> ADValue
Constant PrimValue
pv
evalPrimExp Map VName ADValue
m (BinOpExp BinOp
op PrimExp VName
x PrimExp VName
y) = do
  ADValue
x' <- Map VName ADValue -> PrimExp VName -> Maybe ADValue
evalPrimExp Map VName ADValue
m PrimExp VName
x
  ADValue
y' <- Map VName ADValue -> PrimExp VName -> Maybe ADValue
evalPrimExp Map VName ADValue
m PrimExp VName
y
  Op -> [ADValue] -> Maybe ADValue
doOp (BinOp -> Op
OpBin BinOp
op) [ADValue
x', ADValue
y']
evalPrimExp Map VName ADValue
m (CmpOpExp CmpOp
op PrimExp VName
x PrimExp VName
y) = do
  ADValue
x' <- Map VName ADValue -> PrimExp VName -> Maybe ADValue
evalPrimExp Map VName ADValue
m PrimExp VName
x
  ADValue
y' <- Map VName ADValue -> PrimExp VName -> Maybe ADValue
evalPrimExp Map VName ADValue
m PrimExp VName
y
  Op -> [ADValue] -> Maybe ADValue
doOp (CmpOp -> Op
OpCmp CmpOp
op) [ADValue
x', ADValue
y']
evalPrimExp Map VName ADValue
m (UnOpExp UnOp
op PrimExp VName
x) = do
  ADValue
x' <- Map VName ADValue -> PrimExp VName -> Maybe ADValue
evalPrimExp Map VName ADValue
m PrimExp VName
x
  Op -> [ADValue] -> Maybe ADValue
doOp (UnOp -> Op
OpUn UnOp
op) [ADValue
x']
evalPrimExp Map VName ADValue
m (ConvOpExp ConvOp
op PrimExp VName
x) = do
  ADValue
x' <- Map VName ADValue -> PrimExp VName -> Maybe ADValue
evalPrimExp Map VName ADValue
m PrimExp VName
x
  Op -> [ADValue] -> Maybe ADValue
doOp (ConvOp -> Op
OpConv ConvOp
op) [ADValue
x']
evalPrimExp Map VName ADValue
m (FunExp Text
fn [PrimExp VName]
p PrimType
_) = do
  [ADValue]
p' <- (PrimExp VName -> Maybe ADValue)
-> [PrimExp VName] -> Maybe [ADValue]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Map VName ADValue -> PrimExp VName -> Maybe ADValue
evalPrimExp Map VName ADValue
m) [PrimExp VName]
p
  Op -> [ADValue] -> Maybe ADValue
doOp (Text -> Op
OpFn Text
fn) [ADValue]
p'

-- Returns a list of PrimExps calculating the partial
-- derivative of each operands of a given operation
lookupPDs :: Op -> [PrimExp VName] -> Maybe [PrimExp VName]
lookupPDs :: Op -> [PrimExp VName] -> Maybe [PrimExp VName]
lookupPDs (OpBin BinOp
op) [PrimExp VName
x, PrimExp VName
y] = [PrimExp VName] -> Maybe [PrimExp VName]
forall a. a -> Maybe a
Just ([PrimExp VName] -> Maybe [PrimExp VName])
-> [PrimExp VName] -> Maybe [PrimExp VName]
forall a b. (a -> b) -> a -> b
$ do
  let (PrimExp VName
a, PrimExp VName
b) = BinOp
-> PrimExp VName -> PrimExp VName -> (PrimExp VName, PrimExp VName)
pdBinOp BinOp
op PrimExp VName
x PrimExp VName
y
  [PrimExp VName
a, PrimExp VName
b]
lookupPDs (OpUn UnOp
op) [PrimExp VName
x] = [PrimExp VName] -> Maybe [PrimExp VName]
forall a. a -> Maybe a
Just [UnOp -> PrimExp VName -> PrimExp VName
pdUnOp UnOp
op PrimExp VName
x]
lookupPDs (OpFn Text
fn) [PrimExp VName]
p = Name -> [PrimExp VName] -> Maybe [PrimExp VName]
pdBuiltin (Text -> Name
nameFromText Text
fn) [PrimExp VName]
p
lookupPDs Op
_ [PrimExp VName]
_ = Maybe [PrimExp VName]
forall a. Maybe a
Nothing

-- Shared AD logic--
-- This function performs a mathematical operation on a
-- list of operands, performing automatic differentiation
-- if one or more operands is a Variable (of depth > 0)
doOp :: Op -> [ADValue] -> Maybe ADValue
doOp :: Op -> [ADValue] -> Maybe ADValue
doOp Op
op [ADValue]
o
  | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Op -> [PrimType] -> Bool
opTypeMatch Op
op ((PrimValue -> PrimType) -> [PrimValue] -> [PrimType]
forall a b. (a -> b) -> [a] -> [b]
map PrimValue -> PrimType
primValueType [PrimValue]
pv) =
      -- This function may be called with arguments of invalid types,
      -- because it is used as part of an overloaded operator.
      Maybe ADValue
forall a. Maybe a
Nothing
  | Bool
otherwise = do
      let dep :: Int
dep = case Op
op of
            OpCmp CmpOp
_ -> Int
0 -- AD is not well-defined for comparason operations
            -- There are no derivatives for those written in
            -- PrimExp (check lookupPDs)
            Op
_ -> [Int] -> Int
forall a. Ord a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum ((ADValue -> Int) -> [ADValue] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ADValue -> Int
depth [ADValue]
o)
      if Int
dep Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 then Maybe ADValue
constCase else Int -> Maybe ADValue
nonconstCase Int
dep
  where
    pv :: [PrimValue]
pv = (ADValue -> PrimValue) -> [ADValue] -> [PrimValue]
forall a b. (a -> b) -> [a] -> [b]
map ADValue -> PrimValue
primitive [ADValue]
o

    divideDepths :: Int -> ADValue -> Either ADValue ADVariable
    divideDepths :: Int -> ADValue -> Either ADValue ADVariable
divideDepths Int
_ v :: ADValue
v@(Constant {}) = ADValue -> Either ADValue ADVariable
forall a b. a -> Either a b
Left ADValue
v
    divideDepths Int
d v :: ADValue
v@(Variable Int
d' ADVariable
v') = if Int
d' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
d then ADValue -> Either ADValue ADVariable
forall a b. a -> Either a b
Left ADValue
v else ADVariable -> Either ADValue ADVariable
forall a b. b -> Either a b
Right ADVariable
v'

    -- TODO: There may be a more graceful way of
    -- doing this
    extractVJP :: Either ADValue ADVariable -> Either ADValue VJPValue
    extractVJP :: Either ADValue ADVariable -> Either ADValue VJPValue
extractVJP (Right (VJP VJPValue
v)) = VJPValue -> Either ADValue VJPValue
forall a b. b -> Either a b
Right VJPValue
v
    extractVJP (Left ADValue
v) = ADValue -> Either ADValue VJPValue
forall a b. a -> Either a b
Left ADValue
v
    extractVJP Either ADValue ADVariable
_ =
      -- This will never be called when the maximum depth layer is JVP
      [Char] -> Either ADValue VJPValue
forall a. HasCallStack => [Char] -> a
error [Char]
"extractVJP"

    -- TODO: There may be a more graceful way of
    -- doing this
    extractJVP :: Either ADValue ADVariable -> Either ADValue JVPValue
    extractJVP :: Either ADValue ADVariable -> Either ADValue JVPValue
extractJVP (Right (JVP JVPValue
v)) = JVPValue -> Either ADValue JVPValue
forall a b. b -> Either a b
Right JVPValue
v
    extractJVP (Left ADValue
v) = ADValue -> Either ADValue JVPValue
forall a b. a -> Either a b
Left ADValue
v
    extractJVP Either ADValue ADVariable
_ =
      -- This will never be called when the maximum depth layer is VJP
      [Char] -> Either ADValue JVPValue
forall a. HasCallStack => [Char] -> a
error [Char]
"extractJVP"

    -- In this case, every operand is a constant, and the
    -- mathematical operation can be applied as it would be
    -- otherwise
    constCase :: Maybe ADValue
constCase =
      PrimValue -> ADValue
Constant (PrimValue -> ADValue) -> Maybe PrimValue -> Maybe ADValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> case (Op
op, [PrimValue]
pv) of
        (OpBin BinOp
op', [PrimValue
x, PrimValue
y]) -> BinOp -> PrimValue -> PrimValue -> Maybe PrimValue
doBinOp BinOp
op' PrimValue
x PrimValue
y
        (OpCmp CmpOp
op', [PrimValue
x, PrimValue
y]) -> Bool -> PrimValue
BoolValue (Bool -> PrimValue) -> Maybe Bool -> Maybe PrimValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CmpOp -> PrimValue -> PrimValue -> Maybe Bool
doCmpOp CmpOp
op' PrimValue
x PrimValue
y
        (OpUn UnOp
op', [PrimValue
x]) -> UnOp -> PrimValue -> Maybe PrimValue
doUnOp UnOp
op' PrimValue
x
        (OpConv ConvOp
op', [PrimValue
x]) -> ConvOp -> PrimValue -> Maybe PrimValue
doConvOp ConvOp
op' PrimValue
x
        (OpFn Text
fn, [PrimValue]
_) -> do
          ([PrimType]
_, PrimType
_, [PrimValue] -> Maybe PrimValue
f) <- Text
-> Map Text ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
-> Maybe ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Text
fn Map Text ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
primFuns
          [PrimValue] -> Maybe PrimValue
f [PrimValue]
pv
        (Op, [PrimValue])
_ -> [Char] -> Maybe PrimValue
forall a. HasCallStack => [Char] -> a
error [Char]
"doOp: opTypeMatch"

    nonconstCase :: Int -> Maybe ADValue
nonconstCase Int
dep = do
      -- In this case, some values are variables. We therefore
      -- have to perform the necessary steps for AD

      -- First, we calculate the value for the previous depth
      let oprev :: [ADValue]
oprev = (ADValue -> ADValue) -> [ADValue] -> [ADValue]
forall a b. (a -> b) -> [a] -> [b]
map ADValue -> ADValue
primal [ADValue]
o
      ADValue
vprev <- Op -> [ADValue] -> Maybe ADValue
doOp Op
op [ADValue]
oprev

      -- Then we separate the values of the maximum depth from
      -- those of a lower depth
      let o' :: [Either ADValue ADVariable]
o' = (ADValue -> Either ADValue ADVariable)
-> [ADValue] -> [Either ADValue ADVariable]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> ADValue -> Either ADValue ADVariable
divideDepths Int
dep) [ADValue]
o
      -- Then we find out what type of AD is being performed
      case (Either ADValue ADVariable -> Bool)
-> [Either ADValue ADVariable] -> Maybe (Either ADValue ADVariable)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find Either ADValue ADVariable -> Bool
forall a b. Either a b -> Bool
isRight [Either ADValue ADVariable]
o' of
        -- Finally, we perform the necessary steps for the given
        -- type of AD
        Just (Right (VJP {})) ->
          ADValue -> Maybe ADValue
forall a. a -> Maybe a
Just (ADValue -> Maybe ADValue)
-> (Tape -> ADValue) -> Tape -> Maybe ADValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> ADVariable -> ADValue
Variable Int
dep (ADVariable -> ADValue) -> (Tape -> ADVariable) -> Tape -> ADValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VJPValue -> ADVariable
VJP (VJPValue -> ADVariable)
-> (Tape -> VJPValue) -> Tape -> ADVariable
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tape -> VJPValue
VJPValue (Tape -> Maybe ADValue) -> Tape -> Maybe ADValue
forall a b. (a -> b) -> a -> b
$ Op -> [Either ADValue VJPValue] -> ADValue -> Tape
vjpHandleOp Op
op ((Either ADValue ADVariable -> Either ADValue VJPValue)
-> [Either ADValue ADVariable] -> [Either ADValue VJPValue]
forall a b. (a -> b) -> [a] -> [b]
map Either ADValue ADVariable -> Either ADValue VJPValue
extractVJP [Either ADValue ADVariable]
o') ADValue
vprev
        Just (Right (JVP {})) ->
          Int -> ADVariable -> ADValue
Variable Int
dep (ADVariable -> ADValue)
-> (ADValue -> ADVariable) -> ADValue -> ADValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. JVPValue -> ADVariable
JVP (JVPValue -> ADVariable)
-> (ADValue -> JVPValue) -> ADValue -> ADVariable
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ADValue -> ADValue -> JVPValue
JVPValue ADValue
vprev (ADValue -> ADValue) -> Maybe ADValue -> Maybe ADValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Op -> [Either ADValue JVPValue] -> Maybe ADValue
jvpHandleFn Op
op ((Either ADValue ADVariable -> Either ADValue JVPValue)
-> [Either ADValue ADVariable] -> [Either ADValue JVPValue]
forall a b. (a -> b) -> [a] -> [b]
map Either ADValue ADVariable -> Either ADValue JVPValue
extractJVP [Either ADValue ADVariable]
o')
        Maybe (Either ADValue ADVariable)
_ ->
          -- Since the maximum depth is non-zero, there must be at
          -- least one variable of depth > 0
          [Char] -> Maybe ADValue
forall a. HasCallStack => [Char] -> a
error [Char]
"find isRight"

calculatePDs :: Op -> [ADValue] -> [ADValue]
calculatePDs :: Op -> [ADValue] -> [ADValue]
calculatePDs Op
op [ADValue]
p =
  -- Create a unique VName for each operand
  let n :: [VName]
n = (Int -> VName) -> [Int] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (\Int
i -> Name -> Int -> VName
VName ([Char] -> Name
nameFromString ([Char] -> Name) -> [Char] -> Name
forall a b. (a -> b) -> a -> b
$ [Char]
"x" [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
i) Int
i) [Int
1 .. [ADValue] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ADValue]
p]
      -- Put the operands in the environment
      m :: Map VName ADValue
m = [(VName, ADValue)] -> Map VName ADValue
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, ADValue)] -> Map VName ADValue)
-> [(VName, ADValue)] -> Map VName ADValue
forall a b. (a -> b) -> a -> b
$ [VName] -> [ADValue] -> [(VName, ADValue)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
n [ADValue]
p

      -- Look up, and calculate the partial derivative
      -- of the operation with respect to each operand
      pde :: [PrimExp VName]
pde =
        [PrimExp VName] -> Maybe [PrimExp VName] -> [PrimExp VName]
forall a. a -> Maybe a -> a
fromMaybe ([Char] -> [PrimExp VName]
forall a. HasCallStack => [Char] -> a
error [Char]
"lookupPDs failed") (Maybe [PrimExp VName] -> [PrimExp VName])
-> Maybe [PrimExp VName] -> [PrimExp VName]
forall a b. (a -> b) -> a -> b
$
          Op -> [PrimExp VName] -> Maybe [PrimExp VName]
lookupPDs Op
op ([PrimExp VName] -> Maybe [PrimExp VName])
-> [PrimExp VName] -> Maybe [PrimExp VName]
forall a b. (a -> b) -> a -> b
$
            (VName -> PrimExp VName) -> [VName] -> [PrimExp VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
`LeafExp` Op -> PrimType
opReturnType Op
op) [VName]
n
   in (PrimExp VName -> ADValue) -> [PrimExp VName] -> [ADValue]
forall a b. (a -> b) -> [a] -> [b]
map (ADValue -> Maybe ADValue -> ADValue
forall a. a -> Maybe a -> a
fromMaybe ([Char] -> ADValue
forall a. HasCallStack => [Char] -> a
error [Char]
"evalPrimExp failed") (Maybe ADValue -> ADValue)
-> (PrimExp VName -> Maybe ADValue) -> PrimExp VName -> ADValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName ADValue -> PrimExp VName -> Maybe ADValue
evalPrimExp Map VName ADValue
m) [PrimExp VName]
pde

-- VJP / Reverse mode automatic differentiation--
-- In reverse mode AD, the entire computation
-- leading up to a variable must be saved
-- This is represented as a Tape
newtype VJPValue = VJPValue Tape
  deriving (Int -> VJPValue -> ShowS
[VJPValue] -> ShowS
VJPValue -> [Char]
(Int -> VJPValue -> ShowS)
-> (VJPValue -> [Char]) -> ([VJPValue] -> ShowS) -> Show VJPValue
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> VJPValue -> ShowS
showsPrec :: Int -> VJPValue -> ShowS
$cshow :: VJPValue -> [Char]
show :: VJPValue -> [Char]
$cshowList :: [VJPValue] -> ShowS
showList :: [VJPValue] -> ShowS
Show)

-- | Represents a computation tree, as well as every intermediate
-- value in its evaluation. TODO: make this a graph.
data Tape
  = -- | This represents a variable. Each variable is given a unique ID,
    -- and has an initial value
    TapeID Int ADValue
  | -- | This represents a constant.
    TapeConst ADValue
  | -- | This represents the application of a mathematical operation.
    -- Each parameter is given by its Tape, and the return value of
    -- the operation is saved
    TapeOp Op [Tape] ADValue
  deriving (Int -> Tape -> ShowS
[Tape] -> ShowS
Tape -> [Char]
(Int -> Tape -> ShowS)
-> (Tape -> [Char]) -> ([Tape] -> ShowS) -> Show Tape
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Tape -> ShowS
showsPrec :: Int -> Tape -> ShowS
$cshow :: Tape -> [Char]
show :: Tape -> [Char]
$cshowList :: [Tape] -> ShowS
showList :: [Tape] -> ShowS
Show)

-- | Returns the primal value of a Tape.
tapePrimal :: Tape -> ADValue
tapePrimal :: Tape -> ADValue
tapePrimal (TapeID Int
_ ADValue
v) = ADValue
v
tapePrimal (TapeConst ADValue
v) = ADValue
v
tapePrimal (TapeOp Op
_ [Tape]
_ ADValue
v) = ADValue
v

-- This updates Tape of a VJPValue with a new operation,
-- treating all operands of a lower depth as constants
vjpHandleOp :: Op -> [Either ADValue VJPValue] -> ADValue -> Tape
vjpHandleOp :: Op -> [Either ADValue VJPValue] -> ADValue -> Tape
vjpHandleOp Op
op [Either ADValue VJPValue]
p ADValue
v = do
  Op -> [Tape] -> ADValue -> Tape
TapeOp Op
op ((Either ADValue VJPValue -> Tape)
-> [Either ADValue VJPValue] -> [Tape]
forall a b. (a -> b) -> [a] -> [b]
map Either ADValue VJPValue -> Tape
toTape [Either ADValue VJPValue]
p) ADValue
v
  where
    toTape :: Either ADValue VJPValue -> Tape
toTape (Left ADValue
v') = ADValue -> Tape
TapeConst ADValue
v'
    toTape (Right (VJPValue Tape
t)) = Tape
t

-- | This calculates every partial derivative of a 'Tape'. The result
-- is a map of the partial derivatives, each key corresponding to the
-- ID of a free variable (see TapeID).
deriveTape :: Tape -> ADValue -> M.Map Int ADValue
deriveTape :: Tape -> ADValue -> Map Int ADValue
deriveTape (TapeID Int
i ADValue
_) ADValue
s = [(Int, ADValue)] -> Map Int ADValue
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Int
i, ADValue
s)]
deriveTape (TapeConst ADValue
_) ADValue
_ = Map Int ADValue
forall k a. Map k a
M.empty
deriveTape (TapeOp Op
op [Tape]
p ADValue
_) ADValue
s =
  -- Calculate the new sensitivities
  let s'' :: [ADValue]
s'' = case Op
op of
        OpConv ConvOp
op' ->
          -- In case of type conversion, simply convert the sensitivity
          [ ADValue -> Maybe ADValue -> ADValue
forall a. a -> Maybe a -> a
fromMaybe ([Char] -> ADValue
forall a. HasCallStack => [Char] -> a
error [Char]
"deriveTape: doOp failed") (Maybe ADValue -> ADValue) -> Maybe ADValue -> ADValue
forall a b. (a -> b) -> a -> b
$
              Op -> [ADValue] -> Maybe ADValue
doOp (ConvOp -> Op
OpConv (ConvOp -> Op) -> ConvOp -> Op
forall a b. (a -> b) -> a -> b
$ ConvOp -> ConvOp
flipConvOp ConvOp
op') [ADValue
s]
          ]
        Op
_ ->
          (ADValue -> ADValue) -> [ADValue] -> [ADValue]
forall a b. (a -> b) -> [a] -> [b]
map (ADValue -> ADValue -> ADValue
mul ADValue
s) ([ADValue] -> [ADValue]) -> [ADValue] -> [ADValue]
forall a b. (a -> b) -> a -> b
$ Op -> [ADValue] -> [ADValue]
calculatePDs Op
op ([ADValue] -> [ADValue]) -> [ADValue] -> [ADValue]
forall a b. (a -> b) -> a -> b
$ (Tape -> ADValue) -> [Tape] -> [ADValue]
forall a b. (a -> b) -> [a] -> [b]
map Tape -> ADValue
tapePrimal [Tape]
p

      -- Propagate the new sensitivities
      pd :: [Map Int ADValue]
pd = (Tape -> ADValue -> Map Int ADValue)
-> [Tape] -> [ADValue] -> [Map Int ADValue]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Tape -> ADValue -> Map Int ADValue
deriveTape [Tape]
p [ADValue]
s''
   in -- Add up the results
      (Map Int ADValue -> Map Int ADValue -> Map Int ADValue)
-> Map Int ADValue -> [Map Int ADValue] -> Map Int ADValue
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((ADValue -> ADValue -> ADValue)
-> Map Int ADValue -> Map Int ADValue -> Map Int ADValue
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith ADValue -> ADValue -> ADValue
add) Map Int ADValue
forall k a. Map k a
M.empty [Map Int ADValue]
pd
  where
    add :: ADValue -> ADValue -> ADValue
add ADValue
x ADValue
y =
      ADValue -> Maybe ADValue -> ADValue
forall a. a -> Maybe a -> a
fromMaybe ([Char] -> ADValue
forall a. HasCallStack => [Char] -> a
error [Char]
"deriveTape: add failed") (Maybe ADValue -> ADValue) -> Maybe ADValue -> ADValue
forall a b. (a -> b) -> a -> b
$
        Op -> [ADValue] -> Maybe ADValue
doOp (BinOp -> Op
OpBin (BinOp -> Op) -> BinOp -> Op
forall a b. (a -> b) -> a -> b
$ PrimType -> BinOp
addFor (PrimType -> BinOp) -> PrimType -> BinOp
forall a b. (a -> b) -> a -> b
$ Op -> PrimType
opReturnType Op
op) [ADValue
x, ADValue
y]
    mul :: ADValue -> ADValue -> ADValue
mul ADValue
x ADValue
y =
      ADValue -> Maybe ADValue -> ADValue
forall a. a -> Maybe a -> a
fromMaybe ([Char] -> ADValue
forall a. HasCallStack => [Char] -> a
error [Char]
"deriveTape: mul failed") (Maybe ADValue -> ADValue) -> Maybe ADValue -> ADValue
forall a b. (a -> b) -> a -> b
$
        Op -> [ADValue] -> Maybe ADValue
doOp (BinOp -> Op
OpBin (BinOp -> Op) -> BinOp -> Op
forall a b. (a -> b) -> a -> b
$ PrimType -> BinOp
mulFor (PrimType -> BinOp) -> PrimType -> BinOp
forall a b. (a -> b) -> a -> b
$ Op -> PrimType
opReturnType Op
op) [ADValue
x, ADValue
y]

-- JVP / Forward mode automatic differentiation--

-- | In JVP, the derivative of the variable must be saved. This is
-- represented as a second value.
data JVPValue = JVPValue ADValue ADValue
  deriving (Int -> JVPValue -> ShowS
[JVPValue] -> ShowS
JVPValue -> [Char]
(Int -> JVPValue -> ShowS)
-> (JVPValue -> [Char]) -> ([JVPValue] -> ShowS) -> Show JVPValue
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> JVPValue -> ShowS
showsPrec :: Int -> JVPValue -> ShowS
$cshow :: JVPValue -> [Char]
show :: JVPValue -> [Char]
$cshowList :: [JVPValue] -> ShowS
showList :: [JVPValue] -> ShowS
Show)

-- | This calculates the derivative part of the JVPValue resulting
-- from the application of a mathematical operation on one or more
-- JVPValues.
jvpHandleFn :: Op -> [Either ADValue JVPValue] -> Maybe ADValue
jvpHandleFn :: Op -> [Either ADValue JVPValue] -> Maybe ADValue
jvpHandleFn Op
op [Either ADValue JVPValue]
p = do
  case Op
op of
    OpConv ConvOp
_ ->
      -- In case of type conversion, simply convert
      -- the old derivative
      Op -> [ADValue] -> Maybe ADValue
doOp Op
op [Either ADValue JVPValue -> ADValue
derivative (Either ADValue JVPValue -> ADValue)
-> Either ADValue JVPValue -> ADValue
forall a b. (a -> b) -> a -> b
$ [Either ADValue JVPValue] -> Either ADValue JVPValue
forall a. HasCallStack => [a] -> a
head [Either ADValue JVPValue]
p]
    Op
_ -> do
      -- Calculate the new derivative using the chain
      -- rule
      let pds :: [ADValue]
pds = Op -> [ADValue] -> [ADValue]
calculatePDs Op
op ([ADValue] -> [ADValue]) -> [ADValue] -> [ADValue]
forall a b. (a -> b) -> a -> b
$ (Either ADValue JVPValue -> ADValue)
-> [Either ADValue JVPValue] -> [ADValue]
forall a b. (a -> b) -> [a] -> [b]
map Either ADValue JVPValue -> ADValue
primal' [Either ADValue JVPValue]
p
      [ADValue]
vs <- (ADValue -> ADValue -> Maybe ADValue)
-> [ADValue] -> [ADValue] -> Maybe [ADValue]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM ADValue -> ADValue -> Maybe ADValue
mul [ADValue]
pds ([ADValue] -> Maybe [ADValue]) -> [ADValue] -> Maybe [ADValue]
forall a b. (a -> b) -> a -> b
$ (Either ADValue JVPValue -> ADValue)
-> [Either ADValue JVPValue] -> [ADValue]
forall a b. (a -> b) -> [a] -> [b]
map Either ADValue JVPValue -> ADValue
derivative [Either ADValue JVPValue]
p
      (ADValue -> ADValue -> Maybe ADValue)
-> ADValue -> [ADValue] -> Maybe ADValue
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ADValue -> ADValue -> Maybe ADValue
add (PrimValue -> ADValue
Constant (PrimValue -> ADValue) -> PrimValue -> ADValue
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue (PrimType -> PrimValue) -> PrimType -> PrimValue
forall a b. (a -> b) -> a -> b
$ Op -> PrimType
opReturnType Op
op) [ADValue]
vs
  where
    primal' :: Either ADValue JVPValue -> ADValue
primal' (Left ADValue
v) = ADValue
v
    primal' (Right (JVPValue ADValue
v ADValue
_)) = ADValue
v
    derivative :: Either ADValue JVPValue -> ADValue
derivative (Left ADValue
v) = PrimValue -> ADValue
Constant (PrimValue -> ADValue) -> PrimValue -> ADValue
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue (PrimType -> PrimValue) -> PrimType -> PrimValue
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType (PrimValue -> PrimType) -> PrimValue -> PrimType
forall a b. (a -> b) -> a -> b
$ ADValue -> PrimValue
primitive ADValue
v
    derivative (Right (JVPValue ADValue
_ ADValue
d)) = ADValue
d

    add :: ADValue -> ADValue -> Maybe ADValue
add ADValue
x ADValue
y = Op -> [ADValue] -> Maybe ADValue
doOp (BinOp -> Op
OpBin (BinOp -> Op) -> BinOp -> Op
forall a b. (a -> b) -> a -> b
$ PrimType -> BinOp
addFor (PrimType -> BinOp) -> PrimType -> BinOp
forall a b. (a -> b) -> a -> b
$ Op -> PrimType
opReturnType Op
op) [ADValue
x, ADValue
y]
    mul :: ADValue -> ADValue -> Maybe ADValue
mul ADValue
x ADValue
y = Op -> [ADValue] -> Maybe ADValue
doOp (BinOp -> Op
OpBin (BinOp -> Op) -> BinOp -> Op
forall a b. (a -> b) -> a -> b
$ PrimType -> BinOp
mulFor (PrimType -> BinOp) -> PrimType -> BinOp
forall a b. (a -> b) -> a -> b
$ Op -> PrimType
opReturnType Op
op) [ADValue
x, ADValue
y]