{-# LANGUAGE TypeFamilies #-}
module Futhark.Pass.AD (applyAD, applyADInnermost) where
import Control.Monad
import Control.Monad.Reader
import Futhark.AD.Fwd (fwdJVP)
import Futhark.AD.Rev (revVJP)
import Futhark.Builder
import Futhark.IR.SOACS
import Futhark.IR.SOACS.Simplify (simplifyLambda)
import Futhark.Pass
data Mode = Innermost | All
deriving (Mode -> Mode -> Bool
(Mode -> Mode -> Bool) -> (Mode -> Mode -> Bool) -> Eq Mode
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Mode -> Mode -> Bool
== :: Mode -> Mode -> Bool
$c/= :: Mode -> Mode -> Bool
/= :: Mode -> Mode -> Bool
Eq)
bindLambda ::
(MonadBuilder m, Rep m ~ SOACS) =>
Pat Type ->
StmAux (ExpDec SOACS) ->
Lambda SOACS ->
[SubExp] ->
m ()
bindLambda :: forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
Pat Type
-> StmAux (ExpDec SOACS) -> Lambda SOACS -> [SubExp] -> m ()
bindLambda Pat Type
pat StmAux (ExpDec SOACS)
aux (Lambda [LParam SOACS]
params [Type]
_ Body SOACS
body) [SubExp]
args = do
StmAux () -> m () -> m ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
StmAux (ExpDec SOACS)
aux (m () -> m ())
-> (((Param Type, SubExp) -> m ()) -> m ())
-> ((Param Type, SubExp) -> m ())
-> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Param Type, SubExp)] -> ((Param Type, SubExp) -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param Type] -> [SubExp] -> [(Param Type, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
[LParam SOACS]
params [SubExp]
args) (((Param Type, SubExp) -> m ()) -> m ())
-> ((Param Type, SubExp) -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \(Param Type
param, SubExp
arg) ->
[VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
param] (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ case Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
param of
Array {} -> Shape -> SubExp -> BasicOp
Replicate Shape
forall a. Monoid a => a
mempty SubExp
arg
Type
_ -> SubExp -> BasicOp
SubExp SubExp
arg
Result
res <- Body (Rep m) -> m Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind Body (Rep m)
Body SOACS
body
[(VName, SubExpRes)] -> ((VName, SubExpRes) -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> Result -> [(VName, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
pat) Result
res) (((VName, SubExpRes) -> m ()) -> m ())
-> ((VName, SubExpRes) -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \(VName
v, SubExpRes Certs
cs SubExp
se) ->
Certs -> m () -> m ()
forall a. Certs -> m a -> m a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
v] (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
onStm :: Mode -> Scope SOACS -> Stm SOACS -> PassM (Stms SOACS)
onStm :: Mode -> Scope SOACS -> Stm SOACS -> PassM (Stms SOACS)
onStm Mode
mode Scope SOACS
scope (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op (VJP [SubExp]
args [SubExp]
vec Lambda SOACS
lam))) = do
Lambda SOACS
lam' <- Mode -> Scope SOACS -> Lambda SOACS -> PassM (Lambda SOACS)
onLambda Mode
mode Scope SOACS
scope Lambda SOACS
lam
if Mode
mode Mode -> Mode -> Bool
forall a. Eq a => a -> a -> Bool
== Mode
All Bool -> Bool -> Bool
|| Lambda SOACS
lam Lambda SOACS -> Lambda SOACS -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda SOACS
lam'
then do
Lambda SOACS
lam'' <- (ReaderT (Scope SOACS) PassM (Lambda SOACS)
-> Scope SOACS -> PassM (Lambda SOACS)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
`runReaderT` Scope SOACS
scope) (ReaderT (Scope SOACS) PassM (Lambda SOACS)
-> PassM (Lambda SOACS))
-> (Lambda SOACS -> ReaderT (Scope SOACS) PassM (Lambda SOACS))
-> Lambda SOACS
-> PassM (Lambda SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda SOACS -> ReaderT (Scope SOACS) PassM (Lambda SOACS)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Lambda SOACS -> m (Lambda SOACS)
simplifyLambda (Lambda SOACS -> PassM (Lambda SOACS))
-> PassM (Lambda SOACS) -> PassM (Lambda SOACS)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Scope SOACS -> Lambda SOACS -> PassM (Lambda SOACS)
forall (m :: * -> *).
MonadFreshNames m =>
Scope SOACS -> Lambda SOACS -> m (Lambda SOACS)
revVJP Scope SOACS
scope Lambda SOACS
lam'
BuilderT SOACS PassM () -> Scope SOACS -> PassM (Stms SOACS)
forall (m :: * -> *) rep.
MonadFreshNames m =>
BuilderT rep m () -> Scope rep -> m (Stms rep)
runBuilderT_ (Pat Type
-> StmAux (ExpDec SOACS)
-> Lambda SOACS
-> [SubExp]
-> BuilderT SOACS PassM ()
forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
Pat Type
-> StmAux (ExpDec SOACS) -> Lambda SOACS -> [SubExp] -> m ()
bindLambda Pat Type
Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux Lambda SOACS
lam'' ([SubExp] -> BuilderT SOACS PassM ())
-> [SubExp] -> BuilderT SOACS PassM ()
forall a b. (a -> b) -> a -> b
$ [SubExp]
args [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
vec) Scope SOACS
scope
else Stms SOACS -> PassM (Stms SOACS)
forall a. a -> PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms SOACS -> PassM (Stms SOACS))
-> Stms SOACS -> PassM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Stms SOACS
forall rep. Stm rep -> Stms rep
oneStm (Stm SOACS -> Stms SOACS) -> Stm SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ Op SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [SubExp] -> Lambda SOACS -> SOAC SOACS
forall rep. [SubExp] -> [SubExp] -> Lambda rep -> SOAC rep
VJP [SubExp]
args [SubExp]
vec Lambda SOACS
lam'
onStm Mode
mode Scope SOACS
scope (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op (JVP [SubExp]
args [SubExp]
vec Lambda SOACS
lam))) = do
Lambda SOACS
lam' <- Mode -> Scope SOACS -> Lambda SOACS -> PassM (Lambda SOACS)
onLambda Mode
mode Scope SOACS
scope Lambda SOACS
lam
if Mode
mode Mode -> Mode -> Bool
forall a. Eq a => a -> a -> Bool
== Mode
All Bool -> Bool -> Bool
|| Lambda SOACS
lam Lambda SOACS -> Lambda SOACS -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda SOACS
lam'
then do
Lambda SOACS
lam'' <- Scope SOACS -> Lambda SOACS -> PassM (Lambda SOACS)
forall (m :: * -> *).
MonadFreshNames m =>
Scope SOACS -> Lambda SOACS -> m (Lambda SOACS)
fwdJVP Scope SOACS
scope Lambda SOACS
lam'
BuilderT SOACS PassM () -> Scope SOACS -> PassM (Stms SOACS)
forall (m :: * -> *) rep.
MonadFreshNames m =>
BuilderT rep m () -> Scope rep -> m (Stms rep)
runBuilderT_ (Pat Type
-> StmAux (ExpDec SOACS)
-> Lambda SOACS
-> [SubExp]
-> BuilderT SOACS PassM ()
forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
Pat Type
-> StmAux (ExpDec SOACS) -> Lambda SOACS -> [SubExp] -> m ()
bindLambda Pat Type
Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux Lambda SOACS
lam'' ([SubExp] -> BuilderT SOACS PassM ())
-> [SubExp] -> BuilderT SOACS PassM ()
forall a b. (a -> b) -> a -> b
$ [SubExp]
args [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
vec) Scope SOACS
scope
else Stms SOACS -> PassM (Stms SOACS)
forall a. a -> PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms SOACS -> PassM (Stms SOACS))
-> Stms SOACS -> PassM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Stms SOACS
forall rep. Stm rep -> Stms rep
oneStm (Stm SOACS -> Stms SOACS) -> Stm SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ Op SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [SubExp] -> Lambda SOACS -> SOAC SOACS
forall rep. [SubExp] -> [SubExp] -> Lambda rep -> SOAC rep
JVP [SubExp]
args [SubExp]
vec Lambda SOACS
lam'
onStm Mode
mode Scope SOACS
scope (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux Exp SOACS
e) = Stm SOACS -> Stms SOACS
forall rep. Stm rep -> Stms rep
oneStm (Stm SOACS -> Stms SOACS)
-> (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stms SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stms SOACS)
-> PassM (Exp SOACS) -> PassM (Stms SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Mapper SOACS SOACS PassM -> Exp SOACS -> PassM (Exp SOACS)
forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper SOACS SOACS PassM
mapper Exp SOACS
e
where
mapper :: Mapper SOACS SOACS PassM
mapper =
(forall rep (m :: * -> *). Monad m => Mapper rep rep m
identityMapper @SOACS)
{ mapOnBody = \Scope SOACS
bscope -> Mode -> Scope SOACS -> Body SOACS -> PassM (Body SOACS)
onBody Mode
mode (Scope SOACS
bscope Scope SOACS -> Scope SOACS -> Scope SOACS
forall a. Semigroup a => a -> a -> a
<> Scope SOACS
scope),
mapOnOp = mapSOACM soac_mapper
}
soac_mapper :: SOACMapper SOACS SOACS PassM
soac_mapper = SOACMapper Any Any PassM
forall rep (m :: * -> *). Monad m => SOACMapper rep rep m
identitySOACMapper {mapOnSOACLambda = onLambda mode scope}
onStms :: Mode -> Scope SOACS -> Stms SOACS -> PassM (Stms SOACS)
onStms :: Mode -> Scope SOACS -> Stms SOACS -> PassM (Stms SOACS)
onStms Mode
mode Scope SOACS
scope Stms SOACS
stms = [Stms SOACS] -> Stms SOACS
forall a. Monoid a => [a] -> a
mconcat ([Stms SOACS] -> Stms SOACS)
-> PassM [Stms SOACS] -> PassM (Stms SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm SOACS -> PassM (Stms SOACS))
-> [Stm SOACS] -> PassM [Stms SOACS]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Mode -> Scope SOACS -> Stm SOACS -> PassM (Stms SOACS)
onStm Mode
mode Scope SOACS
scope') (Stms SOACS -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms SOACS
stms)
where
scope' :: Scope SOACS
scope' = Stms SOACS -> Scope SOACS
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms SOACS
stms Scope SOACS -> Scope SOACS -> Scope SOACS
forall a. Semigroup a => a -> a -> a
<> Scope SOACS
scope
onBody :: Mode -> Scope SOACS -> Body SOACS -> PassM (Body SOACS)
onBody :: Mode -> Scope SOACS -> Body SOACS -> PassM (Body SOACS)
onBody Mode
mode Scope SOACS
scope Body SOACS
body = do
Stms SOACS
stms <- Mode -> Scope SOACS -> Stms SOACS -> PassM (Stms SOACS)
onStms Mode
mode Scope SOACS
scope (Stms SOACS -> PassM (Stms SOACS))
-> Stms SOACS -> PassM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Stms SOACS
forall rep. Body rep -> Stms rep
bodyStms Body SOACS
body
Body SOACS -> PassM (Body SOACS)
forall a. a -> PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body SOACS -> PassM (Body SOACS))
-> Body SOACS -> PassM (Body SOACS)
forall a b. (a -> b) -> a -> b
$ Body SOACS
body {bodyStms = stms}
onLambda :: Mode -> Scope SOACS -> Lambda SOACS -> PassM (Lambda SOACS)
onLambda :: Mode -> Scope SOACS -> Lambda SOACS -> PassM (Lambda SOACS)
onLambda Mode
mode Scope SOACS
scope Lambda SOACS
lam = do
Body SOACS
body <- Mode -> Scope SOACS -> Body SOACS -> PassM (Body SOACS)
onBody Mode
mode ([Param Type] -> Scope SOACS
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam) Scope SOACS -> Scope SOACS -> Scope SOACS
forall a. Semigroup a => a -> a -> a
<> Scope SOACS
scope) (Body SOACS -> PassM (Body SOACS))
-> Body SOACS -> PassM (Body SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam
Lambda SOACS -> PassM (Lambda SOACS)
forall a. a -> PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda SOACS -> PassM (Lambda SOACS))
-> Lambda SOACS -> PassM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS
lam {lambdaBody = body}
onFun :: Mode -> Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS)
onFun :: Mode -> Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS)
onFun Mode
mode Stms SOACS
consts FunDef SOACS
fd = do
Body SOACS
body <- Mode -> Scope SOACS -> Body SOACS -> PassM (Body SOACS)
onBody Mode
mode (Stms SOACS -> Scope SOACS
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms SOACS
consts Scope SOACS -> Scope SOACS -> Scope SOACS
forall a. Semigroup a => a -> a -> a
<> FunDef SOACS -> Scope SOACS
forall rep a. Scoped rep a => a -> Scope rep
scopeOf FunDef SOACS
fd) (Body SOACS -> PassM (Body SOACS))
-> Body SOACS -> PassM (Body SOACS)
forall a b. (a -> b) -> a -> b
$ FunDef SOACS -> Body SOACS
forall rep. FunDef rep -> Body rep
funDefBody FunDef SOACS
fd
FunDef SOACS -> PassM (FunDef SOACS)
forall a. a -> PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (FunDef SOACS -> PassM (FunDef SOACS))
-> FunDef SOACS -> PassM (FunDef SOACS)
forall a b. (a -> b) -> a -> b
$ FunDef SOACS
fd {funDefBody = body}
applyAD :: Pass SOACS SOACS
applyAD :: Pass SOACS SOACS
applyAD =
Pass
{ passName :: String
passName = String
"ad",
passDescription :: String
passDescription = String
"Apply AD operators",
passFunction :: Prog SOACS -> PassM (Prog SOACS)
passFunction =
(Stms SOACS -> PassM (Stms SOACS))
-> (Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS))
-> Prog SOACS
-> PassM (Prog SOACS)
forall fromrep torep.
(Stms fromrep -> PassM (Stms torep))
-> (Stms torep -> FunDef fromrep -> PassM (FunDef torep))
-> Prog fromrep
-> PassM (Prog torep)
intraproceduralTransformationWithConsts
(Mode -> Scope SOACS -> Stms SOACS -> PassM (Stms SOACS)
onStms Mode
All Scope SOACS
forall a. Monoid a => a
mempty)
(Mode -> Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS)
onFun Mode
All)
}
applyADInnermost :: Pass SOACS SOACS
applyADInnermost :: Pass SOACS SOACS
applyADInnermost =
Pass
{ passName :: String
passName = String
"ad innermost",
passDescription :: String
passDescription = String
"Apply innermost AD operators",
passFunction :: Prog SOACS -> PassM (Prog SOACS)
passFunction =
(Stms SOACS -> PassM (Stms SOACS))
-> (Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS))
-> Prog SOACS
-> PassM (Prog SOACS)
forall fromrep torep.
(Stms fromrep -> PassM (Stms torep))
-> (Stms torep -> FunDef fromrep -> PassM (FunDef torep))
-> Prog fromrep
-> PassM (Prog torep)
intraproceduralTransformationWithConsts
(Mode -> Scope SOACS -> Stms SOACS -> PassM (Stms SOACS)
onStms Mode
Innermost Scope SOACS
forall a. Monoid a => a
mempty)
(Mode -> Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS)
onFun Mode
Innermost)
}