{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE EmptyDataDecls #-}
module Test.Generator where

import qualified Test.Logic as Logic
import qualified Test.Utility as Util
import Test.Logic
         (Dim, MatchMode(DontForceMatch,ForceMatch), (=!=), (<!=), (!+!), (!*!))
import Test.Utility (Match)

import qualified UniqueLogic.ST.TF.System.Simple as Sys

import qualified Numeric.LAPACK.Matrix.Shape as MatrixShape
import qualified Numeric.LAPACK.Matrix.Triangular as Triangular
import qualified Numeric.LAPACK.Matrix.Square as Square
import qualified Numeric.LAPACK.Matrix as Matrix
import qualified Numeric.LAPACK.Vector as Vector
import Numeric.LAPACK.Matrix.Hermitian (Hermitian)
import Numeric.LAPACK.Matrix (ShapeInt)
import Numeric.LAPACK.Scalar (RealOf, fromReal, one)

import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Shape ((:+:))

import qualified Control.Monad.Trans.RWS as MRWS
import qualified Control.Monad.Trans.Writer as MW
import qualified Control.Monad.Trans.Reader as MR
import qualified Control.Monad.Trans.Class as MT
import qualified Control.Applicative.HT as AppHT
import qualified Control.Functor.HT as FuncHT
import Control.Applicative (liftA2, liftA3, (<*>), (<$>))

import qualified Data.NonEmpty.Class as NonEmptyC
import qualified Data.Ref as Ref
import Data.Semigroup ((<>))
import Data.Monoid (Monoid, mempty)
import Data.Tuple.HT (swap, mapFst, mapSnd)

import qualified Test.QuickCheck as QC



{- |
@generator :: Base dim array@
constructs an array shape or an array of type @array@
with dimensions @dim@ and maintains relations between the dimensions.
Dimensions will be choosen arbitrarily from the range @(0,maxDim)@.
@TaggedVariables s dim@ maintain the 'Logic' variables for the array dimensions.
@Logic.M s array@ contains the query for retrieving the object
that depends on the solved logic system.
@M s@ collects all constraints in a 'MW.WriterT'.

I moved the 's' tag to within the 'Base' constructor
and furthermore defined 'TaggedVariables' to strip the 's' tag
from the Variables in 'dim'.
This way, we can easily define 'checkForAll' in the test modules.
Otherwise there would not be a way to quantify 'dim' while containing 's' tags.
That is, we would have to reset 'dim' to () before every call to 'checkForAll'.
-}
newtype Base dim array = Base {unbase :: forall s. BaseM s dim array}

{- |
'T' adds the capability of array creation to 'Base'.
To this end it employs a 'MR.ReaderT' that provides the 'maxElem' parameter.
Array elements are choosen from the range @(-maxElem,maxElem)@.
The many levels of construction of 'T' look complicated,
but every level represents a major step.
While 'Base' generates matching dimensions for all involved objects,
'T' adds the final 'QC.Gen' that generates
all arrays containing QuickCheck generated random values.
The separation of dimension and array creation allows us
to place constraints like invertibility to the generator, afterwards,
or to extend single generators to list generators.
-}
newtype T dim elem array = Cons (forall s. BaseM s dim (MaxElemQCGen array))
type MaxElemQCGen = MR.ReaderT Integer QC.Gen

type BaseM s dim array = M s (TaggedVariables s dim, Logic.M s array)

type M s = MW.WriterT (Logic.System s) (Logic.M s)
data Variable dim

type family TaggedVariables s tuple
type instance TaggedVariables s (Variable dim) = Logic.Variable s dim
type instance TaggedVariables s () = ()
type instance TaggedVariables s (a,b) =
                  (TaggedVariables s a, TaggedVariables s b)

instance Functor (Base dim) where
   fmap f (Base gen) = Base $ mapSnd (fmap f) <$> gen

instance Functor (T dim elem) where
   fmap f = liftBase $ fmap $ fmap f

newVariable :: (Ref.C m, Monoid w) => MW.WriterT w m (Sys.Variable m a)
newVariable = MT.lift Sys.globalVariable

newVariableWith ::
   (Ref.C m, Monoid w) =>
   (Sys.Variable m a -> w) -> MW.WriterT w m (Sys.Variable m a)
newVariableWith constraint = do
   v <- newVariable
   MW.tell $ constraint v
   return v

run :: T dim elem array -> Integer -> Int -> Util.TaggedGen elem (array, Match)
run (Cons gen) maxElem maxDim =
   Util.Tagged $
      FuncHT.mapFst (flip MR.runReaderT maxElem)
      =<<
      Logic.runSTInGen
         (do ((_dim, queries), sys) <- MW.runWriterT gen
             Logic.solve sys
             queries)
         maxDim
      =<<
      QC.elements [DontForceMatch, ForceMatch]

withExtra ::
   (T dim elem (a,b) -> ((a,b) -> c) -> io) ->
   QC.Gen a -> T dim elem b -> (a -> b -> c) -> io
withExtra checkForAll genA genB test =
   checkForAll (mapQC (\b -> flip (,) b <$> genA) genB) (uncurry test)


fromBase :: Base dim (MaxElemQCGen a) -> T dim elem a
fromBase (Base gen) = Cons gen

liftBase ::
   (Base dimA (MaxElemQCGen a) -> Base dimB (MaxElemQCGen b)) ->
   T dimA elem a -> T dimB elem b
liftBase f (Cons gen) = Cons $ unbase $ f $ Base gen

condition :: (a -> Bool) -> T dim elem a -> T dim elem a
condition = liftBase . fmap . MR.mapReaderT . flip QC.suchThat

mapQC :: (a -> QC.Gen b) -> T dim elem a -> T dim elem b
mapQC f = liftBase $ fmap (MT.lift . f =<<)

mapGen :: (Integer -> a -> QC.Gen b) -> Base dim a -> T dim elem b
mapGen f = fromBase . fmap (MR.ReaderT . flip f)

mapGenDim :: (Integer -> Int -> a -> QC.Gen b) -> Base dim a -> T dim elem b
mapGenDim f (Base gen) =
   Cons $ do
      (maxDim, _matchMode) <- MT.lift $ Logic.M MRWS.ask
      mapSnd (fmap (\a -> MR.ReaderT $ \maxElem -> f maxElem maxDim a))
         <$> gen


constrain ::
   (forall s. TaggedVariables s dim -> Logic.System s) ->
   Base dim a -> Base dim a
constrain constraint (Base gen) =
   Base $ do
      (dim,a) <- gen
      MW.tell $ constraint dim
      return (dim,a)

combine ::
   (forall s.
    TaggedVariables s dimF -> TaggedVariables s dimA ->
    (TaggedVariables s dimB, Logic.System s)) ->
   T dimF elem (a -> b) ->
   T dimA elem a ->
   T dimB elem b
combine combineDim =
   combineM
      (\dimF dimA -> do
         let (dimB, constraint) = combineDim dimF dimA
         MW.tell constraint
         return dimB)

combineM ::
   (forall s.
    TaggedVariables s dimF -> TaggedVariables s dimA ->
    M s (TaggedVariables s dimB)) ->
   T dimF elem (a -> b) ->
   T dimA elem a ->
   T dimB elem b
combineM combineDim (Cons genF) (Cons genA) =
   Cons $ do
      (dimF,f) <- genF
      (dimA,a) <- genA
      dimB <- combineDim dimF dimA
      return (dimB, liftA2 (<*>) f a)

combine3M ::
   (forall s.
    TaggedVariables s dimF ->
    TaggedVariables s dimA -> TaggedVariables s dimB ->
    M s (TaggedVariables s dimC)) ->
   T dimF elem (a -> b -> c) ->
   T dimA elem a ->
   T dimB elem b ->
   T dimC elem c
combine3M combineDim (Cons genF) (Cons genA) (Cons genB) =
   Cons $ do
      (dimF,f) <- genF
      (dimA,a) <- genA
      (dimB,b) <- genB
      dimC <- combineDim dimF dimA dimB
      return (dimC, liftA3 (\fi ai bi -> fi <*> ai <*> bi) f a b)


type ScalarBase = Base ()
type Scalar elem = T () elem

scalarGen :: (Class.Floating a) => Scalar b a
scalarGen = Cons $ return ((), return $ MR.ReaderT Util.genElement)

scalar :: (Class.Floating a) => Scalar a a
scalar = scalarGen

scalarReal :: (Class.Floating a, RealOf a ~ ar, Class.Real ar) => Scalar a ar
scalarReal = scalarGen

(<-*|>) ::
   (Dim size, Eq size) =>
   Vector size elem (a -> b) ->
   Vector size elem a ->
   Scalar elem b
(<-*|>) = combine (\dimF dimA -> ((), dimF=!=dimA))


type VectorBase size = Base (Variable size)
type Vector size elem = T (Variable size) elem
type VectorInt elem = Vector ShapeInt elem

vectorDim :: (Dim size) => VectorBase size size
vectorDim =
   Base $ do
      dim <- newVariable
      return (dim, Logic.query dim)

vectorGen ::
   (Dim size, Class.Floating a) => Vector size b (Vector.Vector size a)
vectorGen = mapGen Util.genVector vectorDim

vector ::
   (Dim size, Class.Floating a) => Vector size a (Vector.Vector size a)
vector = vectorGen

vectorReal ::
   (Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   VectorInt a (Vector.Vector ShapeInt ar)
vectorReal = vectorGen

(<-*#>) ::
   (Dim height, Eq height) =>
   Vector height elem (a -> b) ->
   Matrix height width elem a ->
   Vector width elem b
(<-*#>) = combine (\dim (height,width) -> (width, dim=!=height))

(<#*|>) ::
   (Dim width, Eq width) =>
   Matrix height width elem (a -> b) ->
   Vector width elem a ->
   Vector height elem b
(<#*|>) = combine (\(height,width) dim -> (height, width=!=dim))

(<|=|>) ::
   (Dim size, Eq size) =>
   Vector size elem (a -> b) ->
   Vector size elem a ->
   Vector size elem b
(<|=|>) = combine (\sizeF sizeA -> (sizeF, sizeF=!=sizeA))

(<+++>) ::
   Vector sizeA elem (a -> b) ->
   Vector sizeB elem a ->
   Vector (sizeA:+:sizeB) elem b
(<+++>) = combineM (\sizeA sizeB -> newVariableWith $ sizeA!+!sizeB)


type MatrixBase height width = Base (Variable height, Variable width)
type Matrix height width = T (Variable height, Variable width)
type MatrixInt = Matrix ShapeInt ShapeInt

shapeFromDims :: (MatrixShape.Order -> a -> b) -> Base dim a -> Base dim b
shapeFromDims f (Base gen) =
   Base $ mapSnd (liftA2 f (Logic.liftGen Util.genOrder)) <$> gen

matrixDims ::
   (Dim height, Dim width) => MatrixBase height width (height, width)
matrixDims =
   Base $ do
      dims <- liftA2 (,) newVariable newVariable
      return (dims, AppHT.mapPair (Logic.query,Logic.query) dims)

matrixShape ::
   (Dim height, Dim width) =>
   MatrixBase height width (MatrixShape.General height width)
matrixShape = shapeFromDims (uncurry . MatrixShape.general) matrixDims

matrix ::
   (Dim height, Dim width, Class.Floating a) =>
   Matrix height width a (Matrix.General height width a)
matrix = mapGen Util.genArray matrixShape

matrixInt ::
   (Class.Floating a) => MatrixInt a (Matrix.General ShapeInt ShapeInt a)
matrixInt = matrix


listOf ::
   (NonEmptyC.Gen f) =>
   (forall s. TaggedVariables s dim -> Logic.M s size) ->
   T dim elem a -> T dim elem (size, f a)
listOf querySize (Cons gen) =
   Cons $ do
      (dim, query) <- gen
      return (dim, do
         size <- querySize dim
         qc <- query
         return $ (,) size <$> MR.mapReaderT NonEmptyC.genOf qc)

listOfVector ::
   (Dim size, NonEmptyC.Gen f) =>
   Vector size elem a -> Vector size elem (size, f a)
listOfVector = listOf Logic.query

listOfMatrix ::
   (Dim height, Dim width, NonEmptyC.Gen f) =>
   Matrix height width elem a -> Matrix height width elem ((height,width), f a)
listOfMatrix = listOf (AppHT.mapPair (Logic.query,Logic.query))


type SquareBase sh = MatrixBase sh sh
type Square sh = Matrix sh sh

squareDim :: (Dim sh) => SquareBase sh sh
squareDim =
   Base $ do
      dim <- newVariable
      return ((dim,dim), Logic.query dim)

squareShape :: (Dim sh) => SquareBase sh (MatrixShape.Square sh)
squareShape = shapeFromDims MatrixShape.square squareDim

square :: (Class.Floating a) => MatrixInt a (Square.Square ShapeInt a)
square = mapGen Util.genArray squareShape

invertible ::
   (Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   MatrixInt a (Square.Square ShapeInt a)
invertible = condition Util.invertible square

diagonal :: (Class.Floating a) => MatrixInt a (Triangular.Diagonal ShapeInt a)
diagonal = mapGen Util.genArray $ shapeFromDims MatrixShape.diagonal squareDim

identity ::
   (MatrixShape.Content lo, MatrixShape.Content up, Class.Floating a) =>
   MatrixInt a (Triangular.Triangular lo MatrixShape.Unit up ShapeInt a)
identity = fromBase $ return <$> shapeFromDims Triangular.identity squareDim

triangularShape ::
   (MatrixShape.Content up, MatrixShape.Content lo, MatrixShape.TriDiag diag,
    Dim sh) =>
   SquareBase sh (MatrixShape.Triangular lo diag up sh)
triangularShape =
   shapeFromDims
      (MatrixShape.Triangular MatrixShape.autoDiag MatrixShape.autoUplo)
      squareDim

triangular ::
   (MatrixShape.Content up, MatrixShape.Content lo, MatrixShape.TriDiag diag,
    Dim sh, Shape.Indexed sh, Shape.Index sh ~ ix, Eq ix, Class.Floating a) =>
   Square sh a (Triangular.Triangular lo diag up sh a)
triangular = mapGen genTriangularArray triangularShape


newtype GenTriangularDiag lo up sh a diag =
   GenTriangularDiag {
      runGenTriangularDiag ::
         MatrixShape.Triangular lo diag up sh ->
         QC.Gen (Triangular.Triangular lo diag up sh a)
   }

genTriangularArray ::
   (MatrixShape.Content lo, MatrixShape.Content up, MatrixShape.TriDiag diag,
    Shape.Indexed sh, Shape.Index sh ~ ix, Eq ix, Class.Floating a) =>
   Integer ->
   MatrixShape.Triangular lo diag up sh ->
   QC.Gen (Triangular.Triangular lo diag up sh a)
genTriangularArray maxElem =
   runGenTriangularDiag $
   MatrixShape.switchTriDiag
      (GenTriangularDiag $ \shape ->
         Util.genArrayExtraDiag maxElem shape (const $ return one))
      (GenTriangularDiag $ Util.genArray maxElem)


tallShape :: MatrixBase ShapeInt ShapeInt (MatrixShape.Tall ShapeInt ShapeInt)
tallShape =
   shapeFromDims (uncurry . MatrixShape.tall) $
   constrain (uncurry $ flip (<!=)) matrixDims

tall :: (Class.Floating a) => MatrixInt a (Matrix.Tall ShapeInt ShapeInt a)
tall = mapGen Util.genArray tallShape

fullRankTall ::
   (Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   MatrixInt a (Matrix.Tall ShapeInt ShapeInt a)
fullRankTall = condition Util.fullRankTall tall


wide :: (Class.Floating a) => MatrixInt a (Matrix.Wide ShapeInt ShapeInt a)
wide = Matrix.transpose <$> transpose tall

fullRankWide ::
   (Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   MatrixInt a (Matrix.Wide ShapeInt ShapeInt a)
fullRankWide = Matrix.transpose <$> transpose fullRankTall


hermitian ::
   (Dim sh, Shape.Indexed sh, Shape.Index sh ~ ix, Eq ix,
    Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   Square sh a (Hermitian sh a)
hermitian =
   flip mapGen (shapeFromDims MatrixShape.hermitian squareDim) $
         \maxElem shape ->
      Util.genArrayExtraDiag maxElem shape
         (const $ fromReal <$> Util.genReal maxElem)

lscStack ::
   (Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   Matrix (ShapeInt:+:ShapeInt) ShapeInt a
      (Matrix.Tall (ShapeInt:+:ShapeInt) ShapeInt a)
lscStack =
   condition
      (Util.fullRankTall . Matrix.transpose .
       Matrix.wideFromGeneral . Matrix.takeTop . Matrix.fromFull) $
   condition Util.fullRankTall $
   mapGen Util.genArray $
   shapeFromDims (uncurry . MatrixShape.tall) $
   Base $ do
      height <- newVariable
      width <- newVariableWith $ Logic.between height
      return
         ((height,width),
          liftA2 (,) (Logic.query height) (Logic.query width))


{-
There cannot be a pure/point function.
-}
(<#*#>) ::
   (Dim fuse, Eq fuse) =>
   Matrix height fuse elem (a -> b) ->
   Matrix fuse width elem a ->
   Matrix height width elem b
(<#*#>) =
   combine (\(height,fuseF) (fuseA,width) -> ((height,width), fuseF=!=fuseA))

(<.*#>) :: Scalar elem (a -> b) -> T dim elem a -> T dim elem b
(<.*#>) = combine (\() size -> (size, mempty))

mapDims ::
   (forall s.
    (Logic.Variable s heightA, Logic.Variable s widthA) ->
    (Logic.Variable s heightB, Logic.Variable s widthB)) ->
   Matrix heightA widthA elem a ->
   Matrix heightB widthB elem a
mapDims f = liftBase $ \(Base gen) -> Base $ mapFst f <$> gen

transpose ::
   Matrix height width elem a ->
   Matrix width height elem a
transpose = mapDims swap

gramian ::
   Matrix height width elem a ->
   Matrix width width elem a
gramian = mapDims (\(_,w) -> (w,w))

(<#\#>) ::
   (Dim height, Eq height) =>
   Matrix height width elem (a -> b) ->
   Matrix height nrhs elem a ->
   Matrix width nrhs elem b
(<#\#>) a b = transpose a <#*#> b

(<#/#>) ::
   (Dim width, Eq width) =>
   Matrix nlhs width elem (a -> b) ->
   Matrix height width elem a ->
   Matrix nlhs height elem b
(<#/#>) a b = a <#*#> transpose b

(<|*->) ::
   Vector height elem (a -> b) ->
   Vector width elem a ->
   Matrix height width elem b
(<|*->) = combine (\height width -> ((height,width), mempty))


(<><>) ::
   Matrix heightA widthA elem (a -> b) ->
   Matrix heightB widthB elem a ->
   Matrix (heightA,heightB) (widthA,widthB) elem b
(<><>) =
   combineM
      (\(heightA,widthA) (heightB,widthB) ->
         liftA2 (,)
            (newVariableWith $ heightA !*! heightB)
            (newVariableWith $ widthA !*! widthB))


(<#=#>) ::
   (Dim height, Eq height) =>
   (Dim width, Eq width) =>
   Matrix height width elem (a -> b) ->
   Matrix height width elem a ->
   Matrix height width elem b
(<#=#>) =
   combine $ \(heightF,widthF) (heightA,widthA) ->
      ((heightF,widthF), heightF=!=heightA <> widthF=!=widthA)


(<===>) ::
   (Dim width, Eq width) =>
   Matrix heightA width elem (a -> b) ->
   Matrix heightB width elem a ->
   Matrix (heightA:+:heightB) width elem b
(<===>) =
   combineM
      (\(heightA,widthA) (heightB,widthB) -> do
         MW.tell $ widthA=!=widthB
         heightC <- newVariableWith $ heightA!+!heightB
         return (heightC,widthA))

(<|||>) ::
   (Dim height, Eq height) =>
   Matrix height widthA elem (a -> b) ->
   Matrix height widthB elem a ->
   Matrix height (widthA:+:widthB) elem b
(<|||>) f a = transpose $ transpose f <===> transpose a



stackDiagonal ::
   (Dim heightA, Eq heightA) =>
   (Dim widthB, Eq widthB) =>
   Matrix heightA widthA elem a ->
   Matrix heightB widthB elem c ->
   Matrix (heightA:+:heightB) (widthA:+:widthB) elem (a,c)
stackDiagonal genA =
   combineM
      (\(heightA,widthA) (heightB,widthB) -> do
         liftA2 (,)
            (newVariableWith $ heightA!+!heightB)
            (newVariableWith $ widthA!+!widthB))
      ((,) <$> genA)

stack3 ::
   (Dim heightA, Eq heightA) =>
   (Dim widthB, Eq widthB) =>
   Matrix heightA widthA elem a ->
   Matrix heightA widthB elem b ->
   Matrix heightB widthB elem c ->
   Matrix (heightA:+:heightB) (widthA:+:widthB) elem (a,b,c)
stack3 genA =
   combine3M
      (\(heightA,widthA) (heightA0,widthB0) (heightB,widthB) -> do
         MW.tell $  heightA=!=heightA0  <>  widthB=!=widthB0
         liftA2 (,)
            (newVariableWith $ heightA!+!heightB)
            (newVariableWith $ widthA!+!widthB))
      ((,,) <$> genA)


infixl 4 <-*|>, <.*#>, <-*#>, <#*|>, <#*#>, <#\#>, <#/#>
infixl 4 <|*->, <><>, <|=|>, <#=#>, <+++>, <===>, <|||>