{-# LANGUAGE OverloadedStrings         #-}
{-# LANGUAGE FlexibleInstances         #-}
{-# LANGUAGE GADTs                     #-}
{-# LANGUAGE NoMonomorphismRestriction #-}
{-# LANGUAGE RankNTypes                #-}
{-# LANGUAGE TupleSections             #-}
{-# LANGUAGE UndecidableInstances      #-}
{-# LANGUAGE FlexibleContexts          #-}

{-# OPTIONS_GHC -Wno-incomplete-uni-patterns #-}

-- | This module contains functions for recursively "rewriting"
--   GHC core using "rules".

module Language.Haskell.Liquid.Transforms.Rewrite
  ( -- * Top level rewrite function
    rewriteBinds

  -- * Low-level Rewriting Function
  -- , rewriteWith

  -- * Rewrite Rule
  -- ,  RewriteRule

  ) where

import           Liquid.GHC.API as Ghc hiding (get, showPpr, substExpr)
import           Language.Haskell.Liquid.GHC.TypeRep ()
import           Data.Maybe     (fromMaybe, isJust, mapMaybe)
import           Control.Monad.State hiding (lift)
import           Language.Haskell.Liquid.Misc (Nat)
import           Language.Haskell.Liquid.GHC.Play (sub, substExpr)
import           Language.Haskell.Liquid.GHC.Misc (unTickExpr, isTupleId, mkAlive)
import           Language.Haskell.Liquid.Types.Errors (impossible)
import           Language.Haskell.Liquid.UX.Config  (Config, noSimplifyCore)
import qualified Data.List as L
import qualified Data.HashMap.Strict as M

--------------------------------------------------------------------------------
-- | Top-level rewriter --------------------------------------------------------
--------------------------------------------------------------------------------
rewriteBinds :: Config -> [CoreBind] -> [CoreBind]
rewriteBinds :: Config -> [CoreBind] -> [CoreBind]
rewriteBinds Config
cfg
  | Config -> Bool
simplifyCore Config
cfg
  = (CoreBind -> CoreBind) -> [CoreBind] -> [CoreBind]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (CoreBind -> CoreBind
normalizeTuples 
       (CoreBind -> CoreBind)
-> (CoreBind -> CoreBind) -> CoreBind -> CoreBind
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RewriteRule -> CoreBind -> CoreBind
rewriteBindWith RewriteRule
undollar
       (CoreBind -> CoreBind)
-> (CoreBind -> CoreBind) -> CoreBind -> CoreBind
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CoreBind -> CoreBind
tidyTuples
       (CoreBind -> CoreBind)
-> (CoreBind -> CoreBind) -> CoreBind -> CoreBind
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RewriteRule -> CoreBind -> CoreBind
rewriteBindWith RewriteRule
inlineLoopBreakerTx
       (CoreBind -> CoreBind)
-> (CoreBind -> CoreBind) -> CoreBind -> CoreBind
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CoreBind -> CoreBind
inlineLoopBreaker
       (CoreBind -> CoreBind)
-> (CoreBind -> CoreBind) -> CoreBind -> CoreBind
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RewriteRule -> CoreBind -> CoreBind
rewriteBindWith RewriteRule
strictifyLazyLets
       (CoreBind -> CoreBind)
-> (CoreBind -> CoreBind) -> CoreBind -> CoreBind
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CoreBind -> CoreBind
inlineFailCases)
  | Bool
otherwise
  = [CoreBind] -> [CoreBind]
forall a. a -> a
id

simplifyCore :: Config -> Bool
simplifyCore :: Config -> Bool
simplifyCore = Bool -> Bool
not (Bool -> Bool) -> (Config -> Bool) -> Config -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Config -> Bool
noSimplifyCore

undollar :: RewriteRule
undollar :: RewriteRule
undollar CoreExpr
e
    | Just (CoreExpr
f, CoreExpr
a) <- CoreExpr -> Maybe (CoreExpr, CoreExpr)
splitDollarApp CoreExpr
e =
      RewriteRule
forall a. a -> Maybe a
Just RewriteRule -> RewriteRule
forall a b. (a -> b) -> a -> b
$ CoreExpr -> CoreExpr -> CoreExpr
forall b. Expr b -> Expr b -> Expr b
App CoreExpr
f CoreExpr
a
    | Bool
otherwise = Maybe CoreExpr
forall a. Maybe a
Nothing

tidyTuples :: CoreBind -> CoreBind
tidyTuples :: CoreBind -> CoreBind
tidyTuples CoreBind
ce = case CoreBind
ce of
   NonRec CoreBndr
x CoreExpr
e -> CoreBndr -> CoreExpr -> CoreBind
forall b. b -> Expr b -> Bind b
NonRec CoreBndr
x (State [((AltCon, CoreBndr), [CoreBndr])] CoreExpr
-> [((AltCon, CoreBndr), [CoreBndr])] -> CoreExpr
forall s a. State s a -> s -> a
evalState (CoreExpr -> State [((AltCon, CoreBndr), [CoreBndr])] CoreExpr
forall {f :: * -> *}.
MonadState [((AltCon, CoreBndr), [CoreBndr])] f =>
CoreExpr -> f CoreExpr
go CoreExpr
e) [])
   Rec [(CoreBndr, CoreExpr)]
xs -> [(CoreBndr, CoreExpr)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec ([(CoreBndr, CoreExpr)] -> CoreBind)
-> [(CoreBndr, CoreExpr)] -> CoreBind
forall a b. (a -> b) -> a -> b
$ ((CoreBndr, CoreExpr) -> (CoreBndr, CoreExpr))
-> [(CoreBndr, CoreExpr)] -> [(CoreBndr, CoreExpr)]
forall a b. (a -> b) -> [a] -> [b]
map ((CoreExpr -> CoreExpr)
-> (CoreBndr, CoreExpr) -> (CoreBndr, CoreExpr)
forall a b. (a -> b) -> (CoreBndr, a) -> (CoreBndr, b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\CoreExpr
e -> State [((AltCon, CoreBndr), [CoreBndr])] CoreExpr
-> [((AltCon, CoreBndr), [CoreBndr])] -> CoreExpr
forall s a. State s a -> s -> a
evalState (CoreExpr -> State [((AltCon, CoreBndr), [CoreBndr])] CoreExpr
forall {f :: * -> *}.
MonadState [((AltCon, CoreBndr), [CoreBndr])] f =>
CoreExpr -> f CoreExpr
go CoreExpr
e) [])) [(CoreBndr, CoreExpr)]
xs
  where
    go :: CoreExpr -> f CoreExpr
go (Tick CoreTickish
t CoreExpr
e)
      = CoreTickish -> CoreExpr -> CoreExpr
forall b. CoreTickish -> Expr b -> Expr b
Tick CoreTickish
t (CoreExpr -> CoreExpr) -> f CoreExpr -> f CoreExpr
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CoreExpr -> f CoreExpr
go CoreExpr
e
    go (Let (NonRec CoreBndr
x CoreExpr
ex) CoreExpr
e)
      = do ex' <- CoreExpr -> f CoreExpr
go CoreExpr
ex
           e'  <- go e
           return $ Let (NonRec x ex') e'
    go (Let (Rec [(CoreBndr, CoreExpr)]
bes) CoreExpr
e)
      = CoreBind -> CoreExpr -> CoreExpr
forall b. Bind b -> Expr b -> Expr b
Let (CoreBind -> CoreExpr -> CoreExpr)
-> f CoreBind -> f (CoreExpr -> CoreExpr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([(CoreBndr, CoreExpr)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec ([(CoreBndr, CoreExpr)] -> CoreBind)
-> f [(CoreBndr, CoreExpr)] -> f CoreBind
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((CoreBndr, CoreExpr) -> f (CoreBndr, CoreExpr))
-> [(CoreBndr, CoreExpr)] -> f [(CoreBndr, CoreExpr)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (CoreBndr, CoreExpr) -> f (CoreBndr, CoreExpr)
goRec [(CoreBndr, CoreExpr)]
bes) f (CoreExpr -> CoreExpr) -> f CoreExpr -> f CoreExpr
forall a b. f (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> CoreExpr -> f CoreExpr
go CoreExpr
e
    go (Case (Var CoreBndr
v) CoreBndr
x Type
t [Alt CoreBndr]
alts)
      = CoreExpr -> CoreBndr -> Type -> [Alt CoreBndr] -> CoreExpr
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case (CoreBndr -> CoreExpr
forall b. CoreBndr -> Expr b
Var CoreBndr
v) CoreBndr
x Type
t ([Alt CoreBndr] -> CoreExpr) -> f [Alt CoreBndr] -> f CoreExpr
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Alt CoreBndr -> f (Alt CoreBndr))
-> [Alt CoreBndr] -> f [Alt CoreBndr]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (CoreBndr -> Alt CoreBndr -> f (Alt CoreBndr)
forall {m :: * -> *} {b}.
(MonadState [((AltCon, b), [CoreBndr])] m, Eq b) =>
b -> Alt CoreBndr -> m (Alt CoreBndr)
goAltR CoreBndr
v) [Alt CoreBndr]
alts
    go (Case CoreExpr
e CoreBndr
x Type
t [Alt CoreBndr]
alts)
      = CoreExpr -> CoreBndr -> Type -> [Alt CoreBndr] -> CoreExpr
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case CoreExpr
e CoreBndr
x Type
t ([Alt CoreBndr] -> CoreExpr) -> f [Alt CoreBndr] -> f CoreExpr
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Alt CoreBndr -> f (Alt CoreBndr))
-> [Alt CoreBndr] -> f [Alt CoreBndr]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Alt CoreBndr -> f (Alt CoreBndr)
goAlt [Alt CoreBndr]
alts
    go (App CoreExpr
e1 CoreExpr
e2)
      = CoreExpr -> CoreExpr -> CoreExpr
forall b. Expr b -> Expr b -> Expr b
App (CoreExpr -> CoreExpr -> CoreExpr)
-> f CoreExpr -> f (CoreExpr -> CoreExpr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CoreExpr -> f CoreExpr
go CoreExpr
e1 f (CoreExpr -> CoreExpr) -> f CoreExpr -> f CoreExpr
forall a b. f (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> CoreExpr -> f CoreExpr
go CoreExpr
e2
    go (Lam CoreBndr
x CoreExpr
e)
      = CoreBndr -> CoreExpr -> CoreExpr
forall b. b -> Expr b -> Expr b
Lam CoreBndr
x (CoreExpr -> CoreExpr) -> f CoreExpr -> f CoreExpr
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CoreExpr -> f CoreExpr
go CoreExpr
e
    go (Cast CoreExpr
e CoercionR
c)
      = (CoreExpr -> CoercionR -> CoreExpr
forall b. Expr b -> CoercionR -> Expr b
`Cast` CoercionR
c) (CoreExpr -> CoreExpr) -> f CoreExpr -> f CoreExpr
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CoreExpr -> f CoreExpr
go CoreExpr
e
    go CoreExpr
e
      = CoreExpr -> f CoreExpr
forall a. a -> f a
forall (m :: * -> *) a. Monad m => a -> m a
return CoreExpr
e

    goRec :: (CoreBndr, CoreExpr) -> f (CoreBndr, CoreExpr)
goRec (CoreBndr
x, CoreExpr
e)
      = (CoreBndr
x,) (CoreExpr -> (CoreBndr, CoreExpr))
-> f CoreExpr -> f (CoreBndr, CoreExpr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CoreExpr -> f CoreExpr
go CoreExpr
e

    goAlt :: Alt CoreBndr -> f (Alt CoreBndr)
goAlt (Alt AltCon
c [CoreBndr]
bs CoreExpr
e)
      = AltCon -> [CoreBndr] -> CoreExpr -> Alt CoreBndr
forall b. AltCon -> [b] -> Expr b -> Alt b
Alt AltCon
c [CoreBndr]
bs (CoreExpr -> Alt CoreBndr) -> f CoreExpr -> f (Alt CoreBndr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CoreExpr -> f CoreExpr
go CoreExpr
e

    goAltR :: b -> Alt CoreBndr -> m (Alt CoreBndr)
goAltR b
v (Alt AltCon
c [CoreBndr]
bs CoreExpr
e)
      = do m <- m [((AltCon, b), [CoreBndr])]
forall s (m :: * -> *). MonadState s m => m s
get
           case L.lookup (c,v) m of
            Just [CoreBndr]
bs' -> Alt CoreBndr -> m (Alt CoreBndr)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (AltCon -> [CoreBndr] -> CoreExpr -> Alt CoreBndr
forall b. AltCon -> [b] -> Expr b -> Alt b
Alt AltCon
c [CoreBndr]
bs' ([CoreBndr] -> [CoreBndr] -> CoreExpr -> CoreExpr
substTuple [CoreBndr]
bs' [CoreBndr]
bs CoreExpr
e))
            Maybe [CoreBndr]
Nothing  -> do let bs' :: [CoreBndr]
bs' = CoreBndr -> CoreBndr
mkAlive (CoreBndr -> CoreBndr) -> [CoreBndr] -> [CoreBndr]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CoreBndr]
bs
                           ([((AltCon, b), [CoreBndr])] -> [((AltCon, b), [CoreBndr])])
-> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (((AltCon
c,b
v),[CoreBndr]
bs')((AltCon, b), [CoreBndr])
-> [((AltCon, b), [CoreBndr])] -> [((AltCon, b), [CoreBndr])]
forall a. a -> [a] -> [a]
:)
                           Alt CoreBndr -> m (Alt CoreBndr)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (AltCon -> [CoreBndr] -> CoreExpr -> Alt CoreBndr
forall b. AltCon -> [b] -> Expr b -> Alt b
Alt AltCon
c [CoreBndr]
bs' CoreExpr
e)



normalizeTuples :: CoreBind -> CoreBind
normalizeTuples :: CoreBind -> CoreBind
normalizeTuples CoreBind
cb
  | NonRec CoreBndr
x CoreExpr
e <- CoreBind
cb
  = CoreBndr -> CoreExpr -> CoreBind
forall b. b -> Expr b -> Bind b
NonRec CoreBndr
x (CoreExpr -> CoreBind) -> CoreExpr -> CoreBind
forall a b. (a -> b) -> a -> b
$ CoreExpr -> CoreExpr
go CoreExpr
e
  | Rec [(CoreBndr, CoreExpr)]
xes <- CoreBind
cb
  = let ([CoreBndr]
xs,[CoreExpr]
es) = [(CoreBndr, CoreExpr)] -> ([CoreBndr], [CoreExpr])
forall a b. [(a, b)] -> ([a], [b])
unzip [(CoreBndr, CoreExpr)]
xes in
    [(CoreBndr, CoreExpr)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec ([(CoreBndr, CoreExpr)] -> CoreBind)
-> [(CoreBndr, CoreExpr)] -> CoreBind
forall a b. (a -> b) -> a -> b
$ [CoreBndr] -> [CoreExpr] -> [(CoreBndr, CoreExpr)]
forall a b. [a] -> [b] -> [(a, b)]
zip [CoreBndr]
xs (CoreExpr -> CoreExpr
go (CoreExpr -> CoreExpr) -> [CoreExpr] -> [CoreExpr]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CoreExpr]
es)
  where
    go :: CoreExpr -> CoreExpr
go (Let (NonRec CoreBndr
x CoreExpr
ex) CoreExpr
e)
      | Case CoreExpr
_ CoreBndr
_ Type
_ [Alt CoreBndr]
alts  <- CoreExpr -> CoreExpr
unTickExpr CoreExpr
ex
      , [Alt AltCon
_ [CoreBndr]
vs (Var CoreBndr
z)] <- [Alt CoreBndr]
alts
      , CoreBndr
z CoreBndr -> [CoreBndr] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [CoreBndr]
vs
      = CoreBind -> CoreExpr -> CoreExpr
forall b. Bind b -> Expr b -> Expr b
Let (CoreBndr -> CoreExpr -> CoreBind
forall b. b -> Expr b -> Bind b
NonRec CoreBndr
z (CoreExpr -> CoreExpr
go CoreExpr
ex)) ([CoreBndr] -> [CoreBndr] -> CoreExpr -> CoreExpr
substTuple [CoreBndr
z] [CoreBndr
x] (CoreExpr -> CoreExpr
go CoreExpr
e))
    go (Let (NonRec CoreBndr
x CoreExpr
ex) CoreExpr
e)
      = CoreBind -> CoreExpr -> CoreExpr
forall b. Bind b -> Expr b -> Expr b
Let (CoreBndr -> CoreExpr -> CoreBind
forall b. b -> Expr b -> Bind b
NonRec CoreBndr
x (CoreExpr -> CoreExpr
go CoreExpr
ex)) (CoreExpr -> CoreExpr
go CoreExpr
e)
    go (Let (Rec [(CoreBndr, CoreExpr)]
xes) CoreExpr
e)
      = CoreBind -> CoreExpr -> CoreExpr
forall b. Bind b -> Expr b -> Expr b
Let ([(CoreBndr, CoreExpr)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec ((CoreExpr -> CoreExpr)
-> (CoreBndr, CoreExpr) -> (CoreBndr, CoreExpr)
forall a b. (a -> b) -> (CoreBndr, a) -> (CoreBndr, b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap CoreExpr -> CoreExpr
go ((CoreBndr, CoreExpr) -> (CoreBndr, CoreExpr))
-> [(CoreBndr, CoreExpr)] -> [(CoreBndr, CoreExpr)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(CoreBndr, CoreExpr)]
xes)) (CoreExpr -> CoreExpr
go CoreExpr
e)
    go (App CoreExpr
e1 CoreExpr
e2)
      = CoreExpr -> CoreExpr -> CoreExpr
forall b. Expr b -> Expr b -> Expr b
App (CoreExpr -> CoreExpr
go CoreExpr
e1) (CoreExpr -> CoreExpr
go CoreExpr
e2)
    go (Lam CoreBndr
x CoreExpr
e)
      = CoreBndr -> CoreExpr -> CoreExpr
forall b. b -> Expr b -> Expr b
Lam CoreBndr
x (CoreExpr -> CoreExpr
go CoreExpr
e)
    go (Case CoreExpr
e CoreBndr
b Type
t [Alt CoreBndr]
alt)
      = CoreExpr -> CoreBndr -> Type -> [Alt CoreBndr] -> CoreExpr
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case (CoreExpr -> CoreExpr
go CoreExpr
e) CoreBndr
b Type
t ((\(Alt AltCon
c [CoreBndr]
bs CoreExpr
e') -> AltCon -> [CoreBndr] -> CoreExpr -> Alt CoreBndr
forall b. AltCon -> [b] -> Expr b -> Alt b
Alt AltCon
c [CoreBndr]
bs (CoreExpr -> CoreExpr
go CoreExpr
e')) (Alt CoreBndr -> Alt CoreBndr) -> [Alt CoreBndr] -> [Alt CoreBndr]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Alt CoreBndr]
alt)
    go (Cast CoreExpr
e CoercionR
c)
      = CoreExpr -> CoercionR -> CoreExpr
forall b. Expr b -> CoercionR -> Expr b
Cast (CoreExpr -> CoreExpr
go CoreExpr
e) CoercionR
c
    go (Tick CoreTickish
t CoreExpr
e)
      = CoreTickish -> CoreExpr -> CoreExpr
forall b. CoreTickish -> Expr b -> Expr b
Tick CoreTickish
t (CoreExpr -> CoreExpr
go CoreExpr
e)
    go (Type Type
t)
      = Type -> CoreExpr
forall b. Type -> Expr b
Type Type
t
    go (Coercion CoercionR
c)
      = CoercionR -> CoreExpr
forall b. CoercionR -> Expr b
Coercion CoercionR
c
    go (Lit Literal
l)
      = Literal -> CoreExpr
forall b. Literal -> Expr b
Lit Literal
l
    go (Var CoreBndr
x)
      = CoreBndr -> CoreExpr
forall b. CoreBndr -> Expr b
Var CoreBndr
x


--------------------------------------------------------------------------------
-- | A @RewriteRule@ is a function that maps a CoreExpr to another
--------------------------------------------------------------------------------
type RewriteRule = CoreExpr -> Maybe CoreExpr
--------------------------------------------------------------------------------

--------------------------------------------------------------------------------
rewriteBindWith :: RewriteRule -> CoreBind -> CoreBind
--------------------------------------------------------------------------------
rewriteBindWith :: RewriteRule -> CoreBind -> CoreBind
rewriteBindWith RewriteRule
r (NonRec CoreBndr
x CoreExpr
e) = CoreBndr -> CoreExpr -> CoreBind
forall b. b -> Expr b -> Bind b
NonRec CoreBndr
x (RewriteRule -> CoreExpr -> CoreExpr
rewriteWith RewriteRule
r CoreExpr
e)
rewriteBindWith RewriteRule
r (Rec [(CoreBndr, CoreExpr)]
xes)    = [(CoreBndr, CoreExpr)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec    ((CoreExpr -> CoreExpr)
-> (CoreBndr, CoreExpr) -> (CoreBndr, CoreExpr)
forall a b. (a -> b) -> (CoreBndr, a) -> (CoreBndr, b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (RewriteRule -> CoreExpr -> CoreExpr
rewriteWith RewriteRule
r) ((CoreBndr, CoreExpr) -> (CoreBndr, CoreExpr))
-> [(CoreBndr, CoreExpr)] -> [(CoreBndr, CoreExpr)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(CoreBndr, CoreExpr)]
xes)

--------------------------------------------------------------------------------
rewriteWith :: RewriteRule -> CoreExpr -> CoreExpr
--------------------------------------------------------------------------------
rewriteWith :: RewriteRule -> CoreExpr -> CoreExpr
rewriteWith RewriteRule
tx           = CoreExpr -> CoreExpr
go
  where
    go :: CoreExpr -> CoreExpr
go                   = CoreExpr -> CoreExpr
step (CoreExpr -> CoreExpr)
-> (CoreExpr -> CoreExpr) -> CoreExpr -> CoreExpr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CoreExpr -> CoreExpr
txTop
    txTop :: CoreExpr -> CoreExpr
txTop CoreExpr
e              = CoreExpr -> Maybe CoreExpr -> CoreExpr
forall a. a -> Maybe a -> a
fromMaybe CoreExpr
e (RewriteRule
tx CoreExpr
e)
    goB :: CoreBind -> CoreBind
goB (Rec [(CoreBndr, CoreExpr)]
xes)        = [(CoreBndr, CoreExpr)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec         ((CoreExpr -> CoreExpr)
-> (CoreBndr, CoreExpr) -> (CoreBndr, CoreExpr)
forall a b. (a -> b) -> (CoreBndr, a) -> (CoreBndr, b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap CoreExpr -> CoreExpr
go ((CoreBndr, CoreExpr) -> (CoreBndr, CoreExpr))
-> [(CoreBndr, CoreExpr)] -> [(CoreBndr, CoreExpr)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(CoreBndr, CoreExpr)]
xes)
    goB (NonRec CoreBndr
x CoreExpr
e)     = CoreBndr -> CoreExpr -> CoreBind
forall b. b -> Expr b -> Bind b
NonRec CoreBndr
x    (CoreExpr -> CoreExpr
go CoreExpr
e)
    step :: CoreExpr -> CoreExpr
step (Let CoreBind
b CoreExpr
e)       = CoreBind -> CoreExpr -> CoreExpr
forall b. Bind b -> Expr b -> Expr b
Let (CoreBind -> CoreBind
goB CoreBind
b) (CoreExpr -> CoreExpr
go CoreExpr
e)
    step (App CoreExpr
e CoreExpr
e')      = CoreExpr -> CoreExpr -> CoreExpr
forall b. Expr b -> Expr b -> Expr b
App (CoreExpr -> CoreExpr
go CoreExpr
e)  (CoreExpr -> CoreExpr
go CoreExpr
e')
    step (Lam CoreBndr
x CoreExpr
e)       = CoreBndr -> CoreExpr -> CoreExpr
forall b. b -> Expr b -> Expr b
Lam CoreBndr
x       (CoreExpr -> CoreExpr
go CoreExpr
e)
    step (Cast CoreExpr
e CoercionR
c)      = CoreExpr -> CoercionR -> CoreExpr
forall b. Expr b -> CoercionR -> Expr b
Cast (CoreExpr -> CoreExpr
go CoreExpr
e) CoercionR
c
    step (Tick CoreTickish
t CoreExpr
e)      = CoreTickish -> CoreExpr -> CoreExpr
forall b. CoreTickish -> Expr b -> Expr b
Tick CoreTickish
t      (CoreExpr -> CoreExpr
go CoreExpr
e)
    step (Case CoreExpr
e CoreBndr
x Type
t [Alt CoreBndr]
cs) = CoreExpr -> CoreBndr -> Type -> [Alt CoreBndr] -> CoreExpr
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case (CoreExpr -> CoreExpr
go CoreExpr
e) CoreBndr
x Type
t ((\(Alt AltCon
c [CoreBndr]
bs CoreExpr
e') -> AltCon -> [CoreBndr] -> CoreExpr -> Alt CoreBndr
forall b. AltCon -> [b] -> Expr b -> Alt b
Alt AltCon
c [CoreBndr]
bs (CoreExpr -> CoreExpr
go CoreExpr
e')) (Alt CoreBndr -> Alt CoreBndr) -> [Alt CoreBndr] -> [Alt CoreBndr]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Alt CoreBndr]
cs)
    step e :: CoreExpr
e@(Type Type
_)      = CoreExpr
e
    step e :: CoreExpr
e@(Lit Literal
_)       = CoreExpr
e
    step e :: CoreExpr
e@(Var CoreBndr
_)       = CoreExpr
e
    step e :: CoreExpr
e@(Coercion CoercionR
_)  = CoreExpr
e


--------------------------------------------------------------------------------
-- | Rewriting Pattern-Match-Tuples --------------------------------------------
--------------------------------------------------------------------------------

-- | Transforms
--
-- > let ds = case e0 of
-- >            pat -> (x1,...,xn)
-- >     y1 = proj1 ds
-- >     ...
-- >     yn = projn ds
-- >  in e1
--
--  to
--
-- > case e0 of
-- >   pat -> e1[y1 := x1,..., yn := xn]
--
-- Note that the transformation changes the meaning of the expression if
-- evaluation order matters. But it changes it in a way that LH cannot
-- distinguish.
--
-- Also transforms a variant of the above
--
-- > let y1 = case v of
-- >            C x1 ... xn -> xi
-- >     y2 = proj2 v
-- >     ...
-- >     yn = projn v
-- >  in e1
--
--  to
--
-- > case v of
-- >   C x1 ... xn -> e1[y1 := x1,..., yn := xn]
--
-- The purpose of the transformations is to unpack all of the variables in
-- @pat@ at once in a single scope when verifying @e1@, which allows LH to
-- see the dependencies between the fields of @pat@.
--
strictifyLazyLets :: RewriteRule
strictifyLazyLets :: RewriteRule
strictifyLazyLets (Let (NonRec CoreBndr
x e :: CoreExpr
e@(Case CoreExpr
_ CoreBndr
_ Type
_ [Alt (DataAlt DataCon
_) [CoreBndr]
_ CoreExpr
_])) CoreExpr
rest)
  | Just ([CoreBndr]
bs, [CoreBndr]
bs') <- CoreExpr -> Maybe ([CoreBndr], [CoreBndr])
onlyHasATupleInNestedCases CoreExpr
e
  , [CoreBndr] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([CoreBndr]
bs' [CoreBndr] -> [CoreBndr] -> [CoreBndr]
forall a. Eq a => [a] -> [a] -> [a]
L.\\ [CoreBndr]
bs) -- All variables are from the pattern and occur only once
  , let n :: Int
n = [CoreBndr] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [CoreBndr]
bs'
  , Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1
  =
    let ([(CoreBndr, CoreExpr)]
nrbinds, CoreExpr
e') = Int -> CoreExpr -> ([(CoreBndr, CoreExpr)], CoreExpr)
takeBinds Int
n CoreExpr
rest
        fields :: [(Maybe Int, (CoreBndr, CoreExpr))]
fields = [ (CoreBndr -> CoreExpr -> Maybe Int
isProjectionOf CoreBndr
x CoreExpr
ce, (CoreBndr, CoreExpr)
b) | b :: (CoreBndr, CoreExpr)
b@(CoreBndr
_, CoreExpr
ce) <- [(CoreBndr, CoreExpr)]
nrbinds ]
        ([(Maybe Int, (CoreBndr, CoreExpr))]
projs, [(Maybe Int, (CoreBndr, CoreExpr))]
otherBinds) = ((Maybe Int, (CoreBndr, CoreExpr)) -> Bool)
-> [(Maybe Int, (CoreBndr, CoreExpr))]
-> ([(Maybe Int, (CoreBndr, CoreExpr))],
    [(Maybe Int, (CoreBndr, CoreExpr))])
forall a. (a -> Bool) -> [a] -> ([a], [a])
L.partition (Maybe Int -> Bool
forall a. Maybe a -> Bool
isJust (Maybe Int -> Bool)
-> ((Maybe Int, (CoreBndr, CoreExpr)) -> Maybe Int)
-> (Maybe Int, (CoreBndr, CoreExpr))
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Maybe Int, (CoreBndr, CoreExpr)) -> Maybe Int
forall a b. (a, b) -> a
fst) [(Maybe Int, (CoreBndr, CoreExpr))]
fields
        ss :: [(CoreBndr, CoreBndr)]
ss = [ ([CoreBndr]
bs' [CoreBndr] -> Int -> CoreBndr
forall a. HasCallStack => [a] -> Int -> a
!! Int
i, CoreBndr
v) | (Just Int
i, (CoreBndr
v, CoreExpr
_)) <- [(Maybe Int, (CoreBndr, CoreExpr))]
projs ]
        e'' :: CoreExpr
e'' = ((Maybe Int, (CoreBndr, CoreExpr)) -> CoreExpr -> CoreExpr)
-> CoreExpr -> [(Maybe Int, (CoreBndr, CoreExpr))] -> CoreExpr
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\(Maybe Int
_, (CoreBndr
v, CoreExpr
ce)) -> CoreBind -> CoreExpr -> CoreExpr
forall b. Bind b -> Expr b -> Expr b
Let (CoreBndr -> CoreExpr -> CoreBind
forall b. b -> Expr b -> Bind b
NonRec CoreBndr
v CoreExpr
ce)) CoreExpr
e' [(Maybe Int, (CoreBndr, CoreExpr))]
otherBinds
     in RewriteRule
forall a. a -> Maybe a
Just RewriteRule -> RewriteRule
forall a b. (a -> b) -> a -> b
$ CoreBind -> CoreExpr -> CoreExpr
forall b. Bind b -> Expr b -> Expr b
Let (CoreBndr -> CoreExpr -> CoreBind
forall b. b -> Expr b -> Bind b
NonRec CoreBndr
x CoreExpr
e) (CoreExpr -> CoreExpr) -> CoreExpr -> CoreExpr
forall a b. (a -> b) -> a -> b
$
        Type -> [(CoreBndr, CoreBndr)] -> CoreExpr -> CoreExpr -> CoreExpr
replaceAltInNestedCases (HasDebugCallStack => CoreExpr -> Type
CoreExpr -> Type
Ghc.exprType CoreExpr
e') [(CoreBndr, CoreBndr)]
ss CoreExpr
e'' CoreExpr
e

strictifyLazyLets (Let (NonRec CoreBndr
x e :: CoreExpr
e@(Case CoreExpr
e0 CoreBndr
_ Type
_ [Alt (DataAlt DataCon
_) [CoreBndr]
bs CoreExpr
_])) CoreExpr
rest)
  | Just CoreBndr
v0 <- CoreExpr -> Maybe CoreBndr
isVar CoreExpr
e0
  , Just Int
i0 <- CoreBndr -> CoreExpr -> Maybe Int
isProjectionOf CoreBndr
v0 CoreExpr
e
  , let n :: Int
n = [CoreBndr] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [CoreBndr]
bs
  , Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1
  =
    let ([(CoreBndr, CoreExpr)]
nrbinds, CoreExpr
e') = Int -> CoreExpr -> ([(CoreBndr, CoreExpr)], CoreExpr)
takeBinds (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) CoreExpr
rest
        fields :: [(Maybe Int, (CoreBndr, CoreExpr))]
fields = [ (CoreBndr -> CoreExpr -> Maybe Int
isProjectionOf CoreBndr
v0 CoreExpr
ce, (CoreBndr, CoreExpr)
b) | b :: (CoreBndr, CoreExpr)
b@(CoreBndr
_, CoreExpr
ce) <- [(CoreBndr, CoreExpr)]
nrbinds ]
        ([(Maybe Int, (CoreBndr, CoreExpr))]
projs, [(Maybe Int, (CoreBndr, CoreExpr))]
otherBinds) = ((Maybe Int, (CoreBndr, CoreExpr)) -> Bool)
-> [(Maybe Int, (CoreBndr, CoreExpr))]
-> ([(Maybe Int, (CoreBndr, CoreExpr))],
    [(Maybe Int, (CoreBndr, CoreExpr))])
forall a. (a -> Bool) -> [a] -> ([a], [a])
L.partition (Maybe Int -> Bool
forall a. Maybe a -> Bool
isJust (Maybe Int -> Bool)
-> ((Maybe Int, (CoreBndr, CoreExpr)) -> Maybe Int)
-> (Maybe Int, (CoreBndr, CoreExpr))
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Maybe Int, (CoreBndr, CoreExpr)) -> Maybe Int
forall a b. (a, b) -> a
fst) [(Maybe Int, (CoreBndr, CoreExpr))]
fields
        ss :: [(CoreBndr, CoreBndr)]
ss = [ ([CoreBndr]
bs [CoreBndr] -> Int -> CoreBndr
forall a. HasCallStack => [a] -> Int -> a
!! Int
i, CoreBndr
v) | (Just Int
i, (CoreBndr
v, CoreExpr
_)) <- (Int -> Maybe Int
forall a. a -> Maybe a
Just Int
i0, (CoreBndr
x, CoreExpr
e)) (Maybe Int, (CoreBndr, CoreExpr))
-> [(Maybe Int, (CoreBndr, CoreExpr))]
-> [(Maybe Int, (CoreBndr, CoreExpr))]
forall a. a -> [a] -> [a]
: [(Maybe Int, (CoreBndr, CoreExpr))]
projs ]
        e'' :: CoreExpr
e'' = ((Maybe Int, (CoreBndr, CoreExpr)) -> CoreExpr -> CoreExpr)
-> CoreExpr -> [(Maybe Int, (CoreBndr, CoreExpr))] -> CoreExpr
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\(Maybe Int
_, (CoreBndr
v, CoreExpr
ce)) -> CoreBind -> CoreExpr -> CoreExpr
forall b. Bind b -> Expr b -> Expr b
Let (CoreBndr -> CoreExpr -> CoreBind
forall b. b -> Expr b -> Bind b
NonRec CoreBndr
v CoreExpr
ce)) CoreExpr
e' [(Maybe Int, (CoreBndr, CoreExpr))]
otherBinds
     in RewriteRule
forall a. a -> Maybe a
Just RewriteRule -> RewriteRule
forall a b. (a -> b) -> a -> b
$ Type -> [(CoreBndr, CoreBndr)] -> CoreExpr -> CoreExpr -> CoreExpr
replaceAltInNestedCases (HasDebugCallStack => CoreExpr -> Type
CoreExpr -> Type
Ghc.exprType CoreExpr
e') [(CoreBndr, CoreBndr)]
ss CoreExpr
e'' CoreExpr
e

strictifyLazyLets CoreExpr
_
  = Maybe CoreExpr
forall a. Maybe a
Nothing

-- | Replaces an expression at the end of a sequence of nested cases with a
-- single alternative.
replaceAltInNestedCases
  :: Type
  -> [(Var, Var)]
  -> CoreExpr -- ^ The expression to place at the end of the nested cases
  -> CoreExpr -- ^ The expression with the nested cases
  -> CoreExpr
replaceAltInNestedCases :: Type -> [(CoreBndr, CoreBndr)] -> CoreExpr -> CoreExpr -> CoreExpr
replaceAltInNestedCases Type
t [(CoreBndr, CoreBndr)]
ss CoreExpr
ef = CoreExpr -> CoreExpr
go
  where
    go :: CoreExpr -> CoreExpr
go (Case CoreExpr
e0 CoreBndr
v Type
_ [Alt AltCon
c [CoreBndr]
bs CoreExpr
e1]) =
      let bs' :: [CoreBndr]
bs' = [ CoreBndr -> Maybe CoreBndr -> CoreBndr
forall a. a -> Maybe a -> a
fromMaybe CoreBndr
b (CoreBndr -> [(CoreBndr, CoreBndr)] -> Maybe CoreBndr
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup CoreBndr
b [(CoreBndr, CoreBndr)]
ss) | CoreBndr
b <- [CoreBndr]
bs ]
       in CoreExpr -> CoreBndr -> Type -> [Alt CoreBndr] -> CoreExpr
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case CoreExpr
e0 CoreBndr
v Type
t [AltCon -> [CoreBndr] -> CoreExpr -> Alt CoreBndr
forall b. AltCon -> [b] -> Expr b -> Alt b
Alt AltCon
c [CoreBndr]
bs' (CoreExpr -> CoreExpr
go CoreExpr
e1)]
    go CoreExpr
_ = CoreExpr
ef


-- | Takes at most n binds from an expression that starts with n non-recursive
-- lets.
takeBinds  :: Nat -> CoreExpr -> ([(Var, CoreExpr)], CoreExpr)
takeBinds :: Int -> CoreExpr -> ([(CoreBndr, CoreExpr)], CoreExpr)
takeBinds Int
nat CoreExpr
ce = Int -> CoreExpr -> ([(CoreBndr, CoreExpr)], CoreExpr)
forall {t} {a}.
(Eq t, Num t) =>
t -> Expr a -> ([(a, Expr a)], Expr a)
go Int
nat CoreExpr
ce
    where
      go :: t -> Expr a -> ([(a, Expr a)], Expr a)
go t
0 Expr a
e = ([], Expr a
e)
      go t
n (Let (NonRec a
x Expr a
e) Expr a
e') =
        let ([(a, Expr a)]
xes, Expr a
e'') = t -> Expr a -> ([(a, Expr a)], Expr a)
go (t
nt -> t -> t
forall a. Num a => a -> a -> a
-t
1) Expr a
e'
         in ((a
x,Expr a
e) (a, Expr a) -> [(a, Expr a)] -> [(a, Expr a)]
forall a. a -> [a] -> [a]
: [(a, Expr a)]
xes, Expr a
e'')
      go t
_ Expr a
e = ([], Expr a
e)

-- | Checks that the binding is a projections of some data constructor.
-- | Yields the index of the field being projected
isProjectionOf :: Var -> CoreExpr -> Maybe Int
isProjectionOf :: CoreBndr -> CoreExpr -> Maybe Int
isProjectionOf CoreBndr
v (Case CoreExpr
xe CoreBndr
_ Type
_ [Alt (DataAlt DataCon
_) [CoreBndr]
ys (Var CoreBndr
y)])
  | Just CoreBndr
xv <- CoreExpr -> Maybe CoreBndr
isVar CoreExpr
xe
  , CoreBndr
v CoreBndr -> CoreBndr -> Bool
forall a. Eq a => a -> a -> Bool
== CoreBndr
xv = CoreBndr
y CoreBndr -> [CoreBndr] -> Maybe Int
forall a. Eq a => a -> [a] -> Maybe Int
`L.elemIndex` [CoreBndr]
ys
isProjectionOf CoreBndr
_ CoreExpr
_ = Maybe Int
forall a. Maybe a
Nothing

--------------------------------------------------------------------------------
-- | `substTuple xs ys e'` returns e' [y1 := x1,...,yn := xn]
--------------------------------------------------------------------------------
substTuple :: [Var] -> [Var] -> CoreExpr -> CoreExpr
substTuple :: [CoreBndr] -> [CoreBndr] -> CoreExpr -> CoreExpr
substTuple [CoreBndr]
xs [CoreBndr]
ys = HashMap CoreBndr CoreBndr -> CoreExpr -> CoreExpr
substExpr ([(CoreBndr, CoreBndr)] -> HashMap CoreBndr CoreBndr
forall k v. (Eq k, Hashable k) => [(k, v)] -> HashMap k v
M.fromList ([(CoreBndr, CoreBndr)] -> HashMap CoreBndr CoreBndr)
-> [(CoreBndr, CoreBndr)] -> HashMap CoreBndr CoreBndr
forall a b. (a -> b) -> a -> b
$ [CoreBndr] -> [CoreBndr] -> [(CoreBndr, CoreBndr)]
forall a b. [a] -> [b] -> [(a, b)]
zip [CoreBndr]
ys [CoreBndr]
xs)

-- | Yields the tuple of variables at the end of nested cases with
-- a single alternative each.
--
-- > case e0 of
-- >   pat0 -> case e1 of
-- >     pat1 -> (x1,...,xn)
--
-- Yields both the bound variables of the patterns, and the
-- variables @x1,...,xn@
onlyHasATupleInNestedCases :: CoreExpr -> Maybe ([Var], [Var])
onlyHasATupleInNestedCases :: CoreExpr -> Maybe ([CoreBndr], [CoreBndr])
onlyHasATupleInNestedCases = [[CoreBndr]] -> CoreExpr -> Maybe ([CoreBndr], [CoreBndr])
go []
  where
    go :: [[CoreBndr]] -> CoreExpr -> Maybe ([CoreBndr], [CoreBndr])
go [[CoreBndr]]
bss (Case CoreExpr
_ CoreBndr
_ Type
_ [Alt (DataAlt DataCon
_) [CoreBndr]
bs CoreExpr
e]) = [[CoreBndr]] -> CoreExpr -> Maybe ([CoreBndr], [CoreBndr])
go ([CoreBndr]
bs[CoreBndr] -> [[CoreBndr]] -> [[CoreBndr]]
forall a. a -> [a] -> [a]
:[[CoreBndr]]
bss) CoreExpr
e
    go [[CoreBndr]]
bss CoreExpr
e = ([[CoreBndr]] -> [CoreBndr]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[CoreBndr]]
bss,) ([CoreBndr] -> ([CoreBndr], [CoreBndr]))
-> Maybe [CoreBndr] -> Maybe ([CoreBndr], [CoreBndr])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CoreExpr -> Maybe [CoreBndr]
isTuple CoreExpr
e

isTuple :: CoreExpr -> Maybe [Var]
isTuple :: CoreExpr -> Maybe [CoreBndr]
isTuple CoreExpr
e
  | (Var CoreBndr
t, [CoreExpr]
es) <- CoreExpr -> (CoreExpr, [CoreExpr])
forall b. Expr b -> (Expr b, [Expr b])
collectArgs CoreExpr
e
  , CoreBndr -> Bool
isTupleId CoreBndr
t
  , Just [CoreBndr]
xs     <- (CoreExpr -> Maybe CoreBndr) -> [CoreExpr] -> Maybe [CoreBndr]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM CoreExpr -> Maybe CoreBndr
isVar ([CoreExpr] -> [CoreExpr]
forall a. [a] -> [a]
secondHalf [CoreExpr]
es)
  = [CoreBndr] -> Maybe [CoreBndr]
forall a. a -> Maybe a
Just [CoreBndr]
xs
  | Bool
otherwise
  = Maybe [CoreBndr]
forall a. Maybe a
Nothing

isVar :: CoreExpr -> Maybe Var
isVar :: CoreExpr -> Maybe CoreBndr
isVar (Var CoreBndr
x) = CoreBndr -> Maybe CoreBndr
forall a. a -> Maybe a
Just CoreBndr
x
isVar (Tick CoreTickish
_ CoreExpr
e) = CoreExpr -> Maybe CoreBndr
isVar CoreExpr
e
isVar CoreExpr
_       = Maybe CoreBndr
forall a. Maybe a
Nothing

secondHalf :: [a] -> [a]
secondHalf :: forall a. [a] -> [a]
secondHalf [a]
xs = Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
drop (Int
n Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) [a]
xs
  where
    n :: Int
n         = [a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
xs


inlineLoopBreakerTx :: RewriteRule
inlineLoopBreakerTx :: RewriteRule
inlineLoopBreakerTx (Let CoreBind
b CoreExpr
e) = RewriteRule
forall a. a -> Maybe a
Just RewriteRule -> RewriteRule
forall a b. (a -> b) -> a -> b
$ CoreBind -> CoreExpr -> CoreExpr
forall b. Bind b -> Expr b -> Expr b
Let (CoreBind -> CoreBind
inlineLoopBreaker CoreBind
b) CoreExpr
e
inlineLoopBreakerTx CoreExpr
_ = Maybe CoreExpr
forall a. Maybe a
Nothing

-- | Changes top level bindings of the form
--
-- > v = \x1...xn ->
-- >   letrec v0 = \y0...ym -> e0
-- >       in v0 xj..xn
--
-- to
--
-- > v = \x1...xj y0...ym ->
-- >   e0 [ v0 := v x1...xj y0...ym ]
--
inlineLoopBreaker :: Bind Id -> Bind Id
inlineLoopBreaker :: CoreBind -> CoreBind
inlineLoopBreaker (NonRec CoreBndr
x CoreExpr
e)
    | Just (CoreBndr
lbx, CoreExpr
lbe, [CoreExpr]
lbargs) <- CoreExpr -> Maybe (CoreBndr, CoreExpr, [CoreExpr])
hasLoopBreaker CoreExpr
be =
       let asPrefix :: [CoreBndr]
asPrefix = Int -> [CoreBndr] -> [CoreBndr]
forall a. Int -> [a] -> [a]
take ([CoreBndr] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [CoreBndr]
as Int -> Int -> Int
forall a. Num a => a -> a -> a
- [CoreExpr] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [CoreExpr]
lbargs) [CoreBndr]
as
           lbe' :: CoreExpr
lbe' = HashMap CoreBndr CoreExpr -> CoreExpr -> CoreExpr
forall a. Subable a => HashMap CoreBndr CoreExpr -> a -> a
sub (CoreBndr -> CoreExpr -> HashMap CoreBndr CoreExpr
forall k v. Hashable k => k -> v -> HashMap k v
M.singleton CoreBndr
lbx ([CoreBndr] -> CoreExpr
ecall [CoreBndr]
asPrefix)) CoreExpr
lbe
        in [(CoreBndr, CoreExpr)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec [(CoreBndr
x, [CoreBndr] -> CoreExpr -> CoreExpr
forall b. [b] -> Expr b -> Expr b
mkLams ([CoreBndr]
αs [CoreBndr] -> [CoreBndr] -> [CoreBndr]
forall a. [a] -> [a] -> [a]
++ [CoreBndr]
asPrefix) ([CoreBind] -> CoreExpr -> CoreExpr
forall b. [Bind b] -> Expr b -> Expr b
mkLets [CoreBind]
nrbinds CoreExpr
lbe'))]
  where
    ([CoreBndr]
αs, [CoreBndr]
as, CoreExpr
e') = CoreExpr -> ([CoreBndr], [CoreBndr], CoreExpr)
collectTyAndValBinders CoreExpr
e
    ([CoreBind]
nrbinds, CoreExpr
be) = CoreExpr -> ([CoreBind], CoreExpr)
forall t. Expr t -> ([Bind t], Expr t)
collectNonRecLets CoreExpr
e'

    ecall :: [CoreBndr] -> CoreExpr
ecall [CoreBndr]
xs = (CoreExpr -> CoreExpr -> CoreExpr)
-> CoreExpr -> [CoreExpr] -> CoreExpr
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
L.foldl' CoreExpr -> CoreExpr -> CoreExpr
forall b. Expr b -> Expr b -> Expr b
App ((CoreExpr -> CoreExpr -> CoreExpr)
-> CoreExpr -> [CoreExpr] -> CoreExpr
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
L.foldl' CoreExpr -> CoreExpr -> CoreExpr
forall b. Expr b -> Expr b -> Expr b
App (CoreBndr -> CoreExpr
forall b. CoreBndr -> Expr b
Var CoreBndr
x) (Type -> CoreExpr
forall b. Type -> Expr b
Type (Type -> CoreExpr) -> (CoreBndr -> Type) -> CoreBndr -> CoreExpr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CoreBndr -> Type
TyVarTy (CoreBndr -> CoreExpr) -> [CoreBndr] -> [CoreExpr]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CoreBndr]
αs)) (CoreBndr -> CoreExpr
forall b. CoreBndr -> Expr b
Var (CoreBndr -> CoreExpr) -> [CoreBndr] -> [CoreExpr]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CoreBndr]
xs)

    hasLoopBreaker :: CoreExpr -> Maybe (Var, CoreExpr, [CoreExpr])
    hasLoopBreaker :: CoreExpr -> Maybe (CoreBndr, CoreExpr, [CoreExpr])
hasLoopBreaker (Let (Rec [(CoreBndr
x1, CoreExpr
e1)]) CoreExpr
e2)
      | Bool -> Bool
not (InlinePragma -> Bool
isNoInlinePragma (CoreBndr -> InlinePragma
idInlinePragma CoreBndr
x1))
      , (Var CoreBndr
x2, [CoreExpr]
args) <- CoreExpr -> (CoreExpr, [CoreExpr])
forall b. Expr b -> (Expr b, [Expr b])
collectArgs CoreExpr
e2
      , CoreBndr -> Bool
isLoopBreaker CoreBndr
x1
      , CoreBndr
x1 CoreBndr -> CoreBndr -> Bool
forall a. Eq a => a -> a -> Bool
== CoreBndr
x2
      , (CoreExpr -> Bool) -> [CoreExpr] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Maybe CoreBndr -> Bool
forall a. Maybe a -> Bool
isJust (Maybe CoreBndr -> Bool)
-> (CoreExpr -> Maybe CoreBndr) -> CoreExpr -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CoreExpr -> Maybe CoreBndr
isVar) [CoreExpr]
args
      , [CoreBndr] -> [CoreBndr] -> Bool
forall a. Eq a => [a] -> [a] -> Bool
L.isSuffixOf ((CoreExpr -> Maybe CoreBndr) -> [CoreExpr] -> [CoreBndr]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe CoreExpr -> Maybe CoreBndr
forall {b}. Expr b -> Maybe CoreBndr
getVar [CoreExpr]
args) [CoreBndr]
as
      = (CoreBndr, CoreExpr, [CoreExpr])
-> Maybe (CoreBndr, CoreExpr, [CoreExpr])
forall a. a -> Maybe a
Just (CoreBndr
x1, CoreExpr
e1, [CoreExpr]
args)
    hasLoopBreaker CoreExpr
_ = Maybe (CoreBndr, CoreExpr, [CoreExpr])
forall a. Maybe a
Nothing

    isLoopBreaker :: CoreBndr -> Bool
isLoopBreaker =  OccInfo -> Bool
isStrongLoopBreaker (OccInfo -> Bool) -> (CoreBndr -> OccInfo) -> CoreBndr -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IdInfo -> OccInfo
occInfo (IdInfo -> OccInfo) -> (CoreBndr -> IdInfo) -> CoreBndr -> OccInfo
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HasDebugCallStack => CoreBndr -> IdInfo
CoreBndr -> IdInfo
idInfo

    getVar :: Expr b -> Maybe CoreBndr
getVar (Var CoreBndr
v) = CoreBndr -> Maybe CoreBndr
forall a. a -> Maybe a
Just CoreBndr
v
    getVar Expr b
_ = Maybe CoreBndr
forall a. Maybe a
Nothing

inlineLoopBreaker CoreBind
bs
  = CoreBind
bs

collectNonRecLets :: Expr t -> ([Bind t], Expr t)
collectNonRecLets :: forall t. Expr t -> ([Bind t], Expr t)
collectNonRecLets = [Bind t] -> Expr t -> ([Bind t], Expr t)
forall {b}. [Bind b] -> Expr b -> ([Bind b], Expr b)
go []
  where go :: [Bind b] -> Expr b -> ([Bind b], Expr b)
go [Bind b]
bs (Let b :: Bind b
b@(NonRec b
_ Expr b
_) Expr b
e') = [Bind b] -> Expr b -> ([Bind b], Expr b)
go (Bind b
bBind b -> [Bind b] -> [Bind b]
forall a. a -> [a] -> [a]
:[Bind b]
bs) Expr b
e'
        go [Bind b]
bs Expr b
e'                      = ([Bind b] -> [Bind b]
forall a. [a] -> [a]
reverse [Bind b]
bs, Expr b
e')

-- | Inlines bindings of the form
--
-- > let v = \x -> e0
-- >  in e1
--
-- whenever all of the following hold:
--  * "fail" is a prefix of variable @v@,
--  * @x@ is not free in @e0@, and
--  * v is applied to some value in @e1@.
--
-- In addition to inlining, this function also beta reduces
-- the resulting expressions @(\x -> e0) a@ by replacing them
-- with @e0@.
--
inlineFailCases :: CoreBind -> CoreBind
inlineFailCases :: CoreBind -> CoreBind
inlineFailCases = [(CoreBndr, CoreExpr)] -> CoreBind -> CoreBind
go []
  where
    go :: [(CoreBndr, CoreExpr)] -> CoreBind -> CoreBind
go [(CoreBndr, CoreExpr)]
su (Rec [(CoreBndr, CoreExpr)]
xes)    = [(CoreBndr, CoreExpr)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec ((CoreExpr -> CoreExpr)
-> (CoreBndr, CoreExpr) -> (CoreBndr, CoreExpr)
forall a b. (a -> b) -> (CoreBndr, a) -> (CoreBndr, b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([(CoreBndr, CoreExpr)] -> CoreExpr -> CoreExpr
go' [(CoreBndr, CoreExpr)]
su) ((CoreBndr, CoreExpr) -> (CoreBndr, CoreExpr))
-> [(CoreBndr, CoreExpr)] -> [(CoreBndr, CoreExpr)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(CoreBndr, CoreExpr)]
xes)
    go [(CoreBndr, CoreExpr)]
su (NonRec CoreBndr
x CoreExpr
e) = CoreBndr -> CoreExpr -> CoreBind
forall b. b -> Expr b -> Bind b
NonRec CoreBndr
x ([(CoreBndr, CoreExpr)] -> CoreExpr -> CoreExpr
go' [(CoreBndr, CoreExpr)]
su CoreExpr
e)

    go' :: [(CoreBndr, CoreExpr)] -> CoreExpr -> CoreExpr
go' [(CoreBndr, CoreExpr)]
su (App (Var CoreBndr
x) CoreExpr
_)       | CoreBndr -> Bool
isFailId CoreBndr
x, Just CoreExpr
e <- CoreBndr -> [(CoreBndr, CoreExpr)] -> Maybe CoreExpr
forall a b. Eq a => a -> [(a, b)] -> Maybe b
getFailExpr CoreBndr
x [(CoreBndr, CoreExpr)]
su = CoreExpr
e
    go' [(CoreBndr, CoreExpr)]
su (Let (NonRec CoreBndr
x CoreExpr
ex) CoreExpr
e) | CoreBndr -> Bool
isFailId CoreBndr
x   = [(CoreBndr, CoreExpr)] -> CoreExpr -> CoreExpr
go' (CoreBndr
-> CoreExpr -> [(CoreBndr, CoreExpr)] -> [(CoreBndr, CoreExpr)]
forall {a}. a -> CoreExpr -> [(a, CoreExpr)] -> [(a, CoreExpr)]
addFailExpr CoreBndr
x ([(CoreBndr, CoreExpr)] -> CoreExpr -> CoreExpr
go' [(CoreBndr, CoreExpr)]
su CoreExpr
ex) [(CoreBndr, CoreExpr)]
su) CoreExpr
e

    go' [(CoreBndr, CoreExpr)]
su (App CoreExpr
e1 CoreExpr
e2)      = CoreExpr -> CoreExpr -> CoreExpr
forall b. Expr b -> Expr b -> Expr b
App ([(CoreBndr, CoreExpr)] -> CoreExpr -> CoreExpr
go' [(CoreBndr, CoreExpr)]
su CoreExpr
e1) ([(CoreBndr, CoreExpr)] -> CoreExpr -> CoreExpr
go' [(CoreBndr, CoreExpr)]
su CoreExpr
e2)
    go' [(CoreBndr, CoreExpr)]
su (Lam CoreBndr
x CoreExpr
e)        = CoreBndr -> CoreExpr -> CoreExpr
forall b. b -> Expr b -> Expr b
Lam CoreBndr
x ([(CoreBndr, CoreExpr)] -> CoreExpr -> CoreExpr
go' [(CoreBndr, CoreExpr)]
su CoreExpr
e)
    go' [(CoreBndr, CoreExpr)]
su (Let CoreBind
xs CoreExpr
e)       = CoreBind -> CoreExpr -> CoreExpr
forall b. Bind b -> Expr b -> Expr b
Let ([(CoreBndr, CoreExpr)] -> CoreBind -> CoreBind
go [(CoreBndr, CoreExpr)]
su CoreBind
xs) ([(CoreBndr, CoreExpr)] -> CoreExpr -> CoreExpr
go' [(CoreBndr, CoreExpr)]
su CoreExpr
e)
    go' [(CoreBndr, CoreExpr)]
su (Case CoreExpr
e CoreBndr
x Type
t [Alt CoreBndr]
alt) = CoreExpr -> CoreBndr -> Type -> [Alt CoreBndr] -> CoreExpr
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case ([(CoreBndr, CoreExpr)] -> CoreExpr -> CoreExpr
go' [(CoreBndr, CoreExpr)]
su CoreExpr
e) CoreBndr
x Type
t ([(CoreBndr, CoreExpr)] -> Alt CoreBndr -> Alt CoreBndr
goalt [(CoreBndr, CoreExpr)]
su (Alt CoreBndr -> Alt CoreBndr) -> [Alt CoreBndr] -> [Alt CoreBndr]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Alt CoreBndr]
alt)
    go' [(CoreBndr, CoreExpr)]
su (Cast CoreExpr
e CoercionR
c)       = CoreExpr -> CoercionR -> CoreExpr
forall b. Expr b -> CoercionR -> Expr b
Cast ([(CoreBndr, CoreExpr)] -> CoreExpr -> CoreExpr
go' [(CoreBndr, CoreExpr)]
su CoreExpr
e) CoercionR
c
    go' [(CoreBndr, CoreExpr)]
su (Tick CoreTickish
t CoreExpr
e)       = CoreTickish -> CoreExpr -> CoreExpr
forall b. CoreTickish -> Expr b -> Expr b
Tick CoreTickish
t ([(CoreBndr, CoreExpr)] -> CoreExpr -> CoreExpr
go' [(CoreBndr, CoreExpr)]
su CoreExpr
e)
    go' [(CoreBndr, CoreExpr)]
_  CoreExpr
e                = CoreExpr
e

    goalt :: [(CoreBndr, CoreExpr)] -> Alt CoreBndr -> Alt CoreBndr
goalt [(CoreBndr, CoreExpr)]
su (Alt AltCon
c [CoreBndr]
xs CoreExpr
e)   = AltCon -> [CoreBndr] -> CoreExpr -> Alt CoreBndr
forall b. AltCon -> [b] -> Expr b -> Alt b
Alt AltCon
c [CoreBndr]
xs ([(CoreBndr, CoreExpr)] -> CoreExpr -> CoreExpr
go' [(CoreBndr, CoreExpr)]
su CoreExpr
e)

    isFailId :: CoreBndr -> Bool
isFailId CoreBndr
x  = CoreBndr -> Bool
isLocalId CoreBndr
x Bool -> Bool -> Bool
&& Name -> Bool
isSystemName (CoreBndr -> Name
varName CoreBndr
x) Bool -> Bool -> Bool
&& [Char] -> [Char] -> Bool
forall a. Eq a => [a] -> [a] -> Bool
L.isPrefixOf [Char]
"fail" (CoreBndr -> [Char]
forall a. NamedThing a => a -> [Char]
getOccString CoreBndr
x)
    getFailExpr :: a -> [(a, b)] -> Maybe b
getFailExpr = a -> [(a, b)] -> Maybe b
forall a b. Eq a => a -> [(a, b)] -> Maybe b
L.lookup

    addFailExpr :: a -> CoreExpr -> [(a, CoreExpr)] -> [(a, CoreExpr)]
addFailExpr a
x (Lam CoreBndr
v CoreExpr
e) [(a, CoreExpr)]
su
      | Bool -> Bool
not (CoreBndr -> VarSet -> Bool
elemVarSet CoreBndr
v (VarSet -> Bool) -> VarSet -> Bool
forall a b. (a -> b) -> a -> b
$ CoreExpr -> VarSet
exprFreeVars CoreExpr
e)  = (a
x, CoreExpr
e)(a, CoreExpr) -> [(a, CoreExpr)] -> [(a, CoreExpr)]
forall a. a -> [a] -> [a]
:[(a, CoreExpr)]
su
    addFailExpr a
_ CoreExpr
_         [(a, CoreExpr)]
_  = Maybe SrcSpan -> [Char] -> [(a, CoreExpr)]
forall a. Maybe SrcSpan -> [Char] -> a
impossible Maybe SrcSpan
forall a. Maybe a
Nothing [Char]
"internal error" -- this cannot happen