module Camfort.Specification.Stencils.InferenceFrontend where
import Control.Monad.State.Strict
import Control.Monad.Reader
import Control.Monad.Writer.Strict hiding (Product)
import Camfort.Analysis.CommentAnnotator
import Camfort.Specification.Stencils.InferenceBackend
import Camfort.Specification.Stencils.Syntax
import Camfort.Specification.Stencils.Annotation
import qualified Camfort.Specification.Stencils.Grammar as Gram
import qualified Camfort.Specification.Stencils.Synthesis as Synth
import Camfort.Analysis.Loops (collect)
import Camfort.Analysis.Annotations
import Camfort.Helpers.Vec
import Camfort.Input
import qualified Camfort.Output as O
import qualified Language.Fortran.AST as F
import qualified Language.Fortran.Analysis as FA
import qualified Language.Fortran.Analysis.Types as FAT
import qualified Language.Fortran.Analysis.Renaming as FAR
import qualified Language.Fortran.Analysis.BBlocks as FAB
import qualified Language.Fortran.Analysis.DataFlow as FAD
import qualified Language.Fortran.Util.Position as FU
import qualified Language.Fortran.Util.SecondParameter as FUS
import Data.Data
import Data.Foldable
import Data.Generics.Uniplate.Operations
import Data.Graph.Inductive.Graph hiding (isEmpty)
import qualified Data.Map as M
import qualified Data.IntMap as IM
import qualified Data.Set as S
import Data.Maybe
import Data.List
import Debug.Trace
data InferMode =
  DoMode | AssignMode | CombinedMode | EvalMode | Synth
  deriving (Eq, Show, Data, Read)
instance Default InferMode where
    defaultValue = AssignMode
data InferState = IS {
     ivMap :: FAD.InductionVarMapByASTBlock,
     hasSpec :: [(FU.SrcSpan, Variable)] }
type EvalLog = [(String, Variable)]
type LogLine = (FU.SrcSpan, Either [([Variable], Specification)] (String,Variable))
type Inferer = WriterT [LogLine]
                 (ReaderT (Cycles, F.ProgramUnitName, TypeEnv A)
                    (State InferState))
type Cycles = [(F.Name, F.Name)]
type Params = (?flowsGraph :: FAD.FlowsGraph A, ?nameMap :: FAR.NameMap)
runInferer :: FAD.InductionVarMapByASTBlock
           -> Cycles
           -> F.ProgramUnitName
           -> TypeEnv A
           -> Inferer a
           -> (a, [LogLine])
runInferer ivmap cycles puName tenv =
    flip evalState (IS ivmap [])
  . flip runReaderT (cycles, puName, tenv)
  . runWriterT
stencilInference :: FAR.NameMap
                 -> InferMode
                 -> F.ProgramFile (FA.Analysis A)
                 -> (F.ProgramFile (FA.Analysis A), [LogLine])
