module Camfort.Specification.Units.Synthesis
                (synthesiseUnits, pprintUnitConstant) where
import Data.Function
import Data.List
import Data.Matrix
import Data.Maybe
import Data.Ratio (numerator, denominator)
import qualified Data.Map as M
import Data.Generics.Uniplate.Operations
import Data.Label.Monadic hiding (modify)
import Control.Monad.State.Strict hiding (gets)
import Control.Monad
import qualified Language.Fortran.AST as F
import qualified Language.Fortran.Analysis as FA
import qualified Language.Fortran.Analysis.Renaming as FAR
import qualified Language.Fortran.Util.Position as FU
import qualified Camfort.Output as O (srcSpanToSrcLocs)
import Camfort.Analysis.Annotations hiding (Unitless)
import Camfort.Specification.Units.Environment
import qualified Debug.Trace as D
type A1 = FA.Analysis (UnitAnnotation A)
type Params = ?nameMap :: FAR.NameMap
synthesiseUnits :: Params => Bool -> F.ProgramFile A1 -> State UnitEnv (F.ProgramFile A1)
synthesiseUnits inferReport pf = transformBiM (perBlock inferReport) pf
perBlock :: Params => Bool -> F.Block A1 -> State UnitEnv (F.Block A1)
perBlock inferReport s@(F.BlStatement a span@(FU.SrcSpan lp up) _
                               d@(F.StDeclaration _ _ _ _ decls)) = do
    vColEnv <- gets varColEnv
    let declNames = getNames (F.aStrip decls)
    if inferReport
      
      
      then do
        
        units <- mapM (\d -> findUnit d vColEnv) declNames
        mapM (\u -> fromMaybe (return ()) (fmap (\u -> report <<++ mkReport u) u)) units
        return s
      else do
        
        
        
        
        hasDec <- gets hasDeclaration
        let findUnitIfUndec d | d `elem` hasDec = Nothing
                              | otherwise       = Just $ findUnit d vColEnv
        units <- sequence $ mapMaybe findUnitIfUndec declNames
        
        (n, ad) <- gets evUnitsAdded
        evUnitsAdded =: (n + (length units), ad)
        
        let unitDecls = mapMaybe (fmap mkComment) units
        return $ (F.BlComment a' span0 (intercalate "\n" unitDecls))
    where
     
     mkReport (var, unit) = show (spanLineCol span) ++ "\t" ++ mkInfo (var, unit)
     
     mkInfo   (var, unit) = "unit (" ++ pprintUnitConstant unit  ++ ")"
                                    ++ " :: " ++ realName var
     
     mkComment (var, unit) = tabs ++ "!= " ++ mkInfo (var, unit)
     
     tabs =  take (FU.posColumn lp   1) (repeat ' ')
     
     span0 = FU.SrcSpan (lp {FU.posColumn = 0}) (lp {FU.posColumn = 0})
     
     ap = (prevAnnotation (FA.prevAnnotation a)) { refactored = Just loc }
     a' = a {FA.prevAnnotation = (FA.prevAnnotation a) { prevAnnotation = ap }}
     
     loc  = fst $ O.srcSpanToSrcLocs span
     
     realName v = v `fromMaybe` (v `M.lookup` ?nameMap)
     
     findUnit v colEnv =
        case lookupWithoutSrcSpan v colEnv of
          Just (VarCol m, _) -> do u <- lookupUnit m
                                   case u of
                                     Nothing -> return Nothing
                                     Just u  -> return $ Just (v, u)
          Nothing            -> return $ Nothing
     
     getNames ds =
          [FA.varName e | (F.DeclVariable _ _ e@(F.ExpValue {}) _ _)
             <- universeBi ds :: [F.Declarator A1]]
       ++ [FA.varName e | (F.DeclArray _ _ e@(F.ExpValue {}) _ _ _)
             <- universeBi ds :: [F.Declarator A1]]
perBlock _ b = return b
pprintUnitConstant :: UnitConstant -> String
pprintUnitConstant (UnitlessC 1) = "1"
pprintUnitConstant (UnitlessC r) = "1**(" ++ show r ++")"
pprintUnitConstant (Unitful ucs)  =
     numeratorU
  ++ (if not (null ucsNeg') then " / " else "")
  ++ denominatorU
  where numeratorU = if (null ucsPos) then "1" else numeratorA
        numeratorA = intercalate " " (map (uncurry pprintPow) ucsPos)
        denominatorU = intercalate " " (map (uncurry pprintPow) ucsNeg')
        ucsNeg' = map (\(n, r) -> (n, abs r)) ucsNeg
        (ucsNeg, ucsPos) = break ((>0) . snd) ucs'
        ucs' = sortBy (compare `on` snd) ucs
        pprintPow n 1 = n
        pprintPow n r = n ++ "**" ++ show' r
        show' r =
          if denominator r == 1
          then show $ numerator r
          else '(' : (show $ numerator r) ++ '/' : (show $ denominator r) ++ ")"
lookupUnit :: Col -> State UnitEnv (Maybe UnitConstant)
lookupUnit m = do
    
    
    system@(matrix, vector)  <- gets linearSystem
    ucats   <- gets unitVarCats
    badCols <- gets underdeterminedCols
    vColEnv <- gets varColEnv
    let n = find (\n -> matrix ! (n, m) /= 0) [1 .. nrows matrix]
    let defaultUnit = if ucats !! (m  1) == Argument then Nothing else Just (Unitful [])
    return $ maybe defaultUnit (lookupUnit' ucats badCols system m) n
lookupUnit' :: [UnitVarCategory] -> [Int] -> LinearSystem -> Int -> Int -> Maybe UnitConstant
lookupUnit' ucats badCols (matrix, vector) m n
  | not $ null ms = Nothing
  | ucats !! (m  1) /= Argument && m `notElem` badCols = Just $ vector !! (n  1)
  | ms' /= [m] = Nothing
  | otherwise = Just $ vector !! (n  1)
  where ms = filter significant [1 .. ncols matrix]
        significant m' = m' /= m && matrix ! (n, m') /= 0 && ucats !! (m'  1) == Argument
        ms' = filter (\m -> matrix ! (n, m) /= 0) [1 .. ncols matrix]
lineCol :: FU.Position -> (Int, Int)
lineCol p  = (fromIntegral $ FU.posLine p, fromIntegral $ FU.posColumn p)
spanLineCol :: FU.SrcSpan -> ((Int, Int), (Int, Int))
spanLineCol (FU.SrcSpan l u) = (lineCol l, lineCol u)