{-# LANGUAGE BangPatterns, CPP, GADTs #-}
module CmmBuildInfoTables
    ( CAFSet, CAFEnv, cafAnal
    , doSRTs, TopSRT, emptySRT, isEmptySRT, srtToData )
where
#include "HsVersions.h"
import GhcPrelude hiding (succ)
import Hoopl.Block
import Hoopl.Graph
import Hoopl.Label
import Hoopl.Collections
import Hoopl.Dataflow
import Digraph
import Bitmap
import CLabel
import PprCmmDecl ()
import Cmm
import CmmUtils
import CmmInfo
import Data.List
import DynFlags
import Maybes
import Outputable
import SMRep
import UniqSupply
import Util
import PprCmm()
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Set (Set)
import qualified Data.Set as Set
import Control.Monad
foldSet :: (a -> b -> b) -> b -> Set a -> b
foldSet = Set.foldr
type CAFSet = Set CLabel
type CAFEnv = LabelMap CAFSet
cafLattice :: DataflowLattice CAFSet
cafLattice = DataflowLattice Set.empty add
  where
    add (OldFact old) (NewFact new) =
        let !new' = old `Set.union` new
        in changedIf (Set.size new' > Set.size old) new'
cafTransfers :: TransferFun CAFSet
cafTransfers (BlockCC eNode middle xNode) fBase =
    let joined = cafsInNode xNode $! joinOutFacts cafLattice xNode fBase
        !result = foldNodesBwdOO cafsInNode middle joined
    in mapSingleton (entryLabel eNode) result
cafsInNode :: CmmNode e x -> CAFSet -> CAFSet
cafsInNode node set = foldExpDeep addCaf node set
  where
    addCaf expr !set =
        case expr of
            CmmLit (CmmLabel c) -> add c set
            CmmLit (CmmLabelOff c _) -> add c set
            CmmLit (CmmLabelDiffOff c1 c2 _) -> add c1 $! add c2 set
            _ -> set
    add l s | hasCAF l  = Set.insert (toClosureLbl l) s
            | otherwise = s
cafAnal :: CmmGraph -> CAFEnv
cafAnal cmmGraph = analyzeCmmBwd cafLattice cafTransfers cmmGraph mapEmpty
data TopSRT = TopSRT { lbl      :: CLabel
                     , next_elt :: Int 
                     , rev_elts :: [CLabel]
                     , elt_map  :: Map CLabel Int }
                        
instance Outputable TopSRT where
  ppr (TopSRT lbl next elts eltmap) =
    text "TopSRT:" <+> ppr lbl
                   <+> ppr next
                   <+> ppr elts
                   <+> ppr eltmap
emptySRT :: MonadUnique m => m TopSRT
emptySRT =
  do top_lbl <- getUniqueM >>= \ u -> return $ mkTopSRTLabel u
     return TopSRT { lbl = top_lbl, next_elt = 0, rev_elts = [], elt_map = Map.empty }
isEmptySRT :: TopSRT -> Bool
isEmptySRT srt = null (rev_elts srt)
cafMember :: TopSRT -> CLabel -> Bool
cafMember srt lbl = Map.member lbl (elt_map srt)
cafOffset :: TopSRT -> CLabel -> Maybe Int
cafOffset srt lbl = Map.lookup lbl (elt_map srt)
addCAF :: CLabel -> TopSRT -> TopSRT
addCAF caf srt =
  srt { next_elt = last + 1
      , rev_elts = caf : rev_elts srt
      , elt_map  = Map.insert caf last (elt_map srt) }
    where last  = next_elt srt
srtToData :: TopSRT -> CmmGroup
srtToData srt = [CmmData sec (Statics (lbl srt) tbl)]
    where tbl = map (CmmStaticLit . CmmLabel) (reverse (rev_elts srt))
          sec = Section RelocatableReadOnlyData (lbl srt)
buildSRT :: DynFlags -> TopSRT -> CAFSet -> UniqSM (TopSRT, Maybe CmmDecl, C_SRT)
buildSRT dflags topSRT cafs =
  do let
         
         
         sub_srt topSRT localCafs =
           let cafs = Set.elems localCafs
               mkSRT topSRT =
                 do localSRTs <- procpointSRT dflags (lbl topSRT) (elt_map topSRT) cafs
                    return (topSRT, localSRTs)
           in if cafs `lengthExceeds` maxBmpSize dflags then
                mkSRT (foldl add_if_missing topSRT cafs)
              else 
                mkSRT (add_if_too_far topSRT cafs)
         add_if_missing srt caf =
           if cafMember srt caf then srt else addCAF caf srt
         
         
         
         add_if_too_far srt@(TopSRT {elt_map = m}) cafs =
           add srt (sortBy farthestFst cafs)
             where
               farthestFst x y = case (Map.lookup x m, Map.lookup y m) of
                                   (Nothing, Nothing) -> EQ
                                   (Nothing, Just _)  -> LT
                                   (Just _,  Nothing) -> GT
                                   (Just d, Just d')  -> compare d' d
               add srt [] = srt
               add srt@(TopSRT {next_elt = next}) (caf : rst) =
                 case cafOffset srt caf of
                   Just ix -> if next - ix > maxBmpSize dflags then
                                add (addCAF caf srt) rst
                              else srt
                   Nothing -> add (addCAF caf srt) rst
     (topSRT, subSRTs) <- sub_srt topSRT cafs
     let (sub_tbls, blockSRTs) = subSRTs
     return (topSRT, sub_tbls, blockSRTs)
procpointSRT :: DynFlags -> CLabel -> Map CLabel Int -> [CLabel] ->
                UniqSM (Maybe CmmDecl, C_SRT)
procpointSRT _ _ _ [] =
 return (Nothing, NoC_SRT)
procpointSRT dflags top_srt top_table entries =
 do (top, srt) <- bitmap `seq` to_SRT dflags top_srt offset len bitmap
    return (top, srt)
  where
    ints = map (expectJust "constructSRT" . flip Map.lookup top_table) entries
    sorted_ints = sort ints
    offset = head sorted_ints
    bitmap_entries = map (subtract offset) sorted_ints
    len = GhcPrelude.last bitmap_entries + 1
    bitmap = intsToBitmap dflags len bitmap_entries
maxBmpSize :: DynFlags -> Int
maxBmpSize dflags = widthInBits (wordWidth dflags) `div` 2
to_SRT :: DynFlags -> CLabel -> Int -> Int -> Bitmap -> UniqSM (Maybe CmmDecl, C_SRT)
to_SRT dflags top_srt off len bmp
  | len > maxBmpSize dflags || bmp == [toStgWord dflags (fromStgHalfWord (srtEscape dflags))]
  = do id <- getUniqueM
       let srt_desc_lbl = mkLargeSRTLabel id
           section = Section RelocatableReadOnlyData srt_desc_lbl
           tbl = CmmData section $
                   Statics srt_desc_lbl $ map CmmStaticLit
                     ( cmmLabelOffW dflags top_srt off
                     : mkWordCLit dflags (fromIntegral len)
                     : map (mkStgWordCLit dflags) bmp)
       return (Just tbl, C_SRT srt_desc_lbl 0 (srtEscape dflags))
  | otherwise
  = return (Nothing, C_SRT top_srt off (toStgHalfWord dflags (fromStgWord (head bmp))))
        
localCAFInfo :: CAFEnv -> CmmDecl -> (CAFSet, Maybe CLabel)
localCAFInfo _      (CmmData _ _) = (Set.empty, Nothing)
localCAFInfo cafEnv proc@(CmmProc _ top_l _ (CmmGraph {g_entry=entry})) =
  case topInfoTable proc of
    Just (CmmInfoTable { cit_rep = rep })
      | not (isStaticRep rep) && not (isStackRep rep)
      -> (cafs, Just (toClosureLbl top_l))
    _other -> (cafs, Nothing)
  where
    cafs = expectJust "maybeBindCAFs" $ mapLookup entry cafEnv
mkTopCAFInfo :: [(CAFSet, Maybe CLabel)] -> Map CLabel CAFSet
mkTopCAFInfo localCAFs = foldl addToTop Map.empty g
  where
        addToTop env (AcyclicSCC (l, cafset)) =
          Map.insert l (flatten env cafset) env
        addToTop env (CyclicSCC nodes) =
          let (lbls, cafsets) = unzip nodes
              cafset  = foldr Set.delete (foldl Set.union Set.empty cafsets) lbls
          in foldl (\env l -> Map.insert l (flatten env cafset) env) env lbls
        g = stronglyConnCompFromEdgedVerticesOrd
              [ DigraphNode (l,cafs) l (Set.elems cafs)
              | (cafs, Just l) <- localCAFs ]
flatten :: Map CLabel CAFSet -> CAFSet -> CAFSet
flatten env cafset = foldSet (lookup env) Set.empty cafset
  where
      lookup env caf cafset' =
          case Map.lookup caf env of
             Just cafs -> foldSet Set.insert cafset' cafs
             Nothing   -> Set.insert caf cafset'
bundle :: Map CLabel CAFSet
       -> (CAFEnv, CmmDecl)
       -> (CAFSet, Maybe CLabel)
       -> (LabelMap CAFSet, CmmDecl)
bundle flatmap (env, decl@(CmmProc infos _lbl _ g)) (closure_cafs, mb_lbl)
  = ( mapMapWithKey get_cafs (info_tbls infos), decl )
 where
  entry = g_entry g
  entry_cafs
    | Just l <- mb_lbl = expectJust "bundle" $ Map.lookup l flatmap
    | otherwise        = flatten flatmap closure_cafs
  get_cafs l _
    | l == entry = entry_cafs
    | Just info <- mapLookup l env = flatten flatmap info
    | otherwise  = Set.empty
    
    
    
    
    
bundle _flatmap (_, decl) _
  = ( mapEmpty, decl )
flattenCAFSets :: [(CAFEnv, [CmmDecl])] -> [(LabelMap CAFSet, CmmDecl)]
flattenCAFSets cpsdecls = zipWith (bundle flatmap) zipped localCAFs
   where
     zipped    = [ (env,decl) | (env,decls) <- cpsdecls, decl <- decls ]
     localCAFs = unzipWith localCAFInfo zipped
     flatmap   = mkTopCAFInfo localCAFs 
doSRTs :: DynFlags
       -> TopSRT
       -> [(CAFEnv, [CmmDecl])]
       -> IO (TopSRT, [CmmDecl])
doSRTs dflags topSRT tops
  = do
     let caf_decls = flattenCAFSets tops
     us <- mkSplitUniqSupply 'u'
     let (topSRT', gs') = initUs_ us $ foldM setSRT (topSRT, []) caf_decls
     return (topSRT', reverse gs' )
  where
    setSRT (topSRT, rst) (caf_map, decl@(CmmProc{})) = do
       (topSRT, srt_tables, srt_env) <- buildSRTs dflags topSRT caf_map
       let decl' = updInfoSRTs srt_env decl
       return (topSRT, decl': srt_tables ++ rst)
    setSRT (topSRT, rst) (_, decl) =
      return (topSRT, decl : rst)
buildSRTs :: DynFlags -> TopSRT -> LabelMap CAFSet
          -> UniqSM (TopSRT, [CmmDecl], LabelMap C_SRT)
buildSRTs dflags top_srt caf_map
  = foldM doOne (top_srt, [], mapEmpty) (mapToList caf_map)
  where
  doOne (top_srt, decls, srt_env) (l, cafs)
    = do (top_srt, mb_decl, srt) <- buildSRT dflags top_srt cafs
         return ( top_srt, maybeToList mb_decl ++ decls
                , mapInsert l srt srt_env )
updInfoSRTs :: LabelMap C_SRT -> CmmDecl -> CmmDecl
updInfoSRTs srt_env (CmmProc top_info top_l live g) =
  CmmProc (top_info {info_tbls = mapMapWithKey updInfoTbl (info_tbls top_info)}) top_l live g
  where updInfoTbl l info_tbl
             = info_tbl { cit_srt = expectJust "updInfo" $ mapLookup l srt_env }
updInfoSRTs _ t = t