module Debug.Hoed.CompTree
( CompTree
, Vertex(..)
, mkCompTree
, isRootVertex
, vertexUID
, vertexRes
, replaceVertex
, getJudgement
, setJudgement
, isRight
, isWrong
, isUnassessed
, isAssisted
, isInconclusive
, isPassing
, leafs
, ConstantValue(..)
, unjudgedCharacterCount
#if defined(TRANSCRIPT)
, getTranscript
#endif
, TraceInfo(..)
, traceInfo
, Graph(..) 
)where
import           Control.DeepSeq
import           Control.Exception as E
import           Control.Monad
import           Debug.Hoed.EventForest
import           Debug.Hoed.Observe
import           Debug.Hoed.Span
import           Debug.Hoed.Render
import           Debug.Hoed.Util
import           Data.Bits
import qualified Data.Foldable          as F
import           Data.Graph.Libgraph
import qualified Data.Set               as Set
import           Data.Hashable
import           Data.IntMap.Strict     (IntMap)
import qualified Data.IntMap.Strict     as IntMap
import           Data.IntSet            (IntSet)
import           Data.List              (foldl', unfoldr)
import           Data.Maybe
import           Data.Semigroup
import           Data.Text (Text, pack, unpack)
import qualified Data.Text as T
import qualified Data.Vector.Generic as VG
import qualified Data.Vector.Generic.Mutable as VM
import           Data.Vector.Mutable as VM (IOVector)
import qualified Data.Vector.Unboxed    as U
import           Data.Word
import           GHC.Exts               (IsList (..))
import           GHC.Generics
import           Prelude                hiding (Right)
data Vertex = RootVertex | Vertex {vertexStmt :: CompStmt, vertexJmt :: Judgement}
  deriving (Show,Ord,Generic)
instance Eq Vertex where
  RootVertex == RootVertex   = True
  v1@Vertex{} == v2@Vertex{} = vertexStmt v1 == vertexStmt v2
  _ == _ = False
instance Hashable Vertex where
  hashWithSalt s RootVertex    = s `hashWithSalt` (1 :: Int)
  hashWithSalt s (Vertex cs _) = s `hashWithSalt` cs
instance NFData Vertex
instance NFData AssistedMessage
instance NFData Judgement
instance (NFData a, NFData b) => NFData (Arc a b)
getJudgement :: Vertex -> Judgement
getJudgement RootVertex = Right
getJudgement v          = vertexJmt v
setJudgement :: Vertex -> Judgement -> Vertex
setJudgement RootVertex _ = RootVertex
setJudgement v          j = v{vertexJmt=j}
isRight :: Vertex -> Bool
isRight v      = getJudgement v == Right
isWrong :: Vertex -> Bool
isWrong v      = getJudgement v == Wrong
isUnassessed :: Vertex -> Bool
isUnassessed v = getJudgement v == Unassessed
isAssisted :: Vertex -> Bool
isAssisted v   = case getJudgement v of (Assisted _) -> True; _ -> False
isInconclusive :: Vertex -> Bool
isInconclusive v = case getJudgement v of
  (Assisted ms) -> any isInconclusive' ms
  _             -> False
isInconclusive' :: AssistedMessage -> Bool
isInconclusive' (InconclusiveProperty _) = True
isInconclusive' _                        = False
isPassing :: Vertex -> Bool
isPassing v = case getJudgement v of
  (Assisted ms) -> any isPassing' ms
  _             -> False
isPassing' :: AssistedMessage -> Bool
isPassing' (PassingProperty _) = True
isPassing' _                   = False
vertexUID :: Vertex -> UID
vertexUID RootVertex   = 1
vertexUID (Vertex s _) = stmtIdentifier s
vertexRes :: Vertex -> String
vertexRes RootVertex = "RootVertex"
vertexRes v          = unpack . stmtRes . vertexStmt $ v
type CompTree = Graph Vertex ()
isRootVertex :: Vertex -> Bool
isRootVertex RootVertex = True
isRootVertex _          = False
leafs :: CompTree -> [Vertex]
leafs g = filter (not . (`Set.member` nonLeafs)) (vertices g)
  where
    nonLeafs = Set.fromList [s | Arc s t _ <- arcs g]
unjudgedCharacterCount :: CompTree -> Int
unjudgedCharacterCount = sum . map characterCount . filter unjudged . vertices
  where characterCount = fromIntegral . T.length . stmtLabel . vertexStmt
unjudged :: Vertex -> Bool
unjudged = not . judged
judged :: Vertex -> Bool
judged v = isRight v || isWrong v
replaceVertex :: CompTree -> Vertex -> CompTree
replaceVertex g v = mapGraph f g
  where f RootVertex = RootVertex
        f v' | vertexUID v' == vertexUID v = v
             | otherwise                       = v'
mkCompTree :: [CompStmt] -> Dependencies -> CompTree
mkCompTree cs ds = Graph RootVertex vs as
  where vs = RootVertex : map (`Vertex` Unassessed) cs
        as = [Arc (findVertex i) (findVertex j) () | (i,jj) <- toList ds, j <- toList jj]
        
        vMap :: IntMap Vertex
        
        vMap = foldl' (\m c -> IntMap.insert (stmtIdentifier c) (Vertex c Unassessed) m) IntMap.empty cs
        
        findVertex :: UID -> Vertex
        findVertex (1) = RootVertex
        findVertex a = case IntMap.lookup a vMap of
          Nothing  -> error $ "mkCompTree: Error, cannot find a statement with UID " ++ show a ++ "!\n"
                              ++ "We recorded statements with the following UIDs: " ++ (show . IntMap.keys) vMap ++ "\n"
                              ++ unlines (map (\c -> (show . stmtIdentifier) c ++ ": " ++ show c) cs)
          (Just v) -> v
data ConstantValue = ConstantValue { valStmt :: !UID, valLoc :: !Location
                                   , valMin  :: !UID, valMax :: !UID }
                   | CVRoot
                  deriving Eq
instance Show ConstantValue where
  show CVRoot = "Root"
  show v = "Stmt-" ++ (show . valStmt $ v)
         ++ "-"    ++ (show . valLoc  $ v)
         ++ ": "   ++ (show . valMin  $ v)
         ++ "-"    ++ (show . valMax  $ v)
newtype TopLvlFun = TopLvlFun UID deriving Eq
noTopLvlFun :: TopLvlFun
noTopLvlFun = TopLvlFun (1)
data EventDetails = EventDetails
  { topLvlFun_   :: !TopLvlFun
              
  , locations   :: ParentPosition -> Bool
              
  }
topLvlFun :: EventDetails -> UID
topLvlFun EventDetails{topLvlFun_ = TopLvlFun x} = x
type EventDetailsStore s = VM.IOVector EventDetails
getEventDetails :: EventDetailsStore s -> UID -> IO EventDetails
getEventDetails = VM.unsafeRead
setEventDetails :: EventDetailsStore s -> UID -> EventDetails -> IO ()
setEventDetails = VM.unsafeWrite
getTopLvlFunOr :: UID -> EventDetails -> UID
getTopLvlFunOr def EventDetails{topLvlFun_}
  | topLvlFun_ == noTopLvlFun = def
  | TopLvlFun x <- topLvlFun_ = x
type Dependencies = IntMap IntSet
data TraceInfo = TraceInfo
  { computations :: !SpanZipper
                   
  , dependencies :: !Dependencies
#if defined(TRANSCRIPT)
  , messages     :: !(IntMap String)
              
#endif
  }
  deriving Show
addMessage :: Event -> String -> TraceInfo -> TraceInfo
#if defined(TRANSCRIPT)
addMessage e msg s = s{ messages = (flip $ IntMap.insert i) (messages s) $ case IntMap.lookup i (messages s) of
  Nothing     -> msg
  (Just msg') -> msg' ++ ", " ++ msg }
  where i = eventUID e
getMessage :: Event -> TraceInfo -> String
getMessage e s = case IntMap.lookup i (messages s) of
  Nothing    -> ""
  (Just msg) -> msg
  where i = eventUID e
getTranscript :: [Event] -> TraceInfo -> String
getTranscript es t = foldl (\acc e -> (show e ++ m e) ++ "\n" ++ acc) "" es
  where m e = case IntMap.lookup (eventUID e) ms of
          Nothing    -> ""
          (Just msg) -> "\n  " ++ msg
        ms = messages t
#else
addMessage _ _ t = t
#endif
collectEventDetails :: EventDetailsStore s -> Event -> IO (Bool,UID)
collectEventDetails v e = do
            let !p = eventParent e
            parentDetails <- getEventDetails v (parentUID p)
            let !loc = locations parentDetails (parentPosition p)
                !top = getTopLvlFunOr (parentUID p) parentDetails
            return (loc, top)
mkFunDetails :: EventDetailsStore s -> UID -> Event -> IO EventDetails
mkFunDetails s uid e = do
    let p = eventParent e
    ed  <- getEventDetails s (parentUID p)
    let !loc = locations ed (parentPosition p)
        !top = getTopLvlFunOr uid ed
        locFun 0 = not loc
        locFun 1 = loc
    return $ EventDetails (TopLvlFun top) locFun
start, stop,pause,resume :: Event -> EventDetails -> TraceInfo -> TraceInfo
start e ed s = m s{computations = cs}
  where i  = topLvlFun ed
        cs = startSpan i $ computations s
        m  = addMessage e $ "Start computation " ++ show i ++ ": " ++ show cs
stop e ed s = m s {computations = cs'}
  where
    i = topLvlFun ed
    cs' = stopSpan i (computations s)
    m = addMessage e $ "Stop computation " ++ show i ++ ": " ++ show cs'
pause e ed s = m s {computations = cs'}
  where
    i = topLvlFun ed
    cs' = pauseSpan i (computations s)
    m = addMessage e $ "Pause up to " ++ show i ++ ": " ++ show cs'
resume e ed s = m s {computations = cs'}
  where
    i = topLvlFun ed
    cs' = resumeSpan i (computations s)
    m = addMessage e $ "Resume computation " ++ show i ++ ": " ++ show cs'
activeComputations :: TraceInfo -> [UID]
activeComputations s = map getSpanUID . filter isActive . toList $ computations s
  where isActive (Computing _) = True
        isActive _             = False
addDependency :: Event -> TraceInfo -> TraceInfo
addDependency _e s =
  m s{dependencies = case d of
         Just (from,to) -> IntMap.insertWith (<>) from [to] (dependencies s)
         Nothing -> dependencies s}
  where d = case activeComputations s of
              []      -> Nothing
              [n]     -> Just (1, n)  
              (n:m:_) -> Just (m, n)
        m = case d of
             Nothing       -> addMessage _e ("does not add dependency")
             (Just (a, b)) -> addMessage _e ("adds dependency " ++ show a ++ " -> " ++ show b)
type ConsMap = U.Vector Word
mkConsMap :: Trace -> ConsMap
mkConsMap t =
  U.create $ do
    v <- VM.replicate (VG.length t) 0
    VG.forM_ t $ \e ->
      when (isCons (change e)) $ do
          let p = eventParent e
#if __GLASGOW_HASKELL__ >= 800
          VM.unsafeModify v (`setBit` fromIntegral(parentPosition p)) (parentUID p)
#else
          let ix = parentUID p
          x <- VM.unsafeRead v ix
          VM.unsafeWrite v ix (x `setBit` fromIntegral(parentPosition p))
#endif
    return v
  where
    isCons Cons{} = True
    isCons ConsChar{} = True
    isCons _ = False
corToCons :: ConsMap -> Event -> Bool
corToCons cm e = case U.unsafeIndex cm (parentUID p) of
                   0 -> False
                   other -> testBit other (fromIntegral $ parentPosition p)
  where p = eventParent e
traceInfo :: Verbosity -> Trace -> IO TraceInfo
traceInfo verbose trc = do
  condPutStr verbose "Calculating the edges of the computation graph"
  v <- VM.replicate l $ EventDetails noTopLvlFun (const False)
  let loop !s uid e = do
        when (uid `mod` l100 == 0) $ condPutStr verbose "."
        case (change e) of
          Observe {} -> do
            setEventDetails v uid (EventDetails noTopLvlFun (const True))
            return s
          Fun {} -> do
            setEventDetails v uid =<< mkFunDetails v uid e
            return s
            
          Enter {}
            | corToCons cs e -> do
              (loc, top) <- collectEventDetails v e
              let !details = EventDetails (TopLvlFun top) (const loc)
              setEventDetails v uid details
              return $ if loc
                  then addDependency e . start e details $ s
                  else pause e details s
            | otherwise -> return s
            
          other -> do
            (loc, top) <- collectEventDetails v e
            let !details = EventDetails (TopLvlFun top) (const loc)
            setEventDetails v uid details
            return $ if loc
              then stop e details s
              else resume e details s
  VG.ifoldM' loop s0 trc
  where
    l = VG.length trc
    l100 = max 1 (l `div` 100)
    s0 :: TraceInfo
    s0 = TraceInfo [] []
#if defined(TRANSCRIPT)
           IntMap.empty
#endif
    cs :: ConsMap
    cs = mkConsMap trc