--
--
-- Mathematics Expressions
--
--

mathExpr :=
  matcher
    | #$val as () with
      | $tgt -> if val = tgt then [()] else []
    | $ as (mathExpr') with
      | $tgt -> [fromMathExpr tgt]

mathExpr' :=
  matcher
    | div $ $ as (mathExpr, mathExpr) with
      | Div $p1 $p2 -> [(toMathExpr' p1, toMathExpr' p2)]
      | _ -> []
    | poly $ as (multiset mathExpr) with
      | Div (Plus $ts) (Plus [Term 1 []]) -> [map toMathExpr' ts]
      | _ -> []
    | plus $ as (plusExpr) with
      | Div (Plus $ts) (Plus [Term 1 []]) ->
        [toMathExpr' (Div (Plus ts) (Plus [Term 1 []]))]
      | _ -> []
    | term $ $ as (integer, assocMultiset mathExpr) with
      | Div (Plus [Term $n $xs]) (Plus [Term 1 []]) ->
        [(n, map 2#(toMathExpr' %1, %2) xs)]
      | _ -> []
    | mult $ $ as (integer, multExpr) with
      | Div (Plus [Term $n $xs]) (Plus [Term 1 []]) ->
        [(n, product' (map 2#(toMathExpr' %1 ^' %2) xs))]
      | _ -> []
    | symbol $ $ as (eq, list indexExpr) with
      | Div (Plus [Term 1 [(Symbol $v $js, 1)]]) (Plus [Term 1 []]) ->
        [(v, js)]
      | _ -> []
    | apply $ $ as (eq, list mathExpr) with
      | Div (Plus [Term 1 [(Apply $v $mexprs, 1)]]) (Plus [Term 1 []]) ->
        [(v, map toMathExpr' mexprs)]
      | _ -> []
    | quote $ as (mathExpr) with
      | Div (Plus [Term 1 [(Quote $mexpr, 1)]]) (Plus [Term 1 []]) ->
        [toMathExpr' mexpr]
      | _ -> []
    | func $ $ $ $ as
        (mathExpr, list mathExpr, list mathExpr, list indexExpr ) with
      | Div
          (Plus [Term 1 [(Function $name $argnames $args $js, 1)]])
          (Plus [Term 1 []]) ->
        [(name, argnames, args, js)]
      | _ -> []
    | $ as (something) with
      | $tgt -> [toMathExpr' tgt]

indexExpr :=
  algebraicDataMatcher
    | sub mathExpr
    | sup mathExpr
    | user mathExpr

polyExpr := mathExpr

termExpr := mathExpr

symbolExpr := mathExpr

plusExpr :=
  matcher
    | [] as () with
      | $tgt -> if tgt = 0 then [()] else []
    | $ :: $ as (mathExpr, plusExpr) with
      | $tgt ->
        matchAll tgt as mathExpr with
          | poly ($t :: $ts) -> (t, sum' ts)
    | $ as (mathExpr) with
      | $tgt -> [tgt]

multExpr :=
  matcher
    | [] as () with
      | $tgt ->
        match tgt as mathExpr with
          | #0 -> [()]
          | _ -> []
    | $ :: $ as (mathExpr, multExpr) with
      | $tgt ->
        match tgt as mathExpr with
          | term _ $xs ->
            matchAll xs as assocMultiset mathExpr with
              | $x :: $rs -> (x, product' (map (^') rs))
          | _ -> []
    | ncons $ #$k $ as (mathExpr, multExpr) with
      | $tgt ->
        match tgt as mathExpr with
          | term _ $xs ->
            matchAll xs as list (mathExpr, integer) with
              | $hs ++ ($x, ?(>= k) & $n) :: $ts ->
                (x, product' (map (^') (hs ++ (x, n - k) :: ts)))
          | _ -> []
    | ncons $ $ $ as (mathExpr, integer, multExpr) with
      | $tgt ->
        match tgt as mathExpr with
          | term _ $xs ->
            matchAll xs as list (mathExpr, integer) with
              | $hs ++ ($x, $n) :: $ts -> (x, n, product' (map (^') (hs ++ ts)))
          | _ -> []
    | $ as (mathExpr) with
      | $tgt -> [tgt]

isSymbol %mexpr :=
  match mexpr as mathExpr with
    | symbol _ _ -> True
    | _ -> False

isApply %mexpr :=
  match mexpr as mathExpr with
    | apply _ _ -> True
    | _ -> False

isSimpleTerm := 1#(isSymbol %1 || isApply %1)

isTerm %mexpr :=
  match mexpr as mathExpr with
    | term _ _ -> True
    | #0 -> True
    | _ -> False

isPolynomial %mexpr :=
  match mexpr as mathExpr with
    | poly _ -> True
    | #0 -> True
    | _ -> False

isMonomial %mexpr :=
  match mexpr as mathExpr with
    | poly [term _ _] / poly [term _ _] -> True
    | #0 -> True
    | _ -> False

--
-- Accessor
--
symbolIndices $mexpr :=
  match mexpr as mathExpr with
    | symbol _ $js -> js
    | _ -> undefined

fromMonomial $mexpr :=
  match mexpr as mathExpr with
    | (term $a $xs) / (term $b $ys) ->
      (a / b, foldl (*') 1 (map (^') xs) / foldl (*') 1 (map (^') ys))

--
-- Map
--
mapPolys $fn $mexpr :=
  match mexpr as mathExpr with
    | $p1 / $p2 -> fn p1 /' fn p2

fromPoly $mexpr :=
  match mexpr as mathExpr with
    | poly $ts1 / $q -> map (\t1 -> t1 /' q) ts1

mapPoly $fn $mexpr :=
  match mexpr as mathExpr with
    | poly $ts1 / $q -> foldl (+') 0 (map (\t1 -> fn (t1 /' q)) ts1)

mapTerms $fn $mexpr :=
  match mexpr as mathExpr with
    | poly $ts1 / poly $ts2 ->
      foldl (+') 0 (map fn ts1) /' foldl (+') 0 (map fn ts2)

mapSymbols $fn $mexpr :=
  mapTerms
    (\term ->
      match term as termExpr with
        | term $a $xs ->
          a *' foldl
                 (*')
                 1
                 (map
                    2#(match %1 as symbolExpr with
                      | symbol _ _ -> fn %1 ^' %2
                      | apply $g $args ->
                        let args' := map 1#(mapSymbols fn %1) args
                         in if args = args'
                              then %1 ^' %2
                              else fn (capply g args') ^' %2)
                    xs))
    mexpr

containSymbol $x $mexpr :=
  any
    id
    (match mexpr as mathExpr with
      | poly $ts1 / poly $ts2 ->
        map
          (\term ->
            match term as termExpr with
              | term _ $xs ->
                any
                  id
                  (map
                     2#(match %1 as symbolExpr with
                       | #x -> True
                       | apply _ $args ->
                         any id (map 1#(containSymbol x %1) args)
                       | _ -> False)
                     xs))
          (ts1 ++ ts2))

containFunction $f $mexpr :=
  any
    id
    (match mexpr as mathExpr with
      | poly $ts1 / poly $ts2 ->
        map
          (\term ->
            match term as termExpr with
              | term _ $xs ->
                any
                  id
                  (map
                     2#(match %1 as symbolExpr with
                       | apply $g $args ->
                         if f = g
                           then True
                           else any id (map 1#(containFunction f %1) args)
                       | _ -> False)
                     xs))
          (ts1 ++ ts2))

containFunctionWithOrder $f $n $mexpr :=
  any
    id
    (match mexpr as mathExpr with
      | poly $ts1 / poly $ts2 ->
        map
          (\term ->
            match term as termExpr with
              | term _ $xs ->
                any
                  id
                  (map
                     2#(match %1 as symbolExpr with
                       | apply $g $args ->
                         if f = g && %2 >= n
                           then True
                           else any
                                  id
                                  (map
                                     1#(containFunctionWithOrder f n %1)
                                     args)
                       | _ -> False)
                     xs))
          (ts1 ++ ts2))

containFunctionWithIndex $mexpr :=
  any
    id
    (match mexpr as mathExpr with
      | poly $ts1 / poly $ts2 ->
        map
          (\term ->
            match term as termExpr with
              | term _ $xs ->
                any
                  id
                  (map
                     2#(match %1 as symbolExpr with
                       | apply (?isScalar & $f) $args ->
                         match f as mathExpr with
                           | symbol _ ![] -> True
                           | _ ->
                             any id (map 1#(containFunctionWithIndex %1) args)
                       | apply _ $args ->
                         any id (map 1#(containFunctionWithIndex %1) args)
                       | _ -> False)
                     xs))
          (ts1 ++ ts2))

findSymbolsFromPoly $poly :=
  matchAll poly as mathExpr with
    | poly (term _ ((symbol _ _ & $s) :: _) :: _) -> s

--
-- Substitute
--
substitute %ls $mexpr :=
  match ls as list (symbolExpr, mathExpr) with
    | [] -> mexpr
    | ($x, $a) :: $rs -> substitute rs (substitute' x a mexpr)

substitute' $x %a $mexpr := mapSymbols 1#(rewriteSymbol x a %1) mexpr

rewriteSymbol $x $a $sexpr :=
  match sexpr as symbolExpr with
    | #x -> a
    | _ -> sexpr

V.substitute %xs %ys $mexpr :=
  substitute (zip (tensorToList xs) (tensorToList ys)) mexpr

expandAll $mexpr :=
  match mexpr as mathExpr with
    | ?isSymbol -> mexpr
    -- function application
    | apply $g $args -> capply g (map expandAll args)
    -- quote
    | quote $g -> g
    -- term (multiplication)
    | term $a $ps -> a * product (map 2#(expandAll %1 ^ expandAll %2) ps)
    -- polynomial
    | poly $ts -> sum (map expandAll ts)
    -- quotient
    | $p1 / $p2 ->
      let p1' := expandAll p1
          p2' := expandAll p2
       in p1' / p2'

expandAll' $mexpr :=
  match mexpr as mathExpr with
    | ?isSymbol -> mexpr
    -- function application
    | apply $g $args -> capply g (map expandAll' args)
    -- quote
    | quote $g -> g
    -- term (multiplication)
    | term $a $ps -> a *' product' (map 2#(expandAll' %1 ^' expandAll' %2) ps)
    -- polynomial
    | poly $ts -> sum' (map expandAll' ts)
    -- quotient
    | $p1 / $p2 ->
      let p1' := expandAll' p1
          p2' := expandAll' p2
       in p1' /' p2'

--
-- Coefficient
--
coefficients $f $x :=
  let m := maximum
             (0 :: (matchAll f as mathExpr with
               | poly (term $a (ncons #x $k $ts) :: _) / _ -> k))
   in map 1#(coefficient f x %1) (between 0 m)

coefficient $f $x $m :=
  if m = 0
    then sum
           (matchAll f as mathExpr with
             | poly (term $a (!(#x :: _) & $ts) :: _) / _ ->
               foldl (*') a (map (^') ts)) / denominator f
    else coefficient' f x m

coefficient' $f $x $m :=
  sum
    (matchAll f as mathExpr with
      | poly (term $a (ncons #x $k $ts) :: _) / _ ->
        if m = k then foldl (*') a (map (^') ts) else 0) / denominator f

coefficient2 $f $x $y :=
  sum
    (matchAll f as mathExpr with
      | poly (term $a (#x :: #y :: $ts) :: _) / _ ->
        foldl (*') a (map (^') ts)) / denominator f