module Language.Futhark.Interpreter.AD
( Op (..),
ADVariable (..),
ADValue (..),
Tape (..),
VJPValue (..),
JVPValue (..),
doOp,
addFor,
tapePrimal,
primitive,
varPrimal,
deriveTape,
)
where
import Control.Monad (foldM, zipWithM)
import Data.Either (isRight)
import Data.List (find, foldl')
import Data.Map qualified as M
import Data.Maybe (fromMaybe)
import Data.Text qualified as T
import Futhark.AD.Derivatives (pdBinOp, pdBuiltin, pdUnOp)
import Futhark.Analysis.PrimExp (PrimExp (..))
import Language.Futhark.Core (VName (..), nameFromString, nameFromText)
import Language.Futhark.Primitive
data Op
= OpBin BinOp
| OpCmp CmpOp
| OpUn UnOp
| OpFn T.Text
| OpConv ConvOp
deriving (Int -> Op -> ShowS
[Op] -> ShowS
Op -> [Char]
(Int -> Op -> ShowS)
-> (Op -> [Char]) -> ([Op] -> ShowS) -> Show Op
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Op -> ShowS
showsPrec :: Int -> Op -> ShowS
$cshow :: Op -> [Char]
show :: Op -> [Char]
$cshowList :: [Op] -> ShowS
showList :: [Op] -> ShowS
Show)
opTypeMatch :: Op -> [PrimType] -> Bool
opTypeMatch :: Op -> [PrimType] -> Bool
opTypeMatch (OpBin BinOp
op) [PrimType]
p = (PrimType -> Bool) -> [PrimType] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\PrimType
x -> BinOp -> PrimType
binOpType BinOp
op PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
x) [PrimType]
p
opTypeMatch (OpCmp CmpOp
op) [PrimType]
p = (PrimType -> Bool) -> [PrimType] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\PrimType
x -> CmpOp -> PrimType
cmpOpType CmpOp
op PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
x) [PrimType]
p
opTypeMatch (OpUn UnOp
op) [PrimType]
p = (PrimType -> Bool) -> [PrimType] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\PrimType
x -> UnOp -> PrimType
unOpType UnOp
op PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
x) [PrimType]
p
opTypeMatch (OpConv ConvOp
op) [PrimType]
p = (PrimType -> Bool) -> [PrimType] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\PrimType
x -> (PrimType, PrimType) -> PrimType
forall a b. (a, b) -> a
fst (ConvOp -> (PrimType, PrimType)
convOpType ConvOp
op) PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
x) [PrimType]
p
opTypeMatch (OpFn Text
fn) [PrimType]
p = case Text
-> Map Text ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
-> Maybe ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Text
fn Map Text ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
primFuns of
Just ([PrimType]
t, PrimType
_, [PrimValue] -> Maybe PrimValue
_) -> [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ (PrimType -> PrimType -> Bool)
-> [PrimType] -> [PrimType] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
(==) [PrimType]
t [PrimType]
p
Maybe ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
Nothing -> [Char] -> Bool
forall a. HasCallStack => [Char] -> a
error [Char]
"opTypeMatch"
opReturnType :: Op -> PrimType
opReturnType :: Op -> PrimType
opReturnType (OpBin BinOp
op) = BinOp -> PrimType
binOpType BinOp
op
opReturnType (OpCmp CmpOp
op) = CmpOp -> PrimType
cmpOpType CmpOp
op
opReturnType (OpUn UnOp
op) = UnOp -> PrimType
unOpType UnOp
op
opReturnType (OpConv ConvOp
op) = (PrimType, PrimType) -> PrimType
forall a b. (a, b) -> b
snd ((PrimType, PrimType) -> PrimType)
-> (PrimType, PrimType) -> PrimType
forall a b. (a -> b) -> a -> b
$ ConvOp -> (PrimType, PrimType)
convOpType ConvOp
op
opReturnType (OpFn Text
fn) = case Text
-> Map Text ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
-> Maybe ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Text
fn Map Text ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
primFuns of
Just ([PrimType]
_, PrimType
t, [PrimValue] -> Maybe PrimValue
_) -> PrimType
t
Maybe ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
Nothing -> [Char] -> PrimType
forall a. HasCallStack => [Char] -> a
error [Char]
"opReturnType"
addFor :: PrimType -> BinOp
addFor :: PrimType -> BinOp
addFor (IntType IntType
t) = IntType -> Overflow -> BinOp
Add IntType
t Overflow
OverflowWrap
addFor (FloatType FloatType
t) = FloatType -> BinOp
FAdd FloatType
t
addFor PrimType
Bool = BinOp
LogOr
addFor PrimType
t = [Char] -> BinOp
forall a. HasCallStack => [Char] -> a
error ([Char] -> BinOp) -> [Char] -> BinOp
forall a b. (a -> b) -> a -> b
$ [Char]
"addFor: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ PrimType -> [Char]
forall a. Show a => a -> [Char]
show PrimType
t
mulFor :: PrimType -> BinOp
mulFor :: PrimType -> BinOp
mulFor (IntType IntType
t) = IntType -> Overflow -> BinOp
Mul IntType
t Overflow
OverflowWrap
mulFor (FloatType FloatType
t) = FloatType -> BinOp
FMul FloatType
t
mulFor PrimType
Bool = BinOp
LogAnd
mulFor PrimType
t = [Char] -> BinOp
forall a. HasCallStack => [Char] -> a
error ([Char] -> BinOp) -> [Char] -> BinOp
forall a b. (a -> b) -> a -> b
$ [Char]
"mulFor: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ PrimType -> [Char]
forall a. Show a => a -> [Char]
show PrimType
t
data ADValue
= Variable Int ADVariable
| Constant PrimValue
deriving (Int -> ADValue -> ShowS
[ADValue] -> ShowS
ADValue -> [Char]
(Int -> ADValue -> ShowS)
-> (ADValue -> [Char]) -> ([ADValue] -> ShowS) -> Show ADValue
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ADValue -> ShowS
showsPrec :: Int -> ADValue -> ShowS
$cshow :: ADValue -> [Char]
show :: ADValue -> [Char]
$cshowList :: [ADValue] -> ShowS
showList :: [ADValue] -> ShowS
Show)
data ADVariable
= VJP VJPValue
| JVP JVPValue
deriving (Int -> ADVariable -> ShowS
[ADVariable] -> ShowS
ADVariable -> [Char]
(Int -> ADVariable -> ShowS)
-> (ADVariable -> [Char])
-> ([ADVariable] -> ShowS)
-> Show ADVariable
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ADVariable -> ShowS
showsPrec :: Int -> ADVariable -> ShowS
$cshow :: ADVariable -> [Char]
show :: ADVariable -> [Char]
$cshowList :: [ADVariable] -> ShowS
showList :: [ADVariable] -> ShowS
Show)
depth :: ADValue -> Int
depth :: ADValue -> Int
depth (Variable Int
d ADVariable
_) = Int
d
depth (Constant PrimValue
_) = Int
0
primal :: ADValue -> ADValue
primal :: ADValue -> ADValue
primal (Variable Int
_ (VJP (VJPValue Tape
t))) = Tape -> ADValue
tapePrimal Tape
t
primal (Variable Int
_ (JVP (JVPValue ADValue
v ADValue
_))) = ADValue -> ADValue
primal ADValue
v
primal (Constant PrimValue
v) = PrimValue -> ADValue
Constant PrimValue
v
primitive :: ADValue -> PrimValue
primitive :: ADValue -> PrimValue
primitive (Variable Int
_ ADVariable
v) = ADVariable -> PrimValue
varPrimal ADVariable
v
primitive (Constant PrimValue
v) = PrimValue
v
varPrimal :: ADVariable -> PrimValue
varPrimal :: ADVariable -> PrimValue
varPrimal (VJP (VJPValue Tape
t)) = ADValue -> PrimValue
primitive (ADValue -> PrimValue) -> ADValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ Tape -> ADValue
tapePrimal Tape
t
varPrimal (JVP (JVPValue ADValue
v ADValue
_)) = ADValue -> PrimValue
primitive (ADValue -> PrimValue) -> ADValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ ADValue -> ADValue
primal ADValue
v
evalPrimExp :: M.Map VName ADValue -> PrimExp VName -> Maybe ADValue
evalPrimExp :: Map VName ADValue -> PrimExp VName -> Maybe ADValue
evalPrimExp Map VName ADValue
m (LeafExp VName
n PrimType
_) = VName -> Map VName ADValue -> Maybe ADValue
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
n Map VName ADValue
m
evalPrimExp Map VName ADValue
_ (ValueExp PrimValue
pv) = ADValue -> Maybe ADValue
forall a. a -> Maybe a
Just (ADValue -> Maybe ADValue) -> ADValue -> Maybe ADValue
forall a b. (a -> b) -> a -> b
$ PrimValue -> ADValue
Constant PrimValue
pv
evalPrimExp Map VName ADValue
m (BinOpExp BinOp
op PrimExp VName
x PrimExp VName
y) = do
ADValue
x' <- Map VName ADValue -> PrimExp VName -> Maybe ADValue
evalPrimExp Map VName ADValue
m PrimExp VName
x
ADValue
y' <- Map VName ADValue -> PrimExp VName -> Maybe ADValue
evalPrimExp Map VName ADValue
m PrimExp VName
y
Op -> [ADValue] -> Maybe ADValue
doOp (BinOp -> Op
OpBin BinOp
op) [ADValue
x', ADValue
y']
evalPrimExp Map VName ADValue
m (CmpOpExp CmpOp
op PrimExp VName
x PrimExp VName
y) = do
ADValue
x' <- Map VName ADValue -> PrimExp VName -> Maybe ADValue
evalPrimExp Map VName ADValue
m PrimExp VName
x
ADValue
y' <- Map VName ADValue -> PrimExp VName -> Maybe ADValue
evalPrimExp Map VName ADValue
m PrimExp VName
y
Op -> [ADValue] -> Maybe ADValue
doOp (CmpOp -> Op
OpCmp CmpOp
op) [ADValue
x', ADValue
y']
evalPrimExp Map VName ADValue
m (UnOpExp UnOp
op PrimExp VName
x) = do
ADValue
x' <- Map VName ADValue -> PrimExp VName -> Maybe ADValue
evalPrimExp Map VName ADValue
m PrimExp VName
x
Op -> [ADValue] -> Maybe ADValue
doOp (UnOp -> Op
OpUn UnOp
op) [ADValue
x']
evalPrimExp Map VName ADValue
m (ConvOpExp ConvOp
op PrimExp VName
x) = do
ADValue
x' <- Map VName ADValue -> PrimExp VName -> Maybe ADValue
evalPrimExp Map VName ADValue
m PrimExp VName
x
Op -> [ADValue] -> Maybe ADValue
doOp (ConvOp -> Op
OpConv ConvOp
op) [ADValue
x']
evalPrimExp Map VName ADValue
m (FunExp Text
fn [PrimExp VName]
p PrimType
_) = do
[ADValue]
p' <- (PrimExp VName -> Maybe ADValue)
-> [PrimExp VName] -> Maybe [ADValue]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Map VName ADValue -> PrimExp VName -> Maybe ADValue
evalPrimExp Map VName ADValue
m) [PrimExp VName]
p
Op -> [ADValue] -> Maybe ADValue
doOp (Text -> Op
OpFn Text
fn) [ADValue]
p'
lookupPDs :: Op -> [PrimExp VName] -> Maybe [PrimExp VName]
lookupPDs :: Op -> [PrimExp VName] -> Maybe [PrimExp VName]
lookupPDs (OpBin BinOp
op) [PrimExp VName
x, PrimExp VName
y] = [PrimExp VName] -> Maybe [PrimExp VName]
forall a. a -> Maybe a
Just ([PrimExp VName] -> Maybe [PrimExp VName])
-> [PrimExp VName] -> Maybe [PrimExp VName]
forall a b. (a -> b) -> a -> b
$ do
let (PrimExp VName
a, PrimExp VName
b) = BinOp
-> PrimExp VName -> PrimExp VName -> (PrimExp VName, PrimExp VName)
pdBinOp BinOp
op PrimExp VName
x PrimExp VName
y
[PrimExp VName
a, PrimExp VName
b]
lookupPDs (OpUn UnOp
op) [PrimExp VName
x] = [PrimExp VName] -> Maybe [PrimExp VName]
forall a. a -> Maybe a
Just [UnOp -> PrimExp VName -> PrimExp VName
pdUnOp UnOp
op PrimExp VName
x]
lookupPDs (OpFn Text
fn) [PrimExp VName]
p = Name -> [PrimExp VName] -> Maybe [PrimExp VName]
pdBuiltin (Text -> Name
nameFromText Text
fn) [PrimExp VName]
p
lookupPDs Op
_ [PrimExp VName]
_ = Maybe [PrimExp VName]
forall a. Maybe a
Nothing
doOp :: Op -> [ADValue] -> Maybe ADValue
doOp :: Op -> [ADValue] -> Maybe ADValue
doOp Op
op [ADValue]
o
| Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Op -> [PrimType] -> Bool
opTypeMatch Op
op ((PrimValue -> PrimType) -> [PrimValue] -> [PrimType]
forall a b. (a -> b) -> [a] -> [b]
map PrimValue -> PrimType
primValueType [PrimValue]
pv) =
Maybe ADValue
forall a. Maybe a
Nothing
| Bool
otherwise = do
let dep :: Int
dep = case Op
op of
OpCmp CmpOp
_ -> Int
0
Op
_ -> [Int] -> Int
forall a. Ord a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum ((ADValue -> Int) -> [ADValue] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ADValue -> Int
depth [ADValue]
o)
if Int
dep Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 then Maybe ADValue
constCase else Int -> Maybe ADValue
nonconstCase Int
dep
where
pv :: [PrimValue]
pv = (ADValue -> PrimValue) -> [ADValue] -> [PrimValue]
forall a b. (a -> b) -> [a] -> [b]
map ADValue -> PrimValue
primitive [ADValue]
o
divideDepths :: Int -> ADValue -> Either ADValue ADVariable
divideDepths :: Int -> ADValue -> Either ADValue ADVariable
divideDepths Int
_ v :: ADValue
v@(Constant {}) = ADValue -> Either ADValue ADVariable
forall a b. a -> Either a b
Left ADValue
v
divideDepths Int
d v :: ADValue
v@(Variable Int
d' ADVariable
v') = if Int
d' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
d then ADValue -> Either ADValue ADVariable
forall a b. a -> Either a b
Left ADValue
v else ADVariable -> Either ADValue ADVariable
forall a b. b -> Either a b
Right ADVariable
v'
extractVJP :: Either ADValue ADVariable -> Either ADValue VJPValue
extractVJP :: Either ADValue ADVariable -> Either ADValue VJPValue
extractVJP (Right (VJP VJPValue
v)) = VJPValue -> Either ADValue VJPValue
forall a b. b -> Either a b
Right VJPValue
v
extractVJP (Left ADValue
v) = ADValue -> Either ADValue VJPValue
forall a b. a -> Either a b
Left ADValue
v
extractVJP Either ADValue ADVariable
_ =
[Char] -> Either ADValue VJPValue
forall a. HasCallStack => [Char] -> a
error [Char]
"extractVJP"
extractJVP :: Either ADValue ADVariable -> Either ADValue JVPValue
extractJVP :: Either ADValue ADVariable -> Either ADValue JVPValue
extractJVP (Right (JVP JVPValue
v)) = JVPValue -> Either ADValue JVPValue
forall a b. b -> Either a b
Right JVPValue
v
extractJVP (Left ADValue
v) = ADValue -> Either ADValue JVPValue
forall a b. a -> Either a b
Left ADValue
v
extractJVP Either ADValue ADVariable
_ =
[Char] -> Either ADValue JVPValue
forall a. HasCallStack => [Char] -> a
error [Char]
"extractJVP"
constCase :: Maybe ADValue
constCase =
PrimValue -> ADValue
Constant (PrimValue -> ADValue) -> Maybe PrimValue -> Maybe ADValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> case (Op
op, [PrimValue]
pv) of
(OpBin BinOp
op', [PrimValue
x, PrimValue
y]) -> BinOp -> PrimValue -> PrimValue -> Maybe PrimValue
doBinOp BinOp
op' PrimValue
x PrimValue
y
(OpCmp CmpOp
op', [PrimValue
x, PrimValue
y]) -> Bool -> PrimValue
BoolValue (Bool -> PrimValue) -> Maybe Bool -> Maybe PrimValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CmpOp -> PrimValue -> PrimValue -> Maybe Bool
doCmpOp CmpOp
op' PrimValue
x PrimValue
y
(OpUn UnOp
op', [PrimValue
x]) -> UnOp -> PrimValue -> Maybe PrimValue
doUnOp UnOp
op' PrimValue
x
(OpConv ConvOp
op', [PrimValue
x]) -> ConvOp -> PrimValue -> Maybe PrimValue
doConvOp ConvOp
op' PrimValue
x
(OpFn Text
fn, [PrimValue]
_) -> do
([PrimType]
_, PrimType
_, [PrimValue] -> Maybe PrimValue
f) <- Text
-> Map Text ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
-> Maybe ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Text
fn Map Text ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
primFuns
[PrimValue] -> Maybe PrimValue
f [PrimValue]
pv
(Op, [PrimValue])
_ -> [Char] -> Maybe PrimValue
forall a. HasCallStack => [Char] -> a
error [Char]
"doOp: opTypeMatch"
nonconstCase :: Int -> Maybe ADValue
nonconstCase Int
dep = do
let oprev :: [ADValue]
oprev = (ADValue -> ADValue) -> [ADValue] -> [ADValue]
forall a b. (a -> b) -> [a] -> [b]
map ADValue -> ADValue
primal [ADValue]
o
ADValue
vprev <- Op -> [ADValue] -> Maybe ADValue
doOp Op
op [ADValue]
oprev
let o' :: [Either ADValue ADVariable]
o' = (ADValue -> Either ADValue ADVariable)
-> [ADValue] -> [Either ADValue ADVariable]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> ADValue -> Either ADValue ADVariable
divideDepths Int
dep) [ADValue]
o
case (Either ADValue ADVariable -> Bool)
-> [Either ADValue ADVariable] -> Maybe (Either ADValue ADVariable)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find Either ADValue ADVariable -> Bool
forall a b. Either a b -> Bool
isRight [Either ADValue ADVariable]
o' of
Just (Right (VJP {})) ->
ADValue -> Maybe ADValue
forall a. a -> Maybe a
Just (ADValue -> Maybe ADValue)
-> (Tape -> ADValue) -> Tape -> Maybe ADValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> ADVariable -> ADValue
Variable Int
dep (ADVariable -> ADValue) -> (Tape -> ADVariable) -> Tape -> ADValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VJPValue -> ADVariable
VJP (VJPValue -> ADVariable)
-> (Tape -> VJPValue) -> Tape -> ADVariable
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tape -> VJPValue
VJPValue (Tape -> Maybe ADValue) -> Tape -> Maybe ADValue
forall a b. (a -> b) -> a -> b
$ Op -> [Either ADValue VJPValue] -> ADValue -> Tape
vjpHandleOp Op
op ((Either ADValue ADVariable -> Either ADValue VJPValue)
-> [Either ADValue ADVariable] -> [Either ADValue VJPValue]
forall a b. (a -> b) -> [a] -> [b]
map Either ADValue ADVariable -> Either ADValue VJPValue
extractVJP [Either ADValue ADVariable]
o') ADValue
vprev
Just (Right (JVP {})) ->
Int -> ADVariable -> ADValue
Variable Int
dep (ADVariable -> ADValue)
-> (ADValue -> ADVariable) -> ADValue -> ADValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. JVPValue -> ADVariable
JVP (JVPValue -> ADVariable)
-> (ADValue -> JVPValue) -> ADValue -> ADVariable
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ADValue -> ADValue -> JVPValue
JVPValue ADValue
vprev (ADValue -> ADValue) -> Maybe ADValue -> Maybe ADValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Op -> [Either ADValue JVPValue] -> Maybe ADValue
jvpHandleFn Op
op ((Either ADValue ADVariable -> Either ADValue JVPValue)
-> [Either ADValue ADVariable] -> [Either ADValue JVPValue]
forall a b. (a -> b) -> [a] -> [b]
map Either ADValue ADVariable -> Either ADValue JVPValue
extractJVP [Either ADValue ADVariable]
o')
Maybe (Either ADValue ADVariable)
_ ->
[Char] -> Maybe ADValue
forall a. HasCallStack => [Char] -> a
error [Char]
"find isRight"
calculatePDs :: Op -> [ADValue] -> [ADValue]
calculatePDs :: Op -> [ADValue] -> [ADValue]
calculatePDs Op
op [ADValue]
p =
let n :: [VName]
n = (Int -> VName) -> [Int] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (\Int
i -> Name -> Int -> VName
VName ([Char] -> Name
nameFromString ([Char] -> Name) -> [Char] -> Name
forall a b. (a -> b) -> a -> b
$ [Char]
"x" [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
i) Int
i) [Int
1 .. [ADValue] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ADValue]
p]
m :: Map VName ADValue
m = [(VName, ADValue)] -> Map VName ADValue
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, ADValue)] -> Map VName ADValue)
-> [(VName, ADValue)] -> Map VName ADValue
forall a b. (a -> b) -> a -> b
$ [VName] -> [ADValue] -> [(VName, ADValue)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
n [ADValue]
p
pde :: [PrimExp VName]
pde =
[PrimExp VName] -> Maybe [PrimExp VName] -> [PrimExp VName]
forall a. a -> Maybe a -> a
fromMaybe ([Char] -> [PrimExp VName]
forall a. HasCallStack => [Char] -> a
error [Char]
"lookupPDs failed") (Maybe [PrimExp VName] -> [PrimExp VName])
-> Maybe [PrimExp VName] -> [PrimExp VName]
forall a b. (a -> b) -> a -> b
$
Op -> [PrimExp VName] -> Maybe [PrimExp VName]
lookupPDs Op
op ([PrimExp VName] -> Maybe [PrimExp VName])
-> [PrimExp VName] -> Maybe [PrimExp VName]
forall a b. (a -> b) -> a -> b
$
(VName -> PrimExp VName) -> [VName] -> [PrimExp VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
`LeafExp` Op -> PrimType
opReturnType Op
op) [VName]
n
in (PrimExp VName -> ADValue) -> [PrimExp VName] -> [ADValue]
forall a b. (a -> b) -> [a] -> [b]
map (ADValue -> Maybe ADValue -> ADValue
forall a. a -> Maybe a -> a
fromMaybe ([Char] -> ADValue
forall a. HasCallStack => [Char] -> a
error [Char]
"evalPrimExp failed") (Maybe ADValue -> ADValue)
-> (PrimExp VName -> Maybe ADValue) -> PrimExp VName -> ADValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName ADValue -> PrimExp VName -> Maybe ADValue
evalPrimExp Map VName ADValue
m) [PrimExp VName]
pde
newtype VJPValue = VJPValue Tape
deriving (Int -> VJPValue -> ShowS
[VJPValue] -> ShowS
VJPValue -> [Char]
(Int -> VJPValue -> ShowS)
-> (VJPValue -> [Char]) -> ([VJPValue] -> ShowS) -> Show VJPValue
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> VJPValue -> ShowS
showsPrec :: Int -> VJPValue -> ShowS
$cshow :: VJPValue -> [Char]
show :: VJPValue -> [Char]
$cshowList :: [VJPValue] -> ShowS
showList :: [VJPValue] -> ShowS
Show)
data Tape
=
TapeID Int ADValue
|
TapeConst ADValue
|
TapeOp Op [Tape] ADValue
deriving (Int -> Tape -> ShowS
[Tape] -> ShowS
Tape -> [Char]
(Int -> Tape -> ShowS)
-> (Tape -> [Char]) -> ([Tape] -> ShowS) -> Show Tape
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Tape -> ShowS
showsPrec :: Int -> Tape -> ShowS
$cshow :: Tape -> [Char]
show :: Tape -> [Char]
$cshowList :: [Tape] -> ShowS
showList :: [Tape] -> ShowS
Show)
tapePrimal :: Tape -> ADValue
tapePrimal :: Tape -> ADValue
tapePrimal (TapeID Int
_ ADValue
v) = ADValue
v
tapePrimal (TapeConst ADValue
v) = ADValue
v
tapePrimal (TapeOp Op
_ [Tape]
_ ADValue
v) = ADValue
v
vjpHandleOp :: Op -> [Either ADValue VJPValue] -> ADValue -> Tape
vjpHandleOp :: Op -> [Either ADValue VJPValue] -> ADValue -> Tape
vjpHandleOp Op
op [Either ADValue VJPValue]
p ADValue
v = do
Op -> [Tape] -> ADValue -> Tape
TapeOp Op
op ((Either ADValue VJPValue -> Tape)
-> [Either ADValue VJPValue] -> [Tape]
forall a b. (a -> b) -> [a] -> [b]
map Either ADValue VJPValue -> Tape
toTape [Either ADValue VJPValue]
p) ADValue
v
where
toTape :: Either ADValue VJPValue -> Tape
toTape (Left ADValue
v') = ADValue -> Tape
TapeConst ADValue
v'
toTape (Right (VJPValue Tape
t)) = Tape
t
deriveTape :: Tape -> ADValue -> M.Map Int ADValue
deriveTape :: Tape -> ADValue -> Map Int ADValue
deriveTape (TapeID Int
i ADValue
_) ADValue
s = [(Int, ADValue)] -> Map Int ADValue
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Int
i, ADValue
s)]
deriveTape (TapeConst ADValue
_) ADValue
_ = Map Int ADValue
forall k a. Map k a
M.empty
deriveTape (TapeOp Op
op [Tape]
p ADValue
_) ADValue
s =
let s'' :: [ADValue]
s'' = case Op
op of
OpConv ConvOp
op' ->
[ ADValue -> Maybe ADValue -> ADValue
forall a. a -> Maybe a -> a
fromMaybe ([Char] -> ADValue
forall a. HasCallStack => [Char] -> a
error [Char]
"deriveTape: doOp failed") (Maybe ADValue -> ADValue) -> Maybe ADValue -> ADValue
forall a b. (a -> b) -> a -> b
$
Op -> [ADValue] -> Maybe ADValue
doOp (ConvOp -> Op
OpConv (ConvOp -> Op) -> ConvOp -> Op
forall a b. (a -> b) -> a -> b
$ ConvOp -> ConvOp
flipConvOp ConvOp
op') [ADValue
s]
]
Op
_ ->
(ADValue -> ADValue) -> [ADValue] -> [ADValue]
forall a b. (a -> b) -> [a] -> [b]
map (ADValue -> ADValue -> ADValue
mul ADValue
s) ([ADValue] -> [ADValue]) -> [ADValue] -> [ADValue]
forall a b. (a -> b) -> a -> b
$ Op -> [ADValue] -> [ADValue]
calculatePDs Op
op ([ADValue] -> [ADValue]) -> [ADValue] -> [ADValue]
forall a b. (a -> b) -> a -> b
$ (Tape -> ADValue) -> [Tape] -> [ADValue]
forall a b. (a -> b) -> [a] -> [b]
map Tape -> ADValue
tapePrimal [Tape]
p
pd :: [Map Int ADValue]
pd = (Tape -> ADValue -> Map Int ADValue)
-> [Tape] -> [ADValue] -> [Map Int ADValue]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Tape -> ADValue -> Map Int ADValue
deriveTape [Tape]
p [ADValue]
s''
in
(Map Int ADValue -> Map Int ADValue -> Map Int ADValue)
-> Map Int ADValue -> [Map Int ADValue] -> Map Int ADValue
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((ADValue -> ADValue -> ADValue)
-> Map Int ADValue -> Map Int ADValue -> Map Int ADValue
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith ADValue -> ADValue -> ADValue
add) Map Int ADValue
forall k a. Map k a
M.empty [Map Int ADValue]
pd
where
add :: ADValue -> ADValue -> ADValue
add ADValue
x ADValue
y =
ADValue -> Maybe ADValue -> ADValue
forall a. a -> Maybe a -> a
fromMaybe ([Char] -> ADValue
forall a. HasCallStack => [Char] -> a
error [Char]
"deriveTape: add failed") (Maybe ADValue -> ADValue) -> Maybe ADValue -> ADValue
forall a b. (a -> b) -> a -> b
$
Op -> [ADValue] -> Maybe ADValue
doOp (BinOp -> Op
OpBin (BinOp -> Op) -> BinOp -> Op
forall a b. (a -> b) -> a -> b
$ PrimType -> BinOp
addFor (PrimType -> BinOp) -> PrimType -> BinOp
forall a b. (a -> b) -> a -> b
$ Op -> PrimType
opReturnType Op
op) [ADValue
x, ADValue
y]
mul :: ADValue -> ADValue -> ADValue
mul ADValue
x ADValue
y =
ADValue -> Maybe ADValue -> ADValue
forall a. a -> Maybe a -> a
fromMaybe ([Char] -> ADValue
forall a. HasCallStack => [Char] -> a
error [Char]
"deriveTape: mul failed") (Maybe ADValue -> ADValue) -> Maybe ADValue -> ADValue
forall a b. (a -> b) -> a -> b
$
Op -> [ADValue] -> Maybe ADValue
doOp (BinOp -> Op
OpBin (BinOp -> Op) -> BinOp -> Op
forall a b. (a -> b) -> a -> b
$ PrimType -> BinOp
mulFor (PrimType -> BinOp) -> PrimType -> BinOp
forall a b. (a -> b) -> a -> b
$ Op -> PrimType
opReturnType Op
op) [ADValue
x, ADValue
y]
data JVPValue = JVPValue ADValue ADValue
deriving (Int -> JVPValue -> ShowS
[JVPValue] -> ShowS
JVPValue -> [Char]
(Int -> JVPValue -> ShowS)
-> (JVPValue -> [Char]) -> ([JVPValue] -> ShowS) -> Show JVPValue
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> JVPValue -> ShowS
showsPrec :: Int -> JVPValue -> ShowS
$cshow :: JVPValue -> [Char]
show :: JVPValue -> [Char]
$cshowList :: [JVPValue] -> ShowS
showList :: [JVPValue] -> ShowS
Show)
jvpHandleFn :: Op -> [Either ADValue JVPValue] -> Maybe ADValue
jvpHandleFn :: Op -> [Either ADValue JVPValue] -> Maybe ADValue
jvpHandleFn Op
op [Either ADValue JVPValue]
p = do
case Op
op of
OpConv ConvOp
_ ->
Op -> [ADValue] -> Maybe ADValue
doOp Op
op [Either ADValue JVPValue -> ADValue
derivative (Either ADValue JVPValue -> ADValue)
-> Either ADValue JVPValue -> ADValue
forall a b. (a -> b) -> a -> b
$ [Either ADValue JVPValue] -> Either ADValue JVPValue
forall a. HasCallStack => [a] -> a
head [Either ADValue JVPValue]
p]
Op
_ -> do
let pds :: [ADValue]
pds = Op -> [ADValue] -> [ADValue]
calculatePDs Op
op ([ADValue] -> [ADValue]) -> [ADValue] -> [ADValue]
forall a b. (a -> b) -> a -> b
$ (Either ADValue JVPValue -> ADValue)
-> [Either ADValue JVPValue] -> [ADValue]
forall a b. (a -> b) -> [a] -> [b]
map Either ADValue JVPValue -> ADValue
primal' [Either ADValue JVPValue]
p
[ADValue]
vs <- (ADValue -> ADValue -> Maybe ADValue)
-> [ADValue] -> [ADValue] -> Maybe [ADValue]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM ADValue -> ADValue -> Maybe ADValue
mul [ADValue]
pds ([ADValue] -> Maybe [ADValue]) -> [ADValue] -> Maybe [ADValue]
forall a b. (a -> b) -> a -> b
$ (Either ADValue JVPValue -> ADValue)
-> [Either ADValue JVPValue] -> [ADValue]
forall a b. (a -> b) -> [a] -> [b]
map Either ADValue JVPValue -> ADValue
derivative [Either ADValue JVPValue]
p
(ADValue -> ADValue -> Maybe ADValue)
-> ADValue -> [ADValue] -> Maybe ADValue
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ADValue -> ADValue -> Maybe ADValue
add (PrimValue -> ADValue
Constant (PrimValue -> ADValue) -> PrimValue -> ADValue
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue (PrimType -> PrimValue) -> PrimType -> PrimValue
forall a b. (a -> b) -> a -> b
$ Op -> PrimType
opReturnType Op
op) [ADValue]
vs
where
primal' :: Either ADValue JVPValue -> ADValue
primal' (Left ADValue
v) = ADValue
v
primal' (Right (JVPValue ADValue
v ADValue
_)) = ADValue
v
derivative :: Either ADValue JVPValue -> ADValue
derivative (Left ADValue
v) = PrimValue -> ADValue
Constant (PrimValue -> ADValue) -> PrimValue -> ADValue
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue (PrimType -> PrimValue) -> PrimType -> PrimValue
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType (PrimValue -> PrimType) -> PrimValue -> PrimType
forall a b. (a -> b) -> a -> b
$ ADValue -> PrimValue
primitive ADValue
v
derivative (Right (JVPValue ADValue
_ ADValue
d)) = ADValue
d
add :: ADValue -> ADValue -> Maybe ADValue
add ADValue
x ADValue
y = Op -> [ADValue] -> Maybe ADValue
doOp (BinOp -> Op
OpBin (BinOp -> Op) -> BinOp -> Op
forall a b. (a -> b) -> a -> b
$ PrimType -> BinOp
addFor (PrimType -> BinOp) -> PrimType -> BinOp
forall a b. (a -> b) -> a -> b
$ Op -> PrimType
opReturnType Op
op) [ADValue
x, ADValue
y]
mul :: ADValue -> ADValue -> Maybe ADValue
mul ADValue
x ADValue
y = Op -> [ADValue] -> Maybe ADValue
doOp (BinOp -> Op
OpBin (BinOp -> Op) -> BinOp -> Op
forall a b. (a -> b) -> a -> b
$ PrimType -> BinOp
mulFor (PrimType -> BinOp) -> PrimType -> BinOp
forall a b. (a -> b) -> a -> b
$ Op -> PrimType
opReturnType Op
op) [ADValue
x, ADValue
y]