{-# LANGUAGE QuasiQuotes #-}

{- |
Module      : Language.Egison.Math.Rewrite
Licence     : MIT

This module implements rewrite rules for common mathematical functions.
-}

module Language.Egison.Math.Rewrite
  ( rewriteSymbol
  ) where

import           Control.Egison

import           Language.Egison.Math.Arith
import           Language.Egison.Math.Expr
import           Language.Egison.Math.Normalize
import {-# SOURCE #-} Language.Egison.Data (WHNFData)


rewriteSymbol :: ScalarData -> ScalarData
rewriteSymbol :: ScalarData -> ScalarData
rewriteSymbol =
  ((ScalarData -> ScalarData)
 -> (ScalarData -> ScalarData) -> ScalarData -> ScalarData)
-> [ScalarData -> ScalarData] -> ScalarData -> ScalarData
forall a. (a -> a -> a) -> [a] -> a
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 (\ScalarData -> ScalarData
acc ScalarData -> ScalarData
f -> ScalarData -> ScalarData
f (ScalarData -> ScalarData)
-> (ScalarData -> ScalarData) -> ScalarData -> ScalarData
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ScalarData -> ScalarData
acc)
    [ ScalarData -> ScalarData
rewriteI
    , ScalarData -> ScalarData
rewriteW
    , ScalarData -> ScalarData
rewriteLog
--    , rewriteSinCos
    , ScalarData -> ScalarData
rewriteExp
    , ScalarData -> ScalarData
rewritePower
    , ScalarData -> ScalarData
rewriteSqrt
    , ScalarData -> ScalarData
rewriteRt
    , ScalarData -> ScalarData
rewriteRtu
    , ScalarData -> ScalarData
rewriteDd
    ]

mapTerms :: (TermExpr -> TermExpr) -> ScalarData -> ScalarData
mapTerms :: (TermExpr -> TermExpr) -> ScalarData -> ScalarData
mapTerms TermExpr -> TermExpr
f (Div (Plus [TermExpr]
ts1) (Plus [TermExpr]
ts2)) =
  PolyExpr -> PolyExpr -> ScalarData
Div ([TermExpr] -> PolyExpr
Plus ((TermExpr -> TermExpr) -> [TermExpr] -> [TermExpr]
forall a b. (a -> b) -> [a] -> [b]
map TermExpr -> TermExpr
f [TermExpr]
ts1)) ([TermExpr] -> PolyExpr
Plus ((TermExpr -> TermExpr) -> [TermExpr] -> [TermExpr]
forall a b. (a -> b) -> [a] -> [b]
map TermExpr -> TermExpr
f [TermExpr]
ts2))

mapTerms' :: (TermExpr -> ScalarData) -> ScalarData -> ScalarData
mapTerms' :: (TermExpr -> ScalarData) -> ScalarData -> ScalarData
mapTerms' TermExpr -> ScalarData
f (Div (Plus [TermExpr]
ts1) (Plus [TermExpr]
ts2)) =
  ScalarData -> ScalarData -> ScalarData
mathDiv ((ScalarData -> ScalarData -> ScalarData)
-> ScalarData -> [ScalarData] -> ScalarData
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ScalarData -> ScalarData -> ScalarData
mathPlus (PolyExpr -> PolyExpr -> ScalarData
Div ([TermExpr] -> PolyExpr
Plus []) ([TermExpr] -> PolyExpr
Plus [Integer -> Monomial -> TermExpr
Term Integer
1 []])) ((TermExpr -> ScalarData) -> [TermExpr] -> [ScalarData]
forall a b. (a -> b) -> [a] -> [b]
map TermExpr -> ScalarData
f [TermExpr]
ts1)) ((ScalarData -> ScalarData -> ScalarData)
-> ScalarData -> [ScalarData] -> ScalarData
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ScalarData -> ScalarData -> ScalarData
mathPlus (PolyExpr -> PolyExpr -> ScalarData
Div ([TermExpr] -> PolyExpr
Plus []) ([TermExpr] -> PolyExpr
Plus [Integer -> Monomial -> TermExpr
Term Integer
1 []])) ((TermExpr -> ScalarData) -> [TermExpr] -> [ScalarData]
forall a b. (a -> b) -> [a] -> [b]
map TermExpr -> ScalarData
f [TermExpr]
ts2))

mapPolys :: (PolyExpr -> PolyExpr) -> ScalarData -> ScalarData
mapPolys :: (PolyExpr -> PolyExpr) -> ScalarData -> ScalarData
mapPolys PolyExpr -> PolyExpr
f (Div PolyExpr
p1 PolyExpr
p2) = PolyExpr -> PolyExpr -> ScalarData
Div (PolyExpr -> PolyExpr
f PolyExpr
p1) (PolyExpr -> PolyExpr
f PolyExpr
p2)

rewriteI :: ScalarData -> ScalarData
rewriteI :: ScalarData -> ScalarData
rewriteI = (TermExpr -> TermExpr) -> ScalarData -> ScalarData
mapTerms TermExpr -> TermExpr
f
 where
  f :: TermExpr -> TermExpr
f term :: TermExpr
term@(Term Integer
a Monomial
xs) =
    ((Multiset (SymbolM, Eql), Monomial)
 -> DFS (Multiset (SymbolM, Eql), Monomial))
-> Monomial
-> Multiset (SymbolM, Eql)
-> [(Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr]
-> TermExpr
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (Multiset (SymbolM, Eql), Monomial)
-> DFS (Multiset (SymbolM, Eql), Monomial)
forall a. a -> DFS a
dfs Monomial
xs ((SymbolM, Eql) -> Multiset (SymbolM, Eql)
forall m. m -> Multiset m
Multiset (SymbolM
SymbolM, Eql
Eql))
      [ (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| (symbol #"i", $k) : $xss ->
              if even k
                then Term (a * (-1) ^ (quot k 2)) xss
                else Term (a * (-1) ^ (quot k 2)) ((Symbol "" "i" [], 1) : xss) |]
      , (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| _ -> term |]
      ]

rewriteW :: ScalarData -> ScalarData
rewriteW :: ScalarData -> ScalarData
rewriteW = (PolyExpr -> PolyExpr) -> ScalarData -> ScalarData
mapPolys PolyExpr -> PolyExpr
g (ScalarData -> ScalarData)
-> (ScalarData -> ScalarData) -> ScalarData -> ScalarData
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TermExpr -> TermExpr) -> ScalarData -> ScalarData
mapTerms TermExpr -> TermExpr
f
 where
  f :: TermExpr -> TermExpr
f term :: TermExpr
term@(Term Integer
a Monomial
xs) =
    ((Multiset (SymbolM, Eql), Monomial)
 -> DFS (Multiset (SymbolM, Eql), Monomial))
-> Monomial
-> Multiset (SymbolM, Eql)
-> [(Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr]
-> TermExpr
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (Multiset (SymbolM, Eql), Monomial)
-> DFS (Multiset (SymbolM, Eql), Monomial)
forall a. a -> DFS a
dfs Monomial
xs ((SymbolM, Eql) -> Multiset (SymbolM, Eql)
forall m. m -> Multiset m
Multiset (SymbolM
SymbolM, Eql
Eql))
      [ (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| (symbol #"w", $k & ?(>= 3)) : $xss ->
               Term a ((Symbol "" "w" [], k `mod` 3) : xss) |]
      , (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| _ -> term |]
      ]
  g :: PolyExpr -> PolyExpr
g poly :: PolyExpr
poly@(Plus [TermExpr]
ts) =
    ((Multiset TermM, [TermExpr]) -> DFS (Multiset TermM, [TermExpr]))
-> [TermExpr]
-> Multiset TermM
-> [(Multiset TermM, [TermExpr]) -> DFS PolyExpr]
-> PolyExpr
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (Multiset TermM, [TermExpr]) -> DFS (Multiset TermM, [TermExpr])
forall a. a -> DFS a
dfs [TermExpr]
ts (TermM -> Multiset TermM
forall m. m -> Multiset m
Multiset TermM
TermM)
      [ (Multiset TermM, [TermExpr]) -> DFS PolyExpr
[mc| term $a ((symbol #"w", #2) : $mr) :
             term $b ((symbol #"w", #1) : #mr) : $pr ->
               g (Plus (Term (-a) mr :
                        Term (b - a) ((Symbol "" "w" [], 1) : mr) : pr)) |]
      , (Multiset TermM, [TermExpr]) -> DFS PolyExpr
[mc| _ -> poly |]
      ]

rewriteLog :: ScalarData -> ScalarData
rewriteLog :: ScalarData -> ScalarData
rewriteLog = (TermExpr -> TermExpr) -> ScalarData -> ScalarData
mapTerms TermExpr -> TermExpr
f
 where
  f :: TermExpr -> TermExpr
f term :: TermExpr
term@(Term Integer
a Monomial
xs) =
    ((Multiset (SymbolM, Eql), Monomial)
 -> DFS (Multiset (SymbolM, Eql), Monomial))
-> Monomial
-> Multiset (SymbolM, Eql)
-> [(Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr]
-> TermExpr
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (Multiset (SymbolM, Eql), Monomial)
-> DFS (Multiset (SymbolM, Eql), Monomial)
forall a. a -> DFS a
dfs Monomial
xs ((SymbolM, Eql) -> Multiset (SymbolM, Eql)
forall m. m -> Multiset m
Multiset (SymbolM
SymbolM, Eql
Eql))
      [ (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| (apply1 #"log" _ zero, _) : _ -> Term 0 [] |]
      , (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| (apply1 #"log" _ (singleTerm _ #1 [(symbol #"e", $n)]), _) : $xss ->
              Term (n * a) xss |]
      , (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| _ -> term |]
      ]

makeApply :: WHNFData -> [ScalarData] -> SymbolExpr
makeApply :: WHNFData -> [ScalarData] -> SymbolExpr
makeApply WHNFData
f [ScalarData]
args =
  ScalarData -> [ScalarData] -> SymbolExpr
makeApplyExpr (SymbolExpr -> ScalarData
SingleSymbol (WHNFData -> SymbolExpr
QuoteFunction WHNFData
f)) [ScalarData]
args

rewriteExp :: ScalarData -> ScalarData
rewriteExp :: ScalarData -> ScalarData
rewriteExp = (TermExpr -> TermExpr) -> ScalarData -> ScalarData
mapTerms TermExpr -> TermExpr
f
 where
  f :: TermExpr -> TermExpr
f term :: TermExpr
term@(Term Integer
a Monomial
xs) =
    ((Multiset (SymbolM, Eql), Monomial)
 -> DFS (Multiset (SymbolM, Eql), Monomial))
-> Monomial
-> Multiset (SymbolM, Eql)
-> [(Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr]
-> TermExpr
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (Multiset (SymbolM, Eql), Monomial)
-> DFS (Multiset (SymbolM, Eql), Monomial)
forall a. a -> DFS a
dfs Monomial
xs ((SymbolM, Eql) -> Multiset (SymbolM, Eql)
forall m. m -> Multiset m
Multiset (SymbolM
SymbolM, Eql
Eql))
      [ (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| (apply1 #"exp" _ zero, _) : $xss ->
               f (Term a xss) |]
      , (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| (apply1 #"exp" _ (singleTerm #1 #1 []), _) : $xss ->
               f (Term a ((Symbol "" "e" [], 1) : xss)) |]
      , (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| (apply1 #"exp" _ (singleTerm $n #1 [(symbol #"i", #1), (symbol #"π", #1)]), _) : $xss ->
               f (Term ((-1) ^ n * a) xss) |]
      , (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| (apply1 #"exp" $expWhnf $x, $n & ?(>= 2)) : $xss ->
               f (Term a ((makeApply expWhnf [mathScalarMult n x], 1) : xss)) |]
      , (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| (apply1 #"exp" $expWhnf $x, #1) : (apply1 #"exp" _ $y, #1) : $xss ->
               f (Term a ((makeApply expWhnf [mathPlus x y], 1) : xss)) |]
      , (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| _ -> term |]
      ]

rewritePower :: ScalarData -> ScalarData
rewritePower :: ScalarData -> ScalarData
rewritePower = (TermExpr -> TermExpr) -> ScalarData -> ScalarData
mapTerms TermExpr -> TermExpr
f
 where
  f :: TermExpr -> TermExpr
f term :: TermExpr
term@(Term Integer
a Monomial
xs) =
    ((Multiset (SymbolM, Eql), Monomial)
 -> DFS (Multiset (SymbolM, Eql), Monomial))
-> Monomial
-> Multiset (SymbolM, Eql)
-> [(Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr]
-> TermExpr
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (Multiset (SymbolM, Eql), Monomial)
-> DFS (Multiset (SymbolM, Eql), Monomial)
forall a. a -> DFS a
dfs Monomial
xs ((SymbolM, Eql) -> Multiset (SymbolM, Eql)
forall m. m -> Multiset m
Multiset (SymbolM
SymbolM, Eql
Eql))
      [ (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| (apply1 #"^" _ (singleTerm #1 #1 []), _) : $xss -> f (Term a xss) |]
      , (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| (apply2 #"^" $powerWhnf $x $y, $n & ?(>= 2)) : $xss ->
               f (Term a ((makeApply powerWhnf [x, mathScalarMult n y], 1) : xss)) |]
      , (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| (apply2 #"^" $powerWhnf $x $y, #1) : (apply2 #"^" _ #x $z, #1) : $xss ->
               f (Term a ((makeApply powerWhnf [x, mathPlus y z], 1) : xss)) |]
      , (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| _ -> term |]
      ]

rewriteSinCos :: ScalarData -> ScalarData
rewriteSinCos :: ScalarData -> ScalarData
rewriteSinCos = ScalarData -> ScalarData
h (ScalarData -> ScalarData)
-> (ScalarData -> ScalarData) -> ScalarData -> ScalarData
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TermExpr -> TermExpr) -> ScalarData -> ScalarData
mapTerms (TermExpr -> TermExpr
g (TermExpr -> TermExpr)
-> (TermExpr -> TermExpr) -> TermExpr -> TermExpr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TermExpr -> TermExpr
f)
 where
  f :: TermExpr -> TermExpr
f term :: TermExpr
term@(Term Integer
a Monomial
xs) =
    ((Multiset (SymbolM, Eql), Monomial)
 -> DFS (Multiset (SymbolM, Eql), Monomial))
-> Monomial
-> Multiset (SymbolM, Eql)
-> [(Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr]
-> TermExpr
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (Multiset (SymbolM, Eql), Monomial)
-> DFS (Multiset (SymbolM, Eql), Monomial)
forall a. a -> DFS a
dfs Monomial
xs ((SymbolM, Eql) -> Multiset (SymbolM, Eql)
forall m. m -> Multiset m
Multiset (SymbolM
SymbolM, Eql
Eql))
      [ (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| (apply1 #"sin" _ zero, _) : _ -> Term 0 [] |]
      , (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| (apply1 #"sin" _ (singleTerm _ #1 [(symbol #"π", #1)]), _) : _ ->
               Term 0 [] |]
      , (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| (apply1 #"sin" _ (singleTerm $n #2 [(symbol #"π", #1)]), $m) : $xss ->
              Term (a * (-1) ^ (div (abs n - 1) 2) * m) xss |]
      , (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| _ -> term |]
      ]
  g :: TermExpr -> TermExpr
g term :: TermExpr
term@(Term Integer
a Monomial
xs) =
    ((Multiset (SymbolM, Eql), Monomial)
 -> DFS (Multiset (SymbolM, Eql), Monomial))
-> Monomial
-> Multiset (SymbolM, Eql)
-> [(Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr]
-> TermExpr
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (Multiset (SymbolM, Eql), Monomial)
-> DFS (Multiset (SymbolM, Eql), Monomial)
forall a. a -> DFS a
dfs Monomial
xs ((SymbolM, Eql) -> Multiset (SymbolM, Eql)
forall m. m -> Multiset m
Multiset (SymbolM
SymbolM, Eql
Eql))
      [ (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| (apply1 #"cos" _ zero, _) : $xss -> Term a xss |]
      , (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| (apply1 #"cos" _ (singleTerm _ #2 [(symbol #"π", #1)]), _) : _ ->
              Term 0 [] |]
      , (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| (apply1 #"cos" _ (singleTerm $n #1 [(symbol #"π", #1)]), $m) : $xss ->
               Term (a * (-1) ^ (abs n * m)) xss |]
      , (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| _ -> term |]
      ]
  h :: ScalarData -> ScalarData
h (Div poly1 :: PolyExpr
poly1@(Plus [TermExpr]
ts1) poly2 :: PolyExpr
poly2@(Plus [TermExpr]
ts2)) =
    (((Multiset TermM, Multiset TermM), ([TermExpr], [TermExpr]))
 -> DFS
      ((Multiset TermM, Multiset TermM), ([TermExpr], [TermExpr])))
-> ([TermExpr], [TermExpr])
-> (Multiset TermM, Multiset TermM)
-> [((Multiset TermM, Multiset TermM), ([TermExpr], [TermExpr]))
    -> DFS ScalarData]
-> ScalarData
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match ((Multiset TermM, Multiset TermM), ([TermExpr], [TermExpr]))
-> DFS ((Multiset TermM, Multiset TermM), ([TermExpr], [TermExpr]))
forall a. a -> DFS a
dfs ([TermExpr]
ts1, [TermExpr]
ts2) (TermM -> Multiset TermM
forall m. m -> Multiset m
Multiset TermM
TermM, TermM -> Multiset TermM
forall m. m -> Multiset m
Multiset TermM
TermM)
      [ ((Multiset TermM, Multiset TermM), ([TermExpr], [TermExpr]))
-> DFS ScalarData
[mc| ((term $a ((apply1 #"cos" $cosWhnf $x, #2) : $mr)) : (term $b ((apply1 #"sin" $sinWhnf #x, #2) : #mr)) : $pr, _) ->
              h (Div (Plus (Term a mr : Term (b - a) ((makeApply sinWhnf [x], 2) : mr) : pr)) poly2) |]
      , ((Multiset TermM, Multiset TermM), ([TermExpr], [TermExpr]))
-> DFS ScalarData
[mc| ((term $a ((apply1 #"cos" $cosWhnf $x, #2) : $mr)) : $pr1, (term _ ((apply1 #"sin" $sinWhnf #x, #2) : #mr)) : _) ->
              h (Div (Plus (Term a mr : Term (- a) ((makeApply sinWhnf [x], 2) : mr) : pr1)) poly2) |]
      , ((Multiset TermM, Multiset TermM), ([TermExpr], [TermExpr]))
-> DFS ScalarData
[mc| _ -> Div poly1 poly2 |]
      ]

-- Determine if a ScalarData is definitely negative
-- Returns Just True if negative, Just False if non-negative, Nothing if unknown
isNegativeScalar :: ScalarData -> Maybe Bool
isNegativeScalar :: ScalarData -> Maybe Bool
isNegativeScalar (Div (Plus [TermExpr]
terms) (Plus [Term Integer
d []]))
  | Integer
d Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
> Integer
0 = [TermExpr] -> Maybe Bool
forall {t :: * -> *}.
(ElemT (t TermExpr) ~ TermExpr, Foldable t,
 Matcher (Multiset TermM) (t TermExpr),
 CollectionPattern (Multiset TermM) (t TermExpr)) =>
t TermExpr -> Maybe Bool
analyzeTerms [TermExpr]
terms
  | Integer
d Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
0 = (Bool -> Bool) -> Maybe Bool -> Maybe Bool
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Bool -> Bool
not ([TermExpr] -> Maybe Bool
forall {t :: * -> *}.
(ElemT (t TermExpr) ~ TermExpr, Foldable t,
 Matcher (Multiset TermM) (t TermExpr),
 CollectionPattern (Multiset TermM) (t TermExpr)) =>
t TermExpr -> Maybe Bool
analyzeTerms [TermExpr]
terms)
 where
  analyzeTerms :: t TermExpr -> Maybe Bool
analyzeTerms t TermExpr
ts
    | (TermExpr -> Bool) -> t TermExpr -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\(Term Integer
a Monomial
_) -> Integer
a Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
0) t TermExpr
ts = Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
True
    | (TermExpr -> Bool) -> t TermExpr -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\(Term Integer
a Monomial
_) -> Integer
a Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
> Integer
0) t TermExpr
ts = Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
False
    | Bool
otherwise =
      -- Two-term case: a + b*sqrt(n), compare a^2 with b^2*n
      ((Multiset TermM, t TermExpr) -> DFS (Multiset TermM, t TermExpr))
-> t TermExpr
-> Multiset TermM
-> [(Multiset TermM, t TermExpr) -> DFS (Maybe Bool)]
-> Maybe Bool
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (Multiset TermM, t TermExpr) -> DFS (Multiset TermM, t TermExpr)
forall a. a -> DFS a
dfs t TermExpr
ts (TermM -> Multiset TermM
forall m. m -> Multiset m
Multiset TermM
TermM)
        [ (Multiset TermM, t TermExpr) -> DFS (Maybe Bool)
[mc| term $a [] :
               term $b ((apply1 #"sqrt" _ (singleTerm $n #1 []), #1) : []) :
               [] ->
                 if n > 0
                 then let lhs = a * a; rhs = b * b * n
                      in if lhs > rhs then Just (a < 0)
                         else if lhs < rhs then Just (b < 0)
                         else Just False
                 else Nothing |]
        , (Multiset TermM, t TermExpr) -> DFS (Maybe Bool)
[mc| _ -> Nothing |]
        ]
isNegativeScalar ScalarData
_ = Maybe Bool
forall a. Maybe a
Nothing

-- Find a pair of sqrts in a monomial whose product simplifies to a single term.
-- Uses matchAll to enumerate all sqrt pairs, avoiding DFS ordering issues.
-- We apply rewriteSqrt to the product because mathMult alone does not simplify
-- sqrt(x)^2 to x, which is needed for products like (-5-2√5)*(-5+2√5).
findSqrtPairToMerge :: Monomial -> Maybe (WHNFData, ScalarData, Monomial, Integer)
findSqrtPairToMerge :: Monomial -> Maybe (WHNFData, ScalarData, Monomial, Integer)
findSqrtPairToMerge Monomial
xs =
  case [(WHNFData, ScalarData, Monomial, Integer)]
results of
    ((WHNFData, ScalarData, Monomial, Integer)
r:[(WHNFData, ScalarData, Monomial, Integer)]
_) -> (WHNFData, ScalarData, Monomial, Integer)
-> Maybe (WHNFData, ScalarData, Monomial, Integer)
forall a. a -> Maybe a
Just (WHNFData, ScalarData, Monomial, Integer)
r
    []    -> Maybe (WHNFData, ScalarData, Monomial, Integer)
forall a. Maybe a
Nothing
 where
  results :: [(WHNFData, ScalarData, Monomial, Integer)]
results =
    [ (WHNFData
whnf, ScalarData
simplified, Monomial
xss, Integer
sign)
    | (WHNFData
whnf, ScalarData
x, ScalarData
y, Monomial
xss) <- ((Multiset (SymbolM, Eql), Monomial)
 -> DFS (Multiset (SymbolM, Eql), Monomial))
-> Monomial
-> Multiset (SymbolM, Eql)
-> [(Multiset (SymbolM, Eql), Monomial)
    -> DFS (WHNFData, ScalarData, ScalarData, Monomial)]
-> [(WHNFData, ScalarData, ScalarData, Monomial)]
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> [r]
matchAll (Multiset (SymbolM, Eql), Monomial)
-> DFS (Multiset (SymbolM, Eql), Monomial)
forall a. a -> DFS a
dfs Monomial
xs ((SymbolM, Eql) -> Multiset (SymbolM, Eql)
forall m. m -> Multiset m
Multiset (SymbolM
SymbolM, Eql
Eql))
        [ (Multiset (SymbolM, Eql), Monomial)
-> DFS (WHNFData, ScalarData, ScalarData, Monomial)
[mc| (apply1 #"sqrt" $whnf $x, #1) :
               (apply1 #"sqrt" _ $y, #1) : $xss ->
                 (whnf, x, y, xss) |] ]
    , let simplified :: ScalarData
simplified = ScalarData -> ScalarData
rewriteSqrt (ScalarData -> ScalarData -> ScalarData
mathMult ScalarData
x ScalarData
y)
    , ScalarData -> Bool
isSingleTermScalar ScalarData
simplified
    , let sign :: Integer
sign = case (ScalarData -> Maybe Bool
isNegativeScalar ScalarData
x, ScalarData -> Maybe Bool
isNegativeScalar ScalarData
y) of
                   (Just Bool
True, Just Bool
True) -> -Integer
1
                   (Maybe Bool, Maybe Bool)
_                      -> Integer
1
    ]
  isSingleTermScalar :: ScalarData -> Bool
isSingleTermScalar (Div (Plus [TermExpr
_]) (Plus [TermExpr
_])) = Bool
True
  isSingleTermScalar ScalarData
_ = Bool
False

rewriteSqrt :: ScalarData -> ScalarData
rewriteSqrt :: ScalarData -> ScalarData
rewriteSqrt = (TermExpr -> ScalarData) -> ScalarData -> ScalarData
mapTerms' TermExpr -> ScalarData
f
 where
  f :: TermExpr -> ScalarData
f (Term Integer
a Monomial
xs) =
    ((Multiset (SymbolM, Eql), Monomial)
 -> DFS (Multiset (SymbolM, Eql), Monomial))
-> Monomial
-> Multiset (SymbolM, Eql)
-> [(Multiset (SymbolM, Eql), Monomial) -> DFS ScalarData]
-> ScalarData
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (Multiset (SymbolM, Eql), Monomial)
-> DFS (Multiset (SymbolM, Eql), Monomial)
forall a. a -> DFS a
dfs Monomial
xs ((SymbolM, Eql) -> Multiset (SymbolM, Eql)
forall m. m -> Multiset m
Multiset (SymbolM
SymbolM, Eql
Eql))
      [ (Multiset (SymbolM, Eql), Monomial) -> DFS ScalarData
[mc| (apply1 #"sqrt" $sqrtWhnf $x, ?(> 1) & $k) : $xss ->
               rewriteSqrt
                 (mathMult (SingleTerm a ((makeApply sqrtWhnf [x], k `mod` 2) : xss))
                           (mathPower x (div k 2))) |]
      , (Multiset (SymbolM, Eql), Monomial) -> DFS ScalarData
[mc| (apply1 #"sqrt" $sqrtWhnf (singleTerm $n #1 $x), #1) :
               (apply1 #"sqrt" _ (singleTerm $m #1 $y), #1) : $xss ->
             let d@(Term c z) = termsGcd [Term n x, Term m y]
                 Term n' x' = mathDivideTerm (Term n x) d
                 Term m' y' = mathDivideTerm (Term m y) d
                 in case (n' * m', Term n' x', Term m' y') of
                      (1, Term _ [], Term _ []) -> mathMult (SingleTerm c z) (SingleTerm a xss)
                      (_, _, _) -> mathMult (SingleTerm c z) (SingleTerm a ((makeApply sqrtWhnf [SingleTerm (n' * m') (x' ++ y')], 1) : xss)) |]
      , (Multiset (SymbolM, Eql), Monomial) -> DFS ScalarData
[mc| _ -> case findSqrtPairToMerge xs of
                    Just (whnf, product, remaining, sign) ->
                      rewriteSqrt (SingleTerm (sign * a) ((makeApply whnf [product], 1) : remaining))
                    Nothing -> SingleTerm a xs |]
      ]

rewriteRt :: ScalarData -> ScalarData
rewriteRt :: ScalarData -> ScalarData
rewriteRt = (TermExpr -> ScalarData) -> ScalarData -> ScalarData
mapTerms' TermExpr -> ScalarData
f
 where
  f :: TermExpr -> ScalarData
f (Term Integer
a Monomial
xs) =
    ((Multiset (SymbolM, Eql), Monomial)
 -> DFS (Multiset (SymbolM, Eql), Monomial))
-> Monomial
-> Multiset (SymbolM, Eql)
-> [(Multiset (SymbolM, Eql), Monomial) -> DFS ScalarData]
-> ScalarData
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (Multiset (SymbolM, Eql), Monomial)
-> DFS (Multiset (SymbolM, Eql), Monomial)
forall a. a -> DFS a
dfs Monomial
xs ((SymbolM, Eql) -> Multiset (SymbolM, Eql)
forall m. m -> Multiset m
Multiset (SymbolM
SymbolM, Eql
Eql))
      [ (Multiset (SymbolM, Eql), Monomial) -> DFS ScalarData
[mc| (apply2 #"rt" _ (singleTerm $n #1 []) $x & $rtnx, ?(>= n) & $k) : $xss ->
               mathMult (SingleTerm a ((rtnx, k `mod` n) : xss))
                        (mathPower x (div k n)) |]
      , (Multiset (SymbolM, Eql), Monomial) -> DFS ScalarData
[mc| _ -> SingleTerm a xs |]
      ]

rewriteRtu :: ScalarData -> ScalarData
rewriteRtu :: ScalarData -> ScalarData
rewriteRtu = (TermExpr -> ScalarData) -> ScalarData -> ScalarData
mapTerms' TermExpr -> ScalarData
g (ScalarData -> ScalarData)
-> (ScalarData -> ScalarData) -> ScalarData -> ScalarData
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TermExpr -> TermExpr) -> ScalarData -> ScalarData
mapTerms TermExpr -> TermExpr
f
 where
  f :: TermExpr -> TermExpr
f term :: TermExpr
term@(Term Integer
a Monomial
xs) =
    ((Multiset (SymbolM, Eql), Monomial)
 -> DFS (Multiset (SymbolM, Eql), Monomial))
-> Monomial
-> Multiset (SymbolM, Eql)
-> [(Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr]
-> TermExpr
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (Multiset (SymbolM, Eql), Monomial)
-> DFS (Multiset (SymbolM, Eql), Monomial)
forall a. a -> DFS a
dfs Monomial
xs ((SymbolM, Eql) -> Multiset (SymbolM, Eql)
forall m. m -> Multiset m
Multiset (SymbolM
SymbolM, Eql
Eql))
      [ (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| (apply1 #"rtu" _ (singleTerm $n #1 []) & $rtun, ?(>= n) & $k) : $r ->
               Term a ((rtun, k `mod` n) : r) |]
      , (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| _ -> term |]
      ]
  g :: TermExpr -> ScalarData
g (Term Integer
a Monomial
xs) =
    ((Multiset (SymbolM, Eql), Monomial)
 -> DFS (Multiset (SymbolM, Eql), Monomial))
-> Monomial
-> Multiset (SymbolM, Eql)
-> [(Multiset (SymbolM, Eql), Monomial) -> DFS ScalarData]
-> ScalarData
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (Multiset (SymbolM, Eql), Monomial)
-> DFS (Multiset (SymbolM, Eql), Monomial)
forall a. a -> DFS a
dfs Monomial
xs ((SymbolM, Eql) -> Multiset (SymbolM, Eql)
forall m. m -> Multiset m
Multiset (SymbolM
SymbolM, Eql
Eql))
      [ (Multiset (SymbolM, Eql), Monomial) -> DFS ScalarData
[mc| (apply1 #"rtu" _ (singleTerm $n #1 []) & $rtun, ?(== n - 1)) : $mr ->
               mathMult
                 (foldl mathMinus (SingleTerm (-1) []) (map (\k -> SingleTerm 1 [(rtun, k)]) [1..(n-2)]))
                 (g (Term a mr)) |]
      , (Multiset (SymbolM, Eql), Monomial) -> DFS ScalarData
[mc| _ -> SingleTerm a xs |]
      ]

rewriteDd :: ScalarData -> ScalarData
rewriteDd :: ScalarData -> ScalarData
rewriteDd (Div (Plus [TermExpr]
p1) (Plus [TermExpr]
p2)) =
  PolyExpr -> PolyExpr -> ScalarData
Div ([TermExpr] -> PolyExpr
Plus ([TermExpr] -> [TermExpr]
rewriteDdPoly [TermExpr]
p1)) ([TermExpr] -> PolyExpr
Plus ([TermExpr] -> [TermExpr]
rewriteDdPoly [TermExpr]
p2))
 where
  rewriteDdPoly :: [TermExpr] -> [TermExpr]
rewriteDdPoly [TermExpr]
poly =
    ((Multiset TermM, [TermExpr]) -> DFS (Multiset TermM, [TermExpr]))
-> [TermExpr]
-> Multiset TermM
-> [(Multiset TermM, [TermExpr]) -> DFS [TermExpr]]
-> [TermExpr]
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (Multiset TermM, [TermExpr]) -> DFS (Multiset TermM, [TermExpr])
forall a. a -> DFS a
dfs [TermExpr]
poly (TermM -> Multiset TermM
forall m. m -> Multiset m
Multiset TermM
TermM)
      [ (Multiset TermM, [TermExpr]) -> DFS [TermExpr]
[mc| term $a (($f & func $g $args, $n) : $mr) :
               term $b ((func #g #args, #n) : #mr) : $pr ->
                 rewriteDdPoly (Term (a + b) ((f, n) : mr) : pr) |]
      , (Multiset TermM, [TermExpr]) -> DFS [TermExpr]
[mc| _ -> poly |]
      ]