module Camfort.Specification.Units
          (Solver, removeUnits, checkUnits
                 , inferUnits, synthesiseUnits
                 , inferCriticalVariables)  where
import Data.Data
import Data.Char (isNumber)
import Data.List
import Data.Maybe
import qualified Data.Map as M
import Data.Label.Mono (Lens)
import qualified Data.Label
import Data.Label.Monadic hiding (modify)
import Data.Function
import Data.Matrix
import Data.Generics.Uniplate.Operations
import Control.Monad.State.Strict hiding (gets)
import Control.Monad.Identity
import Camfort.Helpers
import Camfort.Output
import Camfort.Analysis.Annotations hiding (Unitless)
import Camfort.Analysis.Syntax
import Camfort.Analysis.Types
import Camfort.Input
import Camfort.Specification.Units.Debug
import Camfort.Specification.Units.InferenceBackend
import Camfort.Specification.Units.InferenceFrontend
import qualified Camfort.Specification.Units.Synthesis as US
import Camfort.Specification.Units.Strip
import Camfort.Specification.Units.Environment
import Camfort.Specification.Units.Solve
import qualified Language.Fortran.Analysis.Renaming as FAR
import qualified Language.Fortran.Analysis.BBlocks as FAB
import qualified Language.Fortran.Analysis as FA
import qualified Language.Fortran.AST as F
import Camfort.Transformation.Syntax
import qualified Debug.Trace as D
instance Default Solver where
    defaultValue = Custom
instance Default AssumeLiterals where
    defaultValue = Poly
removeUnits ::
    (Filename, F.ProgramFile Annotation) -> (Report, (Filename, F.ProgramFile Annotation))
removeUnits (fname, x) = undefined
  
type Params = (?solver :: Solver, ?assumeLiterals :: AssumeLiterals)
inferCriticalVariables ::
       Params
    => (Filename, F.ProgramFile Annotation)
    -> (Report, (Filename, F.ProgramFile Annotation))
inferCriticalVariables (fname, pf) = (r, (fname, pf))
  where
    
    r = concat [fname ++ ": " ++ r ++ "\n" | r <- Data.Label.get report env]
    
    
    env = let ?criticals = True
              ?debug     = False
              ?nameMap   = nameMap
              ?argumentDecls = False
          in  flip execState emptyUnitEnv
              (do
                 doInferUnits . FAB.analyseBBlocks $ pf'
                 vars <- criticalVars nameMap
                 case vars of
                   [] -> report <<++ "No critical variables. Appropriate annotations."
                   _  -> report <<++ "Critical variables: "
                                ++ (concat $ intersperse "," vars)
                 ifDebug debugGaussian)
    pf' = FAR.analyseRenames . FA.initAnalysis $ (fmap mkUnitAnnotation pf)
    nameMap = FAR.extractNameMap pf'
    
    
checkUnits ::
       Params
    => (Filename, F.ProgramFile Annotation)
    -> (Report, (Filename, F.ProgramFile Annotation))
