{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Internal.Util(
    
    vector, matrix,
    disp,
    formatSparse,
    approxInt,
    dispDots,
    dispBlanks,
    formatShort,
    dispShort,
    zeros, ones,
    diagl,
    row,
    col,
    (&), (¦), (|||), (——), (===),
    (?), (¿),
    Indexable(..), size,
    Numeric,
    rand, randn,
    cross,
    norm,
    ℕ,ℤ,ℝ,ℂ,iC,
    Normed(..), norm_Frob, norm_nuclear,
    magnit,
    normalize,
    mt,
    (~!~),
    pairwiseD2,
    rowOuters,
    null1,
    null1sym,
    
    
    corr, conv, corrMin,
    
    corr2, conv2, separable,
    block2x2,block3x3,view1,unView1,foldMatrix,
    gaussElim_1, gaussElim_2, gaussElim,
    luST, luSolve', luSolve'', luPacked', luPacked'',
    invershur
) where
import Internal.Vector
import Internal.Matrix hiding (size)
import Internal.Numeric
import Internal.Element
import Internal.Container
import Internal.Vectorized
import Internal.IO
import Internal.Algorithms hiding (Normed,linearSolve',luSolve', luPacked')
import Numeric.Matrix()
import Numeric.Vector()
import Internal.Random
import Internal.Convolution
import Control.Monad(when,forM_)
import Text.Printf
import Data.List.Split(splitOn)
import Data.List(intercalate,sortBy,foldl')
import Control.Arrow((&&&),(***))
import Data.Complex
import Data.Function(on)
import Internal.ST
#if MIN_VERSION_base(4,11,0)
import Prelude hiding ((<>))
#endif
type ℝ = Double
type ℕ = Int
type ℤ = Int
type ℂ = Complex Double
iC :: C
iC = 0:+1
vector :: [R] -> Vector R
vector = fromList
matrix
  :: Int 
  -> [R] 
  -> Matrix R
matrix c = reshape c . fromList
disp :: Int -> Matrix Double -> IO ()
disp n = putStr . dispf n
diagl :: [Double] -> Matrix Double
diagl = diag . fromList
zeros :: Int 
      -> Int 
      -> Matrix Double
zeros r c = konst 0 (r,c)
ones :: Int 
     -> Int 
     -> Matrix Double
ones r c = konst 1 (r,c)
infixl 3 &
(&) :: Vector Double -> Vector Double -> Vector Double
a & b = vjoin [a,b]
infixl 3 |||
(|||) :: Element t => Matrix t -> Matrix t -> Matrix t
a ||| b = fromBlocks [[a,b]]
infixl 3 ¦
(¦) :: Matrix Double -> Matrix Double -> Matrix Double
(¦) = (|||)
(===) :: Element t => Matrix t -> Matrix t -> Matrix t
infixl 2 ===
a === b = fromBlocks [[a],[b]]
(——) :: Matrix Double -> Matrix Double -> Matrix Double
infixl 2 ——
(——) = (===)
row :: [Double] -> Matrix Double
row = asRow . fromList
col :: [Double] -> Matrix Double
col = asColumn . fromList
infixl 9 ?
(?) :: Element t => Matrix t -> [Int] -> Matrix t
(?) = flip extractRows
infixl 9 ¿
(¿) :: Element t => Matrix t -> [Int] -> Matrix t
(¿)= flip extractColumns
cross :: Product t => Vector t -> Vector t -> Vector t
cross x y | dim x == 3 && dim y == 3 = fromList [z1,z2,z3]
          | otherwise = error $ "the cross product requires 3-element vectors (sizes given: "
                                ++show (dim x)++" and "++show (dim y)++")"
  where
    [x1,x2,x3] = toList x
    [y1,y2,y3] = toList y
    z1 = x2*y3-x3*y2
    z2 = x3*y1-x1*y3
    z3 = x1*y2-x2*y1
{-# SPECIALIZE cross :: Vector Double -> Vector Double -> Vector Double #-}
{-# SPECIALIZE cross :: Vector (Complex Double) -> Vector (Complex Double) -> Vector (Complex Double) #-}
norm :: Vector Double -> Double
norm = pnorm PNorm2
class Normed a
  where
    norm_0   :: a -> R
    norm_1   :: a -> R
    norm_2   :: a -> R
    norm_Inf :: a -> R
instance Normed (Vector R)
  where
    norm_0 v = sumElements (step (abs v - scalar (eps*normInf v)))
    norm_1 = pnorm PNorm1
    norm_2 = pnorm PNorm2
    norm_Inf = pnorm Infinity
instance Normed (Vector C)
  where
    norm_0 v = sumElements (step (fst (fromComplex (abs v)) - scalar (eps*normInf v)))
    norm_1 = pnorm PNorm1
    norm_2 = pnorm PNorm2
    norm_Inf = pnorm Infinity
instance Normed (Matrix R)
  where
    norm_0 = norm_0 . flatten
    norm_1 = pnorm PNorm1
    norm_2 = pnorm PNorm2
    norm_Inf = pnorm Infinity
instance Normed (Matrix C)
  where
    norm_0 = norm_0 . flatten
    norm_1 = pnorm PNorm1
    norm_2 = pnorm PNorm2
    norm_Inf = pnorm Infinity
instance Normed (Vector I)
  where
    norm_0 = fromIntegral . sumElements . step . abs
    norm_1 = fromIntegral . norm1
    norm_2 v = sqrt . fromIntegral $ dot v v
    norm_Inf = fromIntegral . normInf
instance Normed (Vector Z)
  where
    norm_0 = fromIntegral . sumElements . step . abs
    norm_1 = fromIntegral . norm1
    norm_2 v = sqrt . fromIntegral $ dot v v
    norm_Inf = fromIntegral . normInf
instance Normed (Vector Float)
  where
    norm_0 = norm_0 . double
    norm_1 = norm_1 . double
    norm_2 = norm_2 . double
    norm_Inf = norm_Inf . double
instance Normed (Vector (Complex Float))
  where
    norm_0 = norm_0 . double
    norm_1 = norm_1 . double
    norm_2 = norm_2 . double
    norm_Inf = norm_Inf . double
norm_Frob :: (Normed (Vector t), Element t) => Matrix t -> R
norm_Frob = norm_2 . flatten
norm_nuclear :: Field t => Matrix t -> R
norm_nuclear = sumElements . singularValues
magnit :: (Element t, Normed (Vector t)) => R -> t -> Bool
magnit e x = norm_1 (fromList [x]) > e
normalize :: (Normed (Vector t), Num (Vector t), Field t) => Vector t -> Vector t
normalize v = v / real (scalar (norm_2 v))
mt :: Matrix Double -> Matrix Double
mt = trans . inv
size :: Container c t => c t -> IndexOf c
size = size'
class Indexable c t | c -> t , t -> c
  where
    infixl 9 !
    (!) :: c -> Int -> t
instance Indexable (Vector Double) Double
  where
    (!) = (@>)
instance Indexable (Vector Float) Float
  where
    (!) = (@>)
instance Indexable (Vector I) I
  where
    (!) = (@>)
instance Indexable (Vector Z) Z
  where
    (!) = (@>)
instance Indexable (Vector (Complex Double)) (Complex Double)
  where
    (!) = (@>)
instance Indexable (Vector (Complex Float)) (Complex Float)
  where
    (!) = (@>)
instance Element t => Indexable (Matrix t) (Vector t)
  where
    m!j = subVector (j*c) c (flatten m)
      where
        c = cols m
pairwiseD2 :: Matrix Double -> Matrix Double -> Matrix Double
pairwiseD2 x y | ok = x2 `outer` oy + ox `outer` y2 - 2* x <> trans y
               | otherwise = error $ "pairwiseD2 with different number of columns: "
                                   ++ show (size x) ++ ", " ++ show (size y)
  where
    ox = one (rows x)
    oy = one (rows y)
    oc = one (cols x)
    one k = konst 1 k
    x2 = x * x <> oc
    y2 = y * y <> oc
    ok = cols x == cols y
rowOuters :: Matrix Double -> Matrix Double -> Matrix Double
rowOuters a b = a' * b'
  where
    a' = kronecker a (ones 1 (cols b))
    b' = kronecker (ones 1 (cols a)) b
null1 :: Matrix R -> Vector R
null1 = last . toColumns . snd . rightSV
null1sym :: Herm R -> Vector R
null1sym = last . toColumns . snd . eigSH
infixl 0 ~!~
c ~!~ msg = when c (error msg)
formatSparse :: String -> String -> String -> Int -> Matrix Double -> String
formatSparse zeroI _zeroF sep _ (approxInt -> Just m) = format sep f m
  where
    f 0 = zeroI
    f x = printf "%.0f" x
formatSparse zeroI zeroF sep n m = format sep f m
  where
    f x | abs (x::Double) < 2*peps = zeroI++zeroF
        | abs (fromIntegral (round x::Int) - x) / abs x < 2*peps
            = printf ("%.0f."++replicate n ' ') x
        | otherwise = printf ("%."++show n++"f") x
approxInt m
    | norm_Inf (v - vi) < 2*peps * norm_Inf v = Just (reshape (cols m) vi)
    | otherwise = Nothing
  where
    v = flatten m
    vi = roundVector v
dispDots n = putStr . formatSparse "." (replicate n ' ') "  " n
dispBlanks n = putStr . formatSparse "" "" "  " n
formatShort sep fmt maxr maxc m = auxm4
  where
    (rm,cm) = size m
    (r1,r2,r3)
        | rm <= maxr = (rm,0,0)
        | otherwise  = (maxr-3,rm-maxr+1,2)
    (c1,c2,c3)
        | cm <= maxc = (cm,0,0)
        | otherwise  = (maxc-3,cm-maxc+1,2)
    [ [a,_,b]
     ,[_,_,_]
     ,[c,_,d]] = toBlocks [r1,r2,r3]
                          [c1,c2,c3] m
    auxm = fromBlocks [[a,b],[c,d]]
    auxm2
        | cm > maxc = format "|" fmt auxm
        | otherwise = format sep fmt auxm
    auxm3
        | cm > maxc = map (f . splitOn "|") (lines auxm2)
        | otherwise = (lines auxm2)
    f items = intercalate sep (take (maxc-3) items) ++ "  .. " ++
              intercalate sep (drop (maxc-3) items)
    auxm4
        | rm > maxr = unlines (take (maxr-3) auxm3 ++ vsep : drop (maxr-3) auxm3)
        | otherwise = unlines auxm3
    vsep = map g (head auxm3)
    g '.' = ':'
    g _ = ' '
dispShort :: Int -> Int -> Int -> Matrix Double -> IO ()
dispShort maxr maxc dec m =
    printf "%dx%d\n%s" (rows m) (cols m) (formatShort "  " fmt maxr maxc m)
  where
    fmt = printf ("%."++show dec ++"f")
block2x2 r c m = [[m11,m12],[m21,m22]]
  where
    m11 = m ?? (Take r, Take c)
    m12 = m ?? (Take r, Drop c)
    m21 = m ?? (Drop r, Take c)
    m22 = m ?? (Drop r, Drop c)
block3x3 r nr c nc m = [[m ?? (er !! i, ec !! j) | j <- [0..2] ] | i <- [0..2] ]
  where
    er = [ Range 0 1 (r-1), Range r 1 (r+nr-1), Drop (nr+r) ]
    ec = [ Range 0 1 (c-1), Range c 1 (c+nc-1), Drop (nc+c) ]
view1 :: Numeric t => Matrix t -> Maybe (View1 t)
view1 m
    | rows m > 0 && cols m > 0 = Just (e, flatten m12, flatten m21 , m22)
    | otherwise = Nothing
  where
    [[m11,m12],[m21,m22]] = block2x2 1 1 m
    e = m11 `atIndex` (0, 0)
unView1 :: Numeric t => View1 t -> Matrix t
unView1 (e,r,c,m) = fromBlocks [[scalar e, asRow r],[asColumn c, m]]
type View1 t = (t, Vector t, Vector t, Matrix t)
foldMatrix :: Numeric t => (Matrix t -> Matrix t) -> (View1 t -> View1 t) -> (Matrix t -> Matrix t)
foldMatrix g f ( (f <$>) . view1 . g -> Just (e,r,c,m)) = unView1 (e, r, c, foldMatrix g f m)
foldMatrix _ _ m = m
swapMax k m
    | rows m > 0 && j>0 = (j, m ?? (Pos (idxs swapped), All))
    | otherwise  = (0,m)
  where
    j = maxIndex $ abs (tr m ! k)
    swapped = j:[1..j-1] ++ 0:[j+1..rows m-1]
down g a = foldMatrix g f a
  where
    f (e,r,c,m)
        | e /= 0    = (1, r', 0, m - outer c r')
        | otherwise = error "singular!"
      where
        r' = r / scalar e
gaussElim_2
  :: (Eq t, Fractional t, Num (Vector t), Numeric t)
  => Matrix t -> Matrix t -> Matrix t
gaussElim_2 a b = flipudrl r
  where
    flipudrl = flipud . fliprl
    splitColsAt n = (takeColumns n &&& dropColumns n)
    go f x y = splitColsAt (cols a) (down f $ x ||| y)
    (a1,b1) = go (snd . swapMax 0) a b
    ( _, r) = go id (flipudrl $ a1) (flipudrl $ b1)
gaussElim_1
  :: (Fractional t, Num (Vector t), Ord t, Indexable (Vector t) t, Numeric t)
  => Matrix t -> Matrix t -> Matrix t
gaussElim_1 x y = dropColumns (rows x) (flipud $ fromRows s2)
  where
    rs = toRows $ x ||| y
    s1 = fromRows $ pivotDown (rows x) 0 rs      
    s2 = pivotUp (rows x-1) (toRows $ flipud s1)
pivotDown
  :: forall t . (Fractional t, Num (Vector t), Ord t, Indexable (Vector t) t, Numeric t)
  => Int -> Int -> [Vector t] -> [Vector t]
pivotDown t n xs
    | t == n    = []
    | otherwise = y : pivotDown t (n+1) ys
  where
    y:ys = redu (pivot n xs)
    pivot k = (const k &&& id)
            . sortBy (flip compare `on` (abs. (!k)))
    redu :: (Int, [Vector t]) -> [Vector t]
    redu (k,x:zs)
        | p == 0 = error "gauss: singular!"  
        | otherwise = u : map f zs
      where
        p = x!k
        u = scale (recip (x!k)) x
        f z = z - scale (z!k) u
    redu (_,[]) = []
pivotUp
  :: forall t . (Fractional t, Num (Vector t), Ord t, Indexable (Vector t) t, Numeric t)
  => Int -> [Vector t] -> [Vector t]
pivotUp n xs
    | n == -1 = []
    | otherwise = y : pivotUp (n-1) ys
  where
    y:ys = redu' (n,xs)
    redu' :: (Int, [Vector t]) -> [Vector t]
    redu' (k,x:zs) = u : map f zs
      where
        u = x
        f z = z - scale (z!k) u
    redu' (_,[]) = []
gaussElim a b = dropColumns (rows a) $ fst $ mutable gaussST (a ||| b)
gaussST (r,_) x = do
    let n = r-1
        axpy m a i j = rowOper (AXPY a i j AllCols)     m
        swap m i j   = rowOper (SWAP i j AllCols)       m
        scal m a i   = rowOper (SCAL a (Row i) AllCols) m
    forM_ [0..n] $ \i -> do
        c <- maxIndex . abs . flatten <$> extractMatrix x (FromRow i) (Col i)
        swap x i (i+c)
        a <- readMatrix x i i
        when (a == 0) $ error "singular!"
        scal x (recip a) i
        forM_ [i+1..n] $ \j -> do
            b <- readMatrix x j i
            axpy x (-b) i j
    forM_ [n,n-1..1] $ \i -> do
        forM_ [i-1,i-2..0] $ \j -> do
            b <- readMatrix x j i
            axpy x (-b) i j
luST ok (r,_) x = do
    let axpy m a i j = rowOper (AXPY a i j (FromCol (i+1))) m
        swap m i j   = rowOper (SWAP i j AllCols)           m
    p <- newUndefinedVector r
    forM_ [0..r-1] $ \i -> do
        k <- maxIndex . abs . flatten <$> extractMatrix x (FromRow i) (Col i)
        writeVector p i (k+i)
        swap x i (i+k)
        a <- readMatrix x i i
        when (ok a) $ do
            forM_ [i+1..r-1] $ \j -> do
                b <- (/a) <$> readMatrix x j i
                axpy x (-b) i j
                writeMatrix x j i b
    v <- unsafeFreezeVector p
    return (toList v)
luPacked' x = LU m p
  where
    (m,p) = mutable (luST (magnit 0)) x
scalS a (Slice x r0 c0 nr nc) = rowOper (SCAL a (RowRange r0 (r0+nr-1)) (ColRange c0 (c0+nc-1))) x
view x k r = do
    d <- readMatrix x k k
    let rr = r-1-k
        o  = if k < r-1 then 1 else 0
        s = Slice x (k+1) (k+1) rr rr
        u = Slice x k     (k+1) o  rr
        l = Slice x (k+1) k     rr o
    return (d,u,l,s)
withVec r f = \s x -> do
    p <- newUndefinedVector r
    _ <- f s x p
    v <- unsafeFreezeVector p
    return v
luPacked'' m = (id *** toList) (mutable (withVec (rows m) lu2) m)
  where
    lu2 (r,_) x p = do
        forM_ [0..r-1] $ \k -> do
            pivot x p k
            (d,u,l,s) <- view x k r
            when (magnit 0 d) $ do
                scalS (recip d) l
                gemmm 1 s (-1) l u
    pivot x p k = do
        j <- maxIndex . abs . flatten <$> extractMatrix x (FromRow k) (Col k)
        writeVector p k (j+k)
        swap k (k+j)
      where
        swap i j = rowOper (SWAP i j AllCols) x
rowRange m = [0..rows m -1]
at k = Pos (idxs[k])
backSust' lup rhs = foldl' f (rhs?[]) (reverse ls)
  where
    ls  = [ (d k , u k , b k) | k <- rowRange lup ]
      where
        d k = lup ?? (at k, at k)
        u k = lup ?? (at k, Drop (k+1))
        b k = rhs ?? (at k, All)
    f x (d,u,b) = (b - u<>x) / d
                       ===
                        x
forwSust' lup rhs = foldl' f (rhs?[]) ls
  where
    ls  = [ (l k , b k) | k <- rowRange lup ]
      where
        l k = lup ?? (at k, Take k)
        b k = rhs ?? (at k, All)
    f x (l,b) =     x
                   ===
                (b - l<>x)
luSolve'' (LU lup p) b = backSust' lup (forwSust' lup pb)
  where
    pb = b ?? (Pos (fixPerm' p), All)
forwSust lup rhs = fst $ mutable f rhs
  where
    f (r,c) x = do
        l <- unsafeThawMatrix lup
        let go k = gemmm 1 (Slice x k 0 1 c) (-1) (Slice l k 0 1 k) (Slice x 0 0 k c)
        mapM_ go [0..r-1]
backSust lup rhs = fst $ mutable f rhs
  where
    f (r,c) m = do
        l <- unsafeThawMatrix lup
        let d k = recip (lup `atIndex` (k,k))
            u k = Slice l k (k+1) 1 (r-1-k)
            b k = Slice m k 0 1 c
            x k = Slice m (k+1) 0 (r-1-k) c
            scal k = rowOper (SCAL (d k) (Row k) AllCols) m
            go k = gemmm 1 (b k) (-1) (u k) (x k) >> scal k
        mapM_ go [r-1,r-2..0]
luSolve' (LU lup p) b = backSust lup (forwSust lup pb)
  where
    pb = b ?? (Pos (fixPerm' p), All)
data MatrixView t b
    = Elem t
    | Block b b b b
  deriving Show
viewBlock' r c m
    | (rt,ct) == (1,1) = Elem (atM' m 0 0)
    | otherwise        = Block m11 m12 m21 m22
  where
    (rt,ct) = size m
    m11 = subm (0,0) (r,c)       m
    m12 = subm (0,c) (r,ct-c)    m
    m21 = subm (r,0) (rt-r,c)    m
    m22 = subm (r,c) (rt-r,ct-c) m
    subm = subMatrix
viewBlock m = viewBlock' n n m
  where
    n = rows m `div` 2
invershur (viewBlock -> Block a b c d) = fromBlocks [[a',b'],[c',d']]
  where
    r1 = invershur a
    r2 = c <> r1
    r3 = r1 <> b
    r4 = c <> r3
    r5 = r4-d
    r6 = invershur r5
    b' = r3 <> r6
    c' = r6 <> r2
    r7 = r3 <> c'
    a' = r1-r7
    d' = -r6
invershur x = recip x
instance Testable (Matrix I) where
   checkT _ = test
test :: (Bool, IO())
test = (and ok, return ())
  where
    m  = (3><4) [1..12] :: Matrix I
    r  = (2><3) [1,2,3,4,3,2]
    c  = (3><2) [0,4,4,1,2,3]
    p  = (9><10) [0..89] :: Matrix I
    ep = (2><3) [10,24,32,44,31,23]
    md = fromInt m      :: Matrix Double
    ok = [ tr m <> m == toInt (tr md <> md)
         , m <> tr m == toInt (md <> tr md)
         , m ?? (Take 2, Take 3) == remap (asColumn (range 2)) (asRow (range 3)) m
         , remap r (tr c) p == ep
         , tr p ?? (PosCyc (idxs[-5,13]), Pos (idxs[3,7,1])) == (2><3) [35,75,15,33,73,13]
         ]