{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE DeriveTraversable #-}
module SimpleSym where

import Test.Tasty
import Test.Tasty.HUnit

import Data.Eq.Deriving
import Data.Ord.Deriving
import Text.Show.Deriving

import Data.Equality.Utils
import Data.Equality.Matching
import Data.Equality.Saturation
import Data.Equality.Language
import Data.Equality.Analysis
import Data.Equality.Graph.Lens ((^.), _data)

data SymExpr a = Const Double
               | Symbol String
               | a :+: a
               | a :*: a
               | a :/: a
               deriving (Functor, Foldable, Traversable)
infix 6 :+:
infix 7 :*:, :/:

deriveEq1   ''SymExpr
deriveOrd1  ''SymExpr
deriveShow1 ''SymExpr

instance Language SymExpr

instance Analysis (Maybe Double) SymExpr where
  makeA = \case
    Const x -> Just x
    Symbol _ -> Nothing
    x :+: y -> (+) <$> x <*> y
    x :*: y -> (*) <$> x <*> y
    x :/: y -> (/) <$> x <*> y

  joinA Nothing (Just x) = Just x
  joinA (Just x) Nothing = Just x
  joinA Nothing Nothing  = Nothing
  joinA (Just x) (Just y) = if x == y then Just x else error "ouch, that shouldn't have happened"

  modifyA c = case c^._data of
                Nothing -> (c, [])
                Just i  -> (c, [Fix (Const i)])

cost :: CostFunction SymExpr Int
cost = \case
  Const  _ -> 1
  Symbol _ -> 1
  c1 :+: c2 -> c1 + c2 + 2
  c1 :*: c2 -> c1 + c2 + 3
  c1 :/: c2 -> c1 + c2 + 4

rewrites :: [Rewrite (Maybe Double) SymExpr]
rewrites =
  [ pat (pat ("a" :*: "b") :/: "c") := pat ("a" :*: pat ("b" :/: "c"))
  , pat ("x" :/: "x")               := pat (Const 1)
  , pat ("x" :*: pat (Const 1))     := "x"
  ]

rewrite :: Fix SymExpr -> Fix SymExpr
rewrite e = fst (equalitySaturation e rewrites cost)

e1 :: Fix SymExpr
e1 = Fix (Fix (Fix (Symbol "x") :*: Fix (Const 2)) :/: Fix (Const 2)) -- (x*2)/2

simpleSymTests :: TestTree
simpleSymTests = testGroup "Simple Sym"
    [ testCase "(a*2)/2 = a"  $ rewrite e1 @?= Fix (Symbol "x")
    , testCase "(x/x)+1) = 4" $ rewrite (Fix $ Fix (Const 3) :+: Fix (Fix (Symbol "x") :/: Fix (Symbol "x"))) @?= Fix (Const 4)
    ]