{-# OPTIONS_GHC -Wno-orphans #-} -- Arbitrary
{-# LANGUAGE RoleAnnotations #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Invariants where

import Test.Tasty
import Test.Tasty.QuickCheck as QC hiding (classes)

import Control.Monad

import qualified Data.Containers.ListUtils as LU
import qualified Data.Foldable as F
import qualified Data.List   as L
import qualified Data.IntSet as IS
import qualified Data.IntMap.Strict as IM

import Data.Equality.Graph.Monad as GM
import Data.Equality.Graph.Lens
import Data.Equality.Graph.Internal (EGraph(classes))
import Data.Equality.Graph
import Data.Equality.Extraction
import Data.Equality.Saturation
import Data.Equality.Matching
import Data.Equality.Matching.Database
import Sym

-- | Newtype deriving via Expr to be able to define a different analysis
-- TODO: Use type level symbol to define the analysis
type role SimpleExpr nominal
newtype SimpleExpr l = SE (Expr l)
    deriving (Functor, Foldable, Traversable, Show, Eq, Ord)

-- | When a rewrite of type "x":=c where x is a pattern variable and c is a
-- constant is used in equality saturation of any expression, all e-classes
-- should be merged into a single one, since all classes are equal to c and
-- therefore equivalent to themselves
patFoldAllClasses :: forall l. (Language l, Num (Pattern l))
                  => Fix l -> Integer -> Bool
patFoldAllClasses expr i =
    case IM.toList (classes eg) of
        [_] -> True
        _   -> False
    where
        eg :: EGraph () l
        eg = snd $ equalitySaturation expr [VariablePattern 1:=fromInteger i] (error "Cost function shouldn't be used" :: CostFunction l Int)

-- | Test 'compileToQuery'.
--
-- Every pattern compiled to a query should have the same number of free variables (except for the root variable)
-- as the pattern
--
-- The number of atoms should also match the number of non variable patterns
-- since we should create an additional atom (with a new bound variable) for each. 
testCompileToQuery :: Traversable lang => Pattern lang -> Bool
testCompileToQuery p = case fst $ compileToQuery p of
                         -- Handle special case for selectAll queries...
                         SelectAllQuery x -> [x] == vars p && numNonVarPatterns p == 0
                         q@(Query _ atoms)
                           | [] <- queryHeadVars q   -> False
                           | _:xs <- queryHeadVars q ->
                               L.sort xs == L.sort (vars p)
                                 && length atoms == numNonVarPatterns p
    where
        numNonVarPatterns :: Foldable lang => Pattern lang -> Int
        numNonVarPatterns (VariablePattern _) = 0
        numNonVarPatterns (NonVariablePattern l) = F.foldl' (flip $ (+) . numNonVarPatterns) 1 l

        queryHeadVars :: Foldable lang => Query lang -> [Var]
        queryHeadVars (SelectAllQuery x) = [x]
        queryHeadVars (Query qv _) = qv

        -- | Return distinct variables in a pattern
        vars :: Foldable lang => Pattern lang -> [Var]
        vars (VariablePattern x) = [x]
        vars (NonVariablePattern p') = LU.nubInt $ join $ map vars $ F.toList p'

-- | If we match a singleton variable pattern against an e-graph, we should get
-- a match on all e-classes in the e-graph
ematchSingletonVar :: Language lang => Var -> EGraph () lang -> Bool
ematchSingletonVar v eg =
    let
        db = eGraphToDatabase eg
        matches = IS.fromList $ map matchClassId $ ematch db (VariablePattern v)
        eclasses = IM.keysSet (classes eg)
    in
        matches == eclasses 


-- | Property test for 'genericJoin'.
--
-- If we search a database with an expression in which all patterns are
-- variables (the only non-variable pattern is the top one), then, altogether,
-- we should get a list of all e-classes 
-- genericJoinAll :: Database lang -> 


-- The equivalence relation over e-nodes must be closed over congruence after rebuilding
-- congruenceInvariant :: Testable m (EGraph lang) => Property m


-- The hashcons 𝐻  must map all canonical e-nodes to their e-class ids
--
-- Note: the e-graph argument must have been rebuilt -- checking the property
-- when invariants are broken for sure doesn't make much sense
--
-- ROMES:TODO Should I rebuild it here? Then the property test is that after rebuilding ...HashConsInvariant
hashConsInvariant :: forall l. Language l
                  => EGraph () l -> Bool
hashConsInvariant eg =
    allOf _iclasses f eg
    where
      -- e-node 𝑛 ∈ 𝑀 [𝑎] ⇐⇒ 𝐻 [canonicalize(𝑛)] = find(𝑎)
      f (i, EClass{eClassNodes=nodes}) = all g nodes
        where
          g en = case lookupNM (canonicalize en eg) (eg^._memo) of
            Nothing -> error "how can we not find canonical thing in map? :)" -- False
            Just i' -> i' == find i eg 

benchSaturate :: forall l. Language l
              => [Rewrite () l] -> (l Int -> Int) -> Fix l -> Bool
benchSaturate rws cost expr =
    equalitySaturation expr rws cost `seq` True


-- ROMES:TODO: Property: Extract expression after equality saturation is always better or equal to the original expression

-- ROMES:TODO: Use action trick https://jaspervdj.be/posts/2015-03-13-practical-testing-in-haskell.html
instance Arbitrary (EGraph () SimpleExpr) where
    arbitrary = sized $ \n -> do
        exps <- forM [0..n] $ const arbitrary
        -- rws :: [Rewrite Expr] <- forM [0..n] $ const arbitrary
        (ids, eg) <- return $ egraph $
            mapM GM.represent exps
        ids1 <- sublistOf ids
        ids2 <- sublistOf ids
        return $ snd $ runEGraphM eg $ do
            forM_ (zip ids1 ids2) $ \(a,b) -> do
                GM.merge a b
            GM.rebuild

instance Arbitrary BOp where
    arbitrary = oneof [ return Add
                      , return Sub
                      , return Mul
                      , return Div ]

instance Arbitrary UOp where
    arbitrary = oneof [ return Sin
                      , return Cos
                      ]

instance Arbitrary a => Arbitrary (SimpleExpr a) where
    arbitrary = SE <$> arbitrary

instance Arbitrary a => Arbitrary (Expr a) where
    arbitrary = sized expr'
        where
            expr' :: Int -> Gen (Expr a)
            expr' 0 = oneof [ Sym . un <$> arbitrary
                            , Const . fromInteger <$> arbitrary
                            ]
            expr' n
              | n > 0 = oneof [ BinOp <$> arbitrary <*> resize (n `div` 2) arbitrary <*> resize (n `div` 2) arbitrary
                              , UnOp <$> arbitrary <*> resize (n - 1) arbitrary ]
            expr' _ = error "size is negative?"

instance Arbitrary (Fix SimpleExpr) where
    arbitrary = Fix <$> arbitrary

instance Arbitrary (Fix Expr) where
    arbitrary = Fix <$> arbitrary

instance Arbitrary (Pattern SimpleExpr) where
    arbitrary = sized p'
      where
        p' 0 = VariablePattern <$> oneof (return <$> [1..16])
        p' n = NonVariablePattern <$> resize (n `div` 2) arbitrary

newtype Name = Name { un :: String }

instance Arbitrary Name where
  arbitrary = oneof (return . Name . (:[]) <$> ['a'..'l'])

instance Num (Pattern SimpleExpr) where
    fromInteger = NonVariablePattern . SE . Const . fromInteger
    (+) = error "Should use @Expr or have other way to switch analysis"
    (*) = error "Should use @Expr or have other way to switch analysis"
    (-) = error "Should use @Expr or have other way to switch analysis"
    abs = error "Should use @Expr or have other way to switch analysis"
    signum = error "Should use @Expr or have other way to switch analysis"

invariants :: TestTree
invariants = testGroup "Invariants"
  [ QC.testProperty "Compile to query" (testCompileToQuery @SimpleExpr)
    -- TODO: This bench is still failing because of the bad rewrite scheduler
    -- TODO: Much infinite looping ...
  -- , QC.testProperty "Bench saturation @Expr" (withMaxSuccess 10 (benchSaturate @Expr rewrites symCost))
  , QC.testProperty "Singleton variable matches all" (ematchSingletonVar @SimpleExpr)
  , QC.testProperty "Hash Cons Invariant" (hashConsInvariant @SimpleExpr)
  , QC.testProperty "Fold all classes with x:=c" (patFoldAllClasses @SimpleExpr)
  ]