{-# LANGUAGE TypeFamilies #-}
module Test.Vector (testsVar) where

import qualified Test.Generator as Gen
import qualified Test.Utility as Util
import Test.Generator ((<+++>), (<.*#>), (<#*|>), (<|=|>))
import Test.Utility (Tagged, NonEmptyInt, EInt)

import qualified Numeric.LAPACK.Matrix.Triangular as Triangular
import qualified Numeric.LAPACK.Vector as Vector
import qualified Numeric.LAPACK.Scalar as Scalar
import Numeric.LAPACK.Matrix (ShapeInt, shapeInt, (-/#))
import Numeric.LAPACK.Vector (Vector, (|+|), (|-|), (.*|))
import Numeric.LAPACK.Scalar (RealOf)

import qualified Numeric.Netlib.Class as Class

import Control.Applicative ((<$>))

import qualified Data.Array.Comfort.Storable as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable ((!))

import qualified Data.NonEmpty as NonEmpty
import Data.NonEmpty ((!:))

import qualified Test.QuickCheck as QC
import Test.ChasingBottoms.IsBottom (isBottom)


singleton :: (Class.Floating a) => a -> Bool
singleton x = Util.equalVector (Vector.singleton x) (Vector.constant () x)


genSwapVector ::
   (Class.Floating a) =>
   Gen.Vector NonEmptyInt a ((EInt, EInt), Vector NonEmptyInt a)
genSwapVector =
   flip Gen.mapQC Gen.vector $ \x -> do
      let set = Shape.indices $ Array.shape x
      i <- QC.elements set
      j <- QC.elements set
      return ((i,j),x)

swapInverse ::
   (Eq sh, Shape.Indexed sh, Shape.Index sh ~ ix,
    Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   ((ix,ix), Vector sh a) -> Bool
swapInverse ((i,j),x) =
   Util.equalVector x $ Vector.swap i j $ Vector.swap i j x

swapCommutative ::
   (Eq sh, Shape.Indexed sh, Shape.Index sh ~ ix,
    Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   ((ix,ix), Vector sh a) -> Bool
swapCommutative ((i,j),x) =
   Util.equalVector (Vector.swap i j x) (Vector.swap j i x)


norm2Squared ::
   (Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   Vector ShapeInt a -> Bool
norm2Squared x =
   Util.approxReal
      (Scalar.selectReal 1e-3 1e-10)
      (Vector.norm2Squared x) (Vector.norm2 x ^ (2::Int))

norm2Inner ::
   (Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   Vector ShapeInt a -> Bool
norm2Inner x =
   Scalar.equal (Vector.inner x x) (Scalar.fromReal (Vector.norm2Squared x))


normInf ::
   (Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   Vector ShapeInt a -> Bool
normInf x =
   Vector.normInf x
   ==
   (NonEmpty.maximum $ 0 !: map Scalar.absolute (Array.toList x))

normInf1 ::
   (Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   Vector ShapeInt a -> Bool
normInf1 x =
   Vector.normInf1 x
   ==
   (NonEmpty.maximum $ 0 !: map Scalar.norm1 (Array.toList x))


normInfAppend ::
   (Class.Floating a, RealOf a ~ ar, Class.Real ar, RealOf ar ~ ar) =>
   (Vector ShapeInt a, Vector ShapeInt a) -> Bool
normInfAppend (x,y) =
   Vector.normInf (Vector.append x y)
   ==
   Vector.normInf (Vector.autoFromList [Vector.normInf x, Vector.normInf y])

normInf1Append ::
   (Class.Floating a, RealOf a ~ ar, Class.Real ar, RealOf ar ~ ar) =>
   (Vector ShapeInt a, Vector ShapeInt a) -> Bool
normInf1Append (x,y) =
   Vector.normInf1 (Vector.append x y)
   ==
   Vector.normInf1 (Vector.autoFromList [Vector.normInf1 x, Vector.normInf1 y])


sumList :: (Eq a, Class.Floating a) => Vector ShapeInt a -> Bool
sumList xs  =  Vector.sum xs == sum (Vector.toList xs)

productList :: (Eq a, Class.Floating a) => Vector ShapeInt a -> Bool
productList xs  =  Vector.product xs == product (Vector.toList xs)


withNonEmpty ::
   (Vector ShapeInt a -> b) ->
   (b -> Vector ShapeInt a -> Bool) ->
   Vector ShapeInt a -> Bool
withNonEmpty f law xs =
   let x = f xs
   in if Array.shape xs == shapeInt 0
         then isBottom x
         else law x xs

minimumList :: (Class.Real a) => Vector ShapeInt a -> Bool
minimumList =
   withNonEmpty Vector.minimum $ \x xs -> x == minimum (Vector.toList xs)

maximumList :: (Class.Real a) => Vector ShapeInt a -> Bool
maximumList =
   withNonEmpty Vector.maximum $ \x xs -> x == maximum (Vector.toList xs)

limitsMinimumMaximum :: (Class.Real a) => Vector ShapeInt a -> Bool
limitsMinimumMaximum =
   withNonEmpty Vector.limits $
      \xe xs -> xe == (Vector.minimum xs, Vector.maximum xs)

limits :: (Class.Real a) => Vector ShapeInt a -> Bool
limits =
   withNonEmpty Vector.limits $ \xe xs -> xe == Array.limits xs

argLimits :: (Class.Real a) => Vector ShapeInt a -> Bool
argLimits =
   withNonEmpty Vector.argLimits $
      \xe xs -> xe == (Vector.argMinimum xs, Vector.argMaximum xs)


argAbsMaximum ::
   (Eq a, Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   Vector ShapeInt a -> Bool
argAbsMaximum =
   withNonEmpty Vector.argAbsMaximum $
      \(k,x) xs -> xs!k == x && Scalar.absolute x == Vector.normInf xs

argAbs1Maximum ::
   (Eq a, Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   Vector ShapeInt a -> Bool
argAbs1Maximum =
   withNonEmpty Vector.argAbs1Maximum $
      \(k,x) xs -> xs!k == x && Scalar.norm1 x == Vector.normInf1 xs


raiseZero :: (Eq a, Class.Floating a) => Vector ShapeInt a -> Bool
raiseZero xs =  Util.equalVector xs $ Vector.raise Scalar.zero xs

addRaise :: (Eq a, Class.Floating a) => (a, Vector ShapeInt a) -> Bool
addRaise (x,ys) =
   Util.equalVector
      (Vector.raise x ys)
      (ys |+| Vector.constant (Array.shape ys) x)

subRaise :: (Eq a, Class.Floating a) => (a, Vector ShapeInt a) -> Bool
subRaise (x,ys) =
   Util.equalVector
      (Vector.raise (-x) ys)
      (ys |-| Vector.constant (Array.shape ys) x)

addScaleMac ::
   (Eq a, Class.Floating a) => (a, Vector ShapeInt a, Vector ShapeInt a) -> Bool
addScaleMac (a,xs,ys) =
   Util.equalVector (Vector.mac a xs ys) (a.*|xs |+| ys)


divide ::
   (Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   (Triangular.Diagonal ShapeInt a, Vector ShapeInt a) -> Bool
divide (a,b) =
   Util.approxVector
      (b -/# a)
      (Vector.divide b $ Triangular.takeDiagonal a)


checkForAll ::
   (Show a, QC.Testable test) =>
   Gen.T dim tag a -> (a -> test) -> Tagged tag QC.Property
checkForAll gen = Util.checkForAll (Gen.run gen 10 5)


testsVar ::
   (Show a,
    Class.Floating a, Eq a, RealOf a ~ ar, Class.Real ar, RealOf ar ~ ar) =>
   [(String, Tagged a QC.Property)]
testsVar =
   ("singleton",
      checkForAll Gen.scalar singleton) :
   ("swapInverse",
      checkForAll genSwapVector swapInverse) :
   ("swapCommutative",
      checkForAll genSwapVector swapCommutative) :
   ("norm2Squared",
      checkForAll Gen.vector norm2Squared) :
   ("norm2Inner",
      checkForAll Gen.vector norm2Inner) :
   ("normInf",
      checkForAll Gen.vector normInf) :
   ("normInf1",
      checkForAll Gen.vector normInf1) :
   ("normInfAppend",
      checkForAll ((,) <$> Gen.vector <+++> Gen.vector) normInfAppend) :
   ("normInf1Append",
      checkForAll ((,) <$> Gen.vector <+++> Gen.vector) normInf1Append) :
   ("sum",
      checkForAll Gen.vector sumList) :
   ("product",
      checkForAll Gen.vector productList) :
   ("minimum",
      checkForAll Gen.vector (minimumList . Vector.realPart)) :
   ("maximum",
      checkForAll Gen.vector (maximumList . Vector.realPart)) :
   ("limitsMinimumMaximum",
      checkForAll Gen.vector (limitsMinimumMaximum . Vector.realPart)) :
   ("limits",
      checkForAll Gen.vector (limits . Vector.realPart)) :
   ("argLimits",
      checkForAll Gen.vector (argLimits . Vector.realPart)) :
   ("argAbsMaximum",
      checkForAll Gen.vector argAbsMaximum) :
   ("argAbs1Maximum",
      checkForAll Gen.vector argAbs1Maximum) :
   ("raiseZero",
      checkForAll Gen.vector raiseZero) :
   ("addRaise",
      checkForAll ((,) <$> Gen.scalar <.*#> Gen.vector) addRaise) :
   ("subRaise",
      checkForAll ((,) <$> Gen.scalar <.*#> Gen.vector) subRaise) :
   ("addScaleMac",
      checkForAll
         ((,,) <$> Gen.scalar <.*#> Gen.vector <|=|> Gen.vector)
         addScaleMac) :
   ("divide",
      checkForAll
         ((,) <$> Gen.condition Util.invertible Gen.diagonal <#*|> Gen.vector)
         divide) :
   []