{-# LANGUAGE QuasiQuotes #-}
module Language.Egison.Math.Rewrite
( rewriteSymbol
) where
import Control.Egison
import Language.Egison.Math.Arith
import Language.Egison.Math.Expr
import Language.Egison.Math.Normalize
import {-# SOURCE #-} Language.Egison.Data (WHNFData)
rewriteSymbol :: ScalarData -> ScalarData
rewriteSymbol :: ScalarData -> ScalarData
rewriteSymbol =
((ScalarData -> ScalarData)
-> (ScalarData -> ScalarData) -> ScalarData -> ScalarData)
-> [ScalarData -> ScalarData] -> ScalarData -> ScalarData
forall a. (a -> a -> a) -> [a] -> a
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 (\ScalarData -> ScalarData
acc ScalarData -> ScalarData
f -> ScalarData -> ScalarData
f (ScalarData -> ScalarData)
-> (ScalarData -> ScalarData) -> ScalarData -> ScalarData
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ScalarData -> ScalarData
acc)
[ ScalarData -> ScalarData
rewriteI
, ScalarData -> ScalarData
rewriteW
, ScalarData -> ScalarData
rewriteLog
, ScalarData -> ScalarData
rewriteExp
, ScalarData -> ScalarData
rewritePower
, ScalarData -> ScalarData
rewriteSqrt
, ScalarData -> ScalarData
rewriteRt
, ScalarData -> ScalarData
rewriteRtu
, ScalarData -> ScalarData
rewriteDd
]
mapTerms :: (TermExpr -> TermExpr) -> ScalarData -> ScalarData
mapTerms :: (TermExpr -> TermExpr) -> ScalarData -> ScalarData
mapTerms TermExpr -> TermExpr
f (Div (Plus [TermExpr]
ts1) (Plus [TermExpr]
ts2)) =
PolyExpr -> PolyExpr -> ScalarData
Div ([TermExpr] -> PolyExpr
Plus ((TermExpr -> TermExpr) -> [TermExpr] -> [TermExpr]
forall a b. (a -> b) -> [a] -> [b]
map TermExpr -> TermExpr
f [TermExpr]
ts1)) ([TermExpr] -> PolyExpr
Plus ((TermExpr -> TermExpr) -> [TermExpr] -> [TermExpr]
forall a b. (a -> b) -> [a] -> [b]
map TermExpr -> TermExpr
f [TermExpr]
ts2))
mapTerms' :: (TermExpr -> ScalarData) -> ScalarData -> ScalarData
mapTerms' :: (TermExpr -> ScalarData) -> ScalarData -> ScalarData
mapTerms' TermExpr -> ScalarData
f (Div (Plus [TermExpr]
ts1) (Plus [TermExpr]
ts2)) =
ScalarData -> ScalarData -> ScalarData
mathDiv ((ScalarData -> ScalarData -> ScalarData)
-> ScalarData -> [ScalarData] -> ScalarData
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ScalarData -> ScalarData -> ScalarData
mathPlus (PolyExpr -> PolyExpr -> ScalarData
Div ([TermExpr] -> PolyExpr
Plus []) ([TermExpr] -> PolyExpr
Plus [Integer -> Monomial -> TermExpr
Term Integer
1 []])) ((TermExpr -> ScalarData) -> [TermExpr] -> [ScalarData]
forall a b. (a -> b) -> [a] -> [b]
map TermExpr -> ScalarData
f [TermExpr]
ts1)) ((ScalarData -> ScalarData -> ScalarData)
-> ScalarData -> [ScalarData] -> ScalarData
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ScalarData -> ScalarData -> ScalarData
mathPlus (PolyExpr -> PolyExpr -> ScalarData
Div ([TermExpr] -> PolyExpr
Plus []) ([TermExpr] -> PolyExpr
Plus [Integer -> Monomial -> TermExpr
Term Integer
1 []])) ((TermExpr -> ScalarData) -> [TermExpr] -> [ScalarData]
forall a b. (a -> b) -> [a] -> [b]
map TermExpr -> ScalarData
f [TermExpr]
ts2))
mapPolys :: (PolyExpr -> PolyExpr) -> ScalarData -> ScalarData
mapPolys :: (PolyExpr -> PolyExpr) -> ScalarData -> ScalarData
mapPolys PolyExpr -> PolyExpr
f (Div PolyExpr
p1 PolyExpr
p2) = PolyExpr -> PolyExpr -> ScalarData
Div (PolyExpr -> PolyExpr
f PolyExpr
p1) (PolyExpr -> PolyExpr
f PolyExpr
p2)
rewriteI :: ScalarData -> ScalarData
rewriteI :: ScalarData -> ScalarData
rewriteI = (TermExpr -> TermExpr) -> ScalarData -> ScalarData
mapTerms TermExpr -> TermExpr
f
where
f :: TermExpr -> TermExpr
f term :: TermExpr
term@(Term Integer
a Monomial
xs) =
((Multiset (SymbolM, Eql), Monomial)
-> DFS (Multiset (SymbolM, Eql), Monomial))
-> Monomial
-> Multiset (SymbolM, Eql)
-> [(Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr]
-> TermExpr
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (Multiset (SymbolM, Eql), Monomial)
-> DFS (Multiset (SymbolM, Eql), Monomial)
forall a. a -> DFS a
dfs Monomial
xs ((SymbolM, Eql) -> Multiset (SymbolM, Eql)
forall m. m -> Multiset m
Multiset (SymbolM
SymbolM, Eql
Eql))
[ (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| (symbol #"i", $k) : $xss ->
if even k
then Term (a * (-1) ^ (quot k 2)) xss
else Term (a * (-1) ^ (quot k 2)) ((Symbol "" "i" [], 1) : xss) |]
, (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| _ -> term |]
]
rewriteW :: ScalarData -> ScalarData
rewriteW :: ScalarData -> ScalarData
rewriteW = (PolyExpr -> PolyExpr) -> ScalarData -> ScalarData
mapPolys PolyExpr -> PolyExpr
g (ScalarData -> ScalarData)
-> (ScalarData -> ScalarData) -> ScalarData -> ScalarData
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TermExpr -> TermExpr) -> ScalarData -> ScalarData
mapTerms TermExpr -> TermExpr
f
where
f :: TermExpr -> TermExpr
f term :: TermExpr
term@(Term Integer
a Monomial
xs) =
((Multiset (SymbolM, Eql), Monomial)
-> DFS (Multiset (SymbolM, Eql), Monomial))
-> Monomial
-> Multiset (SymbolM, Eql)
-> [(Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr]
-> TermExpr
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (Multiset (SymbolM, Eql), Monomial)
-> DFS (Multiset (SymbolM, Eql), Monomial)
forall a. a -> DFS a
dfs Monomial
xs ((SymbolM, Eql) -> Multiset (SymbolM, Eql)
forall m. m -> Multiset m
Multiset (SymbolM
SymbolM, Eql
Eql))
[ (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| (symbol #"w", $k & ?(>= 3)) : $xss ->
Term a ((Symbol "" "w" [], k `mod` 3) : xss) |]
, (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| _ -> term |]
]
g :: PolyExpr -> PolyExpr
g poly :: PolyExpr
poly@(Plus [TermExpr]
ts) =
((Multiset TermM, [TermExpr]) -> DFS (Multiset TermM, [TermExpr]))
-> [TermExpr]
-> Multiset TermM
-> [(Multiset TermM, [TermExpr]) -> DFS PolyExpr]
-> PolyExpr
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (Multiset TermM, [TermExpr]) -> DFS (Multiset TermM, [TermExpr])
forall a. a -> DFS a
dfs [TermExpr]
ts (TermM -> Multiset TermM
forall m. m -> Multiset m
Multiset TermM
TermM)
[ (Multiset TermM, [TermExpr]) -> DFS PolyExpr
[mc| term $a ((symbol #"w", #2) : $mr) :
term $b ((symbol #"w", #1) : #mr) : $pr ->
g (Plus (Term (-a) mr :
Term (b - a) ((Symbol "" "w" [], 1) : mr) : pr)) |]
, (Multiset TermM, [TermExpr]) -> DFS PolyExpr
[mc| _ -> poly |]
]
rewriteLog :: ScalarData -> ScalarData
rewriteLog :: ScalarData -> ScalarData
rewriteLog = (TermExpr -> TermExpr) -> ScalarData -> ScalarData
mapTerms TermExpr -> TermExpr
f
where
f :: TermExpr -> TermExpr
f term :: TermExpr
term@(Term Integer
a Monomial
xs) =
((Multiset (SymbolM, Eql), Monomial)
-> DFS (Multiset (SymbolM, Eql), Monomial))
-> Monomial
-> Multiset (SymbolM, Eql)
-> [(Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr]
-> TermExpr
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (Multiset (SymbolM, Eql), Monomial)
-> DFS (Multiset (SymbolM, Eql), Monomial)
forall a. a -> DFS a
dfs Monomial
xs ((SymbolM, Eql) -> Multiset (SymbolM, Eql)
forall m. m -> Multiset m
Multiset (SymbolM
SymbolM, Eql
Eql))
[ (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| (apply1 #"log" _ zero, _) : _ -> Term 0 [] |]
, (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| (apply1 #"log" _ (singleTerm _ #1 [(symbol #"e", $n)]), _) : $xss ->
Term (n * a) xss |]
, (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| _ -> term |]
]
makeApply :: WHNFData -> [ScalarData] -> SymbolExpr
makeApply :: WHNFData -> [ScalarData] -> SymbolExpr
makeApply WHNFData
f [ScalarData]
args =
ScalarData -> [ScalarData] -> SymbolExpr
makeApplyExpr (SymbolExpr -> ScalarData
SingleSymbol (WHNFData -> SymbolExpr
QuoteFunction WHNFData
f)) [ScalarData]
args
rewriteExp :: ScalarData -> ScalarData
rewriteExp :: ScalarData -> ScalarData
rewriteExp = (TermExpr -> TermExpr) -> ScalarData -> ScalarData
mapTerms TermExpr -> TermExpr
f
where
f :: TermExpr -> TermExpr
f term :: TermExpr
term@(Term Integer
a Monomial
xs) =
((Multiset (SymbolM, Eql), Monomial)
-> DFS (Multiset (SymbolM, Eql), Monomial))
-> Monomial
-> Multiset (SymbolM, Eql)
-> [(Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr]
-> TermExpr
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (Multiset (SymbolM, Eql), Monomial)
-> DFS (Multiset (SymbolM, Eql), Monomial)
forall a. a -> DFS a
dfs Monomial
xs ((SymbolM, Eql) -> Multiset (SymbolM, Eql)
forall m. m -> Multiset m
Multiset (SymbolM
SymbolM, Eql
Eql))
[ (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| (apply1 #"exp" _ zero, _) : $xss ->
f (Term a xss) |]
, (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| (apply1 #"exp" _ (singleTerm #1 #1 []), _) : $xss ->
f (Term a ((Symbol "" "e" [], 1) : xss)) |]
, (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| (apply1 #"exp" _ (singleTerm $n #1 [(symbol #"i", #1), (symbol #"π", #1)]), _) : $xss ->
f (Term ((-1) ^ n * a) xss) |]
, (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| (apply1 #"exp" $expWhnf $x, $n & ?(>= 2)) : $xss ->
f (Term a ((makeApply expWhnf [mathScalarMult n x], 1) : xss)) |]
, (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| (apply1 #"exp" $expWhnf $x, #1) : (apply1 #"exp" _ $y, #1) : $xss ->
f (Term a ((makeApply expWhnf [mathPlus x y], 1) : xss)) |]
, (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| _ -> term |]
]
rewritePower :: ScalarData -> ScalarData
rewritePower :: ScalarData -> ScalarData
rewritePower = (TermExpr -> TermExpr) -> ScalarData -> ScalarData
mapTerms TermExpr -> TermExpr
f
where
f :: TermExpr -> TermExpr
f term :: TermExpr
term@(Term Integer
a Monomial
xs) =
((Multiset (SymbolM, Eql), Monomial)
-> DFS (Multiset (SymbolM, Eql), Monomial))
-> Monomial
-> Multiset (SymbolM, Eql)
-> [(Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr]
-> TermExpr
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (Multiset (SymbolM, Eql), Monomial)
-> DFS (Multiset (SymbolM, Eql), Monomial)
forall a. a -> DFS a
dfs Monomial
xs ((SymbolM, Eql) -> Multiset (SymbolM, Eql)
forall m. m -> Multiset m
Multiset (SymbolM
SymbolM, Eql
Eql))
[ (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| (apply1 #"^" _ (singleTerm #1 #1 []), _) : $xss -> f (Term a xss) |]
, (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| (apply2 #"^" $powerWhnf $x $y, $n & ?(>= 2)) : $xss ->
f (Term a ((makeApply powerWhnf [x, mathScalarMult n y], 1) : xss)) |]
, (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| (apply2 #"^" $powerWhnf $x $y, #1) : (apply2 #"^" _ #x $z, #1) : $xss ->
f (Term a ((makeApply powerWhnf [x, mathPlus y z], 1) : xss)) |]
, (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| _ -> term |]
]
rewriteSinCos :: ScalarData -> ScalarData
rewriteSinCos :: ScalarData -> ScalarData
rewriteSinCos = ScalarData -> ScalarData
h (ScalarData -> ScalarData)
-> (ScalarData -> ScalarData) -> ScalarData -> ScalarData
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TermExpr -> TermExpr) -> ScalarData -> ScalarData
mapTerms (TermExpr -> TermExpr
g (TermExpr -> TermExpr)
-> (TermExpr -> TermExpr) -> TermExpr -> TermExpr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TermExpr -> TermExpr
f)
where
f :: TermExpr -> TermExpr
f term :: TermExpr
term@(Term Integer
a Monomial
xs) =
((Multiset (SymbolM, Eql), Monomial)
-> DFS (Multiset (SymbolM, Eql), Monomial))
-> Monomial
-> Multiset (SymbolM, Eql)
-> [(Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr]
-> TermExpr
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (Multiset (SymbolM, Eql), Monomial)
-> DFS (Multiset (SymbolM, Eql), Monomial)
forall a. a -> DFS a
dfs Monomial
xs ((SymbolM, Eql) -> Multiset (SymbolM, Eql)
forall m. m -> Multiset m
Multiset (SymbolM
SymbolM, Eql
Eql))
[ (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| (apply1 #"sin" _ zero, _) : _ -> Term 0 [] |]
, (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| (apply1 #"sin" _ (singleTerm _ #1 [(symbol #"π", #1)]), _) : _ ->
Term 0 [] |]
, (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| (apply1 #"sin" _ (singleTerm $n #2 [(symbol #"π", #1)]), $m) : $xss ->
Term (a * (-1) ^ (div (abs n - 1) 2) * m) xss |]
, (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| _ -> term |]
]
g :: TermExpr -> TermExpr
g term :: TermExpr
term@(Term Integer
a Monomial
xs) =
((Multiset (SymbolM, Eql), Monomial)
-> DFS (Multiset (SymbolM, Eql), Monomial))
-> Monomial
-> Multiset (SymbolM, Eql)
-> [(Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr]
-> TermExpr
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (Multiset (SymbolM, Eql), Monomial)
-> DFS (Multiset (SymbolM, Eql), Monomial)
forall a. a -> DFS a
dfs Monomial
xs ((SymbolM, Eql) -> Multiset (SymbolM, Eql)
forall m. m -> Multiset m
Multiset (SymbolM
SymbolM, Eql
Eql))
[ (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| (apply1 #"cos" _ zero, _) : $xss -> Term a xss |]
, (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| (apply1 #"cos" _ (singleTerm _ #2 [(symbol #"π", #1)]), _) : _ ->
Term 0 [] |]
, (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| (apply1 #"cos" _ (singleTerm $n #1 [(symbol #"π", #1)]), $m) : $xss ->
Term (a * (-1) ^ (abs n * m)) xss |]
, (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| _ -> term |]
]
h :: ScalarData -> ScalarData
h (Div poly1 :: PolyExpr
poly1@(Plus [TermExpr]
ts1) poly2 :: PolyExpr
poly2@(Plus [TermExpr]
ts2)) =
(((Multiset TermM, Multiset TermM), ([TermExpr], [TermExpr]))
-> DFS
((Multiset TermM, Multiset TermM), ([TermExpr], [TermExpr])))
-> ([TermExpr], [TermExpr])
-> (Multiset TermM, Multiset TermM)
-> [((Multiset TermM, Multiset TermM), ([TermExpr], [TermExpr]))
-> DFS ScalarData]
-> ScalarData
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match ((Multiset TermM, Multiset TermM), ([TermExpr], [TermExpr]))
-> DFS ((Multiset TermM, Multiset TermM), ([TermExpr], [TermExpr]))
forall a. a -> DFS a
dfs ([TermExpr]
ts1, [TermExpr]
ts2) (TermM -> Multiset TermM
forall m. m -> Multiset m
Multiset TermM
TermM, TermM -> Multiset TermM
forall m. m -> Multiset m
Multiset TermM
TermM)
[ ((Multiset TermM, Multiset TermM), ([TermExpr], [TermExpr]))
-> DFS ScalarData
[mc| ((term $a ((apply1 #"cos" $cosWhnf $x, #2) : $mr)) : (term $b ((apply1 #"sin" $sinWhnf #x, #2) : #mr)) : $pr, _) ->
h (Div (Plus (Term a mr : Term (b - a) ((makeApply sinWhnf [x], 2) : mr) : pr)) poly2) |]
, ((Multiset TermM, Multiset TermM), ([TermExpr], [TermExpr]))
-> DFS ScalarData
[mc| ((term $a ((apply1 #"cos" $cosWhnf $x, #2) : $mr)) : $pr1, (term _ ((apply1 #"sin" $sinWhnf #x, #2) : #mr)) : _) ->
h (Div (Plus (Term a mr : Term (- a) ((makeApply sinWhnf [x], 2) : mr) : pr1)) poly2) |]
, ((Multiset TermM, Multiset TermM), ([TermExpr], [TermExpr]))
-> DFS ScalarData
[mc| _ -> Div poly1 poly2 |]
]
isNegativeScalar :: ScalarData -> Maybe Bool
isNegativeScalar :: ScalarData -> Maybe Bool
isNegativeScalar (Div (Plus [TermExpr]
terms) (Plus [Term Integer
d []]))
| Integer
d Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
> Integer
0 = [TermExpr] -> Maybe Bool
forall {t :: * -> *}.
(ElemT (t TermExpr) ~ TermExpr, Foldable t,
Matcher (Multiset TermM) (t TermExpr),
CollectionPattern (Multiset TermM) (t TermExpr)) =>
t TermExpr -> Maybe Bool
analyzeTerms [TermExpr]
terms
| Integer
d Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
0 = (Bool -> Bool) -> Maybe Bool -> Maybe Bool
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Bool -> Bool
not ([TermExpr] -> Maybe Bool
forall {t :: * -> *}.
(ElemT (t TermExpr) ~ TermExpr, Foldable t,
Matcher (Multiset TermM) (t TermExpr),
CollectionPattern (Multiset TermM) (t TermExpr)) =>
t TermExpr -> Maybe Bool
analyzeTerms [TermExpr]
terms)
where
analyzeTerms :: t TermExpr -> Maybe Bool
analyzeTerms t TermExpr
ts
| (TermExpr -> Bool) -> t TermExpr -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\(Term Integer
a Monomial
_) -> Integer
a Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
0) t TermExpr
ts = Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
True
| (TermExpr -> Bool) -> t TermExpr -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\(Term Integer
a Monomial
_) -> Integer
a Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
> Integer
0) t TermExpr
ts = Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
False
| Bool
otherwise =
((Multiset TermM, t TermExpr) -> DFS (Multiset TermM, t TermExpr))
-> t TermExpr
-> Multiset TermM
-> [(Multiset TermM, t TermExpr) -> DFS (Maybe Bool)]
-> Maybe Bool
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (Multiset TermM, t TermExpr) -> DFS (Multiset TermM, t TermExpr)
forall a. a -> DFS a
dfs t TermExpr
ts (TermM -> Multiset TermM
forall m. m -> Multiset m
Multiset TermM
TermM)
[ (Multiset TermM, t TermExpr) -> DFS (Maybe Bool)
[mc| term $a [] :
term $b ((apply1 #"sqrt" _ (singleTerm $n #1 []), #1) : []) :
[] ->
if n > 0
then let lhs = a * a; rhs = b * b * n
in if lhs > rhs then Just (a < 0)
else if lhs < rhs then Just (b < 0)
else Just False
else Nothing |]
, (Multiset TermM, t TermExpr) -> DFS (Maybe Bool)
[mc| _ -> Nothing |]
]
isNegativeScalar ScalarData
_ = Maybe Bool
forall a. Maybe a
Nothing
findSqrtPairToMerge :: Monomial -> Maybe (WHNFData, ScalarData, Monomial, Integer)
findSqrtPairToMerge :: Monomial -> Maybe (WHNFData, ScalarData, Monomial, Integer)
findSqrtPairToMerge Monomial
xs =
case [(WHNFData, ScalarData, Monomial, Integer)]
results of
((WHNFData, ScalarData, Monomial, Integer)
r:[(WHNFData, ScalarData, Monomial, Integer)]
_) -> (WHNFData, ScalarData, Monomial, Integer)
-> Maybe (WHNFData, ScalarData, Monomial, Integer)
forall a. a -> Maybe a
Just (WHNFData, ScalarData, Monomial, Integer)
r
[] -> Maybe (WHNFData, ScalarData, Monomial, Integer)
forall a. Maybe a
Nothing
where
results :: [(WHNFData, ScalarData, Monomial, Integer)]
results =
[ (WHNFData
whnf, ScalarData
simplified, Monomial
xss, Integer
sign)
| (WHNFData
whnf, ScalarData
x, ScalarData
y, Monomial
xss) <- ((Multiset (SymbolM, Eql), Monomial)
-> DFS (Multiset (SymbolM, Eql), Monomial))
-> Monomial
-> Multiset (SymbolM, Eql)
-> [(Multiset (SymbolM, Eql), Monomial)
-> DFS (WHNFData, ScalarData, ScalarData, Monomial)]
-> [(WHNFData, ScalarData, ScalarData, Monomial)]
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> [r]
matchAll (Multiset (SymbolM, Eql), Monomial)
-> DFS (Multiset (SymbolM, Eql), Monomial)
forall a. a -> DFS a
dfs Monomial
xs ((SymbolM, Eql) -> Multiset (SymbolM, Eql)
forall m. m -> Multiset m
Multiset (SymbolM
SymbolM, Eql
Eql))
[ (Multiset (SymbolM, Eql), Monomial)
-> DFS (WHNFData, ScalarData, ScalarData, Monomial)
[mc| (apply1 #"sqrt" $whnf $x, #1) :
(apply1 #"sqrt" _ $y, #1) : $xss ->
(whnf, x, y, xss) |] ]
, let simplified :: ScalarData
simplified = ScalarData -> ScalarData
rewriteSqrt (ScalarData -> ScalarData -> ScalarData
mathMult ScalarData
x ScalarData
y)
, ScalarData -> Bool
isSingleTermScalar ScalarData
simplified
, let sign :: Integer
sign = case (ScalarData -> Maybe Bool
isNegativeScalar ScalarData
x, ScalarData -> Maybe Bool
isNegativeScalar ScalarData
y) of
(Just Bool
True, Just Bool
True) -> -Integer
1
(Maybe Bool, Maybe Bool)
_ -> Integer
1
]
isSingleTermScalar :: ScalarData -> Bool
isSingleTermScalar (Div (Plus [TermExpr
_]) (Plus [TermExpr
_])) = Bool
True
isSingleTermScalar ScalarData
_ = Bool
False
rewriteSqrt :: ScalarData -> ScalarData
rewriteSqrt :: ScalarData -> ScalarData
rewriteSqrt = (TermExpr -> ScalarData) -> ScalarData -> ScalarData
mapTerms' TermExpr -> ScalarData
f
where
f :: TermExpr -> ScalarData
f (Term Integer
a Monomial
xs) =
((Multiset (SymbolM, Eql), Monomial)
-> DFS (Multiset (SymbolM, Eql), Monomial))
-> Monomial
-> Multiset (SymbolM, Eql)
-> [(Multiset (SymbolM, Eql), Monomial) -> DFS ScalarData]
-> ScalarData
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (Multiset (SymbolM, Eql), Monomial)
-> DFS (Multiset (SymbolM, Eql), Monomial)
forall a. a -> DFS a
dfs Monomial
xs ((SymbolM, Eql) -> Multiset (SymbolM, Eql)
forall m. m -> Multiset m
Multiset (SymbolM
SymbolM, Eql
Eql))
[ (Multiset (SymbolM, Eql), Monomial) -> DFS ScalarData
[mc| (apply1 #"sqrt" $sqrtWhnf $x, ?(> 1) & $k) : $xss ->
rewriteSqrt
(mathMult (SingleTerm a ((makeApply sqrtWhnf [x], k `mod` 2) : xss))
(mathPower x (div k 2))) |]
, (Multiset (SymbolM, Eql), Monomial) -> DFS ScalarData
[mc| (apply1 #"sqrt" $sqrtWhnf (singleTerm $n #1 $x), #1) :
(apply1 #"sqrt" _ (singleTerm $m #1 $y), #1) : $xss ->
let d@(Term c z) = termsGcd [Term n x, Term m y]
Term n' x' = mathDivideTerm (Term n x) d
Term m' y' = mathDivideTerm (Term m y) d
in case (n' * m', Term n' x', Term m' y') of
(1, Term _ [], Term _ []) -> mathMult (SingleTerm c z) (SingleTerm a xss)
(_, _, _) -> mathMult (SingleTerm c z) (SingleTerm a ((makeApply sqrtWhnf [SingleTerm (n' * m') (x' ++ y')], 1) : xss)) |]
, (Multiset (SymbolM, Eql), Monomial) -> DFS ScalarData
[mc| _ -> case findSqrtPairToMerge xs of
Just (whnf, product, remaining, sign) ->
rewriteSqrt (SingleTerm (sign * a) ((makeApply whnf [product], 1) : remaining))
Nothing -> SingleTerm a xs |]
]
rewriteRt :: ScalarData -> ScalarData
rewriteRt :: ScalarData -> ScalarData
rewriteRt = (TermExpr -> ScalarData) -> ScalarData -> ScalarData
mapTerms' TermExpr -> ScalarData
f
where
f :: TermExpr -> ScalarData
f (Term Integer
a Monomial
xs) =
((Multiset (SymbolM, Eql), Monomial)
-> DFS (Multiset (SymbolM, Eql), Monomial))
-> Monomial
-> Multiset (SymbolM, Eql)
-> [(Multiset (SymbolM, Eql), Monomial) -> DFS ScalarData]
-> ScalarData
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (Multiset (SymbolM, Eql), Monomial)
-> DFS (Multiset (SymbolM, Eql), Monomial)
forall a. a -> DFS a
dfs Monomial
xs ((SymbolM, Eql) -> Multiset (SymbolM, Eql)
forall m. m -> Multiset m
Multiset (SymbolM
SymbolM, Eql
Eql))
[ (Multiset (SymbolM, Eql), Monomial) -> DFS ScalarData
[mc| (apply2 #"rt" _ (singleTerm $n #1 []) $x & $rtnx, ?(>= n) & $k) : $xss ->
mathMult (SingleTerm a ((rtnx, k `mod` n) : xss))
(mathPower x (div k n)) |]
, (Multiset (SymbolM, Eql), Monomial) -> DFS ScalarData
[mc| _ -> SingleTerm a xs |]
]
rewriteRtu :: ScalarData -> ScalarData
rewriteRtu :: ScalarData -> ScalarData
rewriteRtu = (TermExpr -> ScalarData) -> ScalarData -> ScalarData
mapTerms' TermExpr -> ScalarData
g (ScalarData -> ScalarData)
-> (ScalarData -> ScalarData) -> ScalarData -> ScalarData
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TermExpr -> TermExpr) -> ScalarData -> ScalarData
mapTerms TermExpr -> TermExpr
f
where
f :: TermExpr -> TermExpr
f term :: TermExpr
term@(Term Integer
a Monomial
xs) =
((Multiset (SymbolM, Eql), Monomial)
-> DFS (Multiset (SymbolM, Eql), Monomial))
-> Monomial
-> Multiset (SymbolM, Eql)
-> [(Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr]
-> TermExpr
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (Multiset (SymbolM, Eql), Monomial)
-> DFS (Multiset (SymbolM, Eql), Monomial)
forall a. a -> DFS a
dfs Monomial
xs ((SymbolM, Eql) -> Multiset (SymbolM, Eql)
forall m. m -> Multiset m
Multiset (SymbolM
SymbolM, Eql
Eql))
[ (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| (apply1 #"rtu" _ (singleTerm $n #1 []) & $rtun, ?(>= n) & $k) : $r ->
Term a ((rtun, k `mod` n) : r) |]
, (Multiset (SymbolM, Eql), Monomial) -> DFS TermExpr
[mc| _ -> term |]
]
g :: TermExpr -> ScalarData
g (Term Integer
a Monomial
xs) =
((Multiset (SymbolM, Eql), Monomial)
-> DFS (Multiset (SymbolM, Eql), Monomial))
-> Monomial
-> Multiset (SymbolM, Eql)
-> [(Multiset (SymbolM, Eql), Monomial) -> DFS ScalarData]
-> ScalarData
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (Multiset (SymbolM, Eql), Monomial)
-> DFS (Multiset (SymbolM, Eql), Monomial)
forall a. a -> DFS a
dfs Monomial
xs ((SymbolM, Eql) -> Multiset (SymbolM, Eql)
forall m. m -> Multiset m
Multiset (SymbolM
SymbolM, Eql
Eql))
[ (Multiset (SymbolM, Eql), Monomial) -> DFS ScalarData
[mc| (apply1 #"rtu" _ (singleTerm $n #1 []) & $rtun, ?(== n - 1)) : $mr ->
mathMult
(foldl mathMinus (SingleTerm (-1) []) (map (\k -> SingleTerm 1 [(rtun, k)]) [1..(n-2)]))
(g (Term a mr)) |]
, (Multiset (SymbolM, Eql), Monomial) -> DFS ScalarData
[mc| _ -> SingleTerm a xs |]
]
rewriteDd :: ScalarData -> ScalarData
rewriteDd :: ScalarData -> ScalarData
rewriteDd (Div (Plus [TermExpr]
p1) (Plus [TermExpr]
p2)) =
PolyExpr -> PolyExpr -> ScalarData
Div ([TermExpr] -> PolyExpr
Plus ([TermExpr] -> [TermExpr]
rewriteDdPoly [TermExpr]
p1)) ([TermExpr] -> PolyExpr
Plus ([TermExpr] -> [TermExpr]
rewriteDdPoly [TermExpr]
p2))
where
rewriteDdPoly :: [TermExpr] -> [TermExpr]
rewriteDdPoly [TermExpr]
poly =
((Multiset TermM, [TermExpr]) -> DFS (Multiset TermM, [TermExpr]))
-> [TermExpr]
-> Multiset TermM
-> [(Multiset TermM, [TermExpr]) -> DFS [TermExpr]]
-> [TermExpr]
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (Multiset TermM, [TermExpr]) -> DFS (Multiset TermM, [TermExpr])
forall a. a -> DFS a
dfs [TermExpr]
poly (TermM -> Multiset TermM
forall m. m -> Multiset m
Multiset TermM
TermM)
[ (Multiset TermM, [TermExpr]) -> DFS [TermExpr]
[mc| term $a (($f & func $g $args, $n) : $mr) :
term $b ((func #g #args, #n) : #mr) : $pr ->
rewriteDdPoly (Term (a + b) ((f, n) : mr) : pr) |]
, (Multiset TermM, [TermExpr]) -> DFS [TermExpr]
[mc| _ -> poly |]
]