{-# LANGUAGE GADTs, BangPatterns #-}
module CmmCommonBlockElim
  ( elimCommonBlocks
  )
where
import GhcPrelude hiding (iterate, succ, unzip, zip)
import BlockId
import Cmm
import CmmUtils
import CmmSwitch (eqSwitchTargetWith)
import CmmContFlowOpt
import Hoopl.Block
import Hoopl.Graph
import Hoopl.Label
import Hoopl.Collections
import Data.Bits
import Data.Maybe (mapMaybe)
import qualified Data.List as List
import Data.Word
import qualified Data.Map as M
import Outputable
import UniqFM
import UniqDFM
import qualified TrieMap as TM
import Unique
import Control.Arrow (first, second)
elimCommonBlocks :: CmmGraph -> CmmGraph
elimCommonBlocks g = replaceLabels env $ copyTicks env g
  where
     env = iterate mapEmpty blocks_with_key
     groups = groupByInt hash_block (postorderDfs g)
     blocks_with_key = [ [ (successors b, [b]) | b <- bs] | bs <- groups]
type DistinctBlocks = [CmmBlock]
type Key = [Label]
type Subst = LabelMap BlockId
iterate :: Subst -> [[(Key, DistinctBlocks)]] -> Subst
iterate subst blocks
    | mapNull new_substs = subst
    | otherwise = iterate subst' updated_blocks
  where
    grouped_blocks :: [[(Key, [DistinctBlocks])]]
    grouped_blocks = map groupByLabel blocks
    merged_blocks :: [[(Key, DistinctBlocks)]]
    (new_substs, merged_blocks) = List.mapAccumL (List.mapAccumL go) mapEmpty grouped_blocks
      where
        go !new_subst1 (k,dbs) = (new_subst1 `mapUnion` new_subst2, (k,db))
          where
            (new_subst2, db) = mergeBlockList subst dbs
    subst' = subst `mapUnion` new_substs
    updated_blocks = map (map (first (map (lookupBid subst')))) merged_blocks
mergeBlocks :: Subst -> DistinctBlocks -> DistinctBlocks -> (Subst, DistinctBlocks)
mergeBlocks subst existing new = go new
  where
    go [] = (mapEmpty, existing)
    go (b:bs) = case List.find (eqBlockBodyWith (eqBid subst) b) existing of
        
        Just b' -> first (mapInsert (entryLabel b) (entryLabel b')) $ go bs
        
        Nothing -> second (b:) $ go bs
mergeBlockList :: Subst -> [DistinctBlocks] -> (Subst, DistinctBlocks)
mergeBlockList _ [] = pprPanic "mergeBlockList" empty
mergeBlockList subst (b:bs) = go mapEmpty b bs
  where
    go !new_subst1 b [] = (new_subst1, b)
    go !new_subst1 b1 (b2:bs) = go new_subst b bs
      where
        (new_subst2, b) =  mergeBlocks subst b1 b2
        new_subst = new_subst1 `mapUnion` new_subst2
type HashCode = Int
hash_block :: CmmBlock -> HashCode
hash_block block =
  fromIntegral (foldBlockNodesB3 (hash_fst, hash_mid, hash_lst) block (0 :: Word32) .&. (0x7fffffff :: Word32))
  
  where hash_fst _ h = h
        hash_mid m h = hash_node m + h `shiftL` 1
        hash_lst m h = hash_node m + h `shiftL` 1
        hash_node :: CmmNode O x -> Word32
        hash_node n | dont_care n = 0 
        hash_node (CmmAssign r e) = hash_reg r + hash_e e
        hash_node (CmmStore e e') = hash_e e + hash_e e'
        hash_node (CmmUnsafeForeignCall t _ as) = hash_tgt t + hash_list hash_e as
        hash_node (CmmBranch _) = 23 
        hash_node (CmmCondBranch p _ _ _) = hash_e p
        hash_node (CmmCall e _ _ _ _ _) = hash_e e
        hash_node (CmmForeignCall t _ _ _ _ _ _) = hash_tgt t
        hash_node (CmmSwitch e _) = hash_e e
        hash_node _ = error "hash_node: unknown Cmm node!"
        hash_reg :: CmmReg -> Word32
        hash_reg   (CmmLocal localReg) = hash_unique localReg 
        hash_reg   (CmmGlobal _)    = 19
        hash_e :: CmmExpr -> Word32
        hash_e (CmmLit l) = hash_lit l
        hash_e (CmmLoad e _) = 67 + hash_e e
        hash_e (CmmReg r) = hash_reg r
        hash_e (CmmMachOp _ es) = hash_list hash_e es 
        hash_e (CmmRegOff r i) = hash_reg r + cvt i
        hash_e (CmmStackSlot _ _) = 13
        hash_lit :: CmmLit -> Word32
        hash_lit (CmmInt i _) = fromInteger i
        hash_lit (CmmFloat r _) = truncate r
        hash_lit (CmmVec ls) = hash_list hash_lit ls
        hash_lit (CmmLabel _) = 119 
        hash_lit (CmmLabelOff _ i) = cvt $ 199 + i
        hash_lit (CmmLabelDiffOff _ _ i) = cvt $ 299 + i
        hash_lit (CmmBlock _) = 191 
        hash_lit (CmmHighStackMark) = cvt 313
        hash_tgt (ForeignTarget e _) = hash_e e
        hash_tgt (PrimTarget _) = 31 
        hash_list f = foldl (\z x -> f x + z) (0::Word32)
        cvt = fromInteger . toInteger
        hash_unique :: Uniquable a => a -> Word32
        hash_unique = cvt . getKey . getUnique
dont_care :: CmmNode O x -> Bool
dont_care CmmComment {}  = True
dont_care CmmTick {}     = True
dont_care CmmUnwind {}   = True
dont_care _other         = False
eqBid :: LabelMap BlockId -> BlockId -> BlockId -> Bool
eqBid subst bid bid' = lookupBid subst bid == lookupBid subst bid'
lookupBid :: LabelMap BlockId -> BlockId -> BlockId
lookupBid subst bid = case mapLookup bid subst of
                        Just bid  -> lookupBid subst bid
                        Nothing -> bid
eqMiddleWith :: (BlockId -> BlockId -> Bool)
             -> CmmNode O O -> CmmNode O O -> Bool
eqMiddleWith eqBid (CmmAssign r1 e1) (CmmAssign r2 e2)
  = r1 == r2 && eqExprWith eqBid e1 e2
eqMiddleWith eqBid (CmmStore l1 r1) (CmmStore l2 r2)
  = eqExprWith eqBid l1 l2 && eqExprWith eqBid r1 r2
eqMiddleWith eqBid (CmmUnsafeForeignCall t1 r1 a1)
                   (CmmUnsafeForeignCall t2 r2 a2)
  = t1 == t2 && r1 == r2 && eqListWith (eqExprWith eqBid) a1 a2
eqMiddleWith _ _ _ = False
eqExprWith :: (BlockId -> BlockId -> Bool)
           -> CmmExpr -> CmmExpr -> Bool
eqExprWith eqBid = eq
 where
  CmmLit l1          `eq` CmmLit l2          = eqLit l1 l2
  CmmLoad e1 _       `eq` CmmLoad e2 _       = e1 `eq` e2
  CmmReg r1          `eq` CmmReg r2          = r1==r2
  CmmRegOff r1 i1    `eq` CmmRegOff r2 i2    = r1==r2 && i1==i2
  CmmMachOp op1 es1  `eq` CmmMachOp op2 es2  = op1==op2 && es1 `eqs` es2
  CmmStackSlot a1 i1 `eq` CmmStackSlot a2 i2 = eqArea a1 a2 && i1==i2
  _e1                `eq` _e2                = False
  xs `eqs` ys = eqListWith eq xs ys
  eqLit (CmmBlock id1) (CmmBlock id2) = eqBid id1 id2
  eqLit l1 l2 = l1 == l2
  eqArea Old Old = True
  eqArea (Young id1) (Young id2) = eqBid id1 id2
  eqArea _ _ = False
eqBlockBodyWith :: (BlockId -> BlockId -> Bool) -> CmmBlock -> CmmBlock -> Bool
eqBlockBodyWith eqBid block block'
  
  = equal
  where (_,m,l)   = blockSplit block
        nodes     = filter (not . dont_care) (blockToList m)
        (_,m',l') = blockSplit block'
        nodes'    = filter (not . dont_care) (blockToList m')
        equal = eqListWith (eqMiddleWith eqBid) nodes nodes' &&
                eqLastWith eqBid l l'
eqLastWith :: (BlockId -> BlockId -> Bool) -> CmmNode O C -> CmmNode O C -> Bool
eqLastWith eqBid (CmmBranch bid1) (CmmBranch bid2) = eqBid bid1 bid2
eqLastWith eqBid (CmmCondBranch c1 t1 f1 l1) (CmmCondBranch c2 t2 f2 l2) =
  c1 == c2 && l1 == l2 && eqBid t1 t2 && eqBid f1 f2
eqLastWith eqBid (CmmCall t1 c1 g1 a1 r1 u1) (CmmCall t2 c2 g2 a2 r2 u2) =
  t1 == t2 && eqMaybeWith eqBid c1 c2 && a1 == a2 && r1 == r2 && u1 == u2 && g1 == g2
eqLastWith eqBid (CmmSwitch e1 ids1) (CmmSwitch e2 ids2) =
  e1 == e2 && eqSwitchTargetWith eqBid ids1 ids2
eqLastWith _ _ _ = False
eqMaybeWith :: (a -> b -> Bool) -> Maybe a -> Maybe b -> Bool
eqMaybeWith eltEq (Just e) (Just e') = eltEq e e'
eqMaybeWith _ Nothing Nothing = True
eqMaybeWith _ _ _ = False
eqListWith :: (a -> b -> Bool) -> [a] -> [b] -> Bool
eqListWith f (a : as) (b : bs) = f a b && eqListWith f as bs
eqListWith _ []       []       = True
eqListWith _ _        _        = False
copyTicks :: LabelMap BlockId -> CmmGraph -> CmmGraph
copyTicks env g
  | mapNull env = g
  | otherwise   = ofBlockMap (g_entry g) $ mapMap copyTo blockMap
  where 
        blockMap = toBlockMap g
        revEnv = mapFoldWithKey insertRev M.empty env
        insertRev k x = M.insertWith (const (k:)) x [k]
        
        copyTo block = case M.lookup (entryLabel block) revEnv of
          Nothing -> block
          Just ls -> foldr copy block $ mapMaybe (flip mapLookup blockMap) ls
        copy from to =
          let ticks = blockTicks from
              CmmEntry  _   scp0        = firstNode from
              (CmmEntry lbl scp1, code) = blockSplitHead to
          in CmmEntry lbl (combineTickScopes scp0 scp1) `blockJoinHead`
             foldr blockCons code (map CmmTick ticks)
groupByLabel :: [(Key, a)] -> [(Key, [a])]
groupByLabel = go (TM.emptyTM :: TM.ListMap UniqDFM a)
  where
    go !m [] = TM.foldTM (:) m []
    go !m ((k,v) : entries) = go (TM.alterTM k' adjust m) entries
      where k' = map getUnique k
            adjust Nothing       = Just (k,[v])
            adjust (Just (_,vs)) = Just (k,v:vs)
groupByInt :: (a -> Int) -> [a] -> [[a]]
groupByInt f xs = nonDetEltsUFM $ List.foldl' go emptyUFM xs
  
  where go m x = alterUFM (Just . maybe [x] (x:)) m (f x)