checkUnits (fname, pf) = (r, (fname, pf))
  where
    
    r = concat [fname ++ ": " ++ r ++ "\n" | r <- Data.Label.get report env, not (null r)]
        ++ fname ++ ": checked " ++ show n ++ " user variables\n"
    
    n = countVariables (_varColEnv env) (_debugInfo env) (_procedureEnv env)
                                    (fst $ _linearSystem env) (_unitVarCats env)
    pf' = FAB.analyseBBlocks . FAR.analyseRenames . FA.initAnalysis $ (fmap mkUnitAnnotation pf)
    nameMap = FAR.extractNameMap pf'
    
    env = let ?criticals = False
              ?debug     = False
              ?nameMap   = nameMap
              ?argumentDecls = False
          in execState (doInferUnits pf') emptyUnitEnv
inferUnits ::
       Params
    => (Filename, F.ProgramFile Annotation)
    -> (Report, (Filename, F.ProgramFile Annotation))
inferUnits (fname, pf) = (r, (fname, pf))
  where
    
    r = fname ++ ":\n" ++
        concat [ r ++ "\n" | r <- Data.Label.get report env, not (null r)]
        ++ fname ++ ": checked/inferred " ++ show n ++ " user variables\n"
    
    n = countVariables (_varColEnv env) (_debugInfo env) (_procedureEnv env)
                                    (fst $ _linearSystem env) (_unitVarCats env)
    
    (_, env) = runState inferAndSynthesise emptyUnitEnv
    inferAndSynthesise =
        let ?criticals = False
            ?debug     = False
            ?nameMap   = nameMap
            ?argumentDecls = False
        in do
          doInferUnits pf'
          succeeded <- gets success
          if succeeded
            then US.synthesiseUnits True pf'
            else return pf'
    pf' = FAB.analyseBBlocks
        . FAR.analyseRenames
        . FA.initAnalysis
        . fmap mkUnitAnnotation $ pf
    nameMap = FAR.extractNameMap pf'
synthesiseUnits ::
       Params
    => (Filename, F.ProgramFile Annotation)
    -> (Report, (Filename, F.ProgramFile Annotation))
synthesiseUnits (fname, pf) = (r, (fname, fmap (prevAnnotation . FA.prevAnnotation) pf3))
  where
    
    r = concat [fname ++ ": " ++ r ++ "\n" | r <- Data.Label.get report env, not (null r)]
        ++ fname ++ ": checked/inferred " ++ show n ++ " user variables\n"
    
    n = countVariables (_varColEnv env) (_debugInfo env) (_procedureEnv env)
                                    (fst $ _linearSystem env) (_unitVarCats env)
    
    pf' = FAB.analyseBBlocks . FAR.analyseRenames . FA.initAnalysis $ (fmap mkUnitAnnotation pf)
    (pf3, env) = runState inferAndSynthesise emptyUnitEnv
    nameMap = FAR.extractNameMap pf'
    inferAndSynthesise =
        let ?criticals = False
            ?debug     = False
            ?nameMap   = nameMap
            ?argumentDecls = False
        in do
          doInferUnits pf'
          succeeded <- gets success
          if succeeded
            then do
              p <- US.synthesiseUnits False pf'
              (n, _) <- gets evUnitsAdded
              report <<++ ("Added " ++ (show n) ++ " annotations")
              return p
            else return pf'
countVariables vars debugs procs matrix ucats =
    length $ filter (\c -> case (ucats !! (c  1)) of
                             Variable -> case (lookupVarsByCols vars [c]) of
                                           [] -> False
                                           _  -> True
                             Argument -> case (lookupVarsByCols vars [c]) of
                                           [] -> False
                                           _  -> True
                             _        -> False) [1..ncols matrix]
criticalVars :: FAR.NameMap -> State UnitEnv [String]
criticalVars nameMap = do
    uvarenv     <- gets varColEnv
    (matrix, vector) <- gets linearSystem
    ucats       <- gets unitVarCats
    dbgs        <- gets debugInfo
    
    let cv1 = criticalVars' uvarenv ucats matrix 1 dbgs
    let cv2 = [] 
    return (map realName (cv1 ++ cv2))
      where realName v = v `fromMaybe` (v `M.lookup` nameMap)
criticalVars' :: VarColEnv
              -> [UnitVarCategory]
              -> Matrix Rational
              -> Row
              -> DebugInfo -> [String]
criticalVars' varenv ucats matrix i dbgs =
  let m = firstNonZeroCoeff matrix ucats
  in
    if (i == nrows matrix) then
      if (m i) /= (ncols matrix)
      then lookups [((m i) + 1)..(ncols matrix)] dbgs
      else []
    else
      if (m (i + 1)) /= ((m i) + 1)
      then (lookups [((m i) + 1)..(m (i + 1)  1)] dbgs)
        ++ (criticalVars' varenv ucats matrix (i + 1) dbgs)
      else criticalVars' varenv ucats matrix (i + 1) dbgs
  where
    lookups = lookupVarsByColsFilterByArg matrix varenv ucats