module GHC.Core.Opt.FloatOut ( floatOutwards ) where
import GHC.Prelude
import GHC.Core
import GHC.Core.Utils
import GHC.Core.Make
import GHC.Core.Opt.Monad ( FloatOutSwitches(..) )
import GHC.Driver.Flags  ( DumpFlag (..) )
import GHC.Utils.Logger
import GHC.Types.Id      ( Id, idType,
                           isJoinId, isJoinId_maybe )
import GHC.Types.Tickish
import GHC.Core.Opt.SetLevels
import GHC.Types.Unique.Supply ( UniqSupply )
import GHC.Data.Bag
import GHC.Utils.Misc
import GHC.Data.Maybe
import GHC.Utils.Outputable
import GHC.Utils.Panic
import GHC.Core.Type
import qualified Data.IntMap as M
import Data.List        ( partition )
floatOutwards :: Logger
              -> FloatOutSwitches
              -> UniqSupply
              -> CoreProgram -> IO CoreProgram
floatOutwards :: Logger
-> FloatOutSwitches -> UniqSupply -> [CoreBind] -> IO [CoreBind]
floatOutwards Logger
logger FloatOutSwitches
float_sws UniqSupply
us [CoreBind]
pgm
  = do {
        let { annotated_w_levels :: [LevelledBind]
annotated_w_levels = FloatOutSwitches -> [CoreBind] -> UniqSupply -> [LevelledBind]
setLevels FloatOutSwitches
float_sws [CoreBind]
pgm UniqSupply
us ;
              ([FloatStats]
fss, [Bag CoreBind]
binds_s')    = [(FloatStats, Bag CoreBind)] -> ([FloatStats], [Bag CoreBind])
forall a b. [(a, b)] -> ([a], [b])
unzip ((LevelledBind -> (FloatStats, Bag CoreBind))
-> [LevelledBind] -> [(FloatStats, Bag CoreBind)]
forall a b. (a -> b) -> [a] -> [b]
map LevelledBind -> (FloatStats, Bag CoreBind)
floatTopBind [LevelledBind]
annotated_w_levels)
            } ;
        Logger -> DumpFlag -> String -> DumpFormat -> SDoc -> IO ()
putDumpFileMaybe Logger
logger DumpFlag
Opt_D_verbose_core2core String
"Levels added:"
                  DumpFormat
FormatCore
                  ([SDoc] -> SDoc
forall doc. IsDoc doc => [doc] -> doc
vcat ((LevelledBind -> SDoc) -> [LevelledBind] -> [SDoc]
forall a b. (a -> b) -> [a] -> [b]
map LevelledBind -> SDoc
forall a. Outputable a => a -> SDoc
ppr [LevelledBind]
annotated_w_levels));
        let { (Int
tlets, Int
ntlets, Int
lams) = FloatStats -> (Int, Int, Int)
get_stats ([FloatStats] -> FloatStats
sum_stats [FloatStats]
fss) };
        Logger -> DumpFlag -> String -> DumpFormat -> SDoc -> IO ()
putDumpFileMaybe Logger
logger DumpFlag
Opt_D_dump_simpl_stats String
"FloatOut stats:"
                DumpFormat
FormatText
                ([SDoc] -> SDoc
forall doc. IsLine doc => [doc] -> doc
hcat [ Int -> SDoc
forall doc. IsLine doc => Int -> doc
int Int
tlets,  String -> SDoc
forall doc. IsLine doc => String -> doc
text String
" Lets floated to top level; ",
                        Int -> SDoc
forall doc. IsLine doc => Int -> doc
int Int
ntlets, String -> SDoc
forall doc. IsLine doc => String -> doc
text String
" Lets floated elsewhere; from ",
                        Int -> SDoc
forall doc. IsLine doc => Int -> doc
int Int
lams,   String -> SDoc
forall doc. IsLine doc => String -> doc
text String
" Lambda groups"]);
        [CoreBind] -> IO [CoreBind]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Bag CoreBind -> [CoreBind]
forall a. Bag a -> [a]
bagToList ([Bag CoreBind] -> Bag CoreBind
forall a. [Bag a] -> Bag a
unionManyBags [Bag CoreBind]
binds_s'))
    }
floatTopBind :: LevelledBind -> (FloatStats, Bag CoreBind)
floatTopBind :: LevelledBind -> (FloatStats, Bag CoreBind)
floatTopBind LevelledBind
bind
  = case (LevelledBind -> (FloatStats, FloatBinds, [CoreBind])
floatBind LevelledBind
bind) of { (FloatStats
fs, FloatBinds
floats, [CoreBind]
bind') ->
    let float_bag :: Bag CoreBind
float_bag = FloatBinds -> Bag CoreBind
flattenTopFloats FloatBinds
floats
    in case [CoreBind]
bind' of
      
      
      [Rec [(Id, Expr Id)]
prs]    -> (FloatStats
fs, CoreBind -> Bag CoreBind
forall a. a -> Bag a
unitBag ([(Id, Expr Id)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec (Bag CoreBind -> [(Id, Expr Id)] -> [(Id, Expr Id)]
addTopFloatPairs Bag CoreBind
float_bag [(Id, Expr Id)]
prs)))
      [NonRec Id
b Expr Id
e] -> (FloatStats
fs, Bag CoreBind
float_bag Bag CoreBind -> CoreBind -> Bag CoreBind
forall a. Bag a -> a -> Bag a
`snocBag` Id -> Expr Id -> CoreBind
forall b. b -> Expr b -> Bind b
NonRec Id
b Expr Id
e)
      [CoreBind]
_            -> String -> SDoc -> (FloatStats, Bag CoreBind)
forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"floatTopBind" ([CoreBind] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [CoreBind]
bind') }
floatBind :: LevelledBind -> (FloatStats, FloatBinds, [CoreBind])
  
  
  
  
  
  
  
  
floatBind :: LevelledBind -> (FloatStats, FloatBinds, [CoreBind])
floatBind (NonRec (TB Id
var FloatSpec
_) Expr (TaggedBndr FloatSpec)
rhs)
  = case (Id
-> Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatRhs Id
var Expr (TaggedBndr FloatSpec)
rhs) of { (FloatStats
fs, FloatBinds
rhs_floats, Expr Id
rhs') ->
      (FloatStats
fs, FloatBinds
rhs_floats, [Id -> Expr Id -> CoreBind
forall b. b -> Expr b -> Bind b
NonRec Id
var Expr Id
rhs']) }
floatBind (Rec [(TaggedBndr FloatSpec, Expr (TaggedBndr FloatSpec))]
pairs)
  = case ((TaggedBndr FloatSpec, Expr (TaggedBndr FloatSpec))
 -> (FloatStats, FloatBinds, ([(Id, Expr Id)], [(Id, Expr Id)])))
-> [(TaggedBndr FloatSpec, Expr (TaggedBndr FloatSpec))]
-> (FloatStats, FloatBinds, [([(Id, Expr Id)], [(Id, Expr Id)])])
forall a b.
(a -> (FloatStats, FloatBinds, b))
-> [a] -> (FloatStats, FloatBinds, [b])
floatList (TaggedBndr FloatSpec, Expr (TaggedBndr FloatSpec))
-> (FloatStats, FloatBinds, ([(Id, Expr Id)], [(Id, Expr Id)]))
do_pair [(TaggedBndr FloatSpec, Expr (TaggedBndr FloatSpec))]
pairs of { (FloatStats
fs, FloatBinds
rhs_floats, [([(Id, Expr Id)], [(Id, Expr Id)])]
new_pairs) ->
    let ([[(Id, Expr Id)]]
new_ul_pairss, [[(Id, Expr Id)]]
new_other_pairss) = [([(Id, Expr Id)], [(Id, Expr Id)])]
-> ([[(Id, Expr Id)]], [[(Id, Expr Id)]])
forall a b. [(a, b)] -> ([a], [b])
unzip [([(Id, Expr Id)], [(Id, Expr Id)])]
new_pairs
        ([(Id, Expr Id)]
new_join_pairs, [(Id, Expr Id)]
new_l_pairs)     = ((Id, Expr Id) -> Bool)
-> [(Id, Expr Id)] -> ([(Id, Expr Id)], [(Id, Expr Id)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (Id -> Bool
isJoinId (Id -> Bool) -> ((Id, Expr Id) -> Id) -> (Id, Expr Id) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Id, Expr Id) -> Id
forall a b. (a, b) -> a
fst)
                                                      ([[(Id, Expr Id)]] -> [(Id, Expr Id)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[(Id, Expr Id)]]
new_other_pairss)
        
        new_rec_binds :: [CoreBind]
new_rec_binds | [(Id, Expr Id)] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Id, Expr Id)]
new_join_pairs = [ [(Id, Expr Id)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec [(Id, Expr Id)]
new_l_pairs    ]
                      | [(Id, Expr Id)] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Id, Expr Id)]
new_l_pairs    = [ [(Id, Expr Id)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec [(Id, Expr Id)]
new_join_pairs ]
                      | Bool
otherwise           = [ [(Id, Expr Id)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec [(Id, Expr Id)]
new_l_pairs
                                              , [(Id, Expr Id)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec [(Id, Expr Id)]
new_join_pairs ]
        new_non_rec_binds :: [CoreBind]
new_non_rec_binds = [ Id -> Expr Id -> CoreBind
forall b. b -> Expr b -> Bind b
NonRec Id
b Expr Id
e | (Id
b, Expr Id
e) <- [[(Id, Expr Id)]] -> [(Id, Expr Id)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[(Id, Expr Id)]]
new_ul_pairss ]
    in
    (FloatStats
fs, FloatBinds
rhs_floats, [CoreBind]
new_non_rec_binds [CoreBind] -> [CoreBind] -> [CoreBind]
forall a. [a] -> [a] -> [a]
++ [CoreBind]
new_rec_binds) }
  where
    do_pair :: (LevelledBndr, LevelledExpr)
            -> (FloatStats, FloatBinds,
                ([(Id,CoreExpr)],  
                 [(Id,CoreExpr)])) 
    do_pair :: (TaggedBndr FloatSpec, Expr (TaggedBndr FloatSpec))
-> (FloatStats, FloatBinds, ([(Id, Expr Id)], [(Id, Expr Id)]))
do_pair (TB Id
name FloatSpec
spec, Expr (TaggedBndr FloatSpec)
rhs)
      | Level -> Bool
isTopLvl Level
dest_lvl  
      = case (Id
-> Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatRhs Id
name Expr (TaggedBndr FloatSpec)
rhs) of { (FloatStats
fs, FloatBinds
rhs_floats, Expr Id
rhs') ->
        (FloatStats
fs, FloatBinds
emptyFloats, ([], Bag CoreBind -> [(Id, Expr Id)] -> [(Id, Expr Id)]
addTopFloatPairs (FloatBinds -> Bag CoreBind
flattenTopFloats FloatBinds
rhs_floats)
                                                [(Id
name, Expr Id
rhs')]))}
      | Bool
otherwise         
      = case (Id
-> Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatRhs Id
name Expr (TaggedBndr FloatSpec)
rhs) of { (FloatStats
fs, FloatBinds
rhs_floats, Expr Id
rhs') ->
        case (Level -> FloatBinds -> (FloatBinds, Bag FloatBind)
partitionByLevel Level
dest_lvl FloatBinds
rhs_floats) of { (FloatBinds
rhs_floats', Bag FloatBind
heres) ->
        case (Bag FloatBind -> ([(Id, Expr Id)], [(Id, Expr Id)], Bag FloatBind)
splitRecFloats Bag FloatBind
heres) of { ([(Id, Expr Id)]
ul_pairs, [(Id, Expr Id)]
pairs, Bag FloatBind
case_heres) ->
        let pairs' :: [(Id, Expr Id)]
pairs' = (Id
name, Bag FloatBind -> Expr Id -> Expr Id
installUnderLambdas Bag FloatBind
case_heres Expr Id
rhs') (Id, Expr Id) -> [(Id, Expr Id)] -> [(Id, Expr Id)]
forall a. a -> [a] -> [a]
: [(Id, Expr Id)]
pairs in
        (FloatStats
fs, FloatBinds
rhs_floats', ([(Id, Expr Id)]
ul_pairs, [(Id, Expr Id)]
pairs')) }}}
      where
        dest_lvl :: Level
dest_lvl = FloatSpec -> Level
floatSpecLevel FloatSpec
spec
splitRecFloats :: Bag FloatBind
               -> ([(Id,CoreExpr)], 
                   [(Id,CoreExpr)], 
                   Bag FloatBind)   
splitRecFloats :: Bag FloatBind -> ([(Id, Expr Id)], [(Id, Expr Id)], Bag FloatBind)
splitRecFloats Bag FloatBind
fs
  = [(Id, Expr Id)]
-> [(Id, Expr Id)]
-> [FloatBind]
-> ([(Id, Expr Id)], [(Id, Expr Id)], Bag FloatBind)
go [] [] (Bag FloatBind -> [FloatBind]
forall a. Bag a -> [a]
bagToList Bag FloatBind
fs)
  where
    go :: [(Id, Expr Id)]
-> [(Id, Expr Id)]
-> [FloatBind]
-> ([(Id, Expr Id)], [(Id, Expr Id)], Bag FloatBind)
go [(Id, Expr Id)]
ul_prs [(Id, Expr Id)]
prs (FloatLet (NonRec Id
b Expr Id
r) : [FloatBind]
fs) | (() :: Constraint) => Type -> Bool
Type -> Bool
isUnliftedType (Id -> Type
idType Id
b)
                                               
                                               
                                               , Bool -> Bool
not (Id -> Bool
isJoinId Id
b)
                                               = [(Id, Expr Id)]
-> [(Id, Expr Id)]
-> [FloatBind]
-> ([(Id, Expr Id)], [(Id, Expr Id)], Bag FloatBind)
go ((Id
b,Expr Id
r)(Id, Expr Id) -> [(Id, Expr Id)] -> [(Id, Expr Id)]
forall a. a -> [a] -> [a]
:[(Id, Expr Id)]
ul_prs) [(Id, Expr Id)]
prs [FloatBind]
fs
                                               | Bool
otherwise
                                               = [(Id, Expr Id)]
-> [(Id, Expr Id)]
-> [FloatBind]
-> ([(Id, Expr Id)], [(Id, Expr Id)], Bag FloatBind)
go [(Id, Expr Id)]
ul_prs ((Id
b,Expr Id
r)(Id, Expr Id) -> [(Id, Expr Id)] -> [(Id, Expr Id)]
forall a. a -> [a] -> [a]
:[(Id, Expr Id)]
prs) [FloatBind]
fs
    go [(Id, Expr Id)]
ul_prs [(Id, Expr Id)]
prs (FloatLet (Rec [(Id, Expr Id)]
prs')   : [FloatBind]
fs) = [(Id, Expr Id)]
-> [(Id, Expr Id)]
-> [FloatBind]
-> ([(Id, Expr Id)], [(Id, Expr Id)], Bag FloatBind)
go [(Id, Expr Id)]
ul_prs ([(Id, Expr Id)]
prs' [(Id, Expr Id)] -> [(Id, Expr Id)] -> [(Id, Expr Id)]
forall a. [a] -> [a] -> [a]
++ [(Id, Expr Id)]
prs) [FloatBind]
fs
    go [(Id, Expr Id)]
ul_prs [(Id, Expr Id)]
prs [FloatBind]
fs                           = ([(Id, Expr Id)] -> [(Id, Expr Id)]
forall a. [a] -> [a]
reverse [(Id, Expr Id)]
ul_prs, [(Id, Expr Id)]
prs,
                                                  [FloatBind] -> Bag FloatBind
forall a. [a] -> Bag a
listToBag [FloatBind]
fs)
                                                   
                                                   
installUnderLambdas :: Bag FloatBind -> CoreExpr -> CoreExpr
installUnderLambdas :: Bag FloatBind -> Expr Id -> Expr Id
installUnderLambdas Bag FloatBind
floats Expr Id
e
  | Bag FloatBind -> Bool
forall a. Bag a -> Bool
isEmptyBag Bag FloatBind
floats = Expr Id
e
  | Bool
otherwise         = Expr Id -> Expr Id
go Expr Id
e
  where
    go :: Expr Id -> Expr Id
go (Lam Id
b Expr Id
e)                 = Id -> Expr Id -> Expr Id
forall b. b -> Expr b -> Expr b
Lam Id
b (Expr Id -> Expr Id
go Expr Id
e)
    go Expr Id
e                         = Bag FloatBind -> Expr Id -> Expr Id
install Bag FloatBind
floats Expr Id
e
floatList :: (a -> (FloatStats, FloatBinds, b)) -> [a] -> (FloatStats, FloatBinds, [b])
floatList :: forall a b.
(a -> (FloatStats, FloatBinds, b))
-> [a] -> (FloatStats, FloatBinds, [b])
floatList a -> (FloatStats, FloatBinds, b)
_ [] = (FloatStats
zeroStats, FloatBinds
emptyFloats, [])
floatList a -> (FloatStats, FloatBinds, b)
f (a
a:[a]
as) = case a -> (FloatStats, FloatBinds, b)
f a
a            of { (FloatStats
fs_a,  FloatBinds
binds_a,  b
b)  ->
                     case (a -> (FloatStats, FloatBinds, b))
-> [a] -> (FloatStats, FloatBinds, [b])
forall a b.
(a -> (FloatStats, FloatBinds, b))
-> [a] -> (FloatStats, FloatBinds, [b])
floatList a -> (FloatStats, FloatBinds, b)
f [a]
as of { (FloatStats
fs_as, FloatBinds
binds_as, [b]
bs) ->
                     (FloatStats
fs_a FloatStats -> FloatStats -> FloatStats
`add_stats` FloatStats
fs_as, FloatBinds
binds_a FloatBinds -> FloatBinds -> FloatBinds
`plusFloats`  FloatBinds
binds_as, b
bb -> [b] -> [b]
forall a. a -> [a] -> [a]
:[b]
bs) }}
floatBody :: Level
          -> LevelledExpr
          -> (FloatStats, FloatBinds, CoreExpr)
floatBody :: Level
-> Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatBody Level
lvl Expr (TaggedBndr FloatSpec)
arg       
  = case (Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatExpr Expr (TaggedBndr FloatSpec)
arg) of { (FloatStats
fsa, FloatBinds
floats, Expr Id
arg') ->
    case (Level -> FloatBinds -> (FloatBinds, Bag FloatBind)
partitionByLevel Level
lvl FloatBinds
floats) of { (FloatBinds
floats', Bag FloatBind
heres) ->
        
    (FloatStats
fsa, FloatBinds
floats', Bag FloatBind -> Expr Id -> Expr Id
install Bag FloatBind
heres Expr Id
arg') }}
floatExpr :: LevelledExpr
          -> (FloatStats, FloatBinds, CoreExpr)
floatExpr :: Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatExpr (Var Id
v)   = (FloatStats
zeroStats, FloatBinds
emptyFloats, Id -> Expr Id
forall b. Id -> Expr b
Var Id
v)
floatExpr (Type Type
ty) = (FloatStats
zeroStats, FloatBinds
emptyFloats, Type -> Expr Id
forall b. Type -> Expr b
Type Type
ty)
floatExpr (Coercion Coercion
co) = (FloatStats
zeroStats, FloatBinds
emptyFloats, Coercion -> Expr Id
forall b. Coercion -> Expr b
Coercion Coercion
co)
floatExpr (Lit Literal
lit) = (FloatStats
zeroStats, FloatBinds
emptyFloats, Literal -> Expr Id
forall b. Literal -> Expr b
Lit Literal
lit)
floatExpr (App Expr (TaggedBndr FloatSpec)
e Expr (TaggedBndr FloatSpec)
a)
  = case ((FloatStats, FloatBinds, Expr Id)
-> (FloatStats, FloatBinds, Expr Id)
atJoinCeiling ((FloatStats, FloatBinds, Expr Id)
 -> (FloatStats, FloatBinds, Expr Id))
-> (FloatStats, FloatBinds, Expr Id)
-> (FloatStats, FloatBinds, Expr Id)
forall a b. (a -> b) -> a -> b
$ Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatExpr  Expr (TaggedBndr FloatSpec)
e) of { (FloatStats
fse, FloatBinds
floats_e, Expr Id
e') ->
    case ((FloatStats, FloatBinds, Expr Id)
-> (FloatStats, FloatBinds, Expr Id)
atJoinCeiling ((FloatStats, FloatBinds, Expr Id)
 -> (FloatStats, FloatBinds, Expr Id))
-> (FloatStats, FloatBinds, Expr Id)
-> (FloatStats, FloatBinds, Expr Id)
forall a b. (a -> b) -> a -> b
$ Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatExpr  Expr (TaggedBndr FloatSpec)
a) of { (FloatStats
fsa, FloatBinds
floats_a, Expr Id
a') ->
    (FloatStats
fse FloatStats -> FloatStats -> FloatStats
`add_stats` FloatStats
fsa, FloatBinds
floats_e FloatBinds -> FloatBinds -> FloatBinds
`plusFloats` FloatBinds
floats_a, Expr Id -> Expr Id -> Expr Id
forall b. Expr b -> Expr b -> Expr b
App Expr Id
e' Expr Id
a') }}
floatExpr lam :: Expr (TaggedBndr FloatSpec)
lam@(Lam (TB Id
_ FloatSpec
lam_spec) Expr (TaggedBndr FloatSpec)
_)
  = let ([TaggedBndr FloatSpec]
bndrs_w_lvls, Expr (TaggedBndr FloatSpec)
body) = Expr (TaggedBndr FloatSpec)
-> ([TaggedBndr FloatSpec], Expr (TaggedBndr FloatSpec))
forall b. Expr b -> ([b], Expr b)
collectBinders Expr (TaggedBndr FloatSpec)
lam
        bndrs :: [Id]
bndrs                = [Id
b | TB Id
b FloatSpec
_ <- [TaggedBndr FloatSpec]
bndrs_w_lvls]
        bndr_lvl :: Level
bndr_lvl             = Level -> Level
asJoinCeilLvl (FloatSpec -> Level
floatSpecLevel FloatSpec
lam_spec)
        
        
        
    in
    case (Level
-> Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatBody Level
bndr_lvl Expr (TaggedBndr FloatSpec)
body) of { (FloatStats
fs, FloatBinds
floats, Expr Id
body') ->
    (FloatStats -> FloatBinds -> FloatStats
add_to_stats FloatStats
fs FloatBinds
floats, FloatBinds
floats, [Id] -> Expr Id -> Expr Id
forall b. [b] -> Expr b -> Expr b
mkLams [Id]
bndrs Expr Id
body') }
floatExpr (Tick CoreTickish
tickish Expr (TaggedBndr FloatSpec)
expr)
  | CoreTickish
tickish CoreTickish -> TickishScoping -> Bool
forall (pass :: TickishPass).
GenTickish pass -> TickishScoping -> Bool
`tickishScopesLike` TickishScoping
SoftScope 
  = case ((FloatStats, FloatBinds, Expr Id)
-> (FloatStats, FloatBinds, Expr Id)
atJoinCeiling ((FloatStats, FloatBinds, Expr Id)
 -> (FloatStats, FloatBinds, Expr Id))
-> (FloatStats, FloatBinds, Expr Id)
-> (FloatStats, FloatBinds, Expr Id)
forall a b. (a -> b) -> a -> b
$ Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatExpr Expr (TaggedBndr FloatSpec)
expr)    of { (FloatStats
fs, FloatBinds
floating_defns, Expr Id
expr') ->
    (FloatStats
fs, FloatBinds
floating_defns, CoreTickish -> Expr Id -> Expr Id
forall b. CoreTickish -> Expr b -> Expr b
Tick CoreTickish
tickish Expr Id
expr') }
  | Bool -> Bool
not (CoreTickish -> Bool
forall (pass :: TickishPass). GenTickish pass -> Bool
tickishCounts CoreTickish
tickish) Bool -> Bool -> Bool
|| CoreTickish -> Bool
forall (pass :: TickishPass). GenTickish pass -> Bool
tickishCanSplit CoreTickish
tickish
  = case ((FloatStats, FloatBinds, Expr Id)
-> (FloatStats, FloatBinds, Expr Id)
atJoinCeiling ((FloatStats, FloatBinds, Expr Id)
 -> (FloatStats, FloatBinds, Expr Id))
-> (FloatStats, FloatBinds, Expr Id)
-> (FloatStats, FloatBinds, Expr Id)
forall a b. (a -> b) -> a -> b
$ Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatExpr Expr (TaggedBndr FloatSpec)
expr)    of { (FloatStats
fs, FloatBinds
floating_defns, Expr Id
expr') ->
    let 
        
        annotated_defns :: FloatBinds
annotated_defns = CoreTickish -> FloatBinds -> FloatBinds
wrapTick (CoreTickish -> CoreTickish
forall (pass :: TickishPass). GenTickish pass -> GenTickish pass
mkNoCount CoreTickish
tickish) FloatBinds
floating_defns
    in
    (FloatStats
fs, FloatBinds
annotated_defns, CoreTickish -> Expr Id -> Expr Id
forall b. CoreTickish -> Expr b -> Expr b
Tick CoreTickish
tickish Expr Id
expr') }
  
  | Breakpoint{} <- CoreTickish
tickish
  = case (Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatExpr Expr (TaggedBndr FloatSpec)
expr)    of { (FloatStats
fs, FloatBinds
floating_defns, Expr Id
expr') ->
    (FloatStats
fs, FloatBinds
floating_defns, CoreTickish -> Expr Id -> Expr Id
forall b. CoreTickish -> Expr b -> Expr b
Tick CoreTickish
tickish Expr Id
expr') }
  | Bool
otherwise
  = String -> SDoc -> (FloatStats, FloatBinds, Expr Id)
forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"floatExpr tick" (CoreTickish -> SDoc
forall a. Outputable a => a -> SDoc
ppr CoreTickish
tickish)
floatExpr (Cast Expr (TaggedBndr FloatSpec)
expr Coercion
co)
  = case ((FloatStats, FloatBinds, Expr Id)
-> (FloatStats, FloatBinds, Expr Id)
atJoinCeiling ((FloatStats, FloatBinds, Expr Id)
 -> (FloatStats, FloatBinds, Expr Id))
-> (FloatStats, FloatBinds, Expr Id)
-> (FloatStats, FloatBinds, Expr Id)
forall a b. (a -> b) -> a -> b
$ Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatExpr Expr (TaggedBndr FloatSpec)
expr) of { (FloatStats
fs, FloatBinds
floating_defns, Expr Id
expr') ->
    (FloatStats
fs, FloatBinds
floating_defns, Expr Id -> Coercion -> Expr Id
forall b. Expr b -> Coercion -> Expr b
Cast Expr Id
expr' Coercion
co) }
floatExpr (Let LevelledBind
bind Expr (TaggedBndr FloatSpec)
body)
  = case FloatSpec
bind_spec of
      FloatMe Level
dest_lvl
        -> case (LevelledBind -> (FloatStats, FloatBinds, [CoreBind])
floatBind LevelledBind
bind) of { (FloatStats
fsb, FloatBinds
bind_floats, [CoreBind]
binds') ->
           case (Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatExpr Expr (TaggedBndr FloatSpec)
body) of { (FloatStats
fse, FloatBinds
body_floats, Expr Id
body') ->
           let new_bind_floats :: FloatBinds
new_bind_floats = (FloatBinds -> FloatBinds -> FloatBinds)
-> FloatBinds -> [FloatBinds] -> FloatBinds
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr FloatBinds -> FloatBinds -> FloatBinds
plusFloats FloatBinds
emptyFloats
                                   ((CoreBind -> FloatBinds) -> [CoreBind] -> [FloatBinds]
forall a b. (a -> b) -> [a] -> [b]
map (Level -> CoreBind -> FloatBinds
unitLetFloat Level
dest_lvl) [CoreBind]
binds') in
           ( FloatStats -> FloatStats -> FloatStats
add_stats FloatStats
fsb FloatStats
fse
           , FloatBinds
bind_floats FloatBinds -> FloatBinds -> FloatBinds
`plusFloats` FloatBinds
new_bind_floats
                         FloatBinds -> FloatBinds -> FloatBinds
`plusFloats` FloatBinds
body_floats
           , Expr Id
body') }}
      StayPut Level
bind_lvl  
        -> case (LevelledBind -> (FloatStats, FloatBinds, [CoreBind])
floatBind LevelledBind
bind)          of { (FloatStats
fsb, FloatBinds
bind_floats, [CoreBind]
binds') ->
           case (Level
-> Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatBody Level
bind_lvl Expr (TaggedBndr FloatSpec)
body) of { (FloatStats
fse, FloatBinds
body_floats, Expr Id
body') ->
           ( FloatStats -> FloatStats -> FloatStats
add_stats FloatStats
fsb FloatStats
fse
           , FloatBinds
bind_floats FloatBinds -> FloatBinds -> FloatBinds
`plusFloats` FloatBinds
body_floats
           , (CoreBind -> Expr Id -> Expr Id)
-> Expr Id -> [CoreBind] -> Expr Id
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr CoreBind -> Expr Id -> Expr Id
forall b. Bind b -> Expr b -> Expr b
Let Expr Id
body' [CoreBind]
binds' ) }}
  where
    bind_spec :: FloatSpec
bind_spec = case LevelledBind
bind of
                 NonRec (TB Id
_ FloatSpec
s) Expr (TaggedBndr FloatSpec)
_     -> FloatSpec
s
                 Rec ((TB Id
_ FloatSpec
s, Expr (TaggedBndr FloatSpec)
_) : [(TaggedBndr FloatSpec, Expr (TaggedBndr FloatSpec))]
_) -> FloatSpec
s
                 Rec []                -> String -> FloatSpec
forall a. HasCallStack => String -> a
panic String
"floatExpr:rec"
floatExpr (Case Expr (TaggedBndr FloatSpec)
scrut (TB Id
case_bndr FloatSpec
case_spec) Type
ty [Alt (TaggedBndr FloatSpec)]
alts)
  = case FloatSpec
case_spec of
      FloatMe Level
dest_lvl  
        | [Alt con :: AltCon
con@(DataAlt {}) [TaggedBndr FloatSpec]
bndrs Expr (TaggedBndr FloatSpec)
rhs] <- [Alt (TaggedBndr FloatSpec)]
alts
        -> case (FloatStats, FloatBinds, Expr Id)
-> (FloatStats, FloatBinds, Expr Id)
atJoinCeiling ((FloatStats, FloatBinds, Expr Id)
 -> (FloatStats, FloatBinds, Expr Id))
-> (FloatStats, FloatBinds, Expr Id)
-> (FloatStats, FloatBinds, Expr Id)
forall a b. (a -> b) -> a -> b
$ Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatExpr Expr (TaggedBndr FloatSpec)
scrut of { (FloatStats
fse, FloatBinds
fde, Expr Id
scrut') ->
           case                 Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatExpr Expr (TaggedBndr FloatSpec)
rhs   of { (FloatStats
fsb, FloatBinds
fdb, Expr Id
rhs') ->
           let
             float :: FloatBinds
float = Level -> Expr Id -> Id -> AltCon -> [Id] -> FloatBinds
unitCaseFloat Level
dest_lvl Expr Id
scrut'
                          Id
case_bndr AltCon
con [Id
b | TB Id
b FloatSpec
_ <- [TaggedBndr FloatSpec]
bndrs]
           in
           (FloatStats -> FloatStats -> FloatStats
add_stats FloatStats
fse FloatStats
fsb, FloatBinds
fde FloatBinds -> FloatBinds -> FloatBinds
`plusFloats` FloatBinds
float FloatBinds -> FloatBinds -> FloatBinds
`plusFloats` FloatBinds
fdb, Expr Id
rhs') }}
        | Bool
otherwise
        -> String -> SDoc -> (FloatStats, FloatBinds, Expr Id)
forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"Floating multi-case" ([Alt (TaggedBndr FloatSpec)] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [Alt (TaggedBndr FloatSpec)]
alts)
      StayPut Level
bind_lvl  
        -> case (FloatStats, FloatBinds, Expr Id)
-> (FloatStats, FloatBinds, Expr Id)
atJoinCeiling ((FloatStats, FloatBinds, Expr Id)
 -> (FloatStats, FloatBinds, Expr Id))
-> (FloatStats, FloatBinds, Expr Id)
-> (FloatStats, FloatBinds, Expr Id)
forall a b. (a -> b) -> a -> b
$ Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatExpr Expr (TaggedBndr FloatSpec)
scrut of { (FloatStats
fse, FloatBinds
fde, Expr Id
scrut') ->
           case (Alt (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Alt Id))
-> [Alt (TaggedBndr FloatSpec)]
-> (FloatStats, FloatBinds, [Alt Id])
forall a b.
(a -> (FloatStats, FloatBinds, b))
-> [a] -> (FloatStats, FloatBinds, [b])
floatList (Level
-> Alt (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Alt Id)
float_alt Level
bind_lvl) [Alt (TaggedBndr FloatSpec)]
alts of { (FloatStats
fsa, FloatBinds
fda, [Alt Id]
alts')  ->
           (FloatStats -> FloatStats -> FloatStats
add_stats FloatStats
fse FloatStats
fsa, FloatBinds
fda FloatBinds -> FloatBinds -> FloatBinds
`plusFloats` FloatBinds
fde, Expr Id -> Id -> Type -> [Alt Id] -> Expr Id
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case Expr Id
scrut' Id
case_bndr Type
ty [Alt Id]
alts')
           }}
  where
    float_alt :: Level
-> Alt (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Alt Id)
float_alt Level
bind_lvl (Alt AltCon
con [TaggedBndr FloatSpec]
bs Expr (TaggedBndr FloatSpec)
rhs)
        = case (Level
-> Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatBody Level
bind_lvl Expr (TaggedBndr FloatSpec)
rhs) of { (FloatStats
fs, FloatBinds
rhs_floats, Expr Id
rhs') ->
          (FloatStats
fs, FloatBinds
rhs_floats, AltCon -> [Id] -> Expr Id -> Alt Id
forall b. AltCon -> [b] -> Expr b -> Alt b
Alt AltCon
con [Id
b | TB Id
b FloatSpec
_ <- [TaggedBndr FloatSpec]
bs] Expr Id
rhs') }
floatRhs :: CoreBndr
         -> LevelledExpr
         -> (FloatStats, FloatBinds, CoreExpr)
floatRhs :: Id
-> Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatRhs Id
bndr Expr (TaggedBndr FloatSpec)
rhs
  | Just Int
join_arity <- Id -> Maybe Int
isJoinId_maybe Id
bndr
  , Just ([TaggedBndr FloatSpec]
bndrs, Expr (TaggedBndr FloatSpec)
body) <- Int
-> Expr (TaggedBndr FloatSpec)
-> [TaggedBndr FloatSpec]
-> Maybe ([TaggedBndr FloatSpec], Expr (TaggedBndr FloatSpec))
forall {t} {a}.
(Eq t, Num t) =>
t -> Expr a -> [a] -> Maybe ([a], Expr a)
try_collect Int
join_arity Expr (TaggedBndr FloatSpec)
rhs []
  = case [TaggedBndr FloatSpec]
bndrs of
      []                -> Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatExpr Expr (TaggedBndr FloatSpec)
rhs
      (TB Id
_ FloatSpec
lam_spec):[TaggedBndr FloatSpec]
_ ->
        let lvl :: Level
lvl = FloatSpec -> Level
floatSpecLevel FloatSpec
lam_spec in
        case Level
-> Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatBody Level
lvl Expr (TaggedBndr FloatSpec)
body of { (FloatStats
fs, FloatBinds
floats, Expr Id
body') ->
        (FloatStats
fs, FloatBinds
floats, [Id] -> Expr Id -> Expr Id
forall b. [b] -> Expr b -> Expr b
mkLams [Id
b | TB Id
b FloatSpec
_ <- [TaggedBndr FloatSpec]
bndrs] Expr Id
body') }
  | Bool
otherwise
  = (FloatStats, FloatBinds, Expr Id)
-> (FloatStats, FloatBinds, Expr Id)
atJoinCeiling ((FloatStats, FloatBinds, Expr Id)
 -> (FloatStats, FloatBinds, Expr Id))
-> (FloatStats, FloatBinds, Expr Id)
-> (FloatStats, FloatBinds, Expr Id)
forall a b. (a -> b) -> a -> b
$ Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatExpr Expr (TaggedBndr FloatSpec)
rhs
  where
    try_collect :: t -> Expr a -> [a] -> Maybe ([a], Expr a)
try_collect t
0 Expr a
expr      [a]
acc = ([a], Expr a) -> Maybe ([a], Expr a)
forall a. a -> Maybe a
Just ([a] -> [a]
forall a. [a] -> [a]
reverse [a]
acc, Expr a
expr)
    try_collect t
n (Lam a
b Expr a
e) [a]
acc = t -> Expr a -> [a] -> Maybe ([a], Expr a)
try_collect (t
nt -> t -> t
forall a. Num a => a -> a -> a
-t
1) Expr a
e (a
ba -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
acc)
    try_collect t
_ Expr a
_         [a]
_   = Maybe ([a], Expr a)
forall a. Maybe a
Nothing
data FloatStats
  = FlS Int  
        Int  
        Int  
get_stats :: FloatStats -> (Int, Int, Int)
get_stats :: FloatStats -> (Int, Int, Int)
get_stats (FlS Int
a Int
b Int
c) = (Int
a, Int
b, Int
c)
zeroStats :: FloatStats
zeroStats :: FloatStats
zeroStats = Int -> Int -> Int -> FloatStats
FlS Int
0 Int
0 Int
0
sum_stats :: [FloatStats] -> FloatStats
sum_stats :: [FloatStats] -> FloatStats
sum_stats [FloatStats]
xs = (FloatStats -> FloatStats -> FloatStats)
-> FloatStats -> [FloatStats] -> FloatStats
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr FloatStats -> FloatStats -> FloatStats
add_stats FloatStats
zeroStats [FloatStats]
xs
add_stats :: FloatStats -> FloatStats -> FloatStats
add_stats :: FloatStats -> FloatStats -> FloatStats
add_stats (FlS Int
a1 Int
b1 Int
c1) (FlS Int
a2 Int
b2 Int
c2)
  = Int -> Int -> Int -> FloatStats
FlS (Int
a1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
a2) (Int
b1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
b2) (Int
c1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
c2)
add_to_stats :: FloatStats -> FloatBinds -> FloatStats
add_to_stats :: FloatStats -> FloatBinds -> FloatStats
add_to_stats (FlS Int
a Int
b Int
c) (FB Bag CoreBind
tops Bag FloatBind
ceils MajorEnv
others)
  = Int -> Int -> Int -> FloatStats
FlS (Int
a Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Bag CoreBind -> Int
forall a. Bag a -> Int
lengthBag Bag CoreBind
tops)
        (Int
b Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Bag FloatBind -> Int
forall a. Bag a -> Int
lengthBag Bag FloatBind
ceils Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Bag FloatBind -> Int
forall a. Bag a -> Int
lengthBag (MajorEnv -> Bag FloatBind
flattenMajor MajorEnv
others))
        (Int
c Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
type FloatLet = CoreBind        
type MajorEnv = M.IntMap MinorEnv         
type MinorEnv = M.IntMap (Bag FloatBind)  
data FloatBinds  = FB !(Bag FloatLet)           
                      !(Bag FloatBind)          
                      !MajorEnv                 
     
instance Outputable FloatBinds where
  ppr :: FloatBinds -> SDoc
ppr (FB Bag CoreBind
fbs Bag FloatBind
ceils MajorEnv
defs)
      = String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"FB" SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<+> (SDoc -> SDoc
forall doc. IsLine doc => doc -> doc
braces (SDoc -> SDoc) -> SDoc -> SDoc
forall a b. (a -> b) -> a -> b
$ [SDoc] -> SDoc
forall doc. IsDoc doc => [doc] -> doc
vcat
           [ String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"tops ="     SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<+> Bag CoreBind -> SDoc
forall a. Outputable a => a -> SDoc
ppr Bag CoreBind
fbs
           , String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"ceils ="    SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<+> Bag FloatBind -> SDoc
forall a. Outputable a => a -> SDoc
ppr Bag FloatBind
ceils
           , String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"non-tops =" SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<+> MajorEnv -> SDoc
forall a. Outputable a => a -> SDoc
ppr MajorEnv
defs ])
flattenTopFloats :: FloatBinds -> Bag CoreBind
flattenTopFloats :: FloatBinds -> Bag CoreBind
flattenTopFloats (FB Bag CoreBind
tops Bag FloatBind
ceils MajorEnv
defs)
  = Bool -> SDoc -> Bag CoreBind -> Bag CoreBind
forall a. HasCallStack => Bool -> SDoc -> a -> a
assertPpr (Bag FloatBind -> Bool
forall a. Bag a -> Bool
isEmptyBag (MajorEnv -> Bag FloatBind
flattenMajor MajorEnv
defs)) (MajorEnv -> SDoc
forall a. Outputable a => a -> SDoc
ppr MajorEnv
defs) (Bag CoreBind -> Bag CoreBind) -> Bag CoreBind -> Bag CoreBind
forall a b. (a -> b) -> a -> b
$
    Bool -> SDoc -> Bag CoreBind -> Bag CoreBind
forall a. HasCallStack => Bool -> SDoc -> a -> a
assertPpr (Bag FloatBind -> Bool
forall a. Bag a -> Bool
isEmptyBag Bag FloatBind
ceils) (Bag FloatBind -> SDoc
forall a. Outputable a => a -> SDoc
ppr Bag FloatBind
ceils)
    Bag CoreBind
tops
addTopFloatPairs :: Bag CoreBind -> [(Id,CoreExpr)] -> [(Id,CoreExpr)]
addTopFloatPairs :: Bag CoreBind -> [(Id, Expr Id)] -> [(Id, Expr Id)]
addTopFloatPairs Bag CoreBind
float_bag [(Id, Expr Id)]
prs
  = (CoreBind -> [(Id, Expr Id)] -> [(Id, Expr Id)])
-> [(Id, Expr Id)] -> Bag CoreBind -> [(Id, Expr Id)]
forall a b. (a -> b -> b) -> b -> Bag a -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr CoreBind -> [(Id, Expr Id)] -> [(Id, Expr Id)]
forall {a}. Bind a -> [(a, Expr a)] -> [(a, Expr a)]
add [(Id, Expr Id)]
prs Bag CoreBind
float_bag
  where
    add :: Bind a -> [(a, Expr a)] -> [(a, Expr a)]
add (NonRec a
b Expr a
r) [(a, Expr a)]
prs  = (a
b,Expr a
r)(a, Expr a) -> [(a, Expr a)] -> [(a, Expr a)]
forall a. a -> [a] -> [a]
:[(a, Expr a)]
prs
    add (Rec [(a, Expr a)]
prs1)   [(a, Expr a)]
prs2 = [(a, Expr a)]
prs1 [(a, Expr a)] -> [(a, Expr a)] -> [(a, Expr a)]
forall a. [a] -> [a] -> [a]
++ [(a, Expr a)]
prs2
flattenMajor :: MajorEnv -> Bag FloatBind
flattenMajor :: MajorEnv -> Bag FloatBind
flattenMajor = (IntMap (Bag FloatBind) -> Bag FloatBind -> Bag FloatBind)
-> Bag FloatBind -> MajorEnv -> Bag FloatBind
forall a b. (a -> b -> b) -> b -> IntMap a -> b
M.foldr (Bag FloatBind -> Bag FloatBind -> Bag FloatBind
forall a. Bag a -> Bag a -> Bag a
unionBags (Bag FloatBind -> Bag FloatBind -> Bag FloatBind)
-> (IntMap (Bag FloatBind) -> Bag FloatBind)
-> IntMap (Bag FloatBind)
-> Bag FloatBind
-> Bag FloatBind
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IntMap (Bag FloatBind) -> Bag FloatBind
flattenMinor) Bag FloatBind
forall a. Bag a
emptyBag
flattenMinor :: MinorEnv -> Bag FloatBind
flattenMinor :: IntMap (Bag FloatBind) -> Bag FloatBind
flattenMinor = (Bag FloatBind -> Bag FloatBind -> Bag FloatBind)
-> Bag FloatBind -> IntMap (Bag FloatBind) -> Bag FloatBind
forall a b. (a -> b -> b) -> b -> IntMap a -> b
M.foldr Bag FloatBind -> Bag FloatBind -> Bag FloatBind
forall a. Bag a -> Bag a -> Bag a
unionBags Bag FloatBind
forall a. Bag a
emptyBag
emptyFloats :: FloatBinds
emptyFloats :: FloatBinds
emptyFloats = Bag CoreBind -> Bag FloatBind -> MajorEnv -> FloatBinds
FB Bag CoreBind
forall a. Bag a
emptyBag Bag FloatBind
forall a. Bag a
emptyBag MajorEnv
forall a. IntMap a
M.empty
unitCaseFloat :: Level -> CoreExpr -> Id -> AltCon -> [Var] -> FloatBinds
unitCaseFloat :: Level -> Expr Id -> Id -> AltCon -> [Id] -> FloatBinds
unitCaseFloat (Level Int
major Int
minor LevelType
t) Expr Id
e Id
b AltCon
con [Id]
bs
  | LevelType
t LevelType -> LevelType -> Bool
forall a. Eq a => a -> a -> Bool
== LevelType
JoinCeilLvl
  = Bag CoreBind -> Bag FloatBind -> MajorEnv -> FloatBinds
FB Bag CoreBind
forall a. Bag a
emptyBag Bag FloatBind
floats MajorEnv
forall a. IntMap a
M.empty
  | Bool
otherwise
  = Bag CoreBind -> Bag FloatBind -> MajorEnv -> FloatBinds
FB Bag CoreBind
forall a. Bag a
emptyBag Bag FloatBind
forall a. Bag a
emptyBag (Int -> IntMap (Bag FloatBind) -> MajorEnv
forall a. Int -> a -> IntMap a
M.singleton Int
major (Int -> Bag FloatBind -> IntMap (Bag FloatBind)
forall a. Int -> a -> IntMap a
M.singleton Int
minor Bag FloatBind
floats))
  where
    floats :: Bag FloatBind
floats = FloatBind -> Bag FloatBind
forall a. a -> Bag a
unitBag (Expr Id -> Id -> AltCon -> [Id] -> FloatBind
FloatCase Expr Id
e Id
b AltCon
con [Id]
bs)
unitLetFloat :: Level -> FloatLet -> FloatBinds
unitLetFloat :: Level -> CoreBind -> FloatBinds
unitLetFloat lvl :: Level
lvl@(Level Int
major Int
minor LevelType
t) CoreBind
b
  | Level -> Bool
isTopLvl Level
lvl     = Bag CoreBind -> Bag FloatBind -> MajorEnv -> FloatBinds
FB (CoreBind -> Bag CoreBind
forall a. a -> Bag a
unitBag CoreBind
b) Bag FloatBind
forall a. Bag a
emptyBag MajorEnv
forall a. IntMap a
M.empty
  | LevelType
t LevelType -> LevelType -> Bool
forall a. Eq a => a -> a -> Bool
== LevelType
JoinCeilLvl = Bag CoreBind -> Bag FloatBind -> MajorEnv -> FloatBinds
FB Bag CoreBind
forall a. Bag a
emptyBag Bag FloatBind
floats MajorEnv
forall a. IntMap a
M.empty
  | Bool
otherwise        = Bag CoreBind -> Bag FloatBind -> MajorEnv -> FloatBinds
FB Bag CoreBind
forall a. Bag a
emptyBag Bag FloatBind
forall a. Bag a
emptyBag (Int -> IntMap (Bag FloatBind) -> MajorEnv
forall a. Int -> a -> IntMap a
M.singleton Int
major
                                              (Int -> Bag FloatBind -> IntMap (Bag FloatBind)
forall a. Int -> a -> IntMap a
M.singleton Int
minor Bag FloatBind
floats))
  where
    floats :: Bag FloatBind
floats = FloatBind -> Bag FloatBind
forall a. a -> Bag a
unitBag (CoreBind -> FloatBind
FloatLet CoreBind
b)
plusFloats :: FloatBinds -> FloatBinds -> FloatBinds
plusFloats :: FloatBinds -> FloatBinds -> FloatBinds
plusFloats (FB Bag CoreBind
t1 Bag FloatBind
c1 MajorEnv
l1) (FB Bag CoreBind
t2 Bag FloatBind
c2 MajorEnv
l2)
  = Bag CoreBind -> Bag FloatBind -> MajorEnv -> FloatBinds
FB (Bag CoreBind
t1 Bag CoreBind -> Bag CoreBind -> Bag CoreBind
forall a. Bag a -> Bag a -> Bag a
`unionBags` Bag CoreBind
t2) (Bag FloatBind
c1 Bag FloatBind -> Bag FloatBind -> Bag FloatBind
forall a. Bag a -> Bag a -> Bag a
`unionBags` Bag FloatBind
c2) (MajorEnv
l1 MajorEnv -> MajorEnv -> MajorEnv
`plusMajor` MajorEnv
l2)
plusMajor :: MajorEnv -> MajorEnv -> MajorEnv
plusMajor :: MajorEnv -> MajorEnv -> MajorEnv
plusMajor = (IntMap (Bag FloatBind)
 -> IntMap (Bag FloatBind) -> IntMap (Bag FloatBind))
-> MajorEnv -> MajorEnv -> MajorEnv
forall a. (a -> a -> a) -> IntMap a -> IntMap a -> IntMap a
M.unionWith IntMap (Bag FloatBind)
-> IntMap (Bag FloatBind) -> IntMap (Bag FloatBind)
plusMinor
plusMinor :: MinorEnv -> MinorEnv -> MinorEnv
plusMinor :: IntMap (Bag FloatBind)
-> IntMap (Bag FloatBind) -> IntMap (Bag FloatBind)
plusMinor = (Bag FloatBind -> Bag FloatBind -> Bag FloatBind)
-> IntMap (Bag FloatBind)
-> IntMap (Bag FloatBind)
-> IntMap (Bag FloatBind)
forall a. (a -> a -> a) -> IntMap a -> IntMap a -> IntMap a
M.unionWith Bag FloatBind -> Bag FloatBind -> Bag FloatBind
forall a. Bag a -> Bag a -> Bag a
unionBags
install :: Bag FloatBind -> CoreExpr -> CoreExpr
install :: Bag FloatBind -> Expr Id -> Expr Id
install Bag FloatBind
defn_groups Expr Id
expr
  = (FloatBind -> Expr Id -> Expr Id)
-> Expr Id -> Bag FloatBind -> Expr Id
forall a b. (a -> b -> b) -> b -> Bag a -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr FloatBind -> Expr Id -> Expr Id
wrapFloat Expr Id
expr Bag FloatBind
defn_groups
partitionByLevel
        :: Level                
        -> FloatBinds           
        -> (FloatBinds,         
            Bag FloatBind)      
partitionByLevel :: Level -> FloatBinds -> (FloatBinds, Bag FloatBind)
partitionByLevel (Level Int
major Int
minor LevelType
typ) (FB Bag CoreBind
tops Bag FloatBind
ceils MajorEnv
defns)
  = (Bag CoreBind -> Bag FloatBind -> MajorEnv -> FloatBinds
FB Bag CoreBind
tops Bag FloatBind
ceils' (MajorEnv
outer_maj MajorEnv -> MajorEnv -> MajorEnv
`plusMajor` Int -> IntMap (Bag FloatBind) -> MajorEnv
forall a. Int -> a -> IntMap a
M.singleton Int
major IntMap (Bag FloatBind)
outer_min),
     Bag FloatBind
here_min Bag FloatBind -> Bag FloatBind -> Bag FloatBind
forall a. Bag a -> Bag a -> Bag a
`unionBags` Bag FloatBind
here_ceil
              Bag FloatBind -> Bag FloatBind -> Bag FloatBind
forall a. Bag a -> Bag a -> Bag a
`unionBags` IntMap (Bag FloatBind) -> Bag FloatBind
flattenMinor IntMap (Bag FloatBind)
inner_min
              Bag FloatBind -> Bag FloatBind -> Bag FloatBind
forall a. Bag a -> Bag a -> Bag a
`unionBags` MajorEnv -> Bag FloatBind
flattenMajor MajorEnv
inner_maj)
  where
    (MajorEnv
outer_maj, Maybe (IntMap (Bag FloatBind))
mb_here_maj, MajorEnv
inner_maj) = Int
-> MajorEnv -> (MajorEnv, Maybe (IntMap (Bag FloatBind)), MajorEnv)
forall a. Int -> IntMap a -> (IntMap a, Maybe a, IntMap a)
M.splitLookup Int
major MajorEnv
defns
    (IntMap (Bag FloatBind)
outer_min, Maybe (Bag FloatBind)
mb_here_min, IntMap (Bag FloatBind)
inner_min) = case Maybe (IntMap (Bag FloatBind))
mb_here_maj of
                                            Maybe (IntMap (Bag FloatBind))
Nothing -> (IntMap (Bag FloatBind)
forall a. IntMap a
M.empty, Maybe (Bag FloatBind)
forall a. Maybe a
Nothing, IntMap (Bag FloatBind)
forall a. IntMap a
M.empty)
                                            Just IntMap (Bag FloatBind)
min_defns -> Int
-> IntMap (Bag FloatBind)
-> (IntMap (Bag FloatBind), Maybe (Bag FloatBind),
    IntMap (Bag FloatBind))
forall a. Int -> IntMap a -> (IntMap a, Maybe a, IntMap a)
M.splitLookup Int
minor IntMap (Bag FloatBind)
min_defns
    here_min :: Bag FloatBind
here_min = Maybe (Bag FloatBind)
mb_here_min Maybe (Bag FloatBind) -> Bag FloatBind -> Bag FloatBind
forall a. Maybe a -> a -> a
`orElse` Bag FloatBind
forall a. Bag a
emptyBag
    (Bag FloatBind
here_ceil, Bag FloatBind
ceils') | LevelType
typ LevelType -> LevelType -> Bool
forall a. Eq a => a -> a -> Bool
== LevelType
JoinCeilLvl = (Bag FloatBind
ceils, Bag FloatBind
forall a. Bag a
emptyBag)
                        | Bool
otherwise          = (Bag FloatBind
forall a. Bag a
emptyBag, Bag FloatBind
ceils)
partitionAtJoinCeiling :: FloatBinds -> (FloatBinds, Bag FloatBind)
partitionAtJoinCeiling :: FloatBinds -> (FloatBinds, Bag FloatBind)
partitionAtJoinCeiling (FB Bag CoreBind
tops Bag FloatBind
ceils MajorEnv
defs)
  = (Bag CoreBind -> Bag FloatBind -> MajorEnv -> FloatBinds
FB Bag CoreBind
tops Bag FloatBind
forall a. Bag a
emptyBag MajorEnv
defs, Bag FloatBind
ceils)
atJoinCeiling :: (FloatStats, FloatBinds, CoreExpr)
              -> (FloatStats, FloatBinds, CoreExpr)
atJoinCeiling :: (FloatStats, FloatBinds, Expr Id)
-> (FloatStats, FloatBinds, Expr Id)
atJoinCeiling (FloatStats
fs, FloatBinds
floats, Expr Id
expr')
  = (FloatStats
fs, FloatBinds
floats', Bag FloatBind -> Expr Id -> Expr Id
install Bag FloatBind
ceils Expr Id
expr')
  where
    (FloatBinds
floats', Bag FloatBind
ceils) = FloatBinds -> (FloatBinds, Bag FloatBind)
partitionAtJoinCeiling FloatBinds
floats
wrapTick :: CoreTickish -> FloatBinds -> FloatBinds
wrapTick :: CoreTickish -> FloatBinds -> FloatBinds
wrapTick CoreTickish
t (FB Bag CoreBind
tops Bag FloatBind
ceils MajorEnv
defns)
  = Bag CoreBind -> Bag FloatBind -> MajorEnv -> FloatBinds
FB ((CoreBind -> CoreBind) -> Bag CoreBind -> Bag CoreBind
forall a b. (a -> b) -> Bag a -> Bag b
mapBag CoreBind -> CoreBind
wrap_bind Bag CoreBind
tops) (Bag FloatBind -> Bag FloatBind
wrap_defns Bag FloatBind
ceils)
       ((IntMap (Bag FloatBind) -> IntMap (Bag FloatBind))
-> MajorEnv -> MajorEnv
forall a b. (a -> b) -> IntMap a -> IntMap b
M.map ((Bag FloatBind -> Bag FloatBind)
-> IntMap (Bag FloatBind) -> IntMap (Bag FloatBind)
forall a b. (a -> b) -> IntMap a -> IntMap b
M.map Bag FloatBind -> Bag FloatBind
wrap_defns) MajorEnv
defns)
  where
    wrap_defns :: Bag FloatBind -> Bag FloatBind
wrap_defns = (FloatBind -> FloatBind) -> Bag FloatBind -> Bag FloatBind
forall a b. (a -> b) -> Bag a -> Bag b
mapBag FloatBind -> FloatBind
wrap_one
    wrap_bind :: CoreBind -> CoreBind
wrap_bind (NonRec Id
binder Expr Id
rhs) = Id -> Expr Id -> CoreBind
forall b. b -> Expr b -> Bind b
NonRec Id
binder (Expr Id -> Expr Id
maybe_tick Expr Id
rhs)
    wrap_bind (Rec [(Id, Expr Id)]
pairs)         = [(Id, Expr Id)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec ((Expr Id -> Expr Id) -> [(Id, Expr Id)] -> [(Id, Expr Id)]
forall (f :: * -> *) b c a.
Functor f =>
(b -> c) -> f (a, b) -> f (a, c)
mapSnd Expr Id -> Expr Id
maybe_tick [(Id, Expr Id)]
pairs)
    wrap_one :: FloatBind -> FloatBind
wrap_one (FloatLet CoreBind
bind)      = CoreBind -> FloatBind
FloatLet (CoreBind -> CoreBind
wrap_bind CoreBind
bind)
    wrap_one (FloatCase Expr Id
e Id
b AltCon
c [Id]
bs) = Expr Id -> Id -> AltCon -> [Id] -> FloatBind
FloatCase (Expr Id -> Expr Id
maybe_tick Expr Id
e) Id
b AltCon
c [Id]
bs
    maybe_tick :: Expr Id -> Expr Id
maybe_tick Expr Id
e | Expr Id -> Bool
exprIsHNF Expr Id
e = CoreTickish -> Expr Id -> Expr Id
tickHNFArgs CoreTickish
t Expr Id
e
                 | Bool
otherwise   = CoreTickish -> Expr Id -> Expr Id
mkTick CoreTickish
t Expr Id
e