;;
;; Matrices
;;

(define $M.*
  (cambda $ms
    (foldl M.b.* (car ms) (cdr ms))))

(define $M.b.*
  (lambda [%m1 %m2]
    (with-symbols {j}
      (. m1~#~j m2_j_#))))

(define $M.*'
  (cambda $ms
    (foldl M.b.*' (car ms) (cdr ms))))

(define $M.b.*'
  (lambda [%m1 %m2]
    (with-symbols {j}
      (.' m1~#~j m2_j))))

(define $M.power
  (lambda [%m $n]
    (repeated-squaring M.* m n)))
                       
(define $M.comm
  (lambda [%m1 %m2]
    (with-symbols {i j k}
      (- (. m1~i~j m2_j_k) (. m2~i~j m1_j_k)))))

(define $M.inverse
  (lambda [%m]
    (let {[$d (M.det m)]}
      (generate-tensor
        2#(match m matrix
            {[<cons ,%2 ,%1 _ $A $B $C $D>
              (if (even? (+ %1 %2))
                (/ (M.det (M.join A B C D)) d)
                (* -1 (/ (M.det (M.join A B C D)) d)))]})
        (tensor-size m)))))

(define $trace (lambda [%t] (with-symbols {i} (contract + t~i_i))))

(define $matrix
  (matcher
    {[<quad-cons $ $ $ $> [math-expr matrix matrix matrix]
      {[$tgt (match (tensor-size tgt) (list integer)
               {[<cons $m <cons $n _>>
                 {[tgt_1_1 tgt_1_[2 n] tgt_[2 m]_1 tgt_[2 m]_[2 n]]}]
                [_ {}]})]}]
     [<cons ,$i ,$j $ $ $ $ $> [math-expr matrix matrix matrix matrix]
      {[$tgt
        (let* {[$ns (tensor-size tgt)]
               [$m (nth 1 ns)]
               [$n (nth 2 ns)]}
          {[tgt_i_j
            tgt_[1 (- i 1)]_[1 (- j 1)]
            tgt_[1 (- i 1)]_[(+ j 1) n]
            tgt_[(+ i 1) m]_[1 (- j 1)]
            tgt_[(+ i 1) m]_[(+ j 1) n]
            ]})]}]
     [,$val []
      {[$tgt (if (eq? val tgt) {[]} {})]}]
     [$ [something]
      {[$tgt {tgt}]}]
     }))

(define $M.join
  (lambda [%A %B %C %D]
    (let* {[$as (tensor-size A)]
           [$a1 (nth 1 as)] [$a2 (nth 2 as)]
           [$bs (tensor-size B)]
           [$b1 (nth 1 bs)] [$b2 (nth 2 bs)]
           [$cs (tensor-size C)]
           [$c1 (nth 1 cs)] [$c2 (nth 2 cs)]
           [$ds (tensor-size D)]
           [$d1 (nth 1 ds)] [$d2 (nth 2 ds)]
           [$m1 (max a1 b1)] [$m2 (max a2 c2)]
           [$n1 (max c1 d1)] [$n2 (max b2 d2)]
           }
      (generate-tensor
        2#(match [%1 %2] [integer integer]
            {[[?(lte? $ a1) ?(lte? $ a2)] A_%1_%2]
             [[?(lte? $ m1) _] B_%1_(- %2 a2)]
             [[_ ?(lte? $ m2)] C_(- %1 a1)_%2]
             [[_ _] D_(- %1 m1)_(- %2 m2)]})
        {(+ m1 n1) (+ m2 n2)}))))

;;
;; Determinant
;;

(define $even-and-odd-permutations
  (lambda [$n]
    (let {[[$es $os] (even-and-odd-permutations' n)]}
      [(map 1#(lambda [$i] (nth i %1)) es)
       (map 1#(lambda [$i] (nth i %1)) os)])))

(define $even-and-odd-permutations'
  (lambda [$n]
    (match n integer
      {[,1 [{{1}} {}]]
       [,2 [{{1 2}} {{2 1}}]]
       [_ (let* {[[$es $os] (even-and-odd-permutations' (- n 1))]
                 [$es' (map 1#{@%1 n} es)]
                 [$os' (map 1#{@%1 n} os)]}
            [{@es'
              @(concat (map (lambda [$i] (map (permutate i n $) os')) (between 1 (- n 1))))
              }
             {@os'
              @(concat (map (lambda [$i] (map (permutate i n $) es')) (between 1 (- n 1))))
              }
             ]
            )]})))

(define $permutate
  (lambda [$x $y $xs]
    (match xs (list eq)
      {[<join $hs <cons ,x <join $ms <cons ,y $ts>>>>
        {@hs y @ms x @ts}]
       [<join $hs <cons ,y <join $ms <cons ,x $ts>>>>
        {@hs x @ms y @ts}]})))

(define $M.determinant
  (lambda [%m]
    (match (tensor-size m) (list integer)
      {[<cons ,0 <cons ,0 <nil>>> 1]
       [<cons $n <cons ,n <nil>>>
        (let {[[$es $os] (even-and-odd-permutations' n)]}
          (- (sum (map (lambda [$e]
                         (product (map2 (lambda [$i $j] m_i_j)
                                        (between 1 n)
                                        e)))
                       es))
             (sum (map (lambda [$o]
                         (product (map2 (lambda [$i $j] m_i_j)
                                        (between 1 n)
                                        o)))
                       os))))]
       [_ undefined]})))

(define $M.det M.determinant)

;;;
;;; Eigenvalues and eigenvectors
;;;

(define $M.eigenvalues
  (lambda [%m]
    (match (tensor-size m) (list integer)
      {[<cons ,2 <cons ,2 <nil>>>
        (let {[[$e1 $e2] (q-f (M.det (T.- m (scalar-to-tensor x {2 2}))) x)]}
          {e1 e2})]
       [_ undefined]})))

(define $M.eigenvectors
  (lambda [%m]
    (match (tensor-size m) (list integer)
      {[<cons ,2 <cons ,2 <nil>>>
        (let {[[$e1 $e2] (q-f (M.det (T.- m (scalar-to-tensor x {2 2}))) x)]}
          {[e1 (clear-index (T.- m (scalar-to-tensor e1 {2 2}))_i_1)]
           [e2 (clear-index (T.- m (scalar-to-tensor e2 {2 2}))_i_1)]})
        ]
       [_ undefined]})))

;;
;; LU decomposition
;;

(define $M.LU
  (lambda [%x]
    (match (tensor-size x) (list integer)
      {[<cons ,2 <cons ,2 <nil>>>
        (let* {[$L (generate-tensor 2#(match (compare %1 %2) ordering {[<less> 0] [<equal> 1] [<greater> b_%1_%2]}) {2 2})]
               [$U (generate-tensor 2#(match (compare %1 %2) ordering {[<greater> 0] [_ c_%1_%2]}) {2 2})]
               [$m (M.* L U)]
               [$ret (solve {[m_1_1 x_1_1 c_1_1] [m_1_2 x_1_2 c_1_2]
                             [m_2_1 x_2_1 b_2_1] [m_2_2 x_2_2 c_2_2]})]}
          [(substitute ret L) (substitute ret U)])]
       [<cons ,3 <cons ,3 <nil>>>
        (let* {[$L (generate-tensor 2#(match (compare %1 %2) ordering {[<less> 0] [<equal> 1] [<greater> b_%1_%2]}) {3 3})]
               [$U (generate-tensor 2#(match (compare %1 %2) ordering {[<greater> 0] [_ c_%1_%2]}) {3 3})]
               [$m (M.* L U)]
               [$ret (solve {[m_1_1 x_1_1 c_1_1] [m_1_2 x_1_2 c_1_2] [m_1_3 x_1_3 c_1_3]
                             [m_2_1 x_2_1 b_2_1] [m_2_2 x_2_2 c_2_2] [m_2_3 x_2_3 c_2_3]
                             [m_3_1 x_3_1 b_3_1] [m_3_2 x_3_2 b_3_2] [m_3_3 x_3_3 c_3_3]})]}
          [(substitute ret L) (substitute ret U)])]
       [_ undefined]})))

;;
;; Utility
;;

(define $generate-matrix-from-quadratic-expr
  (lambda [$f $xs]
    (generate-tensor
      2#(coefficient2 f (nth %1 xs) (nth %2 xs))
      {(length xs) (length xs)})))