{-# LANGUAGE FlexibleContexts #-}

module Language.Haskell.Liquid.Transforms.InlineAux
  ( inlineAux
  )
where
import qualified Language.Haskell.Liquid.UX.Config  as UX
import           Liquid.GHC.API
import           Control.Arrow                  (second)
import qualified Language.Haskell.Liquid.GHC.Misc
                                               as GM
import qualified Data.HashMap.Strict           as M

inlineAux :: UX.Config -> Module -> CoreProgram -> CoreProgram
inlineAux :: Config -> Module -> CoreProgram -> CoreProgram
inlineAux Config
cfg Module
m CoreProgram
cbs =  if Config -> Bool
UX.auxInline Config
cfg then Module
-> (Var -> Bool)
-> (Activation -> Bool)
-> [CoreRule]
-> CoreProgram
-> CoreProgram
occurAnalysePgm Module
m (Bool -> Var -> Bool
forall a b. a -> b -> a
const Bool
False) (Bool -> Activation -> Bool
forall a b. a -> b -> a
const Bool
False) [] ((Bind Var -> Bind Var) -> CoreProgram -> CoreProgram
forall a b. (a -> b) -> [a] -> [b]
map Bind Var -> Bind Var
f CoreProgram
cbs) else CoreProgram
cbs
 where
  f :: CoreBind -> CoreBind
  f :: Bind Var -> Bind Var
f all' :: Bind Var
all'@(NonRec Var
x CoreExpr
e)
    | Just (Var
dfunId, HashMap Var Var
methodToAux) <- Var
-> HashMap Var (Var, HashMap Var Var)
-> Maybe (Var, HashMap Var Var)
forall k v. Hashable k => k -> HashMap k v -> Maybe v
M.lookup Var
x HashMap Var (Var, HashMap Var Var)
auxToMethodToAux = Var -> CoreExpr -> Bind Var
forall b. b -> Expr b -> Bind b
NonRec
      Var
x
      (Var -> HashMap Var Var -> CoreExpr -> CoreExpr
inlineAuxExpr Var
dfunId HashMap Var Var
methodToAux CoreExpr
e)
    | Bool
otherwise = Bind Var
all'
  f (Rec [(Var, CoreExpr)]
bs) = [(Var, CoreExpr)] -> Bind Var
forall b. [(b, Expr b)] -> Bind b
Rec (((Var, CoreExpr) -> (Var, CoreExpr))
-> [(Var, CoreExpr)] -> [(Var, CoreExpr)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Var, CoreExpr) -> (Var, CoreExpr)
g [(Var, CoreExpr)]
bs)
   where
    g :: (Var, CoreExpr) -> (Var, CoreExpr)
g all' :: (Var, CoreExpr)
all'@(Var
x, CoreExpr
e)
      | Just (Var
dfunId, HashMap Var Var
methodToAux) <- Var
-> HashMap Var (Var, HashMap Var Var)
-> Maybe (Var, HashMap Var Var)
forall k v. Hashable k => k -> HashMap k v -> Maybe v
M.lookup Var
x HashMap Var (Var, HashMap Var Var)
auxToMethodToAux
      = (Var
x, Var -> HashMap Var Var -> CoreExpr -> CoreExpr
inlineAuxExpr Var
dfunId HashMap Var Var
methodToAux CoreExpr
e)
      | Bool
otherwise
      = (Var, CoreExpr)
all'
  auxToMethodToAux :: HashMap Var (Var, HashMap Var Var)
auxToMethodToAux = [HashMap Var (Var, HashMap Var Var)]
-> HashMap Var (Var, HashMap Var Var)
forall a. Monoid a => [a] -> a
mconcat ([HashMap Var (Var, HashMap Var Var)]
 -> HashMap Var (Var, HashMap Var Var))