stencilInference nameMap mode pf =
    (F.ProgramFile cm_pus' blocks', log1 ++ log2)
  where
    
    
    
    
    (pf'@(F.ProgramFile cm_pus blocks), log0) =
         if mode == Synth
          then runWriter (annotateComments Gram.specParser pf)
          else (pf, [])
    (cm_pus', log1) = runWriter (transformBiM perPU cm_pus)
    (blocks', log2) = runInferer ivMap [] F.NamelessBlockData tenv blocksInf
    blocksInf       = let ?flowsGraph = flTo
                          ?nameMap    = nameMap
                      in descendBiM (perBlockInfer mode) blocks
    
    perPU :: F.ProgramUnit (FA.Analysis A)
          -> Writer [LogLine] (F.ProgramUnit (FA.Analysis A))
    perPU pu | Just _ <- FA.bBlocks $ F.getAnnotation pu =
         let ?flowsGraph = flTo
             ?nameMap    = nameMap
         in do
              let pum = descendBiM (perBlockInfer mode) pu
              let (pu', log) = runInferer ivMap [] (FA.puName pu) tenv pum
              tell log
              return pu'
    perPU pu = return pu
    
    ivMap = FAD.genInductionVarMapByASTBlock beMap gr
    
    rd    = FAD.reachingDefinitions dm gr
    
    flTo =  FAD.genFlowsToGraph bm dm gr rd
    
    beMap = FAD.genBackEdgeMap (FAD.dominators gr) gr
    
    bm    = FAD.genBlockMap pf'
    
    bbm   = FAB.genBBlockMap pf'
    
    sgr   = FAB.genSuperBBGr bbm
    
    gr    = FAB.superBBGrGraph sgr
    
    dm    = FAD.genDefMap bm
    tenv  = FAT.inferTypes pf
findVarFlowCycles :: Data a => F.ProgramFile a -> [(F.Name, F.Name)]
findVarFlowCycles = FAR.underRenaming (findVarFlowCycles' . FAB.analyseBBlocks)
findVarFlowCycles' pf = cycs2
  where
    bm    = FAD.genBlockMap pf     
    bbm   = FAB.genBBlockMap pf    
    sgr   = FAB.genSuperBBGr bbm   
    gr    = FAB.superBBGrGraph sgr 
    dm    = FAD.genDefMap bm       
    rd    = FAD.reachingDefinitions dm gr   
    flTo  = FAD.genFlowsToGraph bm dm gr rd 
    
    flMap = FAD.genVarFlowsToMap dm flTo 
    
    cycs2 = [ (n, m) | (n, ns) <- M.toList flMap
                     , m       <- S.toList ns
                     , ms      <- maybeToList $ M.lookup m flMap
                     , n `S.member` ms && n /= m ]
genSpecsAndReport :: Params
  => InferMode -> FU.SrcSpan -> [Neighbour]
  -> [F.Block (FA.Analysis A)]
  -> Inferer [([Variable], Specification)]
genSpecsAndReport mode span lhs blocks = do
    (IS ivmap _) <- get
    let (specs, evalInfos) = runWriter $ genSpecifications ivmap lhs blocks
    tell [ (span, Left specs) ]
    if mode == EvalMode
      then do
         tell [ (span, Right ("EVALMODE: assign to relative array subscript\
                              \ (tag: tickAssign)","")) ]
         mapM_ (\evalInfo -> tell [ (span, Right evalInfo) ]) evalInfos
         mapM_ (\spec -> if show spec == ""
                          then tell [ (span, Right ("EVALMODE: Cannot make spec\
                                                   \ (tag: emptySpec)","")) ]
                          else return ()) specs
         return specs
      else return specs
isArraySubscript :: F.Expression (FA.Analysis A)
                 -> Maybe [F.Index (FA.Analysis A)]
isArraySubscript (F.ExpSubscript _ _ (F.ExpValue _ _ (F.ValVariable _)) subs) =
   Just $ F.aStrip subs
isArraySubscript (F.ExpDataRef _ _ _ e) = isArraySubscript e
isArraySubscript _ = Nothing
fromJustMsg msg (Just x) = x
fromJustMsg msg Nothing = error msg
perBlockInfer :: Params
    => InferMode -> F.Block (FA.Analysis A) -> Inferer (F.Block (FA.Analysis A))
perBlockInfer Synth b@(F.BlComment ann span _) = do
  
  
  
  ann' <- return $ FA.prevAnnotation ann
  
  case (stencilSpec ann', stencilBlock ann') of
    
    (Just (Left (Gram.SpecDec _  vars)), Just block) ->
     
     case block of
      s@(F.BlStatement _ span _ assg@(F.StExpressionAssign _ _ _ _)) -> do
         
         state <- get
         put (state { hasSpec = hasSpec state ++ zip (repeat span) vars })
    _ -> return ()
  return b
perBlockInfer mode b@(F.BlStatement ann span@(FU.SrcSpan lp up) _ stmnt)
  | mode == AssignMode || mode == CombinedMode || mode == EvalMode || mode == Synth = do
    
    let lhses = [lhs | (F.StExpressionAssign _ _ lhs _)
                           <- universe stmnt :: [F.Statement (FA.Analysis A)]]
    (IS ivmap hasSpec) <- get
    specs <- flip mapM lhses
    
      (\lhs -> do
         case isArraySubscript lhs of
           Just subs ->
             
             case neighbourIndex ivmap subs of
               Just lhs -> genSpecsAndReport mode span lhs [b]
               Nothing  -> if mode == EvalMode
                           then do
                             tell [(span , Right ("EVALMODE: LHS is an array\
                                                 \ subscript we can't handle \
                                                 \(tag: LHSnotHandled)",""))]
                             return []
                           else return []
           
           _ -> return [])
    if mode == Synth && not (null specs)
      then
        let specComment = Synth.formatSpec (Just (tabs ++ "!= ")) ?nameMap (span, Left (concat specs'))
            specs' = map (mapMaybe noSpecAlready) specs
            noSpecAlready (vars, spec) =
               if null vars'
               then Nothing
               else Just (vars', spec)
               where vars' = filter (\v -> not ((span, realName v) `elem` hasSpec)) vars
            realName v = v `fromMaybe` (v `M.lookup` ?nameMap)
            tabs  = take (FU.posColumn lp   1) (repeat ' ')
            loc   = fst $ O.srcSpanToSrcLocs span
            span' = FU.SrcSpan (lp {FU.posColumn = 0}) (lp {FU.posColumn = 0})
            ann'  = ann { FA.prevAnnotation = (FA.prevAnnotation ann) { refactored = Just loc } }
        in return $ F.BlComment ann' span' specComment
      else return b
perBlockInfer mode b@(F.BlDo ann span x mDoSpec body) = do
    
    if (mode == DoMode || mode == CombinedMode) && isStencilDo b
      then genSpecsAndReport mode span [] body
      else return []
    
    body' <- mapM (descendBiM (perBlockInfer  mode)) body
    
    return $ F.BlDo ann span x mDoSpec body'
perBlockInfer mode b = do
    
    mapM_ (descendBiM (perBlockInfer mode)) $ children b
    return b
genSpecifications :: Params
  => FAD.InductionVarMapByASTBlock
  -> [Neighbour]
  -> [F.Block (FA.Analysis A)]
  -> Writer EvalLog [([Variable], Specification)]
genSpecifications ivs lhs blocks = do
    let subscripts = evalState (mapM (genSubscripts True) blocks) []
    varToMaybeSpecs <- sequence . map strength . mkSpecs $ subscripts
    let varToSpecs = catMaybes . map strength $ varToMaybeSpecs
    case varToSpecs of
      [] -> do
         tell [("EVALMODE: Empty specification (tag: emptySpec)", "")]
         return []
      _ -> do
         let varsToSpecs = groupKeyBy varToSpecs
         return $ splitUpperAndLower varsToSpecs
    where
      mkSpecs = M.toList
              . M.mapWithKey (\v -> indicesToSpec ivs v lhs)
              . M.unionsWith (++)
      strength :: Monad m => (a, m b) -> m (a, b)
      strength (a, mb) = mb >>= (\b -> return (a, b))
      splitUpperAndLower = concatMap splitUpperAndLower'
      splitUpperAndLower' (vs, Specification (Left (Bound (Just l) (Just u)))) =
         [(vs, Specification (Left (Bound (Just l) Nothing))),
          (vs, Specification (Left (Bound Nothing (Just u))))]
      splitUpperAndLower' x = [x]
genSubscripts :: Params
  => Bool
  -> F.Block (FA.Analysis A)
  -> State [Int] (M.Map Variable [[F.Index (FA.Analysis A)]])
genSubscripts False (F.BlStatement _ _ _ (F.StExpressionAssign _ _ e _))
    | isArraySubscript e /= Nothing
    
    = return M.empty
genSubscripts top block = do
    visited <- get
    case (FA.insLabel $ F.getAnnotation block) of
      Just node
        | node `elem` visited ->
          
          return $ M.empty
        | otherwise -> do
          
          put $ node : visited
          let blocksFlowingIn = mapMaybe (lab ?flowsGraph) $ pre ?flowsGraph node
          dependencies <- mapM (genSubscripts False) blocksFlowingIn
          return $ M.unionsWith (++) (genRHSsubscripts block : dependencies)
      Nothing -> error $ "Missing a label for: " ++ show block
genRHSsubscripts ::
     F.Block (FA.Analysis A)
  -> M.Map Variable [[F.Index (FA.Analysis A)]]
genRHSsubscripts b =
    collect [ (FA.varName exp, e)
      | F.ExpSubscript _ _ exp subs <- FA.rhsExprs b
      , isVariableExpr exp
      , let e = F.aStrip subs
      , not (null e)]
getInductionVar :: Maybe (F.DoSpecification (FA.Analysis A)) -> [Variable]
getInductionVar (Just (F.DoSpecification _ _ (F.StExpressionAssign _ _ e _) _ _))
  | isVariableExpr e = [FA.varName e]
getInductionVar _ = []
isStencilDo :: F.Block (FA.Analysis A) -> Bool
isStencilDo b@(F.BlDo _ span _ mDoSpec body) =
 
 
 case getInductionVar mDoSpec of
    [] -> False
    [ivar] -> length exprs > 0 &&
               and [ all (\sub -> sub `isNeighbour` [ivar]) subs' |
               F.ExpSubscript _ _ _ subs <- exprs
               , let subs' = F.aStrip subs
               , not (null subs') ]
      where exprs = universeBi upToNextDo :: [F.Expression (FA.Analysis A)]
            upToNextDo = takeWhile (not . isDo) body
            isDo (F.BlDo {}) = True
            isDo _            = False
isStencilDo _  = False
padZeros :: [[Int]] -> [[Int]]
padZeros ixss = let m = maximum (map length ixss)
                in map (\ixs -> ixs ++ replicate (m  length ixs) 0) ixss
indicesToSpec :: FAD.InductionVarMapByASTBlock
              -> Variable
              -> [Neighbour]
              -> [[F.Index (FA.Analysis Annotation)]]
              -> Writer EvalLog (Maybe Specification)
indicesToSpec ivs a lhs ixs = do
   
  let rhses = map (map (ixToNeighbour ivs)) ixs
  
  
  let (rhses', mult) = hasDuplicates rhses
  
  if not (consistentIVSuse lhs rhses')
    then do tell [("EVALMODE: Inconsistent IV use (tag: inconsistentIV)", "")]
            return Nothing
    else
      
      
      if hasNonNeighbourhoodRelatives rhses'
      then do tell [("EVALMODE: Non-neighbour relative subscripts\
                    \ (tag: nonNeighbour)","")]
              return Nothing
      else do
        
        let rhses'' = relativise lhs rhses'
        if rhses' /= rhses''
          then  tell [("EVALMODE: Relativized spec (tag: relativized)", "")]
          else return ()
        let offsets  = padZeros $ map (fromJust . mapM neighbourToOffset) rhses''
        tell [("EVALMODE: dimensionality=" ++
                 show (case offsets of [] -> 0
                                       _  -> length (head offsets)), a)]
        let spec = relativeIxsToSpec offsets
        return $ fmap (setLinearity (fromBool mult)) spec
  where hasNonNeighbourhoodRelatives xs = or (map (any ((==) NonNeighbour)) xs)
relativise :: [Neighbour] -> [[Neighbour]] -> [[Neighbour]]
relativise lhs rhses = foldr relativiseRHS rhses lhs
    where
      relativiseRHS (Neighbour lhsIV i) rhses =
          map (map (relativiseBy lhsIV i)) rhses
      relativiseRHS _ rhses = rhses
      relativiseBy v i (Neighbour u j) | v == u = Neighbour u (j  i)
      
      relativiseBy v i (Neighbour "" j)         = Constant (F.ValInteger "")
      relativiseBy _ _ x = x
consistentIVSuse :: [Neighbour] -> [[Neighbour]] -> Bool
consistentIVSuse lhs [] = True
consistentIVSuse lhs rhses =
  consistentRHS /= Nothing && (all consistentWithLHS (fromJust consistentRHS))
    where
      cmp (Neighbour v i) (Neighbour v' _) | v == v'   = Just $ Neighbour v i
                                           | otherwise = Nothing
      
      cmp n@(Neighbour {})  (Constant _)   = Just n
      cmp (Constant _) n@(Neighbour {})    = Just n
      cmp (NonNeighbour {}) (Neighbour {}) = Nothing
      cmp (Neighbour {}) (NonNeighbour{})  = Nothing
      cmp _ _                              = Just $ Constant (F.ValInteger "")
      consistentRHS = foldrM (\a b -> mapM (uncurry cmp) $ zip a b) (head rhses) (tail rhses)
      
      
      consistentWithLHS :: Neighbour -> Bool
      consistentWithLHS (Neighbour rv _) = any (matchesIV rv) lhs
      consistentWithLHS _                = True
      matchesIV :: Variable -> Neighbour -> Bool
      matchesIV v (Neighbour v' _) | v == v' = True
      
      matchesIV v (Neighbour v' _) | v == "" = True
      matchesIV v (Neighbour v' _) | v' == "" = True
      matchesIV _ _                          = False
relativeIxsToSpec :: [[Int]] -> Maybe Specification
relativeIxsToSpec ixs =
    if isEmpty exactSpec then Nothing else Just exactSpec
    where exactSpec = inferFromIndicesWithoutLinearity . fromLists $ ixs
isNeighbour :: Data a => F.Index (FA.Analysis a) -> [Variable] -> Bool
isNeighbour exp vs =
    case (ixToNeighbour' vs exp) of
        Neighbour _ _ -> True
        _             -> False
neighbourIndex :: FAD.InductionVarMapByASTBlock
               -> [F.Index (FA.Analysis Annotation)] -> Maybe [Neighbour]
neighbourIndex ivs ixs =
  if all ((/=) NonNeighbour) neighbours
  then Just neighbours
  else Nothing
    where neighbours = map (ixToNeighbour ivs) ixs
data Neighbour = Neighbour Variable Int
               | Constant (F.Value ())
               | NonNeighbour deriving (Eq, Show)
neighbourToOffset :: Neighbour -> Maybe Int
neighbourToOffset (Neighbour _ o) = Just o
neighbourToOffset (Constant _)    = Just absoluteRep
neighbourToOffset _               = Nothing
ixToNeighbour :: FAD.InductionVarMapByASTBlock
              -> F.Index (FA.Analysis Annotation) -> Neighbour
ixToNeighbour ivmap f = ixToNeighbour' ivsList f
  where
    insl = FA.insLabel . F.getAnnotation $ f
    errorMsg = show (ixsspan f)
            ++ " get IVs associated to labelled index "
            ++ show insl
    insl' = fromJustMsg errorMsg insl
    ivsList = S.toList $ fromMaybe S.empty $ IM.lookup insl'  ivmap
    
    ixsspan :: F.Index (FA.Analysis A)  -> FU.SrcSpan
    ixsspan  (F.IxRange _ sp _ _ _) = sp
    ixsspan (F.IxSingle _ sp _ _ ) = sp
ixToNeighbour' ivs (F.IxRange _ _ Nothing Nothing Nothing)     = Neighbour "" 0
ixToNeighbour' ivs (F.IxRange _ _ Nothing Nothing
                  (Just (F.ExpValue _ _ (F.ValInteger "1")))) = Neighbour "" 0
ixToNeighbour' ivs (F.IxSingle _ _ _ exp)  = expToNeighbour ivs exp
ixToNeighbour' _ _ = NonNeighbour 
expToNeighbour :: forall a. Data a
            => [Variable] -> F.Expression (FA.Analysis a) -> Neighbour
expToNeighbour ivs e@(F.ExpValue _ _ v@(F.ValVariable _))
    | FA.varName e `elem` ivs = Neighbour (FA.varName e) 0
    | otherwise               = Constant (fmap (const ()) v)
expToNeighbour ivs (F.ExpValue _ _ val) = Constant (fmap (const ()) val)
expToNeighbour ivs (F.ExpBinary _ _ F.Addition
                 e@(F.ExpValue _ _ (F.ValVariable _))
                   (F.ExpValue _ _ (F.ValInteger offs)))
    | FA.varName e `elem` ivs = Neighbour (FA.varName e) (read offs)
expToNeighbour ivs (F.ExpBinary _ _ F.Addition
                  (F.ExpValue _ _ (F.ValInteger offs))
                e@(F.ExpValue _ _ (F.ValVariable _)))
    | FA.varName e `elem` ivs = Neighbour (FA.varName e) (read offs)
expToNeighbour ivs (F.ExpBinary _ _ F.Subtraction
                 e@(F.ExpValue _ _ (F.ValVariable _))
                   (F.ExpValue _ _ (F.ValInteger offs)))
   | FA.varName e `elem` ivs =
         Neighbour (FA.varName e) (if x < 0 then abs x else ( x))
             where x = read offs
expToNeighbour ivs e =
  
  
  if null ivs' then Constant (F.ValInteger "0") else NonNeighbour
  where
    
    ivs' = [i | e@(F.ExpValue _ _ (F.ValVariable {}))
                 <- universeBi e :: [F.Expression (FA.Analysis a)]
                , let i = FA.varName e
                , i `elem` ivs]
expToNeighbour ivs e = Constant (F.ValInteger "0")
isUnaryOrBinaryExpr :: F.Expression a -> Bool
isUnaryOrBinaryExpr (F.ExpUnary {})  = True
isUnaryOrBinaryExpr (F.ExpBinary {}) = True
isUnaryOrBinaryExpr _                = False
isVariableExpr :: F.Expression a -> Bool
isVariableExpr (F.ExpValue _ _ (F.ValVariable _)) = True
isVariableExpr _                                  = False
type TypeEnv a = M.Map FAT.TypeScope (M.Map String FA.IDType)
isArrayType :: TypeEnv A -> F.ProgramUnitName -> String -> Bool
isArrayType tenv name v = fromMaybe False $ do
  tmap <- M.lookup (FAT.Local name) tenv `mplus` M.lookup FAT.Global tenv
  idty <- M.lookup v tmap
  cty  <- FA.idCType idty
  return $ cty == FA.CTArray