{-# LANGUAGE TypeFamilies #-}
module Test.LowerUpper (testsVar) where

import qualified Test.Divide as Divide
import qualified Test.Generator as Gen
import qualified Test.Utility as Util
import Test.Generator ((<#\#>), (<#*#>))
import Test.Utility (Tagged, approx, approxMatrix, maybeConjugate)

import qualified Numeric.LAPACK.Linear.LowerUpper as LU
import qualified Numeric.LAPACK.Matrix.Permutation as PermMatrix
import qualified Numeric.LAPACK.Matrix.Extent as Extent
import qualified Numeric.LAPACK.Matrix.Square as Square
import qualified Numeric.LAPACK.Matrix as Matrix
import qualified Numeric.LAPACK.Permutation as Perm
import Numeric.LAPACK.Matrix.Square (Square)
import Numeric.LAPACK.Matrix (ShapeInt, (#*#), (#*##))
import Numeric.LAPACK.Scalar (RealOf, selectReal)

import qualified Numeric.Netlib.Class as Class

import Control.Applicative (liftA2, (<$>))

import Data.Semigroup ((<>))

import qualified Test.QuickCheck as QC


toFromTallMatrix ::
   (Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   Matrix.Tall ShapeInt ShapeInt a -> Bool
toFromTallMatrix a =
   approxMatrix 1e-5 a (LU.toMatrix $ LU.fromMatrix a)

{-
Strictly wide matrices are problematic,
because a full rank wide matrix can have a leading column
consisting entirely of zeros.
To prevent this, the LU decomposition would need column pivoting.
For now we restrict to Square matrices.
-}
toFromSquareMatrix ::
   (Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   Square ShapeInt a -> Bool
toFromSquareMatrix a =
   approxMatrix 1e-5 a (LU.toMatrix $ LU.fromMatrix a)


multiplyPApply ::
   (Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   (LU.Inversion, LU.Inversion) ->
   (Square ShapeInt a, Matrix.General ShapeInt ShapeInt a) -> Bool
multiplyPApply (inv0,inv1) (a,b) =
   let lu = LU.fromMatrix a
   in approxMatrix (selectReal 1e-1 1e-5)
         (LU.multiplyP (inv0<>inv1) lu b)
         (Perm.apply inv0 (PermMatrix.toPermutation $ LU.extractP inv1 lu) b)

multiplyP ::
   (Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   LU.Inversion -> (Square ShapeInt a, Matrix.General ShapeInt ShapeInt a) -> Bool
multiplyP inv (a,b) =
   let lu = LU.fromMatrix a
   in approxMatrix (selectReal 1e-1 1e-5)
         (LU.multiplyP inv lu b)
         (LU.extractP inv lu #*## b)

multiplyL ::
   (Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   LU.Transposition ->
   (Square ShapeInt a, Matrix.General ShapeInt ShapeInt a) -> Bool
multiplyL trans (a,b) =
   let lu = LU.fromMatrix a
   in approxMatrix (selectReal 1e-1 1e-5)
         (LU.wideMultiplyL trans lu b)
         (Matrix.multiplySquare trans (LU.extractL lu) b)

wideMultiplyL ::
   (Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   LU.Transposition ->
   (Square ShapeInt a, Matrix.General ShapeInt ShapeInt a) -> Bool
wideMultiplyL trans (a,b) =
   let lu = LU.fromMatrix a
   in approxMatrix (selectReal 1e-1 1e-5)
         (LU.wideMultiplyL trans lu b)
         (Matrix.multiplySquare trans (LU.wideExtractL lu) b)

multiplyU ::
   (Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   LU.Transposition ->
   (Square ShapeInt a, Matrix.General ShapeInt ShapeInt a) -> Bool
multiplyU trans (a,b) =
   let lu = LU.fromMatrix a
   in approxMatrix (selectReal 1e-1 1e-5)
         (LU.tallMultiplyU trans lu b)
         (Matrix.multiplySquare trans (LU.extractU lu) b)

tallMultiplyU ::
   (Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   LU.Transposition ->
   (Matrix.Tall ShapeInt ShapeInt a, Matrix.General ShapeInt ShapeInt a) -> Bool
tallMultiplyU trans (a,b) =
   let lu = LU.fromMatrix a
   in approxMatrix (selectReal 1e-1 1e-5)
         (LU.tallMultiplyU trans lu b)
         (Matrix.multiplySquare trans (LU.tallExtractU lu) b)

multiplySquareFull ::
   (Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   (Square ShapeInt a, Matrix.General ShapeInt ShapeInt a) -> Bool
multiplySquareFull (a,b) =
   approxMatrix (selectReal 1e-1 1e-5)
      (a #*## b)
      (LU.multiplyFull (LU.mapExtent Extent.fromSquare $ LU.fromMatrix a) b)

multiplyTallFull ::
   (Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   (Matrix.Tall ShapeInt ShapeInt a, Matrix.General ShapeInt ShapeInt a) -> Bool
multiplyTallFull (a,b) =
   approxMatrix (selectReal 1e-1 1e-5)
      (a #*# b)
      (LU.multiplyFull (LU.mapExtent Extent.generalizeTall $ LU.fromMatrix a) b)

determinant ::
   (Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   Square ShapeInt a -> Bool
determinant a =
   approx (selectReal 1e-1 1e-5)
      (Square.determinant a)
      (LU.determinant $ LU.fromMatrix a)

wideSolveL ::
   (Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   (LU.Transposition, LU.Conjugation) ->
   (Square ShapeInt a, Matrix.General ShapeInt ShapeInt a) -> Bool
wideSolveL (trans,conj) (a,b) =
   let lu = LU.fromMatrix a
       l = maybeConjugate conj $ LU.wideExtractL lu
   in approxMatrix (selectReal 1e-1 1e-5)
         (LU.wideSolveL trans conj lu b)
         (Matrix.solve trans l b)

tallSolveU ::
   (Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   (LU.Transposition, LU.Conjugation) ->
   (Matrix.Tall ShapeInt ShapeInt a, Matrix.General ShapeInt ShapeInt a) -> Bool
tallSolveU (trans,conj) (a,b) =
   let lu = LU.fromMatrix a
       u = maybeConjugate conj $ LU.tallExtractU lu
   in approxMatrix (selectReal 1e-1 1e-5)
         (LU.tallSolveU trans conj lu b)
         (Matrix.solve trans u b)

solve ::
   (Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   (Square ShapeInt a, Matrix.General ShapeInt ShapeInt a) -> Bool
solve (a,b) =
   approxMatrix (selectReal 1e-1 1e-5)
      (Square.solve a b)
      (LU.solve (LU.fromMatrix a) b)


checkForAll ::
   (Show a, QC.Testable test) =>
   Gen.T dim tag a -> (a -> test) -> Tagged tag QC.Property
checkForAll gen = Util.checkForAll (Gen.run gen 3 5)

checkForAllExtra ::
   (Show a, Show b, QC.Testable test) =>
   QC.Gen a -> Gen.T dim tag b ->
   (a -> b -> test) -> Tagged tag QC.Property
checkForAllExtra = Gen.withExtra checkForAll


testsVar ::
   (Show a, Class.Floating a, Eq a, RealOf a ~ ar, Class.Real ar) =>
   [(String, Tagged a QC.Property)]
testsVar =
   ("toFromTallMatrix",
      checkForAll Gen.fullRankTall toFromTallMatrix) :
   ("toFromSquareMatrix",
      checkForAll Gen.invertible toFromSquareMatrix) :
   ("multiplyPApply",
      checkForAllExtra
         (liftA2 (,) QC.arbitraryBoundedEnum QC.arbitraryBoundedEnum)
         ((,) <$> Gen.invertible <#*#> Gen.matrix) multiplyPApply) :
   ("multiplyP",
      checkForAllExtra QC.arbitraryBoundedEnum
         ((,) <$> Gen.invertible <#*#> Gen.matrix) multiplyP) :
   ("multiplyL",
      checkForAllExtra QC.arbitraryBoundedEnum
         ((,) <$> Gen.invertible <#*#> Gen.matrix) multiplyL) :
   ("wideMultiplyL",
      checkForAllExtra QC.arbitraryBoundedEnum
         ((,) <$> Gen.invertible <#*#> Gen.matrix) wideMultiplyL) :
   ("multiplyU",
      checkForAllExtra QC.arbitraryBoundedEnum
         ((,) <$> Gen.invertible <#*#> Gen.matrix) multiplyU) :
   ("tallMultiplyU",
      checkForAllExtra QC.arbitraryBoundedEnum
         ((,) <$> Gen.fullRankTall <#*#> Gen.matrix) tallMultiplyU) :
   ("multiplySquareFull",
      checkForAll
         ((,) <$> Gen.invertible <#*#> Gen.matrix) multiplySquareFull) :
   ("multiplyTallFull",
      checkForAll
         ((,) <$> Gen.fullRankTall <#*#> Gen.matrix) multiplyTallFull) :
   ("determinant",
      checkForAll Gen.invertible determinant) :
   ("wideSolveL",
      checkForAllExtra
         (liftA2 (,) QC.arbitraryBoundedEnum QC.arbitraryBoundedEnum)
         ((,) <$> Gen.invertible <#\#> Gen.matrix) wideSolveL) :
   ("tallSolveU",
      checkForAllExtra
         (liftA2 (,) QC.arbitraryBoundedEnum QC.arbitraryBoundedEnum)
         ((,) <$> Gen.fullRankTall <#*#> Gen.matrix) tallSolveU) :
   ("solve",
      checkForAll ((,) <$> Gen.invertible <#\#> Gen.matrix) solve) :
   Divide.testsVar (LU.fromMatrix <$> Gen.invertible) ++
   []