-> [HashMap Var (Var, HashMap Var Var)]
-> HashMap Var (Var, HashMap Var Var)
forall a b. (a -> b) -> a -> b
$ ((Var, CoreExpr) -> HashMap Var (Var, HashMap Var Var))
-> [(Var, CoreExpr)] -> [HashMap Var (Var, HashMap Var Var)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Var -> CoreExpr -> HashMap Var (Var, HashMap Var Var))
-> (Var, CoreExpr) -> HashMap Var (Var, HashMap Var Var)
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Var -> CoreExpr -> HashMap Var (Var, HashMap Var Var)
dfunIdSubst) (CoreProgram -> [(Var, CoreExpr)]
grepDFunIds CoreProgram
cbs)


-- inlineDFun :: DynFlags -> CoreProgram -> IO CoreProgram
-- inlineDFun df cbs = mapM go cbs
--  where
--   go orig@(NonRec x e) | isDFunId x = do
--                            -- e''' <- simplifyExpr df e''
--                            let newBody = mkCoreApps (GM.tracePpr ("substituted type:" ++ GM.showPpr (exprType (mkCoreApps e' (Var <$> binders)))) e') (fmap Var binders)
--                                bind = NonRec (mkWildValBinder (exprType newBody)) newBody
--                            pure $ NonRec x (mkLet bind e)
--                        | otherwise  = pure orig
--    where
--     -- wcBinder = mkWildValBinder t
--     (binders, _) = GM.tracePpr "collectBinders"$ collectBinders e
--     e' = substExprAll empty subst e
--   go recs = pure recs
--   subst = buildDictSubst cbs

-- grab the dictionaries
grepDFunIds :: CoreProgram -> [(DFunId, CoreExpr)]
grepDFunIds :: CoreProgram -> [(Var, CoreExpr)]
grepDFunIds = ((Var, CoreExpr) -> Bool) -> [(Var, CoreExpr)] -> [(Var, CoreExpr)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Var -> Bool
isDFunId (Var -> Bool)
-> ((Var, CoreExpr) -> Var) -> (Var, CoreExpr) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Var, CoreExpr) -> Var
forall a b. (a, b) -> a
fst) ([(Var, CoreExpr)] -> [(Var, CoreExpr)])
-> (CoreProgram -> [(Var, CoreExpr)])
-> CoreProgram
-> [(Var, CoreExpr)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CoreProgram -> [(Var, CoreExpr)]
forall b. [Bind b] -> [(b, Expr b)]
flattenBinds

isClassOpAuxOccName :: OccName -> Bool
isClassOpAuxOccName :: OccName -> Bool
isClassOpAuxOccName OccName
occ = case OccName -> [Char]
occNameString OccName
occ of
  Char
'$' : Char
'c' : [Char]
_ -> Bool
True
  [Char]
_             -> Bool
False

isClassOpAuxOf :: Id -> Id -> Bool
isClassOpAuxOf :: Var -> Var -> Bool
isClassOpAuxOf Var
aux Var
method = case OccName -> [Char]
occNameString (OccName -> [Char]) -> OccName -> [Char]
forall a b. (a -> b) -> a -> b
$ Var -> OccName
forall a. NamedThing a => a -> OccName
getOccName Var
aux of
  Char
'$' : Char
'c' : [Char]
rest -> [Char]
rest [Char] -> [Char] -> Bool
forall a. Eq a => a -> a -> Bool
== OccName -> [Char]
occNameString (Var -> OccName
forall a. NamedThing a => a -> OccName
getOccName Var
method)
  [Char]
_                -> Bool
False

dfunIdSubst :: DFunId -> CoreExpr -> M.HashMap Id (Id, M.HashMap Id Id)
dfunIdSubst :: Var -> CoreExpr -> HashMap Var (Var, HashMap Var Var)
dfunIdSubst Var
dfunId CoreExpr
e = [(Var, (Var, HashMap Var Var))]
-> HashMap Var (Var, HashMap Var Var)
forall k v. Hashable k => [(k, v)] -> HashMap k v
M.fromList [(Var
auxId, (Var
dfunId, HashMap Var Var
methodToAux)) | Var
auxId <- [Var]
auxIds]
 where
  methodToAux :: HashMap Var Var
methodToAux = [(Var, Var)] -> HashMap Var Var
forall k v. Hashable k => [(k, v)] -> HashMap k v
M.fromList
    [ (Var
m, Var
aux) | Var
m <- [Var]
methods, Var
aux <- [Var]
auxIds, Var
aux Var -> Var -> Bool
`isClassOpAuxOf` Var
m ]
  ([Var]
_, [Type]
_, Class
cls, [Type]
_) = Type -> ([Var], [Type], Class, [Type])
tcSplitDFunTy (Var -> Type
idType Var
dfunId)
  auxIds :: [Var]
auxIds = (Var -> Bool) -> [Var] -> [Var]
forall a. (a -> Bool) -> [a] -> [a]
filter (OccName -> Bool
isClassOpAuxOccName (OccName -> Bool) -> (Var -> OccName) -> Var -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Var -> OccName
forall a. NamedThing a => a -> OccName
getOccName) (CoreExpr -> [Var]
exprFreeVarsList CoreExpr
e)
  methods :: [Var]
methods = Class -> [Var]
classAllSelIds Class
cls

inlineAuxExpr :: DFunId -> M.HashMap Id Id -> CoreExpr -> CoreExpr
inlineAuxExpr :: Var -> HashMap Var Var -> CoreExpr -> CoreExpr
inlineAuxExpr Var
dfunId HashMap Var Var
methodToAux = CoreExpr -> CoreExpr
go
 where
  go :: CoreExpr -> CoreExpr
  go :: CoreExpr -> CoreExpr
go (Lam Var
b CoreExpr
body) = Var -> CoreExpr -> CoreExpr
forall b. b -> Expr b -> Expr b
Lam Var
b (CoreExpr -> CoreExpr
go CoreExpr
body)
  go (Let Bind Var
b CoreExpr
body)
    | NonRec Var
x CoreExpr
e <- Bind Var
b, Var -> Bool
isDictId Var
x =
        CoreExpr -> CoreExpr
go (CoreExpr -> CoreExpr) -> CoreExpr -> CoreExpr
forall a b. (a -> b) -> a -> b
$ HasDebugCallStack => Subst -> CoreExpr -> CoreExpr
Subst -> CoreExpr -> CoreExpr
substExpr (Subst -> Var -> CoreExpr -> Subst
extendIdSubst Subst
emptySubst Var
x CoreExpr
e) CoreExpr
body
    | Bool
otherwise = Bind Var -> CoreExpr -> CoreExpr
forall b. Bind b -> Expr b -> Expr b
Let ((CoreExpr -> CoreExpr) -> Bind Var -> Bind Var
forall b. (Expr b -> Expr b) -> Bind b -> Bind b
mapBnd CoreExpr -> CoreExpr
go Bind Var
b) (CoreExpr -> CoreExpr
go CoreExpr
body)
  go (Case CoreExpr
e Var
x Type
t [Alt Var]
alts) = CoreExpr -> Var -> Type -> [Alt Var] -> CoreExpr
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case (CoreExpr -> CoreExpr
go CoreExpr
e) Var
x Type
t ((Alt Var -> Alt Var) -> [Alt Var] -> [Alt Var]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((CoreExpr -> CoreExpr) -> Alt Var -> Alt Var
forall b. (Expr b -> Expr b) -> Alt b -> Alt b
mapAlt CoreExpr -> CoreExpr
go) [Alt Var]
alts)
  go (Cast CoreExpr
e CoercionR
c       ) = CoreExpr -> CoercionR -> CoreExpr
forall b. Expr b -> CoercionR -> Expr b
Cast (CoreExpr -> CoreExpr
go CoreExpr
e) CoercionR
c
  go (Tick CoreTickish
t CoreExpr
e       ) = CoreTickish -> CoreExpr -> CoreExpr
forall b. CoreTickish -> Expr b -> Expr b
Tick CoreTickish
t (CoreExpr -> CoreExpr
go CoreExpr
e)
  go CoreExpr
e
    | (Var Var
m, [CoreExpr]
args) <- CoreExpr -> (CoreExpr, [CoreExpr])
forall b. Expr b -> (Expr b, [Expr b])
collectArgs CoreExpr
e
    , Just Var
aux <- Var -> HashMap Var Var -> Maybe Var
forall k v. Hashable k => k -> HashMap k v -> Maybe v
M.lookup Var
m HashMap Var Var
methodToAux
    , CoreExpr
arg : [CoreExpr]
argsNoTy <- (CoreExpr -> Bool) -> [CoreExpr] -> [CoreExpr]
forall a. (a -> Bool) -> [a] -> [a]
dropWhile CoreExpr -> Bool
forall b. Expr b -> Bool
isTypeArg [CoreExpr]
args
    , (Var Var
x, [CoreExpr]
argargs) <- CoreExpr -> (CoreExpr, [CoreExpr])
forall b. Expr b -> (Expr b, [Expr b])
collectArgs CoreExpr
arg
    , Var
x Var -> Var -> Bool
forall a. Eq a => a -> a -> Bool
== Var
dfunId
    = [Char] -> CoreExpr -> CoreExpr
forall a. Outputable a => [Char] -> a -> a
GM.notracePpr ([Char]
"inlining in" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ CoreExpr -> [Char]
forall a. Outputable a => a -> [Char]
GM.showPpr CoreExpr
e)
      (CoreExpr -> CoreExpr) -> CoreExpr -> CoreExpr
forall a b. (a -> b) -> a -> b
$ CoreExpr -> [CoreExpr] -> CoreExpr
mkCoreApps (Var -> CoreExpr
forall b. Var -> Expr b
Var Var
aux) ([CoreExpr]
argargs [CoreExpr] -> [CoreExpr] -> [CoreExpr]
forall a. [a] -> [a] -> [a]
++ (CoreExpr -> CoreExpr
go (CoreExpr -> CoreExpr) -> [CoreExpr] -> [CoreExpr]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CoreExpr]
argsNoTy))
  go (App CoreExpr
e0 CoreExpr
e1) = CoreExpr -> CoreExpr -> CoreExpr
forall b. Expr b -> Expr b -> Expr b
App (CoreExpr -> CoreExpr
go CoreExpr
e0) (CoreExpr -> CoreExpr
go CoreExpr
e1)
  go CoreExpr
e           = CoreExpr
e


-- modified from Rec.hs
mapBnd :: (Expr b -> Expr b) -> Bind b -> Bind b
mapBnd :: forall b. (Expr b -> Expr b) -> Bind b -> Bind b
mapBnd Expr b -> Expr b
f (NonRec b
b Expr b
e) = b -> Expr b -> Bind b
forall b. b -> Expr b -> Bind b
NonRec b
b (Expr b -> Expr b
f Expr b
e)
mapBnd Expr b -> Expr b
f (Rec [(b, Expr b)]
bs    ) = [(b, Expr b)] -> Bind b
forall b. [(b, Expr b)] -> Bind b
Rec (((b, Expr b) -> (b, Expr b)) -> [(b, Expr b)] -> [(b, Expr b)]
forall a b. (a -> b) -> [a] -> [b]
map ((Expr b -> Expr b) -> (b, Expr b) -> (b, Expr b)
forall b c d. (b -> c) -> (d, b) -> (d, c)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second Expr b -> Expr b
f) [(b, Expr b)]
bs)

mapAlt :: (Expr b -> Expr b) -> Alt b -> Alt b
mapAlt :: forall b. (Expr b -> Expr b) -> Alt b -> Alt b
mapAlt Expr b -> Expr b
f (Alt AltCon
d [b]
bs Expr b
e) = AltCon -> [b] -> Expr b -> Alt b
forall b. AltCon -> [b] -> Expr b -> Alt b
Alt AltCon
d [b]
bs (Expr b -> Expr b
f Expr b
e)