module Camfort.Specification.Units.Solve where
import Data.Ratio
import Data.List
import qualified Data.Matrix as DM
import qualified Data.Vector as V
import Control.Exception
import System.IO.Unsafe
import qualified Debug.Trace as D
import Language.Fortran
import Camfort.Specification.Units.Environment
import Camfort.Specification.Units.SolveHMatrix
solveSystem :: (?solver :: Solver) => LinearSystem -> Consistency LinearSystem
solveSystem = case ?solver of
                  Custom -> solveSystemH
solveSystemC :: LinearSystem -> Consistency LinearSystem
solveSystemC system = solveSystem' system 1 1
solveSystem' :: LinearSystem -> Col -> Row -> Consistency LinearSystem
solveSystem' (matrix, vector) m k
  | m > DM.ncols matrix = efmap (cutSystem k) $ checkSystem (matrix, vector) k
  | otherwise = elimRow (matrix, vector) n m k
                where n = find (\n -> matrix DM.! (n, m) /= 0) [k .. DM.nrows matrix]
cutSystem :: Int -> LinearSystem -> LinearSystem
cutSystem k (matrix, vector) = (matrix', vector')
  where matrix' = DM.submatrix 1 (k  1) 1 (DM.ncols matrix) matrix
        vector' = take (k  1) vector
checkSystem :: LinearSystem -> Row -> Consistency LinearSystem
checkSystem (matrix, vector) k
  | k > DM.nrows matrix = Ok (matrix, vector)
  | vector !! (k  1) /= Unitful [] = let vars = V.toList $ DM.getRow k matrix
                                          bad = Bad (matrix, vector) k (vector !! (k  1), vars)
                                      in bad
  | otherwise = checkSystem (matrix, vector) (k + 1)
elimRow :: LinearSystem -> Maybe Row -> Col -> Row -> Consistency LinearSystem
elimRow system Nothing m k = solveSystem' system (m + 1) k
elimRow (matrix, vector) (Just n) m k = 
 solveSystem' system' (m + 1) (k + 1)
  where matrix' = let s = matrix DM.! (n, m) in
                    (if (k == n) then id else DM.switchRows k n)
                       (if s == 1 then matrix else DM.scaleRow (recip $ s) n matrix)
        vector' = switchScaleElems k n (fromRational $ recip $ matrix DM.! (n, m)) vector
        system' = elimRow' (matrix', vector') k m
msteeper matrix k m = msteep matrix 1
                       where
                         r = DM.nrows matrix
                         msteep matrix n | n > r = matrix
                                         | n == k = msteep matrix (n+1)
                                         | otherwise = let s = ( matrix DM.! (n, m))
                                                       in if s == 0 then msteep matrix (n+1)
                                                          else msteep (DM.combineRows n s k matrix) (n+1)
elimRow' :: LinearSystem -> Row -> Col -> LinearSystem
elimRow' (matrix, vector) k m = (matrix', vector')
  where mstep matrix n = let s = ( matrix DM.! (n, m)) in if s == 0 then matrix else DM.combineRows n s k matrix
        matrix' = foldl mstep matrix $ [1 .. k  1] ++ [k + 1 .. DM.nrows matrix]
        
        vector'' = [x  fromRational (matrix DM.! (n, m)) * vector !! (k  1) | (n, x) <- zip [1..] vector]
        (a, _ : b) = splitAt (k  1) vector''
        vector' = a ++ vector !! (k  1) : b
switchScaleElems :: Num a => Int -> Int -> a -> [a] -> [a]
switchScaleElems i j factor list = a ++ factor * b : c
  where (lj, b:rj) = splitAt (j  1) list
        (a, _:c) = splitAt (i  1) (lj ++ list !! (i  1) : rj)
solveSystemH :: LinearSystem -> Consistency LinearSystem
solveSystemH system@(m,v) =
  case convertToHMatrix system of
    Left  (n:_)       -> Bad system (DM.nrows m) (v !! n, V.toList (DM.getRow n m))
    Right (m', units) -> Ok sys'
      where
        m2   = rref m'
        m3   = takeRows (rank m2) m2
        sys' = convertFromHMatrix (m3, units)
solveSystemH_Either :: LinearSystem -> Either [Int] LinearSystem
solveSystemH_Either system@(m,v) =
  case convertToHMatrix system of
    Left  ns          -> Left ns
    Right (m', units) -> Right sys'
      where
        m2   = rref m'
        m3   = takeRows (rank m2) m2
        sys' = convertFromHMatrix (m3, units)